2-variable joint marginals computed with shortcuts are buggy: they only work if the cliques are disjoint or one of them is the root. They now throw an exception if that is the case.

release/4.3a0
Frank Dellaert 2012-09-17 03:31:24 +00:00
parent 970efd9e29
commit 1f0cc0aaa4
5 changed files with 152 additions and 42 deletions

View File

@ -85,14 +85,14 @@ TEST_UNSAFE( DiscreteBayesTree, thinTree ) {
const int nrNodes = 15;
const size_t nrStates = 2;
// define variables
// define variables
vector<DiscreteKey> key;
for (int i = 0; i < nrNodes; i++) {
DiscreteKey key_i(i, nrStates);
key.push_back(key_i);
}
// create a thin-tree Bayesnet, a la Jean-Guillaume
// create a thin-tree Bayesnet, a la Jean-Guillaume
DiscreteBayesNet bayesNet;
add_front(bayesNet, key[14] % "1/3");
@ -119,17 +119,18 @@ TEST_UNSAFE( DiscreteBayesTree, thinTree ) {
bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
}
// create a BayesTree out of a Bayes net
// create a BayesTree out of a Bayes net
DiscreteBayesTree bayesTree(bayesNet);
if (debug) {
GTSAM_PRINT(bayesTree);
bayesTree.saveGraph("/tmp/discreteBayesTree.dot");
}
// Check whether BN and BT give the same answer on all configurations
// Also calculate all some marginals
// Check whether BN and BT give the same answer on all configurations
// Also calculate all some marginals
Vector marginals = zero(15);
double shortcut8, shortcut0;
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint82 = 0,
joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, joint_4_11 = 0;
vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7]
& key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
@ -144,48 +145,94 @@ TEST_UNSAFE( DiscreteBayesTree, thinTree ) {
marginals[i] += actual;
// calculate shortcut 8 and 0
if (x[12] && x[14])
shortcut8 += actual;
joint_12_14 += actual;
if (x[9] && x[12] & x[14])
joint_9_12_14 += actual;
if (x[8] && x[12] & x[14])
shortcut0 += actual;
joint_8_12_14 += actual;
if (x[8] && x[2])
joint82 += actual;
if (x[1] && x[2])
joint12 += actual;
if (x[2] && x[4])
joint24 += actual;
if (x[4] && x[5])
joint45 += actual;
if (x[4] && x[6])
joint46 += actual;
if (x[4] && x[11])
joint_4_11 += actual;
}
DiscreteFactor::Values all1 = allPosbValues.back();
// check shortcut P(S9||R) to root
// check shortcut P(S9||R) to root
Clique::shared_ptr R = bayesTree.root();
Clique::shared_ptr c = bayesTree[9];
DiscreteBayesNet shortcut = c->shortcut(R, &EliminateDiscrete);
DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
EXPECT_LONGS_EQUAL(0, shortcut.size());
// check shortcut P(S8||R) to root
// check shortcut P(S8||R) to root
c = bayesTree[8];
shortcut = c->shortcut(R, &EliminateDiscrete);
EXPECT_DOUBLES_EQUAL(shortcut8/marginals[14], evaluate(shortcut,all1), 1e-9);
shortcut = c->shortcut(R, EliminateDiscrete);
EXPECT_DOUBLES_EQUAL(joint_12_14/marginals[14], evaluate(shortcut,all1),
1e-9);
// check shortcut P(S0||R) to root
// check shortcut P(S2||R) to root
c = bayesTree[2];
shortcut = c->shortcut(R, EliminateDiscrete);
EXPECT_DOUBLES_EQUAL(joint_9_12_14/marginals[14], evaluate(shortcut,all1),
1e-9);
// check shortcut P(S0||R) to root
c = bayesTree[0];
shortcut = c->shortcut(R, &EliminateDiscrete);
EXPECT_DOUBLES_EQUAL(shortcut0/marginals[14], evaluate(shortcut,all1), 1e-9);
shortcut = c->shortcut(R, EliminateDiscrete);
EXPECT_DOUBLES_EQUAL(joint_8_12_14/marginals[14], evaluate(shortcut,all1),
1e-9);
// calculate all shortcuts to root
// calculate all shortcuts to root
DiscreteBayesTree::Nodes cliques = bayesTree.nodes();
BOOST_FOREACH(Clique::shared_ptr c, cliques) {
DiscreteBayesNet shortcut = c->shortcut(R, &EliminateDiscrete);
DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
if (debug) {
c->printSignature();
shortcut.print("shortcut:");
}
}
// Check all marginals
// Check all marginals
DiscreteFactor::shared_ptr marginalFactor;
for (size_t i = 0; i < 15; i++) {
marginalFactor = bayesTree.marginalFactor(i, &EliminateDiscrete);
DiscreteFactor::Values x;
x[i] = 1;
double actual = (*marginalFactor)(x);
marginalFactor = bayesTree.marginalFactor(i, EliminateDiscrete);
double actual = (*marginalFactor)(all1);
EXPECT_DOUBLES_EQUAL(marginals[i], actual, 1e-9);
}
DiscreteBayesNet::shared_ptr actualJoint;
// Check joint P(8,2) TODO: not disjoint !
// actualJoint = bayesTree.jointBayesNet(8, 2, EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(joint82, evaluate(*actualJoint,all1), 1e-9);
// Check joint P(1,2) TODO: not disjoint !
// actualJoint = bayesTree.jointBayesNet(1, 2, EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(joint12, evaluate(*actualJoint,all1), 1e-9);
// Check joint P(2,4)
actualJoint = bayesTree.jointBayesNet(2, 4, EliminateDiscrete);
EXPECT_DOUBLES_EQUAL(joint24, evaluate(*actualJoint,all1), 1e-9);
// Check joint P(4,5) TODO: not disjoint !
// actualJoint = bayesTree.jointBayesNet(4, 5, EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
// Check joint P(4,6) TODO: not disjoint !
// actualJoint = bayesTree.jointBayesNet(4, 6, EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
// Check joint P(4,11)
actualJoint = bayesTree.jointBayesNet(4, 11, EliminateDiscrete);
EXPECT_DOUBLES_EQUAL(joint_4_11, evaluate(*actualJoint,all1), 1e-9);
}
/* ************************************************************************* */

