update tests
parent
e9822a70d2
commit
2f8c8ddb75
|
@ -46,7 +46,7 @@ TEST(DiscreteConditional, constructors) {
|
||||||
DecisionTreeFactor f2(
|
DecisionTreeFactor f2(
|
||||||
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||||
DiscreteConditional actual2(1, f2);
|
DiscreteConditional actual2(1, f2);
|
||||||
DecisionTreeFactor expected2 = f2 / f2.sum(1);
|
DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor();
|
||||||
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
|
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
|
||||||
|
|
||||||
std::vector<double> probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75};
|
std::vector<double> probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75};
|
||||||
|
@ -70,7 +70,7 @@ TEST(DiscreteConditional, constructors_alt_interface) {
|
||||||
DecisionTreeFactor f2(
|
DecisionTreeFactor f2(
|
||||||
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||||
DiscreteConditional actual2(1, f2);
|
DiscreteConditional actual2(1, f2);
|
||||||
DecisionTreeFactor expected2 = f2 / f2.sum(1);
|
DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor();
|
||||||
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
|
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -94,7 +94,7 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
|
||||||
EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9);
|
EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9);
|
||||||
|
|
||||||
// Check if graph product works
|
// Check if graph product works
|
||||||
DecisionTreeFactor product = graph.product();
|
DecisionTreeFactor product = graph.product()->toDecisionTreeFactor();
|
||||||
EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9);
|
EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,7 +117,7 @@ TEST(DiscreteFactorGraph, test) {
|
||||||
*std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
|
*std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
|
||||||
|
|
||||||
// Normalize newFactor by max for comparison with expected
|
// Normalize newFactor by max for comparison with expected
|
||||||
auto denominator = newFactor.max(newFactor.size());
|
auto denominator = newFactor.max(newFactor.size())->toDecisionTreeFactor();
|
||||||
|
|
||||||
newFactor = newFactor / denominator;
|
newFactor = newFactor / denominator;
|
||||||
|
|
||||||
|
@ -131,7 +131,8 @@ TEST(DiscreteFactorGraph, test) {
|
||||||
CHECK(&newFactor);
|
CHECK(&newFactor);
|
||||||
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
||||||
// Normalize by max.
|
// Normalize by max.
|
||||||
denominator = expectedFactor.max(expectedFactor.size());
|
denominator =
|
||||||
|
expectedFactor.max(expectedFactor.size())->toDecisionTreeFactor();
|
||||||
// Ensure denominator is correct.
|
// Ensure denominator is correct.
|
||||||
expectedFactor = expectedFactor / denominator;
|
expectedFactor = expectedFactor / denominator;
|
||||||
EXPECT(assert_equal(expectedFactor, newFactor));
|
EXPECT(assert_equal(expectedFactor, newFactor));
|
||||||
|
|
|
@ -124,7 +124,7 @@ TEST(CSP, allInOne) {
|
||||||
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
|
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
|
||||||
|
|
||||||
// Just for fun, create the product and check it
|
// Just for fun, create the product and check it
|
||||||
DecisionTreeFactor product = csp.product();
|
DecisionTreeFactor product = csp.product()->toDecisionTreeFactor();
|
||||||
// product.dot("product");
|
// product.dot("product");
|
||||||
DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0");
|
DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0");
|
||||||
EXPECT(assert_equal(expectedProduct, product));
|
EXPECT(assert_equal(expectedProduct, product));
|
||||||
|
|
|
@ -113,7 +113,7 @@ TEST(schedulingExample, test) {
|
||||||
EXPECT(assert_equal(expected, (DiscreteFactorGraph)s));
|
EXPECT(assert_equal(expected, (DiscreteFactorGraph)s));
|
||||||
|
|
||||||
// Do brute force product and output that to file
|
// Do brute force product and output that to file
|
||||||
DecisionTreeFactor product = s.product();
|
DecisionTreeFactor product = s.product()->toDecisionTreeFactor();
|
||||||
// product.dot("scheduling", false);
|
// product.dot("scheduling", false);
|
||||||
|
|
||||||
// Do exact inference
|
// Do exact inference
|
||||||
|
|
Loading…
Reference in New Issue