Specialized DecisionTreeFactor constructors

release/4.3a0
Frank Dellaert 2021-12-26 23:14:22 -05:00
parent 1d12995be5
commit 5882847604
3 changed files with 37 additions and 9 deletions

View File

@ -61,6 +61,23 @@ namespace gtsam {
DiscreteFactor(keys.indices()), Potentials(keys, table) {
}
/// Single-key specialization
template <class SOURCE>
DecisionTreeFactor(const DiscreteKey& key, SOURCE table)
: DecisionTreeFactor(DiscreteKeys{key}, table) {}
/// Two-key specialization
template <class SOURCE>
DecisionTreeFactor(const DiscreteKey& key1, const DiscreteKey& key2,
SOURCE table)
: DecisionTreeFactor({key1, key2}, table) {}
/// Three-key specialization
template <class SOURCE>
DecisionTreeFactor(const DiscreteKey& key1, const DiscreteKey& key2,
const DiscreteKey& key3, SOURCE table)
: DecisionTreeFactor({key1, key2, key3}, table) {}
/** Construct from a DiscreteConditional type */
DecisionTreeFactor(const DiscreteConditional& c);

View File

@ -43,9 +43,7 @@ namespace gtsam {
DiscreteKeys() : std::vector<DiscreteKey>::vector() {}
/// Construct from a key
DiscreteKeys(const DiscreteKey& key) {
push_back(key);
}
explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); }
/// Construct from a vector of keys
DiscreteKeys(const std::vector<DiscreteKey>& keys) :

View File

@ -33,6 +33,12 @@ class DiscreteFactor {
virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
DecisionTreeFactor();
DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table);
DecisionTreeFactor(const gtsam::DiscreteKey& key, const std::string& spec);
DecisionTreeFactor(const gtsam::DiscreteKey& key1,
const gtsam::DiscreteKey& key2, const std::string& spec);
DecisionTreeFactor(const gtsam::DiscreteKey& key1,
const gtsam::DiscreteKey& key2,
const gtsam::DiscreteKey& key3, const std::string& spec);
DecisionTreeFactor(const gtsam::DiscreteConditional& c);
void print(string s = "DecisionTreeFactor\n",
const gtsam::KeyFormatter& keyFormatter =
@ -47,6 +53,12 @@ virtual class DecisionTreeFactor: gtsam::DiscreteFactor {
virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
DiscreteConditional();
DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f);
DiscreteConditional(const gtsam::DiscreteKey& key, string spec);
DiscreteConditional(const gtsam::DiscreteKey& key, string spec,
const gtsam::DiscreteKey& parent1);
DiscreteConditional(const gtsam::DiscreteKey& key, string spec,
const gtsam::DiscreteKey& parent1,
const gtsam::DiscreteKey& parent2);
DiscreteConditional(const gtsam::DiscreteKey& key,
const gtsam::DiscreteKeys& parents, string spec);
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
@ -62,7 +74,8 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
string s = "Discrete Conditional: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
gtsam::DecisionTreeFactor* toFactor() const;
gtsam::DecisionTreeFactor* chooseAsFactor(const gtsam::DiscreteValues& parentsValues) const;
gtsam::DecisionTreeFactor* chooseAsFactor(
const gtsam::DiscreteValues& parentsValues) const;
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
void solveInPlace(gtsam::DiscreteValues @parentsValues) const;