diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 78efd57e2..a89d71b24 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -44,14 +44,19 @@ class DiscreteFactor { #include virtual class DecisionTreeFactor : gtsam::DiscreteFactor { DecisionTreeFactor(); - + DecisionTreeFactor(const gtsam::DiscreteKey& key, const std::vector& spec); DecisionTreeFactor(const gtsam::DiscreteKey& key, string table); - + + DecisionTreeFactor(const gtsam::DiscreteKeys& keys, + const std::vector& table); DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table); + + DecisionTreeFactor(const std::vector& keys, + const std::vector& table); DecisionTreeFactor(const std::vector& keys, string table); - + DecisionTreeFactor(const gtsam::DiscreteConditional& c); void print(string s = "DecisionTreeFactor\n", diff --git a/python/gtsam/tests/test_DecisionTreeFactor.py b/python/gtsam/tests/test_DecisionTreeFactor.py index 0499e7215..173a782f2 100644 --- a/python/gtsam/tests/test_DecisionTreeFactor.py +++ b/python/gtsam/tests/test_DecisionTreeFactor.py @@ -25,7 +25,13 @@ class TestDecisionTreeFactor(GtsamTestCase): self.B = (5, 2) self.factor = DecisionTreeFactor([self.A, self.B], "1 2 3 4 5 6") + def test_from_floats(self): + """Test whether we can construct a factor from floats.""" + actual = DecisionTreeFactor([self.A, self.B], [1., 2., 3., 4., 5., 6.]) + self.gtsamAssertEquals(actual, self.factor) + def test_enumerate(self): + """Test whether we can enumerate the factor.""" actual = self.factor.enumerate() _, values = zip(*actual) self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0])