View File

@ -186,10 +186,16 @@ namespace gtsam {
*/
typename BayesNet<CONDITIONAL>::shared_ptr marginalBayesNet(Index j, Eliminate function) const;
/** return joint on two variables */
/**
* return joint on two variables
* Limitation: can only calculate joint if cliques are disjoint or one of them is root
*/
typename FactorGraph<FactorType>::shared_ptr joint(Index j1, Index j2, Eliminate function) const;
/** return joint on two variables as a BayesNet */
/**
* return joint on two variables as a BayesNet
* Limitation: can only calculate joint if cliques are disjoint or one of them is root
*/
typename BayesNet<CONDITIONAL>::shared_ptr jointBayesNet(Index j1, Index j2, Eliminate function) const;
/**

View File

@ -221,32 +221,38 @@ namespace gtsam {
Eliminate function) const {
// For now, assume neither is the root
sharedConditional p_F1_S1 = this->conditional();
sharedConditional p_F2_S2 = C2->conditional();
// Combine P(F1|S1), P(S1|R), P(F2|S2), P(S2|R), and P(R)
FactorGraph<FactorType> joint;
if (!isRoot())
joint.push_back(this->conditional()->toFactor()); // P(F1|S1)
if (!isRoot())
if (!isRoot()) {
joint.push_back(p_F1_S1->toFactor()); // P(F1|S1)
joint.push_back(shortcut(R, function)); // P(S1|R)
if (!C2->isRoot())
joint.push_back(C2->conditional()->toFactor()); // P(F2|S2)
if (!C2->isRoot())
}
if (!C2->isRoot()) {
joint.push_back(p_F2_S2->toFactor()); // P(F2|S2)
joint.push_back(C2->shortcut(R, function)); // P(S2|R)
}
joint.push_back(R->conditional()->toFactor()); // P(R)
// Find the keys of both C1 and C2
std::vector<Index> keys1(conditional_->keys());
std::vector<Index> keys2(C2->conditional_->keys());
FastSet<Index> keys12;
keys12.insert(keys1.begin(), keys1.end());
keys12.insert(keys2.begin(), keys2.end());
// Merge the keys of C1 and C2
std::vector<Index> keys12;
std::vector<Index> &indices1 = p_F1_S1->keys(), &indices2 = p_F2_S2->keys();
std::set_union(indices1.begin(), indices1.end(), //
indices2.begin(), indices2.end(), std::back_inserter(keys12));
// Check validity
bool cliques_intersect = (keys12.size() < indices1.size() + indices2.size());
if (!isRoot() && !C2->isRoot() && cliques_intersect)
throw std::runtime_error(
"BayesTreeCliqueBase::joint can only calculate joint if cliques are disjoint\n"
"or one of them is the root clique");
// Calculate the marginal
std::vector<Index> keys12vector;
keys12vector.reserve(keys12.size());
keys12vector.insert(keys12vector.begin(), keys12.begin(), keys12.end());
assertInvariants();
GenericSequentialSolver<FactorType> solver(joint);
return *solver.jointFactorGraph(keys12vector, function);
return *solver.jointFactorGraph(keys12, function);
}
/* ************************************************************************* */

View File

@ -192,7 +192,10 @@ namespace gtsam {
FactorGraph<FactorType> marginal(derived_ptr root,
Eliminate function) const;
/** return the joint P(C1,C2), where C1==this. TODO: not a method? */
/**
* return the joint P(C1,C2), where C1==this. TODO: not a method?
* Limitation: can only calculate joint if cliques are disjoint or one of them is root
*/
FactorGraph<FactorType> joint(derived_ptr C2, derived_ptr root,
Eliminate function) const;

View File

@ -99,6 +99,16 @@ TEST_UNSAFE( SymbolicBayesTree, thinTree ) {
EXPECT(assert_equal(expected, shortcut));
}
{
// check shortcut P(S2||R) to root
SymbolicBayesTree::Clique::shared_ptr c = bayesTree[2];
SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic);
SymbolicBayesNet expected;
expected.push_front(boost::make_shared<IndexConditional>(12, 14));
expected.push_front(boost::make_shared<IndexConditional>(9, 12, 14));
EXPECT(assert_equal(expected, shortcut));
}
{
// check shortcut P(S0||R) to root
SymbolicBayesTree::Clique::shared_ptr c = bayesTree[0];
@ -108,6 +118,44 @@ TEST_UNSAFE( SymbolicBayesTree, thinTree ) {
expected.push_front(boost::make_shared<IndexConditional>(8, 12, 14));
EXPECT(assert_equal(expected, shortcut));
}
SymbolicBayesNet::shared_ptr actualJoint;
// Check joint P(8,2)
if (false) { // TODO, not disjoint
actualJoint = bayesTree.jointBayesNet(8, 2, EliminateSymbolic);
SymbolicBayesNet expected;
expected.push_front(boost::make_shared<IndexConditional>(8));
expected.push_front(boost::make_shared<IndexConditional>(2, 8));
EXPECT(assert_equal(expected, *actualJoint));
}
// Check joint P(1,2)
if (false) { // TODO, not disjoint
actualJoint = bayesTree.jointBayesNet(1, 2, EliminateSymbolic);
SymbolicBayesNet expected;
expected.push_front(boost::make_shared<IndexConditional>(2));
expected.push_front(boost::make_shared<IndexConditional>(1, 2));
EXPECT(assert_equal(expected, *actualJoint));
}
// Check joint P(2,6)
if (true) {
actualJoint = bayesTree.jointBayesNet(2, 6, EliminateSymbolic);
SymbolicBayesNet expected;
expected.push_front(boost::make_shared<IndexConditional>(6));
expected.push_front(boost::make_shared<IndexConditional>(2, 6));
EXPECT(assert_equal(expected, *actualJoint));
}
// Check joint P(4,6)
if (false) { // TODO, not disjoint
actualJoint = bayesTree.jointBayesNet(4, 6, EliminateSymbolic);
SymbolicBayesNet expected;
expected.push_front(boost::make_shared<IndexConditional>(6));
expected.push_front(boost::make_shared<IndexConditional>(4, 6));
EXPECT(assert_equal(expected, *actualJoint));
}
}
/* ************************************************************************* *