diff --git a/gtsam/hybrid/CGMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp similarity index 77% rename from gtsam/hybrid/CGMixtureFactor.cpp rename to gtsam/hybrid/GaussianMixtureFactor.cpp index 08ae8b6e9..61cfb1d70 100644 --- a/gtsam/hybrid/CGMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -10,7 +10,7 @@ * -------------------------------------------------------------------------- */ /** - * @file CGMixtureFactor.cpp + * @file GaussianMixtureFactor.cpp * @brief A set of Gaussian factors indexed by a set of discrete keys. * @author Fan Jiang * @author Varun Agrawal @@ -18,7 +18,7 @@ * @date Mar 12, 2022 */ -#include +#include #include #include @@ -27,15 +27,15 @@ namespace gtsam { -CGMixtureFactor::CGMixtureFactor(const KeyVector &continuousKeys, +GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, const Factors &factors) : Base(continuousKeys, discreteKeys), factors_(factors) {} -bool CGMixtureFactor::equals(const HybridFactor &lf, double tol) const { +bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { return false; } -void CGMixtureFactor::print(const std::string &s, const KeyFormatter &formatter) const { +void GaussianMixtureFactor::print(const std::string &s, const KeyFormatter &formatter) const { HybridFactor::print(s, formatter); factors_.print( "mixture = ", @@ -49,12 +49,12 @@ void CGMixtureFactor::print(const std::string &s, const KeyFormatter &formatter) }); } -const CGMixtureFactor::Factors& CGMixtureFactor::factors() { +const GaussianMixtureFactor::Factors& GaussianMixtureFactor::factors() { return factors_; } /* *******************************************************************************/ -CGMixtureFactor::Sum CGMixtureFactor::addTo(const CGMixtureFactor::Sum &sum) const { +GaussianMixtureFactor::Sum GaussianMixtureFactor::addTo(const GaussianMixtureFactor::Sum &sum) const { using Y = GaussianFactorGraph; auto add = [](const Y &graph1, const Y &graph2) { auto result = graph1; @@ -66,7 +66,7 @@ CGMixtureFactor::Sum CGMixtureFactor::addTo(const CGMixtureFactor::Sum &sum) con } /* *******************************************************************************/ -CGMixtureFactor::Sum CGMixtureFactor::wrappedFactors() const { +GaussianMixtureFactor::Sum GaussianMixtureFactor::wrappedFactors() const { auto wrap = [](const GaussianFactor::shared_ptr &factor) { GaussianFactorGraph result; result.push_back(factor); @@ -74,4 +74,4 @@ CGMixtureFactor::Sum CGMixtureFactor::wrappedFactors() const { }; return {factors_, wrap}; } -} \ No newline at end of file +} diff --git a/gtsam/hybrid/CGMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h similarity index 87% rename from gtsam/hybrid/CGMixtureFactor.h rename to gtsam/hybrid/GaussianMixtureFactor.h index fd81cdd91..4cd5e71de 100644 --- a/gtsam/hybrid/CGMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -10,7 +10,7 @@ * -------------------------------------------------------------------------- */ /** - * @file CGMixtureFactor.h + * @file GaussianMixtureFactor.h * @brief A set of Gaussian factors indexed by a set of discrete keys. * @author Fan Jiang * @author Varun Agrawal @@ -29,19 +29,19 @@ namespace gtsam { class GaussianFactorGraph; -class CGMixtureFactor : public HybridFactor { +class GaussianMixtureFactor : public HybridFactor { public: using Base = HybridFactor; - using This = CGMixtureFactor; + using This = GaussianMixtureFactor; using shared_ptr = boost::shared_ptr; using Factors = DecisionTree; Factors factors_; - CGMixtureFactor() = default; + GaussianMixtureFactor() = default; - CGMixtureFactor(const KeyVector &continuousKeys, + GaussianMixtureFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, const Factors &factors); using Sum = DecisionTree; diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 53d10518f..d7e2f33af 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -25,8 +25,8 @@ namespace gtsam { /** - * A hybrid Bayes net can have discrete conditionals, Gaussian mixtures, - * or pure Gaussian conditionals. + * A hybrid Bayes net is a collection of HybridConditionals, which can have + * discrete conditionals, Gaussian mixtures, or pure Gaussian conditionals. */ class GTSAM_EXPORT HybridBayesNet : public BayesNet { public: @@ -40,4 +40,4 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { HybridBayesNet() : Base() {} }; -} // namespace gtsam \ No newline at end of file +} // namespace gtsam diff --git a/gtsam/hybrid/HybridBayesTree.h b/gtsam/hybrid/HybridBayesTree.h index 74b967fc8..ee51bdd6c 100644 --- a/gtsam/hybrid/HybridBayesTree.h +++ b/gtsam/hybrid/HybridBayesTree.h @@ -33,7 +33,9 @@ class HybridConditional; class VectorValues; /* ************************************************************************* */ -/** A clique in a HybridBayesTree */ +/** A clique in a HybridBayesTree + * which is a HybridConditional internally. + */ class GTSAM_EXPORT HybridBayesTreeClique : public BayesTreeCliqueBase { public: diff --git a/gtsam/hybrid/HybridDiscreteFactor.h b/gtsam/hybrid/HybridDiscreteFactor.h index f182f90a9..c55c5ecf0 100644 --- a/gtsam/hybrid/HybridDiscreteFactor.h +++ b/gtsam/hybrid/HybridDiscreteFactor.h @@ -23,6 +23,10 @@ namespace gtsam { +/** + * A HybridDiscreteFactor is a wrapper for DiscreteFactor, so we hide the + * implementation of DiscreteFactor, and thus avoiding diamond inheritance. + */ class HybridDiscreteFactor : public HybridFactor { public: using Base = HybridFactor; diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index 653d4270a..3d5bd7b21 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -34,6 +34,11 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, /** * Base class for hybrid probabilistic factors + * Examples: + * - HybridGaussianFactor + * - HybridDiscreteFactor + * - GaussianMixtureFactor + * - GaussianMixture */ class GTSAM_EXPORT HybridFactor : public Factor { public: diff --git a/gtsam/hybrid/HybridFactorGraph.cpp b/gtsam/hybrid/HybridFactorGraph.cpp index 735b8cfa9..574c3b781 100644 --- a/gtsam/hybrid/HybridFactorGraph.cpp +++ b/gtsam/hybrid/HybridFactorGraph.cpp @@ -21,7 +21,7 @@ #include #include #include -#include +#include #include #include #include @@ -56,14 +56,14 @@ static std::string GREEN = "\033[0;32m"; static std::string GREEN_BOLD = "\033[1;32m"; static std::string RESET = "\033[0m"; -static CGMixtureFactor::Sum &addGaussian( - CGMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) { +static GaussianMixtureFactor::Sum &addGaussian( + GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) { using Y = GaussianFactorGraph; // If the decision tree is not intiialized, then intialize it. if (sum.empty()) { GaussianFactorGraph result; result.push_back(factor); - sum = CGMixtureFactor::Sum(result); + sum = GaussianMixtureFactor::Sum(result); } else { auto add = [&factor](const Y &graph) { @@ -307,7 +307,7 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { // continue; // } - // auto ptr_mf = boost::dynamic_pointer_cast(factor); + // auto ptr_mf = boost::dynamic_pointer_cast(factor); // if (ptr_mf) gf.push_back(ptr_mf->factors_(new_assignment)); // auto ptr_gm = boost::dynamic_pointer_cast(factor); @@ -329,13 +329,13 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { std::cout << RED_BOLD << "HYBRID ELIM." << RESET << "\n"; - CGMixtureFactor::Sum sum; + GaussianMixtureFactor::Sum sum; std::vector deferredFactors; for (auto &f : factors) { if (f->isHybrid_) { - auto cgmf = boost::dynamic_pointer_cast(f); + auto cgmf = boost::dynamic_pointer_cast(f); if (cgmf) { sum = cgmf->addTo(sum); } @@ -395,7 +395,7 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { auto pair = unzip(eliminationResults); const GaussianMixture::Conditionals &conditionals = pair.first; - const CGMixtureFactor::Factors &separatorFactors = pair.second; + const GaussianMixtureFactor::Factors &separatorFactors = pair.second; // Create the GaussianMixture from the conditionals auto conditional = boost::make_shared( @@ -429,7 +429,7 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { } else { // Create a resulting DCGaussianMixture on the separator. - auto factor = boost::make_shared( + auto factor = boost::make_shared( frontalKeys, discreteSeparator, separatorFactors); return {boost::make_shared(conditional), factor}; } diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index 03c49564e..86c87a0ec 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -23,6 +23,10 @@ namespace gtsam { +/** + * A HybridGaussianFactor is a wrapper for GaussianFactor so that we do not have + * a diamond inheritance. + */ class HybridGaussianFactor : public HybridFactor { public: using Base = HybridFactor; diff --git a/gtsam/hybrid/tests/CMakeLists.txt b/gtsam/hybrid/tests/CMakeLists.txt index b52e18586..06ad2c505 100644 --- a/gtsam/hybrid/tests/CMakeLists.txt +++ b/gtsam/hybrid/tests/CMakeLists.txt @@ -1 +1 @@ -gtsamAddTestsGlob(hybrid "test*.cpp" "" "gtsam") \ No newline at end of file +gtsamAddTestsGlob(hybrid "test*.cpp" "" "gtsam") diff --git a/gtsam/hybrid/tests/testHybridConditional.cpp b/gtsam/hybrid/tests/testHybridConditional.cpp index 32cb05325..876105510 100644 --- a/gtsam/hybrid/tests/testHybridConditional.cpp +++ b/gtsam/hybrid/tests/testHybridConditional.cpp @@ -17,7 +17,7 @@ #include #include -#include +#include #include #include #include @@ -108,8 +108,8 @@ TEST(HybridFactorGraph, eliminateFullSequentialSimple) { C(1), boost::make_shared(X(1), I_3x3, Z_3x1), boost::make_shared(X(1), I_3x3, Vector3::Ones())); - hfg.add(CGMixtureFactor({X(1)}, {c1}, dt)); - // hfg.add(CGMixtureFactor({X(0)}, {c1}, dt)); + hfg.add(GaussianMixtureFactor({X(1)}, {c1}, dt)); + // hfg.add(GaussianMixtureFactor({X(0)}, {c1}, dt)); hfg.add(HybridDiscreteFactor(DecisionTreeFactor(c1, {2, 8}))); hfg.add(HybridDiscreteFactor( DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1 2 3 4"))); @@ -137,8 +137,8 @@ TEST(HybridFactorGraph, eliminateFullMultifrontalSimple) { C(1), boost::make_shared(X(1), I_3x3, Z_3x1), boost::make_shared(X(1), I_3x3, Vector3::Ones())); - hfg.add(CGMixtureFactor({X(1)}, {c1}, dt)); - // hfg.add(CGMixtureFactor({X(0)}, {c1}, dt)); + hfg.add(GaussianMixtureFactor({X(1)}, {c1}, dt)); + // hfg.add(GaussianMixtureFactor({X(0)}, {c1}, dt)); hfg.add(HybridDiscreteFactor(DecisionTreeFactor(c1, {2, 8}))); hfg.add(HybridDiscreteFactor( DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1 2 3 4"))); @@ -167,7 +167,7 @@ TEST_DISABLED(HybridFactorGraph, eliminateFullMultifrontalCLG) { C(1), boost::make_shared(X(1), I_3x3, Z_3x1), boost::make_shared(X(1), I_3x3, Vector3::Ones())); - hfg.add(CGMixtureFactor({X(1)}, {c}, dt)); + hfg.add(GaussianMixtureFactor({X(1)}, {c}, dt)); hfg.add(HybridDiscreteFactor(DecisionTreeFactor(c, {2, 8}))); // hfg.add(HybridDiscreteFactor(DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1 // 2 3 4"))); @@ -203,13 +203,13 @@ TEST_DISABLED(HybridFactorGraph, eliminateFullMultifrontalTwoClique) { C(0), boost::make_shared(X(0), I_3x3, Z_3x1), boost::make_shared(X(0), I_3x3, Vector3::Ones())); - hfg.add(CGMixtureFactor({X(0)}, {{C(0), 2}}, dt)); + hfg.add(GaussianMixtureFactor({X(0)}, {{C(0), 2}}, dt)); DecisionTree dt1( C(1), boost::make_shared(X(2), I_3x3, Z_3x1), boost::make_shared(X(2), I_3x3, Vector3::Ones())); - hfg.add(CGMixtureFactor({X(2)}, {{C(1), 2}}, dt1)); + hfg.add(GaussianMixtureFactor({X(2)}, {{C(1), 2}}, dt1)); } // hfg.add(HybridDiscreteFactor(DecisionTreeFactor(c, {2, 8}))); @@ -224,13 +224,13 @@ TEST_DISABLED(HybridFactorGraph, eliminateFullMultifrontalTwoClique) { C(3), boost::make_shared(X(3), I_3x3, Z_3x1), boost::make_shared(X(3), I_3x3, Vector3::Ones())); - hfg.add(CGMixtureFactor({X(3)}, {{C(3), 2}}, dt)); + hfg.add(GaussianMixtureFactor({X(3)}, {{C(3), 2}}, dt)); DecisionTree dt1( C(2), boost::make_shared(X(5), I_3x3, Z_3x1), boost::make_shared(X(5), I_3x3, Vector3::Ones())); - hfg.add(CGMixtureFactor({X(5)}, {{C(2), 2}}, dt1)); + hfg.add(GaussianMixtureFactor({X(5)}, {{C(2), 2}}, dt1)); } auto ordering_full =