From 7f2fa61fb53422aa195cd5815d51f3b5c4e47814 Mon Sep 17 00:00:00 2001 From: Fan Jiang Date: Fri, 25 Mar 2022 23:28:40 -0600 Subject: [PATCH] Added more Python examples --- gtsam/hybrid/HybridConditional.h | 34 +++++++++++++++----- gtsam/hybrid/HybridFactorGraph.cpp | 26 +++++++-------- gtsam/hybrid/HybridJunctionTree.cpp | 11 +++++++ gtsam/hybrid/hybrid.i | 11 +++++++ python/gtsam/tests/test_HybridFactorGraph.py | 12 +++++-- 5 files changed, 71 insertions(+), 23 deletions(-) diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index 5fdf5064a..dea8fa9f4 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -19,7 +19,10 @@ #include #include +#include #include +#include +#include #include #include @@ -27,11 +30,8 @@ #include #include #include -#include "gtsam/hybrid/GaussianMixture.h" -#include -#include -#include +#include "gtsam/hybrid/GaussianMixture.h" namespace gtsam { @@ -44,6 +44,19 @@ class HybridFactorGraph; * - DiscreteConditional * - GaussianConditional * - GaussianMixture + * + * The reason why this is important is that `Conditional` is a CRTP class. + * CRTP is static polymorphism such that all CRTP classes, while bearing the + * same name, are different classes not sharing a vtable. This prevents them + * from being contained in any container, and thus it is impossible to + * dynamically cast between them. A better option, as illustrated here, is + * treating them as an implementation detail - such that the hybrid mechanism + * does not know what is inside the HybridConditional. This prevents us from + * having diamond inheritances, and neutralized the need to change other + * components of GTSAM to make hybrid elimination work. + * + * A great reference to the type-erasure pattern is Edurado Madrid's CppCon + * talk. */ class GTSAM_EXPORT HybridConditional : public HybridFactor, @@ -76,15 +89,20 @@ class GTSAM_EXPORT HybridConditional const KeyVector& continuousParents, const DiscreteKeys& discreteParents); - HybridConditional(boost::shared_ptr continuousConditional); + HybridConditional( + boost::shared_ptr continuousConditional); HybridConditional(boost::shared_ptr discreteConditional); - + HybridConditional(boost::shared_ptr gaussianMixture); GaussianMixture::shared_ptr asMixture() { - if (!isHybrid_) throw std::invalid_argument("Not a mixture"); - return boost::static_pointer_cast(inner); + if (!isHybrid_) throw std::invalid_argument("Not a mixture"); + return boost::static_pointer_cast(inner); + } + + boost::shared_ptr getInner() { + return inner; } /// @} diff --git a/gtsam/hybrid/HybridFactorGraph.cpp b/gtsam/hybrid/HybridFactorGraph.cpp index 6c7599a12..7a26dfadb 100644 --- a/gtsam/hybrid/HybridFactorGraph.cpp +++ b/gtsam/hybrid/HybridFactorGraph.cpp @@ -57,7 +57,7 @@ static std::string GREEN = "\033[0;32m"; static std::string GREEN_BOLD = "\033[1;32m"; static std::string RESET = "\033[0m"; -static bool DEBUG = false; +constexpr bool DEBUG = false; static GaussianMixtureFactor::Sum &addGaussian( GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) { @@ -123,7 +123,7 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { // However this is also the case with iSAM2, so no pressure :) // PREPROCESS: Identify the nature of the current elimination - std::unordered_map discreteCardinalities; + std::unordered_map mapFromKeyToDiscreteKey; std::set discreteSeparatorSet; std::set discreteFrontals; @@ -137,7 +137,7 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { frontalKeys.print(); } - // This initializes separatorKeys and discreteCardinalities + // This initializes separatorKeys and mapFromKeyToDiscreteKey for (auto &&factor : factors) { if (DEBUG) { std::cout << ">>> Adding factor: " << GREEN; @@ -147,7 +147,7 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { separatorKeys.insert(factor->begin(), factor->end()); if (!factor->isContinuous_) { for (auto &k : factor->discreteKeys_) { - discreteCardinalities[k.first] = k; + mapFromKeyToDiscreteKey[k.first] = k; } } } @@ -159,8 +159,8 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { // Fill in discrete frontals and continuous frontals for the end result for (auto &k : frontalKeys) { - if (discreteCardinalities.find(k) != discreteCardinalities.end()) { - discreteFrontals.insert(discreteCardinalities.at(k)); + if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) { + discreteFrontals.insert(mapFromKeyToDiscreteKey.at(k)); } else { continuousFrontals.insert(k); allContinuousKeys.insert(k); @@ -169,8 +169,8 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { // Fill in discrete frontals and continuous frontals for the end result for (auto &k : separatorKeys) { - if (discreteCardinalities.find(k) != discreteCardinalities.end()) { - discreteSeparatorSet.insert(discreteCardinalities.at(k)); + if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) { + discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k)); } else { continuousSeparator.insert(k); allContinuousKeys.insert(k); @@ -181,8 +181,8 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { if (DEBUG) { std::cout << RED_BOLD << "Keys: " << RESET; for (auto &f : frontalKeys) { - if (discreteCardinalities.find(f) != discreteCardinalities.end()) { - auto &key = discreteCardinalities.at(f); + if (mapFromKeyToDiscreteKey.find(f) != mapFromKeyToDiscreteKey.end()) { + auto &key = mapFromKeyToDiscreteKey.at(f); std::cout << boost::format(" (%1%,%2%),") % DefaultKeyFormatter(key.first) % key.second; } else { @@ -195,8 +195,8 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { } for (auto &f : separatorKeys) { - if (discreteCardinalities.find(f) != discreteCardinalities.end()) { - auto &key = discreteCardinalities.at(f); + if (mapFromKeyToDiscreteKey.find(f) != mapFromKeyToDiscreteKey.end()) { + auto &key = mapFromKeyToDiscreteKey.at(f); std::cout << boost::format(" (%1%,%2%),") % DefaultKeyFormatter(key.first) % key.second; } else { @@ -209,7 +209,7 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { // NOTE: We should really defer the product here because of pruning // Case 1: we are only dealing with continuous - if (discreteCardinalities.empty() && !allContinuousKeys.empty()) { + if (mapFromKeyToDiscreteKey.empty() && !allContinuousKeys.empty()) { if (DEBUG) { std::cout << RED_BOLD << "CONT. ONLY" << RESET << "\n"; } diff --git a/gtsam/hybrid/HybridJunctionTree.cpp b/gtsam/hybrid/HybridJunctionTree.cpp index 7c48934fc..22b40ad05 100644 --- a/gtsam/hybrid/HybridJunctionTree.cpp +++ b/gtsam/hybrid/HybridJunctionTree.cpp @@ -59,10 +59,15 @@ struct HybridConstructorTraversalData { myData.myJTNode = boost::make_shared(node->key, node->factors); parentData.myJTNode->addChild(myData.myJTNode); +#ifndef NDEBUG std::cout << "Getting discrete info: "; +#endif for (HybridFactor::shared_ptr& f : node->factors) { for (auto& k : f->discreteKeys_) { +#ifndef NDEBUG std::cout << "DK: " << DefaultKeyFormatter(k.first) << "\n"; +#endif + myData.discreteKeys.insert(k.first); } } @@ -99,8 +104,10 @@ struct HybridConstructorTraversalData { boost::tie(myConditional, mySeparatorFactor) = internal::EliminateSymbolic(symbolicFactors, keyAsOrdering); +#ifndef NDEBUG std::cout << "Symbolic: "; myConditional->print(); +#endif // Store symbolic elimination results in the parent myData.parentData->childSymbolicConditionals.push_back(myConditional); @@ -129,15 +136,19 @@ struct HybridConstructorTraversalData { myData.discreteKeys.exists(myConditional->frontals()[0]); const bool theirType = myData.discreteKeys.exists(childConditionals[i]->frontals()[0]); +#ifndef NDEBUG std::cout << "Type: " << DefaultKeyFormatter(myConditional->frontals()[0]) << " vs " << DefaultKeyFormatter(childConditionals[i]->frontals()[0]) << "\n"; +#endif if (myType == theirType) { // Increment number of frontal variables myNrFrontals += nrFrontals[i]; +#ifndef NDEBUG std::cout << "Merging "; childConditionals[i]->print(); +#endif merge[i] = true; } } diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index c84cd82c5..052575011 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -15,6 +15,17 @@ virtual class HybridFactor { gtsam::KeyVector keys() const; }; +#include +virtual class HybridConditional { + void print(string s = "Hybrid Conditional\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::HybridConditional& other, double tol = 1e-9) const; + size_t nrFrontals() const; + size_t nrParents() const; + Factor* getInner(); +}; + #include class GaussianMixtureFactor : gtsam::HybridFactor { static GaussianMixtureFactor FromFactorList( diff --git a/python/gtsam/tests/test_HybridFactorGraph.py b/python/gtsam/tests/test_HybridFactorGraph.py index d8b0f1f33..48187b7a2 100644 --- a/python/gtsam/tests/test_HybridFactorGraph.py +++ b/python/gtsam/tests/test_HybridFactorGraph.py @@ -40,9 +40,17 @@ class TestHybridFactorGraph(GtsamTestCase): hfg.add(jf2) hfg.push_back(gmf) - hfg.eliminateSequential( + hbn = hfg.eliminateSequential( gtsam.Ordering.ColamdConstrainedLastHybridFactorGraph( - hfg, [C(0)])).print() + hfg, [C(0)])) + + print("hbn = ", hbn) + + mixture = hbn.at(0).getInner() + print(mixture) + + discrete_conditional = hbn.at(hbn.size()-1).getInner() + print(discrete_conditional) if __name__ == "__main__":