`optimize` now computes MPE

release/4.3a0
Frank Dellaert 2022-01-21 10:12:31 -05:00
parent 6e4f50dfac
commit ec39197cc3
3 changed files with 187 additions and 110 deletions

View File

@ -95,22 +95,74 @@ namespace gtsam {
// }
// }
/* ************************************************************************* */
DiscreteValues DiscreteFactorGraph::optimize() const
{
gttic(DiscreteFactorGraph_optimize);
return BaseEliminateable::eliminateSequential()->optimize();
}
/* ************************************************************************ */
/**
* @brief Lookup table for max-product
*
* This inherits from a DiscreteConditional but is not normalized to 1
*
*/
class Lookup : public DiscreteConditional {
public:
Lookup(size_t nFrontals, const DiscreteKeys& keys, const ADT& potentials)
: DiscreteConditional(nFrontals, keys, potentials) {}
};
/* ************************************************************************* */
// Alternate eliminate function for MPE
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) {
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
gttic(product);
DecisionTreeFactor product;
for(const DiscreteFactor::shared_ptr& factor: factors)
product = (*factor) * product;
for (auto&& factor : factors) product = (*factor) * product;
gttoc(product);
// max out frontals, this is the factor on the separator
gttic(max);
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
gttoc(max);
// Ordering keys for the conditional so that frontalKeys are really in front
DiscreteKeys orderedKeys;
for (auto&& key : frontalKeys)
orderedKeys.emplace_back(key, product.cardinality(key));
for (auto&& key : max->keys())
orderedKeys.emplace_back(key, product.cardinality(key));
// Make lookup with product
gttic(lookup);
size_t nrFrontals = frontalKeys.size();
auto lookup = boost::make_shared<Lookup>(nrFrontals, orderedKeys, product);
gttoc(lookup);
return std::make_pair(
boost::dynamic_pointer_cast<DiscreteConditional>(lookup), max);
}
/* ************************************************************************ */
DiscreteBayesNet::shared_ptr DiscreteFactorGraph::maxProduct(
OptionalOrderingType orderingType) const {
gttic(DiscreteFactorGraph_maxProduct);
return BaseEliminateable::eliminateSequential(orderingType,
EliminateForMPE);
}
/* ************************************************************************ */
DiscreteValues DiscreteFactorGraph::optimize(
OptionalOrderingType orderingType) const {
gttic(DiscreteFactorGraph_optimize);
return maxProduct()->optimize();
}
/* ************************************************************************ */
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
gttic(product);
DecisionTreeFactor product;
for (auto&& factor : factors) product = (*factor) * product;
gttoc(product);
// sum out frontals, this is the factor on the separator
@ -120,15 +172,18 @@ namespace gtsam {
// Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys;
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end());
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end());
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(),
frontalKeys.end());
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(),
sum->keys().end());
// now divide product/sum to get conditional
gttic(divide);
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum, orderedKeys));
auto conditional =
boost::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
gttoc(divide);
return std::make_pair(cond, sum);
return std::make_pair(conditional, sum);
}
/* ************************************************************************ */

View File

@ -128,18 +128,31 @@ class GTSAM_EXPORT DiscreteFactorGraph
const std::string& s = "DiscreteFactorGraph",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/** Solve the factor graph by performing variable elimination in COLAMD order using
* the dense elimination function specified in \c function,
* followed by back-substitution resulting from elimination. Is equivalent
* to calling graph.eliminateSequential()->optimize(). */
DiscreteValues optimize() const;
/**
* @brief Implement the max-product algorithm
*
* @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
* @return DiscreteBayesNet::shared_ptr DAG with lookup tables
*/
boost::shared_ptr<DiscreteBayesNet> maxProduct(
OptionalOrderingType orderingType = boost::none) const;
/**
* @brief Find the maximum probable explanation (MPE) by doing max-product.
*
* @param orderingType
* @return DiscreteValues : MPE
*/
DiscreteValues optimize(
OptionalOrderingType orderingType = boost::none) const;
// /** Permute the variables in the factors */
// GTSAM_EXPORT void permuteWithInverse(const Permutation& inversePermutation);
// GTSAM_EXPORT void permuteWithInverse(const Permutation&
// inversePermutation);
//
// /** Apply a reduction, which is a remapping of variable indices. */
// GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction);
// GTSAM_EXPORT void reduceWithInverse(const internal::Reduction&
// inverseReduction);
/// @name Wrapper support
/// @{

View File

