diff --git a/gtsam/discrete/tests/testDiscreteBayesTree.cpp b/gtsam/discrete/tests/testDiscreteBayesTree.cpp index f58fd2b19..150a41c24 100644 --- a/gtsam/discrete/tests/testDiscreteBayesTree.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesTree.cpp @@ -33,41 +33,6 @@ using namespace gtsam; static bool debug = false; -// /** -// * Custom clique class to debug shortcuts -// */ -// struct Clique : public BayesTreeCliqueBase { -// typedef BayesTreeCliqueBase Base; -// typedef boost::shared_ptr shared_ptr; - -// // Constructors -// Clique() {} -// explicit Clique(const DiscreteConditional::shared_ptr& conditional) -// : Base(conditional) {} -// Clique(const std::pair& -// result) -// : Base(result) {} - -// /// print index signature only -// void printSignature( -// const std::string& s = "Clique: ", -// const KeyFormatter& indexFormatter = DefaultKeyFormatter) const { -// ((IndexConditionalOrdered::shared_ptr)conditional_) -// ->print(s, indexFormatter); -// } - -// /// evaluate value of sub-tree -// double evaluate(const DiscreteConditional::Values& values) { -// double result = (*(this->conditional_))(values); -// // evaluate all children and multiply into result -// for (boost::shared_ptr c : children_) result *= -// c->evaluate(values); return result; -// } -// }; - -// typedef BayesTreeOrdered DiscreteBayesTree; - /* ************************************************************************* */ TEST_UNSAFE(DiscreteBayesTree, thinTree) { @@ -124,24 +89,24 @@ TEST_UNSAFE(DiscreteBayesTree, thinTree) { for (size_t i = 0; i < allPosbValues.size(); ++i) { DiscreteFactor::Values x = allPosbValues[i]; double expected = bayesNet.evaluate(x); - double actual = R->evaluate(x); + double actual = bayesTree->evaluate(x); DOUBLES_EQUAL(expected, actual, 1e-9); } - // Calculate all some marginals + // Calculate all some marginals for Values==all1 Vector marginals = zero(15); double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0, joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, - joint_4_11 = 0; + joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0, + joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0; for (size_t i = 0; i < allPosbValues.size(); ++i) { DiscreteFactor::Values x = allPosbValues[i]; - double px = R->evaluate(x); + double px = bayesTree->evaluate(x); for (size_t i = 0; i < 15; i++) if (x[i]) marginals[i] += px; - // calculate shortcut 8 and 0 if (x[12] && x[14]) joint_12_14 += px; - if (x[9] && x[12] & x[14]) joint_9_12_14 += px; - if (x[8] && x[12] & x[14]) joint_8_12_14 += px; + if (x[9] && x[12] && x[14]) joint_9_12_14 += px; + if (x[8] && x[12] && x[14]) joint_8_12_14 += px; if (x[8] && x[12]) joint_8_12 += px; if (x[8] && x[2]) joint82 += px; if (x[1] && x[2]) joint12 += px; @@ -149,96 +114,102 @@ TEST_UNSAFE(DiscreteBayesTree, thinTree) { if (x[4] && x[5]) joint45 += px; if (x[4] && x[6]) joint46 += px; if (x[4] && x[11]) joint_4_11 += px; + if (x[11] && x[13]) { + joint_11_13 += px; + if (x[8] && x[12]) joint_8_11_12_13 += px; + if (x[9] && x[12]) joint_9_11_12_13 += px; + if (x[14]) { + joint_11_13_14 += px; + if (x[12]) { + joint_11_12_13_14 += px; + } + } + } } DiscreteFactor::Values all1 = allPosbValues.back(); - // check separator marginal P(S0) auto c = (*bayesTree)[0]; DiscreteFactorGraph separatorMarginal0 = c->separatorMarginal(EliminateDiscrete); - EXPECT_DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9); + DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9); - // // check separator marginal P(S9), should be P(14) - // c = (*bayesTree)[9]; - // DiscreteFactorGraph separatorMarginal9 = - // c->separatorMarginal(EliminateDiscrete); - // EXPECT_DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9); + // check separator marginal P(S9), should be P(14) + c = (*bayesTree)[9]; + DiscreteFactorGraph separatorMarginal9 = + c->separatorMarginal(EliminateDiscrete); + DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9); - // // check separator marginal of root, should be empty - // c = (*bayesTree)[11]; - // DiscreteFactorGraph separatorMarginal11 = - // c->separatorMarginal(EliminateDiscrete); - // EXPECT_LONGS_EQUAL(0, separatorMarginal11.size()); + // check separator marginal of root, should be empty + c = (*bayesTree)[11]; + DiscreteFactorGraph separatorMarginal11 = + c->separatorMarginal(EliminateDiscrete); + LONGS_EQUAL(0, separatorMarginal11.size()); - // // check shortcut P(S9||R) to root - // c = (*bayesTree)[9]; - // DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete); - // EXPECT_LONGS_EQUAL(0, shortcut.size()); + // check shortcut P(S9||R) to root + c = (*bayesTree)[9]; + DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete); + LONGS_EQUAL(1, shortcut.size()); + DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9); - // // check shortcut P(S8||R) to root - // c = (*bayesTree)[8]; - // shortcut = c->shortcut(R, EliminateDiscrete); - // EXPECT_DOUBLES_EQUAL(joint_12_14 / marginals[14], evaluate(shortcut, all1), - // 1e-9); + // check shortcut P(S8||R) to root + c = (*bayesTree)[8]; + shortcut = c->shortcut(R, EliminateDiscrete); + DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9); - // // check shortcut P(S2||R) to root - // c = (*bayesTree)[2]; - // shortcut = c->shortcut(R, EliminateDiscrete); - // EXPECT_DOUBLES_EQUAL(joint_9_12_14 / marginals[14], evaluate(shortcut, - // all1), - // 1e-9); + // check shortcut P(S2||R) to root + c = (*bayesTree)[2]; + shortcut = c->shortcut(R, EliminateDiscrete); + DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9); - // // check shortcut P(S0||R) to root - // c = (*bayesTree)[0]; - // shortcut = c->shortcut(R, EliminateDiscrete); - // EXPECT_DOUBLES_EQUAL(joint_8_12_14 / marginals[14], evaluate(shortcut, - // all1), - // 1e-9); + // check shortcut P(S0||R) to root + c = (*bayesTree)[0]; + shortcut = c->shortcut(R, EliminateDiscrete); + DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9); - // // calculate all shortcuts to root - // DiscreteBayesTree::Nodes cliques = bayesTree->nodes(); - // for (auto c : cliques) { - // DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete); - // if (debug) { - // c->printSignature(); - // shortcut.print("shortcut:"); - // } - // } + // calculate all shortcuts to root + DiscreteBayesTree::Nodes cliques = bayesTree->nodes(); + for (auto c : cliques) { + DiscreteBayesNet shortcut = c.second->shortcut(R, EliminateDiscrete); + if (debug) { + c.second->conditional_->printSignature(); + shortcut.print("shortcut:"); + } + } - // // Check all marginals - // DiscreteFactor::shared_ptr marginalFactor; - // for (size_t i = 0; i < 15; i++) { - // marginalFactor = bayesTree->marginalFactor(i, EliminateDiscrete); - // double actual = (*marginalFactor)(all1); - // EXPECT_DOUBLES_EQUAL(marginals[i], actual, 1e-9); - // } + // Check all marginals + DiscreteFactor::shared_ptr marginalFactor; + for (size_t i = 0; i < 15; i++) { + marginalFactor = bayesTree->marginalFactor(i, EliminateDiscrete); + double actual = (*marginalFactor)(all1); + DOUBLES_EQUAL(marginals[i], actual, 1e-9); + } - // DiscreteBayesNet::shared_ptr actualJoint; + DiscreteBayesNet::shared_ptr actualJoint; - // Check joint P(8,2) TODO: not disjoint ! - // actualJoint = bayesTree->jointBayesNet(8, 2, EliminateDiscrete); - // EXPECT_DOUBLES_EQUAL(joint82, evaluate(*actualJoint,all1), 1e-9); + // Check joint P(8, 2) + actualJoint = bayesTree->jointBayesNet(8, 2, EliminateDiscrete); + DOUBLES_EQUAL(joint82, actualJoint->evaluate(all1), 1e-9); - // Check joint P(1,2) TODO: not disjoint ! - // actualJoint = bayesTree->jointBayesNet(1, 2, EliminateDiscrete); - // EXPECT_DOUBLES_EQUAL(joint12, evaluate(*actualJoint,all1), 1e-9); + // Check joint P(1, 2) + actualJoint = bayesTree->jointBayesNet(1, 2, EliminateDiscrete); + DOUBLES_EQUAL(joint12, actualJoint->evaluate(all1), 1e-9); - // Check joint P(2,4) - // actualJoint = bayesTree->jointBayesNet(2, 4, EliminateDiscrete); - // EXPECT_DOUBLES_EQUAL(joint24, evaluate(*actualJoint, all1), 1e-9); + // Check joint P(2, 4) + actualJoint = bayesTree->jointBayesNet(2, 4, EliminateDiscrete); + DOUBLES_EQUAL(joint24, actualJoint->evaluate(all1), 1e-9); - // Check joint P(4,5) TODO: not disjoint ! - // actualJoint = bayesTree->jointBayesNet(4, 5, EliminateDiscrete); - // EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9); + // Check joint P(4, 5) + actualJoint = bayesTree->jointBayesNet(4, 5, EliminateDiscrete); + DOUBLES_EQUAL(joint45, actualJoint->evaluate(all1), 1e-9); - // Check joint P(4,6) TODO: not disjoint ! - // actualJoint = bayesTree->jointBayesNet(4, 6, EliminateDiscrete); - // EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9); + // Check joint P(4, 6) + actualJoint = bayesTree->jointBayesNet(4, 6, EliminateDiscrete); + DOUBLES_EQUAL(joint46, actualJoint->evaluate(all1), 1e-9); - // Check joint P(4,11) - // actualJoint = bayesTree->jointBayesNet(4, 11, EliminateDiscrete); - // EXPECT_DOUBLES_EQUAL(joint_4_11, evaluate(*actualJoint, all1), 1e-9); + // Check joint P(4, 11) + actualJoint = bayesTree->jointBayesNet(4, 11, EliminateDiscrete); + DOUBLES_EQUAL(joint_4_11, actualJoint->evaluate(all1), 1e-9); } /* ************************************************************************* */