diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index e6501c3f7..a41d06c2b 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -174,6 +174,70 @@ TEST(DecisionTreeFactor, Prune) { EXPECT(assert_equal(expected3, pruned3)); } +/* ************************************************************************** */ +// Asia Bayes Network +/* ************************************************************************** */ + +#define DISABLE_DOT + +void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) { +#ifndef DISABLE_DOT + std::vector names = {"A", "S", "T", "L", "B", "E", "X", "D"}; + auto formatter = [&](Key key) { return names[key]; }; + f.dot(filename, formatter, true); +#endif +} + +/** Convert Signature into CPT */ +DecisionTreeFactor create(const Signature& signature) { + DecisionTreeFactor p(signature.discreteKeys(), signature.cpt()); + return p; +} + +/* ************************************************************************* */ +// test Asia Joint +TEST(DecisionTreeFactor, joint) { + DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), + D(7, 2); + + gttic_(asiaCPTs); + DecisionTreeFactor pA = create(A % "99/1"); + DecisionTreeFactor pS = create(S % "50/50"); + DecisionTreeFactor pT = create(T | A = "99/1 95/5"); + DecisionTreeFactor pL = create(L | S = "99/1 90/10"); + DecisionTreeFactor pB = create(B | S = "70/30 40/60"); + DecisionTreeFactor pE = create((E | T, L) = "F T T T"); + DecisionTreeFactor pX = create(X | E = "95/5 2/98"); + DecisionTreeFactor pD = create((D | E, B) = "9/1 2/8 3/7 1/9"); + + // Create joint + gttic_(asiaJoint); + DecisionTreeFactor joint = pA; + maybeSaveDotFile(joint, "Asia-A"); + joint = joint * pS; + maybeSaveDotFile(joint, "Asia-AS"); + joint = joint * pT; + maybeSaveDotFile(joint, "Asia-AST"); + joint = joint * pL; + maybeSaveDotFile(joint, "Asia-ASTL"); + joint = joint * pB; + maybeSaveDotFile(joint, "Asia-ASTLB"); + joint = joint * pE; + maybeSaveDotFile(joint, "Asia-ASTLBE"); + joint = joint * pX; + maybeSaveDotFile(joint, "Asia-ASTLBEX"); + joint = joint * pD; + maybeSaveDotFile(joint, "Asia-ASTLBEXD"); + + // Check that discrete keys are as expected + EXPECT(assert_equal(joint.discreteKeys(), {A, S, T, L, B, E, X, D})); + + // Check that summing out variables maintains the keys even if merged, as is + // the case with S. + auto noAB = joint.sum(Ordering{A.first, B.first}); + EXPECT(assert_equal(noAB->discreteKeys(), {S, T, L, E, X, D})); +} + /* ************************************************************************* */ TEST(DecisionTreeFactor, DotWithNames) { DiscreteKey A(12, 3), B(5, 2);