diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index c60ab47aa..3826ef899 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -25,12 +25,12 @@ #include #include #include +#include #include #include #include #include -#include "gtsam/linear/GaussianConditional.h" namespace gtsam { /* *******************************************************************************/ @@ -162,7 +162,7 @@ HybridGaussianConditional::HybridGaussianConditional( /* *******************************************************************************/ const HybridGaussianConditional::Conditionals HybridGaussianConditional::conditionals() const { - return Conditionals(factors(), [](const auto& pair) { + return Conditionals(factors(), [](auto &&pair) { return std::dynamic_pointer_cast(pair.first); }); } @@ -170,7 +170,7 @@ HybridGaussianConditional::conditionals() const { /* *******************************************************************************/ size_t HybridGaussianConditional::nrComponents() const { size_t total = 0; - factors().visit([&total](const auto& node) { + factors().visit([&total](auto &&node) { if (node.first) total += 1; }); return total; @@ -178,8 +178,8 @@ size_t HybridGaussianConditional::nrComponents() const { /* *******************************************************************************/ GaussianConditional::shared_ptr HybridGaussianConditional::choose( - const DiscreteValues& discreteValues) const { - auto& [ptr, _] = factors()(discreteValues); + const DiscreteValues &discreteValues) const { + auto &[ptr, _] = factors()(discreteValues); if (!ptr) return nullptr; auto conditional = std::dynamic_pointer_cast(ptr); if (conditional) @@ -196,9 +196,10 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf, if (e == nullptr) return false; // Factors existence and scalar values are checked in BaseFactor::equals. - // Here we check additionally that the factors *are* conditionals and are equal. - auto compareFunc = [tol](const GaussianFactorValuePair& pair1, - const GaussianFactorValuePair& pair2) { + // Here we check additionally that the factors *are* conditionals + // and are equal. + auto compareFunc = [tol](const GaussianFactorValuePair &pair1, + const GaussianFactorValuePair &pair2) { auto c1 = std::dynamic_pointer_cast(pair1.first), c2 = std::dynamic_pointer_cast(pair2.first); return (!c1 && !c2) || (c1 && c2 && c1->equals(*c2, tol)); @@ -222,7 +223,8 @@ void HybridGaussianConditional::print(const std::string &s, "", [&](Key k) { return formatter(k); }, [&](const GaussianFactorValuePair &pair) -> std::string { RedirectCout rd; - if (auto gf = std::dynamic_pointer_cast(pair.first)) { + if (auto gf = + std::dynamic_pointer_cast(pair.first)) { gf->print("", formatter); return rd.str(); } else { @@ -323,7 +325,7 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( /* *******************************************************************************/ double HybridGaussianConditional::logProbability( - const HybridValues& values) const { + const HybridValues &values) const { auto [factor, _] = factors()(values.discrete()); if (auto conditional = std::dynamic_pointer_cast(factor)) return conditional->logProbability(values.continuous()); @@ -333,7 +335,7 @@ double HybridGaussianConditional::logProbability( } /* *******************************************************************************/ -double HybridGaussianConditional::evaluate(const HybridValues& values) const { +double HybridGaussianConditional::evaluate(const HybridValues &values) const { auto [factor, _] = factors()(values.discrete()); if (auto conditional = std::dynamic_pointer_cast(factor)) return conditional->evaluate(values.continuous()); diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index ecc317ab3..ceabe0871 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -20,6 +20,7 @@ #include #include +#include #include #include #include @@ -48,8 +49,6 @@ #include #include -#include "gtsam/discrete/DecisionTreeFactor.h" - namespace gtsam { /// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph: @@ -367,6 +366,7 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const { // any difference in noise models used. HybridGaussianProductFactor productFactor = collectProductFactor(); + // Check if a factor is null auto isNull = [](const GaussianFactor::shared_ptr &ptr) { return !ptr; }; // This is the elimination method on the leaf nodes