exposing more factor methods
							parent
							
								
									be5aa56df7
								
							
						
					
					
						commit
						c15bbed9dc
					
				|  | @ -58,6 +58,15 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor { | |||
|              const gtsam::KeyFormatter& keyFormatter = | ||||
|                  gtsam::DefaultKeyFormatter) const; | ||||
|   bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; | ||||
| 
 | ||||
|   double operator()(const gtsam::DiscreteValues& values) const; | ||||
|   gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const; | ||||
|   size_t cardinality(gtsam::Key j) const; | ||||
|   gtsam::DecisionTreeFactor operator/(const gtsam::DecisionTreeFactor& f) const; | ||||
|   gtsam::DecisionTreeFactor* sum(size_t nrFrontals) const; | ||||
|   gtsam::DecisionTreeFactor* sum(const gtsam::Ordering& keys) const; | ||||
|   gtsam::DecisionTreeFactor* max(size_t nrFrontals) const; | ||||
| 
 | ||||
|   string dot( | ||||
|       const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, | ||||
|       bool showZero = true) const; | ||||
|  |  | |||
|  | @ -17,10 +17,12 @@ | |||
|  *  @author Duy-Nguyen Ta | ||||
|  */ | ||||
| 
 | ||||
| #include <gtsam/discrete/Signature.h> | ||||
| #include <gtsam/discrete/DecisionTreeFactor.h> | ||||
| #include <gtsam/base/Testable.h> | ||||
| #include <CppUnitLite/TestHarness.h> | ||||
| #include <gtsam/base/Testable.h> | ||||
| #include <gtsam/discrete/DecisionTreeFactor.h> | ||||
| #include <gtsam/discrete/DiscretePrior.h> | ||||
| #include <gtsam/discrete/Signature.h> | ||||
| 
 | ||||
| #include <boost/assign/std/map.hpp> | ||||
| using namespace boost::assign; | ||||
| 
 | ||||
|  | @ -51,17 +53,21 @@ TEST( DecisionTreeFactor, constructors) | |||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST_UNSAFE( DecisionTreeFactor, multiplication) | ||||
| { | ||||
|   DiscreteKey v0(0,2), v1(1,2), v2(2,2); | ||||
| TEST(DecisionTreeFactor, multiplication) { | ||||
|   DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2); | ||||
| 
 | ||||
|   // Multiply with a DiscretePrior, i.e., Bayes Law!
 | ||||
|   DiscretePrior prior(v1 % "1/3"); | ||||
|   DecisionTreeFactor f1(v0 & v1, "1 2 3 4"); | ||||
|   DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3"); | ||||
|   CHECK(assert_equal(expected, prior * f1)); | ||||
|   CHECK(assert_equal(expected, f1 * prior)); | ||||
| 
 | ||||
|   // Multiply two factors
 | ||||
|   DecisionTreeFactor f2(v1 & v2, "5 6 7 8"); | ||||
| 
 | ||||
|   DecisionTreeFactor expected(v0 & v1 & v2, "5 6 14 16 15 18 28 32"); | ||||
| 
 | ||||
|   DecisionTreeFactor actual = f1 * f2; | ||||
|   CHECK(assert_equal(expected, actual)); | ||||
|   DecisionTreeFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32"); | ||||
|   CHECK(assert_equal(expected2, actual)); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
|  |  | |||
|  | @ -13,7 +13,7 @@ Author: Frank Dellaert | |||
| 
 | ||||
| import unittest | ||||
| 
 | ||||
| from gtsam import DecisionTreeFactor, DecisionTreeFactor, DiscreteKeys | ||||
| from gtsam import DecisionTreeFactor, DiscreteValues, DiscretePrior, Ordering | ||||
| from gtsam.utils.test_case import GtsamTestCase | ||||
| 
 | ||||
| 
 | ||||
|  | @ -21,15 +21,59 @@ class TestDecisionTreeFactor(GtsamTestCase): | |||
|     """Tests for DecisionTreeFactors.""" | ||||
| 
 | ||||
|     def setUp(self): | ||||
|         A = (12, 3) | ||||
|         B = (5, 2) | ||||
|         self.factor = DecisionTreeFactor([A, B], "1 2  3 4  5 6") | ||||
|         self.A = (12, 3) | ||||
|         self.B = (5, 2) | ||||
|         self.factor = DecisionTreeFactor([self.A, self.B], "1 2  3 4  5 6") | ||||
| 
 | ||||
|     def test_enumerate(self): | ||||
|         actual = self.factor.enumerate() | ||||
|         _, values = zip(*actual) | ||||
|         self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) | ||||
| 
 | ||||
|     def test_multiplication(self): | ||||
|         """Test whether multiplication works with overloading.""" | ||||
|         v0 = (0, 2) | ||||
|         v1 = (1, 2) | ||||
|         v2 = (2, 2) | ||||
| 
 | ||||
|         # Multiply with a DiscretePrior, i.e., Bayes Law! | ||||
|         prior = DiscretePrior(v1, [1, 3]) | ||||
|         f1 = DecisionTreeFactor([v0, v1], "1 2 3 4") | ||||
|         expected = DecisionTreeFactor([v0, v1], "0.25 1.5 0.75 3") | ||||
|         self.gtsamAssertEquals(prior * f1, expected) | ||||
|         self.gtsamAssertEquals(f1 * prior, expected) | ||||
| 
 | ||||
|         # Multiply two factors | ||||
|         f2 = DecisionTreeFactor([v1, v2], "5 6 7 8") | ||||
|         actual = f1 * f2 | ||||
|         expected2 = DecisionTreeFactor([v0, v1, v2], "5 6 14 16 15 18 28 32") | ||||
|         self.gtsamAssertEquals(actual, expected2) | ||||
| 
 | ||||
|     def test_methods(self): | ||||
|         """Test whether we can call methods in python.""" | ||||
|         # double operator()(const DiscreteValues& values) const; | ||||
|         values = DiscreteValues() | ||||
|         values[self.A[0]] = 0 | ||||
|         values[self.B[0]] = 0 | ||||
|         self.assertIsInstance(self.factor(values), float) | ||||
| 
 | ||||
|         # size_t cardinality(Key j) const; | ||||
|         self.assertIsInstance(self.factor.cardinality(self.A[0]), int) | ||||
| 
 | ||||
|         # DecisionTreeFactor operator/(const DecisionTreeFactor& f) const; | ||||
|         self.assertIsInstance(self.factor / self.factor, DecisionTreeFactor) | ||||
| 
 | ||||
|         # DecisionTreeFactor* sum(size_t nrFrontals) const; | ||||
|         self.assertIsInstance(self.factor.sum(1), DecisionTreeFactor) | ||||
| 
 | ||||
|         # DecisionTreeFactor* sum(const Ordering& keys) const; | ||||
|         ordering = Ordering() | ||||
|         ordering.push_back(self.A[0]) | ||||
|         self.assertIsInstance(self.factor.sum(ordering), DecisionTreeFactor) | ||||
| 
 | ||||
|         # DecisionTreeFactor* max(size_t nrFrontals) const; | ||||
|         self.assertIsInstance(self.factor.max(1), DecisionTreeFactor) | ||||
| 
 | ||||
|     def test_markdown(self): | ||||
|         """Test whether the _repr_markdown_ method.""" | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue