Fix tests
parent
fff14ab0b7
commit
2b1f51f098
|
@ -28,7 +28,7 @@
|
|||
#include <vector>
|
||||
|
||||
using namespace gtsam;
|
||||
static constexpr bool debug = false;
|
||||
static constexpr bool debug = true;
|
||||
|
||||
/* ************************************************************************* */
|
||||
struct TestFixture {
|
||||
|
@ -186,11 +186,11 @@ TEST(DiscreteBayesTree, Shortcuts) {
|
|||
shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||
|
||||
// calculate all shortcuts to root
|
||||
if (debug) {
|
||||
// print all shortcuts to root
|
||||
DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes();
|
||||
for (auto clique : cliques) {
|
||||
DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete);
|
||||
if (debug) {
|
||||
clique.second->conditional_->printSignature();
|
||||
shortcut.print("shortcut:");
|
||||
}
|
||||
|
@ -202,6 +202,7 @@ TEST(DiscreteBayesTree, Shortcuts) {
|
|||
TEST(DiscreteBayesTree, MarginalFactors) {
|
||||
TestFixture self;
|
||||
|
||||
// Caclulate marginals with brute force enumeration.
|
||||
Vector marginals = Vector::Zero(15);
|
||||
for (size_t i = 0; i < self.assignments.size(); ++i) {
|
||||
DiscreteValues& x = self.assignments[i];
|
||||
|
@ -287,6 +288,8 @@ TEST(DiscreteBayesTree, Joints) {
|
|||
TEST(DiscreteBayesTree, Dot) {
|
||||
TestFixture self;
|
||||
std::string actual = self.bayesTree->dot();
|
||||
// print actual:
|
||||
if (debug) std::cout << actual << std::endl;
|
||||
EXPECT(actual ==
|
||||
"digraph G{\n"
|
||||
"0[label=\"13, 11, 6, 7\"];\n"
|
||||
|
@ -374,18 +377,18 @@ TEST(DiscreteBayesTree, Lookup) {
|
|||
TEST(DiscreteBayesTree, DirectFromCliques) {
|
||||
// Create a BayesNet
|
||||
DiscreteBayesNet bayesNet;
|
||||
DiscreteKey key0(0, 2), key1(1, 2), key2(2, 2);
|
||||
bayesNet.add(key0 % "1/3");
|
||||
bayesNet.add(key1 | key0 = "1/3 3/1");
|
||||
bayesNet.add(key2 | key1 = "3/1 3/1");
|
||||
DiscreteKey A(0, 2), B(1, 2), C(2, 2);
|
||||
bayesNet.add(A % "1/3");
|
||||
bayesNet.add(B | A = "1/3 3/1");
|
||||
bayesNet.add(C | B = "3/1 3/1");
|
||||
|
||||
// Create cliques directly
|
||||
auto clique2 = std::make_shared<DiscreteBayesTree::Clique>(
|
||||
std::make_shared<DiscreteConditional>(key2 | key1 = "3/1 3/1"));
|
||||
std::make_shared<DiscreteConditional>(C | B = "3/1 3/1"));
|
||||
auto clique1 = std::make_shared<DiscreteBayesTree::Clique>(
|
||||
std::make_shared<DiscreteConditional>(key1 | key0 = "1/3 3/1"));
|
||||
std::make_shared<DiscreteConditional>(B | A = "1/3 3/1"));
|
||||
auto clique0 = std::make_shared<DiscreteBayesTree::Clique>(
|
||||
std::make_shared<DiscreteConditional>(key0 % "1/3"));
|
||||
std::make_shared<DiscreteConditional>(A % "1/3"));
|
||||
|
||||
// Create a BayesTree
|
||||
DiscreteBayesTree bayesTree;
|
||||
|
@ -395,13 +398,13 @@ TEST(DiscreteBayesTree, DirectFromCliques) {
|
|||
|
||||
// Check that the BayesTree is correct
|
||||
DiscreteValues values;
|
||||
values[0] = 1;
|
||||
values[1] = 1;
|
||||
values[2] = 1;
|
||||
values[A.first] = 1;
|
||||
values[A.first] = 1;
|
||||
values[A.first] = 1;
|
||||
|
||||
double expected = bayesNet.evaluate(values);
|
||||
double actual = bayesTree.evaluate(values);
|
||||
DOUBLES_EQUAL(expected, actual, 1e-9);
|
||||
// Regression
|
||||
double expected = .046875;
|
||||
DOUBLES_EQUAL(expected, bayesTree.evaluate(values), 1e-9);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -160,15 +160,15 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
"""Test creating a Bayes tree directly from cliques."""
|
||||
# Create a BayesNet
|
||||
bayesNet = DiscreteBayesNet()
|
||||
key0, key1, key2 = (0, 2), (1, 2), (2, 2)
|
||||
bayesNet.add(key0, "1/3")
|
||||
bayesNet.add(key1, [key0], "1/3 3/1")
|
||||
bayesNet.add(key2, [key1], "3/1 3/1")
|
||||
A, B, C = (0, 2), (1, 2), (2, 2)
|
||||
bayesNet.add(A, "1/3")
|
||||
bayesNet.add(B, [A], "1/3 3/1")
|
||||
bayesNet.add(C, [B], "3/1 3/1")
|
||||
|
||||
# Create cliques directly
|
||||
clique2 = DiscreteBayesTreeClique(DiscreteConditional(key2, [key1], "3/1 3/1"))
|
||||
clique1 = DiscreteBayesTreeClique(DiscreteConditional(key1, [key0], "1/3 3/1"))
|
||||
clique0 = DiscreteBayesTreeClique(DiscreteConditional(key0, "1/3"))
|
||||
clique2 = DiscreteBayesTreeClique(DiscreteConditional(C, [B], "3/1 3/1"))
|
||||
clique1 = DiscreteBayesTreeClique(DiscreteConditional(B, [A], "1/3 3/1"))
|
||||
clique0 = DiscreteBayesTreeClique(DiscreteConditional(A, "1/3"))
|
||||
|
||||
# Create a BayesTree
|
||||
bayesTree = gtsam.DiscreteBayesTree()
|
||||
|
@ -182,9 +182,9 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
values[1] = 1
|
||||
values[2] = 1
|
||||
|
||||
expected = bayesNet.evaluate(values)
|
||||
actual = bayesTree.evaluate(values)
|
||||
self.assertAlmostEqual(expected, actual, places=9)
|
||||
# regression
|
||||
expected = .046875
|
||||
self.assertAlmostEqual(expected, bayesNet.evaluate(values))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue