From 4852949b754195c1974110e95ff3003a3f31ac4f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 31 Aug 2022 12:25:57 -0400 Subject: [PATCH] add helpers to get different types of keys from the hybrid graph --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 25 +++++++++++++++++-- gtsam/hybrid/HybridGaussianFactorGraph.h | 17 ++++++++----- .../tests/testGaussianHybridFactorGraph.cpp | 9 +++---- 3 files changed, 38 insertions(+), 13 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 94e33890d..c031b9729 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -402,14 +402,35 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) { } /* ************************************************************************ */ -const Ordering HybridGaussianFactorGraph::getHybridOrdering( - OptionalOrderingType orderingType) const { +const KeySet HybridGaussianFactorGraph::getDiscreteKeys() const { KeySet discrete_keys; for (auto &factor : factors_) { for (const DiscreteKey &k : factor->discreteKeys()) { discrete_keys.insert(k.first); } } + return discrete_keys; +} + +/* ************************************************************************ */ +const KeySet HybridGaussianFactorGraph::getContinuousKeys() const { + KeySet keys; + for (auto &factor : factors_) { + for (const Key &key : factor->continuousKeys()) { + keys.insert(key); + } + } + return keys; +} + +/* ************************************************************************ */ +const Ordering HybridGaussianFactorGraph::getHybridOrdering() const { + KeySet discrete_keys = getDiscreteKeys(); + for (auto &factor : factors_) { + for (const DiscreteKey &k : factor->discreteKeys()) { + discrete_keys.insert(k.first); + } + } const VariableIndex index(factors_); Ordering ordering = Ordering::ColamdConstrainedLast( diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 56f9b7e07..ad5cde09b 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -161,14 +161,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph } } + /// Get all the discrete keys in the factor graph. + const KeySet getDiscreteKeys() const; + + /// Get all the continuous keys in the factor graph. + const KeySet getContinuousKeys() const; + /** - * @brief - * - * @param orderingType - * @return const Ordering + * @brief Return a Colamd constrained ordering where the discrete keys are + * eliminated after the continuous keys. + * + * @return const Ordering */ - const Ordering getHybridOrdering( - OptionalOrderingType orderingType = boost::none) const; + const Ordering getHybridOrdering() const; }; } // namespace gtsam diff --git a/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp b/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp index e4bd0e084..40da42412 100644 --- a/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp +++ b/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp @@ -184,8 +184,8 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalSimple) { hfg.add(DecisionTreeFactor(m1, {2, 8})); hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4")); - HybridBayesTree::shared_ptr result = hfg.eliminateMultifrontal( - Ordering::ColamdConstrainedLast(hfg, {M(1), M(2)})); + HybridBayesTree::shared_ptr result = + hfg.eliminateMultifrontal(hfg.getHybridOrdering()); // The bayes tree should have 3 cliques EXPECT_LONGS_EQUAL(3, result->size()); @@ -215,7 +215,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalCLG) { hfg.add(HybridDiscreteFactor(DecisionTreeFactor(m, {2, 8}))); // Get a constrained ordering keeping c1 last - auto ordering_full = Ordering::ColamdConstrainedLast(hfg, {M(1)}); + auto ordering_full = hfg.getHybridOrdering(); // Returns a Hybrid Bayes Tree with distribution P(x0|x1)P(x1|c1)P(c1) HybridBayesTree::shared_ptr hbt = hfg.eliminateMultifrontal(ordering_full); @@ -484,8 +484,7 @@ TEST(HybridGaussianFactorGraph, SwitchingTwoVar) { } HybridBayesNet::shared_ptr hbn; HybridGaussianFactorGraph::shared_ptr remaining; - std::tie(hbn, remaining) = - hfg->eliminatePartialSequential(ordering_partial); + std::tie(hbn, remaining) = hfg->eliminatePartialSequential(ordering_partial); EXPECT_LONGS_EQUAL(14, hbn->size()); EXPECT_LONGS_EQUAL(11, remaining->size());