Add test for combine
parent
52acceebc9
commit
ae8d79cb3c
|
|
@ -174,6 +174,70 @@ TEST(DecisionTreeFactor, Prune) {
|
||||||
EXPECT(assert_equal(expected3, pruned3));
|
EXPECT(assert_equal(expected3, pruned3));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Asia Bayes Network
|
||||||
|
/* ************************************************************************** */
|
||||||
|
|
||||||
|
#define DISABLE_DOT
|
||||||
|
|
||||||
|
void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) {
|
||||||
|
#ifndef DISABLE_DOT
|
||||||
|
std::vector<std::string> names = {"A", "S", "T", "L", "B", "E", "X", "D"};
|
||||||
|
auto formatter = [&](Key key) { return names[key]; };
|
||||||
|
f.dot(filename, formatter, true);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Convert Signature into CPT */
|
||||||
|
DecisionTreeFactor create(const Signature& signature) {
|
||||||
|
DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// test Asia Joint
|
||||||
|
TEST(DecisionTreeFactor, joint) {
|
||||||
|
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
|
||||||
|
D(7, 2);
|
||||||
|
|
||||||
|
gttic_(asiaCPTs);
|
||||||
|
DecisionTreeFactor pA = create(A % "99/1");
|
||||||
|
DecisionTreeFactor pS = create(S % "50/50");
|
||||||
|
DecisionTreeFactor pT = create(T | A = "99/1 95/5");
|
||||||
|
DecisionTreeFactor pL = create(L | S = "99/1 90/10");
|
||||||
|
DecisionTreeFactor pB = create(B | S = "70/30 40/60");
|
||||||
|
DecisionTreeFactor pE = create((E | T, L) = "F T T T");
|
||||||
|
DecisionTreeFactor pX = create(X | E = "95/5 2/98");
|
||||||
|
DecisionTreeFactor pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
|
||||||
|
|
||||||
|
// Create joint
|
||||||
|
gttic_(asiaJoint);
|
||||||
|
DecisionTreeFactor joint = pA;
|
||||||
|
maybeSaveDotFile(joint, "Asia-A");
|
||||||
|
joint = joint * pS;
|
||||||
|
maybeSaveDotFile(joint, "Asia-AS");
|
||||||
|
joint = joint * pT;
|
||||||
|
maybeSaveDotFile(joint, "Asia-AST");
|
||||||
|
joint = joint * pL;
|
||||||
|
maybeSaveDotFile(joint, "Asia-ASTL");
|
||||||
|
joint = joint * pB;
|
||||||
|
maybeSaveDotFile(joint, "Asia-ASTLB");
|
||||||
|
joint = joint * pE;
|
||||||
|
maybeSaveDotFile(joint, "Asia-ASTLBE");
|
||||||
|
joint = joint * pX;
|
||||||
|
maybeSaveDotFile(joint, "Asia-ASTLBEX");
|
||||||
|
joint = joint * pD;
|
||||||
|
maybeSaveDotFile(joint, "Asia-ASTLBEXD");
|
||||||
|
|
||||||
|
// Check that discrete keys are as expected
|
||||||
|
EXPECT(assert_equal(joint.discreteKeys(), {A, S, T, L, B, E, X, D}));
|
||||||
|
|
||||||
|
// Check that summing out variables maintains the keys even if merged, as is
|
||||||
|
// the case with S.
|
||||||
|
auto noAB = joint.sum(Ordering{A.first, B.first});
|
||||||
|
EXPECT(assert_equal(noAB->discreteKeys(), {S, T, L, E, X, D}));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DecisionTreeFactor, DotWithNames) {
|
TEST(DecisionTreeFactor, DotWithNames) {
|
||||||
DiscreteKey A(12, 3), B(5, 2);
|
DiscreteKey A(12, 3), B(5, 2);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue