add helpers to get different types of keys from the hybrid graph

release/4.3a0
Varun Agrawal 2022-08-31 12:25:57 -04:00
parent fe595bf745
commit 4852949b75
3 changed files with 38 additions and 13 deletions

View File

@ -402,14 +402,35 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
} }
/* ************************************************************************ */ /* ************************************************************************ */
const Ordering HybridGaussianFactorGraph::getHybridOrdering( const KeySet HybridGaussianFactorGraph::getDiscreteKeys() const {
OptionalOrderingType orderingType) const {
KeySet discrete_keys; KeySet discrete_keys;
for (auto &factor : factors_) { for (auto &factor : factors_) {
for (const DiscreteKey &k : factor->discreteKeys()) { for (const DiscreteKey &k : factor->discreteKeys()) {
discrete_keys.insert(k.first); discrete_keys.insert(k.first);
} }
} }
return discrete_keys;
}
/* ************************************************************************ */
const KeySet HybridGaussianFactorGraph::getContinuousKeys() const {
KeySet keys;
for (auto &factor : factors_) {
for (const Key &key : factor->continuousKeys()) {
keys.insert(key);
}
}
return keys;
}
/* ************************************************************************ */
const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
KeySet discrete_keys = getDiscreteKeys();
for (auto &factor : factors_) {
for (const DiscreteKey &k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
}
}
const VariableIndex index(factors_); const VariableIndex index(factors_);
Ordering ordering = Ordering::ColamdConstrainedLast( Ordering ordering = Ordering::ColamdConstrainedLast(

View File

@ -161,14 +161,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
} }
} }
/// Get all the discrete keys in the factor graph.
const KeySet getDiscreteKeys() const;
/// Get all the continuous keys in the factor graph.
const KeySet getContinuousKeys() const;
/** /**
* @brief * @brief Return a Colamd constrained ordering where the discrete keys are
* * eliminated after the continuous keys.
* @param orderingType *
* @return const Ordering * @return const Ordering
*/ */
const Ordering getHybridOrdering( const Ordering getHybridOrdering() const;
OptionalOrderingType orderingType = boost::none) const;
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -184,8 +184,8 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalSimple) {
hfg.add(DecisionTreeFactor(m1, {2, 8})); hfg.add(DecisionTreeFactor(m1, {2, 8}));
hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4")); hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));
HybridBayesTree::shared_ptr result = hfg.eliminateMultifrontal( HybridBayesTree::shared_ptr result =
Ordering::ColamdConstrainedLast(hfg, {M(1), M(2)})); hfg.eliminateMultifrontal(hfg.getHybridOrdering());
// The bayes tree should have 3 cliques // The bayes tree should have 3 cliques
EXPECT_LONGS_EQUAL(3, result->size()); EXPECT_LONGS_EQUAL(3, result->size());
@ -215,7 +215,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalCLG) {
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(m, {2, 8}))); hfg.add(HybridDiscreteFactor(DecisionTreeFactor(m, {2, 8})));
// Get a constrained ordering keeping c1 last // Get a constrained ordering keeping c1 last
auto ordering_full = Ordering::ColamdConstrainedLast(hfg, {M(1)}); auto ordering_full = hfg.getHybridOrdering();
// Returns a Hybrid Bayes Tree with distribution P(x0|x1)P(x1|c1)P(c1) // Returns a Hybrid Bayes Tree with distribution P(x0|x1)P(x1|c1)P(c1)
HybridBayesTree::shared_ptr hbt = hfg.eliminateMultifrontal(ordering_full); HybridBayesTree::shared_ptr hbt = hfg.eliminateMultifrontal(ordering_full);
@ -484,8 +484,7 @@ TEST(HybridGaussianFactorGraph, SwitchingTwoVar) {
} }
HybridBayesNet::shared_ptr hbn; HybridBayesNet::shared_ptr hbn;
HybridGaussianFactorGraph::shared_ptr remaining; HybridGaussianFactorGraph::shared_ptr remaining;
std::tie(hbn, remaining) = std::tie(hbn, remaining) = hfg->eliminatePartialSequential(ordering_partial);
hfg->eliminatePartialSequential(ordering_partial);
EXPECT_LONGS_EQUAL(14, hbn->size()); EXPECT_LONGS_EQUAL(14, hbn->size());
EXPECT_LONGS_EQUAL(11, remaining->size()); EXPECT_LONGS_EQUAL(11, remaining->size());