refactor into smaller tests
parent
3fd5c2501e
commit
9857e62a56
|
@ -16,10 +16,10 @@
|
|||
*/
|
||||
|
||||
#include <gtsam/base/Vector.h>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
|
@ -32,7 +32,8 @@ static constexpr bool debug = false;
|
|||
|
||||
/* ************************************************************************* */
|
||||
struct TestFixture {
|
||||
vector<DiscreteKey> keys;
|
||||
DiscreteKeys keys;
|
||||
std::vector<DiscreteValues> assignments;
|
||||
DiscreteBayesNet bayesNet;
|
||||
boost::shared_ptr<DiscreteBayesTree> bayesTree;
|
||||
|
||||
|
@ -47,6 +48,9 @@ struct TestFixture {
|
|||
keys.push_back(key_i);
|
||||
}
|
||||
|
||||
// Enumerate all assignments.
|
||||
assignments = DiscreteValues::CartesianProduct(keys);
|
||||
|
||||
// Create thin-tree Bayesnet.
|
||||
bayesNet.add(keys[14] % "1/3");
|
||||
|
||||
|
@ -74,9 +78,9 @@ struct TestFixture {
|
|||
};
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check that BN and BT give the same answer on all configurations
|
||||
TEST(DiscreteBayesTree, ThinTree) {
|
||||
const TestFixture self;
|
||||
const auto& keys = self.keys;
|
||||
TestFixture self;
|
||||
|
||||
if (debug) {
|
||||
GTSAM_PRINT(self.bayesNet);
|
||||
|
@ -95,47 +99,56 @@ TEST(DiscreteBayesTree, ThinTree) {
|
|||
EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals()));
|
||||
}
|
||||
|
||||
auto R = self.bayesTree->roots().front();
|
||||
|
||||
// Check whether BN and BT give the same answer on all configurations
|
||||
auto allPosbValues = DiscreteValues::CartesianProduct(
|
||||
keys[0] & keys[1] & keys[2] & keys[3] & keys[4] & keys[5] & keys[6] &
|
||||
keys[7] & keys[8] & keys[9] & keys[10] & keys[11] & keys[12] & keys[13] &
|
||||
keys[14]);
|
||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||
DiscreteValues x = allPosbValues[i];
|
||||
for (const auto& x : self.assignments) {
|
||||
double expected = self.bayesNet.evaluate(x);
|
||||
double actual = self.bayesTree->evaluate(x);
|
||||
DOUBLES_EQUAL(expected, actual, 1e-9);
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate all some marginals for DiscreteValues==all1
|
||||
Vector marginals = Vector::Zero(15);
|
||||
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
|
||||
joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
|
||||
joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0,
|
||||
joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
|
||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||
DiscreteValues x = allPosbValues[i];
|
||||
/* ************************************************************************* */
|
||||
// Check calculation of separator marginals
|
||||
TEST(DiscreteBayesTree, separatorMarginal) {
|
||||
TestFixture self;
|
||||
|
||||
// Calculate some marginals for DiscreteValues==all1
|
||||
double marginal_14 = 0, joint_8_12 = 0;
|
||||
for (auto& x : self.assignments) {
|
||||
double px = self.bayesTree->evaluate(x);
|
||||
for (size_t i = 0; i < 15; i++)
|
||||
if (x[i]) marginals[i] += px;
|
||||
if (x[12] && x[14]) {
|
||||
joint_12_14 += px;
|
||||
if (x[9]) joint_9_12_14 += px;
|
||||
if (x[8]) joint_8_12_14 += px;
|
||||
}
|
||||
if (x[8] && x[12]) joint_8_12 += px;
|
||||
if (x[2]) {
|
||||
if (x[8]) joint82 += px;
|
||||
if (x[1]) joint12 += px;
|
||||
}
|
||||
if (x[4]) {
|
||||
if (x[2]) joint24 += px;
|
||||
if (x[5]) joint45 += px;
|
||||
if (x[6]) joint46 += px;
|
||||
if (x[11]) joint_4_11 += px;
|
||||
}
|
||||
if (x[14]) marginal_14 += px;
|
||||
}
|
||||
DiscreteValues all1 = self.assignments.back();
|
||||
|
||||
// check separator marginal P(S0)
|
||||
auto clique = (*self.bayesTree)[0];
|
||||
DiscreteFactorGraph separatorMarginal0 =
|
||||
clique->separatorMarginal(EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
|
||||
|
||||
// check separator marginal P(S9), should be P(14)
|
||||
clique = (*self.bayesTree)[9];
|
||||
DiscreteFactorGraph separatorMarginal9 =
|
||||
clique->separatorMarginal(EliminateDiscrete);
|
||||
DOUBLES_EQUAL(marginal_14, separatorMarginal9(all1), 1e-9);
|
||||
|
||||
// check separator marginal of root, should be empty
|
||||
clique = (*self.bayesTree)[11];
|
||||
DiscreteFactorGraph separatorMarginal11 =
|
||||
clique->separatorMarginal(EliminateDiscrete);
|
||||
LONGS_EQUAL(0, separatorMarginal11.size());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check shortcuts in the tree
|
||||
TEST(DiscreteBayesTree, shortcut) {
|
||||
TestFixture self;
|
||||
|
||||
// Calculate some marginals for DiscreteValues==all1
|
||||
double joint_11_13 = 0, joint_11_13_14 = 0, joint_11_12_13_14 = 0,
|
||||
joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
|
||||
for (auto& x : self.assignments) {
|
||||
double px = self.bayesTree->evaluate(x);
|
||||
if (x[11] && x[13]) {
|
||||
joint_11_13 += px;
|
||||
if (x[8] && x[12]) joint_8_11_12_13 += px;
|
||||
|
@ -148,32 +161,12 @@ TEST(DiscreteBayesTree, ThinTree) {
|
|||
}
|
||||
}
|
||||
}
|
||||
DiscreteValues all1 = allPosbValues.back();
|
||||
DiscreteValues all1 = self.assignments.back();
|
||||
|
||||
// check separator marginal P(S0)
|
||||
auto clique = (*self.bayesTree)[0];
|
||||
DiscreteFactorGraph separatorMarginal0 =
|
||||
clique->separatorMarginal(EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
|
||||
|
||||
DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9);
|
||||
DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9);
|
||||
DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9);
|
||||
|
||||
// check separator marginal P(S9), should be P(14)
|
||||
clique = (*self.bayesTree)[9];
|
||||
DiscreteFactorGraph separatorMarginal9 =
|
||||
clique->separatorMarginal(EliminateDiscrete);
|
||||
DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
|
||||
|
||||
// check separator marginal of root, should be empty
|
||||
clique = (*self.bayesTree)[11];
|
||||
DiscreteFactorGraph separatorMarginal11 =
|
||||
clique->separatorMarginal(EliminateDiscrete);
|
||||
LONGS_EQUAL(0, separatorMarginal11.size());
|
||||
auto R = self.bayesTree->roots().front();
|
||||
|
||||
// check shortcut P(S9||R) to root
|
||||
clique = (*self.bayesTree)[9];
|
||||
auto clique = (*self.bayesTree)[9];
|
||||
DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||
LONGS_EQUAL(1, shortcut.size());
|
||||
DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||
|
@ -202,15 +195,67 @@ TEST(DiscreteBayesTree, ThinTree) {
|
|||
shortcut.print("shortcut:");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check all marginals
|
||||
TEST(DiscreteBayesTree, marginalFactor) {
|
||||
TestFixture self;
|
||||
|
||||
Vector marginals = Vector::Zero(15);
|
||||
for (size_t i = 0; i < self.assignments.size(); ++i) {
|
||||
DiscreteValues& x = self.assignments[i];
|
||||
double px = self.bayesTree->evaluate(x);
|
||||
for (size_t i = 0; i < 15; i++)
|
||||
if (x[i]) marginals[i] += px;
|
||||
}
|
||||
|
||||
// Check all marginals
|
||||
DiscreteFactor::shared_ptr marginalFactor;
|
||||
DiscreteValues all1 = self.assignments.back();
|
||||
for (size_t i = 0; i < 15; i++) {
|
||||
marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete);
|
||||
auto marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete);
|
||||
double actual = (*marginalFactor)(all1);
|
||||
DOUBLES_EQUAL(marginals[i], actual, 1e-9);
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check a number of joint marginals.
|
||||
TEST(DiscreteBayesTree, Joints) {
|
||||
TestFixture self;
|
||||
|
||||
// Calculate some marginals for DiscreteValues==all1
|
||||
Vector marginals = Vector::Zero(15);
|
||||
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint82 = 0,
|
||||
joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, joint_4_11 = 0;
|
||||
for (size_t i = 0; i < self.assignments.size(); ++i) {
|
||||
DiscreteValues& x = self.assignments[i];
|
||||
double px = self.bayesTree->evaluate(x);
|
||||
for (size_t i = 0; i < 15; i++)
|
||||
if (x[i]) marginals[i] += px;
|
||||
if (x[12] && x[14]) {
|
||||
joint_12_14 += px;
|
||||
if (x[9]) joint_9_12_14 += px;
|
||||
if (x[8]) joint_8_12_14 += px;
|
||||
}
|
||||
if (x[2]) {
|
||||
if (x[8]) joint82 += px;
|
||||
if (x[1]) joint12 += px;
|
||||
}
|
||||
if (x[4]) {
|
||||
if (x[2]) joint24 += px;
|
||||
if (x[5]) joint45 += px;
|
||||
if (x[6]) joint46 += px;
|
||||
if (x[11]) joint_4_11 += px;
|
||||
}
|
||||
}
|
||||
|
||||
// regression tests:
|
||||
DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9);
|
||||
DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9);
|
||||
DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9);
|
||||
|
||||
DiscreteValues all1 = self.assignments.back();
|
||||
DiscreteBayesNet::shared_ptr actualJoint;
|
||||
|
||||
// Check joint P(8, 2)
|
||||
|
@ -240,7 +285,7 @@ TEST(DiscreteBayesTree, ThinTree) {
|
|||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteBayesTree, Dot) {
|
||||
const TestFixture self;
|
||||
TestFixture self;
|
||||
string actual = self.bayesTree->dot();
|
||||
EXPECT(actual ==
|
||||
"digraph G{\n"
|
||||
|
|
Loading…
Reference in New Issue