Now doing MPE with DAG class

release/4.3a0
Frank Dellaert 2022-01-21 13:18:28 -05:00
parent fcdb5b43c1
commit 5add858c24
3 changed files with 61 additions and 61 deletions

View File

@ -21,6 +21,7 @@
#include <gtsam/discrete/DiscreteEliminationTree.h> #include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteJunctionTree.h> #include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/inference/EliminateableFactorGraph-inst.h> #include <gtsam/inference/EliminateableFactorGraph-inst.h>
#include <gtsam/inference/FactorGraph-inst.h> #include <gtsam/inference/FactorGraph-inst.h>
@ -96,18 +97,6 @@ namespace gtsam {
// } // }
/* ************************************************************************ */ /* ************************************************************************ */
/**
* @brief Lookup table for max-product
*
* This inherits from a DiscreteConditional but is not normalized to 1
*
*/
class Lookup : public DiscreteConditional {
public:
Lookup(size_t nFrontals, const DiscreteKeys& keys, const ADT& potentials)
: DiscreteConditional(nFrontals, keys, potentials) {}
};
// Alternate eliminate function for MPE // Alternate eliminate function for MPE
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> // std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors, EliminateForMPE(const DiscreteFactorGraph& factors,
@ -133,7 +122,8 @@ namespace gtsam {
// Make lookup with product // Make lookup with product
gttic(lookup); gttic(lookup);
size_t nrFrontals = frontalKeys.size(); size_t nrFrontals = frontalKeys.size();
auto lookup = boost::make_shared<Lookup>(nrFrontals, orderedKeys, product); auto lookup = boost::make_shared<DiscreteLookupTable>(nrFrontals,
orderedKeys, product);
gttoc(lookup); gttoc(lookup);
return std::make_pair( return std::make_pair(
@ -141,18 +131,37 @@ namespace gtsam {
} }
/* ************************************************************************ */ /* ************************************************************************ */
DiscreteBayesNet::shared_ptr DiscreteFactorGraph::maxProduct( DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
OptionalOrderingType orderingType) const { OptionalOrderingType orderingType) const {
gttic(DiscreteFactorGraph_maxProduct); gttic(DiscreteFactorGraph_maxProduct);
return BaseEliminateable::eliminateSequential(orderingType,
EliminateForMPE); // The solution below is a bitclunky: the elimination machinery does not
// allow for differently *typed* versions of elimination, so we eliminate
// into a Bayes Net using the special eliminate function above, and then
// create the DiscreteLookupDAG after the fact, in linear time.
auto bayesNet =
BaseEliminateable::eliminateSequential(orderingType, EliminateForMPE);
// Copy to the DAG
DiscreteLookupDAG dag;
for (auto&& conditional : *bayesNet) {
if (auto lookupTable =
boost::dynamic_pointer_cast<DiscreteLookupTable>(conditional)) {
dag.push_back(lookupTable);
} else {
throw std::runtime_error(
"DiscreteFactorGraph::maxProduct: Expected look up table.");
}
}
return dag;
} }
/* ************************************************************************ */ /* ************************************************************************ */
DiscreteValues DiscreteFactorGraph::optimize( DiscreteValues DiscreteFactorGraph::optimize(
OptionalOrderingType orderingType) const { OptionalOrderingType orderingType) const {
gttic(DiscreteFactorGraph_optimize); gttic(DiscreteFactorGraph_optimize);
return maxProduct()->optimize(); DiscreteLookupDAG dag = maxProduct();
return dag.argmax();
} }
/* ************************************************************************ */ /* ************************************************************************ */

View File

@ -18,10 +18,11 @@
#pragma once #pragma once
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/inference/EliminateableFactorGraph.h>
#include <gtsam/inference/Ordering.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/inference/EliminateableFactorGraph.h>
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/inference/Ordering.h>
#include <gtsam/base/FastSet.h> #include <gtsam/base/FastSet.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
@ -132,9 +133,9 @@ class GTSAM_EXPORT DiscreteFactorGraph
* @brief Implement the max-product algorithm * @brief Implement the max-product algorithm
* *
* @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM * @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
* @return DiscreteBayesNet::shared_ptr DAG with lookup tables * @return DiscreteLookupDAG::shared_ptr DAG with lookup tables
*/ */
boost::shared_ptr<DiscreteBayesNet> maxProduct( DiscreteLookupDAG maxProduct(
OptionalOrderingType orderingType = boost::none) const; OptionalOrderingType orderingType = boost::none) const;
/** /**

View File

@ -52,13 +52,6 @@ TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) {
DiscreteValues mpe; DiscreteValues mpe;
insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0); insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0);
EXPECT(assert_equal(mpe, actualMPE)); EXPECT(assert_equal(mpe, actualMPE));
// Check Bayes Net
Ordering ordering;
ordering += Key(0), Key(1), Key(2), Key(3);
auto chordal = graph.eliminateSequential(ordering);
// happens to be the same, but not in general!
EXPECT(assert_equal(mpe, chordal->optimize()));
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -125,57 +118,46 @@ TEST(DiscreteFactorGraph, test) {
DecisionTreeFactor::shared_ptr newFactor; DecisionTreeFactor::shared_ptr newFactor;
boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys); boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys);
// Check Bayes net // Check Conditional
CHECK(conditional); CHECK(conditional);
DiscreteBayesNet expected;
Signature signature((C | B, A) = "9/1 1/1 1/1 1/9"); Signature signature((C | B, A) = "9/1 1/1 1/1 1/9");
DiscreteConditional expectedConditional(signature); DiscreteConditional expectedConditional(signature);
EXPECT(assert_equal(expectedConditional, *conditional)); EXPECT(assert_equal(expectedConditional, *conditional));
expected.add(signature);
// Check Factor // Check Factor
CHECK(newFactor); CHECK(newFactor);
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10"); DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
EXPECT(assert_equal(expectedFactor, *newFactor)); EXPECT(assert_equal(expectedFactor, *newFactor));
// add conditionals to complete expected Bayes net // Test using elimination tree
expected.add(B | A = "5/3 3/5");
expected.add(A % "1/1");
// Test elimination tree
Ordering ordering; Ordering ordering;
ordering += Key(0), Key(1), Key(2); ordering += Key(0), Key(1), Key(2);
DiscreteEliminationTree etree(graph, ordering); DiscreteEliminationTree etree(graph, ordering);
DiscreteBayesNet::shared_ptr actual; DiscreteBayesNet::shared_ptr actual;
DiscreteFactorGraph::shared_ptr remainingGraph; DiscreteFactorGraph::shared_ptr remainingGraph;
boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete); boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete);
EXPECT(assert_equal(expected, *actual));
DiscreteValues mpe; // Check Bayes net
insert(mpe)(0, 0)(1, 0)(2, 0); DiscreteBayesNet expectedBayesNet;
EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression expectedBayesNet.add(signature);
expectedBayesNet.add(B | A = "5/3 3/5");
// Check Bayes Net expectedBayesNet.add(A % "1/1");
auto chordal = graph.eliminateSequential(); EXPECT(assert_equal(expectedBayesNet, *actual));
auto notOptimal = chordal->optimize();
// happens to be the same but not in general!
EXPECT(assert_equal(mpe, notOptimal));
// Test eliminateSequential // Test eliminateSequential
DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering); DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering);
EXPECT(assert_equal(expected, *actual2)); EXPECT(assert_equal(expectedBayesNet, *actual2));
auto notOptimal2 = actual2->optimize();
// happens to be the same but not in general!
EXPECT(assert_equal(mpe, notOptimal2));
// Test mpe // Test mpe
DiscreteValues mpe;
insert(mpe)(0, 0)(1, 0)(2, 0);
auto actualMPE = graph.optimize(); auto actualMPE = graph.optimize();
EXPECT(assert_equal(mpe, actualMPE)); EXPECT(assert_equal(mpe, actualMPE));
EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST_UNSAFE(DiscreteFactorGraph, testMPE) { TEST_UNSAFE(DiscreteFactorGraph, testMaxProduct) {
// Declare a bunch of keys // Declare a bunch of keys
DiscreteKey C(0, 2), A(1, 2), B(2, 2); DiscreteKey C(0, 2), A(1, 2), B(2, 2);
@ -184,17 +166,20 @@ TEST_UNSAFE(DiscreteFactorGraph, testMPE) {
graph.add(C & A, "0.2 0.8 0.3 0.7"); graph.add(C & A, "0.2 0.8 0.3 0.7");
graph.add(C & B, "0.1 0.9 0.4 0.6"); graph.add(C & B, "0.1 0.9 0.4 0.6");
// Check MPE. // Created expected MPE
auto actualMPE = graph.optimize();
DiscreteValues mpe; DiscreteValues mpe;
insert(mpe)(0, 0)(1, 1)(2, 1); insert(mpe)(0, 0)(1, 1)(2, 1);
EXPECT(assert_equal(mpe, actualMPE));
// Check Bayes Net // Do max-product with different orderings
auto chordal = graph.eliminateSequential(); for (Ordering::OrderingType orderingType :
auto notOptimal = chordal->optimize(); {Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
// happens to be the same but not in general Ordering::CUSTOM}) {
EXPECT(assert_equal(mpe, notOptimal)); DiscreteLookupDAG dag = graph.maxProduct(orderingType);
auto actualMPE = dag.argmax();
EXPECT(assert_equal(mpe, actualMPE));
auto actualMPE2 = graph.optimize(); // all in one
EXPECT(assert_equal(mpe, actualMPE2));
}
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -218,10 +203,12 @@ TEST(DiscreteFactorGraph, marginalIsNotMPE) {
EXPECT(assert_equal(mpe, actualMPE)); EXPECT(assert_equal(mpe, actualMPE));
EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
// Optimize on BayesNet maximizes marginal, then the conditional marginals: // Optimize on BayesNet maximizes marginal, then the conditional marginals:
auto notOptimal = bayesNet.optimize(); auto notOptimal = bayesNet.optimize();
EXPECT(graph(notOptimal) < graph(mpe)); EXPECT(graph(notOptimal) < graph(mpe));
EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression
#endif
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -252,8 +239,11 @@ TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) {
Ordering ordering; Ordering ordering;
ordering += Key(0), Key(1), Key(2), Key(3), Key(4); ordering += Key(0), Key(1), Key(2), Key(3), Key(4);
auto chordal = graph.eliminateSequential(ordering); auto chordal = graph.eliminateSequential(ordering);
EXPECT_LONGS_EQUAL(2, chordal->size());
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
auto notOptimal = chordal->optimize(); // not MPE ! auto notOptimal = chordal->optimize(); // not MPE !
EXPECT(graph(notOptimal) < graph(mpe)); EXPECT(graph(notOptimal) < graph(mpe));
#endif
// Let us create the Bayes tree here, just for fun, because we don't use it // Let us create the Bayes tree here, just for fun, because we don't use it
DiscreteBayesTree::shared_ptr bayesTree = DiscreteBayesTree::shared_ptr bayesTree =