fix tests to work when GTSAM_DT_MERGING=OFF
parent
3d7163a995
commit
b24f20afe1
|
@ -178,7 +178,11 @@ TEST(ADT, joint) {
|
|||
dot(joint, "Asia-ASTLBEX");
|
||||
joint = apply(joint, pD, &mul);
|
||||
dot(joint, "Asia-ASTLBEXD");
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
EXPECT_LONGS_EQUAL(346, muls);
|
||||
#else
|
||||
EXPECT_LONGS_EQUAL(508, muls);
|
||||
#endif
|
||||
gttoc_(asiaJoint);
|
||||
tictoc_getNode(asiaJointNode, asiaJoint);
|
||||
elapsed = asiaJointNode->secs() + asiaJointNode->wall();
|
||||
|
@ -239,7 +243,11 @@ TEST(ADT, inference) {
|
|||
dot(joint, "Joint-Product-ASTLBEX");
|
||||
joint = apply(joint, pD, &mul);
|
||||
dot(joint, "Joint-Product-ASTLBEXD");
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering
|
||||
#else
|
||||
EXPECT_LONGS_EQUAL(508, (long)muls); // different ordering
|
||||
#endif
|
||||
gttoc_(asiaProd);
|
||||
tictoc_getNode(asiaProdNode, asiaProd);
|
||||
elapsed = asiaProdNode->secs() + asiaProdNode->wall();
|
||||
|
@ -257,7 +265,11 @@ TEST(ADT, inference) {
|
|||
dot(marginal, "Joint-Sum-ADBLE");
|
||||
marginal = marginal.combine(E, &add_);
|
||||
dot(marginal, "Joint-Sum-ADBL");
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
EXPECT_LONGS_EQUAL(161, (long)adds);
|
||||
#else
|
||||
EXPECT_LONGS_EQUAL(240, (long)adds);
|
||||
#endif
|
||||
gttoc_(asiaSum);
|
||||
tictoc_getNode(asiaSumNode, asiaSum);
|
||||
elapsed = asiaSumNode->secs() + asiaSumNode->wall();
|
||||
|
@ -295,7 +307,11 @@ TEST(ADT, factor_graph) {
|
|||
fg = apply(fg, pX, &mul);
|
||||
fg = apply(fg, pD, &mul);
|
||||
dot(fg, "FactorGraph");
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
EXPECT_LONGS_EQUAL(158, (long)muls);
|
||||
#else
|
||||
EXPECT_LONGS_EQUAL(188, (long)muls);
|
||||
#endif
|
||||
gttoc_(asiaFG);
|
||||
tictoc_getNode(asiaFGNode, asiaFG);
|
||||
elapsed = asiaFGNode->secs() + asiaFGNode->wall();
|
||||
|
@ -314,7 +330,11 @@ TEST(ADT, factor_graph) {
|
|||
dot(fg, "Marginalized-3E");
|
||||
fg = fg.combine(L, &add_);
|
||||
dot(fg, "Marginalized-2L");
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
LONGS_EQUAL(49, adds);
|
||||
#else
|
||||
LONGS_EQUAL(62, adds);
|
||||
#endif
|
||||
gttoc_(marg);
|
||||
tictoc_getNode(margNode, marg);
|
||||
elapsed = margNode->secs() + margNode->wall();
|
||||
|
|
|
@ -191,7 +191,11 @@ TEST(DecisionTree, example) {
|
|||
|
||||
// Test choose 0
|
||||
DT actual0 = notba.choose(A, 0);
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
EXPECT(assert_equal(DT(0.0), actual0));
|
||||
#else
|
||||
// EXPECT(assert_equal(DT({0.0, 0.0}), actual0));
|
||||
#endif
|
||||
DOT(actual0);
|
||||
|
||||
// Test choose 1
|
||||
|
@ -332,9 +336,11 @@ TEST(DecisionTree, NrAssignments) {
|
|||
|
||||
EXPECT_LONGS_EQUAL(8, tree.nrAssignments());
|
||||
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
EXPECT(tree.root_->isLeaf());
|
||||
auto leaf = std::dynamic_pointer_cast<const DT::Leaf>(tree.root_);
|
||||
EXPECT_LONGS_EQUAL(8, leaf->nrAssignments());
|
||||
#endif
|
||||
|
||||
DT tree2({C, B, A}, "1 1 1 2 3 4 5 5");
|
||||
/* The tree is
|
||||
|
@ -357,6 +363,8 @@ TEST(DecisionTree, NrAssignments) {
|
|||
CHECK(root);
|
||||
auto choice0 = std::dynamic_pointer_cast<const DT::Choice>(root->branches()[0]);
|
||||
CHECK(choice0);
|
||||
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
EXPECT(choice0->branches()[0]->isLeaf());
|
||||
auto choice00 = std::dynamic_pointer_cast<const DT::Leaf>(choice0->branches()[0]);
|
||||
CHECK(choice00);
|
||||
|
@ -370,6 +378,7 @@ TEST(DecisionTree, NrAssignments) {
|
|||
CHECK(choice11);
|
||||
EXPECT(choice11->isLeaf());
|
||||
EXPECT_LONGS_EQUAL(2, choice11->nrAssignments());
|
||||
#endif
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
|
@ -411,27 +420,61 @@ TEST(DecisionTree, VisitWithPruned) {
|
|||
};
|
||||
tree.visitWith(func);
|
||||
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
EXPECT_LONGS_EQUAL(6, choices.size());
|
||||
#else
|
||||
EXPECT_LONGS_EQUAL(8, choices.size());
|
||||
#endif
|
||||
|
||||
Assignment<string> expectedAssignment;
|
||||
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
expectedAssignment = {{"B", 0}, {"C", 0}};
|
||||
EXPECT(expectedAssignment == choices.at(0));
|
||||
#else
|
||||
expectedAssignment = {{"A", 0}, {"B", 0}, {"C", 0}};
|
||||
EXPECT(expectedAssignment == choices.at(0));
|
||||
#endif
|
||||
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 0}};
|
||||
EXPECT(expectedAssignment == choices.at(1));
|
||||
#else
|
||||
expectedAssignment = {{"A", 1}, {"B", 0}, {"C", 0}};
|
||||
EXPECT(expectedAssignment == choices.at(1));
|
||||
#endif
|
||||
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 0}};
|
||||
EXPECT(expectedAssignment == choices.at(2));
|
||||
#else
|
||||
expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 0}};
|
||||
EXPECT(expectedAssignment == choices.at(2));
|
||||
#endif
|
||||
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
expectedAssignment = {{"B", 0}, {"C", 1}};
|
||||
EXPECT(expectedAssignment == choices.at(3));
|
||||
#else
|
||||
expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 0}};
|
||||
EXPECT(expectedAssignment == choices.at(3));
|
||||
#endif
|
||||
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 1}};
|
||||
EXPECT(expectedAssignment == choices.at(4));
|
||||
#else
|
||||
expectedAssignment = {{"A", 0}, {"B", 0}, {"C", 1}};
|
||||
EXPECT(expectedAssignment == choices.at(4));
|
||||
#endif
|
||||
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 1}};
|
||||
EXPECT(expectedAssignment == choices.at(5));
|
||||
#else
|
||||
expectedAssignment = {{"A", 1}, {"B", 0}, {"C", 1}};
|
||||
EXPECT(expectedAssignment == choices.at(5));
|
||||
#endif
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
|
@ -442,7 +485,11 @@ TEST(DecisionTree, fold) {
|
|||
DT tree(B, DT(A, 1, 1), DT(A, 2, 3));
|
||||
auto add = [](const int& y, double x) { return y + x; };
|
||||
double sum = tree.fold(add, 0.0);
|
||||
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); // Note, not 7, due to pruning!
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); // Note, not 7, due to merging!
|
||||
#else
|
||||
EXPECT_DOUBLES_EQUAL(7.0, sum, 1e-9);
|
||||
#endif
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
|
@ -494,9 +541,14 @@ TEST(DecisionTree, threshold) {
|
|||
auto threshold = [](int value) { return value < 5 ? 0 : value; };
|
||||
DT thresholded(tree, threshold);
|
||||
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
// Check number of leaves equal to zero now = 2
|
||||
// Note: it is 2, because the pruned branches are counted as 1!
|
||||
EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0));
|
||||
#else
|
||||
// if GTSAM_DT_MERGING is disabled, the count will be larger
|
||||
EXPECT_LONGS_EQUAL(5, thresholded.fold(count, 0));
|
||||
#endif
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
|
@ -532,8 +584,13 @@ TEST(DecisionTree, ApplyWithAssignment) {
|
|||
};
|
||||
DT prunedTree2 = prunedTree.apply(counter);
|
||||
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
// Check if apply doesn't enumerate all leaves.
|
||||
EXPECT_LONGS_EQUAL(5, count);
|
||||
#else
|
||||
// if GTSAM_DT_MERGING is disabled, the count will be full
|
||||
EXPECT_LONGS_EQUAL(8, count);
|
||||
#endif
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
|
|
|
@ -350,6 +350,7 @@ TEST(DiscreteFactorGraph, markdown) {
|
|||
}
|
||||
|
||||
TEST(DiscreteFactorGraph, NrAssignments) {
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
string expected_dfg = R"(
|
||||
size: 2
|
||||
factor 0: f[ (m0,2), (m1,2), (m2,2), ]
|
||||
|
@ -387,6 +388,59 @@ factor 1: f[ (m0,2), (m1,2), (m2,2), (m3,2), ]
|
|||
1 1 1 0 Leaf [1] 1
|
||||
1 1 1 1 Leaf [1]0.009321081
|
||||
)";
|
||||
#else
|
||||
string expected_dfg = R"(
|
||||
size: 2
|
||||
factor 0: f[ (m0,2), (m1,2), (m2,2), ]
|
||||
Choice(m2)
|
||||
0 Choice(m1)
|
||||
0 0 Choice(m0)
|
||||
0 0 0 Leaf [1] 0
|
||||
0 0 1 Leaf [1] 0
|
||||
0 1 Choice(m0)
|
||||
0 1 0 Leaf [1]0.27527634
|
||||
0 1 1 Leaf [1]0.44944733
|
||||
1 Choice(m1)
|
||||
1 0 Choice(m0)
|
||||
1 0 0 Leaf [1] 0
|
||||
1 0 1 Leaf [1] 0
|
||||
1 1 Choice(m0)
|
||||
1 1 0 Leaf [1] 0
|
||||
1 1 1 Leaf [1]0.27527634
|
||||
factor 1: f[ (m0,2), (m1,2), (m2,2), (m3,2), ]
|
||||
Choice(m3)
|
||||
0 Choice(m2)
|
||||
0 0 Choice(m1)
|
||||
0 0 0 Choice(m0)
|
||||
0 0 0 0 Leaf [1] 1
|
||||
0 0 0 1 Leaf [1] 1
|
||||
0 0 1 Choice(m0)
|
||||
0 0 1 0 Leaf [1]0.015366387
|
||||
0 0 1 1 Leaf [1]0.015366387
|
||||
0 1 Choice(m1)
|
||||
0 1 0 Choice(m0)
|
||||
0 1 0 0 Leaf [1] 1
|
||||
0 1 0 1 Leaf [1] 1
|
||||
0 1 1 Choice(m0)
|
||||
0 1 1 0 Leaf [1] 1
|
||||
0 1 1 1 Leaf [1]0.015365663
|
||||
1 Choice(m2)
|
||||
1 0 Choice(m1)
|
||||
1 0 0 Choice(m0)
|
||||
1 0 0 0 Leaf [1] 1
|
||||
1 0 0 1 Leaf [1] 1
|
||||
1 0 1 Choice(m0)
|
||||
1 0 1 0 Leaf [1]0.0094115739
|
||||
1 0 1 1 Leaf [1]0.0094115652
|
||||
1 1 Choice(m1)
|
||||
1 1 0 Choice(m0)
|
||||
1 1 0 0 Leaf [1] 1
|
||||
1 1 0 1 Leaf [1] 1
|
||||
1 1 1 Choice(m0)
|
||||
1 1 1 0 Leaf [1] 1
|
||||
1 1 1 1 Leaf [1]0.009321081
|
||||
)";
|
||||
#endif
|
||||
|
||||
DiscreteKeys d0{{M(2), 2}, {M(1), 2}, {M(0), 2}};
|
||||
std::vector<double> p0 = {0, 0, 0.17054468, 0.27845056, 0, 0, 0, 0.17054468};
|
||||
|
|
|
@ -288,8 +288,12 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
|||
std::make_shared<DecisionTreeFactor>(
|
||||
discreteConditionals->prune(maxNrLeaves));
|
||||
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
|
||||
prunedDecisionTree->nrLeaves());
|
||||
#else
|
||||
EXPECT_LONGS_EQUAL(8 /*full tree*/, prunedDecisionTree->nrLeaves());
|
||||
#endif
|
||||
|
||||
auto original_discrete_conditionals = *(posterior->at(4)->asDiscrete());
|
||||
|
||||
|
|
|
@ -481,6 +481,7 @@ TEST(HybridFactorGraph, Printing) {
|
|||
const auto [hybridBayesNet, remainingFactorGraph] =
|
||||
linearizedFactorGraph.eliminatePartialSequential(ordering);
|
||||
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
string expected_hybridFactorGraph = R"(
|
||||
size: 7
|
||||
factor 0:
|
||||
|
@ -562,6 +563,92 @@ factor 6: P( m1 | m0 ):
|
|||
1 1 Leaf [1] 0.4
|
||||
|
||||
)";
|
||||
#else
|
||||
string expected_hybridFactorGraph = R"(
|
||||
size: 7
|
||||
factor 0:
|
||||
A[x0] = [
|
||||
10
|
||||
]
|
||||
b = [ -10 ]
|
||||
No noise model
|
||||
factor 1:
|
||||
Hybrid [x0 x1; m0]{
|
||||
Choice(m0)
|
||||
0 Leaf [1]:
|
||||
A[x0] = [
|
||||
-1
|
||||
]
|
||||
A[x1] = [
|
||||
1
|
||||
]
|
||||
b = [ -1 ]
|
||||
No noise model
|
||||
|
||||
1 Leaf [1]:
|
||||
A[x0] = [
|
||||
-1
|
||||
]
|
||||
A[x1] = [
|
||||
1
|
||||
]
|
||||
b = [ -0 ]
|
||||
No noise model
|
||||
|
||||
}
|
||||
factor 2:
|
||||
Hybrid [x1 x2; m1]{
|
||||
Choice(m1)
|
||||
0 Leaf [1]:
|
||||
A[x1] = [
|
||||
-1
|
||||
]
|
||||
A[x2] = [
|
||||
1
|
||||
]
|
||||
b = [ -1 ]
|
||||
No noise model
|
||||
|
||||
1 Leaf [1]:
|
||||
A[x1] = [
|
||||
-1
|
||||
]
|
||||
A[x2] = [
|
||||
1
|
||||
]
|
||||
b = [ -0 ]
|
||||
No noise model
|
||||
|
||||
}
|
||||
factor 3:
|
||||
A[x1] = [
|
||||
10
|
||||
]
|
||||
b = [ -10 ]
|
||||
No noise model
|
||||
factor 4:
|
||||
A[x2] = [
|
||||
10
|
||||
]
|
||||
b = [ -10 ]
|
||||
No noise model
|
||||
factor 5: P( m0 ):
|
||||
Choice(m0)
|
||||
0 Leaf [1] 0.5
|
||||
1 Leaf [1] 0.5
|
||||
|
||||
factor 6: P( m1 | m0 ):
|
||||
Choice(m1)
|
||||
0 Choice(m0)
|
||||
0 0 Leaf [1]0.33333333
|
||||
0 1 Leaf [1] 0.6
|
||||
1 Choice(m0)
|
||||
1 0 Leaf [1]0.66666667
|
||||
1 1 Leaf [1] 0.4
|
||||
|
||||
)";
|
||||
#endif
|
||||
|
||||
EXPECT(assert_print_equal(expected_hybridFactorGraph, linearizedFactorGraph));
|
||||
|
||||
// Expected output for hybridBayesNet.
|
||||
|
|
Loading…
Reference in New Issue