Removed 2 and 3 key constructors for DecisionTreeFactor because wrapper is awesome!

release/4.3a0
Frank Dellaert 2021-12-28 13:00:14 -05:00
parent 93e9756ef0
commit 340ac7569d
5 changed files with 21 additions and 32 deletions

View File

@ -70,18 +70,6 @@ namespace gtsam {
DecisionTreeFactor(const DiscreteKey& key, const std::vector<double>& row)
: DecisionTreeFactor(DiscreteKeys{key}, row) {}
/// Two-key specialization
template <class SOURCE>
DecisionTreeFactor(const DiscreteKey& key1, const DiscreteKey& key2,
SOURCE table)
: DecisionTreeFactor({key1, key2}, table) {}
/// Three-key specialization
template <class SOURCE>
DecisionTreeFactor(const DiscreteKey& key1, const DiscreteKey& key2,
const DiscreteKey& key3, SOURCE table)
: DecisionTreeFactor({key1, key2, key3}, table) {}
/** Construct from a DiscreteConditional type */
DecisionTreeFactor(const DiscreteConditional& c);

View File

@ -32,16 +32,16 @@ class DiscreteFactor {
#include <gtsam/discrete/DecisionTreeFactor.h>
virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
DecisionTreeFactor();
DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table);
DecisionTreeFactor(const gtsam::DiscreteKey& key,
const std::vector<double>& spec);
DecisionTreeFactor(const gtsam::DiscreteKey& key, const std::string& spec);
DecisionTreeFactor(const gtsam::DiscreteKey& key1,
const gtsam::DiscreteKey& key2, const std::string& spec);
DecisionTreeFactor(const gtsam::DiscreteKey& key1,
const gtsam::DiscreteKey& key2,
const gtsam::DiscreteKey& key3, const std::string& spec);
DecisionTreeFactor(const gtsam::DiscreteKey& key, string table);
DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table);
DecisionTreeFactor(const std::vector<gtsam::DiscreteKey>& keys, string table);
DecisionTreeFactor(const gtsam::DiscreteConditional& c);
void print(string s = "DecisionTreeFactor\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
@ -174,12 +174,13 @@ class DotWriter {
class DiscreteFactorGraph {
DiscreteFactorGraph();
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
void add(const gtsam::DiscreteKey& j, string table);
void add(const gtsam::DiscreteKey& j1, const gtsam::DiscreteKey& j2, string table);
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
void add(const gtsam::DiscreteKeys& keys, string table);
void add(const std::vector<gtsam::DiscreteKey>& keys, string table);
bool empty() const;
size_t size() const;
gtsam::KeySet keys() const;

View File

@ -133,10 +133,10 @@ void Scheduler::addStudentSpecificConstraints(size_t i,
Potentials::ADT p(dummy & areaKey,
available_); // available_ is Doodle string
Potentials::ADT q = p.choose(dummyIndex, *slot);
DiscreteFactor::shared_ptr f(new DecisionTreeFactor(areaKey, q));
CSP::push_back(f);
CSP::add(areaKey, q);
} else {
CSP::add(s.key_, areaKey, available_); // available_ is Doodle string
DiscreteKeys keys {s.key_, areaKey};
CSP::add(keys, available_); // available_ is Doodle string
}
}

View File

@ -23,7 +23,7 @@ class TestDecisionTreeFactor(GtsamTestCase):
def setUp(self):
A = (12, 3)
B = (5, 2)
self.factor = DecisionTreeFactor(A, B, "1 2 3 4 5 6")
self.factor = DecisionTreeFactor([A, B], "1 2 3 4 5 6")
def test_enumerate(self):
actual = self.factor.enumerate()

View File

@ -36,7 +36,7 @@ class TestDiscreteFactorGraph(GtsamTestCase):
graph.add(P2, "0.9 0.6")
# Add a binary factor
graph.add(P1, P2, "4 1 10 4")
graph.add([P1, P2], "4 1 10 4")
# Instantiate Values
assignment = DiscreteValues()
@ -85,8 +85,8 @@ class TestDiscreteFactorGraph(GtsamTestCase):
# A simple factor graph (A)-fAC-(C)-fBC-(B)
# with smoothness priors
graph = DiscreteFactorGraph()
graph.add(A, C, "3 1 1 3")
graph.add(C, B, "3 1 1 3")
graph.add([A, C], "3 1 1 3")
graph.add([C, B], "3 1 1 3")
# Test optimization
expectedValues = DiscreteValues()
@ -105,8 +105,8 @@ class TestDiscreteFactorGraph(GtsamTestCase):
# Create Factor graph
graph = DiscreteFactorGraph()
graph.add(C, A, "0.2 0.8 0.3 0.7")
graph.add(C, B, "0.1 0.9 0.4 0.6")
graph.add([C, A], "0.2 0.8 0.3 0.7")
graph.add([C, B], "0.1 0.9 0.4 0.6")
actualMPE = graph.optimize()