Now doing MPE with DAG class
parent
fcdb5b43c1
commit
5add858c24
|
@ -21,6 +21,7 @@
|
||||||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
||||||
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
#include <gtsam/inference/EliminateableFactorGraph-inst.h>
|
#include <gtsam/inference/EliminateableFactorGraph-inst.h>
|
||||||
#include <gtsam/inference/FactorGraph-inst.h>
|
#include <gtsam/inference/FactorGraph-inst.h>
|
||||||
|
|
||||||
|
@ -96,18 +97,6 @@ namespace gtsam {
|
||||||
// }
|
// }
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
/**
|
|
||||||
* @brief Lookup table for max-product
|
|
||||||
*
|
|
||||||
* This inherits from a DiscreteConditional but is not normalized to 1
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
class Lookup : public DiscreteConditional {
|
|
||||||
public:
|
|
||||||
Lookup(size_t nFrontals, const DiscreteKeys& keys, const ADT& potentials)
|
|
||||||
: DiscreteConditional(nFrontals, keys, potentials) {}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Alternate eliminate function for MPE
|
// Alternate eliminate function for MPE
|
||||||
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
||||||
EliminateForMPE(const DiscreteFactorGraph& factors,
|
EliminateForMPE(const DiscreteFactorGraph& factors,
|
||||||
|
@ -133,7 +122,8 @@ namespace gtsam {
|
||||||
// Make lookup with product
|
// Make lookup with product
|
||||||
gttic(lookup);
|
gttic(lookup);
|
||||||
size_t nrFrontals = frontalKeys.size();
|
size_t nrFrontals = frontalKeys.size();
|
||||||
auto lookup = boost::make_shared<Lookup>(nrFrontals, orderedKeys, product);
|
auto lookup = boost::make_shared<DiscreteLookupTable>(nrFrontals,
|
||||||
|
orderedKeys, product);
|
||||||
gttoc(lookup);
|
gttoc(lookup);
|
||||||
|
|
||||||
return std::make_pair(
|
return std::make_pair(
|
||||||
|
@ -141,18 +131,37 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DiscreteBayesNet::shared_ptr DiscreteFactorGraph::maxProduct(
|
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
|
||||||
OptionalOrderingType orderingType) const {
|
OptionalOrderingType orderingType) const {
|
||||||
gttic(DiscreteFactorGraph_maxProduct);
|
gttic(DiscreteFactorGraph_maxProduct);
|
||||||
return BaseEliminateable::eliminateSequential(orderingType,
|
|
||||||
EliminateForMPE);
|
// The solution below is a bitclunky: the elimination machinery does not
|
||||||
|
// allow for differently *typed* versions of elimination, so we eliminate
|
||||||
|
// into a Bayes Net using the special eliminate function above, and then
|
||||||
|
// create the DiscreteLookupDAG after the fact, in linear time.
|
||||||
|
auto bayesNet =
|
||||||
|
BaseEliminateable::eliminateSequential(orderingType, EliminateForMPE);
|
||||||
|
|
||||||
|
// Copy to the DAG
|
||||||
|
DiscreteLookupDAG dag;
|
||||||
|
for (auto&& conditional : *bayesNet) {
|
||||||
|
if (auto lookupTable =
|
||||||
|
boost::dynamic_pointer_cast<DiscreteLookupTable>(conditional)) {
|
||||||
|
dag.push_back(lookupTable);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"DiscreteFactorGraph::maxProduct: Expected look up table.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dag;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DiscreteValues DiscreteFactorGraph::optimize(
|
DiscreteValues DiscreteFactorGraph::optimize(
|
||||||
OptionalOrderingType orderingType) const {
|
OptionalOrderingType orderingType) const {
|
||||||
gttic(DiscreteFactorGraph_optimize);
|
gttic(DiscreteFactorGraph_optimize);
|
||||||
return maxProduct()->optimize();
|
DiscreteLookupDAG dag = maxProduct();
|
||||||
|
return dag.argmax();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
|
|
@ -18,10 +18,11 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/inference/FactorGraph.h>
|
|
||||||
#include <gtsam/inference/EliminateableFactorGraph.h>
|
|
||||||
#include <gtsam/inference/Ordering.h>
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
|
#include <gtsam/inference/EliminateableFactorGraph.h>
|
||||||
|
#include <gtsam/inference/FactorGraph.h>
|
||||||
|
#include <gtsam/inference/Ordering.h>
|
||||||
#include <gtsam/base/FastSet.h>
|
#include <gtsam/base/FastSet.h>
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
#include <boost/make_shared.hpp>
|
||||||
|
@ -132,9 +133,9 @@ class GTSAM_EXPORT DiscreteFactorGraph
|
||||||
* @brief Implement the max-product algorithm
|
* @brief Implement the max-product algorithm
|
||||||
*
|
*
|
||||||
* @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
|
* @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
|
||||||
* @return DiscreteBayesNet::shared_ptr DAG with lookup tables
|
* @return DiscreteLookupDAG::shared_ptr DAG with lookup tables
|
||||||
*/
|
*/
|
||||||
boost::shared_ptr<DiscreteBayesNet> maxProduct(
|
DiscreteLookupDAG maxProduct(
|
||||||
OptionalOrderingType orderingType = boost::none) const;
|
OptionalOrderingType orderingType = boost::none) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -52,13 +52,6 @@ TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) {
|
||||||
DiscreteValues mpe;
|
DiscreteValues mpe;
|
||||||
insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0);
|
insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0);
|
||||||
EXPECT(assert_equal(mpe, actualMPE));
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
|
|
||||||
// Check Bayes Net
|
|
||||||
Ordering ordering;
|
|
||||||
ordering += Key(0), Key(1), Key(2), Key(3);
|
|
||||||
auto chordal = graph.eliminateSequential(ordering);
|
|
||||||
// happens to be the same, but not in general!
|
|
||||||
EXPECT(assert_equal(mpe, chordal->optimize()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -125,57 +118,46 @@ TEST(DiscreteFactorGraph, test) {
|
||||||
DecisionTreeFactor::shared_ptr newFactor;
|
DecisionTreeFactor::shared_ptr newFactor;
|
||||||
boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys);
|
boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys);
|
||||||
|
|
||||||
// Check Bayes net
|
// Check Conditional
|
||||||
CHECK(conditional);
|
CHECK(conditional);
|
||||||
DiscreteBayesNet expected;
|
|
||||||
Signature signature((C | B, A) = "9/1 1/1 1/1 1/9");
|
Signature signature((C | B, A) = "9/1 1/1 1/1 1/9");
|
||||||
|
|
||||||
DiscreteConditional expectedConditional(signature);
|
DiscreteConditional expectedConditional(signature);
|
||||||
EXPECT(assert_equal(expectedConditional, *conditional));
|
EXPECT(assert_equal(expectedConditional, *conditional));
|
||||||
expected.add(signature);
|
|
||||||
|
|
||||||
// Check Factor
|
// Check Factor
|
||||||
CHECK(newFactor);
|
CHECK(newFactor);
|
||||||
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
||||||
EXPECT(assert_equal(expectedFactor, *newFactor));
|
EXPECT(assert_equal(expectedFactor, *newFactor));
|
||||||
|
|
||||||
// add conditionals to complete expected Bayes net
|
// Test using elimination tree
|
||||||
expected.add(B | A = "5/3 3/5");
|
|
||||||
expected.add(A % "1/1");
|
|
||||||
|
|
||||||
// Test elimination tree
|
|
||||||
Ordering ordering;
|
Ordering ordering;
|
||||||
ordering += Key(0), Key(1), Key(2);
|
ordering += Key(0), Key(1), Key(2);
|
||||||
DiscreteEliminationTree etree(graph, ordering);
|
DiscreteEliminationTree etree(graph, ordering);
|
||||||
DiscreteBayesNet::shared_ptr actual;
|
DiscreteBayesNet::shared_ptr actual;
|
||||||
DiscreteFactorGraph::shared_ptr remainingGraph;
|
DiscreteFactorGraph::shared_ptr remainingGraph;
|
||||||
boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete);
|
boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete);
|
||||||
EXPECT(assert_equal(expected, *actual));
|
|
||||||
|
|
||||||
DiscreteValues mpe;
|
// Check Bayes net
|
||||||
insert(mpe)(0, 0)(1, 0)(2, 0);
|
DiscreteBayesNet expectedBayesNet;
|
||||||
EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
|
expectedBayesNet.add(signature);
|
||||||
|
expectedBayesNet.add(B | A = "5/3 3/5");
|
||||||
// Check Bayes Net
|
expectedBayesNet.add(A % "1/1");
|
||||||
auto chordal = graph.eliminateSequential();
|
EXPECT(assert_equal(expectedBayesNet, *actual));
|
||||||
auto notOptimal = chordal->optimize();
|
|
||||||
// happens to be the same but not in general!
|
|
||||||
EXPECT(assert_equal(mpe, notOptimal));
|
|
||||||
|
|
||||||
// Test eliminateSequential
|
// Test eliminateSequential
|
||||||
DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering);
|
DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering);
|
||||||
EXPECT(assert_equal(expected, *actual2));
|
EXPECT(assert_equal(expectedBayesNet, *actual2));
|
||||||
auto notOptimal2 = actual2->optimize();
|
|
||||||
// happens to be the same but not in general!
|
|
||||||
EXPECT(assert_equal(mpe, notOptimal2));
|
|
||||||
|
|
||||||
// Test mpe
|
// Test mpe
|
||||||
|
DiscreteValues mpe;
|
||||||
|
insert(mpe)(0, 0)(1, 0)(2, 0);
|
||||||
auto actualMPE = graph.optimize();
|
auto actualMPE = graph.optimize();
|
||||||
EXPECT(assert_equal(mpe, actualMPE));
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
|
EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST_UNSAFE(DiscreteFactorGraph, testMPE) {
|
TEST_UNSAFE(DiscreteFactorGraph, testMaxProduct) {
|
||||||
// Declare a bunch of keys
|
// Declare a bunch of keys
|
||||||
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
|
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
|
||||||
|
|
||||||
|
@ -184,17 +166,20 @@ TEST_UNSAFE(DiscreteFactorGraph, testMPE) {
|
||||||
graph.add(C & A, "0.2 0.8 0.3 0.7");
|
graph.add(C & A, "0.2 0.8 0.3 0.7");
|
||||||
graph.add(C & B, "0.1 0.9 0.4 0.6");
|
graph.add(C & B, "0.1 0.9 0.4 0.6");
|
||||||
|
|
||||||
// Check MPE.
|
// Created expected MPE
|
||||||
auto actualMPE = graph.optimize();
|
|
||||||
DiscreteValues mpe;
|
DiscreteValues mpe;
|
||||||
insert(mpe)(0, 0)(1, 1)(2, 1);
|
insert(mpe)(0, 0)(1, 1)(2, 1);
|
||||||
EXPECT(assert_equal(mpe, actualMPE));
|
|
||||||
|
|
||||||
// Check Bayes Net
|
// Do max-product with different orderings
|
||||||
auto chordal = graph.eliminateSequential();
|
for (Ordering::OrderingType orderingType :
|
||||||
auto notOptimal = chordal->optimize();
|
{Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
|
||||||
// happens to be the same but not in general
|
Ordering::CUSTOM}) {
|
||||||
EXPECT(assert_equal(mpe, notOptimal));
|
DiscreteLookupDAG dag = graph.maxProduct(orderingType);
|
||||||
|
auto actualMPE = dag.argmax();
|
||||||
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
|
auto actualMPE2 = graph.optimize(); // all in one
|
||||||
|
EXPECT(assert_equal(mpe, actualMPE2));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -218,10 +203,12 @@ TEST(DiscreteFactorGraph, marginalIsNotMPE) {
|
||||||
EXPECT(assert_equal(mpe, actualMPE));
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression
|
EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression
|
||||||
|
|
||||||
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
// Optimize on BayesNet maximizes marginal, then the conditional marginals:
|
// Optimize on BayesNet maximizes marginal, then the conditional marginals:
|
||||||
auto notOptimal = bayesNet.optimize();
|
auto notOptimal = bayesNet.optimize();
|
||||||
EXPECT(graph(notOptimal) < graph(mpe));
|
EXPECT(graph(notOptimal) < graph(mpe));
|
||||||
EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression
|
EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -252,8 +239,11 @@ TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) {
|
||||||
Ordering ordering;
|
Ordering ordering;
|
||||||
ordering += Key(0), Key(1), Key(2), Key(3), Key(4);
|
ordering += Key(0), Key(1), Key(2), Key(3), Key(4);
|
||||||
auto chordal = graph.eliminateSequential(ordering);
|
auto chordal = graph.eliminateSequential(ordering);
|
||||||
|
EXPECT_LONGS_EQUAL(2, chordal->size());
|
||||||
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
auto notOptimal = chordal->optimize(); // not MPE !
|
auto notOptimal = chordal->optimize(); // not MPE !
|
||||||
EXPECT(graph(notOptimal) < graph(mpe));
|
EXPECT(graph(notOptimal) < graph(mpe));
|
||||||
|
#endif
|
||||||
|
|
||||||
// Let us create the Bayes tree here, just for fun, because we don't use it
|
// Let us create the Bayes tree here, just for fun, because we don't use it
|
||||||
DiscreteBayesTree::shared_ptr bayesTree =
|
DiscreteBayesTree::shared_ptr bayesTree =
|
||||||
|
|
Loading…
Reference in New Issue