2-variable joint marginals computed with shortcuts are buggy: they only work if the cliques are disjoint or one of them is the root. They now throw an exception if that is the case.
parent
970efd9e29
commit
1f0cc0aaa4
|
@ -85,14 +85,14 @@ TEST_UNSAFE( DiscreteBayesTree, thinTree ) {
|
|||
const int nrNodes = 15;
|
||||
const size_t nrStates = 2;
|
||||
|
||||
// define variables
|
||||
// define variables
|
||||
vector<DiscreteKey> key;
|
||||
for (int i = 0; i < nrNodes; i++) {
|
||||
DiscreteKey key_i(i, nrStates);
|
||||
key.push_back(key_i);
|
||||
}
|
||||
|
||||
// create a thin-tree Bayesnet, a la Jean-Guillaume
|
||||
// create a thin-tree Bayesnet, a la Jean-Guillaume
|
||||
DiscreteBayesNet bayesNet;
|
||||
add_front(bayesNet, key[14] % "1/3");
|
||||
|
||||
|
@ -119,17 +119,18 @@ TEST_UNSAFE( DiscreteBayesTree, thinTree ) {
|
|||
bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
|
||||
}
|
||||
|
||||
// create a BayesTree out of a Bayes net
|
||||
// create a BayesTree out of a Bayes net
|
||||
DiscreteBayesTree bayesTree(bayesNet);
|
||||
if (debug) {
|
||||
GTSAM_PRINT(bayesTree);
|
||||
bayesTree.saveGraph("/tmp/discreteBayesTree.dot");
|
||||
}
|
||||
|
||||
// Check whether BN and BT give the same answer on all configurations
|
||||
// Also calculate all some marginals
|
||||
// Check whether BN and BT give the same answer on all configurations
|
||||
// Also calculate all some marginals
|
||||
Vector marginals = zero(15);
|
||||
double shortcut8, shortcut0;
|
||||
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint82 = 0,
|
||||
joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, joint_4_11 = 0;
|
||||
vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
|
||||
key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7]
|
||||
& key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
|
||||
|
@ -144,48 +145,94 @@ TEST_UNSAFE( DiscreteBayesTree, thinTree ) {
|
|||
marginals[i] += actual;
|
||||
// calculate shortcut 8 and 0
|
||||
if (x[12] && x[14])
|
||||
shortcut8 += actual;
|
||||
joint_12_14 += actual;
|
||||
if (x[9] && x[12] & x[14])
|
||||
joint_9_12_14 += actual;
|
||||
if (x[8] && x[12] & x[14])
|
||||
shortcut0 += actual;
|
||||
joint_8_12_14 += actual;
|
||||
if (x[8] && x[2])
|
||||
joint82 += actual;
|
||||
if (x[1] && x[2])
|
||||
joint12 += actual;
|
||||
if (x[2] && x[4])
|
||||
joint24 += actual;
|
||||
if (x[4] && x[5])
|
||||
joint45 += actual;
|
||||
if (x[4] && x[6])
|
||||
joint46 += actual;
|
||||
if (x[4] && x[11])
|
||||
joint_4_11 += actual;
|
||||
}
|
||||
DiscreteFactor::Values all1 = allPosbValues.back();
|
||||
|
||||
// check shortcut P(S9||R) to root
|
||||
// check shortcut P(S9||R) to root
|
||||
Clique::shared_ptr R = bayesTree.root();
|
||||
Clique::shared_ptr c = bayesTree[9];
|
||||
DiscreteBayesNet shortcut = c->shortcut(R, &EliminateDiscrete);
|
||||
DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
|
||||
EXPECT_LONGS_EQUAL(0, shortcut.size());
|
||||
|
||||
// check shortcut P(S8||R) to root
|
||||
// check shortcut P(S8||R) to root
|
||||
c = bayesTree[8];
|
||||
shortcut = c->shortcut(R, &EliminateDiscrete);
|
||||
EXPECT_DOUBLES_EQUAL(shortcut8/marginals[14], evaluate(shortcut,all1), 1e-9);
|
||||
shortcut = c->shortcut(R, EliminateDiscrete);
|
||||
EXPECT_DOUBLES_EQUAL(joint_12_14/marginals[14], evaluate(shortcut,all1),
|
||||
1e-9);
|
||||
|
||||
// check shortcut P(S0||R) to root
|
||||
// check shortcut P(S2||R) to root
|
||||
c = bayesTree[2];
|
||||
shortcut = c->shortcut(R, EliminateDiscrete);
|
||||
EXPECT_DOUBLES_EQUAL(joint_9_12_14/marginals[14], evaluate(shortcut,all1),
|
||||
1e-9);
|
||||
|
||||
// check shortcut P(S0||R) to root
|
||||
c = bayesTree[0];
|
||||
shortcut = c->shortcut(R, &EliminateDiscrete);
|
||||
EXPECT_DOUBLES_EQUAL(shortcut0/marginals[14], evaluate(shortcut,all1), 1e-9);
|
||||
shortcut = c->shortcut(R, EliminateDiscrete);
|
||||
EXPECT_DOUBLES_EQUAL(joint_8_12_14/marginals[14], evaluate(shortcut,all1),
|
||||
1e-9);
|
||||
|
||||
// calculate all shortcuts to root
|
||||
// calculate all shortcuts to root
|
||||
DiscreteBayesTree::Nodes cliques = bayesTree.nodes();
|
||||
BOOST_FOREACH(Clique::shared_ptr c, cliques) {
|
||||
DiscreteBayesNet shortcut = c->shortcut(R, &EliminateDiscrete);
|
||||
DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
|
||||
if (debug) {
|
||||
c->printSignature();
|
||||
shortcut.print("shortcut:");
|
||||
}
|
||||
}
|
||||
|
||||
// Check all marginals
|
||||
// Check all marginals
|
||||
DiscreteFactor::shared_ptr marginalFactor;
|
||||
for (size_t i = 0; i < 15; i++) {
|
||||
marginalFactor = bayesTree.marginalFactor(i, &EliminateDiscrete);
|
||||
DiscreteFactor::Values x;
|
||||
x[i] = 1;
|
||||
double actual = (*marginalFactor)(x);
|
||||
marginalFactor = bayesTree.marginalFactor(i, EliminateDiscrete);
|
||||
double actual = (*marginalFactor)(all1);
|
||||
EXPECT_DOUBLES_EQUAL(marginals[i], actual, 1e-9);
|
||||
}
|
||||
|
||||
DiscreteBayesNet::shared_ptr actualJoint;
|
||||
|
||||
// Check joint P(8,2) TODO: not disjoint !
|
||||
// actualJoint = bayesTree.jointBayesNet(8, 2, EliminateDiscrete);
|
||||
// EXPECT_DOUBLES_EQUAL(joint82, evaluate(*actualJoint,all1), 1e-9);
|
||||
|
||||
// Check joint P(1,2) TODO: not disjoint !
|
||||
// actualJoint = bayesTree.jointBayesNet(1, 2, EliminateDiscrete);
|
||||
// EXPECT_DOUBLES_EQUAL(joint12, evaluate(*actualJoint,all1), 1e-9);
|
||||
|
||||
// Check joint P(2,4)
|
||||
actualJoint = bayesTree.jointBayesNet(2, 4, EliminateDiscrete);
|
||||
EXPECT_DOUBLES_EQUAL(joint24, evaluate(*actualJoint,all1), 1e-9);
|
||||
|
||||
// Check joint P(4,5) TODO: not disjoint !
|
||||
// actualJoint = bayesTree.jointBayesNet(4, 5, EliminateDiscrete);
|
||||
// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
|
||||
|
||||
// Check joint P(4,6) TODO: not disjoint !
|
||||
// actualJoint = bayesTree.jointBayesNet(4, 6, EliminateDiscrete);
|
||||
// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
|
||||
|
||||
// Check joint P(4,11)
|
||||
actualJoint = bayesTree.jointBayesNet(4, 11, EliminateDiscrete);
|
||||
EXPECT_DOUBLES_EQUAL(joint_4_11, evaluate(*actualJoint,all1), 1e-9);
|
||||
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -186,10 +186,16 @@ namespace gtsam {
|
|||
*/
|
||||
typename BayesNet<CONDITIONAL>::shared_ptr marginalBayesNet(Index j, Eliminate function) const;
|
||||
|
||||
/** return joint on two variables */
|
||||
/**
|
||||
* return joint on two variables
|
||||
* Limitation: can only calculate joint if cliques are disjoint or one of them is root
|
||||
*/
|
||||
typename FactorGraph<FactorType>::shared_ptr joint(Index j1, Index j2, Eliminate function) const;
|
||||
|
||||
/** return joint on two variables as a BayesNet */
|
||||
/**
|
||||
* return joint on two variables as a BayesNet
|
||||
* Limitation: can only calculate joint if cliques are disjoint or one of them is root
|
||||
*/
|
||||
typename BayesNet<CONDITIONAL>::shared_ptr jointBayesNet(Index j1, Index j2, Eliminate function) const;
|
||||
|
||||
/**
|
||||
|
|
|
@ -221,32 +221,38 @@ namespace gtsam {
|
|||
Eliminate function) const {
|
||||
// For now, assume neither is the root
|
||||
|
||||
sharedConditional p_F1_S1 = this->conditional();
|
||||
sharedConditional p_F2_S2 = C2->conditional();
|
||||
|
||||
// Combine P(F1|S1), P(S1|R), P(F2|S2), P(S2|R), and P(R)
|
||||
FactorGraph<FactorType> joint;
|
||||
if (!isRoot())
|
||||
joint.push_back(this->conditional()->toFactor()); // P(F1|S1)
|
||||
if (!isRoot())
|
||||
if (!isRoot()) {
|
||||
joint.push_back(p_F1_S1->toFactor()); // P(F1|S1)
|
||||
joint.push_back(shortcut(R, function)); // P(S1|R)
|
||||
if (!C2->isRoot())
|
||||
joint.push_back(C2->conditional()->toFactor()); // P(F2|S2)
|
||||
if (!C2->isRoot())
|
||||
}
|
||||
if (!C2->isRoot()) {
|
||||
joint.push_back(p_F2_S2->toFactor()); // P(F2|S2)
|
||||
joint.push_back(C2->shortcut(R, function)); // P(S2|R)
|
||||
}
|
||||
joint.push_back(R->conditional()->toFactor()); // P(R)
|
||||
|
||||
// Find the keys of both C1 and C2
|
||||
std::vector<Index> keys1(conditional_->keys());
|
||||
std::vector<Index> keys2(C2->conditional_->keys());
|
||||
FastSet<Index> keys12;
|
||||
keys12.insert(keys1.begin(), keys1.end());
|
||||
keys12.insert(keys2.begin(), keys2.end());
|
||||
// Merge the keys of C1 and C2
|
||||
std::vector<Index> keys12;
|
||||
std::vector<Index> &indices1 = p_F1_S1->keys(), &indices2 = p_F2_S2->keys();
|
||||
std::set_union(indices1.begin(), indices1.end(), //
|
||||
indices2.begin(), indices2.end(), std::back_inserter(keys12));
|
||||
|
||||
// Check validity
|
||||
bool cliques_intersect = (keys12.size() < indices1.size() + indices2.size());
|
||||
if (!isRoot() && !C2->isRoot() && cliques_intersect)
|
||||
throw std::runtime_error(
|
||||
"BayesTreeCliqueBase::joint can only calculate joint if cliques are disjoint\n"
|
||||
"or one of them is the root clique");
|
||||
|
||||
// Calculate the marginal
|
||||
std::vector<Index> keys12vector;
|
||||
keys12vector.reserve(keys12.size());
|
||||
keys12vector.insert(keys12vector.begin(), keys12.begin(), keys12.end());
|
||||
assertInvariants();
|
||||
GenericSequentialSolver<FactorType> solver(joint);
|
||||
return *solver.jointFactorGraph(keys12vector, function);
|
||||
return *solver.jointFactorGraph(keys12, function);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -192,7 +192,10 @@ namespace gtsam {
|
|||
FactorGraph<FactorType> marginal(derived_ptr root,
|
||||
Eliminate function) const;
|
||||
|
||||
/** return the joint P(C1,C2), where C1==this. TODO: not a method? */
|
||||
/**
|
||||
* return the joint P(C1,C2), where C1==this. TODO: not a method?
|
||||
* Limitation: can only calculate joint if cliques are disjoint or one of them is root
|
||||
*/
|
||||
FactorGraph<FactorType> joint(derived_ptr C2, derived_ptr root,
|
||||
Eliminate function) const;
|
||||
|
||||
|
|
|
@ -99,6 +99,16 @@ TEST_UNSAFE( SymbolicBayesTree, thinTree ) {
|
|||
EXPECT(assert_equal(expected, shortcut));
|
||||
}
|
||||
|
||||
{
|
||||
// check shortcut P(S2||R) to root
|
||||
SymbolicBayesTree::Clique::shared_ptr c = bayesTree[2];
|
||||
SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic);
|
||||
SymbolicBayesNet expected;
|
||||
expected.push_front(boost::make_shared<IndexConditional>(12, 14));
|
||||
expected.push_front(boost::make_shared<IndexConditional>(9, 12, 14));
|
||||
EXPECT(assert_equal(expected, shortcut));
|
||||
}
|
||||
|
||||
{
|
||||
// check shortcut P(S0||R) to root
|
||||
SymbolicBayesTree::Clique::shared_ptr c = bayesTree[0];
|
||||
|
@ -108,6 +118,44 @@ TEST_UNSAFE( SymbolicBayesTree, thinTree ) {
|
|||
expected.push_front(boost::make_shared<IndexConditional>(8, 12, 14));
|
||||
EXPECT(assert_equal(expected, shortcut));
|
||||
}
|
||||
|
||||
SymbolicBayesNet::shared_ptr actualJoint;
|
||||
|
||||
// Check joint P(8,2)
|
||||
if (false) { // TODO, not disjoint
|
||||
actualJoint = bayesTree.jointBayesNet(8, 2, EliminateSymbolic);
|
||||
SymbolicBayesNet expected;
|
||||
expected.push_front(boost::make_shared<IndexConditional>(8));
|
||||
expected.push_front(boost::make_shared<IndexConditional>(2, 8));
|
||||
EXPECT(assert_equal(expected, *actualJoint));
|
||||
}
|
||||
|
||||
// Check joint P(1,2)
|
||||
if (false) { // TODO, not disjoint
|
||||
actualJoint = bayesTree.jointBayesNet(1, 2, EliminateSymbolic);
|
||||
SymbolicBayesNet expected;
|
||||
expected.push_front(boost::make_shared<IndexConditional>(2));
|
||||
expected.push_front(boost::make_shared<IndexConditional>(1, 2));
|
||||
EXPECT(assert_equal(expected, *actualJoint));
|
||||
}
|
||||
|
||||
// Check joint P(2,6)
|
||||
if (true) {
|
||||
actualJoint = bayesTree.jointBayesNet(2, 6, EliminateSymbolic);
|
||||
SymbolicBayesNet expected;
|
||||
expected.push_front(boost::make_shared<IndexConditional>(6));
|
||||
expected.push_front(boost::make_shared<IndexConditional>(2, 6));
|
||||
EXPECT(assert_equal(expected, *actualJoint));
|
||||
}
|
||||
|
||||
// Check joint P(4,6)
|
||||
if (false) { // TODO, not disjoint
|
||||
actualJoint = bayesTree.jointBayesNet(4, 6, EliminateSymbolic);
|
||||
SymbolicBayesNet expected;
|
||||
expected.push_front(boost::make_shared<IndexConditional>(6));
|
||||
expected.push_front(boost::make_shared<IndexConditional>(4, 6));
|
||||
EXPECT(assert_equal(expected, *actualJoint));
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* *
|
||||
|
|
Loading…
Reference in New Issue