@ -47,25 +47,18 @@ TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) {
graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
// graph.print("Graph: ");
DecisionTreeFactor product = graph.product();
DecisionTreeFactor::shared_ptr sum = product.sum(1);
// sum->print("Debug SUM: ");
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum));
// Check MPE.
auto actualMPE = graph.optimize();
DiscreteValues mpe;
insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0);
EXPECT(assert_equal(mpe, actualMPE));
// cond->print("marginal:");
// pair<DiscreteBayesNet::shared_ptr, DiscreteFactor::shared_ptr> result = EliminateDiscrete(graph, 1);
// result.first->print("BayesNet: ");
// result.second->print("New factor: ");
//
// Check Bayes Net
Ordering ordering;
ordering += Key(0), Key(1), Key(2), Key(3);
DiscreteEliminationTree eliminationTree(graph, ordering);
// eliminationTree.print("Elimination tree: ");
eliminationTree.eliminate(EliminateDiscrete);
// solver.optimize();
// DiscreteBayesNet::shared_ptr bayesNet = solver.eliminate();
auto chordal = graph.eliminateSequential(ordering);
// happens to be the same, but not in general!
EXPECT(assert_equal(mpe, chordal->optimize()));
}
/* ************************************************************************* */
@ -115,8 +108,7 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
}
/* ************************************************************************* */
TEST( DiscreteFactorGraph, test)
{
TEST(DiscreteFactorGraph, test) {
// Declare keys and ordering
DiscreteKey C(0, 2), B(1, 2), A(2, 2);
@ -127,7 +119,6 @@ TEST( DiscreteFactorGraph, test)
graph.add(C & B, "3 1 1 3");
// Test EliminateDiscrete
// FIXME: apparently Eliminate returns a conditional rather than a net
Ordering frontalKeys;
frontalKeys += Key(0);
DiscreteConditional::shared_ptr conditional;
@ -138,7 +129,7 @@ TEST( DiscreteFactorGraph, test)
CHECK(conditional);
DiscreteBayesNet expected;
Signature signature((C | B, A) = "9/1 1/1 1/1 1/9");
// cout << signature << endl;
DiscreteConditional expectedConditional(signature);
EXPECT(assert_equal(expectedConditional, *conditional));
expected.add(signature);
@ -151,7 +142,6 @@ TEST( DiscreteFactorGraph, test)
// add conditionals to complete expected Bayes net
expected.add(B | A = "5/3 3/5");
expected.add(A % "1/1");
// GTSAM_PRINT(expected);
// Test elimination tree
Ordering ordering;
@ -162,20 +152,30 @@ TEST( DiscreteFactorGraph, test)
boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete);
EXPECT(assert_equal(expected, *actual));
// // Test solver
// DiscreteBayesNet::shared_ptr actual2 = solver.eliminate();
// EXPECT(assert_equal(expected, *actual2));
DiscreteValues mpe;
insert(mpe)(0, 0)(1, 0)(2, 0);
EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
// Test optimization
DiscreteValues expectedValues;
insert(expectedValues)(0, 0)(1, 0)(2, 0);
auto actualValues = graph.optimize();
EXPECT(assert_equal(expectedValues, actualValues));
// Check Bayes Net
auto chordal = graph.eliminateSequential();
auto notOptimal = chordal->optimize();
// happens to be the same but not in general!
EXPECT(assert_equal(mpe, notOptimal));
// Test eliminateSequential
DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering);
EXPECT(assert_equal(expected, *actual2));
auto notOptimal2 = actual2->optimize();
// happens to be the same but not in general!
EXPECT(assert_equal(mpe, notOptimal2));
// Test mpe
auto actualMPE = graph.optimize();
EXPECT(assert_equal(mpe, actualMPE));
}
/* ************************************************************************* */
TEST( DiscreteFactorGraph, testMPE)
{
TEST_UNSAFE(DiscreteFactorGraph, testMPE) {
// Declare a bunch of keys
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
@ -183,19 +183,49 @@ TEST( DiscreteFactorGraph, testMPE)
DiscreteFactorGraph graph;
graph.add(C & A, "0.2 0.8 0.3 0.7");
graph.add(C & B, "0.1 0.9 0.4 0.6");
// graph.product().print();
// DiscreteSequentialSolver(graph).eliminate()->print();
// Check MPE.
auto actualMPE = graph.optimize();
DiscreteValues mpe;
insert(mpe)(0, 0)(1, 1)(2, 1);
EXPECT(assert_equal(mpe, actualMPE));
DiscreteValues expectedMPE;
insert(expectedMPE)(0, 0)(1, 1)(2, 1);
EXPECT(assert_equal(expectedMPE, actualMPE));
// Check Bayes Net
auto chordal = graph.eliminateSequential();
auto notOptimal = chordal->optimize();
// happens to be the same but not in general
EXPECT(assert_equal(mpe, notOptimal));
}
/* ************************************************************************* */
TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
{
TEST(DiscreteFactorGraph, marginalIsNotMPE) {
// Declare 2 keys
DiscreteKey A(0, 2), B(1, 2);
// Create Bayes net such that marginal on A is bigger for 0 than 1, but the
// MPE does not have A=0.
DiscreteBayesNet bayesNet;
bayesNet.add(B | A = "1/1 1/2");
bayesNet.add(A % "10/9");
// The expected MPE is A=1, B=1
DiscreteValues mpe;
insert(mpe)(0, 1)(1, 1);
// Which we verify using max-product:
DiscreteFactorGraph graph(bayesNet);
auto actualMPE = graph.optimize();
EXPECT(assert_equal(mpe, actualMPE));
EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression
// Optimize on BayesNet maximizes marginal, then the conditional marginals:
auto notOptimal = bayesNet.optimize();
EXPECT(graph(notOptimal) < graph(mpe));
EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression
}
/* ************************************************************************* */
TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) {
// The factor graph in Darwiche09book, page 244
DiscreteKey A(4, 2), C(3, 2), S(2, 2), T1(0, 2), T2(1, 2);
@ -207,52 +237,31 @@ TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95");
graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0");
graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche)
DiscreteValues mpe;
insert(mpe)(4, 0)(2, 1)(3, 1)(0, 1)(1, 1);
EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression
// You can check visually by printing product:
// graph.product().print("Darwiche-product");
// graph.product().potentials().dot("Darwiche-product");
// DiscreteSequentialSolver(graph).eliminate()->print();
DiscreteValues expectedMPE;
insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1);
// Use the solver machinery.
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
auto actualMPE = chordal->optimize();
EXPECT(assert_equal(expectedMPE, actualMPE));
// DiscreteConditional::shared_ptr root = chordal->back();
// EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9);
// Let us create the Bayes tree here, just for fun, because we don't use it now
// typedef JunctionTreeOrdered<DiscreteFactorGraph> JT;
// GenericMultifrontalSolver<DiscreteFactor, JT> solver(graph);
// BayesTreeOrdered<DiscreteConditional>::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete);
//// bayesTree->print("Bayes Tree");
// EXPECT_LONGS_EQUAL(2,bayesTree->size());
// Check MPE.
auto actualMPE = graph.optimize();
EXPECT(assert_equal(mpe, actualMPE));
// Check Bayes Net
Ordering ordering;
ordering += Key(0), Key(1), Key(2), Key(3), Key(4);
DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering);
auto chordal = graph.eliminateSequential(ordering);
auto notOptimal = chordal->optimize(); // not MPE !
EXPECT(graph(notOptimal) < graph(mpe));
// Let us create the Bayes tree here, just for fun, because we don't use it
DiscreteBayesTree::shared_ptr bayesTree =
graph.eliminateMultifrontal(ordering);
// bayesTree->print("Bayes Tree");
EXPECT_LONGS_EQUAL(2, bayesTree->size());
#ifdef OLD
// Create the elimination tree manually
VariableIndexOrdered structure(graph);
typedef EliminationTreeOrdered<DiscreteFactor> ETree;
ETree::shared_ptr eTree = ETree::Create(graph, structure);
//eTree->print(">>>>>>>>>>> Elimination Tree <<<<<<<<<<<<<<<<<");
// eliminate normally and check solution
DiscreteBayesNet::shared_ptr bayesNet = eTree->eliminate(&EliminateDiscrete);
// bayesNet->print(">>>>>>>>>>>>>> Bayes Net <<<<<<<<<<<<<<<<<<");
auto actualMPE = optimize(*bayesNet);
EXPECT(assert_equal(expectedMPE, actualMPE));
// Approximate and check solution
// DiscreteBayesNet::shared_ptr approximateNet = eTree->approximate();
// approximateNet->print(">>>>>>>>>>>>>> Approximate Net <<<<<<<<<<<<<<<<<<");
// EXPECT(assert_equal(expectedMPE, *actualMPE));
#endif
}
#ifdef OLD
/* ************************************************************************* */