Make hybrid elimination a method of HGFG
parent
dac90db441
commit
ea54525d37
|
@ -366,18 +366,17 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
|
|||
return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors);
|
||||
}
|
||||
|
||||
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
||||
hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||
const Ordering &frontalKeys,
|
||||
const std::set<DiscreteKey> &discreteSeparatorSet) {
|
||||
// NOTE: since we use the special JunctionTree,
|
||||
// only possibility is continuous conditioned on discrete.
|
||||
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
|
||||
discreteSeparatorSet.end());
|
||||
std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
||||
HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
||||
// Since we eliminate all continuous variables first,
|
||||
// the discrete separator will be *all* the discrete keys.
|
||||
const std::set<DiscreteKey> keysForDiscreteVariables = discreteKeys();
|
||||
DiscreteKeys discreteSeparator(keysForDiscreteVariables.begin(),
|
||||
keysForDiscreteVariables.end());
|
||||
|
||||
// Collect all the factors to create a set of Gaussian factor graphs in a
|
||||
// decision tree indexed by all discrete keys involved.
|
||||
GaussianFactorGraphTree factorGraphTree = factors.assembleGraphTree();
|
||||
GaussianFactorGraphTree factorGraphTree = assembleGraphTree();
|
||||
|
||||
// Convert factor graphs with a nullptr to an empty factor graph.
|
||||
// This is done after assembly since it is non-trivial to keep track of which
|
||||
|
@ -392,7 +391,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
}
|
||||
|
||||
// Expensive elimination of product factor.
|
||||
auto result = EliminatePreferCholesky(graph, frontalKeys);
|
||||
auto result = EliminatePreferCholesky(graph, keys);
|
||||
|
||||
// Record whether there any continuous variables left
|
||||
someContinuousLeft |= !result.second->empty();
|
||||
|
@ -436,7 +435,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
*/
|
||||
std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> //
|
||||
EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
||||
const Ordering &frontalKeys) {
|
||||
const Ordering &keys) {
|
||||
// NOTE: Because we are in the Conditional Gaussian regime there are only
|
||||
// a few cases:
|
||||
// 1. continuous variable, make a hybrid Gaussian conditional if there are
|
||||
|
@ -510,20 +509,13 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
|||
|
||||
if (only_discrete) {
|
||||
// Case 1: we are only dealing with discrete
|
||||
return discreteElimination(factors, frontalKeys);
|
||||
return discreteElimination(factors, keys);
|
||||
} else if (only_continuous) {
|
||||
// Case 2: we are only dealing with continuous
|
||||
return continuousElimination(factors, frontalKeys);
|
||||
return continuousElimination(factors, keys);
|
||||
} else {
|
||||
// Case 3: We are now in the hybrid land!
|
||||
KeySet frontalKeysSet(frontalKeys.begin(), frontalKeys.end());
|
||||
|
||||
// Find all discrete keys.
|
||||
// Since we eliminate all continuous variables first,
|
||||
// the discrete separator will be *all* the discrete keys.
|
||||
std::set<DiscreteKey> discreteSeparator = factors.discreteKeys();
|
||||
|
||||
return hybridElimination(factors, frontalKeys, discreteSeparator);
|
||||
return factors.eliminate(keys);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -217,6 +217,14 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
*/
|
||||
GaussianFactorGraphTree assembleGraphTree() const;
|
||||
|
||||
/**
|
||||
* @brief Eliminate the given continuous keys.
|
||||
*
|
||||
* @param keys The continuous keys to eliminate.
|
||||
* @return The conditional on the keys and a factor on the separator.
|
||||
*/
|
||||
std::pair<std::shared_ptr<HybridConditional>, std::shared_ptr<Factor>>
|
||||
eliminate(const Ordering& keys) const;
|
||||
/// @}
|
||||
|
||||
/// Get the GaussianFactorGraph at a given discrete assignment.
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <CppUnitLite/Test.h>
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/TestableAssertions.h>
|
||||
#include <gtsam/base/Vector.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
|
@ -42,6 +43,7 @@
|
|||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
|
@ -120,6 +122,25 @@ std::vector<GaussianFactor::shared_ptr> components(Key key) {
|
|||
}
|
||||
} // namespace two
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) {
|
||||
HybridGaussianFactorGraph hfg;
|
||||
hfg.add(HybridGaussianFactor(m1, two::components(X(1))));
|
||||
|
||||
auto result = hfg.eliminate({X(1)});
|
||||
|
||||
// Check that we have a valid Gaussian conditional.
|
||||
auto hgc = result.first->asHybrid();
|
||||
CHECK(hgc);
|
||||
const HybridValues values{{{X(1), Z_3x1}}, {{M(1), 1}}};
|
||||
EXPECT(HybridConditional::CheckInvariants(*result.first, values));
|
||||
|
||||
// Check that factor is discrete and correct
|
||||
auto factor = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second);
|
||||
CHECK(factor);
|
||||
EXPECT(assert_equal(DecisionTreeFactor{m1, "1 1"}, *factor));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(HybridGaussianFactorGraph, eliminateFullSequentialEqualChance) {
|
||||
HybridGaussianFactorGraph hfg;
|
||||
|
@ -221,20 +242,16 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalTwoClique) {
|
|||
hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));
|
||||
hfg.add(JacobianFactor(X(1), I_3x3, X(2), -I_3x3, Z_3x1));
|
||||
|
||||
{
|
||||
hfg.add(HybridGaussianFactor({M(0), 2}, two::components(X(0))));
|
||||
hfg.add(HybridGaussianFactor({M(1), 2}, two::components(X(2))));
|
||||
}
|
||||
hfg.add(HybridGaussianFactor({M(0), 2}, two::components(X(0))));
|
||||
hfg.add(HybridGaussianFactor({M(1), 2}, two::components(X(2))));
|
||||
|
||||
hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));
|
||||
|
||||
hfg.add(JacobianFactor(X(3), I_3x3, X(4), -I_3x3, Z_3x1));
|
||||
hfg.add(JacobianFactor(X(4), I_3x3, X(5), -I_3x3, Z_3x1));
|
||||
|
||||
{
|
||||
hfg.add(HybridGaussianFactor({M(3), 2}, two::components(X(3))));
|
||||
hfg.add(HybridGaussianFactor({M(2), 2}, two::components(X(5))));
|
||||
}
|
||||
hfg.add(HybridGaussianFactor({M(3), 2}, two::components(X(3))));
|
||||
hfg.add(HybridGaussianFactor({M(2), 2}, two::components(X(5))));
|
||||
|
||||
auto ordering_full =
|
||||
Ordering::ColamdConstrainedLast(hfg, {M(0), M(1), M(2), M(3)});
|
||||
|
|
Loading…
Reference in New Issue