From 2f8c8ddb75cb96b30b550bd03e0a659746938857 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 20:50:40 -0500 Subject: [PATCH] update tests --- gtsam/discrete/tests/testDiscreteConditional.cpp | 4 ++-- gtsam/discrete/tests/testDiscreteFactorGraph.cpp | 7 ++++--- gtsam_unstable/discrete/tests/testCSP.cpp | 2 +- gtsam_unstable/discrete/tests/testScheduler.cpp | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index d17c76837..b91e1bd8a 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -46,7 +46,7 @@ TEST(DiscreteConditional, constructors) { DecisionTreeFactor f2( X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); - DecisionTreeFactor expected2 = f2 / f2.sum(1); + DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor(); EXPECT(assert_equal(expected2, static_cast(actual2))); std::vector probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75}; @@ -70,7 +70,7 @@ TEST(DiscreteConditional, constructors_alt_interface) { DecisionTreeFactor f2( X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); - DecisionTreeFactor expected2 = f2 / f2.sum(1); + DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor(); EXPECT(assert_equal(expected2, static_cast(actual2))); } diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index 4ee36f0ab..0c1dd7a2a 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -94,7 +94,7 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) { EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9); // Check if graph product works - DecisionTreeFactor product = graph.product(); + DecisionTreeFactor product = graph.product()->toDecisionTreeFactor(); EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9); } @@ -117,7 +117,7 @@ TEST(DiscreteFactorGraph, test) { *std::dynamic_pointer_cast(newFactorPtr); // Normalize newFactor by max for comparison with expected - auto denominator = newFactor.max(newFactor.size()); + auto denominator = newFactor.max(newFactor.size())->toDecisionTreeFactor(); newFactor = newFactor / denominator; @@ -131,7 +131,8 @@ TEST(DiscreteFactorGraph, test) { CHECK(&newFactor); DecisionTreeFactor expectedFactor(B & A, "10 6 6 10"); // Normalize by max. - denominator = expectedFactor.max(expectedFactor.size()); + denominator = + expectedFactor.max(expectedFactor.size())->toDecisionTreeFactor(); // Ensure denominator is correct. expectedFactor = expectedFactor / denominator; EXPECT(assert_equal(expectedFactor, newFactor)); diff --git a/gtsam_unstable/discrete/tests/testCSP.cpp b/gtsam_unstable/discrete/tests/testCSP.cpp index 2b9a20ca6..6806bfe58 100644 --- a/gtsam_unstable/discrete/tests/testCSP.cpp +++ b/gtsam_unstable/discrete/tests/testCSP.cpp @@ -124,7 +124,7 @@ TEST(CSP, allInOne) { EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); // Just for fun, create the product and check it - DecisionTreeFactor product = csp.product(); + DecisionTreeFactor product = csp.product()->toDecisionTreeFactor(); // product.dot("product"); DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0"); EXPECT(assert_equal(expectedProduct, product)); diff --git a/gtsam_unstable/discrete/tests/testScheduler.cpp b/gtsam_unstable/discrete/tests/testScheduler.cpp index f868abb5e..5f9b7f287 100644 --- a/gtsam_unstable/discrete/tests/testScheduler.cpp +++ b/gtsam_unstable/discrete/tests/testScheduler.cpp @@ -113,7 +113,7 @@ TEST(schedulingExample, test) { EXPECT(assert_equal(expected, (DiscreteFactorGraph)s)); // Do brute force product and output that to file - DecisionTreeFactor product = s.product(); + DecisionTreeFactor product = s.product()->toDecisionTreeFactor(); // product.dot("scheduling", false); // Do exact inference