Make hybrid elimination a method of HGFG

release/4.3a0
Frank Dellaert 2024-09-28 19:35:11 -07:00
parent dac90db441
commit ea54525d37
3 changed files with 46 additions and 29 deletions

View File

@ -366,18 +366,17 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors);
}
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
hybridElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys,
const std::set<DiscreteKey> &discreteSeparatorSet) {
// NOTE: since we use the special JunctionTree,
// only possibility is continuous conditioned on discrete.
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
discreteSeparatorSet.end());
std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
// Since we eliminate all continuous variables first,
// the discrete separator will be *all* the discrete keys.
const std::set<DiscreteKey> keysForDiscreteVariables = discreteKeys();
DiscreteKeys discreteSeparator(keysForDiscreteVariables.begin(),
keysForDiscreteVariables.end());
// Collect all the factors to create a set of Gaussian factor graphs in a
// decision tree indexed by all discrete keys involved.
GaussianFactorGraphTree factorGraphTree = factors.assembleGraphTree();
GaussianFactorGraphTree factorGraphTree = assembleGraphTree();
// Convert factor graphs with a nullptr to an empty factor graph.
// This is done after assembly since it is non-trivial to keep track of which
@ -392,7 +391,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
}
// Expensive elimination of product factor.
auto result = EliminatePreferCholesky(graph, frontalKeys);
auto result = EliminatePreferCholesky(graph, keys);
// Record whether there any continuous variables left
someContinuousLeft |= !result.second->empty();
@ -436,7 +435,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
*/
std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> //
EliminateHybrid(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) {
const Ordering &keys) {
// NOTE: Because we are in the Conditional Gaussian regime there are only
// a few cases:
// 1. continuous variable, make a hybrid Gaussian conditional if there are
@ -510,20 +509,13 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
if (only_discrete) {
// Case 1: we are only dealing with discrete
return discreteElimination(factors, frontalKeys);
return discreteElimination(factors, keys);
} else if (only_continuous) {
// Case 2: we are only dealing with continuous
return continuousElimination(factors, frontalKeys);
return continuousElimination(factors, keys);
} else {
// Case 3: We are now in the hybrid land!
KeySet frontalKeysSet(frontalKeys.begin(), frontalKeys.end());
// Find all discrete keys.
// Since we eliminate all continuous variables first,
// the discrete separator will be *all* the discrete keys.
std::set<DiscreteKey> discreteSeparator = factors.discreteKeys();
return hybridElimination(factors, frontalKeys, discreteSeparator);
return factors.eliminate(keys);
}
}

View File

@ -217,6 +217,14 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
*/
GaussianFactorGraphTree assembleGraphTree() const;
/**
* @brief Eliminate the given continuous keys.
*
* @param keys The continuous keys to eliminate.
* @return The conditional on the keys and a factor on the separator.
*/
std::pair<std::shared_ptr<HybridConditional>, std::shared_ptr<Factor>>
eliminate(const Ordering& keys) const;
/// @}
/// Get the GaussianFactorGraph at a given discrete assignment.

View File

@ -18,6 +18,7 @@
#include <CppUnitLite/Test.h>
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/TestableAssertions.h>
#include <gtsam/base/Vector.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h>
@ -42,6 +43,7 @@
#include <functional>
#include <iostream>
#include <iterator>
#include <memory>
#include <numeric>
#include <vector>
@ -120,6 +122,25 @@ std::vector<GaussianFactor::shared_ptr> components(Key key) {
}
} // namespace two
/* ************************************************************************* */
TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) {
HybridGaussianFactorGraph hfg;
hfg.add(HybridGaussianFactor(m1, two::components(X(1))));
auto result = hfg.eliminate({X(1)});
// Check that we have a valid Gaussian conditional.
auto hgc = result.first->asHybrid();
CHECK(hgc);
const HybridValues values{{{X(1), Z_3x1}}, {{M(1), 1}}};
EXPECT(HybridConditional::CheckInvariants(*result.first, values));
// Check that factor is discrete and correct
auto factor = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second);
CHECK(factor);
EXPECT(assert_equal(DecisionTreeFactor{m1, "1 1"}, *factor));
}
/* ************************************************************************* */
TEST(HybridGaussianFactorGraph, eliminateFullSequentialEqualChance) {
HybridGaussianFactorGraph hfg;
@ -221,20 +242,16 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalTwoClique) {
hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));
hfg.add(JacobianFactor(X(1), I_3x3, X(2), -I_3x3, Z_3x1));
{
hfg.add(HybridGaussianFactor({M(0), 2}, two::components(X(0))));
hfg.add(HybridGaussianFactor({M(1), 2}, two::components(X(2))));
}
hfg.add(HybridGaussianFactor({M(0), 2}, two::components(X(0))));
hfg.add(HybridGaussianFactor({M(1), 2}, two::components(X(2))));
hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));
hfg.add(JacobianFactor(X(3), I_3x3, X(4), -I_3x3, Z_3x1));
hfg.add(JacobianFactor(X(4), I_3x3, X(5), -I_3x3, Z_3x1));
{
hfg.add(HybridGaussianFactor({M(3), 2}, two::components(X(3))));
hfg.add(HybridGaussianFactor({M(2), 2}, two::components(X(5))));
}
hfg.add(HybridGaussianFactor({M(3), 2}, two::components(X(3))));
hfg.add(HybridGaussianFactor({M(2), 2}, two::components(X(5))));
auto ordering_full =
Ordering::ColamdConstrainedLast(hfg, {M(0), M(1), M(2), M(3)});