Allow a vector of doubles for single-variable factors

release/4.3a0
Frank Dellaert 2021-12-26 23:42:12 -05:00
parent 34c3d6af5e
commit dbe5c0fa81
4 changed files with 9 additions and 2 deletions

View File

@ -66,6 +66,10 @@ namespace gtsam {
DecisionTreeFactor(const DiscreteKey& key, SOURCE table)
: DecisionTreeFactor(DiscreteKeys{key}, table) {}
/// Single-key specialization, with vector of doubles.
DecisionTreeFactor(const DiscreteKey& key, const std::vector<double>& row)
: DecisionTreeFactor(DiscreteKeys{key}, row) {}
/// Two-key specialization
template <class SOURCE>
DecisionTreeFactor(const DiscreteKey& key1, const DiscreteKey& key2,

View File

@ -33,6 +33,8 @@ class DiscreteFactor {
virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
DecisionTreeFactor();
DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table);
DecisionTreeFactor(const gtsam::DiscreteKey& key,
const std::vector<double>& spec);
DecisionTreeFactor(const gtsam::DiscreteKey& key, const std::string& spec);
DecisionTreeFactor(const gtsam::DiscreteKey& key1,
const gtsam::DiscreteKey& key2, const std::string& spec);
@ -175,6 +177,7 @@ class DiscreteFactorGraph {
DiscreteFactorGraph();
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
void add(const gtsam::DiscreteKey& j, string table);
void add(const gtsam::DiscreteKey& j1, const gtsam::DiscreteKey& j2, string table);
void add(const gtsam::DiscreteKeys& keys, string table);

View File

@ -34,7 +34,7 @@ TEST( DecisionTreeFactor, constructors)
DiscreteKey X(0,2), Y(1,3), Z(2,2);
// Create factors
DecisionTreeFactor f1(X, "2 8");
DecisionTreeFactor f1(X, {2, 8});
DecisionTreeFactor f2(X & Y, "2 5 3 6 4 7");
DecisionTreeFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
EXPECT_LONGS_EQUAL(1,f1.size());

View File

@ -32,7 +32,7 @@ class TestDiscreteFactorGraph(GtsamTestCase):
graph = DiscreteFactorGraph()
# Add two unary factors (priors)
graph.add(P1, "0.9 0.3")
graph.add(P1, [0.9, 0.3])
graph.add(P2, "0.9 0.6")
# Add a binary factor