diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index f8e1b4bb8..ebcac445c 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -144,6 +144,22 @@ namespace gtsam { boost::dynamic_pointer_cast(lookup), max); } + /* ************************************************************************ */ + // sumProduct is just an alias for regular eliminateSequential. + DiscreteBayesNet DiscreteFactorGraph::sumProduct( + OptionalOrderingType orderingType) const { + gttic(DiscreteFactorGraph_sumProduct); + auto bayesNet = eliminateSequential(orderingType); + return *bayesNet; + } + + DiscreteBayesNet DiscreteFactorGraph::sumProduct( + const Ordering& ordering) const { + gttic(DiscreteFactorGraph_sumProduct); + auto bayesNet = eliminateSequential(ordering); + return *bayesNet; + } + /* ************************************************************************ */ // The max-product solution below is a bit clunky: the elimination machinery // does not allow for differently *typed* versions of elimination, so we @@ -153,16 +169,14 @@ namespace gtsam { DiscreteLookupDAG DiscreteFactorGraph::maxProduct( OptionalOrderingType orderingType) const { gttic(DiscreteFactorGraph_maxProduct); - auto bayesNet = - BaseEliminateable::eliminateSequential(orderingType, EliminateForMPE); + auto bayesNet = eliminateSequential(orderingType, EliminateForMPE); return DiscreteLookupDAG::FromBayesNet(*bayesNet); } DiscreteLookupDAG DiscreteFactorGraph::maxProduct( const Ordering& ordering) const { gttic(DiscreteFactorGraph_maxProduct); - auto bayesNet = - BaseEliminateable::eliminateSequential(ordering, EliminateForMPE); + auto bayesNet = eliminateSequential(ordering, EliminateForMPE); return DiscreteLookupDAG::FromBayesNet(*bayesNet); } diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index 1ba39ff9d..f962b1802 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -132,11 +132,28 @@ class GTSAM_EXPORT DiscreteFactorGraph const std::string& s = "DiscreteFactorGraph", const KeyFormatter& formatter = DefaultKeyFormatter) const override; + /** + * @brief Implement the sum-product algorithm + * + * @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM + * @return DiscreteBayesNet encoding posterior P(X|Z) + */ + DiscreteBayesNet sumProduct( + OptionalOrderingType orderingType = boost::none) const; + + /** + * @brief Implement the sum-product algorithm + * + * @param ordering + * @return DiscreteBayesNet encoding posterior P(X|Z) + */ + DiscreteBayesNet sumProduct(const Ordering& ordering) const; + /** * @brief Implement the max-product algorithm * * @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM - * @return DiscreteLookupDAG::shared_ptr DAG with lookup tables + * @return DiscreteLookupDAG DAG with lookup tables */ DiscreteLookupDAG maxProduct( OptionalOrderingType orderingType = boost::none) const; @@ -145,7 +162,7 @@ class GTSAM_EXPORT DiscreteFactorGraph * @brief Implement the max-product algorithm * * @param ordering - * @return DiscreteLookupDAG::shared_ptr `DAG with lookup tables + * @return DiscreteLookupDAG `DAG with lookup tables */ DiscreteLookupDAG maxProduct(const Ordering& ordering) const; diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 3f2c3e060..258286901 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -277,7 +277,12 @@ class DiscreteFactorGraph { double operator()(const gtsam::DiscreteValues& values) const; gtsam::DiscreteValues optimize() const; + gtsam::DiscreteBayesNet sumProduct(); + gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type); + gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering); + gtsam::DiscreteLookupDAG maxProduct(); + gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type); gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering); gtsam::DiscreteBayesNet eliminateSequential(); diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index f4819dab5..0a7d869ec 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -154,6 +154,21 @@ TEST(DiscreteFactorGraph, test) { auto actualMPE = graph.optimize(); EXPECT(assert_equal(mpe, actualMPE)); EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression + + // Test sumProduct alias with all orderings: + auto mpeProbability = expectedBayesNet(mpe); + EXPECT_DOUBLES_EQUAL(0.28125, mpeProbability, 1e-5); // regression + + // Using custom ordering + DiscreteBayesNet bayesNet = graph.sumProduct(ordering); + EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5); + + for (Ordering::OrderingType orderingType : + {Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL, + Ordering::CUSTOM}) { + auto bayesNet = graph.sumProduct(orderingType); + EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5); + } } /* ************************************************************************* */ diff --git a/gtsam/nonlinear/nonlinear.i b/gtsam/nonlinear/nonlinear.i index b6ab086c4..a6883d38b 100644 --- a/gtsam/nonlinear/nonlinear.i +++ b/gtsam/nonlinear/nonlinear.i @@ -111,6 +111,11 @@ size_t mrsymbolIndex(size_t key); #include class Ordering { + /// Type of ordering to use + enum OrderingType { + COLAMD, METIS, NATURAL, CUSTOM + }; + // Standard Constructors and Named Constructors Ordering(); Ordering(const gtsam::Ordering& other); diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py index 1ba145e09..ef85fc753 100644 --- a/python/gtsam/tests/test_DiscreteFactorGraph.py +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -13,9 +13,11 @@ Author: Frank Dellaert import unittest -from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues +from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering from gtsam.utils.test_case import GtsamTestCase +OrderingType = Ordering.OrderingType + class TestDiscreteFactorGraph(GtsamTestCase): """Tests for Discrete Factor Graphs.""" @@ -108,14 +110,50 @@ class TestDiscreteFactorGraph(GtsamTestCase): graph.add([C, A], "0.2 0.8 0.3 0.7") graph.add([C, B], "0.1 0.9 0.4 0.6") - actualMPE = graph.optimize() + # We know MPE + mpe = DiscreteValues() + mpe[0] = 0 + mpe[1] = 1 + mpe[2] = 1 - expectedMPE = DiscreteValues() - expectedMPE[0] = 0 - expectedMPE[1] = 1 - expectedMPE[2] = 1 + # Use maxProduct + dag = graph.maxProduct(OrderingType.COLAMD) + actualMPE = dag.argmax() self.assertEqual(list(actualMPE.items()), - list(expectedMPE.items())) + list(mpe.items())) + + # All in one + actualMPE2 = graph.optimize() + self.assertEqual(list(actualMPE2.items()), + list(mpe.items())) + + def test_sumProduct(self): + """Test sumProduct.""" + + # Declare a bunch of keys + C, A, B = (0, 2), (1, 2), (2, 2) + + # Create Factor graph + graph = DiscreteFactorGraph() + graph.add([C, A], "0.2 0.8 0.3 0.7") + graph.add([C, B], "0.1 0.9 0.4 0.6") + + # We know MPE + mpe = DiscreteValues() + mpe[0] = 0 + mpe[1] = 1 + mpe[2] = 1 + + # Use default sumProduct + bayesNet = graph.sumProduct() + mpeProbability = bayesNet(mpe) + self.assertAlmostEqual(mpeProbability, 0.36) # regression + + # Use sumProduct + for ordering_type in [OrderingType.COLAMD, OrderingType.METIS, OrderingType.NATURAL, + OrderingType.CUSTOM]: + bayesNet = graph.sumProduct(ordering_type) + self.assertEqual(bayesNet(mpe), mpeProbability) if __name__ == "__main__":