diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index f8e1b4bb8..b4b65f885 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -144,6 +144,23 @@ namespace gtsam { boost::dynamic_pointer_cast(lookup), max); } + /* ************************************************************************ */ + // sumProduct is just an alias for regular eliminateSequential. + DiscreteBayesNet DiscreteFactorGraph::sumProduct( + OptionalOrderingType orderingType) const { + gttic(DiscreteFactorGraph_sumProduct); + auto bayesNet = BaseEliminateable::eliminateSequential(orderingType); + return *bayesNet; + } + + DiscreteLookupDAG DiscreteFactorGraph::sumProduct( + const Ordering& ordering) const { + gttic(DiscreteFactorGraph_sumProduct); + auto bayesNet = + BaseEliminateable::eliminateSequential(ordering, EliminateForMPE); + return DiscreteLookupDAG::FromBayesNet(*bayesNet); + } + /* ************************************************************************ */ // The max-product solution below is a bit clunky: the elimination machinery // does not allow for differently *typed* versions of elimination, so we diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index 1ba39ff9d..2e9b40823 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -132,11 +132,28 @@ class GTSAM_EXPORT DiscreteFactorGraph const std::string& s = "DiscreteFactorGraph", const KeyFormatter& formatter = DefaultKeyFormatter) const override; + /** + * @brief Implement the sum-product algorithm + * + * @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM + * @return DiscreteBayesNet encoding posterior P(X|Z) + */ + DiscreteBayesNet sumProduct( + OptionalOrderingType orderingType = boost::none) const; + + /** + * @brief Implement the sum-product algorithm + * + * @param ordering + * @return DiscreteBayesNet encoding posterior P(X|Z) + */ + DiscreteLookupDAG sumProduct(const Ordering& ordering) const; + /** * @brief Implement the max-product algorithm * * @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM - * @return DiscreteLookupDAG::shared_ptr DAG with lookup tables + * @return DiscreteLookupDAG DAG with lookup tables */ DiscreteLookupDAG maxProduct( OptionalOrderingType orderingType = boost::none) const; @@ -145,7 +162,7 @@ class GTSAM_EXPORT DiscreteFactorGraph * @brief Implement the max-product algorithm * * @param ordering - * @return DiscreteLookupDAG::shared_ptr `DAG with lookup tables + * @return DiscreteLookupDAG `DAG with lookup tables */ DiscreteLookupDAG maxProduct(const Ordering& ordering) const; diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index f4819dab5..63f5b7319 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -154,6 +154,16 @@ TEST(DiscreteFactorGraph, test) { auto actualMPE = graph.optimize(); EXPECT(assert_equal(mpe, actualMPE)); EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression + + // Test sumProduct alias with all orderings: + auto mpeProbability = expectedBayesNet(mpe); + EXPECT_DOUBLES_EQUAL(0.28125, mpeProbability, 1e-5); // regression + for (Ordering::OrderingType orderingType : + {Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL, + Ordering::CUSTOM}) { + auto bayesNet = graph.sumProduct(orderingType); + EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5); + } } /* ************************************************************************* */