diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 43dd892fc..b5a037119 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -66,6 +66,10 @@ namespace gtsam { DecisionTreeFactor(const DiscreteKey& key, SOURCE table) : DecisionTreeFactor(DiscreteKeys{key}, table) {} + /// Single-key specialization, with vector of doubles. + DecisionTreeFactor(const DiscreteKey& key, const std::vector& row) + : DecisionTreeFactor(DiscreteKeys{key}, row) {} + /// Two-key specialization template DecisionTreeFactor(const DiscreteKey& key1, const DiscreteKey& key2, diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 0f319562f..971250ba1 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -33,6 +33,8 @@ class DiscreteFactor { virtual class DecisionTreeFactor : gtsam::DiscreteFactor { DecisionTreeFactor(); DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table); + DecisionTreeFactor(const gtsam::DiscreteKey& key, + const std::vector& spec); DecisionTreeFactor(const gtsam::DiscreteKey& key, const std::string& spec); DecisionTreeFactor(const gtsam::DiscreteKey& key1, const gtsam::DiscreteKey& key2, const std::string& spec); @@ -175,6 +177,7 @@ class DiscreteFactorGraph { DiscreteFactorGraph(); DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet); + void add(const gtsam::DiscreteKey& j, const std::vector& spec); void add(const gtsam::DiscreteKey& j, string table); void add(const gtsam::DiscreteKey& j1, const gtsam::DiscreteKey& j2, string table); void add(const gtsam::DiscreteKeys& keys, string table); diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index ad8e9bd2a..542d16b29 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -34,7 +34,7 @@ TEST( DecisionTreeFactor, constructors) DiscreteKey X(0,2), Y(1,3), Z(2,2); // Create factors - DecisionTreeFactor f1(X, "2 8"); + DecisionTreeFactor f1(X, {2, 8}); DecisionTreeFactor f2(X & Y, "2 5 3 6 4 7"); DecisionTreeFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75"); EXPECT_LONGS_EQUAL(1,f1.size()); diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py index 9dafff33f..dc2c7a4f5 100644 --- a/python/gtsam/tests/test_DiscreteFactorGraph.py +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -32,7 +32,7 @@ class TestDiscreteFactorGraph(GtsamTestCase): graph = DiscreteFactorGraph() # Add two unary factors (priors) - graph.add(P1, "0.9 0.3") + graph.add(P1, [0.9, 0.3]) graph.add(P2, "0.9 0.6") # Add a binary factor