Specialized DecisionTreeFactor constructors

release/4.3a0
Frank Dellaert 2021-12-26 23:14:22 -05:00
parent 1d12995be5
commit 5882847604
3 changed files with 37 additions and 9 deletions

View File

@ -61,6 +61,23 @@ namespace gtsam {
DiscreteFactor(keys.indices()), Potentials(keys, table) { DiscreteFactor(keys.indices()), Potentials(keys, table) {
} }
/// Single-key specialization
template <class SOURCE>
DecisionTreeFactor(const DiscreteKey& key, SOURCE table)
: DecisionTreeFactor(DiscreteKeys{key}, table) {}
/// Two-key specialization
template <class SOURCE>
DecisionTreeFactor(const DiscreteKey& key1, const DiscreteKey& key2,
SOURCE table)
: DecisionTreeFactor({key1, key2}, table) {}
/// Three-key specialization
template <class SOURCE>
DecisionTreeFactor(const DiscreteKey& key1, const DiscreteKey& key2,
const DiscreteKey& key3, SOURCE table)
: DecisionTreeFactor({key1, key2, key3}, table) {}
/** Construct from a DiscreteConditional type */ /** Construct from a DiscreteConditional type */
DecisionTreeFactor(const DiscreteConditional& c); DecisionTreeFactor(const DiscreteConditional& c);

View File

@ -43,9 +43,7 @@ namespace gtsam {
DiscreteKeys() : std::vector<DiscreteKey>::vector() {} DiscreteKeys() : std::vector<DiscreteKey>::vector() {}
/// Construct from a key /// Construct from a key
DiscreteKeys(const DiscreteKey& key) { explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); }
push_back(key);
}
/// Construct from a vector of keys /// Construct from a vector of keys
DiscreteKeys(const std::vector<DiscreteKey>& keys) : DiscreteKeys(const std::vector<DiscreteKey>& keys) :

View File

@ -30,9 +30,15 @@ class DiscreteFactor {
}; };
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
virtual class DecisionTreeFactor: gtsam::DiscreteFactor { virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
DecisionTreeFactor(); DecisionTreeFactor();
DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table); DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table);
DecisionTreeFactor(const gtsam::DiscreteKey& key, const std::string& spec);
DecisionTreeFactor(const gtsam::DiscreteKey& key1,
const gtsam::DiscreteKey& key2, const std::string& spec);
DecisionTreeFactor(const gtsam::DiscreteKey& key1,
const gtsam::DiscreteKey& key2,
const gtsam::DiscreteKey& key3, const std::string& spec);
DecisionTreeFactor(const gtsam::DiscreteConditional& c); DecisionTreeFactor(const gtsam::DiscreteConditional& c);
void print(string s = "DecisionTreeFactor\n", void print(string s = "DecisionTreeFactor\n",
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =
@ -40,13 +46,19 @@ virtual class DecisionTreeFactor: gtsam::DiscreteFactor {
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
string dot(bool showZero = false) const; string dot(bool showZero = false) const;
string markdown(const gtsam::KeyFormatter& keyFormatter = string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
}; };
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
virtual class DiscreteConditional : gtsam::DecisionTreeFactor { virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
DiscreteConditional(); DiscreteConditional();
DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f); DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f);
DiscreteConditional(const gtsam::DiscreteKey& key, string spec);
DiscreteConditional(const gtsam::DiscreteKey& key, string spec,
const gtsam::DiscreteKey& parent1);
DiscreteConditional(const gtsam::DiscreteKey& key, string spec,
const gtsam::DiscreteKey& parent1,
const gtsam::DiscreteKey& parent2);
DiscreteConditional(const gtsam::DiscreteKey& key, DiscreteConditional(const gtsam::DiscreteKey& key,
const gtsam::DiscreteKeys& parents, string spec); const gtsam::DiscreteKeys& parents, string spec);
DiscreteConditional(const gtsam::DecisionTreeFactor& joint, DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
@ -62,13 +74,14 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
string s = "Discrete Conditional: ", string s = "Discrete Conditional: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
gtsam::DecisionTreeFactor* toFactor() const; gtsam::DecisionTreeFactor* toFactor() const;
gtsam::DecisionTreeFactor* chooseAsFactor(const gtsam::DiscreteValues& parentsValues) const; gtsam::DecisionTreeFactor* chooseAsFactor(
const gtsam::DiscreteValues& parentsValues) const;
size_t solve(const gtsam::DiscreteValues& parentsValues) const; size_t solve(const gtsam::DiscreteValues& parentsValues) const;
size_t sample(const gtsam::DiscreteValues& parentsValues) const; size_t sample(const gtsam::DiscreteValues& parentsValues) const;
void solveInPlace(gtsam::DiscreteValues@ parentsValues) const; void solveInPlace(gtsam::DiscreteValues @parentsValues) const;
void sampleInPlace(gtsam::DiscreteValues@ parentsValues) const; void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
string markdown(const gtsam::KeyFormatter& keyFormatter = string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
}; };
#include <gtsam/discrete/DiscretePrior.h> #include <gtsam/discrete/DiscretePrior.h>