`optimize` now computes MPE

release/4.3a0
Frank Dellaert 2022-01-21 10:12:31 -05:00
parent 6e4f50dfac
commit ec39197cc3
3 changed files with 187 additions and 110 deletions

View File

@ -95,22 +95,74 @@ namespace gtsam {
// } // }
// } // }
/* ************************************************************************* */ /* ************************************************************************ */
DiscreteValues DiscreteFactorGraph::optimize() const /**
{ * @brief Lookup table for max-product
gttic(DiscreteFactorGraph_optimize); *
return BaseEliminateable::eliminateSequential()->optimize(); * This inherits from a DiscreteConditional but is not normalized to 1
} *
*/
class Lookup : public DiscreteConditional {
public:
Lookup(size_t nFrontals, const DiscreteKeys& keys, const ADT& potentials)
: DiscreteConditional(nFrontals, keys, potentials) {}
};
/* ************************************************************************* */ // Alternate eliminate function for MPE
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> // std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors // PRODUCT: multiply all factors
gttic(product); gttic(product);
DecisionTreeFactor product; DecisionTreeFactor product;
for(const DiscreteFactor::shared_ptr& factor: factors) for (auto&& factor : factors) product = (*factor) * product;
product = (*factor) * product; gttoc(product);
// max out frontals, this is the factor on the separator
gttic(max);
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
gttoc(max);
// Ordering keys for the conditional so that frontalKeys are really in front
DiscreteKeys orderedKeys;
for (auto&& key : frontalKeys)
orderedKeys.emplace_back(key, product.cardinality(key));
for (auto&& key : max->keys())
orderedKeys.emplace_back(key, product.cardinality(key));
// Make lookup with product
gttic(lookup);
size_t nrFrontals = frontalKeys.size();
auto lookup = boost::make_shared<Lookup>(nrFrontals, orderedKeys, product);
gttoc(lookup);
return std::make_pair(
boost::dynamic_pointer_cast<DiscreteConditional>(lookup), max);
}
/* ************************************************************************ */
DiscreteBayesNet::shared_ptr DiscreteFactorGraph::maxProduct(
OptionalOrderingType orderingType) const {
gttic(DiscreteFactorGraph_maxProduct);
return BaseEliminateable::eliminateSequential(orderingType,
EliminateForMPE);
}
/* ************************************************************************ */
DiscreteValues DiscreteFactorGraph::optimize(
OptionalOrderingType orderingType) const {
gttic(DiscreteFactorGraph_optimize);
return maxProduct()->optimize();
}
/* ************************************************************************ */
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
gttic(product);
DecisionTreeFactor product;
for (auto&& factor : factors) product = (*factor) * product;
gttoc(product); gttoc(product);
// sum out frontals, this is the factor on the separator // sum out frontals, this is the factor on the separator
@ -120,15 +172,18 @@ namespace gtsam {
// Ordering keys for the conditional so that frontalKeys are really in front // Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys; Ordering orderedKeys;
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end()); orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(),
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end()); frontalKeys.end());
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(),
sum->keys().end());
// now divide product/sum to get conditional // now divide product/sum to get conditional
gttic(divide); gttic(divide);
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum, orderedKeys)); auto conditional =
boost::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
gttoc(divide); gttoc(divide);
return std::make_pair(cond, sum); return std::make_pair(conditional, sum);
} }
/* ************************************************************************ */ /* ************************************************************************ */

View File

