diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index d8e9aa244..a166fdce9 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -96,18 +97,6 @@ namespace gtsam { // } /* ************************************************************************ */ - /** - * @brief Lookup table for max-product - * - * This inherits from a DiscreteConditional but is not normalized to 1 - * - */ - class Lookup : public DiscreteConditional { - public: - Lookup(size_t nFrontals, const DiscreteKeys& keys, const ADT& potentials) - : DiscreteConditional(nFrontals, keys, potentials) {} - }; - // Alternate eliminate function for MPE std::pair // EliminateForMPE(const DiscreteFactorGraph& factors, @@ -133,7 +122,8 @@ namespace gtsam { // Make lookup with product gttic(lookup); size_t nrFrontals = frontalKeys.size(); - auto lookup = boost::make_shared(nrFrontals, orderedKeys, product); + auto lookup = boost::make_shared(nrFrontals, + orderedKeys, product); gttoc(lookup); return std::make_pair( @@ -141,18 +131,37 @@ namespace gtsam { } /* ************************************************************************ */ - DiscreteBayesNet::shared_ptr DiscreteFactorGraph::maxProduct( + DiscreteLookupDAG DiscreteFactorGraph::maxProduct( OptionalOrderingType orderingType) const { gttic(DiscreteFactorGraph_maxProduct); - return BaseEliminateable::eliminateSequential(orderingType, - EliminateForMPE); + + // The solution below is a bitclunky: the elimination machinery does not + // allow for differently *typed* versions of elimination, so we eliminate + // into a Bayes Net using the special eliminate function above, and then + // create the DiscreteLookupDAG after the fact, in linear time. + auto bayesNet = + BaseEliminateable::eliminateSequential(orderingType, EliminateForMPE); + + // Copy to the DAG + DiscreteLookupDAG dag; + for (auto&& conditional : *bayesNet) { + if (auto lookupTable = + boost::dynamic_pointer_cast(conditional)) { + dag.push_back(lookupTable); + } else { + throw std::runtime_error( + "DiscreteFactorGraph::maxProduct: Expected look up table."); + } + } + return dag; } /* ************************************************************************ */ DiscreteValues DiscreteFactorGraph::optimize( OptionalOrderingType orderingType) const { gttic(DiscreteFactorGraph_optimize); - return maxProduct()->optimize(); + DiscreteLookupDAG dag = maxProduct(); + return dag.argmax(); } /* ************************************************************************ */ diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index b4e98c876..7c658f548 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -18,10 +18,11 @@ #pragma once -#include -#include -#include #include +#include +#include +#include +#include #include #include @@ -132,9 +133,9 @@ class GTSAM_EXPORT DiscreteFactorGraph * @brief Implement the max-product algorithm * * @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM - * @return DiscreteBayesNet::shared_ptr DAG with lookup tables + * @return DiscreteLookupDAG::shared_ptr DAG with lookup tables */ - boost::shared_ptr maxProduct( + DiscreteLookupDAG maxProduct( OptionalOrderingType orderingType = boost::none) const; /** diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index 14432d08c..e63cc26b8 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -52,13 +52,6 @@ TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) { DiscreteValues mpe; insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0); EXPECT(assert_equal(mpe, actualMPE)); - - // Check Bayes Net - Ordering ordering; - ordering += Key(0), Key(1), Key(2), Key(3); - auto chordal = graph.eliminateSequential(ordering); - // happens to be the same, but not in general! - EXPECT(assert_equal(mpe, chordal->optimize())); } /* ************************************************************************* */ @@ -125,57 +118,46 @@ TEST(DiscreteFactorGraph, test) { DecisionTreeFactor::shared_ptr newFactor; boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys); - // Check Bayes net + // Check Conditional CHECK(conditional); - DiscreteBayesNet expected; Signature signature((C | B, A) = "9/1 1/1 1/1 1/9"); - DiscreteConditional expectedConditional(signature); EXPECT(assert_equal(expectedConditional, *conditional)); - expected.add(signature); // Check Factor CHECK(newFactor); DecisionTreeFactor expectedFactor(B & A, "10 6 6 10"); EXPECT(assert_equal(expectedFactor, *newFactor)); - // add conditionals to complete expected Bayes net - expected.add(B | A = "5/3 3/5"); - expected.add(A % "1/1"); - - // Test elimination tree + // Test using elimination tree Ordering ordering; ordering += Key(0), Key(1), Key(2); DiscreteEliminationTree etree(graph, ordering); DiscreteBayesNet::shared_ptr actual; DiscreteFactorGraph::shared_ptr remainingGraph; boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete); - EXPECT(assert_equal(expected, *actual)); - DiscreteValues mpe; - insert(mpe)(0, 0)(1, 0)(2, 0); - EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression - - // Check Bayes Net - auto chordal = graph.eliminateSequential(); - auto notOptimal = chordal->optimize(); - // happens to be the same but not in general! - EXPECT(assert_equal(mpe, notOptimal)); + // Check Bayes net + DiscreteBayesNet expectedBayesNet; + expectedBayesNet.add(signature); + expectedBayesNet.add(B | A = "5/3 3/5"); + expectedBayesNet.add(A % "1/1"); + EXPECT(assert_equal(expectedBayesNet, *actual)); // Test eliminateSequential DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering); - EXPECT(assert_equal(expected, *actual2)); - auto notOptimal2 = actual2->optimize(); - // happens to be the same but not in general! - EXPECT(assert_equal(mpe, notOptimal2)); + EXPECT(assert_equal(expectedBayesNet, *actual2)); // Test mpe + DiscreteValues mpe; + insert(mpe)(0, 0)(1, 0)(2, 0); auto actualMPE = graph.optimize(); EXPECT(assert_equal(mpe, actualMPE)); + EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression } /* ************************************************************************* */ -TEST_UNSAFE(DiscreteFactorGraph, testMPE) { +TEST_UNSAFE(DiscreteFactorGraph, testMaxProduct) { // Declare a bunch of keys DiscreteKey C(0, 2), A(1, 2), B(2, 2); @@ -184,17 +166,20 @@ TEST_UNSAFE(DiscreteFactorGraph, testMPE) { graph.add(C & A, "0.2 0.8 0.3 0.7"); graph.add(C & B, "0.1 0.9 0.4 0.6"); - // Check MPE. - auto actualMPE = graph.optimize(); + // Created expected MPE DiscreteValues mpe; insert(mpe)(0, 0)(1, 1)(2, 1); - EXPECT(assert_equal(mpe, actualMPE)); - // Check Bayes Net - auto chordal = graph.eliminateSequential(); - auto notOptimal = chordal->optimize(); - // happens to be the same but not in general - EXPECT(assert_equal(mpe, notOptimal)); + // Do max-product with different orderings + for (Ordering::OrderingType orderingType : + {Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL, + Ordering::CUSTOM}) { + DiscreteLookupDAG dag = graph.maxProduct(orderingType); + auto actualMPE = dag.argmax(); + EXPECT(assert_equal(mpe, actualMPE)); + auto actualMPE2 = graph.optimize(); // all in one + EXPECT(assert_equal(mpe, actualMPE2)); + } } /* ************************************************************************* */ @@ -218,10 +203,12 @@ TEST(DiscreteFactorGraph, marginalIsNotMPE) { EXPECT(assert_equal(mpe, actualMPE)); EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 // Optimize on BayesNet maximizes marginal, then the conditional marginals: auto notOptimal = bayesNet.optimize(); EXPECT(graph(notOptimal) < graph(mpe)); EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression +#endif } /* ************************************************************************* */ @@ -252,8 +239,11 @@ TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) { Ordering ordering; ordering += Key(0), Key(1), Key(2), Key(3), Key(4); auto chordal = graph.eliminateSequential(ordering); + EXPECT_LONGS_EQUAL(2, chordal->size()); +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 auto notOptimal = chordal->optimize(); // not MPE ! EXPECT(graph(notOptimal) < graph(mpe)); +#endif // Let us create the Bayes tree here, just for fun, because we don't use it DiscreteBayesTree::shared_ptr bayesTree =