Merge pull request #1466 from borglab/hybrid-support

release/4.3a0
Varun Agrawal 2023-02-15 21:30:04 -05:00 committed by GitHub
commit dafa0076ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 108 additions and 14 deletions

View File

@ -106,7 +106,9 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
// TODO(dellaert): just use a virtual method defined in HybridFactor. // TODO(dellaert): just use a virtual method defined in HybridFactor.
if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) { if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) {
result = addGaussian(result, gf); result = addGaussian(result, gf);
} else if (auto gm = dynamic_pointer_cast<GaussianMixtureFactor>(f)) { } else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
result = gmf->add(result);
} else if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
result = gm->add(result); result = gm->add(result);
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) { } else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
if (auto gm = hc->asMixture()) { if (auto gm = hc->asMixture()) {
@ -283,17 +285,15 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// taking care to correct for conditional constant. // taking care to correct for conditional constant.
// Correct for the normalization constant used up by the conditional // Correct for the normalization constant used up by the conditional
auto correct = [&](const Result &pair) -> GaussianFactor::shared_ptr { auto correct = [&](const Result &pair) {
const auto &factor = pair.second; const auto &factor = pair.second;
if (!factor) return factor; // TODO(dellaert): not loving this. if (!factor) return;
auto hf = boost::dynamic_pointer_cast<HessianFactor>(factor); auto hf = boost::dynamic_pointer_cast<HessianFactor>(factor);
if (!hf) throw std::runtime_error("Expected HessianFactor!"); if (!hf) throw std::runtime_error("Expected HessianFactor!");
hf->constantTerm() += 2.0 * pair.first->logNormalizationConstant(); hf->constantTerm() += 2.0 * pair.first->logNormalizationConstant();
return hf;
}; };
eliminationResults.visit(correct);
GaussianMixtureFactor::Factors correctedFactors(eliminationResults,
correct);
const auto mixtureFactor = boost::make_shared<GaussianMixtureFactor>( const auto mixtureFactor = boost::make_shared<GaussianMixtureFactor>(
continuousSeparator, discreteSeparator, newFactors); continuousSeparator, discreteSeparator, newFactors);

View File

@ -17,6 +17,7 @@
*/ */
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h> #include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h> #include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
#include <gtsam/hybrid/MixtureFactor.h> #include <gtsam/hybrid/MixtureFactor.h>
@ -69,6 +70,12 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) { } else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
// If discrete-only: doesn't need linearization. // If discrete-only: doesn't need linearization.
linearFG->push_back(f); linearFG->push_back(f);
} else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
linearFG->push_back(gmf);
} else if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
linearFG->push_back(gm);
} else if (dynamic_pointer_cast<GaussianFactor>(f)) {
linearFG->push_back(f);
} else { } else {
auto& fr = *f; auto& fr = *f;
throw std::invalid_argument( throw std::invalid_argument(

View File

@ -23,6 +23,37 @@
namespace gtsam { namespace gtsam {
/* ************************************************************************* */
Ordering HybridSmoother::getOrdering(
const HybridGaussianFactorGraph &newFactors) {
HybridGaussianFactorGraph factors(hybridBayesNet());
factors += newFactors;
// Get all the discrete keys from the factors
KeySet allDiscrete = factors.discreteKeySet();
// Create KeyVector with continuous keys followed by discrete keys.
KeyVector newKeysDiscreteLast;
const KeySet newFactorKeys = newFactors.keys();
// Insert continuous keys first.
for (auto &k : newFactorKeys) {
if (!allDiscrete.exists(k)) {
newKeysDiscreteLast.push_back(k);
}
}
// Insert discrete keys at the end
std::copy(allDiscrete.begin(), allDiscrete.end(),
std::back_inserter(newKeysDiscreteLast));
const VariableIndex index(newFactors);
// Get an ordering where the new keys are eliminated last
Ordering ordering = Ordering::ColamdConstrainedLast(
index, KeyVector(newKeysDiscreteLast.begin(), newKeysDiscreteLast.end()),
true);
return ordering;
}
/* ************************************************************************* */ /* ************************************************************************* */
void HybridSmoother::update(HybridGaussianFactorGraph graph, void HybridSmoother::update(HybridGaussianFactorGraph graph,
const Ordering &ordering, const Ordering &ordering,
@ -92,7 +123,6 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
} }
graph.push_back(newConditionals); graph.push_back(newConditionals);
// newConditionals.print("\n\n\nNew Conditionals to add back");
} }
return {graph, hybridBayesNet}; return {graph, hybridBayesNet};
} }

View File

@ -50,6 +50,8 @@ class HybridSmoother {
void update(HybridGaussianFactorGraph graph, const Ordering& ordering, void update(HybridGaussianFactorGraph graph, const Ordering& ordering,
boost::optional<size_t> maxNrLeaves = boost::none); boost::optional<size_t> maxNrLeaves = boost::none);
Ordering getOrdering(const HybridGaussianFactorGraph& newFactors);
/** /**
* @brief Add conditionals from previous timestep as part of liquefication. * @brief Add conditionals from previous timestep as part of liquefication.
* *

View File

@ -93,6 +93,7 @@ TEST(GaussianMixtureFactor, Sum) {
EXPECT(actual.at(1) == f22); EXPECT(actual.at(1) == f22);
} }
/* ************************************************************************* */
TEST(GaussianMixtureFactor, Printing) { TEST(GaussianMixtureFactor, Printing) {
DiscreteKey m1(1, 2); DiscreteKey m1(1, 2);
auto A1 = Matrix::Zero(2, 1); auto A1 = Matrix::Zero(2, 1);
@ -136,6 +137,7 @@ TEST(GaussianMixtureFactor, Printing) {
EXPECT(assert_print_equal(expected, mixtureFactor)); EXPECT(assert_print_equal(expected, mixtureFactor));
} }
/* ************************************************************************* */
TEST(GaussianMixtureFactor, GaussianMixture) { TEST(GaussianMixtureFactor, GaussianMixture) {
KeyVector keys; KeyVector keys;
keys.push_back(X(0)); keys.push_back(X(0));

View File

@ -612,7 +612,6 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
// Check that assembleGraphTree assembles Gaussian factor graphs for each // Check that assembleGraphTree assembles Gaussian factor graphs for each
// assignment. // assignment.
TEST(HybridGaussianFactorGraph, assembleGraphTree) { TEST(HybridGaussianFactorGraph, assembleGraphTree) {
using symbol_shorthand::Z;
const int num_measurements = 1; const int num_measurements = 1;
auto fg = tiny::createHybridGaussianFactorGraph( auto fg = tiny::createHybridGaussianFactorGraph(
num_measurements, VectorValues{{Z(0), Vector1(5.0)}}); num_measurements, VectorValues{{Z(0), Vector1(5.0)}});
@ -694,7 +693,6 @@ bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
/* ****************************************************************************/ /* ****************************************************************************/
// Check that eliminating tiny net with 1 measurement yields correct result. // Check that eliminating tiny net with 1 measurement yields correct result.
TEST(HybridGaussianFactorGraph, EliminateTiny1) { TEST(HybridGaussianFactorGraph, EliminateTiny1) {
using symbol_shorthand::Z;
const int num_measurements = 1; const int num_measurements = 1;
const VectorValues measurements{{Z(0), Vector1(5.0)}}; const VectorValues measurements{{Z(0), Vector1(5.0)}};
auto bn = tiny::createHybridBayesNet(num_measurements); auto bn = tiny::createHybridBayesNet(num_measurements);
@ -726,11 +724,67 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
EXPECT(ratioTest(bn, measurements, *posterior)); EXPECT(ratioTest(bn, measurements, *posterior));
} }
/* ****************************************************************************/
// Check that eliminating tiny net with 1 measurement with mode order swapped
// yields correct result.
TEST(HybridGaussianFactorGraph, EliminateTiny1Swapped) {
const VectorValues measurements{{Z(0), Vector1(5.0)}};
// Create mode key: 1 is low-noise, 0 is high-noise.
const DiscreteKey mode{M(0), 2};
HybridBayesNet bn;
// Create Gaussian mixture z_0 = x0 + noise for each measurement.
bn.emplace_back(new GaussianMixture(
{Z(0)}, {X(0)}, {mode},
{GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Z_1x1, 3),
GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Z_1x1,
0.5)}));
// Create prior on X(0).
bn.push_back(
GaussianConditional::sharedMeanAndStddev(X(0), Vector1(5.0), 0.5));
// Add prior on mode.
bn.emplace_back(new DiscreteConditional(mode, "1/1"));
// bn.print();
auto fg = bn.toFactorGraph(measurements);
EXPECT_LONGS_EQUAL(3, fg.size());
// fg.print();
EXPECT(ratioTest(bn, measurements, fg));
// Create expected Bayes Net:
HybridBayesNet expectedBayesNet;
// Create Gaussian mixture on X(0).
// regression, but mean checked to be 5.0 in both cases:
const auto conditional0 = boost::make_shared<GaussianConditional>(
X(0), Vector1(10.1379), I_1x1 * 2.02759),
conditional1 = boost::make_shared<GaussianConditional>(
X(0), Vector1(14.1421), I_1x1 * 2.82843);
expectedBayesNet.emplace_back(
new GaussianMixture({X(0)}, {}, {mode}, {conditional0, conditional1}));
// Add prior on mode.
expectedBayesNet.emplace_back(new DiscreteConditional(mode, "1/1"));
// Test elimination
const auto posterior = fg.eliminateSequential();
// EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
EXPECT(ratioTest(bn, measurements, *posterior));
// posterior->print();
// posterior->optimize().print();
}
/* ****************************************************************************/ /* ****************************************************************************/
// Check that eliminating tiny net with 2 measurements yields correct result. // Check that eliminating tiny net with 2 measurements yields correct result.
TEST(HybridGaussianFactorGraph, EliminateTiny2) { TEST(HybridGaussianFactorGraph, EliminateTiny2) {
// Create factor graph with 2 measurements such that posterior mean = 5.0. // Create factor graph with 2 measurements such that posterior mean = 5.0.
using symbol_shorthand::Z;
const int num_measurements = 2; const int num_measurements = 2;
const VectorValues measurements{{Z(0), Vector1(4.0)}, {Z(1), Vector1(6.0)}}; const VectorValues measurements{{Z(0), Vector1(4.0)}, {Z(1), Vector1(6.0)}};
auto bn = tiny::createHybridBayesNet(num_measurements); auto bn = tiny::createHybridBayesNet(num_measurements);
@ -764,7 +818,6 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) {
// Test eliminating tiny net with 1 mode per measurement. // Test eliminating tiny net with 1 mode per measurement.
TEST(HybridGaussianFactorGraph, EliminateTiny22) { TEST(HybridGaussianFactorGraph, EliminateTiny22) {
// Create factor graph with 2 measurements such that posterior mean = 5.0. // Create factor graph with 2 measurements such that posterior mean = 5.0.
using symbol_shorthand::Z;
const int num_measurements = 2; const int num_measurements = 2;
const bool manyModes = true; const bool manyModes = true;
@ -835,12 +888,12 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) {
// D D // D D
// | | // | |
// m1 m2 // m1 m2
// | | // | |
// C-x0-HC-x1-HC-x2 // C-x0-HC-x1-HC-x2
// | | | // | | |
// HF HF HF // HF HF HF
// | | | // | | |
// n0 n1 n2 // n0 n1 n2
// | | | // | | |
// D D D // D D D
EXPECT_LONGS_EQUAL(11, fg.size()); EXPECT_LONGS_EQUAL(11, fg.size());
@ -853,7 +906,7 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) {
EXPECT(ratioTest(bn, measurements, fg1)); EXPECT(ratioTest(bn, measurements, fg1));
// Create ordering that eliminates in time order, then discrete modes: // Create ordering that eliminates in time order, then discrete modes:
Ordering ordering {X(2), X(1), X(0), N(0), N(1), N(2), M(1), M(2)}; Ordering ordering{X(2), X(1), X(0), N(0), N(1), N(2), M(1), M(2)};
// Do elimination: // Do elimination:
const HybridBayesNet::shared_ptr posterior = fg.eliminateSequential(ordering); const HybridBayesNet::shared_ptr posterior = fg.eliminateSequential(ordering);