From 9094fe27449c02e9131a2034a8bef0c69a6cebb2 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 16 Sep 2012 04:37:59 +0000 Subject: [PATCH] Fully functioning, non-buggy separator shortcuts. Still not as tight as they can be.... --- .../discrete/tests/testDiscreteBayesTree.cpp | 175 +++++++++++++++++- 1 file changed, 167 insertions(+), 8 deletions(-) diff --git a/gtsam/discrete/tests/testDiscreteBayesTree.cpp b/gtsam/discrete/tests/testDiscreteBayesTree.cpp index c2bbcb55e..a86cd3e32 100644 --- a/gtsam/discrete/tests/testDiscreteBayesTree.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesTree.cpp @@ -16,6 +16,7 @@ */ #include +#include #include #include @@ -32,9 +33,42 @@ static bool debug = false; * Custom clique class to debug shortcuts */ class Clique: public BayesTreeCliqueBase { + +protected: + + /** + * Determine variable indices to keep in recursive separator shortcut calculation + * The factor graph p_Cp_B has keys from the parent clique Cp and from B. + * But we only keep the variables not in S union B. + */ + vector indices(derived_ptr B, + const FactorGraph& p_Cp_B) const { + + // Get all keys + set allKeys = p_Cp_B.keys(); + + // We do this by first merging S and B + boost::iterator_range indicesS = + this->conditional()->parents(); + size_t sizeS = indicesS.end() - indicesS.begin(); + vector &indicesB = B->conditional()->keys(); + vector S_union_B(indicesB.size() + sizeS); + vector::iterator it = set_union(indicesS.begin(), indicesS.end(), + indicesB.begin(), indicesB.end(), S_union_B.begin()); + + // then intersecting S_union_B with allKeys + vector keepers(indicesB.size() + sizeS); + it = set_intersection(S_union_B.begin(), it, allKeys.begin(), allKeys.end(), + keepers.begin()); + keepers.erase(it, keepers.end()); + + return keepers; + } + public: typedef BayesTreeCliqueBase Base; + typedef boost::shared_ptr shared_ptr; // Constructors Clique() { @@ -48,7 +82,13 @@ public: Base(result) { } - // evaluate value of sub-tree + /// print index signature only + void printSignature(const std::string& s = "Clique: ", + const IndexFormatter& indexFormatter = DefaultIndexFormatter) const { + ((IndexConditional::shared_ptr) conditional_)->print(s, indexFormatter); + } + + /// evaluate value of sub-tree double evaluate(const DiscreteConditional::Values & values) { double result = (*(this->conditional_))(values); // evaluate all children and multiply into result @@ -56,6 +96,85 @@ public: result *= c->evaluate(values); return result; } + + /** + * Separator shortcut function P(S||B) = P(S\B|B) + * where S is a clique separator, and B any node (e.g., a brancing in the tree) + * We can compute it recursively from the parent shortcut + * P(Sp||B) as \int P(Fp|Sp) P(Sp||B), where Fp are the frontal nodes in p + */ + FactorGraph::shared_ptr separatorShortcut(derived_ptr B) const { + + FactorGraph::shared_ptr p_S_B; //shortcut P(S||B) This is empty now + + // We only calculate the shortcut when this clique is not B + derived_ptr parent(parent_.lock()); + if (B.get() != this) { + + // Obtain P(Fp|Sp) as a factor + boost::shared_ptr p_Fp_Sp = parent->conditional()->toFactor(); + + // Obtain the parent shortcut P(Sp|B) as factors + // TODO: really annoying that we eliminate more than we have to ! + // TODO: we should only eliminate C_p\B, with S\B variables last + // TODO: and this index dance will be easier then, as well + FactorGraph p_Sp_B(parent->shortcut(B, &EliminateDiscrete)); + + // now combine P(Cp||B) = P(Fp|Sp) * P(Sp||B) + FactorGraph p_Cp_B; + p_Cp_B.push_back(p_Fp_Sp); + p_Cp_B.push_back(p_Sp_B); + + // Create a generic solver that will marginalize for us + GenericSequentialSolver solver(p_Cp_B); + + // The factor graph above will have keys from the parent clique Cp and from B. + // But we only keep the variables not in S union B. + vector keepers = indices(B, p_Cp_B); + + p_S_B = solver.jointFactorGraph(keepers, &EliminateDiscrete); + } + // return the shortcut P(S||B) + return p_S_B; + } + + /** + * The shortcut density is a conditional P(S||B) of the separator of this + * clique on the clique B. + */ + BayesNet shortcut(derived_ptr B, + Eliminate function) const { + + //Check if the ShortCut already exists + if (cachedShortcut_) { + return *cachedShortcut_; // return the cached version + } else { + BayesNet bn; + FactorGraph::shared_ptr fg = separatorShortcut(B); + if (fg) { + // calculate set S\B of indices to keep in Bayes net + vector indicesS(this->conditional()->beginParents(), + this->conditional()->endParents()); + // now get B indices out + vector &indicesB = B->conditional()->keys(); + vector S_setminus_B(indicesS.size()); + vector::iterator it = set_difference(indicesS.begin(), + indicesS.end(), indicesB.begin(), indicesB.end(), + S_setminus_B.begin()); + S_setminus_B.erase(it, S_setminus_B.end()); + set keep(S_setminus_B.begin(), S_setminus_B.end()); + BOOST_FOREACH (FactorType::shared_ptr factor,*fg) { + DecisionTreeFactor::shared_ptr df = boost::dynamic_pointer_cast< + DecisionTreeFactor>(factor); + if (keep.count(*factor->begin())) + bn.push_front(boost::make_shared(1, *df)); + } + } + cachedShortcut_ = bn; + return bn; + } + } + }; typedef BayesTree DiscreteBayesTree; @@ -73,14 +192,14 @@ TEST_UNSAFE( DiscreteMarginals, thinTree ) { const int nrNodes = 15; const size_t nrStates = 2; - // define variables +// define variables vector key; for (int i = 0; i < nrNodes; i++) { DiscreteKey key_i(i, nrStates); key.push_back(key_i); } - // create a thin-tree Bayesnet, a la Jean-Guillaume +// create a thin-tree Bayesnet, a la Jean-Guillaume DiscreteBayesNet bayesNet; add_front(bayesNet, key[14] % "1/3"); @@ -89,8 +208,8 @@ TEST_UNSAFE( DiscreteMarginals, thinTree ) { add_front(bayesNet, (key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1"); add_front(bayesNet, (key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1"); - add_front(bayesNet, (key[9] | key[12], key[14]) = "4/1 2/3 3/2 1/4"); - add_front(bayesNet, (key[8] | key[12], key[14]) = "2/3 1/4 3/2 4/1"); + add_front(bayesNet, (key[9] | key[12], key[14]) = "4/1 2/3 F 1/4"); + add_front(bayesNet, (key[8] | key[12], key[14]) = "T 1/4 3/2 4/1"); add_front(bayesNet, (key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1"); add_front(bayesNet, (key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1"); @@ -98,7 +217,7 @@ TEST_UNSAFE( DiscreteMarginals, thinTree ) { add_front(bayesNet, (key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1"); add_front(bayesNet, (key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1"); - add_front(bayesNet, (key[2] | key[9], key[12]) = "1/4 3/2 2/3 4/1"); + add_front(bayesNet, (key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1"); add_front(bayesNet, (key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4"); add_front(bayesNet, (key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1"); @@ -107,14 +226,17 @@ TEST_UNSAFE( DiscreteMarginals, thinTree ) { bayesNet.saveGraph("/tmp/discreteBayesNet.dot"); } - // create a BayesTree out of a Bayes net +// create a BayesTree out of a Bayes net DiscreteBayesTree bayesTree(bayesNet); if (debug) { GTSAM_PRINT(bayesTree); bayesTree.saveGraph("/tmp/discreteBayesTree.dot"); } - // Check whether BN and BT give the same answer on all configurations +// Check whether BN and BT give the same answer on all configurations +// Also calculate all some marginals + Vector marginals = zero(15); + double shortcut0, sum0; vector allPosbValues = cartesianProduct( key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] & key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]); @@ -123,7 +245,44 @@ TEST_UNSAFE( DiscreteMarginals, thinTree ) { double expected = evaluate(bayesNet, x); double actual = evaluate(bayesTree, x); DOUBLES_EQUAL(expected, actual, 1e-9); + // collect marginals + for (size_t i = 0; i < 15; i++) + if (x[i]) + marginals[i] += actual; + // calculate a deep shortcut + if (x[12] && x[14] & x[8]) + shortcut0 += actual; + if (x[14]) + sum0 += actual; } + DiscreteFactor::Values all1 = allPosbValues.back(); + + // check shortcut P(S0||R) to root + Clique::shared_ptr R = bayesTree.root(); + Clique::shared_ptr c = bayesTree[0]; + DiscreteBayesNet shortcut = c->shortcut(R, &EliminateDiscrete); + EXPECT_DOUBLES_EQUAL(shortcut0/sum0, evaluate(shortcut,all1), 1e-9); + + // calculate all shortcuts to root + DiscreteBayesTree::Nodes cliques = bayesTree.nodes(); + BOOST_FOREACH(Clique::shared_ptr c, cliques) { + DiscreteBayesNet shortcut = c->shortcut(R, &EliminateDiscrete); + if (debug) { + c->printSignature(); + shortcut.print("shortcut:"); + } + } + + // Check all marginals + DiscreteFactor::shared_ptr marginalFactor; + for (size_t i = 0; i < 15; i++) { + marginalFactor = bayesTree.marginalFactor(i, &EliminateDiscrete); + DiscreteFactor::Values x; + x[i] = 1; + double actual = (*marginalFactor)(x); + EXPECT_DOUBLES_EQUAL(marginals[i], actual, 1e-9); + } + } /* ************************************************************************* */