diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index fc91e0838..8d6fef3f4 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -32,9 +32,6 @@ namespace gtsam { class HybridValues; -/// Alias for DecisionTree of GaussianFactorGraphs -using GaussianFactorGraphTree = DecisionTree; - KeyVector CollectKeys(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys); KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2); diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index 2c0fb28a4..e69b966f8 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -136,9 +137,9 @@ HybridGaussianConditional::conditionals() const { } /* *******************************************************************************/ -GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree() +HybridGaussianProductFactor HybridGaussianConditional::asProductFactor() const { - auto wrap = [this](const GaussianConditional::shared_ptr &gc) { + auto wrap = [this](const std::shared_ptr &gc) { // First check if conditional has not been pruned if (gc) { const double Cgm_Kgcm = gc->negLogConstant() - this->negLogConstant_; @@ -155,7 +156,7 @@ GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree() } return GaussianFactorGraph{gc}; }; - return {conditionals_, wrap}; + return {{conditionals_, wrap}}; } /* *******************************************************************************/ diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index 27e31d767..3e6574abc 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -221,6 +222,9 @@ class GTSAM_EXPORT HybridGaussianConditional HybridGaussianConditional::shared_ptr prune( const DecisionTreeFactor &discreteProbs) const; + /// Convert to a DecisionTree of Gaussian factor graphs. + HybridGaussianProductFactor asProductFactor() const override; + /// @} private: @@ -231,9 +235,6 @@ class GTSAM_EXPORT HybridGaussianConditional HybridGaussianConditional(const DiscreteKeys &discreteParents, const Helper &helper); - /// Convert to a DecisionTree of Gaussian factor graphs. - GaussianFactorGraphTree asGaussianFactorGraphTree() const; - /// Check whether `given` has values for all frontal keys. bool allFrontalsGiven(const VectorValues &given) const; diff --git a/gtsam/hybrid/HybridGaussianFactor.cpp b/gtsam/hybrid/HybridGaussianFactor.cpp index aa88ded30..cc72dc0ec 100644 --- a/gtsam/hybrid/HybridGaussianFactor.cpp +++ b/gtsam/hybrid/HybridGaussianFactor.cpp @@ -195,26 +195,6 @@ HybridGaussianFactor::sharedFactor HybridGaussianFactor::operator()( return factors_(assignment).first; } -/* *******************************************************************************/ -GaussianFactorGraphTree HybridGaussianFactor::add( - const GaussianFactorGraphTree &sum) const { - using Y = GaussianFactorGraph; - auto add = [](const Y &graph1, const Y &graph2) { - auto result = graph1; - result.push_back(graph2); - return result; - }; - const auto tree = asGaussianFactorGraphTree(); - return sum.empty() ? tree : sum.apply(tree, add); -} - -/* *******************************************************************************/ -GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree() - const { - auto wrap = [](const auto &pair) { return GaussianFactorGraph{pair.first}; }; - return {factors_, wrap}; -} - /* *******************************************************************************/ HybridGaussianProductFactor HybridGaussianFactor::asProductFactor() const { return {{factors_, diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index d160798b6..64a9e1918 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -130,16 +130,6 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { /// Get factor at a given discrete assignment. sharedFactor operator()(const DiscreteValues &assignment) const; - /** - * @brief Combine the Gaussian Factor Graphs in `sum` and `this` while - * maintaining the original tree structure. - * - * @param sum Decision Tree of Gaussian Factor Graphs indexed by the - * variables. - * @return Sum - */ - GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const; - /** * @brief Compute error of the HybridGaussianFactor as a tree. * @@ -158,14 +148,6 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { /// Getter for GaussianFactor decision tree const FactorValuePairs &factors() const { return factors_; } - - /// Add HybridNonlinearFactor to a Sum, syntactic sugar. - friend GaussianFactorGraphTree &operator+=( - GaussianFactorGraphTree &sum, const HybridGaussianFactor &factor) { - sum = factor.add(sum); - return sum; - } - /** * @brief Helper function to return factors and functional to create a * DecisionTree of Gaussian Factor Graphs. @@ -175,16 +157,6 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { virtual HybridGaussianProductFactor asProductFactor() const; /// @} - protected: - /** - * @brief Helper function to return factors and functional to create a - * DecisionTree of Gaussian Factor Graphs. - * - * @return GaussianFactorGraphTree - */ - GaussianFactorGraphTree asGaussianFactorGraphTree() const; - - private: /** * @brief Helper function to augment the [A|b] matrices in the factor diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 957a85038..b58400aa4 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -127,43 +127,28 @@ void HybridGaussianFactorGraph::printErrors( std::cout.flush(); } -/* ************************************************************************ */ -static GaussianFactorGraphTree addGaussian( - const GaussianFactorGraphTree &gfgTree, - const GaussianFactor::shared_ptr &factor) { - // If the decision tree is not initialized, then initialize it. - if (gfgTree.empty()) { - GaussianFactorGraph result{factor}; - return GaussianFactorGraphTree(result); - } else { - auto add = [&factor](const GaussianFactorGraph &graph) { - auto result = graph; - result.push_back(factor); - return result; - }; - return gfgTree.apply(add); - } -} - /* ************************************************************************ */ // TODO(dellaert): it's probably more efficient to first collect the discrete // keys, and then loop over all assignments to populate a vector. -GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { - GaussianFactorGraphTree result; +HybridGaussianProductFactor +HybridGaussianFactorGraph::collectProductFactor() const { + HybridGaussianProductFactor result; for (auto &f : factors_) { - // TODO(dellaert): just use a virtual method defined in HybridFactor. + // TODO(dellaert): can we make this cleaner and less error-prone? if (auto gf = dynamic_pointer_cast(f)) { - result = addGaussian(result, gf); + result += gf; + } else if (auto gc = dynamic_pointer_cast(f)) { + result += gc; } else if (auto gmf = dynamic_pointer_cast(f)) { - result = gmf->add(result); + result += *gmf; } else if (auto gm = dynamic_pointer_cast(f)) { - result = gm->add(result); + result += *gm; // handled above already? } else if (auto hc = dynamic_pointer_cast(f)) { if (auto gm = hc->asHybrid()) { - result = gm->add(result); + result += *gm; } else if (auto g = hc->asGaussian()) { - result = addGaussian(result, g); + result += g; } else { // Has to be discrete. // TODO(dellaert): in C++20, we can use std::visit. @@ -176,7 +161,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { } else { // TODO(dellaert): there was an unattributed comment here: We need to // handle the case where the object is actually an BayesTreeOrphanWrapper! - throwRuntimeError("gtsam::assembleGraphTree", f); + throwRuntimeError("gtsam::collectProductFactor", f); } } @@ -269,21 +254,6 @@ discreteElimination(const HybridGaussianFactorGraph &factors, return {std::make_shared(result.first), result.second}; } -/* ************************************************************************ */ -// If any GaussianFactorGraph in the decision tree contains a nullptr, convert -// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will -// otherwise create a GFG with a single (null) factor, -// which doesn't register as null. -GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) { - auto emptyGaussian = [](const GaussianFactorGraph &graph) { - bool hasNull = - std::any_of(graph.begin(), graph.end(), - [](const GaussianFactor::shared_ptr &ptr) { return !ptr; }); - return hasNull ? GaussianFactorGraph() : graph; - }; - return GaussianFactorGraphTree(sum, emptyGaussian); -} - /* ************************************************************************ */ using Result = std::pair, HybridGaussianFactor::sharedFactor>; @@ -334,7 +304,7 @@ static std::shared_ptr createHybridGaussianFactor( // and negative since we want log(k) hf->constantTerm() += -2.0 * conditional->negLogConstant(); } - return {factor, 0.0}; + return {factor, conditional->negLogConstant()}; }; DecisionTree newFactors(eliminationResults, correct); @@ -359,12 +329,12 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const { // Collect all the factors to create a set of Gaussian factor graphs in a // decision tree indexed by all discrete keys involved. - GaussianFactorGraphTree factorGraphTree = assembleGraphTree(); + HybridGaussianProductFactor productFactor = collectProductFactor(); // Convert factor graphs with a nullptr to an empty factor graph. // This is done after assembly since it is non-trivial to keep track of which // FG has a nullptr as we're looping over the factors. - factorGraphTree = removeEmpty(factorGraphTree); + auto prunedProductFactor = productFactor.removeEmpty(); // This is the elimination method on the leaf nodes bool someContinuousLeft = false; @@ -383,7 +353,7 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const { }; // Perform elimination! - DecisionTree eliminationResults(factorGraphTree, eliminate); + DecisionTree eliminationResults(prunedProductFactor, eliminate); // If there are no more continuous parents we create a DiscreteFactor with the // error for each discrete choice. Otherwise, create a HybridGaussianFactor diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index a5130ca08..ad6d6337d 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -218,7 +218,7 @@ class GTSAM_EXPORT HybridGaussianFactorGraph * one for A and one for B. The leaves of the tree will be the Gaussian * factors that have only continuous keys. */ - GaussianFactorGraphTree assembleGraphTree() const; + HybridGaussianProductFactor collectProductFactor() const; /** * @brief Eliminate the given continuous keys. diff --git a/gtsam/hybrid/tests/testHybridGaussianFactor.cpp b/gtsam/hybrid/tests/testHybridGaussianFactor.cpp index 5ff8c1478..1c9476eeb 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactor.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactor.cpp @@ -82,40 +82,25 @@ TEST(HybridGaussianFactor, ConstructorVariants) { } /* ************************************************************************* */ -// "Add" two hybrid factors together. -TEST(HybridGaussianFactor, Sum) { +TEST(HybridGaussianFactor, Keys) { using namespace test_constructor; - DiscreteKey m2(2, 3); - - auto A3 = Matrix::Zero(2, 3); - auto f20 = std::make_shared(X(1), A1, X(3), A3, b); - auto f21 = std::make_shared(X(1), A1, X(3), A3, b); - auto f22 = std::make_shared(X(1), A1, X(3), A3, b); - - // TODO(Frank): why specify keys at all? And: keys in factor should be *all* - // keys, deviating from Kevin's scheme. Should we index DT on DiscreteKey? - // Design review! HybridGaussianFactor hybridFactorA(m1, {f10, f11}); - HybridGaussianFactor hybridFactorB(m2, {f20, f21, f22}); - // Check the number of keys matches what we expect EXPECT_LONGS_EQUAL(3, hybridFactorA.keys().size()); EXPECT_LONGS_EQUAL(2, hybridFactorA.continuousKeys().size()); EXPECT_LONGS_EQUAL(1, hybridFactorA.discreteKeys().size()); - // Create sum of two hybrid factors: it will be a decision tree now on both - // discrete variables m1 and m2: - GaussianFactorGraphTree sum; - sum += hybridFactorA; - sum += hybridFactorB; + DiscreteKey m2(2, 3); + auto A3 = Matrix::Zero(2, 3); + auto f20 = std::make_shared(X(1), A1, X(3), A3, b); + auto f21 = std::make_shared(X(1), A1, X(3), A3, b); + auto f22 = std::make_shared(X(1), A1, X(3), A3, b); + HybridGaussianFactor hybridFactorB(m2, {f20, f21, f22}); - // Let's check that this worked: - Assignment mode; - mode[m1.first] = 1; - mode[m2.first] = 2; - auto actual = sum(mode); - EXPECT(actual.at(0) == f11); - EXPECT(actual.at(1) == f22); + // Check the number of keys matches what we expect + EXPECT_LONGS_EQUAL(3, hybridFactorB.keys().size()); + EXPECT_LONGS_EQUAL(2, hybridFactorB.continuousKeys().size()); + EXPECT_LONGS_EQUAL(1, hybridFactorB.discreteKeys().size()); } /* ************************************************************************* */ diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index f30085f02..4c88fa8cd 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -674,14 +675,14 @@ TEST(HybridGaussianFactorGraph, IncrementalErrorTree) { /* ****************************************************************************/ // Check that assembleGraphTree assembles Gaussian factor graphs for each // assignment. -TEST(HybridGaussianFactorGraph, assembleGraphTree) { +TEST(HybridGaussianFactorGraph, collectProductFactor) { const int num_measurements = 1; auto fg = tiny::createHybridGaussianFactorGraph( num_measurements, VectorValues{{Z(0), Vector1(5.0)}}); EXPECT_LONGS_EQUAL(3, fg.size()); // Assemble graph tree: - auto actual = fg.assembleGraphTree(); + auto actual = fg.collectProductFactor(); // Create expected decision tree with two factor graphs: @@ -701,9 +702,9 @@ TEST(HybridGaussianFactorGraph, assembleGraphTree) { // Expected decision tree with two factor graphs: // f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0) - GaussianFactorGraphTree expected{ - M(0), GaussianFactorGraph(std::vector{(*hybrid)(d0), prior}), - GaussianFactorGraph(std::vector{(*hybrid)(d1), prior})}; + HybridGaussianProductFactor expected{ + {M(0), GaussianFactorGraph(std::vector{(*hybrid)(d0), prior}), + GaussianFactorGraph(std::vector{(*hybrid)(d1), prior})}}; EXPECT(assert_equal(expected(d0), actual(d0), 1e-5)); EXPECT(assert_equal(expected(d1), actual(d1), 1e-5));