Fully functioning, non-buggy separator shortcuts. Still not as tight as they can be....

release/4.3a0
Frank Dellaert 2012-09-16 04:37:59 +00:00
parent 44c66cb0cb
commit 9094fe2744
1 changed files with 167 additions and 8 deletions

View File

@ -16,6 +16,7 @@
*/
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/inference/BayesTree.h>
#include <boost/assign/std/vector.hpp>
@ -32,9 +33,42 @@ static bool debug = false;
* Custom clique class to debug shortcuts
*/
class Clique: public BayesTreeCliqueBase<Clique, DiscreteConditional> {
protected:
/**
* Determine variable indices to keep in recursive separator shortcut calculation
* The factor graph p_Cp_B has keys from the parent clique Cp and from B.
* But we only keep the variables not in S union B.
*/
vector<Index> indices(derived_ptr B,
const FactorGraph<FactorType>& p_Cp_B) const {
// Get all keys
set<Index> allKeys = p_Cp_B.keys();
// We do this by first merging S and B
boost::iterator_range<FactorType::iterator> indicesS =
this->conditional()->parents();
size_t sizeS = indicesS.end() - indicesS.begin();
vector<Index> &indicesB = B->conditional()->keys();
vector<Index> S_union_B(indicesB.size() + sizeS);
vector<Index>::iterator it = set_union(indicesS.begin(), indicesS.end(),
indicesB.begin(), indicesB.end(), S_union_B.begin());
// then intersecting S_union_B with allKeys
vector<Index> keepers(indicesB.size() + sizeS);
it = set_intersection(S_union_B.begin(), it, allKeys.begin(), allKeys.end(),
keepers.begin());
keepers.erase(it, keepers.end());
return keepers;
}
public:
typedef BayesTreeCliqueBase<Clique, DiscreteConditional> Base;
typedef boost::shared_ptr<Clique> shared_ptr;
// Constructors
Clique() {
@ -48,7 +82,13 @@ public:
Base(result) {
}
// evaluate value of sub-tree
/// print index signature only
void printSignature(const std::string& s = "Clique: ",
const IndexFormatter& indexFormatter = DefaultIndexFormatter) const {
((IndexConditional::shared_ptr) conditional_)->print(s, indexFormatter);
}
/// evaluate value of sub-tree
double evaluate(const DiscreteConditional::Values & values) {
double result = (*(this->conditional_))(values);
// evaluate all children and multiply into result
@ -56,6 +96,85 @@ public:
result *= c->evaluate(values);
return result;
}
/**
* Separator shortcut function P(S||B) = P(S\B|B)
* where S is a clique separator, and B any node (e.g., a brancing in the tree)
* We can compute it recursively from the parent shortcut
* P(Sp||B) as \int P(Fp|Sp) P(Sp||B), where Fp are the frontal nodes in p
*/
FactorGraph<FactorType>::shared_ptr separatorShortcut(derived_ptr B) const {
FactorGraph<FactorType>::shared_ptr p_S_B; //shortcut P(S||B) This is empty now
// We only calculate the shortcut when this clique is not B
derived_ptr parent(parent_.lock());
if (B.get() != this) {
// Obtain P(Fp|Sp) as a factor
boost::shared_ptr<FactorType> p_Fp_Sp = parent->conditional()->toFactor();
// Obtain the parent shortcut P(Sp|B) as factors
// TODO: really annoying that we eliminate more than we have to !
// TODO: we should only eliminate C_p\B, with S\B variables last
// TODO: and this index dance will be easier then, as well
FactorGraph<FactorType> p_Sp_B(parent->shortcut(B, &EliminateDiscrete));
// now combine P(Cp||B) = P(Fp|Sp) * P(Sp||B)
FactorGraph<FactorType> p_Cp_B;
p_Cp_B.push_back(p_Fp_Sp);
p_Cp_B.push_back(p_Sp_B);
// Create a generic solver that will marginalize for us
GenericSequentialSolver<FactorType> solver(p_Cp_B);
// The factor graph above will have keys from the parent clique Cp and from B.
// But we only keep the variables not in S union B.
vector<Index> keepers = indices(B, p_Cp_B);
p_S_B = solver.jointFactorGraph(keepers, &EliminateDiscrete);
}
// return the shortcut P(S||B)
return p_S_B;
}
/**
* The shortcut density is a conditional P(S||B) of the separator of this
* clique on the clique B.
*/
BayesNet<DiscreteConditional> shortcut(derived_ptr B,
Eliminate function) const {
//Check if the ShortCut already exists
if (cachedShortcut_) {
return *cachedShortcut_; // return the cached version
} else {
BayesNet<DiscreteConditional> bn;
FactorGraph<FactorType>::shared_ptr fg = separatorShortcut(B);
if (fg) {
// calculate set S\B of indices to keep in Bayes net
vector<Index> indicesS(this->conditional()->beginParents(),
this->conditional()->endParents());
// now get B indices out
vector<Index> &indicesB = B->conditional()->keys();
vector<Index> S_setminus_B(indicesS.size());
vector<Index>::iterator it = set_difference(indicesS.begin(),
indicesS.end(), indicesB.begin(), indicesB.end(),
S_setminus_B.begin());
S_setminus_B.erase(it, S_setminus_B.end());
set<Index> keep(S_setminus_B.begin(), S_setminus_B.end());
BOOST_FOREACH (FactorType::shared_ptr factor,*fg) {
DecisionTreeFactor::shared_ptr df = boost::dynamic_pointer_cast<
DecisionTreeFactor>(factor);
if (keep.count(*factor->begin()))
bn.push_front(boost::make_shared<DiscreteConditional>(1, *df));
}
}
cachedShortcut_ = bn;
return bn;
}
}
};
typedef BayesTree<DiscreteConditional, Clique> DiscreteBayesTree;
@ -73,14 +192,14 @@ TEST_UNSAFE( DiscreteMarginals, thinTree ) {
const int nrNodes = 15;
const size_t nrStates = 2;
// define variables
// define variables
vector<DiscreteKey> key;
for (int i = 0; i < nrNodes; i++) {
DiscreteKey key_i(i, nrStates);
key.push_back(key_i);
}
// create a thin-tree Bayesnet, a la Jean-Guillaume
// create a thin-tree Bayesnet, a la Jean-Guillaume
DiscreteBayesNet bayesNet;
add_front(bayesNet, key[14] % "1/3");
@ -89,8 +208,8 @@ TEST_UNSAFE( DiscreteMarginals, thinTree ) {
add_front(bayesNet, (key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1");
add_front(bayesNet, (key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1");
add_front(bayesNet, (key[9] | key[12], key[14]) = "4/1 2/3 3/2 1/4");
add_front(bayesNet, (key[8] | key[12], key[14]) = "2/3 1/4 3/2 4/1");
add_front(bayesNet, (key[9] | key[12], key[14]) = "4/1 2/3 F 1/4");
add_front(bayesNet, (key[8] | key[12], key[14]) = "T 1/4 3/2 4/1");
add_front(bayesNet, (key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1");
add_front(bayesNet, (key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1");
@ -98,7 +217,7 @@ TEST_UNSAFE( DiscreteMarginals, thinTree ) {
add_front(bayesNet, (key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1");
add_front(bayesNet, (key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1");
add_front(bayesNet, (key[2] | key[9], key[12]) = "1/4 3/2 2/3 4/1");
add_front(bayesNet, (key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1");
add_front(bayesNet, (key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4");
add_front(bayesNet, (key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1");
@ -107,14 +226,17 @@ TEST_UNSAFE( DiscreteMarginals, thinTree ) {
bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
}
// create a BayesTree out of a Bayes net
// create a BayesTree out of a Bayes net
DiscreteBayesTree bayesTree(bayesNet);
if (debug) {
GTSAM_PRINT(bayesTree);
bayesTree.saveGraph("/tmp/discreteBayesTree.dot");
}
// Check whether BN and BT give the same answer on all configurations
// Check whether BN and BT give the same answer on all configurations
// Also calculate all some marginals
Vector marginals = zero(15);
double shortcut0, sum0;
vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7]
& key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
@ -123,7 +245,44 @@ TEST_UNSAFE( DiscreteMarginals, thinTree ) {
double expected = evaluate(bayesNet, x);
double actual = evaluate(bayesTree, x);
DOUBLES_EQUAL(expected, actual, 1e-9);
// collect marginals
for (size_t i = 0; i < 15; i++)
if (x[i])
marginals[i] += actual;
// calculate a deep shortcut
if (x[12] && x[14] & x[8])
shortcut0 += actual;
if (x[14])
sum0 += actual;
}
DiscreteFactor::Values all1 = allPosbValues.back();
// check shortcut P(S0||R) to root
Clique::shared_ptr R = bayesTree.root();
Clique::shared_ptr c = bayesTree[0];
DiscreteBayesNet shortcut = c->shortcut(R, &EliminateDiscrete);
EXPECT_DOUBLES_EQUAL(shortcut0/sum0, evaluate(shortcut,all1), 1e-9);
// calculate all shortcuts to root
DiscreteBayesTree::Nodes cliques = bayesTree.nodes();
BOOST_FOREACH(Clique::shared_ptr c, cliques) {
DiscreteBayesNet shortcut = c->shortcut(R, &EliminateDiscrete);
if (debug) {
c->printSignature();
shortcut.print("shortcut:");
}
}
// Check all marginals
DiscreteFactor::shared_ptr marginalFactor;
for (size_t i = 0; i < 15; i++) {
marginalFactor = bayesTree.marginalFactor(i, &EliminateDiscrete);
DiscreteFactor::Values x;
x[i] = 1;
double actual = (*marginalFactor)(x);
EXPECT_DOUBLES_EQUAL(marginals[i], actual, 1e-9);
}
}
/* ************************************************************************* */