diff --git a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp index f1b661e6f..ffb1f0b5a 100644 --- a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp +++ b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp @@ -133,20 +133,25 @@ ADT create(const Signature& signature) { return p; } +/* ************************************************************************* */ +namespace asiaCPTs { +DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), + D(7, 2); + +ADT pA = create(A % "99/1"); +ADT pS = create(S % "50/50"); +ADT pT = create(T | A = "99/1 95/5"); +ADT pL = create(L | S = "99/1 90/10"); +ADT pB = create(B | S = "70/30 40/60"); +ADT pE = create((E | T, L) = "F T T T"); +ADT pX = create(X | E = "95/5 2/98"); +ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9"); +} // namespace asiaCPTs + /* ************************************************************************* */ // test Asia Joint TEST(ADT, joint) { - DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), - D(7, 2); - - ADT pA = create(A % "99/1"); - ADT pS = create(S % "50/50"); - ADT pT = create(T | A = "99/1 95/5"); - ADT pL = create(L | S = "99/1 90/10"); - ADT pB = create(B | S = "70/30 40/60"); - ADT pE = create((E | T, L) = "F T T T"); - ADT pX = create(X | E = "95/5 2/98"); - ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9"); + using namespace asiaCPTs; // Create joint resetCounts(); @@ -172,6 +177,11 @@ TEST(ADT, joint) { EXPECT_LONGS_EQUAL(508, muls); #endif printCounts("Asia joint"); +} + +/* ************************************************************************* */ +TEST(ADT, combine) { + using namespace asiaCPTs; // Form P(A,S,T,L) = P(A) P(S) P(T|A) P(L|S) ADT pASTL = pA; @@ -187,7 +197,7 @@ TEST(ADT, joint) { } /* ************************************************************************* */ -// test Inference with joint +// test Inference with joint, created using different ordering TEST(ADT, inference) { DiscreteKey A(0, 2), D(1, 2), // B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2); @@ -248,7 +258,6 @@ TEST(ADT, inference) { TEST(ADT, factor_graph) { DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2); - resetCounts(); ADT pS = create(S % "50/50"); ADT pT = create(T % "95/5"); ADT pL = create(L | S = "99/1 90/10"); @@ -256,7 +265,6 @@ TEST(ADT, factor_graph) { ADT pX = create(X | E = "95/5 2/98"); ADT pD = create(B | E = "1/8 7/9"); ADT pB = create(B | S = "70/30 40/60"); - printCounts("Create CPTs"); // Create joint resetCounts();