diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 7ab7893cb..787c0edd7 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -366,18 +366,17 @@ static std::shared_ptr createHybridGaussianFactor( return std::make_shared(discreteSeparator, newFactors); } -static std::pair> -hybridElimination(const HybridGaussianFactorGraph &factors, - const Ordering &frontalKeys, - const std::set &discreteSeparatorSet) { - // NOTE: since we use the special JunctionTree, - // only possibility is continuous conditioned on discrete. - DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(), - discreteSeparatorSet.end()); +std::pair> +HybridGaussianFactorGraph::eliminate(const Ordering &keys) const { + // Since we eliminate all continuous variables first, + // the discrete separator will be *all* the discrete keys. + const std::set keysForDiscreteVariables = discreteKeys(); + DiscreteKeys discreteSeparator(keysForDiscreteVariables.begin(), + keysForDiscreteVariables.end()); // Collect all the factors to create a set of Gaussian factor graphs in a // decision tree indexed by all discrete keys involved. - GaussianFactorGraphTree factorGraphTree = factors.assembleGraphTree(); + GaussianFactorGraphTree factorGraphTree = assembleGraphTree(); // Convert factor graphs with a nullptr to an empty factor graph. // This is done after assembly since it is non-trivial to keep track of which @@ -392,7 +391,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors, } // Expensive elimination of product factor. - auto result = EliminatePreferCholesky(graph, frontalKeys); + auto result = EliminatePreferCholesky(graph, keys); // Record whether there any continuous variables left someContinuousLeft |= !result.second->empty(); @@ -436,7 +435,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors, */ std::pair> // EliminateHybrid(const HybridGaussianFactorGraph &factors, - const Ordering &frontalKeys) { + const Ordering &keys) { // NOTE: Because we are in the Conditional Gaussian regime there are only // a few cases: // 1. continuous variable, make a hybrid Gaussian conditional if there are @@ -510,20 +509,13 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, if (only_discrete) { // Case 1: we are only dealing with discrete - return discreteElimination(factors, frontalKeys); + return discreteElimination(factors, keys); } else if (only_continuous) { // Case 2: we are only dealing with continuous - return continuousElimination(factors, frontalKeys); + return continuousElimination(factors, keys); } else { // Case 3: We are now in the hybrid land! - KeySet frontalKeysSet(frontalKeys.begin(), frontalKeys.end()); - - // Find all discrete keys. - // Since we eliminate all continuous variables first, - // the discrete separator will be *all* the discrete keys. - std::set discreteSeparator = factors.discreteKeys(); - - return hybridElimination(factors, frontalKeys, discreteSeparator); + return factors.eliminate(keys); } } diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index d7aa1042f..923f48e38 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -217,6 +217,14 @@ class GTSAM_EXPORT HybridGaussianFactorGraph */ GaussianFactorGraphTree assembleGraphTree() const; + /** + * @brief Eliminate the given continuous keys. + * + * @param keys The continuous keys to eliminate. + * @return The conditional on the keys and a factor on the separator. + */ + std::pair, std::shared_ptr> + eliminate(const Ordering& keys) const; /// @} /// Get the GaussianFactorGraph at a given discrete assignment. diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 0421727ff..ee7ca5265 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -42,6 +43,7 @@ #include #include #include +#include #include #include @@ -120,6 +122,25 @@ std::vector components(Key key) { } } // namespace two +/* ************************************************************************* */ +TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) { + HybridGaussianFactorGraph hfg; + hfg.add(HybridGaussianFactor(m1, two::components(X(1)))); + + auto result = hfg.eliminate({X(1)}); + + // Check that we have a valid Gaussian conditional. + auto hgc = result.first->asHybrid(); + CHECK(hgc); + const HybridValues values{{{X(1), Z_3x1}}, {{M(1), 1}}}; + EXPECT(HybridConditional::CheckInvariants(*result.first, values)); + + // Check that factor is discrete and correct + auto factor = std::dynamic_pointer_cast(result.second); + CHECK(factor); + EXPECT(assert_equal(DecisionTreeFactor{m1, "1 1"}, *factor)); +} + /* ************************************************************************* */ TEST(HybridGaussianFactorGraph, eliminateFullSequentialEqualChance) { HybridGaussianFactorGraph hfg; @@ -221,20 +242,16 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalTwoClique) { hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1)); hfg.add(JacobianFactor(X(1), I_3x3, X(2), -I_3x3, Z_3x1)); - { - hfg.add(HybridGaussianFactor({M(0), 2}, two::components(X(0)))); - hfg.add(HybridGaussianFactor({M(1), 2}, two::components(X(2)))); - } + hfg.add(HybridGaussianFactor({M(0), 2}, two::components(X(0)))); + hfg.add(HybridGaussianFactor({M(1), 2}, two::components(X(2)))); hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4")); hfg.add(JacobianFactor(X(3), I_3x3, X(4), -I_3x3, Z_3x1)); hfg.add(JacobianFactor(X(4), I_3x3, X(5), -I_3x3, Z_3x1)); - { - hfg.add(HybridGaussianFactor({M(3), 2}, two::components(X(3)))); - hfg.add(HybridGaussianFactor({M(2), 2}, two::components(X(5)))); - } + hfg.add(HybridGaussianFactor({M(3), 2}, two::components(X(3)))); + hfg.add(HybridGaussianFactor({M(2), 2}, two::components(X(5)))); auto ordering_full = Ordering::ColamdConstrainedLast(hfg, {M(0), M(1), M(2), M(3)});