diff --git a/gtsam/discrete/tests/testDiscreteBayesTree.cpp b/gtsam/discrete/tests/testDiscreteBayesTree.cpp index 8ff50ad3c..68a3f45af 100644 --- a/gtsam/discrete/tests/testDiscreteBayesTree.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesTree.cpp @@ -85,14 +85,14 @@ TEST_UNSAFE( DiscreteBayesTree, thinTree ) { const int nrNodes = 15; const size_t nrStates = 2; -// define variables + // define variables vector key; for (int i = 0; i < nrNodes; i++) { DiscreteKey key_i(i, nrStates); key.push_back(key_i); } -// create a thin-tree Bayesnet, a la Jean-Guillaume + // create a thin-tree Bayesnet, a la Jean-Guillaume DiscreteBayesNet bayesNet; add_front(bayesNet, key[14] % "1/3"); @@ -119,17 +119,18 @@ TEST_UNSAFE( DiscreteBayesTree, thinTree ) { bayesNet.saveGraph("/tmp/discreteBayesNet.dot"); } -// create a BayesTree out of a Bayes net + // create a BayesTree out of a Bayes net DiscreteBayesTree bayesTree(bayesNet); if (debug) { GTSAM_PRINT(bayesTree); bayesTree.saveGraph("/tmp/discreteBayesTree.dot"); } -// Check whether BN and BT give the same answer on all configurations -// Also calculate all some marginals + // Check whether BN and BT give the same answer on all configurations + // Also calculate all some marginals Vector marginals = zero(15); - double shortcut8, shortcut0; + double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint82 = 0, + joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, joint_4_11 = 0; vector allPosbValues = cartesianProduct( key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] & key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]); @@ -144,48 +145,94 @@ TEST_UNSAFE( DiscreteBayesTree, thinTree ) { marginals[i] += actual; // calculate shortcut 8 and 0 if (x[12] && x[14]) - shortcut8 += actual; + joint_12_14 += actual; + if (x[9] && x[12] & x[14]) + joint_9_12_14 += actual; if (x[8] && x[12] & x[14]) - shortcut0 += actual; + joint_8_12_14 += actual; + if (x[8] && x[2]) + joint82 += actual; + if (x[1] && x[2]) + joint12 += actual; + if (x[2] && x[4]) + joint24 += actual; + if (x[4] && x[5]) + joint45 += actual; + if (x[4] && x[6]) + joint46 += actual; + if (x[4] && x[11]) + joint_4_11 += actual; } DiscreteFactor::Values all1 = allPosbValues.back(); -// check shortcut P(S9||R) to root + // check shortcut P(S9||R) to root Clique::shared_ptr R = bayesTree.root(); Clique::shared_ptr c = bayesTree[9]; - DiscreteBayesNet shortcut = c->shortcut(R, &EliminateDiscrete); + DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete); EXPECT_LONGS_EQUAL(0, shortcut.size()); -// check shortcut P(S8||R) to root + // check shortcut P(S8||R) to root c = bayesTree[8]; - shortcut = c->shortcut(R, &EliminateDiscrete); - EXPECT_DOUBLES_EQUAL(shortcut8/marginals[14], evaluate(shortcut,all1), 1e-9); + shortcut = c->shortcut(R, EliminateDiscrete); + EXPECT_DOUBLES_EQUAL(joint_12_14/marginals[14], evaluate(shortcut,all1), + 1e-9); -// check shortcut P(S0||R) to root + // check shortcut P(S2||R) to root + c = bayesTree[2]; + shortcut = c->shortcut(R, EliminateDiscrete); + EXPECT_DOUBLES_EQUAL(joint_9_12_14/marginals[14], evaluate(shortcut,all1), + 1e-9); + + // check shortcut P(S0||R) to root c = bayesTree[0]; - shortcut = c->shortcut(R, &EliminateDiscrete); - EXPECT_DOUBLES_EQUAL(shortcut0/marginals[14], evaluate(shortcut,all1), 1e-9); + shortcut = c->shortcut(R, EliminateDiscrete); + EXPECT_DOUBLES_EQUAL(joint_8_12_14/marginals[14], evaluate(shortcut,all1), + 1e-9); -// calculate all shortcuts to root + // calculate all shortcuts to root DiscreteBayesTree::Nodes cliques = bayesTree.nodes(); BOOST_FOREACH(Clique::shared_ptr c, cliques) { - DiscreteBayesNet shortcut = c->shortcut(R, &EliminateDiscrete); + DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete); if (debug) { c->printSignature(); shortcut.print("shortcut:"); } } -// Check all marginals + // Check all marginals DiscreteFactor::shared_ptr marginalFactor; for (size_t i = 0; i < 15; i++) { - marginalFactor = bayesTree.marginalFactor(i, &EliminateDiscrete); - DiscreteFactor::Values x; - x[i] = 1; - double actual = (*marginalFactor)(x); + marginalFactor = bayesTree.marginalFactor(i, EliminateDiscrete); + double actual = (*marginalFactor)(all1); EXPECT_DOUBLES_EQUAL(marginals[i], actual, 1e-9); } + DiscreteBayesNet::shared_ptr actualJoint; + + // Check joint P(8,2) TODO: not disjoint ! +// actualJoint = bayesTree.jointBayesNet(8, 2, EliminateDiscrete); +// EXPECT_DOUBLES_EQUAL(joint82, evaluate(*actualJoint,all1), 1e-9); + + // Check joint P(1,2) TODO: not disjoint ! +// actualJoint = bayesTree.jointBayesNet(1, 2, EliminateDiscrete); +// EXPECT_DOUBLES_EQUAL(joint12, evaluate(*actualJoint,all1), 1e-9); + + // Check joint P(2,4) + actualJoint = bayesTree.jointBayesNet(2, 4, EliminateDiscrete); + EXPECT_DOUBLES_EQUAL(joint24, evaluate(*actualJoint,all1), 1e-9); + + // Check joint P(4,5) TODO: not disjoint ! +// actualJoint = bayesTree.jointBayesNet(4, 5, EliminateDiscrete); +// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9); + + // Check joint P(4,6) TODO: not disjoint ! +// actualJoint = bayesTree.jointBayesNet(4, 6, EliminateDiscrete); +// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9); + + // Check joint P(4,11) + actualJoint = bayesTree.jointBayesNet(4, 11, EliminateDiscrete); + EXPECT_DOUBLES_EQUAL(joint_4_11, evaluate(*actualJoint,all1), 1e-9); + } /* ************************************************************************* */ diff --git a/gtsam/inference/BayesTree.h b/gtsam/inference/BayesTree.h index 6406fe8f6..075b0b411 100644 --- a/gtsam/inference/BayesTree.h +++ b/gtsam/inference/BayesTree.h @@ -186,10 +186,16 @@ namespace gtsam { */ typename BayesNet::shared_ptr marginalBayesNet(Index j, Eliminate function) const; - /** return joint on two variables */ + /** + * return joint on two variables + * Limitation: can only calculate joint if cliques are disjoint or one of them is root + */ typename FactorGraph::shared_ptr joint(Index j1, Index j2, Eliminate function) const; - /** return joint on two variables as a BayesNet */ + /** + * return joint on two variables as a BayesNet + * Limitation: can only calculate joint if cliques are disjoint or one of them is root + */ typename BayesNet::shared_ptr jointBayesNet(Index j1, Index j2, Eliminate function) const; /** diff --git a/gtsam/inference/BayesTreeCliqueBase-inl.h b/gtsam/inference/BayesTreeCliqueBase-inl.h index c08a592f7..7d77a5150 100644 --- a/gtsam/inference/BayesTreeCliqueBase-inl.h +++ b/gtsam/inference/BayesTreeCliqueBase-inl.h @@ -221,32 +221,38 @@ namespace gtsam { Eliminate function) const { // For now, assume neither is the root + sharedConditional p_F1_S1 = this->conditional(); + sharedConditional p_F2_S2 = C2->conditional(); + // Combine P(F1|S1), P(S1|R), P(F2|S2), P(S2|R), and P(R) FactorGraph joint; - if (!isRoot()) - joint.push_back(this->conditional()->toFactor()); // P(F1|S1) - if (!isRoot()) + if (!isRoot()) { + joint.push_back(p_F1_S1->toFactor()); // P(F1|S1) joint.push_back(shortcut(R, function)); // P(S1|R) - if (!C2->isRoot()) - joint.push_back(C2->conditional()->toFactor()); // P(F2|S2) - if (!C2->isRoot()) + } + if (!C2->isRoot()) { + joint.push_back(p_F2_S2->toFactor()); // P(F2|S2) joint.push_back(C2->shortcut(R, function)); // P(S2|R) + } joint.push_back(R->conditional()->toFactor()); // P(R) - // Find the keys of both C1 and C2 - std::vector keys1(conditional_->keys()); - std::vector keys2(C2->conditional_->keys()); - FastSet keys12; - keys12.insert(keys1.begin(), keys1.end()); - keys12.insert(keys2.begin(), keys2.end()); + // Merge the keys of C1 and C2 + std::vector keys12; + std::vector &indices1 = p_F1_S1->keys(), &indices2 = p_F2_S2->keys(); + std::set_union(indices1.begin(), indices1.end(), // + indices2.begin(), indices2.end(), std::back_inserter(keys12)); + + // Check validity + bool cliques_intersect = (keys12.size() < indices1.size() + indices2.size()); + if (!isRoot() && !C2->isRoot() && cliques_intersect) + throw std::runtime_error( + "BayesTreeCliqueBase::joint can only calculate joint if cliques are disjoint\n" + "or one of them is the root clique"); // Calculate the marginal - std::vector keys12vector; - keys12vector.reserve(keys12.size()); - keys12vector.insert(keys12vector.begin(), keys12.begin(), keys12.end()); assertInvariants(); GenericSequentialSolver solver(joint); - return *solver.jointFactorGraph(keys12vector, function); + return *solver.jointFactorGraph(keys12, function); } /* ************************************************************************* */ diff --git a/gtsam/inference/BayesTreeCliqueBase.h b/gtsam/inference/BayesTreeCliqueBase.h index ceeef3f45..87711b3f2 100644 --- a/gtsam/inference/BayesTreeCliqueBase.h +++ b/gtsam/inference/BayesTreeCliqueBase.h @@ -192,7 +192,10 @@ namespace gtsam { FactorGraph marginal(derived_ptr root, Eliminate function) const; - /** return the joint P(C1,C2), where C1==this. TODO: not a method? */ + /** + * return the joint P(C1,C2), where C1==this. TODO: not a method? + * Limitation: can only calculate joint if cliques are disjoint or one of them is root + */ FactorGraph joint(derived_ptr C2, derived_ptr root, Eliminate function) const; diff --git a/gtsam/inference/tests/testSymbolicBayesTree.cpp b/gtsam/inference/tests/testSymbolicBayesTree.cpp index 5e00e23b1..b69bdac5e 100644 --- a/gtsam/inference/tests/testSymbolicBayesTree.cpp +++ b/gtsam/inference/tests/testSymbolicBayesTree.cpp @@ -99,6 +99,16 @@ TEST_UNSAFE( SymbolicBayesTree, thinTree ) { EXPECT(assert_equal(expected, shortcut)); } + { + // check shortcut P(S2||R) to root + SymbolicBayesTree::Clique::shared_ptr c = bayesTree[2]; + SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic); + SymbolicBayesNet expected; + expected.push_front(boost::make_shared(12, 14)); + expected.push_front(boost::make_shared(9, 12, 14)); + EXPECT(assert_equal(expected, shortcut)); + } + { // check shortcut P(S0||R) to root SymbolicBayesTree::Clique::shared_ptr c = bayesTree[0]; @@ -108,6 +118,44 @@ TEST_UNSAFE( SymbolicBayesTree, thinTree ) { expected.push_front(boost::make_shared(8, 12, 14)); EXPECT(assert_equal(expected, shortcut)); } + + SymbolicBayesNet::shared_ptr actualJoint; + + // Check joint P(8,2) + if (false) { // TODO, not disjoint + actualJoint = bayesTree.jointBayesNet(8, 2, EliminateSymbolic); + SymbolicBayesNet expected; + expected.push_front(boost::make_shared(8)); + expected.push_front(boost::make_shared(2, 8)); + EXPECT(assert_equal(expected, *actualJoint)); + } + + // Check joint P(1,2) + if (false) { // TODO, not disjoint + actualJoint = bayesTree.jointBayesNet(1, 2, EliminateSymbolic); + SymbolicBayesNet expected; + expected.push_front(boost::make_shared(2)); + expected.push_front(boost::make_shared(1, 2)); + EXPECT(assert_equal(expected, *actualJoint)); + } + + // Check joint P(2,6) + if (true) { + actualJoint = bayesTree.jointBayesNet(2, 6, EliminateSymbolic); + SymbolicBayesNet expected; + expected.push_front(boost::make_shared(6)); + expected.push_front(boost::make_shared(2, 6)); + EXPECT(assert_equal(expected, *actualJoint)); + } + + // Check joint P(4,6) + if (false) { // TODO, not disjoint + actualJoint = bayesTree.jointBayesNet(4, 6, EliminateSymbolic); + SymbolicBayesNet expected; + expected.push_front(boost::make_shared(6)); + expected.push_front(boost::make_shared(4, 6)); + EXPECT(assert_equal(expected, *actualJoint)); + } } /* ************************************************************************* *