fix tests
parent
fc2d33f437
commit
2c02efcae2
|
|
@ -111,15 +111,15 @@ TEST(DecisionTreeFactor, sum_max) {
|
||||||
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
|
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
|
||||||
|
|
||||||
DecisionTreeFactor expected(v1, "9 12");
|
DecisionTreeFactor expected(v1, "9 12");
|
||||||
DecisionTreeFactor::shared_ptr actual = f1.sum(1);
|
auto actual = std::dynamic_pointer_cast<DecisionTreeFactor>(f1.sum(1));
|
||||||
CHECK(assert_equal(expected, *actual, 1e-5));
|
CHECK(assert_equal(expected, *actual, 1e-5));
|
||||||
|
|
||||||
DecisionTreeFactor expected2(v1, "5 6");
|
DecisionTreeFactor expected2(v1, "5 6");
|
||||||
DecisionTreeFactor::shared_ptr actual2 = f1.max(1);
|
auto actual2 = std::dynamic_pointer_cast<DecisionTreeFactor>(f1.max(1));
|
||||||
CHECK(assert_equal(expected2, *actual2));
|
CHECK(assert_equal(expected2, *actual2));
|
||||||
|
|
||||||
DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
|
DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
|
||||||
DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
|
auto actual22 = std::dynamic_pointer_cast<DecisionTreeFactor>(f2.sum(1));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,7 @@ TEST(DiscreteConditional, constructors) {
|
||||||
DecisionTreeFactor f2(
|
DecisionTreeFactor f2(
|
||||||
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||||
DiscreteConditional actual2(1, f2);
|
DiscreteConditional actual2(1, f2);
|
||||||
DecisionTreeFactor expected2 = f2 / *f2.sum(1);
|
DecisionTreeFactor expected2 = f2 / f2.sum(1);
|
||||||
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
|
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
|
||||||
|
|
||||||
std::vector<double> probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75};
|
std::vector<double> probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75};
|
||||||
|
|
@ -70,7 +70,7 @@ TEST(DiscreteConditional, constructors_alt_interface) {
|
||||||
DecisionTreeFactor f2(
|
DecisionTreeFactor f2(
|
||||||
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||||
DiscreteConditional actual2(1, f2);
|
DiscreteConditional actual2(1, f2);
|
||||||
DecisionTreeFactor expected2 = f2 / *f2.sum(1);
|
DecisionTreeFactor expected2 = f2 / f2.sum(1);
|
||||||
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
|
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -113,12 +113,12 @@ TEST(DiscreteFactorGraph, test) {
|
||||||
const Ordering frontalKeys{0};
|
const Ordering frontalKeys{0};
|
||||||
const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys);
|
const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys);
|
||||||
|
|
||||||
DecisionTreeFactor newFactor = *newFactorPtr;
|
auto newFactor = *std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
|
||||||
|
|
||||||
// Normalize newFactor by max for comparison with expected
|
// Normalize newFactor by max for comparison with expected
|
||||||
auto normalization = newFactor.max(newFactor.size());
|
auto normalization = newFactor.max(newFactor.size());
|
||||||
|
|
||||||
newFactor = newFactor / *normalization;
|
newFactor = newFactor / normalization;
|
||||||
|
|
||||||
// Check Conditional
|
// Check Conditional
|
||||||
CHECK(conditional);
|
CHECK(conditional);
|
||||||
|
|
@ -132,7 +132,7 @@ TEST(DiscreteFactorGraph, test) {
|
||||||
// Normalize by max.
|
// Normalize by max.
|
||||||
normalization = expectedFactor.max(expectedFactor.size());
|
normalization = expectedFactor.max(expectedFactor.size());
|
||||||
// Ensure normalization is correct.
|
// Ensure normalization is correct.
|
||||||
expectedFactor = expectedFactor / *normalization;
|
expectedFactor = expectedFactor / normalization;
|
||||||
EXPECT(assert_equal(expectedFactor, newFactor));
|
EXPECT(assert_equal(expectedFactor, newFactor));
|
||||||
|
|
||||||
// Test using elimination tree
|
// Test using elimination tree
|
||||||
|
|
|
||||||
|
|
@ -242,15 +242,15 @@ TEST(TableFactor, sum_max) {
|
||||||
TableFactor f1(v0 & v1, "1 2 3 4 5 6");
|
TableFactor f1(v0 & v1, "1 2 3 4 5 6");
|
||||||
|
|
||||||
TableFactor expected(v1, "9 12");
|
TableFactor expected(v1, "9 12");
|
||||||
TableFactor::shared_ptr actual = f1.sum(1);
|
auto actual = std::dynamic_pointer_cast<TableFactor>(f1.sum(1));
|
||||||
CHECK(assert_equal(expected, *actual, 1e-5));
|
CHECK(assert_equal(expected, *actual, 1e-5));
|
||||||
|
|
||||||
TableFactor expected2(v1, "5 6");
|
TableFactor expected2(v1, "5 6");
|
||||||
TableFactor::shared_ptr actual2 = f1.max(1);
|
auto actual2 = std::dynamic_pointer_cast<TableFactor>(f1.max(1));
|
||||||
CHECK(assert_equal(expected2, *actual2));
|
CHECK(assert_equal(expected2, *actual2));
|
||||||
|
|
||||||
TableFactor f2(v1 & v0, "1 2 3 4 5 6");
|
TableFactor f2(v1 & v0, "1 2 3 4 5 6");
|
||||||
TableFactor::shared_ptr actual22 = f2.sum(1);
|
auto actual22 = std::dynamic_pointer_cast<TableFactor>(f2.sum(1));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue