Make hybrid elimination a method of HGFG
parent
dac90db441
commit
ea54525d37
|
@ -366,18 +366,17 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
|
||||||
return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors);
|
return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
||||||
hybridElimination(const HybridGaussianFactorGraph &factors,
|
HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
||||||
const Ordering &frontalKeys,
|
// Since we eliminate all continuous variables first,
|
||||||
const std::set<DiscreteKey> &discreteSeparatorSet) {
|
// the discrete separator will be *all* the discrete keys.
|
||||||
// NOTE: since we use the special JunctionTree,
|
const std::set<DiscreteKey> keysForDiscreteVariables = discreteKeys();
|
||||||
// only possibility is continuous conditioned on discrete.
|
DiscreteKeys discreteSeparator(keysForDiscreteVariables.begin(),
|
||||||
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
|
keysForDiscreteVariables.end());
|
||||||
discreteSeparatorSet.end());
|
|
||||||
|
|
||||||
// Collect all the factors to create a set of Gaussian factor graphs in a
|
// Collect all the factors to create a set of Gaussian factor graphs in a
|
||||||
// decision tree indexed by all discrete keys involved.
|
// decision tree indexed by all discrete keys involved.
|
||||||
GaussianFactorGraphTree factorGraphTree = factors.assembleGraphTree();
|
GaussianFactorGraphTree factorGraphTree = assembleGraphTree();
|
||||||
|
|
||||||
// Convert factor graphs with a nullptr to an empty factor graph.
|
// Convert factor graphs with a nullptr to an empty factor graph.
|
||||||
// This is done after assembly since it is non-trivial to keep track of which
|
// This is done after assembly since it is non-trivial to keep track of which
|
||||||
|
@ -392,7 +391,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Expensive elimination of product factor.
|
// Expensive elimination of product factor.
|
||||||
auto result = EliminatePreferCholesky(graph, frontalKeys);
|
auto result = EliminatePreferCholesky(graph, keys);
|
||||||
|
|
||||||
// Record whether there any continuous variables left
|
// Record whether there any continuous variables left
|
||||||
someContinuousLeft |= !result.second->empty();
|
someContinuousLeft |= !result.second->empty();
|
||||||
|
@ -436,7 +435,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
*/
|
*/
|
||||||
std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> //
|
std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> //
|
||||||
EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
||||||
const Ordering &frontalKeys) {
|
const Ordering &keys) {
|
||||||
// NOTE: Because we are in the Conditional Gaussian regime there are only
|
// NOTE: Because we are in the Conditional Gaussian regime there are only
|
||||||
// a few cases:
|
// a few cases:
|
||||||
// 1. continuous variable, make a hybrid Gaussian conditional if there are
|
// 1. continuous variable, make a hybrid Gaussian conditional if there are
|
||||||
|
@ -510,20 +509,13 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
||||||
|
|
||||||
if (only_discrete) {
|
if (only_discrete) {
|
||||||
// Case 1: we are only dealing with discrete
|
// Case 1: we are only dealing with discrete
|
||||||
return discreteElimination(factors, frontalKeys);
|
return discreteElimination(factors, keys);
|
||||||
} else if (only_continuous) {
|
} else if (only_continuous) {
|
||||||
// Case 2: we are only dealing with continuous
|
// Case 2: we are only dealing with continuous
|
||||||
return continuousElimination(factors, frontalKeys);
|
return continuousElimination(factors, keys);
|
||||||
} else {
|
} else {
|
||||||
// Case 3: We are now in the hybrid land!
|
// Case 3: We are now in the hybrid land!
|
||||||
KeySet frontalKeysSet(frontalKeys.begin(), frontalKeys.end());
|
return factors.eliminate(keys);
|
||||||
|
|
||||||
// Find all discrete keys.
|
|
||||||
// Since we eliminate all continuous variables first,
|
|
||||||
// the discrete separator will be *all* the discrete keys.
|
|
||||||
std::set<DiscreteKey> discreteSeparator = factors.discreteKeys();
|
|
||||||
|
|
||||||
return hybridElimination(factors, frontalKeys, discreteSeparator);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -217,6 +217,14 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
*/
|
*/
|
||||||
GaussianFactorGraphTree assembleGraphTree() const;
|
GaussianFactorGraphTree assembleGraphTree() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Eliminate the given continuous keys.
|
||||||
|
*
|
||||||
|
* @param keys The continuous keys to eliminate.
|
||||||
|
* @return The conditional on the keys and a factor on the separator.
|
||||||
|
*/
|
||||||
|
std::pair<std::shared_ptr<HybridConditional>, std::shared_ptr<Factor>>
|
||||||
|
eliminate(const Ordering& keys) const;
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
/// Get the GaussianFactorGraph at a given discrete assignment.
|
/// Get the GaussianFactorGraph at a given discrete assignment.
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include <CppUnitLite/Test.h>
|
#include <CppUnitLite/Test.h>
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
#include <gtsam/base/TestableAssertions.h>
|
#include <gtsam/base/TestableAssertions.h>
|
||||||
|
#include <gtsam/base/Vector.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
|
@ -42,6 +43,7 @@
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
|
#include <memory>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
@ -120,6 +122,25 @@ std::vector<GaussianFactor::shared_ptr> components(Key key) {
|
||||||
}
|
}
|
||||||
} // namespace two
|
} // namespace two
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) {
|
||||||
|
HybridGaussianFactorGraph hfg;
|
||||||
|
hfg.add(HybridGaussianFactor(m1, two::components(X(1))));
|
||||||
|
|
||||||
|
auto result = hfg.eliminate({X(1)});
|
||||||
|
|
||||||
|
// Check that we have a valid Gaussian conditional.
|
||||||
|
auto hgc = result.first->asHybrid();
|
||||||
|
CHECK(hgc);
|
||||||
|
const HybridValues values{{{X(1), Z_3x1}}, {{M(1), 1}}};
|
||||||
|
EXPECT(HybridConditional::CheckInvariants(*result.first, values));
|
||||||
|
|
||||||
|
// Check that factor is discrete and correct
|
||||||
|
auto factor = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second);
|
||||||
|
CHECK(factor);
|
||||||
|
EXPECT(assert_equal(DecisionTreeFactor{m1, "1 1"}, *factor));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(HybridGaussianFactorGraph, eliminateFullSequentialEqualChance) {
|
TEST(HybridGaussianFactorGraph, eliminateFullSequentialEqualChance) {
|
||||||
HybridGaussianFactorGraph hfg;
|
HybridGaussianFactorGraph hfg;
|
||||||
|
@ -221,20 +242,16 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalTwoClique) {
|
||||||
hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));
|
hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));
|
||||||
hfg.add(JacobianFactor(X(1), I_3x3, X(2), -I_3x3, Z_3x1));
|
hfg.add(JacobianFactor(X(1), I_3x3, X(2), -I_3x3, Z_3x1));
|
||||||
|
|
||||||
{
|
hfg.add(HybridGaussianFactor({M(0), 2}, two::components(X(0))));
|
||||||
hfg.add(HybridGaussianFactor({M(0), 2}, two::components(X(0))));
|
hfg.add(HybridGaussianFactor({M(1), 2}, two::components(X(2))));
|
||||||
hfg.add(HybridGaussianFactor({M(1), 2}, two::components(X(2))));
|
|
||||||
}
|
|
||||||
|
|
||||||
hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));
|
hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));
|
||||||
|
|
||||||
hfg.add(JacobianFactor(X(3), I_3x3, X(4), -I_3x3, Z_3x1));
|
hfg.add(JacobianFactor(X(3), I_3x3, X(4), -I_3x3, Z_3x1));
|
||||||
hfg.add(JacobianFactor(X(4), I_3x3, X(5), -I_3x3, Z_3x1));
|
hfg.add(JacobianFactor(X(4), I_3x3, X(5), -I_3x3, Z_3x1));
|
||||||
|
|
||||||
{
|
hfg.add(HybridGaussianFactor({M(3), 2}, two::components(X(3))));
|
||||||
hfg.add(HybridGaussianFactor({M(3), 2}, two::components(X(3))));
|
hfg.add(HybridGaussianFactor({M(2), 2}, two::components(X(5))));
|
||||||
hfg.add(HybridGaussianFactor({M(2), 2}, two::components(X(5))));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto ordering_full =
|
auto ordering_full =
|
||||||
Ordering::ColamdConstrainedLast(hfg, {M(0), M(1), M(2), M(3)});
|
Ordering::ColamdConstrainedLast(hfg, {M(0), M(1), M(2), M(3)});
|
||||||
|
|
Loading…
Reference in New Issue