From c15bbed9dc044ffa159ec5a243dce6985e5203cd Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 15 Jan 2022 08:44:10 -0500 Subject: [PATCH] exposing more factor methods --- gtsam/discrete/discrete.i | 9 ++++ .../discrete/tests/testDecisionTreeFactor.cpp | 26 ++++++---- python/gtsam/tests/test_DecisionTreeFactor.py | 52 +++++++++++++++++-- 3 files changed, 73 insertions(+), 14 deletions(-) diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 12bd5be54..24a941056 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -58,6 +58,15 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor { const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; + + double operator()(const gtsam::DiscreteValues& values) const; + gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const; + size_t cardinality(gtsam::Key j) const; + gtsam::DecisionTreeFactor operator/(const gtsam::DecisionTreeFactor& f) const; + gtsam::DecisionTreeFactor* sum(size_t nrFrontals) const; + gtsam::DecisionTreeFactor* sum(const gtsam::Ordering& keys) const; + gtsam::DecisionTreeFactor* max(size_t nrFrontals) const; + string dot( const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, bool showZero = true) const; diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 594134edf..f2ab5f6bc 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -17,10 +17,12 @@ * @author Duy-Nguyen Ta */ -#include -#include -#include #include +#include +#include +#include +#include + #include using namespace boost::assign; @@ -51,17 +53,21 @@ TEST( DecisionTreeFactor, constructors) } /* ************************************************************************* */ -TEST_UNSAFE( DecisionTreeFactor, multiplication) -{ - DiscreteKey v0(0,2), v1(1,2), v2(2,2); +TEST(DecisionTreeFactor, multiplication) { + DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2); + // Multiply with a DiscretePrior, i.e., Bayes Law! + DiscretePrior prior(v1 % "1/3"); DecisionTreeFactor f1(v0 & v1, "1 2 3 4"); + DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3"); + CHECK(assert_equal(expected, prior * f1)); + CHECK(assert_equal(expected, f1 * prior)); + + // Multiply two factors DecisionTreeFactor f2(v1 & v2, "5 6 7 8"); - - DecisionTreeFactor expected(v0 & v1 & v2, "5 6 14 16 15 18 28 32"); - DecisionTreeFactor actual = f1 * f2; - CHECK(assert_equal(expected, actual)); + DecisionTreeFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32"); + CHECK(assert_equal(expected2, actual)); } /* ************************************************************************* */ diff --git a/python/gtsam/tests/test_DecisionTreeFactor.py b/python/gtsam/tests/test_DecisionTreeFactor.py index 12a60d5cb..03d9f82d7 100644 --- a/python/gtsam/tests/test_DecisionTreeFactor.py +++ b/python/gtsam/tests/test_DecisionTreeFactor.py @@ -13,7 +13,7 @@ Author: Frank Dellaert import unittest -from gtsam import DecisionTreeFactor, DecisionTreeFactor, DiscreteKeys +from gtsam import DecisionTreeFactor, DiscreteValues, DiscretePrior, Ordering from gtsam.utils.test_case import GtsamTestCase @@ -21,15 +21,59 @@ class TestDecisionTreeFactor(GtsamTestCase): """Tests for DecisionTreeFactors.""" def setUp(self): - A = (12, 3) - B = (5, 2) - self.factor = DecisionTreeFactor([A, B], "1 2 3 4 5 6") + self.A = (12, 3) + self.B = (5, 2) + self.factor = DecisionTreeFactor([self.A, self.B], "1 2 3 4 5 6") def test_enumerate(self): actual = self.factor.enumerate() _, values = zip(*actual) self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + def test_multiplication(self): + """Test whether multiplication works with overloading.""" + v0 = (0, 2) + v1 = (1, 2) + v2 = (2, 2) + + # Multiply with a DiscretePrior, i.e., Bayes Law! + prior = DiscretePrior(v1, [1, 3]) + f1 = DecisionTreeFactor([v0, v1], "1 2 3 4") + expected = DecisionTreeFactor([v0, v1], "0.25 1.5 0.75 3") + self.gtsamAssertEquals(prior * f1, expected) + self.gtsamAssertEquals(f1 * prior, expected) + + # Multiply two factors + f2 = DecisionTreeFactor([v1, v2], "5 6 7 8") + actual = f1 * f2 + expected2 = DecisionTreeFactor([v0, v1, v2], "5 6 14 16 15 18 28 32") + self.gtsamAssertEquals(actual, expected2) + + def test_methods(self): + """Test whether we can call methods in python.""" + # double operator()(const DiscreteValues& values) const; + values = DiscreteValues() + values[self.A[0]] = 0 + values[self.B[0]] = 0 + self.assertIsInstance(self.factor(values), float) + + # size_t cardinality(Key j) const; + self.assertIsInstance(self.factor.cardinality(self.A[0]), int) + + # DecisionTreeFactor operator/(const DecisionTreeFactor& f) const; + self.assertIsInstance(self.factor / self.factor, DecisionTreeFactor) + + # DecisionTreeFactor* sum(size_t nrFrontals) const; + self.assertIsInstance(self.factor.sum(1), DecisionTreeFactor) + + # DecisionTreeFactor* sum(const Ordering& keys) const; + ordering = Ordering() + ordering.push_back(self.A[0]) + self.assertIsInstance(self.factor.sum(ordering), DecisionTreeFactor) + + # DecisionTreeFactor* max(size_t nrFrontals) const; + self.assertIsInstance(self.factor.max(1), DecisionTreeFactor) + def test_markdown(self): """Test whether the _repr_markdown_ method."""