From 5882847604fbcd5b3c35ef58ffff1bab07caeb80 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 26 Dec 2021 23:14:22 -0500 Subject: [PATCH] Specialized DecisionTreeFactor constructors --- gtsam/discrete/DecisionTreeFactor.h | 17 +++++++++++++++++ gtsam/discrete/DiscreteKey.h | 4 +--- gtsam/discrete/discrete.i | 25 +++++++++++++++++++------ 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 27ee67cf2..43dd892fc 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -61,6 +61,23 @@ namespace gtsam { DiscreteFactor(keys.indices()), Potentials(keys, table) { } + /// Single-key specialization + template + DecisionTreeFactor(const DiscreteKey& key, SOURCE table) + : DecisionTreeFactor(DiscreteKeys{key}, table) {} + + /// Two-key specialization + template + DecisionTreeFactor(const DiscreteKey& key1, const DiscreteKey& key2, + SOURCE table) + : DecisionTreeFactor({key1, key2}, table) {} + + /// Three-key specialization + template + DecisionTreeFactor(const DiscreteKey& key1, const DiscreteKey& key2, + const DiscreteKey& key3, SOURCE table) + : DecisionTreeFactor({key1, key2, key3}, table) {} + /** Construct from a DiscreteConditional type */ DecisionTreeFactor(const DiscreteConditional& c); diff --git a/gtsam/discrete/DiscreteKey.h b/gtsam/discrete/DiscreteKey.h index 86f1bcf63..ae4dac38f 100644 --- a/gtsam/discrete/DiscreteKey.h +++ b/gtsam/discrete/DiscreteKey.h @@ -43,9 +43,7 @@ namespace gtsam { DiscreteKeys() : std::vector::vector() {} /// Construct from a key - DiscreteKeys(const DiscreteKey& key) { - push_back(key); - } + explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); } /// Construct from a vector of keys DiscreteKeys(const std::vector& keys) : diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 9bb05085b..0f319562f 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -30,9 +30,15 @@ class DiscreteFactor { }; #include -virtual class DecisionTreeFactor: gtsam::DiscreteFactor { +virtual class DecisionTreeFactor : gtsam::DiscreteFactor { DecisionTreeFactor(); DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table); + DecisionTreeFactor(const gtsam::DiscreteKey& key, const std::string& spec); + DecisionTreeFactor(const gtsam::DiscreteKey& key1, + const gtsam::DiscreteKey& key2, const std::string& spec); + DecisionTreeFactor(const gtsam::DiscreteKey& key1, + const gtsam::DiscreteKey& key2, + const gtsam::DiscreteKey& key3, const std::string& spec); DecisionTreeFactor(const gtsam::DiscreteConditional& c); void print(string s = "DecisionTreeFactor\n", const gtsam::KeyFormatter& keyFormatter = @@ -40,13 +46,19 @@ virtual class DecisionTreeFactor: gtsam::DiscreteFactor { bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; string dot(bool showZero = false) const; string markdown(const gtsam::KeyFormatter& keyFormatter = - gtsam::DefaultKeyFormatter) const; + gtsam::DefaultKeyFormatter) const; }; #include virtual class DiscreteConditional : gtsam::DecisionTreeFactor { DiscreteConditional(); DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f); + DiscreteConditional(const gtsam::DiscreteKey& key, string spec); + DiscreteConditional(const gtsam::DiscreteKey& key, string spec, + const gtsam::DiscreteKey& parent1); + DiscreteConditional(const gtsam::DiscreteKey& key, string spec, + const gtsam::DiscreteKey& parent1, + const gtsam::DiscreteKey& parent2); DiscreteConditional(const gtsam::DiscreteKey& key, const gtsam::DiscreteKeys& parents, string spec); DiscreteConditional(const gtsam::DecisionTreeFactor& joint, @@ -62,13 +74,14 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { string s = "Discrete Conditional: ", const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; gtsam::DecisionTreeFactor* toFactor() const; - gtsam::DecisionTreeFactor* chooseAsFactor(const gtsam::DiscreteValues& parentsValues) const; + gtsam::DecisionTreeFactor* chooseAsFactor( + const gtsam::DiscreteValues& parentsValues) const; size_t solve(const gtsam::DiscreteValues& parentsValues) const; size_t sample(const gtsam::DiscreteValues& parentsValues) const; - void solveInPlace(gtsam::DiscreteValues@ parentsValues) const; - void sampleInPlace(gtsam::DiscreteValues@ parentsValues) const; + void solveInPlace(gtsam::DiscreteValues @parentsValues) const; + void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; string markdown(const gtsam::KeyFormatter& keyFormatter = - gtsam::DefaultKeyFormatter) const; + gtsam::DefaultKeyFormatter) const; }; #include