diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index b4b65f885..ebcac445c 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -149,16 +149,15 @@ namespace gtsam { DiscreteBayesNet DiscreteFactorGraph::sumProduct( OptionalOrderingType orderingType) const { gttic(DiscreteFactorGraph_sumProduct); - auto bayesNet = BaseEliminateable::eliminateSequential(orderingType); + auto bayesNet = eliminateSequential(orderingType); return *bayesNet; } - DiscreteLookupDAG DiscreteFactorGraph::sumProduct( + DiscreteBayesNet DiscreteFactorGraph::sumProduct( const Ordering& ordering) const { gttic(DiscreteFactorGraph_sumProduct); - auto bayesNet = - BaseEliminateable::eliminateSequential(ordering, EliminateForMPE); - return DiscreteLookupDAG::FromBayesNet(*bayesNet); + auto bayesNet = eliminateSequential(ordering); + return *bayesNet; } /* ************************************************************************ */ @@ -170,16 +169,14 @@ namespace gtsam { DiscreteLookupDAG DiscreteFactorGraph::maxProduct( OptionalOrderingType orderingType) const { gttic(DiscreteFactorGraph_maxProduct); - auto bayesNet = - BaseEliminateable::eliminateSequential(orderingType, EliminateForMPE); + auto bayesNet = eliminateSequential(orderingType, EliminateForMPE); return DiscreteLookupDAG::FromBayesNet(*bayesNet); } DiscreteLookupDAG DiscreteFactorGraph::maxProduct( const Ordering& ordering) const { gttic(DiscreteFactorGraph_maxProduct); - auto bayesNet = - BaseEliminateable::eliminateSequential(ordering, EliminateForMPE); + auto bayesNet = eliminateSequential(ordering, EliminateForMPE); return DiscreteLookupDAG::FromBayesNet(*bayesNet); } diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index 2e9b40823..f962b1802 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -147,7 +147,7 @@ class GTSAM_EXPORT DiscreteFactorGraph * @param ordering * @return DiscreteBayesNet encoding posterior P(X|Z) */ - DiscreteLookupDAG sumProduct(const Ordering& ordering) const; + DiscreteBayesNet sumProduct(const Ordering& ordering) const; /** * @brief Implement the max-product algorithm diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 0dcbcc1cf..258286901 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -277,9 +277,9 @@ class DiscreteFactorGraph { double operator()(const gtsam::DiscreteValues& values) const; gtsam::DiscreteValues optimize() const; - gtsam::DiscreteLookupDAG sumProduct(); - gtsam::DiscreteLookupDAG sumProduct(gtsam::Ordering::OrderingType type); - gtsam::DiscreteLookupDAG sumProduct(const gtsam::Ordering& ordering); + gtsam::DiscreteBayesNet sumProduct(); + gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type); + gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering); gtsam::DiscreteLookupDAG maxProduct(); gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type); diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index 63f5b7319..0a7d869ec 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -158,6 +158,11 @@ TEST(DiscreteFactorGraph, test) { // Test sumProduct alias with all orderings: auto mpeProbability = expectedBayesNet(mpe); EXPECT_DOUBLES_EQUAL(0.28125, mpeProbability, 1e-5); // regression + + // Using custom ordering + DiscreteBayesNet bayesNet = graph.sumProduct(ordering); + EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5); + for (Ordering::OrderingType orderingType : {Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL, Ordering::CUSTOM}) {