Merge pull request #1155 from borglab/decisiontree-refactor

release/4.3a0
Frank Dellaert 2022-04-13 22:24:43 -04:00 committed by GitHub
commit 78d7e903f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 163 additions and 72 deletions

View File

@ -59,7 +59,7 @@ namespace gtsam {
/** constant stored in this leaf */
Y constant_;
/** The number of assignments contained within this leaf
/** The number of assignments contained within this leaf.
* Particularly useful when leaves have been pruned.
*/
size_t nrAssignments_;
@ -68,7 +68,7 @@ namespace gtsam {
Leaf(const Y& constant, size_t nrAssignments = 1)
: constant_(constant), nrAssignments_(nrAssignments) {}
/** return the constant */
/// Return the constant
const Y& constant() const {
return constant_;
}
@ -81,19 +81,19 @@ namespace gtsam {
return constant_ == q.constant_;
}
/// polymorphic equality: is q is a leaf, could be
/// polymorphic equality: is q a leaf and is it the same as this leaf?
bool sameLeaf(const Node& q) const override {
return (q.isLeaf() && q.sameLeaf(*this));
}
/** equality up to tolerance */
/// equality up to tolerance
bool equals(const Node& q, const CompareFunc& compare) const override {
const Leaf* other = dynamic_cast<const Leaf*>(&q);
if (!other) return false;
return compare(this->constant_, other->constant_);
}
/** print */
/// print
void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override {
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
@ -122,8 +122,8 @@ namespace gtsam {
/// Apply unary operator with assignment
NodePtr apply(const UnaryAssignment& op,
const Assignment<L>& choices) const override {
NodePtr f(new Leaf(op(choices, constant_), nrAssignments_));
const Assignment<L>& assignment) const override {
NodePtr f(new Leaf(op(assignment, constant_), nrAssignments_));
return f;
}
@ -168,7 +168,10 @@ namespace gtsam {
std::vector<NodePtr> branches_;
private:
/** incremental allSame */
/**
* Incremental allSame.
* Records if all the branches are the same leaf.
*/
size_t allSame_;
using ChoicePtr = boost::shared_ptr<const Choice>;
@ -181,9 +184,9 @@ namespace gtsam {
#endif
}
/** If all branches of a choice node f are the same, just return a branch */
/// If all branches of a choice node f are the same, just return a branch.
static NodePtr Unique(const ChoicePtr& f) {
#ifndef DT_NO_PRUNING
#ifndef GTSAM_DT_NO_PRUNING
if (f->allSame_) {
assert(f->branches().size() > 0);
NodePtr f0 = f->branches_[0];
@ -205,15 +208,13 @@ namespace gtsam {
bool isLeaf() const override { return false; }
/** Constructor, given choice label and mandatory expected branch count */
/// Constructor, given choice label and mandatory expected branch count.
Choice(const L& label, size_t count) :
label_(label), allSame_(true) {
branches_.reserve(count);
}
/**
* Construct from applying binary op to two Choice nodes
*/
/// Construct from applying binary op to two Choice nodes.
Choice(const Choice& f, const Choice& g, const Binary& op) :
allSame_(true) {
// Choose what to do based on label
@ -241,6 +242,7 @@ namespace gtsam {
}
}
/// Return the label of this choice node.
const L& label() const {
return label_;
}
@ -262,7 +264,7 @@ namespace gtsam {
branches_.push_back(node);
}
/** print (as a tree) */
/// print (as a tree).
void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override {
std::cout << s << " Choice(";
@ -308,7 +310,7 @@ namespace gtsam {
return (q.isLeaf() && q.sameLeaf(*this));
}
/** equality */
/// equality
bool equals(const Node& q, const CompareFunc& compare) const override {
const Choice* other = dynamic_cast<const Choice*>(&q);
if (!other) return false;
@ -321,7 +323,7 @@ namespace gtsam {
return true;
}
/** evaluate */
/// evaluate
const Y& operator()(const Assignment<L>& x) const override {
#ifndef NDEBUG
typename Assignment<L>::const_iterator it = x.find(label_);
@ -336,13 +338,13 @@ namespace gtsam {
return (*child)(x);
}
/**
* Construct from applying unary op to a Choice node
*/
/// Construct from applying unary op to a Choice node.
Choice(const L& label, const Choice& f, const Unary& op) :
label_(label), allSame_(true) {
branches_.reserve(f.branches_.size()); // reserve space
for (const NodePtr& branch : f.branches_) push_back(branch->apply(op));
for (const NodePtr& branch : f.branches_) {
push_back(branch->apply(op));
}
}
/**
@ -353,28 +355,28 @@ namespace gtsam {
* @param f The original choice node to apply the op on.
* @param op Function to apply on the choice node. Takes Assignment and
* value as arguments.
* @param choices The Assignment that will go to op.
* @param assignment The Assignment that will go to op.
*/
Choice(const L& label, const Choice& f, const UnaryAssignment& op,
const Assignment<L>& choices)
const Assignment<L>& assignment)
: label_(label), allSame_(true) {
branches_.reserve(f.branches_.size()); // reserve space
Assignment<L> choices_ = choices;
Assignment<L> assignment_ = assignment;
for (size_t i = 0; i < f.branches_.size(); i++) {
choices_[label_] = i; // Set assignment for label to i
assignment_[label_] = i; // Set assignment for label to i
const NodePtr branch = f.branches_[i];
push_back(branch->apply(op, choices_));
push_back(branch->apply(op, assignment_));
// Remove the choice so we are backtracking
auto choice_it = choices_.find(label_);
choices_.erase(choice_it);
// Remove the assignment so we are backtracking
auto assignment_it = assignment_.find(label_);
assignment_.erase(assignment_it);
}
}
/** apply unary operator */
/// apply unary operator.
NodePtr apply(const Unary& op) const override {
auto r = boost::make_shared<Choice>(label_, *this, op);
return Unique(r);
@ -382,8 +384,8 @@ namespace gtsam {
/// Apply unary operator with assignment
NodePtr apply(const UnaryAssignment& op,
const Assignment<L>& choices) const override {
auto r = boost::make_shared<Choice>(label_, *this, op, choices);
const Assignment<L>& assignment) const override {
auto r = boost::make_shared<Choice>(label_, *this, op, assignment);
return Unique(r);
}
@ -678,7 +680,16 @@ namespace gtsam {
}
/****************************************************************************/
// Functor performing depth-first visit without Assignment<L> argument.
/**
* Functor performing depth-first visit to each leaf with the leaf value as
* the argument.
*
* NOTE: We differentiate between leaves and assignments. Concretely, a 3
* binary variable tree will have 2^3=8 assignments, but based on pruning, it
* can have less than 8 leaves. For example, if a tree has all assignment
* values as 1, then pruning will cause the tree to have only 1 leaf yet 8
* assignments.
*/
template <typename L, typename Y>
struct Visit {
using F = std::function<void(const Y&)>;
@ -707,33 +718,74 @@ namespace gtsam {
}
/****************************************************************************/
// Functor performing depth-first visit with Assignment<L> argument.
/**
* Functor performing depth-first visit to each leaf with the Leaf object
* passed as an argument.
*
* NOTE: We differentiate between leaves and assignments. Concretely, a 3
* binary variable tree will have 2^3=8 assignments, but based on pruning, it
* can have <8 leaves. For example, if a tree has all assignment values as 1,
* then pruning will cause the tree to have only 1 leaf yet 8 assignments.
*/
template <typename L, typename Y>
struct VisitLeaf {
using F = std::function<void(const typename DecisionTree<L, Y>::Leaf&)>;
explicit VisitLeaf(F f) : f(f) {} ///< Construct from folding function.
F f; ///< folding function object.
/// Do a depth-first visit on the tree rooted at node.
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
using Leaf = typename DecisionTree<L, Y>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
return f(*leaf);
using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
if (!choice)
throw std::invalid_argument("DecisionTree::VisitLeaf: Invalid NodePtr");
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
}
};
template <typename L, typename Y>
template <typename Func>
void DecisionTree<L, Y>::visitLeaf(Func f) const {
VisitLeaf<L, Y> visit(f);
visit(root_);
}
/****************************************************************************/
/**
* Functor performing depth-first visit to each leaf with the leaf's
* `Assignment<L>` and value passed as arguments.
*
* NOTE: Follows the same pruning semantics as `visit`.
*/
template <typename L, typename Y>
struct VisitWith {
using Choices = Assignment<L>;
using F = std::function<void(const Choices&, const Y&)>;
using F = std::function<void(const Assignment<L>&, const Y&)>;
explicit VisitWith(F f) : f(f) {} ///< Construct from folding function.
Choices choices; ///< Assignment, mutating through recursion.
F f; ///< folding function object.
Assignment<L> assignment; ///< Assignment, mutating through recursion.
F f; ///< folding function object.
/// Do a depth-first visit on the tree rooted at node.
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
using Leaf = typename DecisionTree<L, Y>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
return f(choices, leaf->constant());
return f(assignment, leaf->constant());
using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
if (!choice)
throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
for (size_t i = 0; i < choice->nrChoices(); i++) {
choices[choice->label()] = i; // Set assignment for label to i
assignment[choice->label()] = i; // Set assignment for label to i
(*this)(choice->branches()[i]); // recurse!
// Remove the choice so we are backtracking
auto choice_it = choices.find(choice->label());
choices.erase(choice_it);
auto choice_it = assignment.find(choice->label());
assignment.erase(choice_it);
}
}
};
@ -763,12 +815,26 @@ namespace gtsam {
}
/****************************************************************************/
// labels is just done with a visit
/**
* Get (partial) labels by performing a visit.
*
* This method performs a depth-first search to go to every leaf and records
* the keys assignment which leads to that leaf. Since the tree can be pruned,
* there might be a leaf at a lower depth which results in a partial
* assignment (i.e. not all keys are specified).
*
* E.g. given a tree with 3 keys, there may be a branch where the 3rd key has
* the same values for all the leaves. This leads to the branch being pruned
* so we get a leaf which is arrived at by just the first 2 keys and their
* assignments.
*/
template <typename L, typename Y>
std::set<L> DecisionTree<L, Y>::labels() const {
std::set<L> unique;
auto f = [&](const Assignment<L>& choices, const Y&) {
for (auto&& kv : choices) unique.insert(kv.first);
auto f = [&](const Assignment<L>& assignment, const Y&) {
for (auto&& kv : assignment) {
unique.insert(kv.first);
}
};
visitWith(f);
return unique;
@ -817,8 +883,8 @@ namespace gtsam {
throw std::runtime_error(
"DecisionTree::apply(unary op) undefined for empty tree.");
}
Assignment<L> choices;
return DecisionTree(root_->apply(op, choices));
Assignment<L> assignment;
return DecisionTree(root_->apply(op, assignment));
}
/****************************************************************************/

View File

@ -105,7 +105,7 @@ namespace gtsam {
virtual const Y& operator()(const Assignment<L>& x) const = 0;
virtual Ptr apply(const Unary& op) const = 0;
virtual Ptr apply(const UnaryAssignment& op,
const Assignment<L>& choices) const = 0;
const Assignment<L>& assignment) const = 0;
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0;
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
@ -153,7 +153,7 @@ namespace gtsam {
/** Create a constant */
explicit DecisionTree(const Y& y);
/** Create a new leaf function splitting on a variable */
/// Create tree with 2 assignments `y1`, `y2`, splitting on variable `label`
DecisionTree(const L& label, const Y& y1, const Y& y2);
/** Allow Label+Cardinality for convenience */
@ -219,9 +219,8 @@ namespace gtsam {
/// @name Standard Interface
/// @{
/** Make virtual */
virtual ~DecisionTree() {
}
/// Make virtual
virtual ~DecisionTree() {}
/// Check if tree is empty.
bool empty() const { return !root_; }
@ -234,11 +233,13 @@ namespace gtsam {
/**
* @brief Visit all leaves in depth-first fashion.
*
* @param f side-effect taking a value.
*
* @note Due to pruning, leaves might not exhaust choices.
*
*
* @param f (side-effect) Function taking a value.
*
* @note Due to pruning, the number of leaves may not be the same as the
* number of assignments. E.g. if we have a tree on 2 binary variables with
* all values being 1, then there are 2^2=4 assignments, but only 1 leaf.
*
* Example:
* int sum = 0;
* auto visitor = [&](int y) { sum += y; };
@ -249,14 +250,33 @@ namespace gtsam {
/**
* @brief Visit all leaves in depth-first fashion.
*
* @param f side-effect taking an assignment and a value.
*
* @note Due to pruning, leaves might not exhaust choices.
*
*
* @param f (side-effect) Function taking the leaf node pointer.
*
* @note Due to pruning, the number of leaves may not be the same as the
* number of assignments. E.g. if we have a tree on 2 binary variables with
* all values being 1, then there are 2^2=4 assignments, but only 1 leaf.
*
* Example:
* int sum = 0;
* auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; };
* auto visitor = [&](int y) { sum += y; };
* tree.visitWith(visitor);
*/
template <typename Func>
void visitLeaf(Func f) const;
/**
* @brief Visit all leaves in depth-first fashion.
*
* @param f (side-effect) Function taking an assignment and a value.
*
* @note Due to pruning, the number of leaves may not be the same as the
* number of assignments. E.g. if we have a tree on 2 binary variables with
* all values being 1, then there are 2^2=4 assignments, but only 1 leaf.
*
* Example:
* int sum = 0;
* auto visitor = [&](const Assignment<L>& assignment, int y) { sum += y; };
* tree.visitWith(visitor);
*/
template <typename Func>
@ -275,7 +295,7 @@ namespace gtsam {
*
* @note X is always passed by value.
* @note Due to pruning, leaves might not exhaust choices.
*
*
* Example:
* auto add = [](const double& y, double x) { return y + x; };
* double sum = tree.fold(add, 0.0);

View File

@ -287,12 +287,16 @@ namespace gtsam {
cardinalities_(keys.cardinalities()) {}
/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrLeaves) const {
const size_t N = maxNrLeaves;
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
const size_t N = maxNrAssignments;
// Get the probabilities in the decision tree so we can threshold.
std::vector<double> probabilities;
this->visit([&](const double& prob) { probabilities.emplace_back(prob); });
this->visitLeaf([&](const Leaf& leaf) {
size_t nrAssignments = leaf.nrAssignments();
double prob = leaf.constant();
probabilities.insert(probabilities.end(), nrAssignments, prob);
});
// The number of probabilities can be lower than max_leaves
if (probabilities.size() <= N) {

View File

@ -175,12 +175,13 @@ namespace gtsam {
*
* Pruning will set the leaves to be "pruned" to 0 indicating a 0
* probability.
* A leaf is pruned if it is not in the top `maxNrLeaves` values.
* An assignment is pruned if it is not in the top `maxNrAssignments`
* values.
*
* @param maxNrLeaves The maximum number of leaves to keep.
* @param maxNrAssignments The maximum number of assignments to keep.
* @return DecisionTreeFactor
*/
DecisionTreeFactor prune(size_t maxNrLeaves) const;
DecisionTreeFactor prune(size_t maxNrAssignments) const;
/// @}
/// @name Wrapper support

View File

@ -20,7 +20,7 @@
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
#include <gtsam/discrete/DiscreteValues.h>
// headers first to make sure no missing headers
//#define DT_NO_PRUNING
//#define GTSAM_DT_NO_PRUNING
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
#define DISABLE_TIMING

View File

@ -18,7 +18,7 @@
*/
// #define DT_DEBUG_MEMORY
// #define DT_NO_PRUNING
// #define GTSAM_DT_NO_PRUNING
#define DISABLE_DOT
#include <gtsam/discrete/DecisionTree-inl.h>