From 53551e051dedae2c245f47dbf69bf390844bb3ab Mon Sep 17 00:00:00 2001 From: Fan Jiang Date: Sat, 12 Mar 2022 11:51:48 -0500 Subject: [PATCH] Add shorthand for inserting raw JacobianFactor --- gtsam/hybrid/CGMixtureFactor.h | 22 +++++++++++-- gtsam/hybrid/HybridDiscreteFactor.cpp | 2 +- gtsam/hybrid/HybridDiscreteFactor.h | 2 ++ gtsam/hybrid/HybridFactor.h | 34 ++++++++++++++++++-- gtsam/hybrid/HybridFactorGraph.cpp | 3 ++ gtsam/hybrid/HybridFactorGraph.h | 6 ++++ gtsam/hybrid/HybridGaussianFactor.h | 2 ++ gtsam/hybrid/tests/testHybridConditional.cpp | 12 ++++--- 8 files changed, 72 insertions(+), 11 deletions(-) diff --git a/gtsam/hybrid/CGMixtureFactor.h b/gtsam/hybrid/CGMixtureFactor.h index f2fa1e54a..b46146eb3 100644 --- a/gtsam/hybrid/CGMixtureFactor.h +++ b/gtsam/hybrid/CGMixtureFactor.h @@ -5,19 +5,35 @@ /** * @file CGMixtureFactor.h - * @brief Custom hybrid factor graph for discrete + continuous factors - * @author Kevin Doherty, kdoherty@mit.edu + * @brief A set of Gaussian factors indexed by a set of discrete keys. * @author Varun Agrawal * @author Fan Jiang + * @author Frank Dellaert * @date December 2021 */ #include +#include +#include +#include namespace gtsam { -class CGMixtureFactor : HybridFactor { +class CGMixtureFactor : public HybridFactor { +public: + using Base = HybridFactor; + using This = CGMixtureFactor; + using shared_ptr = boost::shared_ptr; + using Factors = DecisionTree; + + Factors factors_; + + CGMixtureFactor() = default; + + CGMixtureFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, const Factors &factors) : Base(continuousKeys, discreteKeys) { + + } }; } diff --git a/gtsam/hybrid/HybridDiscreteFactor.cpp b/gtsam/hybrid/HybridDiscreteFactor.cpp index 0bd2c9a64..13766933b 100644 --- a/gtsam/hybrid/HybridDiscreteFactor.cpp +++ b/gtsam/hybrid/HybridDiscreteFactor.cpp @@ -14,7 +14,7 @@ HybridDiscreteFactor::HybridDiscreteFactor(DiscreteFactor::shared_ptr other) } HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf) - : Base(dtf.keys()), + : Base(dtf.discreteKeys()), inner(boost::make_shared(std::move(dtf))) { } diff --git a/gtsam/hybrid/HybridDiscreteFactor.h b/gtsam/hybrid/HybridDiscreteFactor.h index cfb94bc69..0395c9512 100644 --- a/gtsam/hybrid/HybridDiscreteFactor.h +++ b/gtsam/hybrid/HybridDiscreteFactor.h @@ -24,6 +24,8 @@ namespace gtsam { class HybridDiscreteFactor : public HybridFactor { public: using Base = HybridFactor; + using This = HybridDiscreteFactor; + using shared_ptr = boost::shared_ptr; DiscreteFactor::shared_ptr inner; diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index eb28bc168..64b49f605 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -36,6 +37,12 @@ public: typedef boost::shared_ptr shared_ptr; ///< shared_ptr to this class typedef Factor Base; ///< Our base class + bool isDiscrete_ = false; + bool isContinuous_ = false; + bool isHybrid_ = false; + + DiscreteKeys discreteKeys_; + public: /// @name Standard Constructors @@ -46,8 +53,25 @@ public: /** Construct from container of keys. This constructor is used internally from derived factor * constructors, either from a container of keys or from a boost::assign::list_of. */ - template - HybridFactor(const CONTAINER &keys) : Base(keys) {} +// template +// HybridFactor(const CONTAINER &keys) : Base(keys) {} + + HybridFactor(const KeyVector &keys) : Base(keys), isContinuous_(true) {} + + static KeyVector CollectKeys(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys) { + KeyVector allKeys; + std::copy(continuousKeys.begin(), continuousKeys.end(), std::back_inserter(allKeys)); + std::transform(discreteKeys.begin(), + discreteKeys.end(), + std::back_inserter(allKeys), + [](const DiscreteKey &k) { return k.first; }); + return allKeys; + } + + HybridFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys) : Base( + CollectKeys(continuousKeys, discreteKeys)), isHybrid_(true) {} + + HybridFactor(const DiscreteKeys &discreteKeys) : Base(CollectKeys({}, discreteKeys)), isDiscrete_(true) {} /// Virtual destructor virtual ~HybridFactor() { @@ -64,7 +88,11 @@ public: void print( const std::string &s = "HybridFactor\n", const KeyFormatter &formatter = DefaultKeyFormatter) const override { - Base::print(s, formatter); + std::cout << s; + if (isContinuous_) std::cout << "Cont. "; + if (isDiscrete_) std::cout << "Disc. "; + if (isHybrid_) std::cout << "Hybr. "; + this->printKeys("", formatter); } /// @} diff --git a/gtsam/hybrid/HybridFactorGraph.cpp b/gtsam/hybrid/HybridFactorGraph.cpp index bd245b16a..080efc0e9 100644 --- a/gtsam/hybrid/HybridFactorGraph.cpp +++ b/gtsam/hybrid/HybridFactorGraph.cpp @@ -75,4 +75,7 @@ EliminateHybrid(const HybridFactorGraph &factors, boost::make_shared(std::move(sum))); } +void HybridFactorGraph::add(JacobianFactor &&factor) { + FactorGraph::add(boost::make_shared(std::move(factor))); +} } diff --git a/gtsam/hybrid/HybridFactorGraph.h b/gtsam/hybrid/HybridFactorGraph.h index fb0de644d..13bbd005d 100644 --- a/gtsam/hybrid/HybridFactorGraph.h +++ b/gtsam/hybrid/HybridFactorGraph.h @@ -33,6 +33,8 @@ class HybridEliminationTree; class HybridBayesTree; class HybridJunctionTree; +class JacobianFactor; + /** Main elimination function for HybridFactorGraph */ GTSAM_EXPORT std::pair, HybridFactor::shared_ptr> EliminateHybrid(const HybridFactorGraph& factors, const Ordering& keys); @@ -77,6 +79,10 @@ class HybridFactorGraph : public FactorGraph, public Eliminateable template HybridFactorGraph(const FactorGraph& graph) : Base(graph) {} + using FactorGraph::add; + + /// Add a factor directly using a shared_ptr. + void add(JacobianFactor &&factor); }; } diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index e40682dc1..d8a97ba30 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -24,6 +24,8 @@ namespace gtsam { class HybridGaussianFactor : public HybridFactor { public: using Base = HybridFactor; + using This = HybridGaussianFactor; + using shared_ptr = boost::shared_ptr; GaussianFactor::shared_ptr inner; diff --git a/gtsam/hybrid/tests/testHybridConditional.cpp b/gtsam/hybrid/tests/testHybridConditional.cpp index 6eef100c1..a56bcadf0 100644 --- a/gtsam/hybrid/tests/testHybridConditional.cpp +++ b/gtsam/hybrid/tests/testHybridConditional.cpp @@ -26,6 +26,8 @@ #include #include +#include + #include #include @@ -34,6 +36,8 @@ using namespace boost::assign; using namespace std; using namespace gtsam; +using gtsam::symbol_shorthand::X; + /* ************************************************************************* */ TEST_UNSAFE(HybridFactorGraph, test) { HybridConditional test; @@ -58,12 +62,12 @@ TEST_UNSAFE(HybridFactorGraph, eliminate) { TEST(HybridFactorGraph, eliminateMultifrontal) { HybridFactorGraph hfg; - DiscreteKey X(1, 2); + DiscreteKey x(X(1), 2); - hfg.add(HybridGaussianFactor(JacobianFactor(0, I_3x3, Z_3x1))); - hfg.add(HybridDiscreteFactor(DecisionTreeFactor(X, {2, 8}))); + hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1)); + hfg.add(HybridDiscreteFactor(DecisionTreeFactor(x, {2, 8}))); - auto result = hfg.eliminatePartialMultifrontal({0}); + auto result = hfg.eliminatePartialMultifrontal({X(0)}); GTSAM_PRINT(*result.first); GTSAM_PRINT(*result.second);