Fixed tests
parent
d456dddc6f
commit
468c7aee0c
|
@ -33,41 +33,6 @@ using namespace gtsam;
|
||||||
|
|
||||||
static bool debug = false;
|
static bool debug = false;
|
||||||
|
|
||||||
// /**
|
|
||||||
// * Custom clique class to debug shortcuts
|
|
||||||
// */
|
|
||||||
// struct Clique : public BayesTreeCliqueBase<Clique, DiscreteConditional> {
|
|
||||||
// typedef BayesTreeCliqueBase<Clique, DiscreteConditional> Base;
|
|
||||||
// typedef boost::shared_ptr<Clique> shared_ptr;
|
|
||||||
|
|
||||||
// // Constructors
|
|
||||||
// Clique() {}
|
|
||||||
// explicit Clique(const DiscreteConditional::shared_ptr& conditional)
|
|
||||||
// : Base(conditional) {}
|
|
||||||
// Clique(const std::pair<DiscreteConditional::shared_ptr,
|
|
||||||
// DiscreteConditional::FactorType::shared_ptr>&
|
|
||||||
// result)
|
|
||||||
// : Base(result) {}
|
|
||||||
|
|
||||||
// /// print index signature only
|
|
||||||
// void printSignature(
|
|
||||||
// const std::string& s = "Clique: ",
|
|
||||||
// const KeyFormatter& indexFormatter = DefaultKeyFormatter) const {
|
|
||||||
// ((IndexConditionalOrdered::shared_ptr)conditional_)
|
|
||||||
// ->print(s, indexFormatter);
|
|
||||||
// }
|
|
||||||
|
|
||||||
// /// evaluate value of sub-tree
|
|
||||||
// double evaluate(const DiscreteConditional::Values& values) {
|
|
||||||
// double result = (*(this->conditional_))(values);
|
|
||||||
// // evaluate all children and multiply into result
|
|
||||||
// for (boost::shared_ptr<Clique> c : children_) result *=
|
|
||||||
// c->evaluate(values); return result;
|
|
||||||
// }
|
|
||||||
// };
|
|
||||||
|
|
||||||
// typedef BayesTreeOrdered<DiscreteConditional, Clique> DiscreteBayesTree;
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
||||||
TEST_UNSAFE(DiscreteBayesTree, thinTree) {
|
TEST_UNSAFE(DiscreteBayesTree, thinTree) {
|
||||||
|
@ -124,24 +89,24 @@ TEST_UNSAFE(DiscreteBayesTree, thinTree) {
|
||||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||||
DiscreteFactor::Values x = allPosbValues[i];
|
DiscreteFactor::Values x = allPosbValues[i];
|
||||||
double expected = bayesNet.evaluate(x);
|
double expected = bayesNet.evaluate(x);
|
||||||
double actual = R->evaluate(x);
|
double actual = bayesTree->evaluate(x);
|
||||||
DOUBLES_EQUAL(expected, actual, 1e-9);
|
DOUBLES_EQUAL(expected, actual, 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate all some marginals
|
// Calculate all some marginals for Values==all1
|
||||||
Vector marginals = zero(15);
|
Vector marginals = zero(15);
|
||||||
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
|
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
|
||||||
joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
|
joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
|
||||||
joint_4_11 = 0;
|
joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0,
|
||||||
|
joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
|
||||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||||
DiscreteFactor::Values x = allPosbValues[i];
|
DiscreteFactor::Values x = allPosbValues[i];
|
||||||
double px = R->evaluate(x);
|
double px = bayesTree->evaluate(x);
|
||||||
for (size_t i = 0; i < 15; i++)
|
for (size_t i = 0; i < 15; i++)
|
||||||
if (x[i]) marginals[i] += px;
|
if (x[i]) marginals[i] += px;
|
||||||
// calculate shortcut 8 and 0
|
|
||||||
if (x[12] && x[14]) joint_12_14 += px;
|
if (x[12] && x[14]) joint_12_14 += px;
|
||||||
if (x[9] && x[12] & x[14]) joint_9_12_14 += px;
|
if (x[9] && x[12] && x[14]) joint_9_12_14 += px;
|
||||||
if (x[8] && x[12] & x[14]) joint_8_12_14 += px;
|
if (x[8] && x[12] && x[14]) joint_8_12_14 += px;
|
||||||
if (x[8] && x[12]) joint_8_12 += px;
|
if (x[8] && x[12]) joint_8_12 += px;
|
||||||
if (x[8] && x[2]) joint82 += px;
|
if (x[8] && x[2]) joint82 += px;
|
||||||
if (x[1] && x[2]) joint12 += px;
|
if (x[1] && x[2]) joint12 += px;
|
||||||
|
@ -149,96 +114,102 @@ TEST_UNSAFE(DiscreteBayesTree, thinTree) {
|
||||||
if (x[4] && x[5]) joint45 += px;
|
if (x[4] && x[5]) joint45 += px;
|
||||||
if (x[4] && x[6]) joint46 += px;
|
if (x[4] && x[6]) joint46 += px;
|
||||||
if (x[4] && x[11]) joint_4_11 += px;
|
if (x[4] && x[11]) joint_4_11 += px;
|
||||||
|
if (x[11] && x[13]) {
|
||||||
|
joint_11_13 += px;
|
||||||
|
if (x[8] && x[12]) joint_8_11_12_13 += px;
|
||||||
|
if (x[9] && x[12]) joint_9_11_12_13 += px;
|
||||||
|
if (x[14]) {
|
||||||
|
joint_11_13_14 += px;
|
||||||
|
if (x[12]) {
|
||||||
|
joint_11_12_13_14 += px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
DiscreteFactor::Values all1 = allPosbValues.back();
|
DiscreteFactor::Values all1 = allPosbValues.back();
|
||||||
|
|
||||||
|
|
||||||
// check separator marginal P(S0)
|
// check separator marginal P(S0)
|
||||||
auto c = (*bayesTree)[0];
|
auto c = (*bayesTree)[0];
|
||||||
DiscreteFactorGraph separatorMarginal0 =
|
DiscreteFactorGraph separatorMarginal0 =
|
||||||
c->separatorMarginal(EliminateDiscrete);
|
c->separatorMarginal(EliminateDiscrete);
|
||||||
EXPECT_DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
|
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
|
||||||
|
|
||||||
// // check separator marginal P(S9), should be P(14)
|
// check separator marginal P(S9), should be P(14)
|
||||||
// c = (*bayesTree)[9];
|
c = (*bayesTree)[9];
|
||||||
// DiscreteFactorGraph separatorMarginal9 =
|
DiscreteFactorGraph separatorMarginal9 =
|
||||||
// c->separatorMarginal(EliminateDiscrete);
|
c->separatorMarginal(EliminateDiscrete);
|
||||||
// EXPECT_DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
|
DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
|
||||||
|
|
||||||
// // check separator marginal of root, should be empty
|
// check separator marginal of root, should be empty
|
||||||
// c = (*bayesTree)[11];
|
c = (*bayesTree)[11];
|
||||||
// DiscreteFactorGraph separatorMarginal11 =
|
DiscreteFactorGraph separatorMarginal11 =
|
||||||
// c->separatorMarginal(EliminateDiscrete);
|
c->separatorMarginal(EliminateDiscrete);
|
||||||
// EXPECT_LONGS_EQUAL(0, separatorMarginal11.size());
|
LONGS_EQUAL(0, separatorMarginal11.size());
|
||||||
|
|
||||||
// // check shortcut P(S9||R) to root
|
// check shortcut P(S9||R) to root
|
||||||
// c = (*bayesTree)[9];
|
c = (*bayesTree)[9];
|
||||||
// DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
|
DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
|
||||||
// EXPECT_LONGS_EQUAL(0, shortcut.size());
|
LONGS_EQUAL(1, shortcut.size());
|
||||||
|
DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||||
|
|
||||||
// // check shortcut P(S8||R) to root
|
// check shortcut P(S8||R) to root
|
||||||
// c = (*bayesTree)[8];
|
c = (*bayesTree)[8];
|
||||||
// shortcut = c->shortcut(R, EliminateDiscrete);
|
shortcut = c->shortcut(R, EliminateDiscrete);
|
||||||
// EXPECT_DOUBLES_EQUAL(joint_12_14 / marginals[14], evaluate(shortcut, all1),
|
DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||||
// 1e-9);
|
|
||||||
|
|
||||||
// // check shortcut P(S2||R) to root
|
// check shortcut P(S2||R) to root
|
||||||
// c = (*bayesTree)[2];
|
c = (*bayesTree)[2];
|
||||||
// shortcut = c->shortcut(R, EliminateDiscrete);
|
shortcut = c->shortcut(R, EliminateDiscrete);
|
||||||
// EXPECT_DOUBLES_EQUAL(joint_9_12_14 / marginals[14], evaluate(shortcut,
|
DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||||
// all1),
|
|
||||||
// 1e-9);
|
|
||||||
|
|
||||||
// // check shortcut P(S0||R) to root
|
// check shortcut P(S0||R) to root
|
||||||
// c = (*bayesTree)[0];
|
c = (*bayesTree)[0];
|
||||||
// shortcut = c->shortcut(R, EliminateDiscrete);
|
shortcut = c->shortcut(R, EliminateDiscrete);
|
||||||
// EXPECT_DOUBLES_EQUAL(joint_8_12_14 / marginals[14], evaluate(shortcut,
|
DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||||
// all1),
|
|
||||||
// 1e-9);
|
|
||||||
|
|
||||||
// // calculate all shortcuts to root
|
// calculate all shortcuts to root
|
||||||
// DiscreteBayesTree::Nodes cliques = bayesTree->nodes();
|
DiscreteBayesTree::Nodes cliques = bayesTree->nodes();
|
||||||
// for (auto c : cliques) {
|
for (auto c : cliques) {
|
||||||
// DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
|
DiscreteBayesNet shortcut = c.second->shortcut(R, EliminateDiscrete);
|
||||||
// if (debug) {
|
if (debug) {
|
||||||
// c->printSignature();
|
c.second->conditional_->printSignature();
|
||||||
// shortcut.print("shortcut:");
|
shortcut.print("shortcut:");
|
||||||
// }
|
}
|
||||||
// }
|
}
|
||||||
|
|
||||||
// // Check all marginals
|
// Check all marginals
|
||||||
// DiscreteFactor::shared_ptr marginalFactor;
|
DiscreteFactor::shared_ptr marginalFactor;
|
||||||
// for (size_t i = 0; i < 15; i++) {
|
for (size_t i = 0; i < 15; i++) {
|
||||||
// marginalFactor = bayesTree->marginalFactor(i, EliminateDiscrete);
|
marginalFactor = bayesTree->marginalFactor(i, EliminateDiscrete);
|
||||||
// double actual = (*marginalFactor)(all1);
|
double actual = (*marginalFactor)(all1);
|
||||||
// EXPECT_DOUBLES_EQUAL(marginals[i], actual, 1e-9);
|
DOUBLES_EQUAL(marginals[i], actual, 1e-9);
|
||||||
// }
|
}
|
||||||
|
|
||||||
// DiscreteBayesNet::shared_ptr actualJoint;
|
DiscreteBayesNet::shared_ptr actualJoint;
|
||||||
|
|
||||||
// Check joint P(8,2) TODO: not disjoint !
|
// Check joint P(8, 2)
|
||||||
// actualJoint = bayesTree->jointBayesNet(8, 2, EliminateDiscrete);
|
actualJoint = bayesTree->jointBayesNet(8, 2, EliminateDiscrete);
|
||||||
// EXPECT_DOUBLES_EQUAL(joint82, evaluate(*actualJoint,all1), 1e-9);
|
DOUBLES_EQUAL(joint82, actualJoint->evaluate(all1), 1e-9);
|
||||||
|
|
||||||
// Check joint P(1,2) TODO: not disjoint !
|
// Check joint P(1, 2)
|
||||||
// actualJoint = bayesTree->jointBayesNet(1, 2, EliminateDiscrete);
|
actualJoint = bayesTree->jointBayesNet(1, 2, EliminateDiscrete);
|
||||||
// EXPECT_DOUBLES_EQUAL(joint12, evaluate(*actualJoint,all1), 1e-9);
|
DOUBLES_EQUAL(joint12, actualJoint->evaluate(all1), 1e-9);
|
||||||
|
|
||||||
// Check joint P(2,4)
|
// Check joint P(2, 4)
|
||||||
// actualJoint = bayesTree->jointBayesNet(2, 4, EliminateDiscrete);
|
actualJoint = bayesTree->jointBayesNet(2, 4, EliminateDiscrete);
|
||||||
// EXPECT_DOUBLES_EQUAL(joint24, evaluate(*actualJoint, all1), 1e-9);
|
DOUBLES_EQUAL(joint24, actualJoint->evaluate(all1), 1e-9);
|
||||||
|
|
||||||
// Check joint P(4,5) TODO: not disjoint !
|
// Check joint P(4, 5)
|
||||||
// actualJoint = bayesTree->jointBayesNet(4, 5, EliminateDiscrete);
|
actualJoint = bayesTree->jointBayesNet(4, 5, EliminateDiscrete);
|
||||||
// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
|
DOUBLES_EQUAL(joint45, actualJoint->evaluate(all1), 1e-9);
|
||||||
|
|
||||||
// Check joint P(4,6) TODO: not disjoint !
|
// Check joint P(4, 6)
|
||||||
// actualJoint = bayesTree->jointBayesNet(4, 6, EliminateDiscrete);
|
actualJoint = bayesTree->jointBayesNet(4, 6, EliminateDiscrete);
|
||||||
// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
|
DOUBLES_EQUAL(joint46, actualJoint->evaluate(all1), 1e-9);
|
||||||
|
|
||||||
// Check joint P(4,11)
|
// Check joint P(4, 11)
|
||||||
// actualJoint = bayesTree->jointBayesNet(4, 11, EliminateDiscrete);
|
actualJoint = bayesTree->jointBayesNet(4, 11, EliminateDiscrete);
|
||||||
// EXPECT_DOUBLES_EQUAL(joint_4_11, evaluate(*actualJoint, all1), 1e-9);
|
DOUBLES_EQUAL(joint_4_11, actualJoint->evaluate(all1), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
Loading…
Reference in New Issue