Add test for combine

release/4.3a0
Frank Dellaert 2024-09-23 18:35:36 -07:00
parent 52acceebc9
commit ae8d79cb3c
1 changed files with 64 additions and 0 deletions

View File

@ -174,6 +174,70 @@ TEST(DecisionTreeFactor, Prune) {
EXPECT(assert_equal(expected3, pruned3));
}
/* ************************************************************************** */
// Asia Bayes Network
/* ************************************************************************** */
#define DISABLE_DOT
void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) {
#ifndef DISABLE_DOT
std::vector<std::string> names = {"A", "S", "T", "L", "B", "E", "X", "D"};
auto formatter = [&](Key key) { return names[key]; };
f.dot(filename, formatter, true);
#endif
}
/** Convert Signature into CPT */
DecisionTreeFactor create(const Signature& signature) {
DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
return p;
}
/* ************************************************************************* */
// test Asia Joint
TEST(DecisionTreeFactor, joint) {
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
D(7, 2);
gttic_(asiaCPTs);
DecisionTreeFactor pA = create(A % "99/1");
DecisionTreeFactor pS = create(S % "50/50");
DecisionTreeFactor pT = create(T | A = "99/1 95/5");
DecisionTreeFactor pL = create(L | S = "99/1 90/10");
DecisionTreeFactor pB = create(B | S = "70/30 40/60");
DecisionTreeFactor pE = create((E | T, L) = "F T T T");
DecisionTreeFactor pX = create(X | E = "95/5 2/98");
DecisionTreeFactor pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
// Create joint
gttic_(asiaJoint);
DecisionTreeFactor joint = pA;
maybeSaveDotFile(joint, "Asia-A");
joint = joint * pS;
maybeSaveDotFile(joint, "Asia-AS");
joint = joint * pT;
maybeSaveDotFile(joint, "Asia-AST");
joint = joint * pL;
maybeSaveDotFile(joint, "Asia-ASTL");
joint = joint * pB;
maybeSaveDotFile(joint, "Asia-ASTLB");
joint = joint * pE;
maybeSaveDotFile(joint, "Asia-ASTLBE");
joint = joint * pX;
maybeSaveDotFile(joint, "Asia-ASTLBEX");
joint = joint * pD;
maybeSaveDotFile(joint, "Asia-ASTLBEXD");
// Check that discrete keys are as expected
EXPECT(assert_equal(joint.discreteKeys(), {A, S, T, L, B, E, X, D}));
// Check that summing out variables maintains the keys even if merged, as is
// the case with S.
auto noAB = joint.sum(Ordering{A.first, B.first});
EXPECT(assert_equal(noAB->discreteKeys(), {S, T, L, E, X, D}));
}
/* ************************************************************************* */
TEST(DecisionTreeFactor, DotWithNames) {
DiscreteKey A(12, 3), B(5, 2);