Now doing MPE with DAG class
parent
fcdb5b43c1
commit
5add858c24
|
@ -21,6 +21,7 @@
|
|||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||
#include <gtsam/inference/EliminateableFactorGraph-inst.h>
|
||||
#include <gtsam/inference/FactorGraph-inst.h>
|
||||
|
||||
|
@ -96,18 +97,6 @@ namespace gtsam {
|
|||
// }
|
||||
|
||||
/* ************************************************************************ */
|
||||
/**
|
||||
* @brief Lookup table for max-product
|
||||
*
|
||||
* This inherits from a DiscreteConditional but is not normalized to 1
|
||||
*
|
||||
*/
|
||||
class Lookup : public DiscreteConditional {
|
||||
public:
|
||||
Lookup(size_t nFrontals, const DiscreteKeys& keys, const ADT& potentials)
|
||||
: DiscreteConditional(nFrontals, keys, potentials) {}
|
||||
};
|
||||
|
||||
// Alternate eliminate function for MPE
|
||||
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
||||
EliminateForMPE(const DiscreteFactorGraph& factors,
|
||||
|
@ -133,7 +122,8 @@ namespace gtsam {
|
|||
// Make lookup with product
|
||||
gttic(lookup);
|
||||
size_t nrFrontals = frontalKeys.size();
|
||||
auto lookup = boost::make_shared<Lookup>(nrFrontals, orderedKeys, product);
|
||||
auto lookup = boost::make_shared<DiscreteLookupTable>(nrFrontals,
|
||||
orderedKeys, product);
|
||||
gttoc(lookup);
|
||||
|
||||
return std::make_pair(
|
||||
|
@ -141,18 +131,37 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DiscreteBayesNet::shared_ptr DiscreteFactorGraph::maxProduct(
|
||||
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
|
||||
OptionalOrderingType orderingType) const {
|
||||
gttic(DiscreteFactorGraph_maxProduct);
|
||||
return BaseEliminateable::eliminateSequential(orderingType,
|
||||
EliminateForMPE);
|
||||
|
||||
// The solution below is a bitclunky: the elimination machinery does not
|
||||
// allow for differently *typed* versions of elimination, so we eliminate
|
||||
// into a Bayes Net using the special eliminate function above, and then
|
||||
// create the DiscreteLookupDAG after the fact, in linear time.
|
||||
auto bayesNet =
|
||||
BaseEliminateable::eliminateSequential(orderingType, EliminateForMPE);
|
||||
|
||||
// Copy to the DAG
|
||||
DiscreteLookupDAG dag;
|
||||
for (auto&& conditional : *bayesNet) {
|
||||
if (auto lookupTable =
|
||||
boost::dynamic_pointer_cast<DiscreteLookupTable>(conditional)) {
|
||||
dag.push_back(lookupTable);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"DiscreteFactorGraph::maxProduct: Expected look up table.");
|
||||
}
|
||||
}
|
||||
return dag;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DiscreteValues DiscreteFactorGraph::optimize(
|
||||
OptionalOrderingType orderingType) const {
|
||||
gttic(DiscreteFactorGraph_optimize);
|
||||
return maxProduct()->optimize();
|
||||
DiscreteLookupDAG dag = maxProduct();
|
||||
return dag.argmax();
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
|
|
|
@ -18,10 +18,11 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/inference/FactorGraph.h>
|
||||
#include <gtsam/inference/EliminateableFactorGraph.h>
|
||||
#include <gtsam/inference/Ordering.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||
#include <gtsam/inference/EliminateableFactorGraph.h>
|
||||
#include <gtsam/inference/FactorGraph.h>
|
||||
#include <gtsam/inference/Ordering.h>
|
||||
#include <gtsam/base/FastSet.h>
|
||||
|
||||
#include <boost/make_shared.hpp>
|
||||
|
@ -132,9 +133,9 @@ class GTSAM_EXPORT DiscreteFactorGraph
|
|||
* @brief Implement the max-product algorithm
|
||||
*
|
||||
* @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
|
||||
* @return DiscreteBayesNet::shared_ptr DAG with lookup tables
|
||||
* @return DiscreteLookupDAG::shared_ptr DAG with lookup tables
|
||||
*/
|
||||
boost::shared_ptr<DiscreteBayesNet> maxProduct(
|
||||
DiscreteLookupDAG maxProduct(
|
||||
OptionalOrderingType orderingType = boost::none) const;
|
||||
|
||||
/**
|
||||
|
|
|
@ -52,13 +52,6 @@ TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) {
|
|||
DiscreteValues mpe;
|
||||
insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0);
|
||||
EXPECT(assert_equal(mpe, actualMPE));
|
||||
|
||||
// Check Bayes Net
|
||||
Ordering ordering;
|
||||
ordering += Key(0), Key(1), Key(2), Key(3);
|
||||
auto chordal = graph.eliminateSequential(ordering);
|
||||
// happens to be the same, but not in general!
|
||||
EXPECT(assert_equal(mpe, chordal->optimize()));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -125,57 +118,46 @@ TEST(DiscreteFactorGraph, test) {
|
|||
DecisionTreeFactor::shared_ptr newFactor;
|
||||
boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys);
|
||||
|
||||
// Check Bayes net
|
||||
// Check Conditional
|
||||
CHECK(conditional);
|
||||
DiscreteBayesNet expected;
|
||||
Signature signature((C | B, A) = "9/1 1/1 1/1 1/9");
|
||||
|
||||
DiscreteConditional expectedConditional(signature);
|
||||
EXPECT(assert_equal(expectedConditional, *conditional));
|
||||
expected.add(signature);
|
||||
|
||||
// Check Factor
|
||||
CHECK(newFactor);
|
||||
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
||||
EXPECT(assert_equal(expectedFactor, *newFactor));
|
||||
|
||||
// add conditionals to complete expected Bayes net
|
||||
expected.add(B | A = "5/3 3/5");
|
||||
expected.add(A % "1/1");
|
||||
|
||||
// Test elimination tree
|
||||
// Test using elimination tree
|
||||
Ordering ordering;
|
||||
ordering += Key(0), Key(1), Key(2);
|
||||
DiscreteEliminationTree etree(graph, ordering);
|
||||
DiscreteBayesNet::shared_ptr actual;
|
||||
DiscreteFactorGraph::shared_ptr remainingGraph;
|
||||
boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete);
|
||||
EXPECT(assert_equal(expected, *actual));
|
||||
|
||||
DiscreteValues mpe;
|
||||
insert(mpe)(0, 0)(1, 0)(2, 0);
|
||||
EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
|
||||
|
||||
// Check Bayes Net
|
||||
auto chordal = graph.eliminateSequential();
|
||||
auto notOptimal = chordal->optimize();
|
||||
// happens to be the same but not in general!
|
||||
EXPECT(assert_equal(mpe, notOptimal));
|
||||
// Check Bayes net
|
||||
DiscreteBayesNet expectedBayesNet;
|
||||
expectedBayesNet.add(signature);
|
||||
expectedBayesNet.add(B | A = "5/3 3/5");
|
||||
expectedBayesNet.add(A % "1/1");
|
||||
EXPECT(assert_equal(expectedBayesNet, *actual));
|
||||
|
||||
// Test eliminateSequential
|
||||
DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering);
|
||||
EXPECT(assert_equal(expected, *actual2));
|
||||
auto notOptimal2 = actual2->optimize();
|
||||
// happens to be the same but not in general!
|
||||
EXPECT(assert_equal(mpe, notOptimal2));
|
||||
EXPECT(assert_equal(expectedBayesNet, *actual2));
|
||||
|
||||
// Test mpe
|
||||
DiscreteValues mpe;
|
||||
insert(mpe)(0, 0)(1, 0)(2, 0);
|
||||
auto actualMPE = graph.optimize();
|
||||
EXPECT(assert_equal(mpe, actualMPE));
|
||||
EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST_UNSAFE(DiscreteFactorGraph, testMPE) {
|
||||
TEST_UNSAFE(DiscreteFactorGraph, testMaxProduct) {
|
||||
// Declare a bunch of keys
|
||||
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
|
||||
|
||||
|
@ -184,17 +166,20 @@ TEST_UNSAFE(DiscreteFactorGraph, testMPE) {
|
|||
graph.add(C & A, "0.2 0.8 0.3 0.7");
|
||||
graph.add(C & B, "0.1 0.9 0.4 0.6");
|
||||
|
||||
// Check MPE.
|
||||
auto actualMPE = graph.optimize();
|
||||
// Created expected MPE
|
||||
DiscreteValues mpe;
|
||||
insert(mpe)(0, 0)(1, 1)(2, 1);
|
||||
EXPECT(assert_equal(mpe, actualMPE));
|
||||
|
||||
// Check Bayes Net
|
||||
auto chordal = graph.eliminateSequential();
|
||||
auto notOptimal = chordal->optimize();
|
||||
// happens to be the same but not in general
|
||||
EXPECT(assert_equal(mpe, notOptimal));
|
||||
// Do max-product with different orderings
|
||||
for (Ordering::OrderingType orderingType :
|
||||
{Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
|
||||
Ordering::CUSTOM}) {
|
||||
DiscreteLookupDAG dag = graph.maxProduct(orderingType);
|
||||
auto actualMPE = dag.argmax();
|
||||
EXPECT(assert_equal(mpe, actualMPE));
|
||||
auto actualMPE2 = graph.optimize(); // all in one
|
||||
EXPECT(assert_equal(mpe, actualMPE2));
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -218,10 +203,12 @@ TEST(DiscreteFactorGraph, marginalIsNotMPE) {
|
|||
EXPECT(assert_equal(mpe, actualMPE));
|
||||
EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression
|
||||
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
// Optimize on BayesNet maximizes marginal, then the conditional marginals:
|
||||
auto notOptimal = bayesNet.optimize();
|
||||
EXPECT(graph(notOptimal) < graph(mpe));
|
||||
EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression
|
||||
#endif
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -252,8 +239,11 @@ TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) {
|
|||
Ordering ordering;
|
||||
ordering += Key(0), Key(1), Key(2), Key(3), Key(4);
|
||||
auto chordal = graph.eliminateSequential(ordering);
|
||||
EXPECT_LONGS_EQUAL(2, chordal->size());
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
auto notOptimal = chordal->optimize(); // not MPE !
|
||||
EXPECT(graph(notOptimal) < graph(mpe));
|
||||
#endif
|
||||
|
||||
// Let us create the Bayes tree here, just for fun, because we don't use it
|
||||
DiscreteBayesTree::shared_ptr bayesTree =
|
||||
|
|
Loading…
Reference in New Issue