fix tests

release/4.3a0
Varun Agrawal 2024-12-08 17:02:47 -05:00
parent fc2d33f437
commit 2c02efcae2
4 changed files with 11 additions and 11 deletions

View File

@ -111,15 +111,15 @@ TEST(DecisionTreeFactor, sum_max) {
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
DecisionTreeFactor expected(v1, "9 12");
DecisionTreeFactor::shared_ptr actual = f1.sum(1);
auto actual = std::dynamic_pointer_cast<DecisionTreeFactor>(f1.sum(1));
CHECK(assert_equal(expected, *actual, 1e-5));
DecisionTreeFactor expected2(v1, "5 6");
DecisionTreeFactor::shared_ptr actual2 = f1.max(1);
auto actual2 = std::dynamic_pointer_cast<DecisionTreeFactor>(f1.max(1));
CHECK(assert_equal(expected2, *actual2));
DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
auto actual22 = std::dynamic_pointer_cast<DecisionTreeFactor>(f2.sum(1));
}
/* ************************************************************************* */

View File

@ -46,7 +46,7 @@ TEST(DiscreteConditional, constructors) {
DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2);
DecisionTreeFactor expected2 = f2 / *f2.sum(1);
DecisionTreeFactor expected2 = f2 / f2.sum(1);
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
std::vector<double> probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75};
@ -70,7 +70,7 @@ TEST(DiscreteConditional, constructors_alt_interface) {
DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2);
DecisionTreeFactor expected2 = f2 / *f2.sum(1);
DecisionTreeFactor expected2 = f2 / f2.sum(1);
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
}

View File

@ -113,12 +113,12 @@ TEST(DiscreteFactorGraph, test) {
const Ordering frontalKeys{0};
const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys);
DecisionTreeFactor newFactor = *newFactorPtr;
auto newFactor = *std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
// Normalize newFactor by max for comparison with expected
auto normalization = newFactor.max(newFactor.size());
newFactor = newFactor / *normalization;
newFactor = newFactor / normalization;
// Check Conditional
CHECK(conditional);
@ -132,7 +132,7 @@ TEST(DiscreteFactorGraph, test) {
// Normalize by max.
normalization = expectedFactor.max(expectedFactor.size());
// Ensure normalization is correct.
expectedFactor = expectedFactor / *normalization;
expectedFactor = expectedFactor / normalization;
EXPECT(assert_equal(expectedFactor, newFactor));
// Test using elimination tree

View File

@ -242,15 +242,15 @@ TEST(TableFactor, sum_max) {
TableFactor f1(v0 & v1, "1 2 3 4 5 6");
TableFactor expected(v1, "9 12");
TableFactor::shared_ptr actual = f1.sum(1);
auto actual = std::dynamic_pointer_cast<TableFactor>(f1.sum(1));
CHECK(assert_equal(expected, *actual, 1e-5));
TableFactor expected2(v1, "5 6");
TableFactor::shared_ptr actual2 = f1.max(1);
auto actual2 = std::dynamic_pointer_cast<TableFactor>(f1.max(1));
CHECK(assert_equal(expected2, *actual2));
TableFactor f2(v1 & v0, "1 2 3 4 5 6");
TableFactor::shared_ptr actual22 = f2.sum(1);
auto actual22 = std::dynamic_pointer_cast<TableFactor>(f2.sum(1));
}
/* ************************************************************************* */