fix tests
parent
fc2d33f437
commit
2c02efcae2
|
|
@ -111,15 +111,15 @@ TEST(DecisionTreeFactor, sum_max) {
|
|||
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
|
||||
|
||||
DecisionTreeFactor expected(v1, "9 12");
|
||||
DecisionTreeFactor::shared_ptr actual = f1.sum(1);
|
||||
auto actual = std::dynamic_pointer_cast<DecisionTreeFactor>(f1.sum(1));
|
||||
CHECK(assert_equal(expected, *actual, 1e-5));
|
||||
|
||||
DecisionTreeFactor expected2(v1, "5 6");
|
||||
DecisionTreeFactor::shared_ptr actual2 = f1.max(1);
|
||||
auto actual2 = std::dynamic_pointer_cast<DecisionTreeFactor>(f1.max(1));
|
||||
CHECK(assert_equal(expected2, *actual2));
|
||||
|
||||
DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
|
||||
DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
|
||||
auto actual22 = std::dynamic_pointer_cast<DecisionTreeFactor>(f2.sum(1));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ TEST(DiscreteConditional, constructors) {
|
|||
DecisionTreeFactor f2(
|
||||
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||
DiscreteConditional actual2(1, f2);
|
||||
DecisionTreeFactor expected2 = f2 / *f2.sum(1);
|
||||
DecisionTreeFactor expected2 = f2 / f2.sum(1);
|
||||
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
|
||||
|
||||
std::vector<double> probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75};
|
||||
|
|
@ -70,7 +70,7 @@ TEST(DiscreteConditional, constructors_alt_interface) {
|
|||
DecisionTreeFactor f2(
|
||||
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||
DiscreteConditional actual2(1, f2);
|
||||
DecisionTreeFactor expected2 = f2 / *f2.sum(1);
|
||||
DecisionTreeFactor expected2 = f2 / f2.sum(1);
|
||||
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -113,12 +113,12 @@ TEST(DiscreteFactorGraph, test) {
|
|||
const Ordering frontalKeys{0};
|
||||
const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys);
|
||||
|
||||
DecisionTreeFactor newFactor = *newFactorPtr;
|
||||
auto newFactor = *std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
|
||||
|
||||
// Normalize newFactor by max for comparison with expected
|
||||
auto normalization = newFactor.max(newFactor.size());
|
||||
|
||||
newFactor = newFactor / *normalization;
|
||||
newFactor = newFactor / normalization;
|
||||
|
||||
// Check Conditional
|
||||
CHECK(conditional);
|
||||
|
|
@ -132,7 +132,7 @@ TEST(DiscreteFactorGraph, test) {
|
|||
// Normalize by max.
|
||||
normalization = expectedFactor.max(expectedFactor.size());
|
||||
// Ensure normalization is correct.
|
||||
expectedFactor = expectedFactor / *normalization;
|
||||
expectedFactor = expectedFactor / normalization;
|
||||
EXPECT(assert_equal(expectedFactor, newFactor));
|
||||
|
||||
// Test using elimination tree
|
||||
|
|
|
|||
|
|
@ -242,15 +242,15 @@ TEST(TableFactor, sum_max) {
|
|||
TableFactor f1(v0 & v1, "1 2 3 4 5 6");
|
||||
|
||||
TableFactor expected(v1, "9 12");
|
||||
TableFactor::shared_ptr actual = f1.sum(1);
|
||||
auto actual = std::dynamic_pointer_cast<TableFactor>(f1.sum(1));
|
||||
CHECK(assert_equal(expected, *actual, 1e-5));
|
||||
|
||||
TableFactor expected2(v1, "5 6");
|
||||
TableFactor::shared_ptr actual2 = f1.max(1);
|
||||
auto actual2 = std::dynamic_pointer_cast<TableFactor>(f1.max(1));
|
||||
CHECK(assert_equal(expected2, *actual2));
|
||||
|
||||
TableFactor f2(v1 & v0, "1 2 3 4 5 6");
|
||||
TableFactor::shared_ptr actual22 = f2.sum(1);
|
||||
auto actual22 = std::dynamic_pointer_cast<TableFactor>(f2.sum(1));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
|||
Loading…
Reference in New Issue