formatting and comments

release/4.3a0
Varun Agrawal 2024-10-23 11:37:55 -04:00
parent 977ac0d762
commit 4c74ec113a
2 changed files with 15 additions and 13 deletions

View File

@ -25,12 +25,12 @@
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianBayesNet.h>
#include <gtsam/linear/GaussianConditional.h>
#include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/linear/JacobianFactor.h>
#include <cstddef>
#include <memory>
#include "gtsam/linear/GaussianConditional.h"
namespace gtsam {
/* *******************************************************************************/
@ -162,7 +162,7 @@ HybridGaussianConditional::HybridGaussianConditional(
/* *******************************************************************************/
const HybridGaussianConditional::Conditionals
HybridGaussianConditional::conditionals() const {
return Conditionals(factors(), [](const auto& pair) {
return Conditionals(factors(), [](auto &&pair) {
return std::dynamic_pointer_cast<GaussianConditional>(pair.first);
});
}
@ -170,7 +170,7 @@ HybridGaussianConditional::conditionals() const {
/* *******************************************************************************/
size_t HybridGaussianConditional::nrComponents() const {
size_t total = 0;
factors().visit([&total](const auto& node) {
factors().visit([&total](auto &&node) {
if (node.first) total += 1;
});
return total;
@ -178,8 +178,8 @@ size_t HybridGaussianConditional::nrComponents() const {
/* *******************************************************************************/
GaussianConditional::shared_ptr HybridGaussianConditional::choose(
const DiscreteValues& discreteValues) const {
auto& [ptr, _] = factors()(discreteValues);
const DiscreteValues &discreteValues) const {
auto &[ptr, _] = factors()(discreteValues);
if (!ptr) return nullptr;
auto conditional = std::dynamic_pointer_cast<GaussianConditional>(ptr);
if (conditional)
@ -196,9 +196,10 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf,
if (e == nullptr) return false;
// Factors existence and scalar values are checked in BaseFactor::equals.
// Here we check additionally that the factors *are* conditionals and are equal.
auto compareFunc = [tol](const GaussianFactorValuePair& pair1,
const GaussianFactorValuePair& pair2) {
// Here we check additionally that the factors *are* conditionals
// and are equal.
auto compareFunc = [tol](const GaussianFactorValuePair &pair1,
const GaussianFactorValuePair &pair2) {
auto c1 = std::dynamic_pointer_cast<GaussianConditional>(pair1.first),
c2 = std::dynamic_pointer_cast<GaussianConditional>(pair2.first);
return (!c1 && !c2) || (c1 && c2 && c1->equals(*c2, tol));
@ -222,7 +223,8 @@ void HybridGaussianConditional::print(const std::string &s,
"", [&](Key k) { return formatter(k); },
[&](const GaussianFactorValuePair &pair) -> std::string {
RedirectCout rd;
if (auto gf = std::dynamic_pointer_cast<GaussianConditional>(pair.first)) {
if (auto gf =
std::dynamic_pointer_cast<GaussianConditional>(pair.first)) {
gf->print("", formatter);
return rd.str();
} else {
@ -323,7 +325,7 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
/* *******************************************************************************/
double HybridGaussianConditional::logProbability(
const HybridValues& values) const {
const HybridValues &values) const {
auto [factor, _] = factors()(values.discrete());
if (auto conditional = std::dynamic_pointer_cast<GaussianConditional>(factor))
return conditional->logProbability(values.continuous());
@ -333,7 +335,7 @@ double HybridGaussianConditional::logProbability(
}
/* *******************************************************************************/
double HybridGaussianConditional::evaluate(const HybridValues& values) const {
double HybridGaussianConditional::evaluate(const HybridValues &values) const {
auto [factor, _] = factors()(values.discrete());
if (auto conditional = std::dynamic_pointer_cast<GaussianConditional>(factor))
return conditional->evaluate(values.continuous());

View File

@ -20,6 +20,7 @@
#include <gtsam/base/utilities.h>
#include <gtsam/discrete/Assignment.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
@ -48,8 +49,6 @@
#include <utility>
#include <vector>
#include "gtsam/discrete/DecisionTreeFactor.h"
namespace gtsam {
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
@ -367,6 +366,7 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
// any difference in noise models used.
HybridGaussianProductFactor productFactor = collectProductFactor();
// Check if a factor is null
auto isNull = [](const GaussianFactor::shared_ptr &ptr) { return !ptr; };
// This is the elimination method on the leaf nodes