@ -128,18 +128,31 @@ class GTSAM_EXPORT DiscreteFactorGraph
const std::string& s = "DiscreteFactorGraph", const std::string& s = "DiscreteFactorGraph",
const KeyFormatter& formatter = DefaultKeyFormatter) const override; const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/** Solve the factor graph by performing variable elimination in COLAMD order using /**
* the dense elimination function specified in \c function, * @brief Implement the max-product algorithm
* followed by back-substitution resulting from elimination. Is equivalent *
* to calling graph.eliminateSequential()->optimize(). */ * @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
DiscreteValues optimize() const; * @return DiscreteBayesNet::shared_ptr DAG with lookup tables
*/
boost::shared_ptr<DiscreteBayesNet> maxProduct(
OptionalOrderingType orderingType = boost::none) const;
/**
* @brief Find the maximum probable explanation (MPE) by doing max-product.
*
* @param orderingType
* @return DiscreteValues : MPE
*/
DiscreteValues optimize(
OptionalOrderingType orderingType = boost::none) const;
// /** Permute the variables in the factors */ // /** Permute the variables in the factors */
// GTSAM_EXPORT void permuteWithInverse(const Permutation& inversePermutation); // GTSAM_EXPORT void permuteWithInverse(const Permutation&
// // inversePermutation);
// /** Apply a reduction, which is a remapping of variable indices. */ //
// GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction); // /** Apply a reduction, which is a remapping of variable indices. */
// GTSAM_EXPORT void reduceWithInverse(const internal::Reduction&
// inverseReduction);
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{

View File

@ -30,8 +30,8 @@ using namespace std;
using namespace gtsam; using namespace gtsam;
/* ************************************************************************* */ /* ************************************************************************* */
TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) { TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) {
DiscreteKey PC(0,4), ME(1, 4), AI(2, 4), A(3, 3); DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3);
DiscreteFactorGraph graph; DiscreteFactorGraph graph;
graph.add(AI, "1 0 0 1"); graph.add(AI, "1 0 0 1");
@ -47,25 +47,18 @@ TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) {
graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0"); graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0"); graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
// graph.print("Graph: "); // Check MPE.
DecisionTreeFactor product = graph.product(); auto actualMPE = graph.optimize();
DecisionTreeFactor::shared_ptr sum = product.sum(1); DiscreteValues mpe;
// sum->print("Debug SUM: "); insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0);
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum)); EXPECT(assert_equal(mpe, actualMPE));
// cond->print("marginal:"); // Check Bayes Net
// pair<DiscreteBayesNet::shared_ptr, DiscreteFactor::shared_ptr> result = EliminateDiscrete(graph, 1);
// result.first->print("BayesNet: ");
// result.second->print("New factor: ");
//
Ordering ordering; Ordering ordering;
ordering += Key(0),Key(1),Key(2),Key(3); ordering += Key(0), Key(1), Key(2), Key(3);
DiscreteEliminationTree eliminationTree(graph, ordering); auto chordal = graph.eliminateSequential(ordering);
// eliminationTree.print("Elimination tree: "); // happens to be the same, but not in general!
eliminationTree.eliminate(EliminateDiscrete); EXPECT(assert_equal(mpe, chordal->optimize()));
// solver.optimize();
// DiscreteBayesNet::shared_ptr bayesNet = solver.eliminate();
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -115,10 +108,9 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( DiscreteFactorGraph, test) TEST(DiscreteFactorGraph, test) {
{
// Declare keys and ordering // Declare keys and ordering
DiscreteKey C(0,2), B(1,2), A(2,2); DiscreteKey C(0, 2), B(1, 2), A(2, 2);
// A simple factor graph (A)-fAC-(C)-fBC-(B) // A simple factor graph (A)-fAC-(C)-fBC-(B)
// with smoothness priors // with smoothness priors
@ -127,7 +119,6 @@ TEST( DiscreteFactorGraph, test)
graph.add(C & B, "3 1 1 3"); graph.add(C & B, "3 1 1 3");
// Test EliminateDiscrete // Test EliminateDiscrete
// FIXME: apparently Eliminate returns a conditional rather than a net
Ordering frontalKeys; Ordering frontalKeys;
frontalKeys += Key(0); frontalKeys += Key(0);
DiscreteConditional::shared_ptr conditional; DiscreteConditional::shared_ptr conditional;
@ -138,7 +129,7 @@ TEST( DiscreteFactorGraph, test)
CHECK(conditional); CHECK(conditional);
DiscreteBayesNet expected; DiscreteBayesNet expected;
Signature signature((C | B, A) = "9/1 1/1 1/1 1/9"); Signature signature((C | B, A) = "9/1 1/1 1/1 1/9");
// cout << signature << endl;
DiscreteConditional expectedConditional(signature); DiscreteConditional expectedConditional(signature);
EXPECT(assert_equal(expectedConditional, *conditional)); EXPECT(assert_equal(expectedConditional, *conditional));
expected.add(signature); expected.add(signature);
@ -151,7 +142,6 @@ TEST( DiscreteFactorGraph, test)
// add conditionals to complete expected Bayes net // add conditionals to complete expected Bayes net
expected.add(B | A = "5/3 3/5"); expected.add(B | A = "5/3 3/5");
expected.add(A % "1/1"); expected.add(A % "1/1");
// GTSAM_PRINT(expected);
// Test elimination tree // Test elimination tree
Ordering ordering; Ordering ordering;
@ -162,42 +152,82 @@ TEST( DiscreteFactorGraph, test)
boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete); boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete);
EXPECT(assert_equal(expected, *actual)); EXPECT(assert_equal(expected, *actual));
// // Test solver DiscreteValues mpe;
// DiscreteBayesNet::shared_ptr actual2 = solver.eliminate(); insert(mpe)(0, 0)(1, 0)(2, 0);
// EXPECT(assert_equal(expected, *actual2)); EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
// Test optimization // Check Bayes Net
DiscreteValues expectedValues; auto chordal = graph.eliminateSequential();
insert(expectedValues)(0, 0)(1, 0)(2, 0); auto notOptimal = chordal->optimize();
auto actualValues = graph.optimize(); // happens to be the same but not in general!
EXPECT(assert_equal(expectedValues, actualValues)); EXPECT(assert_equal(mpe, notOptimal));
// Test eliminateSequential
DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering);
EXPECT(assert_equal(expected, *actual2));
auto notOptimal2 = actual2->optimize();
// happens to be the same but not in general!
EXPECT(assert_equal(mpe, notOptimal2));
// Test mpe
auto actualMPE = graph.optimize();
EXPECT(assert_equal(mpe, actualMPE));
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( DiscreteFactorGraph, testMPE) TEST_UNSAFE(DiscreteFactorGraph, testMPE) {
{
// Declare a bunch of keys // Declare a bunch of keys
DiscreteKey C(0,2), A(1,2), B(2,2); DiscreteKey C(0, 2), A(1, 2), B(2, 2);
// Create Factor graph // Create Factor graph
DiscreteFactorGraph graph; DiscreteFactorGraph graph;
graph.add(C & A, "0.2 0.8 0.3 0.7"); graph.add(C & A, "0.2 0.8 0.3 0.7");
graph.add(C & B, "0.1 0.9 0.4 0.6"); graph.add(C & B, "0.1 0.9 0.4 0.6");
// graph.product().print();
// DiscreteSequentialSolver(graph).eliminate()->print();
// Check MPE.
auto actualMPE = graph.optimize(); auto actualMPE = graph.optimize();
DiscreteValues mpe;
insert(mpe)(0, 0)(1, 1)(2, 1);
EXPECT(assert_equal(mpe, actualMPE));
DiscreteValues expectedMPE; // Check Bayes Net
insert(expectedMPE)(0, 0)(1, 1)(2, 1); auto chordal = graph.eliminateSequential();
EXPECT(assert_equal(expectedMPE, actualMPE)); auto notOptimal = chordal->optimize();
// happens to be the same but not in general
EXPECT(assert_equal(mpe, notOptimal));
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244) TEST(DiscreteFactorGraph, marginalIsNotMPE) {
{ // Declare 2 keys
DiscreteKey A(0, 2), B(1, 2);
// Create Bayes net such that marginal on A is bigger for 0 than 1, but the
// MPE does not have A=0.
DiscreteBayesNet bayesNet;
bayesNet.add(B | A = "1/1 1/2");
bayesNet.add(A % "10/9");
// The expected MPE is A=1, B=1
DiscreteValues mpe;
insert(mpe)(0, 1)(1, 1);
// Which we verify using max-product:
DiscreteFactorGraph graph(bayesNet);
auto actualMPE = graph.optimize();
EXPECT(assert_equal(mpe, actualMPE));
EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression
// Optimize on BayesNet maximizes marginal, then the conditional marginals:
auto notOptimal = bayesNet.optimize();
EXPECT(graph(notOptimal) < graph(mpe));
EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression
}
/* ************************************************************************* */
TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) {
// The factor graph in Darwiche09book, page 244 // The factor graph in Darwiche09book, page 244
DiscreteKey A(4,2), C(3,2), S(2,2), T1(0,2), T2(1,2); DiscreteKey A(4, 2), C(3, 2), S(2, 2), T1(0, 2), T2(1, 2);
// Create Factor graph // Create Factor graph
DiscreteFactorGraph graph; DiscreteFactorGraph graph;
@ -206,53 +236,32 @@ TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
graph.add(C & T1, "0.80 0.20 0.20 0.80"); graph.add(C & T1, "0.80 0.20 0.20 0.80");
graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95"); graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95");
graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0"); graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0");
graph.add(A, "1 0");// evidence, A = yes (first choice in Darwiche) graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche)
//graph.product().print("Darwiche-product");
// graph.product().potentials().dot("Darwiche-product");
// DiscreteSequentialSolver(graph).eliminate()->print();
DiscreteValues expectedMPE; DiscreteValues mpe;
insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1); insert(mpe)(4, 0)(2, 1)(3, 1)(0, 1)(1, 1);
EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression
// You can check visually by printing product:
// graph.product().print("Darwiche-product");
// Use the solver machinery. // Check MPE.
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); auto actualMPE = graph.optimize();
auto actualMPE = chordal->optimize(); EXPECT(assert_equal(mpe, actualMPE));
EXPECT(assert_equal(expectedMPE, actualMPE));
// DiscreteConditional::shared_ptr root = chordal->back();
// EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9);
// Let us create the Bayes tree here, just for fun, because we don't use it now
// typedef JunctionTreeOrdered<DiscreteFactorGraph> JT;
// GenericMultifrontalSolver<DiscreteFactor, JT> solver(graph);
// BayesTreeOrdered<DiscreteConditional>::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete);
//// bayesTree->print("Bayes Tree");
// EXPECT_LONGS_EQUAL(2,bayesTree->size());
// Check Bayes Net
Ordering ordering; Ordering ordering;
ordering += Key(0),Key(1),Key(2),Key(3),Key(4); ordering += Key(0), Key(1), Key(2), Key(3), Key(4);
DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering); auto chordal = graph.eliminateSequential(ordering);
auto notOptimal = chordal->optimize(); // not MPE !
EXPECT(graph(notOptimal) < graph(mpe));
// Let us create the Bayes tree here, just for fun, because we don't use it
DiscreteBayesTree::shared_ptr bayesTree =
graph.eliminateMultifrontal(ordering);
// bayesTree->print("Bayes Tree"); // bayesTree->print("Bayes Tree");
EXPECT_LONGS_EQUAL(2,bayesTree->size()); EXPECT_LONGS_EQUAL(2, bayesTree->size());
#ifdef OLD
// Create the elimination tree manually
VariableIndexOrdered structure(graph);
typedef EliminationTreeOrdered<DiscreteFactor> ETree;
ETree::shared_ptr eTree = ETree::Create(graph, structure);
//eTree->print(">>>>>>>>>>> Elimination Tree <<<<<<<<<<<<<<<<<");
// eliminate normally and check solution
DiscreteBayesNet::shared_ptr bayesNet = eTree->eliminate(&EliminateDiscrete);
// bayesNet->print(">>>>>>>>>>>>>> Bayes Net <<<<<<<<<<<<<<<<<<");
auto actualMPE = optimize(*bayesNet);
EXPECT(assert_equal(expectedMPE, actualMPE));
// Approximate and check solution
// DiscreteBayesNet::shared_ptr approximateNet = eTree->approximate();
// approximateNet->print(">>>>>>>>>>>>>> Approximate Net <<<<<<<<<<<<<<<<<<");
// EXPECT(assert_equal(expectedMPE, *actualMPE));
#endif
} }
#ifdef OLD #ifdef OLD
/* ************************************************************************* */ /* ************************************************************************* */