Make hybrid elimination a method of HGFG

release/4.3a0
Frank Dellaert 2024-09-28 19:35:11 -07:00
parent dac90db441
commit ea54525d37
3 changed files with 46 additions and 29 deletions

View File

@ -366,18 +366,17 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors); return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors);
} }
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
hybridElimination(const HybridGaussianFactorGraph &factors, HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
const Ordering &frontalKeys, // Since we eliminate all continuous variables first,
const std::set<DiscreteKey> &discreteSeparatorSet) { // the discrete separator will be *all* the discrete keys.
// NOTE: since we use the special JunctionTree, const std::set<DiscreteKey> keysForDiscreteVariables = discreteKeys();
// only possibility is continuous conditioned on discrete. DiscreteKeys discreteSeparator(keysForDiscreteVariables.begin(),
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(), keysForDiscreteVariables.end());
discreteSeparatorSet.end());
// Collect all the factors to create a set of Gaussian factor graphs in a // Collect all the factors to create a set of Gaussian factor graphs in a
// decision tree indexed by all discrete keys involved. // decision tree indexed by all discrete keys involved.
GaussianFactorGraphTree factorGraphTree = factors.assembleGraphTree(); GaussianFactorGraphTree factorGraphTree = assembleGraphTree();
// Convert factor graphs with a nullptr to an empty factor graph. // Convert factor graphs with a nullptr to an empty factor graph.
// This is done after assembly since it is non-trivial to keep track of which // This is done after assembly since it is non-trivial to keep track of which
@ -392,7 +391,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
} }
// Expensive elimination of product factor. // Expensive elimination of product factor.
auto result = EliminatePreferCholesky(graph, frontalKeys); auto result = EliminatePreferCholesky(graph, keys);
// Record whether there any continuous variables left // Record whether there any continuous variables left
someContinuousLeft |= !result.second->empty(); someContinuousLeft |= !result.second->empty();
@ -436,7 +435,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
*/ */
std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> // std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> //
EliminateHybrid(const HybridGaussianFactorGraph &factors, EliminateHybrid(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) { const Ordering &keys) {
// NOTE: Because we are in the Conditional Gaussian regime there are only // NOTE: Because we are in the Conditional Gaussian regime there are only
// a few cases: // a few cases:
// 1. continuous variable, make a hybrid Gaussian conditional if there are // 1. continuous variable, make a hybrid Gaussian conditional if there are
@ -510,20 +509,13 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
if (only_discrete) { if (only_discrete) {
// Case 1: we are only dealing with discrete // Case 1: we are only dealing with discrete
return discreteElimination(factors, frontalKeys); return discreteElimination(factors, keys);
} else if (only_continuous) { } else if (only_continuous) {
// Case 2: we are only dealing with continuous // Case 2: we are only dealing with continuous
return continuousElimination(factors, frontalKeys); return continuousElimination(factors, keys);
} else { } else {
// Case 3: We are now in the hybrid land! // Case 3: We are now in the hybrid land!
KeySet frontalKeysSet(frontalKeys.begin(), frontalKeys.end()); return factors.eliminate(keys);
// Find all discrete keys.
// Since we eliminate all continuous variables first,
// the discrete separator will be *all* the discrete keys.
std::set<DiscreteKey> discreteSeparator = factors.discreteKeys();
return hybridElimination(factors, frontalKeys, discreteSeparator);
} }
} }

View File

@ -217,6 +217,14 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
*/ */
GaussianFactorGraphTree assembleGraphTree() const; GaussianFactorGraphTree assembleGraphTree() const;
/**
* @brief Eliminate the given continuous keys.
*
* @param keys The continuous keys to eliminate.
* @return The conditional on the keys and a factor on the separator.
*/
std::pair<std::shared_ptr<HybridConditional>, std::shared_ptr<Factor>>
eliminate(const Ordering& keys) const;
/// @} /// @}
/// Get the GaussianFactorGraph at a given discrete assignment. /// Get the GaussianFactorGraph at a given discrete assignment.

View File

@ -18,6 +18,7 @@
#include <CppUnitLite/Test.h> #include <CppUnitLite/Test.h>
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/base/TestableAssertions.h> #include <gtsam/base/TestableAssertions.h>
#include <gtsam/base/Vector.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
@ -42,6 +43,7 @@
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <iterator> #include <iterator>
#include <memory>
#include <numeric> #include <numeric>
#include <vector> #include <vector>
@ -120,6 +122,25 @@ std::vector<GaussianFactor::shared_ptr> components(Key key) {
} }
} // namespace two } // namespace two
/* ************************************************************************* */
TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) {
HybridGaussianFactorGraph hfg;
hfg.add(HybridGaussianFactor(m1, two::components(X(1))));
auto result = hfg.eliminate({X(1)});
// Check that we have a valid Gaussian conditional.
auto hgc = result.first->asHybrid();
CHECK(hgc);
const HybridValues values{{{X(1), Z_3x1}}, {{M(1), 1}}};
EXPECT(HybridConditional::CheckInvariants(*result.first, values));
// Check that factor is discrete and correct
auto factor = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second);
CHECK(factor);
EXPECT(assert_equal(DecisionTreeFactor{m1, "1 1"}, *factor));
}
/* ************************************************************************* */ /* ************************************************************************* */
TEST(HybridGaussianFactorGraph, eliminateFullSequentialEqualChance) { TEST(HybridGaussianFactorGraph, eliminateFullSequentialEqualChance) {
HybridGaussianFactorGraph hfg; HybridGaussianFactorGraph hfg;
@ -221,20 +242,16 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalTwoClique) {
hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1)); hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));
hfg.add(JacobianFactor(X(1), I_3x3, X(2), -I_3x3, Z_3x1)); hfg.add(JacobianFactor(X(1), I_3x3, X(2), -I_3x3, Z_3x1));
{ hfg.add(HybridGaussianFactor({M(0), 2}, two::components(X(0))));
hfg.add(HybridGaussianFactor({M(0), 2}, two::components(X(0)))); hfg.add(HybridGaussianFactor({M(1), 2}, two::components(X(2))));
hfg.add(HybridGaussianFactor({M(1), 2}, two::components(X(2))));
}
hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4")); hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));
hfg.add(JacobianFactor(X(3), I_3x3, X(4), -I_3x3, Z_3x1)); hfg.add(JacobianFactor(X(3), I_3x3, X(4), -I_3x3, Z_3x1));
hfg.add(JacobianFactor(X(4), I_3x3, X(5), -I_3x3, Z_3x1)); hfg.add(JacobianFactor(X(4), I_3x3, X(5), -I_3x3, Z_3x1));
{ hfg.add(HybridGaussianFactor({M(3), 2}, two::components(X(3))));
hfg.add(HybridGaussianFactor({M(3), 2}, two::components(X(3)))); hfg.add(HybridGaussianFactor({M(2), 2}, two::components(X(5))));
hfg.add(HybridGaussianFactor({M(2), 2}, two::components(X(5))));
}
auto ordering_full = auto ordering_full =
Ordering::ColamdConstrainedLast(hfg, {M(0), M(1), M(2), M(3)}); Ordering::ColamdConstrainedLast(hfg, {M(0), M(1), M(2), M(3)});