refactor into smaller tests

release/4.3a0
Frank Dellaert 2023-06-10 09:39:13 -07:00
parent 3fd5c2501e
commit 9857e62a56
1 changed files with 109 additions and 64 deletions

View File

@ -16,10 +16,10 @@
*/
#include <gtsam/base/Vector.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/inference/BayesNet.h>
#include <CppUnitLite/TestHarness.h>
@ -32,7 +32,8 @@ static constexpr bool debug = false;
/* ************************************************************************* */
struct TestFixture {
vector<DiscreteKey> keys;
DiscreteKeys keys;
std::vector<DiscreteValues> assignments;
DiscreteBayesNet bayesNet;
boost::shared_ptr<DiscreteBayesTree> bayesTree;
@ -47,6 +48,9 @@ struct TestFixture {
keys.push_back(key_i);
}
// Enumerate all assignments.
assignments = DiscreteValues::CartesianProduct(keys);
// Create thin-tree Bayesnet.
bayesNet.add(keys[14] % "1/3");
@ -74,9 +78,9 @@ struct TestFixture {
};
/* ************************************************************************* */
// Check that BN and BT give the same answer on all configurations
TEST(DiscreteBayesTree, ThinTree) {
const TestFixture self;
const auto& keys = self.keys;
TestFixture self;
if (debug) {
GTSAM_PRINT(self.bayesNet);
@ -95,47 +99,56 @@ TEST(DiscreteBayesTree, ThinTree) {
EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals()));
}
auto R = self.bayesTree->roots().front();
// Check whether BN and BT give the same answer on all configurations
auto allPosbValues = DiscreteValues::CartesianProduct(
keys[0] & keys[1] & keys[2] & keys[3] & keys[4] & keys[5] & keys[6] &
keys[7] & keys[8] & keys[9] & keys[10] & keys[11] & keys[12] & keys[13] &
keys[14]);
for (size_t i = 0; i < allPosbValues.size(); ++i) {
DiscreteValues x = allPosbValues[i];
for (const auto& x : self.assignments) {
double expected = self.bayesNet.evaluate(x);
double actual = self.bayesTree->evaluate(x);
DOUBLES_EQUAL(expected, actual, 1e-9);
}
}
// Calculate all some marginals for DiscreteValues==all1
Vector marginals = Vector::Zero(15);
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0,
joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
for (size_t i = 0; i < allPosbValues.size(); ++i) {
DiscreteValues x = allPosbValues[i];
/* ************************************************************************* */
// Check calculation of separator marginals
TEST(DiscreteBayesTree, separatorMarginal) {
TestFixture self;
// Calculate some marginals for DiscreteValues==all1
double marginal_14 = 0, joint_8_12 = 0;
for (auto& x : self.assignments) {
double px = self.bayesTree->evaluate(x);
for (size_t i = 0; i < 15; i++)
if (x[i]) marginals[i] += px;
if (x[12] && x[14]) {
joint_12_14 += px;
if (x[9]) joint_9_12_14 += px;
if (x[8]) joint_8_12_14 += px;
}
if (x[8] && x[12]) joint_8_12 += px;
if (x[2]) {
if (x[8]) joint82 += px;
if (x[1]) joint12 += px;
}
if (x[4]) {
if (x[2]) joint24 += px;
if (x[5]) joint45 += px;
if (x[6]) joint46 += px;
if (x[11]) joint_4_11 += px;
}
if (x[14]) marginal_14 += px;
}
DiscreteValues all1 = self.assignments.back();
// check separator marginal P(S0)
auto clique = (*self.bayesTree)[0];
DiscreteFactorGraph separatorMarginal0 =
clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
// check separator marginal P(S9), should be P(14)
clique = (*self.bayesTree)[9];
DiscreteFactorGraph separatorMarginal9 =
clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(marginal_14, separatorMarginal9(all1), 1e-9);
// check separator marginal of root, should be empty
clique = (*self.bayesTree)[11];
DiscreteFactorGraph separatorMarginal11 =
clique->separatorMarginal(EliminateDiscrete);
LONGS_EQUAL(0, separatorMarginal11.size());
}
/* ************************************************************************* */
// Check shortcuts in the tree
TEST(DiscreteBayesTree, shortcut) {
TestFixture self;
// Calculate some marginals for DiscreteValues==all1
double joint_11_13 = 0, joint_11_13_14 = 0, joint_11_12_13_14 = 0,
joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
for (auto& x : self.assignments) {
double px = self.bayesTree->evaluate(x);
if (x[11] && x[13]) {
joint_11_13 += px;
if (x[8] && x[12]) joint_8_11_12_13 += px;
@ -148,32 +161,12 @@ TEST(DiscreteBayesTree, ThinTree) {
}
}
}
DiscreteValues all1 = allPosbValues.back();
DiscreteValues all1 = self.assignments.back();
// check separator marginal P(S0)
auto clique = (*self.bayesTree)[0];
DiscreteFactorGraph separatorMarginal0 =
clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9);
DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9);
DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9);
// check separator marginal P(S9), should be P(14)
clique = (*self.bayesTree)[9];
DiscreteFactorGraph separatorMarginal9 =
clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
// check separator marginal of root, should be empty
clique = (*self.bayesTree)[11];
DiscreteFactorGraph separatorMarginal11 =
clique->separatorMarginal(EliminateDiscrete);
LONGS_EQUAL(0, separatorMarginal11.size());
auto R = self.bayesTree->roots().front();
// check shortcut P(S9||R) to root
clique = (*self.bayesTree)[9];
auto clique = (*self.bayesTree)[9];
DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete);
LONGS_EQUAL(1, shortcut.size());
DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
@ -202,15 +195,67 @@ TEST(DiscreteBayesTree, ThinTree) {
shortcut.print("shortcut:");
}
}
}
/* ************************************************************************* */
// Check all marginals
TEST(DiscreteBayesTree, marginalFactor) {
TestFixture self;
Vector marginals = Vector::Zero(15);
for (size_t i = 0; i < self.assignments.size(); ++i) {
DiscreteValues& x = self.assignments[i];
double px = self.bayesTree->evaluate(x);
for (size_t i = 0; i < 15; i++)
if (x[i]) marginals[i] += px;
}
// Check all marginals
DiscreteFactor::shared_ptr marginalFactor;
DiscreteValues all1 = self.assignments.back();
for (size_t i = 0; i < 15; i++) {
marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete);
auto marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete);
double actual = (*marginalFactor)(all1);
DOUBLES_EQUAL(marginals[i], actual, 1e-9);
}
}
/* ************************************************************************* */
// Check a number of joint marginals.
TEST(DiscreteBayesTree, Joints) {
TestFixture self;
// Calculate some marginals for DiscreteValues==all1
Vector marginals = Vector::Zero(15);
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint82 = 0,
joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, joint_4_11 = 0;
for (size_t i = 0; i < self.assignments.size(); ++i) {
DiscreteValues& x = self.assignments[i];
double px = self.bayesTree->evaluate(x);
for (size_t i = 0; i < 15; i++)
if (x[i]) marginals[i] += px;
if (x[12] && x[14]) {
joint_12_14 += px;
if (x[9]) joint_9_12_14 += px;
if (x[8]) joint_8_12_14 += px;
}
if (x[2]) {
if (x[8]) joint82 += px;
if (x[1]) joint12 += px;
}
if (x[4]) {
if (x[2]) joint24 += px;
if (x[5]) joint45 += px;
if (x[6]) joint46 += px;
if (x[11]) joint_4_11 += px;
}
}
// regression tests:
DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9);
DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9);
DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9);
DiscreteValues all1 = self.assignments.back();
DiscreteBayesNet::shared_ptr actualJoint;
// Check joint P(8, 2)
@ -240,7 +285,7 @@ TEST(DiscreteBayesTree, ThinTree) {
/* ************************************************************************* */
TEST(DiscreteBayesTree, Dot) {
const TestFixture self;
TestFixture self;
string actual = self.bayesTree->dot();
EXPECT(actual ==
"digraph G{\n"