Add test for combine
parent
52acceebc9
commit
ae8d79cb3c
|
|
@ -174,6 +174,70 @@ TEST(DecisionTreeFactor, Prune) {
|
|||
EXPECT(assert_equal(expected3, pruned3));
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
// Asia Bayes Network
|
||||
/* ************************************************************************** */
|
||||
|
||||
#define DISABLE_DOT
|
||||
|
||||
void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) {
|
||||
#ifndef DISABLE_DOT
|
||||
std::vector<std::string> names = {"A", "S", "T", "L", "B", "E", "X", "D"};
|
||||
auto formatter = [&](Key key) { return names[key]; };
|
||||
f.dot(filename, formatter, true);
|
||||
#endif
|
||||
}
|
||||
|
||||
/** Convert Signature into CPT */
|
||||
DecisionTreeFactor create(const Signature& signature) {
|
||||
DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
|
||||
return p;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// test Asia Joint
|
||||
TEST(DecisionTreeFactor, joint) {
|
||||
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
|
||||
D(7, 2);
|
||||
|
||||
gttic_(asiaCPTs);
|
||||
DecisionTreeFactor pA = create(A % "99/1");
|
||||
DecisionTreeFactor pS = create(S % "50/50");
|
||||
DecisionTreeFactor pT = create(T | A = "99/1 95/5");
|
||||
DecisionTreeFactor pL = create(L | S = "99/1 90/10");
|
||||
DecisionTreeFactor pB = create(B | S = "70/30 40/60");
|
||||
DecisionTreeFactor pE = create((E | T, L) = "F T T T");
|
||||
DecisionTreeFactor pX = create(X | E = "95/5 2/98");
|
||||
DecisionTreeFactor pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
|
||||
|
||||
// Create joint
|
||||
gttic_(asiaJoint);
|
||||
DecisionTreeFactor joint = pA;
|
||||
maybeSaveDotFile(joint, "Asia-A");
|
||||
joint = joint * pS;
|
||||
maybeSaveDotFile(joint, "Asia-AS");
|
||||
joint = joint * pT;
|
||||
maybeSaveDotFile(joint, "Asia-AST");
|
||||
joint = joint * pL;
|
||||
maybeSaveDotFile(joint, "Asia-ASTL");
|
||||
joint = joint * pB;
|
||||
maybeSaveDotFile(joint, "Asia-ASTLB");
|
||||
joint = joint * pE;
|
||||
maybeSaveDotFile(joint, "Asia-ASTLBE");
|
||||
joint = joint * pX;
|
||||
maybeSaveDotFile(joint, "Asia-ASTLBEX");
|
||||
joint = joint * pD;
|
||||
maybeSaveDotFile(joint, "Asia-ASTLBEXD");
|
||||
|
||||
// Check that discrete keys are as expected
|
||||
EXPECT(assert_equal(joint.discreteKeys(), {A, S, T, L, B, E, X, D}));
|
||||
|
||||
// Check that summing out variables maintains the keys even if merged, as is
|
||||
// the case with S.
|
||||
auto noAB = joint.sum(Ordering{A.first, B.first});
|
||||
EXPECT(assert_equal(noAB->discreteKeys(), {S, T, L, E, X, D}));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DecisionTreeFactor, DotWithNames) {
|
||||
DiscreteKey A(12, 3), B(5, 2);
|
||||
|
|
|
|||
Loading…
Reference in New Issue