fix tests

release/4.3a0
Varun Agrawal 2024-12-08 17:02:47 -05:00
parent fc2d33f437
commit 2c02efcae2
4 changed files with 11 additions and 11 deletions

View File

@ -111,15 +111,15 @@ TEST(DecisionTreeFactor, sum_max) {
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6"); DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
DecisionTreeFactor expected(v1, "9 12"); DecisionTreeFactor expected(v1, "9 12");
DecisionTreeFactor::shared_ptr actual = f1.sum(1); auto actual = std::dynamic_pointer_cast<DecisionTreeFactor>(f1.sum(1));
CHECK(assert_equal(expected, *actual, 1e-5)); CHECK(assert_equal(expected, *actual, 1e-5));
DecisionTreeFactor expected2(v1, "5 6"); DecisionTreeFactor expected2(v1, "5 6");
DecisionTreeFactor::shared_ptr actual2 = f1.max(1); auto actual2 = std::dynamic_pointer_cast<DecisionTreeFactor>(f1.max(1));
CHECK(assert_equal(expected2, *actual2)); CHECK(assert_equal(expected2, *actual2));
DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6"); DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
DecisionTreeFactor::shared_ptr actual22 = f2.sum(1); auto actual22 = std::dynamic_pointer_cast<DecisionTreeFactor>(f2.sum(1));
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -46,7 +46,7 @@ TEST(DiscreteConditional, constructors) {
DecisionTreeFactor f2( DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2); DiscreteConditional actual2(1, f2);
DecisionTreeFactor expected2 = f2 / *f2.sum(1); DecisionTreeFactor expected2 = f2 / f2.sum(1);
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2))); EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
std::vector<double> probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75}; std::vector<double> probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75};
@ -70,7 +70,7 @@ TEST(DiscreteConditional, constructors_alt_interface) {
DecisionTreeFactor f2( DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2); DiscreteConditional actual2(1, f2);
DecisionTreeFactor expected2 = f2 / *f2.sum(1); DecisionTreeFactor expected2 = f2 / f2.sum(1);
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2))); EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
} }

View File

@ -113,12 +113,12 @@ TEST(DiscreteFactorGraph, test) {
const Ordering frontalKeys{0}; const Ordering frontalKeys{0};
const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys); const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys);
DecisionTreeFactor newFactor = *newFactorPtr; auto newFactor = *std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
// Normalize newFactor by max for comparison with expected // Normalize newFactor by max for comparison with expected
auto normalization = newFactor.max(newFactor.size()); auto normalization = newFactor.max(newFactor.size());
newFactor = newFactor / *normalization; newFactor = newFactor / normalization;
// Check Conditional // Check Conditional
CHECK(conditional); CHECK(conditional);
@ -132,7 +132,7 @@ TEST(DiscreteFactorGraph, test) {
// Normalize by max. // Normalize by max.
normalization = expectedFactor.max(expectedFactor.size()); normalization = expectedFactor.max(expectedFactor.size());
// Ensure normalization is correct. // Ensure normalization is correct.
expectedFactor = expectedFactor / *normalization; expectedFactor = expectedFactor / normalization;
EXPECT(assert_equal(expectedFactor, newFactor)); EXPECT(assert_equal(expectedFactor, newFactor));
// Test using elimination tree // Test using elimination tree

View File

@ -242,15 +242,15 @@ TEST(TableFactor, sum_max) {
TableFactor f1(v0 & v1, "1 2 3 4 5 6"); TableFactor f1(v0 & v1, "1 2 3 4 5 6");
TableFactor expected(v1, "9 12"); TableFactor expected(v1, "9 12");
TableFactor::shared_ptr actual = f1.sum(1); auto actual = std::dynamic_pointer_cast<TableFactor>(f1.sum(1));
CHECK(assert_equal(expected, *actual, 1e-5)); CHECK(assert_equal(expected, *actual, 1e-5));
TableFactor expected2(v1, "5 6"); TableFactor expected2(v1, "5 6");
TableFactor::shared_ptr actual2 = f1.max(1); auto actual2 = std::dynamic_pointer_cast<TableFactor>(f1.max(1));
CHECK(assert_equal(expected2, *actual2)); CHECK(assert_equal(expected2, *actual2));
TableFactor f2(v1 & v0, "1 2 3 4 5 6"); TableFactor f2(v1 & v0, "1 2 3 4 5 6");
TableFactor::shared_ptr actual22 = f2.sum(1); auto actual22 = std::dynamic_pointer_cast<TableFactor>(f2.sum(1));
} }
/* ************************************************************************* */ /* ************************************************************************* */