2-variable joint marginals computed with shortcuts are buggy: they only work if the cliques are disjoint or one of them is the root. They now throw an exception if that is the case.
parent
970efd9e29
commit
1f0cc0aaa4
|
|
@ -85,14 +85,14 @@ TEST_UNSAFE( DiscreteBayesTree, thinTree ) {
|
||||||
const int nrNodes = 15;
|
const int nrNodes = 15;
|
||||||
const size_t nrStates = 2;
|
const size_t nrStates = 2;
|
||||||
|
|
||||||
// define variables
|
// define variables
|
||||||
vector<DiscreteKey> key;
|
vector<DiscreteKey> key;
|
||||||
for (int i = 0; i < nrNodes; i++) {
|
for (int i = 0; i < nrNodes; i++) {
|
||||||
DiscreteKey key_i(i, nrStates);
|
DiscreteKey key_i(i, nrStates);
|
||||||
key.push_back(key_i);
|
key.push_back(key_i);
|
||||||
}
|
}
|
||||||
|
|
||||||
// create a thin-tree Bayesnet, a la Jean-Guillaume
|
// create a thin-tree Bayesnet, a la Jean-Guillaume
|
||||||
DiscreteBayesNet bayesNet;
|
DiscreteBayesNet bayesNet;
|
||||||
add_front(bayesNet, key[14] % "1/3");
|
add_front(bayesNet, key[14] % "1/3");
|
||||||
|
|
||||||
|
|
@ -119,17 +119,18 @@ TEST_UNSAFE( DiscreteBayesTree, thinTree ) {
|
||||||
bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
|
bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
|
||||||
}
|
}
|
||||||
|
|
||||||
// create a BayesTree out of a Bayes net
|
// create a BayesTree out of a Bayes net
|
||||||
DiscreteBayesTree bayesTree(bayesNet);
|
DiscreteBayesTree bayesTree(bayesNet);
|
||||||
if (debug) {
|
if (debug) {
|
||||||
GTSAM_PRINT(bayesTree);
|
GTSAM_PRINT(bayesTree);
|
||||||
bayesTree.saveGraph("/tmp/discreteBayesTree.dot");
|
bayesTree.saveGraph("/tmp/discreteBayesTree.dot");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check whether BN and BT give the same answer on all configurations
|
// Check whether BN and BT give the same answer on all configurations
|
||||||
// Also calculate all some marginals
|
// Also calculate all some marginals
|
||||||
Vector marginals = zero(15);
|
Vector marginals = zero(15);
|
||||||
double shortcut8, shortcut0;
|
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint82 = 0,
|
||||||
|
joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, joint_4_11 = 0;
|
||||||
vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
|
vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
|
||||||
key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7]
|
key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7]
|
||||||
& key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
|
& key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
|
||||||
|
|
@ -144,48 +145,94 @@ TEST_UNSAFE( DiscreteBayesTree, thinTree ) {
|
||||||
marginals[i] += actual;
|
marginals[i] += actual;
|
||||||
// calculate shortcut 8 and 0
|
// calculate shortcut 8 and 0
|
||||||
if (x[12] && x[14])
|
if (x[12] && x[14])
|
||||||
shortcut8 += actual;
|
joint_12_14 += actual;
|
||||||
|
if (x[9] && x[12] & x[14])
|
||||||
|
joint_9_12_14 += actual;
|
||||||
if (x[8] && x[12] & x[14])
|
if (x[8] && x[12] & x[14])
|
||||||
shortcut0 += actual;
|
joint_8_12_14 += actual;
|
||||||
|
if (x[8] && x[2])
|
||||||
|
joint82 += actual;
|
||||||
|
if (x[1] && x[2])
|
||||||
|
joint12 += actual;
|
||||||
|
if (x[2] && x[4])
|
||||||
|
joint24 += actual;
|
||||||
|
if (x[4] && x[5])
|
||||||
|
joint45 += actual;
|
||||||
|
if (x[4] && x[6])
|
||||||
|
joint46 += actual;
|
||||||
|
if (x[4] && x[11])
|
||||||
|
joint_4_11 += actual;
|
||||||
}
|
}
|
||||||
DiscreteFactor::Values all1 = allPosbValues.back();
|
DiscreteFactor::Values all1 = allPosbValues.back();
|
||||||
|
|
||||||
// check shortcut P(S9||R) to root
|
// check shortcut P(S9||R) to root
|
||||||
Clique::shared_ptr R = bayesTree.root();
|
Clique::shared_ptr R = bayesTree.root();
|
||||||
Clique::shared_ptr c = bayesTree[9];
|
Clique::shared_ptr c = bayesTree[9];
|
||||||
DiscreteBayesNet shortcut = c->shortcut(R, &EliminateDiscrete);
|
DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
|
||||||
EXPECT_LONGS_EQUAL(0, shortcut.size());
|
EXPECT_LONGS_EQUAL(0, shortcut.size());
|
||||||
|
|
||||||
// check shortcut P(S8||R) to root
|
// check shortcut P(S8||R) to root
|
||||||
c = bayesTree[8];
|
c = bayesTree[8];
|
||||||
shortcut = c->shortcut(R, &EliminateDiscrete);
|
shortcut = c->shortcut(R, EliminateDiscrete);
|
||||||
EXPECT_DOUBLES_EQUAL(shortcut8/marginals[14], evaluate(shortcut,all1), 1e-9);
|
EXPECT_DOUBLES_EQUAL(joint_12_14/marginals[14], evaluate(shortcut,all1),
|
||||||
|
1e-9);
|
||||||
|
|
||||||
// check shortcut P(S0||R) to root
|
// check shortcut P(S2||R) to root
|
||||||
|
c = bayesTree[2];
|
||||||
|
shortcut = c->shortcut(R, EliminateDiscrete);
|
||||||
|
EXPECT_DOUBLES_EQUAL(joint_9_12_14/marginals[14], evaluate(shortcut,all1),
|
||||||
|
1e-9);
|
||||||
|
|
||||||
|
// check shortcut P(S0||R) to root
|
||||||
c = bayesTree[0];
|
c = bayesTree[0];
|
||||||
shortcut = c->shortcut(R, &EliminateDiscrete);
|
shortcut = c->shortcut(R, EliminateDiscrete);
|
||||||
EXPECT_DOUBLES_EQUAL(shortcut0/marginals[14], evaluate(shortcut,all1), 1e-9);
|
EXPECT_DOUBLES_EQUAL(joint_8_12_14/marginals[14], evaluate(shortcut,all1),
|
||||||
|
1e-9);
|
||||||
|
|
||||||
// calculate all shortcuts to root
|
// calculate all shortcuts to root
|
||||||
DiscreteBayesTree::Nodes cliques = bayesTree.nodes();
|
DiscreteBayesTree::Nodes cliques = bayesTree.nodes();
|
||||||
BOOST_FOREACH(Clique::shared_ptr c, cliques) {
|
BOOST_FOREACH(Clique::shared_ptr c, cliques) {
|
||||||
DiscreteBayesNet shortcut = c->shortcut(R, &EliminateDiscrete);
|
DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
|
||||||
if (debug) {
|
if (debug) {
|
||||||
c->printSignature();
|
c->printSignature();
|
||||||
shortcut.print("shortcut:");
|
shortcut.print("shortcut:");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check all marginals
|
// Check all marginals
|
||||||
DiscreteFactor::shared_ptr marginalFactor;
|
DiscreteFactor::shared_ptr marginalFactor;
|
||||||
for (size_t i = 0; i < 15; i++) {
|
for (size_t i = 0; i < 15; i++) {
|
||||||
marginalFactor = bayesTree.marginalFactor(i, &EliminateDiscrete);
|
marginalFactor = bayesTree.marginalFactor(i, EliminateDiscrete);
|
||||||
DiscreteFactor::Values x;
|
double actual = (*marginalFactor)(all1);
|
||||||
x[i] = 1;
|
|
||||||
double actual = (*marginalFactor)(x);
|
|
||||||
EXPECT_DOUBLES_EQUAL(marginals[i], actual, 1e-9);
|
EXPECT_DOUBLES_EQUAL(marginals[i], actual, 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DiscreteBayesNet::shared_ptr actualJoint;
|
||||||
|
|
||||||
|
// Check joint P(8,2) TODO: not disjoint !
|
||||||
|
// actualJoint = bayesTree.jointBayesNet(8, 2, EliminateDiscrete);
|
||||||
|
// EXPECT_DOUBLES_EQUAL(joint82, evaluate(*actualJoint,all1), 1e-9);
|
||||||
|
|
||||||
|
// Check joint P(1,2) TODO: not disjoint !
|
||||||
|
// actualJoint = bayesTree.jointBayesNet(1, 2, EliminateDiscrete);
|
||||||
|
// EXPECT_DOUBLES_EQUAL(joint12, evaluate(*actualJoint,all1), 1e-9);
|
||||||
|
|
||||||
|
// Check joint P(2,4)
|
||||||
|
actualJoint = bayesTree.jointBayesNet(2, 4, EliminateDiscrete);
|
||||||
|
EXPECT_DOUBLES_EQUAL(joint24, evaluate(*actualJoint,all1), 1e-9);
|
||||||
|
|
||||||
|
// Check joint P(4,5) TODO: not disjoint !
|
||||||
|
// actualJoint = bayesTree.jointBayesNet(4, 5, EliminateDiscrete);
|
||||||
|
// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
|
||||||
|
|
||||||
|
// Check joint P(4,6) TODO: not disjoint !
|
||||||
|
// actualJoint = bayesTree.jointBayesNet(4, 6, EliminateDiscrete);
|
||||||
|
// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
|
||||||
|
|
||||||
|
// Check joint P(4,11)
|
||||||
|
actualJoint = bayesTree.jointBayesNet(4, 11, EliminateDiscrete);
|
||||||
|
EXPECT_DOUBLES_EQUAL(joint_4_11, evaluate(*actualJoint,all1), 1e-9);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
|
|
@ -186,10 +186,16 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
typename BayesNet<CONDITIONAL>::shared_ptr marginalBayesNet(Index j, Eliminate function) const;
|
typename BayesNet<CONDITIONAL>::shared_ptr marginalBayesNet(Index j, Eliminate function) const;
|
||||||
|
|
||||||
/** return joint on two variables */
|
/**
|
||||||
|
* return joint on two variables
|
||||||
|
* Limitation: can only calculate joint if cliques are disjoint or one of them is root
|
||||||
|
*/
|
||||||
typename FactorGraph<FactorType>::shared_ptr joint(Index j1, Index j2, Eliminate function) const;
|
typename FactorGraph<FactorType>::shared_ptr joint(Index j1, Index j2, Eliminate function) const;
|
||||||
|
|
||||||
/** return joint on two variables as a BayesNet */
|
/**
|
||||||
|
* return joint on two variables as a BayesNet
|
||||||
|
* Limitation: can only calculate joint if cliques are disjoint or one of them is root
|
||||||
|
*/
|
||||||
typename BayesNet<CONDITIONAL>::shared_ptr jointBayesNet(Index j1, Index j2, Eliminate function) const;
|
typename BayesNet<CONDITIONAL>::shared_ptr jointBayesNet(Index j1, Index j2, Eliminate function) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
|
|
@ -221,32 +221,38 @@ namespace gtsam {
|
||||||
Eliminate function) const {
|
Eliminate function) const {
|
||||||
// For now, assume neither is the root
|
// For now, assume neither is the root
|
||||||
|
|
||||||
|
sharedConditional p_F1_S1 = this->conditional();
|
||||||
|
sharedConditional p_F2_S2 = C2->conditional();
|
||||||
|
|
||||||
// Combine P(F1|S1), P(S1|R), P(F2|S2), P(S2|R), and P(R)
|
// Combine P(F1|S1), P(S1|R), P(F2|S2), P(S2|R), and P(R)
|
||||||
FactorGraph<FactorType> joint;
|
FactorGraph<FactorType> joint;
|
||||||
if (!isRoot())
|
if (!isRoot()) {
|
||||||
joint.push_back(this->conditional()->toFactor()); // P(F1|S1)
|
joint.push_back(p_F1_S1->toFactor()); // P(F1|S1)
|
||||||
if (!isRoot())
|
|
||||||
joint.push_back(shortcut(R, function)); // P(S1|R)
|
joint.push_back(shortcut(R, function)); // P(S1|R)
|
||||||
if (!C2->isRoot())
|
}
|
||||||
joint.push_back(C2->conditional()->toFactor()); // P(F2|S2)
|
if (!C2->isRoot()) {
|
||||||
if (!C2->isRoot())
|
joint.push_back(p_F2_S2->toFactor()); // P(F2|S2)
|
||||||
joint.push_back(C2->shortcut(R, function)); // P(S2|R)
|
joint.push_back(C2->shortcut(R, function)); // P(S2|R)
|
||||||
|
}
|
||||||
joint.push_back(R->conditional()->toFactor()); // P(R)
|
joint.push_back(R->conditional()->toFactor()); // P(R)
|
||||||
|
|
||||||
// Find the keys of both C1 and C2
|
// Merge the keys of C1 and C2
|
||||||
std::vector<Index> keys1(conditional_->keys());
|
std::vector<Index> keys12;
|
||||||
std::vector<Index> keys2(C2->conditional_->keys());
|
std::vector<Index> &indices1 = p_F1_S1->keys(), &indices2 = p_F2_S2->keys();
|
||||||
FastSet<Index> keys12;
|
std::set_union(indices1.begin(), indices1.end(), //
|
||||||
keys12.insert(keys1.begin(), keys1.end());
|
indices2.begin(), indices2.end(), std::back_inserter(keys12));
|
||||||
keys12.insert(keys2.begin(), keys2.end());
|
|
||||||
|
// Check validity
|
||||||
|
bool cliques_intersect = (keys12.size() < indices1.size() + indices2.size());
|
||||||
|
if (!isRoot() && !C2->isRoot() && cliques_intersect)
|
||||||
|
throw std::runtime_error(
|
||||||
|
"BayesTreeCliqueBase::joint can only calculate joint if cliques are disjoint\n"
|
||||||
|
"or one of them is the root clique");
|
||||||
|
|
||||||
// Calculate the marginal
|
// Calculate the marginal
|
||||||
std::vector<Index> keys12vector;
|
|
||||||
keys12vector.reserve(keys12.size());
|
|
||||||
keys12vector.insert(keys12vector.begin(), keys12.begin(), keys12.end());
|
|
||||||
assertInvariants();
|
assertInvariants();
|
||||||
GenericSequentialSolver<FactorType> solver(joint);
|
GenericSequentialSolver<FactorType> solver(joint);
|
||||||
return *solver.jointFactorGraph(keys12vector, function);
|
return *solver.jointFactorGraph(keys12, function);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
|
|
@ -192,7 +192,10 @@ namespace gtsam {
|
||||||
FactorGraph<FactorType> marginal(derived_ptr root,
|
FactorGraph<FactorType> marginal(derived_ptr root,
|
||||||
Eliminate function) const;
|
Eliminate function) const;
|
||||||
|
|
||||||
/** return the joint P(C1,C2), where C1==this. TODO: not a method? */
|
/**
|
||||||
|
* return the joint P(C1,C2), where C1==this. TODO: not a method?
|
||||||
|
* Limitation: can only calculate joint if cliques are disjoint or one of them is root
|
||||||
|
*/
|
||||||
FactorGraph<FactorType> joint(derived_ptr C2, derived_ptr root,
|
FactorGraph<FactorType> joint(derived_ptr C2, derived_ptr root,
|
||||||
Eliminate function) const;
|
Eliminate function) const;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -99,6 +99,16 @@ TEST_UNSAFE( SymbolicBayesTree, thinTree ) {
|
||||||
EXPECT(assert_equal(expected, shortcut));
|
EXPECT(assert_equal(expected, shortcut));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// check shortcut P(S2||R) to root
|
||||||
|
SymbolicBayesTree::Clique::shared_ptr c = bayesTree[2];
|
||||||
|
SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic);
|
||||||
|
SymbolicBayesNet expected;
|
||||||
|
expected.push_front(boost::make_shared<IndexConditional>(12, 14));
|
||||||
|
expected.push_front(boost::make_shared<IndexConditional>(9, 12, 14));
|
||||||
|
EXPECT(assert_equal(expected, shortcut));
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
// check shortcut P(S0||R) to root
|
// check shortcut P(S0||R) to root
|
||||||
SymbolicBayesTree::Clique::shared_ptr c = bayesTree[0];
|
SymbolicBayesTree::Clique::shared_ptr c = bayesTree[0];
|
||||||
|
|
@ -108,6 +118,44 @@ TEST_UNSAFE( SymbolicBayesTree, thinTree ) {
|
||||||
expected.push_front(boost::make_shared<IndexConditional>(8, 12, 14));
|
expected.push_front(boost::make_shared<IndexConditional>(8, 12, 14));
|
||||||
EXPECT(assert_equal(expected, shortcut));
|
EXPECT(assert_equal(expected, shortcut));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SymbolicBayesNet::shared_ptr actualJoint;
|
||||||
|
|
||||||
|
// Check joint P(8,2)
|
||||||
|
if (false) { // TODO, not disjoint
|
||||||
|
actualJoint = bayesTree.jointBayesNet(8, 2, EliminateSymbolic);
|
||||||
|
SymbolicBayesNet expected;
|
||||||
|
expected.push_front(boost::make_shared<IndexConditional>(8));
|
||||||
|
expected.push_front(boost::make_shared<IndexConditional>(2, 8));
|
||||||
|
EXPECT(assert_equal(expected, *actualJoint));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check joint P(1,2)
|
||||||
|
if (false) { // TODO, not disjoint
|
||||||
|
actualJoint = bayesTree.jointBayesNet(1, 2, EliminateSymbolic);
|
||||||
|
SymbolicBayesNet expected;
|
||||||
|
expected.push_front(boost::make_shared<IndexConditional>(2));
|
||||||
|
expected.push_front(boost::make_shared<IndexConditional>(1, 2));
|
||||||
|
EXPECT(assert_equal(expected, *actualJoint));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check joint P(2,6)
|
||||||
|
if (true) {
|
||||||
|
actualJoint = bayesTree.jointBayesNet(2, 6, EliminateSymbolic);
|
||||||
|
SymbolicBayesNet expected;
|
||||||
|
expected.push_front(boost::make_shared<IndexConditional>(6));
|
||||||
|
expected.push_front(boost::make_shared<IndexConditional>(2, 6));
|
||||||
|
EXPECT(assert_equal(expected, *actualJoint));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check joint P(4,6)
|
||||||
|
if (false) { // TODO, not disjoint
|
||||||
|
actualJoint = bayesTree.jointBayesNet(4, 6, EliminateSymbolic);
|
||||||
|
SymbolicBayesNet expected;
|
||||||
|
expected.push_front(boost::make_shared<IndexConditional>(6));
|
||||||
|
expected.push_front(boost::make_shared<IndexConditional>(4, 6));
|
||||||
|
EXPECT(assert_equal(expected, *actualJoint));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* *
|
/* ************************************************************************* *
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue