Merge pull request #1155 from borglab/decisiontree-refactor

release/4.3a0
Frank Dellaert 2022-04-13 22:24:43 -04:00 committed by GitHub
commit 78d7e903f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 163 additions and 72 deletions

View File

@ -59,7 +59,7 @@ namespace gtsam {
/** constant stored in this leaf */ /** constant stored in this leaf */
Y constant_; Y constant_;
/** The number of assignments contained within this leaf /** The number of assignments contained within this leaf.
* Particularly useful when leaves have been pruned. * Particularly useful when leaves have been pruned.
*/ */
size_t nrAssignments_; size_t nrAssignments_;
@ -68,7 +68,7 @@ namespace gtsam {
Leaf(const Y& constant, size_t nrAssignments = 1) Leaf(const Y& constant, size_t nrAssignments = 1)
: constant_(constant), nrAssignments_(nrAssignments) {} : constant_(constant), nrAssignments_(nrAssignments) {}
/** return the constant */ /// Return the constant
const Y& constant() const { const Y& constant() const {
return constant_; return constant_;
} }
@ -81,19 +81,19 @@ namespace gtsam {
return constant_ == q.constant_; return constant_ == q.constant_;
} }
/// polymorphic equality: is q is a leaf, could be /// polymorphic equality: is q a leaf and is it the same as this leaf?
bool sameLeaf(const Node& q) const override { bool sameLeaf(const Node& q) const override {
return (q.isLeaf() && q.sameLeaf(*this)); return (q.isLeaf() && q.sameLeaf(*this));
} }
/** equality up to tolerance */ /// equality up to tolerance
bool equals(const Node& q, const CompareFunc& compare) const override { bool equals(const Node& q, const CompareFunc& compare) const override {
const Leaf* other = dynamic_cast<const Leaf*>(&q); const Leaf* other = dynamic_cast<const Leaf*>(&q);
if (!other) return false; if (!other) return false;
return compare(this->constant_, other->constant_); return compare(this->constant_, other->constant_);
} }
/** print */ /// print
void print(const std::string& s, const LabelFormatter& labelFormatter, void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override { const ValueFormatter& valueFormatter) const override {
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
@ -122,8 +122,8 @@ namespace gtsam {
/// Apply unary operator with assignment /// Apply unary operator with assignment
NodePtr apply(const UnaryAssignment& op, NodePtr apply(const UnaryAssignment& op,
const Assignment<L>& choices) const override { const Assignment<L>& assignment) const override {
NodePtr f(new Leaf(op(choices, constant_), nrAssignments_)); NodePtr f(new Leaf(op(assignment, constant_), nrAssignments_));
return f; return f;
} }
@ -168,7 +168,10 @@ namespace gtsam {
std::vector<NodePtr> branches_; std::vector<NodePtr> branches_;
private: private:
/** incremental allSame */ /**
* Incremental allSame.
* Records if all the branches are the same leaf.
*/
size_t allSame_; size_t allSame_;
using ChoicePtr = boost::shared_ptr<const Choice>; using ChoicePtr = boost::shared_ptr<const Choice>;
@ -181,9 +184,9 @@ namespace gtsam {
#endif #endif
} }
/** If all branches of a choice node f are the same, just return a branch */ /// If all branches of a choice node f are the same, just return a branch.
static NodePtr Unique(const ChoicePtr& f) { static NodePtr Unique(const ChoicePtr& f) {
#ifndef DT_NO_PRUNING #ifndef GTSAM_DT_NO_PRUNING
if (f->allSame_) { if (f->allSame_) {
assert(f->branches().size() > 0); assert(f->branches().size() > 0);
NodePtr f0 = f->branches_[0]; NodePtr f0 = f->branches_[0];
@ -205,15 +208,13 @@ namespace gtsam {
bool isLeaf() const override { return false; } bool isLeaf() const override { return false; }
/** Constructor, given choice label and mandatory expected branch count */ /// Constructor, given choice label and mandatory expected branch count.
Choice(const L& label, size_t count) : Choice(const L& label, size_t count) :
label_(label), allSame_(true) { label_(label), allSame_(true) {
branches_.reserve(count); branches_.reserve(count);
} }
/** /// Construct from applying binary op to two Choice nodes.
* Construct from applying binary op to two Choice nodes
*/
Choice(const Choice& f, const Choice& g, const Binary& op) : Choice(const Choice& f, const Choice& g, const Binary& op) :
allSame_(true) { allSame_(true) {
// Choose what to do based on label // Choose what to do based on label
@ -241,6 +242,7 @@ namespace gtsam {
} }
} }
/// Return the label of this choice node.
const L& label() const { const L& label() const {
return label_; return label_;
} }
@ -262,7 +264,7 @@ namespace gtsam {
branches_.push_back(node); branches_.push_back(node);
} }
/** print (as a tree) */ /// print (as a tree).
void print(const std::string& s, const LabelFormatter& labelFormatter, void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override { const ValueFormatter& valueFormatter) const override {
std::cout << s << " Choice("; std::cout << s << " Choice(";
@ -308,7 +310,7 @@ namespace gtsam {
return (q.isLeaf() && q.sameLeaf(*this)); return (q.isLeaf() && q.sameLeaf(*this));
} }
/** equality */ /// equality
bool equals(const Node& q, const CompareFunc& compare) const override { bool equals(const Node& q, const CompareFunc& compare) const override {
const Choice* other = dynamic_cast<const Choice*>(&q); const Choice* other = dynamic_cast<const Choice*>(&q);
if (!other) return false; if (!other) return false;
@ -321,7 +323,7 @@ namespace gtsam {
return true; return true;
} }
/** evaluate */ /// evaluate
const Y& operator()(const Assignment<L>& x) const override { const Y& operator()(const Assignment<L>& x) const override {
#ifndef NDEBUG #ifndef NDEBUG
typename Assignment<L>::const_iterator it = x.find(label_); typename Assignment<L>::const_iterator it = x.find(label_);
@ -336,13 +338,13 @@ namespace gtsam {
return (*child)(x); return (*child)(x);
} }
/** /// Construct from applying unary op to a Choice node.
* Construct from applying unary op to a Choice node
*/
Choice(const L& label, const Choice& f, const Unary& op) : Choice(const L& label, const Choice& f, const Unary& op) :
label_(label), allSame_(true) { label_(label), allSame_(true) {
branches_.reserve(f.branches_.size()); // reserve space branches_.reserve(f.branches_.size()); // reserve space
for (const NodePtr& branch : f.branches_) push_back(branch->apply(op)); for (const NodePtr& branch : f.branches_) {
push_back(branch->apply(op));
}
} }
/** /**
@ -353,28 +355,28 @@ namespace gtsam {
* @param f The original choice node to apply the op on. * @param f The original choice node to apply the op on.
* @param op Function to apply on the choice node. Takes Assignment and * @param op Function to apply on the choice node. Takes Assignment and
* value as arguments. * value as arguments.
* @param choices The Assignment that will go to op. * @param assignment The Assignment that will go to op.
*/ */
Choice(const L& label, const Choice& f, const UnaryAssignment& op, Choice(const L& label, const Choice& f, const UnaryAssignment& op,
const Assignment<L>& choices) const Assignment<L>& assignment)
: label_(label), allSame_(true) { : label_(label), allSame_(true) {
branches_.reserve(f.branches_.size()); // reserve space branches_.reserve(f.branches_.size()); // reserve space
Assignment<L> choices_ = choices; Assignment<L> assignment_ = assignment;
for (size_t i = 0; i < f.branches_.size(); i++) { for (size_t i = 0; i < f.branches_.size(); i++) {
choices_[label_] = i; // Set assignment for label to i assignment_[label_] = i; // Set assignment for label to i
const NodePtr branch = f.branches_[i]; const NodePtr branch = f.branches_[i];
push_back(branch->apply(op, choices_)); push_back(branch->apply(op, assignment_));
// Remove the choice so we are backtracking // Remove the assignment so we are backtracking
auto choice_it = choices_.find(label_); auto assignment_it = assignment_.find(label_);
choices_.erase(choice_it); assignment_.erase(assignment_it);
} }
} }
/** apply unary operator */ /// apply unary operator.
NodePtr apply(const Unary& op) const override { NodePtr apply(const Unary& op) const override {
auto r = boost::make_shared<Choice>(label_, *this, op); auto r = boost::make_shared<Choice>(label_, *this, op);
return Unique(r); return Unique(r);
@ -382,8 +384,8 @@ namespace gtsam {
/// Apply unary operator with assignment /// Apply unary operator with assignment
NodePtr apply(const UnaryAssignment& op, NodePtr apply(const UnaryAssignment& op,
const Assignment<L>& choices) const override { const Assignment<L>& assignment) const override {
auto r = boost::make_shared<Choice>(label_, *this, op, choices); auto r = boost::make_shared<Choice>(label_, *this, op, assignment);
return Unique(r); return Unique(r);
} }
@ -678,7 +680,16 @@ namespace gtsam {
} }
/****************************************************************************/ /****************************************************************************/
// Functor performing depth-first visit without Assignment<L> argument. /**
* Functor performing depth-first visit to each leaf with the leaf value as
* the argument.
*
* NOTE: We differentiate between leaves and assignments. Concretely, a 3
* binary variable tree will have 2^3=8 assignments, but based on pruning, it
* can have less than 8 leaves. For example, if a tree has all assignment
* values as 1, then pruning will cause the tree to have only 1 leaf yet 8
* assignments.
*/
template <typename L, typename Y> template <typename L, typename Y>
struct Visit { struct Visit {
using F = std::function<void(const Y&)>; using F = std::function<void(const Y&)>;
@ -707,33 +718,74 @@ namespace gtsam {
} }
/****************************************************************************/ /****************************************************************************/
// Functor performing depth-first visit with Assignment<L> argument. /**
* Functor performing depth-first visit to each leaf with the Leaf object
* passed as an argument.
*
* NOTE: We differentiate between leaves and assignments. Concretely, a 3
* binary variable tree will have 2^3=8 assignments, but based on pruning, it
* can have <8 leaves. For example, if a tree has all assignment values as 1,
* then pruning will cause the tree to have only 1 leaf yet 8 assignments.
*/
template <typename L, typename Y>
struct VisitLeaf {
using F = std::function<void(const typename DecisionTree<L, Y>::Leaf&)>;
explicit VisitLeaf(F f) : f(f) {} ///< Construct from folding function.
F f; ///< folding function object.
/// Do a depth-first visit on the tree rooted at node.
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
using Leaf = typename DecisionTree<L, Y>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
return f(*leaf);
using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
if (!choice)
throw std::invalid_argument("DecisionTree::VisitLeaf: Invalid NodePtr");
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
}
};
template <typename L, typename Y>
template <typename Func>
void DecisionTree<L, Y>::visitLeaf(Func f) const {
VisitLeaf<L, Y> visit(f);
visit(root_);
}
/****************************************************************************/
/**
* Functor performing depth-first visit to each leaf with the leaf's
* `Assignment<L>` and value passed as arguments.
*
* NOTE: Follows the same pruning semantics as `visit`.
*/
template <typename L, typename Y> template <typename L, typename Y>
struct VisitWith { struct VisitWith {
using Choices = Assignment<L>; using F = std::function<void(const Assignment<L>&, const Y&)>;
using F = std::function<void(const Choices&, const Y&)>;
explicit VisitWith(F f) : f(f) {} ///< Construct from folding function. explicit VisitWith(F f) : f(f) {} ///< Construct from folding function.
Choices choices; ///< Assignment, mutating through recursion. Assignment<L> assignment; ///< Assignment, mutating through recursion.
F f; ///< folding function object. F f; ///< folding function object.
/// Do a depth-first visit on the tree rooted at node. /// Do a depth-first visit on the tree rooted at node.
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) { void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
using Leaf = typename DecisionTree<L, Y>::Leaf; using Leaf = typename DecisionTree<L, Y>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node)) if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
return f(choices, leaf->constant()); return f(assignment, leaf->constant());
using Choice = typename DecisionTree<L, Y>::Choice; using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node); auto choice = boost::dynamic_pointer_cast<const Choice>(node);
if (!choice) if (!choice)
throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr"); throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
for (size_t i = 0; i < choice->nrChoices(); i++) { for (size_t i = 0; i < choice->nrChoices(); i++) {
choices[choice->label()] = i; // Set assignment for label to i assignment[choice->label()] = i; // Set assignment for label to i
(*this)(choice->branches()[i]); // recurse! (*this)(choice->branches()[i]); // recurse!
// Remove the choice so we are backtracking // Remove the choice so we are backtracking
auto choice_it = choices.find(choice->label()); auto choice_it = assignment.find(choice->label());
choices.erase(choice_it); assignment.erase(choice_it);
} }
} }
}; };
@ -763,12 +815,26 @@ namespace gtsam {
} }
/****************************************************************************/ /****************************************************************************/
// labels is just done with a visit /**
* Get (partial) labels by performing a visit.
*
* This method performs a depth-first search to go to every leaf and records
* the keys assignment which leads to that leaf. Since the tree can be pruned,
* there might be a leaf at a lower depth which results in a partial
* assignment (i.e. not all keys are specified).
*
* E.g. given a tree with 3 keys, there may be a branch where the 3rd key has
* the same values for all the leaves. This leads to the branch being pruned
* so we get a leaf which is arrived at by just the first 2 keys and their
* assignments.
*/
template <typename L, typename Y> template <typename L, typename Y>
std::set<L> DecisionTree<L, Y>::labels() const { std::set<L> DecisionTree<L, Y>::labels() const {
std::set<L> unique; std::set<L> unique;
auto f = [&](const Assignment<L>& choices, const Y&) { auto f = [&](const Assignment<L>& assignment, const Y&) {
for (auto&& kv : choices) unique.insert(kv.first); for (auto&& kv : assignment) {
unique.insert(kv.first);
}
}; };
visitWith(f); visitWith(f);
return unique; return unique;
@ -817,8 +883,8 @@ namespace gtsam {
throw std::runtime_error( throw std::runtime_error(
"DecisionTree::apply(unary op) undefined for empty tree."); "DecisionTree::apply(unary op) undefined for empty tree.");
} }
Assignment<L> choices; Assignment<L> assignment;
return DecisionTree(root_->apply(op, choices)); return DecisionTree(root_->apply(op, assignment));
} }
/****************************************************************************/ /****************************************************************************/

View File

@ -105,7 +105,7 @@ namespace gtsam {
virtual const Y& operator()(const Assignment<L>& x) const = 0; virtual const Y& operator()(const Assignment<L>& x) const = 0;
virtual Ptr apply(const Unary& op) const = 0; virtual Ptr apply(const Unary& op) const = 0;
virtual Ptr apply(const UnaryAssignment& op, virtual Ptr apply(const UnaryAssignment& op,
const Assignment<L>& choices) const = 0; const Assignment<L>& assignment) const = 0;
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0; virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0; virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0;
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0; virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
@ -153,7 +153,7 @@ namespace gtsam {
/** Create a constant */ /** Create a constant */
explicit DecisionTree(const Y& y); explicit DecisionTree(const Y& y);
/** Create a new leaf function splitting on a variable */ /// Create tree with 2 assignments `y1`, `y2`, splitting on variable `label`
DecisionTree(const L& label, const Y& y1, const Y& y2); DecisionTree(const L& label, const Y& y1, const Y& y2);
/** Allow Label+Cardinality for convenience */ /** Allow Label+Cardinality for convenience */
@ -219,9 +219,8 @@ namespace gtsam {
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
/** Make virtual */ /// Make virtual
virtual ~DecisionTree() { virtual ~DecisionTree() {}
}
/// Check if tree is empty. /// Check if tree is empty.
bool empty() const { return !root_; } bool empty() const { return !root_; }
@ -235,9 +234,11 @@ namespace gtsam {
/** /**
* @brief Visit all leaves in depth-first fashion. * @brief Visit all leaves in depth-first fashion.
* *
* @param f side-effect taking a value. * @param f (side-effect) Function taking a value.
* *
* @note Due to pruning, leaves might not exhaust choices. * @note Due to pruning, the number of leaves may not be the same as the
* number of assignments. E.g. if we have a tree on 2 binary variables with
* all values being 1, then there are 2^2=4 assignments, but only 1 leaf.
* *
* Example: * Example:
* int sum = 0; * int sum = 0;
@ -250,13 +251,32 @@ namespace gtsam {
/** /**
* @brief Visit all leaves in depth-first fashion. * @brief Visit all leaves in depth-first fashion.
* *
* @param f side-effect taking an assignment and a value. * @param f (side-effect) Function taking the leaf node pointer.
* *
* @note Due to pruning, leaves might not exhaust choices. * @note Due to pruning, the number of leaves may not be the same as the
* number of assignments. E.g. if we have a tree on 2 binary variables with
* all values being 1, then there are 2^2=4 assignments, but only 1 leaf.
* *
* Example: * Example:
* int sum = 0; * int sum = 0;
* auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; }; * auto visitor = [&](int y) { sum += y; };
* tree.visitWith(visitor);
*/
template <typename Func>
void visitLeaf(Func f) const;
/**
* @brief Visit all leaves in depth-first fashion.
*
* @param f (side-effect) Function taking an assignment and a value.
*
* @note Due to pruning, the number of leaves may not be the same as the
* number of assignments. E.g. if we have a tree on 2 binary variables with
* all values being 1, then there are 2^2=4 assignments, but only 1 leaf.
*
* Example:
* int sum = 0;
* auto visitor = [&](const Assignment<L>& assignment, int y) { sum += y; };
* tree.visitWith(visitor); * tree.visitWith(visitor);
*/ */
template <typename Func> template <typename Func>

View File

@ -287,12 +287,16 @@ namespace gtsam {
cardinalities_(keys.cardinalities()) {} cardinalities_(keys.cardinalities()) {}
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrLeaves) const { DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
const size_t N = maxNrLeaves; const size_t N = maxNrAssignments;
// Get the probabilities in the decision tree so we can threshold. // Get the probabilities in the decision tree so we can threshold.
std::vector<double> probabilities; std::vector<double> probabilities;
this->visit([&](const double& prob) { probabilities.emplace_back(prob); }); this->visitLeaf([&](const Leaf& leaf) {
size_t nrAssignments = leaf.nrAssignments();
double prob = leaf.constant();
probabilities.insert(probabilities.end(), nrAssignments, prob);
});
// The number of probabilities can be lower than max_leaves // The number of probabilities can be lower than max_leaves
if (probabilities.size() <= N) { if (probabilities.size() <= N) {

View File

@ -175,12 +175,13 @@ namespace gtsam {
* *
* Pruning will set the leaves to be "pruned" to 0 indicating a 0 * Pruning will set the leaves to be "pruned" to 0 indicating a 0
* probability. * probability.
* A leaf is pruned if it is not in the top `maxNrLeaves` values. * An assignment is pruned if it is not in the top `maxNrAssignments`
* values.
* *
* @param maxNrLeaves The maximum number of leaves to keep. * @param maxNrAssignments The maximum number of assignments to keep.
* @return DecisionTreeFactor * @return DecisionTreeFactor
*/ */
DecisionTreeFactor prune(size_t maxNrLeaves) const; DecisionTreeFactor prune(size_t maxNrAssignments) const;
/// @} /// @}
/// @name Wrapper support /// @name Wrapper support

View File

@ -20,7 +20,7 @@
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits #include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
// headers first to make sure no missing headers // headers first to make sure no missing headers
//#define DT_NO_PRUNING //#define GTSAM_DT_NO_PRUNING
#include <gtsam/discrete/AlgebraicDecisionTree.h> #include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only #include <gtsam/discrete/DecisionTree-inl.h> // for convert only
#define DISABLE_TIMING #define DISABLE_TIMING

View File

@ -18,7 +18,7 @@
*/ */
// #define DT_DEBUG_MEMORY // #define DT_DEBUG_MEMORY
// #define DT_NO_PRUNING // #define GTSAM_DT_NO_PRUNING
#define DISABLE_DOT #define DISABLE_DOT
#include <gtsam/discrete/DecisionTree-inl.h> #include <gtsam/discrete/DecisionTree-inl.h>