diff --git a/gtsam/hybrid/HybridGaussianFactor.cpp b/gtsam/hybrid/HybridGaussianFactor.cpp index b04db4977..d3b26d4ef 100644 --- a/gtsam/hybrid/HybridGaussianFactor.cpp +++ b/gtsam/hybrid/HybridGaussianFactor.cpp @@ -27,14 +27,16 @@ #include #include +#include "gtsam/base/types.h" + namespace gtsam { /* *******************************************************************************/ -HybridGaussianFactor::Factors HybridGaussianFactor::augment( +HybridGaussianFactor::FactorValuePairs HybridGaussianFactor::augment( const FactorValuePairs &factors) { // Find the minimum value so we can "proselytize" to positive values. // Done because we can't have sqrt of negative numbers. - Factors gaussianFactors; + DecisionTree gaussianFactors; AlgebraicDecisionTree valueTree; std::tie(gaussianFactors, valueTree) = unzip(factors); @@ -42,16 +44,16 @@ HybridGaussianFactor::Factors HybridGaussianFactor::augment( double min_value = valueTree.min(); // Finally, update the [A|b] matrices. - auto update = [&min_value](const GaussianFactorValuePair &gfv) { + auto update = [&min_value](const auto &gfv) -> GaussianFactorValuePair { auto [gf, value] = gfv; auto jf = std::dynamic_pointer_cast(gf); - if (!jf) return gf; + if (!jf) return {gf, 0.0}; // should this be zero or infinite? double normalized_value = value - min_value; // If the value is 0, do nothing - if (normalized_value == 0.0) return gf; + if (normalized_value == 0.0) return {gf, 0.0}; GaussianFactorGraph gfg; gfg.push_back(jf); @@ -62,18 +64,16 @@ HybridGaussianFactor::Factors HybridGaussianFactor::augment( auto constantFactor = std::make_shared(c); gfg.push_back(constantFactor); - return std::dynamic_pointer_cast( - std::make_shared(gfg)); + return {std::make_shared(gfg), normalized_value}; }; - return Factors(factors, update); + return FactorValuePairs(factors, update); } /* *******************************************************************************/ struct HybridGaussianFactor::ConstructorHelper { KeyVector continuousKeys; // Continuous keys extracted from factors DiscreteKeys discreteKeys; // Discrete keys provided to the constructors - FactorValuePairs pairs; // Used only if factorsTree is empty - Factors factorsTree; + FactorValuePairs pairs; // The decision tree with factors and scalars ConstructorHelper(const DiscreteKey &discreteKey, const std::vector &factors) @@ -85,9 +85,10 @@ struct HybridGaussianFactor::ConstructorHelper { break; } } - - // Build the DecisionTree from the factor vector - factorsTree = Factors(discreteKeys, factors); + // Build the FactorValuePairs DecisionTree + pairs = FactorValuePairs( + DecisionTree(discreteKeys, factors), + [](const auto &f) { return std::pair{f, 0.0}; }); } ConstructorHelper(const DiscreteKey &discreteKey, @@ -109,6 +110,7 @@ struct HybridGaussianFactor::ConstructorHelper { const FactorValuePairs &factorPairs) : discreteKeys(discreteKeys) { // Extract continuous keys from the first non-null factor + // TODO: just stop after first non-null factor factorPairs.visit([&](const GaussianFactorValuePair &pair) { if (pair.first && continuousKeys.empty()) { continuousKeys = pair.first->keys(); @@ -123,14 +125,13 @@ struct HybridGaussianFactor::ConstructorHelper { /* *******************************************************************************/ HybridGaussianFactor::HybridGaussianFactor(const ConstructorHelper &helper) : Base(helper.continuousKeys, helper.discreteKeys), - factors_(helper.factorsTree.empty() ? augment(helper.pairs) - : helper.factorsTree) {} + factors_(augment(helper.pairs)) {} /* *******************************************************************************/ HybridGaussianFactor::HybridGaussianFactor( const DiscreteKey &discreteKey, - const std::vector &factors) - : HybridGaussianFactor(ConstructorHelper(discreteKey, factors)) {} + const std::vector &factorPairs) + : HybridGaussianFactor(ConstructorHelper(discreteKey, factorPairs)) {} /* *******************************************************************************/ HybridGaussianFactor::HybridGaussianFactor( @@ -140,8 +141,8 @@ HybridGaussianFactor::HybridGaussianFactor( /* *******************************************************************************/ HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys &discreteKeys, - const FactorValuePairs &factors) - : HybridGaussianFactor(ConstructorHelper(discreteKeys, factors)) {} + const FactorValuePairs &factorPairs) + : HybridGaussianFactor(ConstructorHelper(discreteKeys, factorPairs)) {} /* *******************************************************************************/ bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const { @@ -153,10 +154,12 @@ bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const { if (factors_.empty() ^ e->factors_.empty()) return false; // Check the base and the factors: - return Base::equals(*e, tol) && - factors_.equals(e->factors_, [tol](const auto &f1, const auto &f2) { - return (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol)); - }); + auto compareFunc = [tol](const auto &pair1, const auto &pair2) { + auto f1 = pair1.first, f2 = pair2.first; + bool match = (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol)); + return match && gtsam::equal(pair1.second, pair2.second, tol); + }; + return Base::equals(*e, tol) && factors_.equals(e->factors_, compareFunc); } /* *******************************************************************************/ @@ -171,15 +174,16 @@ void HybridGaussianFactor::print(const std::string &s, } else { factors_.print( "", [&](Key k) { return formatter(k); }, - [&](const sharedFactor &gf) -> std::string { + [&](const auto &pair) -> std::string { RedirectCout rd; std::cout << ":\n"; - if (gf) { - gf->print("", formatter); + if (pair.first) { + pair.first->print("", formatter); return rd.str(); } else { return "nullptr"; } + std::cout << "scalar: " << pair.second << "\n"; }); } std::cout << "}" << std::endl; @@ -188,7 +192,7 @@ void HybridGaussianFactor::print(const std::string &s, /* *******************************************************************************/ HybridGaussianFactor::sharedFactor HybridGaussianFactor::operator()( const DiscreteValues &assignment) const { - return factors_(assignment); + return factors_(assignment).first; } /* *******************************************************************************/ @@ -207,7 +211,7 @@ GaussianFactorGraphTree HybridGaussianFactor::add( /* *******************************************************************************/ GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree() const { - auto wrap = [](const sharedFactor &gf) { return GaussianFactorGraph{gf}; }; + auto wrap = [](const auto &pair) { return GaussianFactorGraph{pair.first}; }; return {factors_, wrap}; } @@ -229,8 +233,8 @@ static double PotentiallyPrunedComponentError( AlgebraicDecisionTree HybridGaussianFactor::errorTree( const VectorValues &continuousValues) const { // functor to convert from sharedFactor to double error value. - auto errorFunc = [&continuousValues](const sharedFactor &gf) { - return PotentiallyPrunedComponentError(gf, continuousValues); + auto errorFunc = [this, &continuousValues](const auto &pair) { + return PotentiallyPrunedComponentError(pair.first, continuousValues); }; DecisionTree error_tree(factors_, errorFunc); return error_tree; @@ -239,8 +243,8 @@ AlgebraicDecisionTree HybridGaussianFactor::errorTree( /* *******************************************************************************/ double HybridGaussianFactor::error(const HybridValues &values) const { // Directly index to get the component, no need to build the whole tree. - const sharedFactor gf = factors_(values.discrete()); - return PotentiallyPrunedComponentError(gf, values.continuous()); + const auto pair = factors_(values.discrete()); + return PotentiallyPrunedComponentError(pair.first, values.continuous()); } } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index e5a575409..15993f582 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -66,12 +66,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { /// typedef for Decision Tree of Gaussian factors and arbitrary value. using FactorValuePairs = DecisionTree; - /// typedef for Decision Tree of Gaussian factors. - using Factors = DecisionTree; private: /// Decision tree of Gaussian factors indexed by discrete keys. - Factors factors_; + FactorValuePairs factors_; public: /// @name Constructors @@ -110,10 +108,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { * The value ϕ(x,M) for the factor is again ϕ_m(x) + E_m. * * @param discreteKeys Discrete variables and their cardinalities. - * @param factors The decision tree of Gaussian factor/scalar pairs. + * @param factorPairs The decision tree of Gaussian factor/scalar pairs. */ HybridGaussianFactor(const DiscreteKeys &discreteKeys, - const FactorValuePairs &factors); + const FactorValuePairs &factorPairs); /// @} /// @name Testable @@ -158,7 +156,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { double error(const HybridValues &values) const override; /// Getter for GaussianFactor decision tree - const Factors &factors() const { return factors_; } + const FactorValuePairs &factors() const { return factors_; } /// Add HybridNonlinearFactor to a Sum, syntactic sugar. friend GaussianFactorGraphTree &operator+=( @@ -184,10 +182,9 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { * value in the `b` vector as an additional row. * * @param factors DecisionTree of GaussianFactors and arbitrary scalars. - * Gaussian factor in factors. - * @return HybridGaussianFactor::Factors + * @return FactorValuePairs */ - static Factors augment(const FactorValuePairs &factors); + static FactorValuePairs augment(const FactorValuePairs &factors); /// Helper struct to assist private constructor below. struct ConstructorHelper; diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 7dfa56e77..957a85038 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -238,8 +238,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors, } else if (auto gmf = dynamic_pointer_cast(f)) { // Case where we have a HybridGaussianFactor with no continuous keys. // In this case, compute discrete probabilities. - auto logProbability = - [&](const GaussianFactor::shared_ptr &factor) -> double { + auto logProbability = [&](const auto &pair) -> double { + auto [factor, _] = pair; if (!factor) return 0.0; return factor->error(VectorValues()); }; diff --git a/gtsam/hybrid/HybridNonlinearFactor.cpp b/gtsam/hybrid/HybridNonlinearFactor.cpp index 9378d07fe..56711b313 100644 --- a/gtsam/hybrid/HybridNonlinearFactor.cpp +++ b/gtsam/hybrid/HybridNonlinearFactor.cpp @@ -196,8 +196,8 @@ std::shared_ptr HybridNonlinearFactor::linearize( } }; - DecisionTree> - linearized_factors(factors_, linearizeDT); + HybridGaussianFactor::FactorValuePairs linearized_factors(factors_, + linearizeDT); return std::make_shared(discreteKeys_, linearized_factors); diff --git a/gtsam/hybrid/tests/testSerializationHybrid.cpp b/gtsam/hybrid/tests/testSerializationHybrid.cpp index e09669117..8258d8615 100644 --- a/gtsam/hybrid/tests/testSerializationHybrid.cpp +++ b/gtsam/hybrid/tests/testSerializationHybrid.cpp @@ -52,11 +52,11 @@ BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf"); BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice") BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor, "gtsam_HybridGaussianFactor"); -BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::Factors, +BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::FactorValuePairs, "gtsam_HybridGaussianFactor_Factors"); -BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::Factors::Leaf, +BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::FactorValuePairs::Leaf, "gtsam_HybridGaussianFactor_Factors_Leaf"); -BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::Factors::Choice, +BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::FactorValuePairs::Choice, "gtsam_HybridGaussianFactor_Factors_Choice"); BOOST_CLASS_EXPORT_GUID(HybridGaussianConditional,