From ec39197cc3fb8b3a547d2b015f52537770eb80eb Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 21 Jan 2022 10:12:31 -0500 Subject: [PATCH] `optimize` now computes MPE --- gtsam/discrete/DiscreteFactorGraph.cpp | 85 +++++++-- gtsam/discrete/DiscreteFactorGraph.h | 33 +++- .../tests/testDiscreteFactorGraph.cpp | 179 +++++++++--------- 3 files changed, 187 insertions(+), 110 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index c1248c60b..d8e9aa244 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -95,22 +95,74 @@ namespace gtsam { // } // } - /* ************************************************************************* */ - DiscreteValues DiscreteFactorGraph::optimize() const - { - gttic(DiscreteFactorGraph_optimize); - return BaseEliminateable::eliminateSequential()->optimize(); - } + /* ************************************************************************ */ + /** + * @brief Lookup table for max-product + * + * This inherits from a DiscreteConditional but is not normalized to 1 + * + */ + class Lookup : public DiscreteConditional { + public: + Lookup(size_t nFrontals, const DiscreteKeys& keys, const ADT& potentials) + : DiscreteConditional(nFrontals, keys, potentials) {} + }; - /* ************************************************************************* */ + // Alternate eliminate function for MPE std::pair // - EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - + EliminateForMPE(const DiscreteFactorGraph& factors, + const Ordering& frontalKeys) { // PRODUCT: multiply all factors gttic(product); DecisionTreeFactor product; - for(const DiscreteFactor::shared_ptr& factor: factors) - product = (*factor) * product; + for (auto&& factor : factors) product = (*factor) * product; + gttoc(product); + + // max out frontals, this is the factor on the separator + gttic(max); + DecisionTreeFactor::shared_ptr max = product.max(frontalKeys); + gttoc(max); + + // Ordering keys for the conditional so that frontalKeys are really in front + DiscreteKeys orderedKeys; + for (auto&& key : frontalKeys) + orderedKeys.emplace_back(key, product.cardinality(key)); + for (auto&& key : max->keys()) + orderedKeys.emplace_back(key, product.cardinality(key)); + + // Make lookup with product + gttic(lookup); + size_t nrFrontals = frontalKeys.size(); + auto lookup = boost::make_shared(nrFrontals, orderedKeys, product); + gttoc(lookup); + + return std::make_pair( + boost::dynamic_pointer_cast(lookup), max); + } + + /* ************************************************************************ */ + DiscreteBayesNet::shared_ptr DiscreteFactorGraph::maxProduct( + OptionalOrderingType orderingType) const { + gttic(DiscreteFactorGraph_maxProduct); + return BaseEliminateable::eliminateSequential(orderingType, + EliminateForMPE); + } + + /* ************************************************************************ */ + DiscreteValues DiscreteFactorGraph::optimize( + OptionalOrderingType orderingType) const { + gttic(DiscreteFactorGraph_optimize); + return maxProduct()->optimize(); + } + + /* ************************************************************************ */ + std::pair // + EliminateDiscrete(const DiscreteFactorGraph& factors, + const Ordering& frontalKeys) { + // PRODUCT: multiply all factors + gttic(product); + DecisionTreeFactor product; + for (auto&& factor : factors) product = (*factor) * product; gttoc(product); // sum out frontals, this is the factor on the separator @@ -120,15 +172,18 @@ namespace gtsam { // Ordering keys for the conditional so that frontalKeys are really in front Ordering orderedKeys; - orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end()); - orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end()); + orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), + frontalKeys.end()); + orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), + sum->keys().end()); // now divide product/sum to get conditional gttic(divide); - DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum, orderedKeys)); + auto conditional = + boost::make_shared(product, *sum, orderedKeys); gttoc(divide); - return std::make_pair(cond, sum); + return std::make_pair(conditional, sum); } /* ************************************************************************ */ diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index 1da840eb8..b4e98c876 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -128,18 +128,31 @@ class GTSAM_EXPORT DiscreteFactorGraph const std::string& s = "DiscreteFactorGraph", const KeyFormatter& formatter = DefaultKeyFormatter) const override; - /** Solve the factor graph by performing variable elimination in COLAMD order using - * the dense elimination function specified in \c function, - * followed by back-substitution resulting from elimination. Is equivalent - * to calling graph.eliminateSequential()->optimize(). */ - DiscreteValues optimize() const; + /** + * @brief Implement the max-product algorithm + * + * @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM + * @return DiscreteBayesNet::shared_ptr DAG with lookup tables + */ + boost::shared_ptr maxProduct( + OptionalOrderingType orderingType = boost::none) const; + /** + * @brief Find the maximum probable explanation (MPE) by doing max-product. + * + * @param orderingType + * @return DiscreteValues : MPE + */ + DiscreteValues optimize( + OptionalOrderingType orderingType = boost::none) const; -// /** Permute the variables in the factors */ -// GTSAM_EXPORT void permuteWithInverse(const Permutation& inversePermutation); -// -// /** Apply a reduction, which is a remapping of variable indices. */ -// GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction); + // /** Permute the variables in the factors */ + // GTSAM_EXPORT void permuteWithInverse(const Permutation& + // inversePermutation); + // + // /** Apply a reduction, which is a remapping of variable indices. */ + // GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& + // inverseReduction); /// @name Wrapper support /// @{ diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index 579244c57..14432d08c 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -30,8 +30,8 @@ using namespace std; using namespace gtsam; /* ************************************************************************* */ -TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) { - DiscreteKey PC(0,4), ME(1, 4), AI(2, 4), A(3, 3); +TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) { + DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3); DiscreteFactorGraph graph; graph.add(AI, "1 0 0 1"); @@ -47,25 +47,18 @@ TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) { graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0"); graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0"); -// graph.print("Graph: "); - DecisionTreeFactor product = graph.product(); - DecisionTreeFactor::shared_ptr sum = product.sum(1); -// sum->print("Debug SUM: "); - DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum)); + // Check MPE. + auto actualMPE = graph.optimize(); + DiscreteValues mpe; + insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0); + EXPECT(assert_equal(mpe, actualMPE)); -// cond->print("marginal:"); - -// pair result = EliminateDiscrete(graph, 1); -// result.first->print("BayesNet: "); -// result.second->print("New factor: "); -// + // Check Bayes Net Ordering ordering; - ordering += Key(0),Key(1),Key(2),Key(3); - DiscreteEliminationTree eliminationTree(graph, ordering); -// eliminationTree.print("Elimination tree: "); - eliminationTree.eliminate(EliminateDiscrete); -// solver.optimize(); -// DiscreteBayesNet::shared_ptr bayesNet = solver.eliminate(); + ordering += Key(0), Key(1), Key(2), Key(3); + auto chordal = graph.eliminateSequential(ordering); + // happens to be the same, but not in general! + EXPECT(assert_equal(mpe, chordal->optimize())); } /* ************************************************************************* */ @@ -115,10 +108,9 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) { } /* ************************************************************************* */ -TEST( DiscreteFactorGraph, test) -{ +TEST(DiscreteFactorGraph, test) { // Declare keys and ordering - DiscreteKey C(0,2), B(1,2), A(2,2); + DiscreteKey C(0, 2), B(1, 2), A(2, 2); // A simple factor graph (A)-fAC-(C)-fBC-(B) // with smoothness priors @@ -127,7 +119,6 @@ TEST( DiscreteFactorGraph, test) graph.add(C & B, "3 1 1 3"); // Test EliminateDiscrete - // FIXME: apparently Eliminate returns a conditional rather than a net Ordering frontalKeys; frontalKeys += Key(0); DiscreteConditional::shared_ptr conditional; @@ -138,7 +129,7 @@ TEST( DiscreteFactorGraph, test) CHECK(conditional); DiscreteBayesNet expected; Signature signature((C | B, A) = "9/1 1/1 1/1 1/9"); - // cout << signature << endl; + DiscreteConditional expectedConditional(signature); EXPECT(assert_equal(expectedConditional, *conditional)); expected.add(signature); @@ -151,7 +142,6 @@ TEST( DiscreteFactorGraph, test) // add conditionals to complete expected Bayes net expected.add(B | A = "5/3 3/5"); expected.add(A % "1/1"); - // GTSAM_PRINT(expected); // Test elimination tree Ordering ordering; @@ -162,42 +152,82 @@ TEST( DiscreteFactorGraph, test) boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete); EXPECT(assert_equal(expected, *actual)); -// // Test solver -// DiscreteBayesNet::shared_ptr actual2 = solver.eliminate(); -// EXPECT(assert_equal(expected, *actual2)); + DiscreteValues mpe; + insert(mpe)(0, 0)(1, 0)(2, 0); + EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression - // Test optimization - DiscreteValues expectedValues; - insert(expectedValues)(0, 0)(1, 0)(2, 0); - auto actualValues = graph.optimize(); - EXPECT(assert_equal(expectedValues, actualValues)); + // Check Bayes Net + auto chordal = graph.eliminateSequential(); + auto notOptimal = chordal->optimize(); + // happens to be the same but not in general! + EXPECT(assert_equal(mpe, notOptimal)); + + // Test eliminateSequential + DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering); + EXPECT(assert_equal(expected, *actual2)); + auto notOptimal2 = actual2->optimize(); + // happens to be the same but not in general! + EXPECT(assert_equal(mpe, notOptimal2)); + + // Test mpe + auto actualMPE = graph.optimize(); + EXPECT(assert_equal(mpe, actualMPE)); } /* ************************************************************************* */ -TEST( DiscreteFactorGraph, testMPE) -{ +TEST_UNSAFE(DiscreteFactorGraph, testMPE) { // Declare a bunch of keys - DiscreteKey C(0,2), A(1,2), B(2,2); + DiscreteKey C(0, 2), A(1, 2), B(2, 2); // Create Factor graph DiscreteFactorGraph graph; graph.add(C & A, "0.2 0.8 0.3 0.7"); graph.add(C & B, "0.1 0.9 0.4 0.6"); - // graph.product().print(); - // DiscreteSequentialSolver(graph).eliminate()->print(); + // Check MPE. auto actualMPE = graph.optimize(); + DiscreteValues mpe; + insert(mpe)(0, 0)(1, 1)(2, 1); + EXPECT(assert_equal(mpe, actualMPE)); - DiscreteValues expectedMPE; - insert(expectedMPE)(0, 0)(1, 1)(2, 1); - EXPECT(assert_equal(expectedMPE, actualMPE)); + // Check Bayes Net + auto chordal = graph.eliminateSequential(); + auto notOptimal = chordal->optimize(); + // happens to be the same but not in general + EXPECT(assert_equal(mpe, notOptimal)); } /* ************************************************************************* */ -TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244) -{ +TEST(DiscreteFactorGraph, marginalIsNotMPE) { + // Declare 2 keys + DiscreteKey A(0, 2), B(1, 2); + + // Create Bayes net such that marginal on A is bigger for 0 than 1, but the + // MPE does not have A=0. + DiscreteBayesNet bayesNet; + bayesNet.add(B | A = "1/1 1/2"); + bayesNet.add(A % "10/9"); + + // The expected MPE is A=1, B=1 + DiscreteValues mpe; + insert(mpe)(0, 1)(1, 1); + + // Which we verify using max-product: + DiscreteFactorGraph graph(bayesNet); + auto actualMPE = graph.optimize(); + EXPECT(assert_equal(mpe, actualMPE)); + EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression + + // Optimize on BayesNet maximizes marginal, then the conditional marginals: + auto notOptimal = bayesNet.optimize(); + EXPECT(graph(notOptimal) < graph(mpe)); + EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression +} + +/* ************************************************************************* */ +TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) { // The factor graph in Darwiche09book, page 244 - DiscreteKey A(4,2), C(3,2), S(2,2), T1(0,2), T2(1,2); + DiscreteKey A(4, 2), C(3, 2), S(2, 2), T1(0, 2), T2(1, 2); // Create Factor graph DiscreteFactorGraph graph; @@ -206,53 +236,32 @@ TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244) graph.add(C & T1, "0.80 0.20 0.20 0.80"); graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95"); graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0"); - graph.add(A, "1 0");// evidence, A = yes (first choice in Darwiche) - //graph.product().print("Darwiche-product"); - // graph.product().potentials().dot("Darwiche-product"); - // DiscreteSequentialSolver(graph).eliminate()->print(); + graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche) - DiscreteValues expectedMPE; - insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1); + DiscreteValues mpe; + insert(mpe)(4, 0)(2, 1)(3, 1)(0, 1)(1, 1); + EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression + // You can check visually by printing product: + // graph.product().print("Darwiche-product"); - // Use the solver machinery. - DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); - auto actualMPE = chordal->optimize(); - EXPECT(assert_equal(expectedMPE, actualMPE)); -// DiscreteConditional::shared_ptr root = chordal->back(); -// EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9); - - // Let us create the Bayes tree here, just for fun, because we don't use it now -// typedef JunctionTreeOrdered JT; -// GenericMultifrontalSolver solver(graph); -// BayesTreeOrdered::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete); -//// bayesTree->print("Bayes Tree"); -// EXPECT_LONGS_EQUAL(2,bayesTree->size()); + // Check MPE. + auto actualMPE = graph.optimize(); + EXPECT(assert_equal(mpe, actualMPE)); + // Check Bayes Net Ordering ordering; - ordering += Key(0),Key(1),Key(2),Key(3),Key(4); - DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering); + ordering += Key(0), Key(1), Key(2), Key(3), Key(4); + auto chordal = graph.eliminateSequential(ordering); + auto notOptimal = chordal->optimize(); // not MPE ! + EXPECT(graph(notOptimal) < graph(mpe)); + + // Let us create the Bayes tree here, just for fun, because we don't use it + DiscreteBayesTree::shared_ptr bayesTree = + graph.eliminateMultifrontal(ordering); // bayesTree->print("Bayes Tree"); - EXPECT_LONGS_EQUAL(2,bayesTree->size()); - -#ifdef OLD -// Create the elimination tree manually -VariableIndexOrdered structure(graph); -typedef EliminationTreeOrdered ETree; -ETree::shared_ptr eTree = ETree::Create(graph, structure); -//eTree->print(">>>>>>>>>>> Elimination Tree <<<<<<<<<<<<<<<<<"); - -// eliminate normally and check solution -DiscreteBayesNet::shared_ptr bayesNet = eTree->eliminate(&EliminateDiscrete); -// bayesNet->print(">>>>>>>>>>>>>> Bayes Net <<<<<<<<<<<<<<<<<<"); -auto actualMPE = optimize(*bayesNet); -EXPECT(assert_equal(expectedMPE, actualMPE)); - -// Approximate and check solution -// DiscreteBayesNet::shared_ptr approximateNet = eTree->approximate(); -// approximateNet->print(">>>>>>>>>>>>>> Approximate Net <<<<<<<<<<<<<<<<<<"); -// EXPECT(assert_equal(expectedMPE, *actualMPE)); -#endif + EXPECT_LONGS_EQUAL(2, bayesTree->size()); } + #ifdef OLD /* ************************************************************************* */