diff --git a/gtsam/discrete/DiscreteKey.h b/gtsam/discrete/DiscreteKey.h index dea00074d..297d5570d 100644 --- a/gtsam/discrete/DiscreteKey.h +++ b/gtsam/discrete/DiscreteKey.h @@ -69,6 +69,16 @@ namespace gtsam { push_back(key); return *this; } + + /// Print the keys and cardinalities. + void print(const std::string& s = "", + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { + for (auto&& dkey : *this) { + std::cout << DefaultKeyFormatter(dkey.first) << " " << dkey.second + << std::endl; + } + } + }; // DiscreteKeys /// Create a list from two keys diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 94e33890d..c031b9729 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -402,14 +402,35 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) { } /* ************************************************************************ */ -const Ordering HybridGaussianFactorGraph::getHybridOrdering( - OptionalOrderingType orderingType) const { +const KeySet HybridGaussianFactorGraph::getDiscreteKeys() const { KeySet discrete_keys; for (auto &factor : factors_) { for (const DiscreteKey &k : factor->discreteKeys()) { discrete_keys.insert(k.first); } } + return discrete_keys; +} + +/* ************************************************************************ */ +const KeySet HybridGaussianFactorGraph::getContinuousKeys() const { + KeySet keys; + for (auto &factor : factors_) { + for (const Key &key : factor->continuousKeys()) { + keys.insert(key); + } + } + return keys; +} + +/* ************************************************************************ */ +const Ordering HybridGaussianFactorGraph::getHybridOrdering() const { + KeySet discrete_keys = getDiscreteKeys(); + for (auto &factor : factors_) { + for (const DiscreteKey &k : factor->discreteKeys()) { + discrete_keys.insert(k.first); + } + } const VariableIndex index(factors_); Ordering ordering = Ordering::ColamdConstrainedLast( diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 56f9b7e07..ad5cde09b 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -161,14 +161,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph } } + /// Get all the discrete keys in the factor graph. + const KeySet getDiscreteKeys() const; + + /// Get all the continuous keys in the factor graph. + const KeySet getContinuousKeys() const; + /** - * @brief - * - * @param orderingType - * @return const Ordering + * @brief Return a Colamd constrained ordering where the discrete keys are + * eliminated after the continuous keys. + * + * @return const Ordering */ - const Ordering getHybridOrdering( - OptionalOrderingType orderingType = boost::none) const; + const Ordering getHybridOrdering() const; }; } // namespace gtsam diff --git a/gtsam/hybrid/HybridValues.h b/gtsam/hybrid/HybridValues.h index d5c78f951..4928f9384 100644 --- a/gtsam/hybrid/HybridValues.h +++ b/gtsam/hybrid/HybridValues.h @@ -31,8 +31,8 @@ namespace gtsam { /** - * HybridValues represents a collection of DiscreteValues and VectorValues. It - * is typically used to store the variables of a HybridGaussianFactorGraph. + * HybridValues represents a collection of DiscreteValues and VectorValues. + * It is typically used to store the variables of a HybridGaussianFactorGraph. * Optimizing a HybridGaussianBayesNet returns this class. */ class GTSAM_EXPORT HybridValues { @@ -47,10 +47,10 @@ class GTSAM_EXPORT HybridValues { /// @name Standard Constructors /// @{ - // Default constructor creates an empty HybridValues. + /// Default constructor creates an empty HybridValues. HybridValues() = default; - // Construct from DiscreteValues and VectorValues. + /// Construct from DiscreteValues and VectorValues. HybridValues(const DiscreteValues& dv, const VectorValues& cv) : discrete_(dv), continuous_(cv){}; @@ -58,7 +58,7 @@ class GTSAM_EXPORT HybridValues { /// @name Testable /// @{ - // print required by Testable for unit testing + /// print required by Testable for unit testing void print(const std::string& s = "HybridValues", const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { std::cout << s << ": \n"; @@ -67,7 +67,7 @@ class GTSAM_EXPORT HybridValues { keyFormatter); // print continuous components }; - // equals required by Testable for unit testing + /// equals required by Testable for unit testing bool equals(const HybridValues& other, double tol = 1e-9) const { return discrete_.equals(other.discrete_, tol) && continuous_.equals(other.continuous_, tol); @@ -83,13 +83,13 @@ class GTSAM_EXPORT HybridValues { /// Return the delta update for the continuous vectors VectorValues continuous() const { return continuous_; } - // Check whether a variable with key \c j exists in DiscreteValue. + /// Check whether a variable with key \c j exists in DiscreteValue. bool existsDiscrete(Key j) { return (discrete_.find(j) != discrete_.end()); }; - // Check whether a variable with key \c j exists in VectorValue. + /// Check whether a variable with key \c j exists in VectorValue. bool existsVector(Key j) { return continuous_.exists(j); }; - // Check whether a variable with key \c j exists. + /// Check whether a variable with key \c j exists. bool exists(Key j) { return existsDiscrete(j) || existsVector(j); }; /** Insert a discrete \c value with key \c j. Replaces the existing value if diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index dfbf3919d..86029a48a 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -99,6 +99,8 @@ class HybridBayesTree { bool empty() const; const HybridBayesTreeClique* operator[](size_t j) const; + gtsam::HybridValues optimize() const; + string dot(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; }; diff --git a/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp b/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp index e4bd0e084..40da42412 100644 --- a/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp +++ b/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp @@ -184,8 +184,8 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalSimple) { hfg.add(DecisionTreeFactor(m1, {2, 8})); hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4")); - HybridBayesTree::shared_ptr result = hfg.eliminateMultifrontal( - Ordering::ColamdConstrainedLast(hfg, {M(1), M(2)})); + HybridBayesTree::shared_ptr result = + hfg.eliminateMultifrontal(hfg.getHybridOrdering()); // The bayes tree should have 3 cliques EXPECT_LONGS_EQUAL(3, result->size()); @@ -215,7 +215,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalCLG) { hfg.add(HybridDiscreteFactor(DecisionTreeFactor(m, {2, 8}))); // Get a constrained ordering keeping c1 last - auto ordering_full = Ordering::ColamdConstrainedLast(hfg, {M(1)}); + auto ordering_full = hfg.getHybridOrdering(); // Returns a Hybrid Bayes Tree with distribution P(x0|x1)P(x1|c1)P(c1) HybridBayesTree::shared_ptr hbt = hfg.eliminateMultifrontal(ordering_full); @@ -484,8 +484,7 @@ TEST(HybridGaussianFactorGraph, SwitchingTwoVar) { } HybridBayesNet::shared_ptr hbn; HybridGaussianFactorGraph::shared_ptr remaining; - std::tie(hbn, remaining) = - hfg->eliminatePartialSequential(ordering_partial); + std::tie(hbn, remaining) = hfg->eliminatePartialSequential(ordering_partial); EXPECT_LONGS_EQUAL(14, hbn->size()); EXPECT_LONGS_EQUAL(11, remaining->size());