diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index c912a74fc..20950d4f9 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -106,7 +106,9 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { // TODO(dellaert): just use a virtual method defined in HybridFactor. if (auto gf = dynamic_pointer_cast(f)) { result = addGaussian(result, gf); - } else if (auto gm = dynamic_pointer_cast(f)) { + } else if (auto gmf = dynamic_pointer_cast(f)) { + result = gmf->add(result); + } else if (auto gm = dynamic_pointer_cast(f)) { result = gm->add(result); } else if (auto hc = dynamic_pointer_cast(f)) { if (auto gm = hc->asMixture()) { @@ -283,17 +285,15 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // taking care to correct for conditional constant. // Correct for the normalization constant used up by the conditional - auto correct = [&](const Result &pair) -> GaussianFactor::shared_ptr { + auto correct = [&](const Result &pair) { const auto &factor = pair.second; - if (!factor) return factor; // TODO(dellaert): not loving this. + if (!factor) return; auto hf = boost::dynamic_pointer_cast(factor); if (!hf) throw std::runtime_error("Expected HessianFactor!"); hf->constantTerm() += 2.0 * pair.first->logNormalizationConstant(); - return hf; }; + eliminationResults.visit(correct); - GaussianMixtureFactor::Factors correctedFactors(eliminationResults, - correct); const auto mixtureFactor = boost::make_shared( continuousSeparator, discreteSeparator, newFactors); diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp index 71b064eb6..ab64b1ec3 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp @@ -17,6 +17,7 @@ */ #include +#include #include #include #include @@ -69,6 +70,12 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize( } else if (dynamic_pointer_cast(f)) { // If discrete-only: doesn't need linearization. linearFG->push_back(f); + } else if (auto gmf = dynamic_pointer_cast(f)) { + linearFG->push_back(gmf); + } else if (auto gm = dynamic_pointer_cast(f)) { + linearFG->push_back(gm); + } else if (dynamic_pointer_cast(f)) { + linearFG->push_back(f); } else { auto& fr = *f; throw std::invalid_argument( diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index 35dd5f88b..fcee7833a 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -23,6 +23,37 @@ namespace gtsam { +/* ************************************************************************* */ +Ordering HybridSmoother::getOrdering( + const HybridGaussianFactorGraph &newFactors) { + HybridGaussianFactorGraph factors(hybridBayesNet()); + factors += newFactors; + // Get all the discrete keys from the factors + KeySet allDiscrete = factors.discreteKeySet(); + + // Create KeyVector with continuous keys followed by discrete keys. + KeyVector newKeysDiscreteLast; + const KeySet newFactorKeys = newFactors.keys(); + // Insert continuous keys first. + for (auto &k : newFactorKeys) { + if (!allDiscrete.exists(k)) { + newKeysDiscreteLast.push_back(k); + } + } + + // Insert discrete keys at the end + std::copy(allDiscrete.begin(), allDiscrete.end(), + std::back_inserter(newKeysDiscreteLast)); + + const VariableIndex index(newFactors); + + // Get an ordering where the new keys are eliminated last + Ordering ordering = Ordering::ColamdConstrainedLast( + index, KeyVector(newKeysDiscreteLast.begin(), newKeysDiscreteLast.end()), + true); + return ordering; +} + /* ************************************************************************* */ void HybridSmoother::update(HybridGaussianFactorGraph graph, const Ordering &ordering, @@ -92,7 +123,6 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph, } graph.push_back(newConditionals); - // newConditionals.print("\n\n\nNew Conditionals to add back"); } return {graph, hybridBayesNet}; } diff --git a/gtsam/hybrid/HybridSmoother.h b/gtsam/hybrid/HybridSmoother.h index 7e90f9425..9f14a7002 100644 --- a/gtsam/hybrid/HybridSmoother.h +++ b/gtsam/hybrid/HybridSmoother.h @@ -50,6 +50,8 @@ class HybridSmoother { void update(HybridGaussianFactorGraph graph, const Ordering& ordering, boost::optional maxNrLeaves = boost::none); + Ordering getOrdering(const HybridGaussianFactorGraph& newFactors); + /** * @brief Add conditionals from previous timestep as part of liquefication. * diff --git a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp index 962d238a8..4ef2af471 100644 --- a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp @@ -93,6 +93,7 @@ TEST(GaussianMixtureFactor, Sum) { EXPECT(actual.at(1) == f22); } +/* ************************************************************************* */ TEST(GaussianMixtureFactor, Printing) { DiscreteKey m1(1, 2); auto A1 = Matrix::Zero(2, 1); @@ -136,6 +137,7 @@ TEST(GaussianMixtureFactor, Printing) { EXPECT(assert_print_equal(expected, mixtureFactor)); } +/* ************************************************************************* */ TEST(GaussianMixtureFactor, GaussianMixture) { KeyVector keys; keys.push_back(X(0)); diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index e5c11bf0c..3302994c0 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -612,7 +612,6 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) { // Check that assembleGraphTree assembles Gaussian factor graphs for each // assignment. TEST(HybridGaussianFactorGraph, assembleGraphTree) { - using symbol_shorthand::Z; const int num_measurements = 1; auto fg = tiny::createHybridGaussianFactorGraph( num_measurements, VectorValues{{Z(0), Vector1(5.0)}}); @@ -694,7 +693,6 @@ bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements, /* ****************************************************************************/ // Check that eliminating tiny net with 1 measurement yields correct result. TEST(HybridGaussianFactorGraph, EliminateTiny1) { - using symbol_shorthand::Z; const int num_measurements = 1; const VectorValues measurements{{Z(0), Vector1(5.0)}}; auto bn = tiny::createHybridBayesNet(num_measurements); @@ -726,11 +724,67 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) { EXPECT(ratioTest(bn, measurements, *posterior)); } +/* ****************************************************************************/ +// Check that eliminating tiny net with 1 measurement with mode order swapped +// yields correct result. +TEST(HybridGaussianFactorGraph, EliminateTiny1Swapped) { + const VectorValues measurements{{Z(0), Vector1(5.0)}}; + + // Create mode key: 1 is low-noise, 0 is high-noise. + const DiscreteKey mode{M(0), 2}; + HybridBayesNet bn; + + // Create Gaussian mixture z_0 = x0 + noise for each measurement. + bn.emplace_back(new GaussianMixture( + {Z(0)}, {X(0)}, {mode}, + {GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Z_1x1, 3), + GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Z_1x1, + 0.5)})); + + // Create prior on X(0). + bn.push_back( + GaussianConditional::sharedMeanAndStddev(X(0), Vector1(5.0), 0.5)); + + // Add prior on mode. + bn.emplace_back(new DiscreteConditional(mode, "1/1")); + + // bn.print(); + auto fg = bn.toFactorGraph(measurements); + EXPECT_LONGS_EQUAL(3, fg.size()); + + // fg.print(); + + EXPECT(ratioTest(bn, measurements, fg)); + + // Create expected Bayes Net: + HybridBayesNet expectedBayesNet; + + // Create Gaussian mixture on X(0). + // regression, but mean checked to be 5.0 in both cases: + const auto conditional0 = boost::make_shared( + X(0), Vector1(10.1379), I_1x1 * 2.02759), + conditional1 = boost::make_shared( + X(0), Vector1(14.1421), I_1x1 * 2.82843); + expectedBayesNet.emplace_back( + new GaussianMixture({X(0)}, {}, {mode}, {conditional0, conditional1})); + + // Add prior on mode. + expectedBayesNet.emplace_back(new DiscreteConditional(mode, "1/1")); + + // Test elimination + const auto posterior = fg.eliminateSequential(); + // EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01)); + + EXPECT(ratioTest(bn, measurements, *posterior)); + + // posterior->print(); + // posterior->optimize().print(); +} + /* ****************************************************************************/ // Check that eliminating tiny net with 2 measurements yields correct result. TEST(HybridGaussianFactorGraph, EliminateTiny2) { // Create factor graph with 2 measurements such that posterior mean = 5.0. - using symbol_shorthand::Z; const int num_measurements = 2; const VectorValues measurements{{Z(0), Vector1(4.0)}, {Z(1), Vector1(6.0)}}; auto bn = tiny::createHybridBayesNet(num_measurements); @@ -764,7 +818,6 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) { // Test eliminating tiny net with 1 mode per measurement. TEST(HybridGaussianFactorGraph, EliminateTiny22) { // Create factor graph with 2 measurements such that posterior mean = 5.0. - using symbol_shorthand::Z; const int num_measurements = 2; const bool manyModes = true; @@ -835,12 +888,12 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) { // D D // | | // m1 m2 - // | | + // | | // C-x0-HC-x1-HC-x2 // | | | // HF HF HF // | | | - // n0 n1 n2 + // n0 n1 n2 // | | | // D D D EXPECT_LONGS_EQUAL(11, fg.size()); @@ -853,7 +906,7 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) { EXPECT(ratioTest(bn, measurements, fg1)); // Create ordering that eliminates in time order, then discrete modes: - Ordering ordering {X(2), X(1), X(0), N(0), N(1), N(2), M(1), M(2)}; + Ordering ordering{X(2), X(1), X(0), N(0), N(1), N(2), M(1), M(2)}; // Do elimination: const HybridBayesNet::shared_ptr posterior = fg.eliminateSequential(ordering);