refactor into smaller tests
parent
3fd5c2501e
commit
9857e62a56
|
@ -16,10 +16,10 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/base/Vector.h>
|
#include <gtsam/base/Vector.h>
|
||||||
|
#include <gtsam/inference/BayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/inference/BayesNet.h>
|
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
|
@ -32,7 +32,8 @@ static constexpr bool debug = false;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
struct TestFixture {
|
struct TestFixture {
|
||||||
vector<DiscreteKey> keys;
|
DiscreteKeys keys;
|
||||||
|
std::vector<DiscreteValues> assignments;
|
||||||
DiscreteBayesNet bayesNet;
|
DiscreteBayesNet bayesNet;
|
||||||
boost::shared_ptr<DiscreteBayesTree> bayesTree;
|
boost::shared_ptr<DiscreteBayesTree> bayesTree;
|
||||||
|
|
||||||
|
@ -47,6 +48,9 @@ struct TestFixture {
|
||||||
keys.push_back(key_i);
|
keys.push_back(key_i);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Enumerate all assignments.
|
||||||
|
assignments = DiscreteValues::CartesianProduct(keys);
|
||||||
|
|
||||||
// Create thin-tree Bayesnet.
|
// Create thin-tree Bayesnet.
|
||||||
bayesNet.add(keys[14] % "1/3");
|
bayesNet.add(keys[14] % "1/3");
|
||||||
|
|
||||||
|
@ -74,9 +78,9 @@ struct TestFixture {
|
||||||
};
|
};
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
// Check that BN and BT give the same answer on all configurations
|
||||||
TEST(DiscreteBayesTree, ThinTree) {
|
TEST(DiscreteBayesTree, ThinTree) {
|
||||||
const TestFixture self;
|
TestFixture self;
|
||||||
const auto& keys = self.keys;
|
|
||||||
|
|
||||||
if (debug) {
|
if (debug) {
|
||||||
GTSAM_PRINT(self.bayesNet);
|
GTSAM_PRINT(self.bayesNet);
|
||||||
|
@ -95,47 +99,56 @@ TEST(DiscreteBayesTree, ThinTree) {
|
||||||
EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals()));
|
EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals()));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto R = self.bayesTree->roots().front();
|
for (const auto& x : self.assignments) {
|
||||||
|
|
||||||
// Check whether BN and BT give the same answer on all configurations
|
|
||||||
auto allPosbValues = DiscreteValues::CartesianProduct(
|
|
||||||
keys[0] & keys[1] & keys[2] & keys[3] & keys[4] & keys[5] & keys[6] &
|
|
||||||
keys[7] & keys[8] & keys[9] & keys[10] & keys[11] & keys[12] & keys[13] &
|
|
||||||
keys[14]);
|
|
||||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
|
||||||
DiscreteValues x = allPosbValues[i];
|
|
||||||
double expected = self.bayesNet.evaluate(x);
|
double expected = self.bayesNet.evaluate(x);
|
||||||
double actual = self.bayesTree->evaluate(x);
|
double actual = self.bayesTree->evaluate(x);
|
||||||
DOUBLES_EQUAL(expected, actual, 1e-9);
|
DOUBLES_EQUAL(expected, actual, 1e-9);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Calculate all some marginals for DiscreteValues==all1
|
/* ************************************************************************* */
|
||||||
Vector marginals = Vector::Zero(15);
|
// Check calculation of separator marginals
|
||||||
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
|
TEST(DiscreteBayesTree, separatorMarginal) {
|
||||||
joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
|
TestFixture self;
|
||||||
joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0,
|
|
||||||
joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
|
// Calculate some marginals for DiscreteValues==all1
|
||||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
double marginal_14 = 0, joint_8_12 = 0;
|
||||||
DiscreteValues x = allPosbValues[i];
|
for (auto& x : self.assignments) {
|
||||||
double px = self.bayesTree->evaluate(x);
|
double px = self.bayesTree->evaluate(x);
|
||||||
for (size_t i = 0; i < 15; i++)
|
|
||||||
if (x[i]) marginals[i] += px;
|
|
||||||
if (x[12] && x[14]) {
|
|
||||||
joint_12_14 += px;
|
|
||||||
if (x[9]) joint_9_12_14 += px;
|
|
||||||
if (x[8]) joint_8_12_14 += px;
|
|
||||||
}
|
|
||||||
if (x[8] && x[12]) joint_8_12 += px;
|
if (x[8] && x[12]) joint_8_12 += px;
|
||||||
if (x[2]) {
|
if (x[14]) marginal_14 += px;
|
||||||
if (x[8]) joint82 += px;
|
}
|
||||||
if (x[1]) joint12 += px;
|
DiscreteValues all1 = self.assignments.back();
|
||||||
}
|
|
||||||
if (x[4]) {
|
// check separator marginal P(S0)
|
||||||
if (x[2]) joint24 += px;
|
auto clique = (*self.bayesTree)[0];
|
||||||
if (x[5]) joint45 += px;
|
DiscreteFactorGraph separatorMarginal0 =
|
||||||
if (x[6]) joint46 += px;
|
clique->separatorMarginal(EliminateDiscrete);
|
||||||
if (x[11]) joint_4_11 += px;
|
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
|
||||||
}
|
|
||||||
|
// check separator marginal P(S9), should be P(14)
|
||||||
|
clique = (*self.bayesTree)[9];
|
||||||
|
DiscreteFactorGraph separatorMarginal9 =
|
||||||
|
clique->separatorMarginal(EliminateDiscrete);
|
||||||
|
DOUBLES_EQUAL(marginal_14, separatorMarginal9(all1), 1e-9);
|
||||||
|
|
||||||
|
// check separator marginal of root, should be empty
|
||||||
|
clique = (*self.bayesTree)[11];
|
||||||
|
DiscreteFactorGraph separatorMarginal11 =
|
||||||
|
clique->separatorMarginal(EliminateDiscrete);
|
||||||
|
LONGS_EQUAL(0, separatorMarginal11.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check shortcuts in the tree
|
||||||
|
TEST(DiscreteBayesTree, shortcut) {
|
||||||
|
TestFixture self;
|
||||||
|
|
||||||
|
// Calculate some marginals for DiscreteValues==all1
|
||||||
|
double joint_11_13 = 0, joint_11_13_14 = 0, joint_11_12_13_14 = 0,
|
||||||
|
joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
|
||||||
|
for (auto& x : self.assignments) {
|
||||||
|
double px = self.bayesTree->evaluate(x);
|
||||||
if (x[11] && x[13]) {
|
if (x[11] && x[13]) {
|
||||||
joint_11_13 += px;
|
joint_11_13 += px;
|
||||||
if (x[8] && x[12]) joint_8_11_12_13 += px;
|
if (x[8] && x[12]) joint_8_11_12_13 += px;
|
||||||
|
@ -148,32 +161,12 @@ TEST(DiscreteBayesTree, ThinTree) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
DiscreteValues all1 = allPosbValues.back();
|
DiscreteValues all1 = self.assignments.back();
|
||||||
|
|
||||||
// check separator marginal P(S0)
|
auto R = self.bayesTree->roots().front();
|
||||||
auto clique = (*self.bayesTree)[0];
|
|
||||||
DiscreteFactorGraph separatorMarginal0 =
|
|
||||||
clique->separatorMarginal(EliminateDiscrete);
|
|
||||||
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
|
|
||||||
|
|
||||||
DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9);
|
|
||||||
DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9);
|
|
||||||
DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9);
|
|
||||||
|
|
||||||
// check separator marginal P(S9), should be P(14)
|
|
||||||
clique = (*self.bayesTree)[9];
|
|
||||||
DiscreteFactorGraph separatorMarginal9 =
|
|
||||||
clique->separatorMarginal(EliminateDiscrete);
|
|
||||||
DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
|
|
||||||
|
|
||||||
// check separator marginal of root, should be empty
|
|
||||||
clique = (*self.bayesTree)[11];
|
|
||||||
DiscreteFactorGraph separatorMarginal11 =
|
|
||||||
clique->separatorMarginal(EliminateDiscrete);
|
|
||||||
LONGS_EQUAL(0, separatorMarginal11.size());
|
|
||||||
|
|
||||||
// check shortcut P(S9||R) to root
|
// check shortcut P(S9||R) to root
|
||||||
clique = (*self.bayesTree)[9];
|
auto clique = (*self.bayesTree)[9];
|
||||||
DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete);
|
DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||||
LONGS_EQUAL(1, shortcut.size());
|
LONGS_EQUAL(1, shortcut.size());
|
||||||
DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||||
|
@ -202,15 +195,67 @@ TEST(DiscreteBayesTree, ThinTree) {
|
||||||
shortcut.print("shortcut:");
|
shortcut.print("shortcut:");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check all marginals
|
||||||
|
TEST(DiscreteBayesTree, marginalFactor) {
|
||||||
|
TestFixture self;
|
||||||
|
|
||||||
|
Vector marginals = Vector::Zero(15);
|
||||||
|
for (size_t i = 0; i < self.assignments.size(); ++i) {
|
||||||
|
DiscreteValues& x = self.assignments[i];
|
||||||
|
double px = self.bayesTree->evaluate(x);
|
||||||
|
for (size_t i = 0; i < 15; i++)
|
||||||
|
if (x[i]) marginals[i] += px;
|
||||||
|
}
|
||||||
|
|
||||||
// Check all marginals
|
// Check all marginals
|
||||||
DiscreteFactor::shared_ptr marginalFactor;
|
DiscreteValues all1 = self.assignments.back();
|
||||||
for (size_t i = 0; i < 15; i++) {
|
for (size_t i = 0; i < 15; i++) {
|
||||||
marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete);
|
auto marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete);
|
||||||
double actual = (*marginalFactor)(all1);
|
double actual = (*marginalFactor)(all1);
|
||||||
DOUBLES_EQUAL(marginals[i], actual, 1e-9);
|
DOUBLES_EQUAL(marginals[i], actual, 1e-9);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check a number of joint marginals.
|
||||||
|
TEST(DiscreteBayesTree, Joints) {
|
||||||
|
TestFixture self;
|
||||||
|
|
||||||
|
// Calculate some marginals for DiscreteValues==all1
|
||||||
|
Vector marginals = Vector::Zero(15);
|
||||||
|
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint82 = 0,
|
||||||
|
joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, joint_4_11 = 0;
|
||||||
|
for (size_t i = 0; i < self.assignments.size(); ++i) {
|
||||||
|
DiscreteValues& x = self.assignments[i];
|
||||||
|
double px = self.bayesTree->evaluate(x);
|
||||||
|
for (size_t i = 0; i < 15; i++)
|
||||||
|
if (x[i]) marginals[i] += px;
|
||||||
|
if (x[12] && x[14]) {
|
||||||
|
joint_12_14 += px;
|
||||||
|
if (x[9]) joint_9_12_14 += px;
|
||||||
|
if (x[8]) joint_8_12_14 += px;
|
||||||
|
}
|
||||||
|
if (x[2]) {
|
||||||
|
if (x[8]) joint82 += px;
|
||||||
|
if (x[1]) joint12 += px;
|
||||||
|
}
|
||||||
|
if (x[4]) {
|
||||||
|
if (x[2]) joint24 += px;
|
||||||
|
if (x[5]) joint45 += px;
|
||||||
|
if (x[6]) joint46 += px;
|
||||||
|
if (x[11]) joint_4_11 += px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// regression tests:
|
||||||
|
DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9);
|
||||||
|
DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9);
|
||||||
|
DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9);
|
||||||
|
|
||||||
|
DiscreteValues all1 = self.assignments.back();
|
||||||
DiscreteBayesNet::shared_ptr actualJoint;
|
DiscreteBayesNet::shared_ptr actualJoint;
|
||||||
|
|
||||||
// Check joint P(8, 2)
|
// Check joint P(8, 2)
|
||||||
|
@ -240,7 +285,7 @@ TEST(DiscreteBayesTree, ThinTree) {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteBayesTree, Dot) {
|
TEST(DiscreteBayesTree, Dot) {
|
||||||
const TestFixture self;
|
TestFixture self;
|
||||||
string actual = self.bayesTree->dot();
|
string actual = self.bayesTree->dot();
|
||||||
EXPECT(actual ==
|
EXPECT(actual ==
|
||||||
"digraph G{\n"
|
"digraph G{\n"
|
||||||
|
|
Loading…
Reference in New Issue