Merge pull request #1144 from borglab/decision-tree-improvements
commit
239a978086
|
|
@ -635,11 +635,13 @@ namespace gtsam {
|
||||||
std::function<Y(const X&)> Y_of_X) const {
|
std::function<Y(const X&)> Y_of_X) const {
|
||||||
using LY = DecisionTree<L, Y>;
|
using LY = DecisionTree<L, Y>;
|
||||||
|
|
||||||
// ugliness below because apparently we can't have templated virtual
|
// Ugliness below because apparently we can't have templated virtual
|
||||||
// functions If leaf, apply unary conversion "op" and create a unique leaf
|
// functions.
|
||||||
|
// If leaf, apply unary conversion "op" and create a unique leaf.
|
||||||
using MXLeaf = typename DecisionTree<M, X>::Leaf;
|
using MXLeaf = typename DecisionTree<M, X>::Leaf;
|
||||||
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f))
|
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f)) {
|
||||||
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
||||||
|
}
|
||||||
|
|
||||||
// Check if Choice
|
// Check if Choice
|
||||||
using MXChoice = typename DecisionTree<M, X>::Choice;
|
using MXChoice = typename DecisionTree<M, X>::Choice;
|
||||||
|
|
@ -727,6 +729,14 @@ namespace gtsam {
|
||||||
visit(root_);
|
visit(root_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
|
template <typename L, typename Y>
|
||||||
|
size_t DecisionTree<L, Y>::nrLeaves() const {
|
||||||
|
size_t total = 0;
|
||||||
|
visit([&total](const Y& node) { total += 1; });
|
||||||
|
return total;
|
||||||
|
}
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
// fold is just done with a visit
|
// fold is just done with a visit
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
|
|
|
||||||
|
|
@ -262,6 +262,9 @@ namespace gtsam {
|
||||||
template <typename Func>
|
template <typename Func>
|
||||||
void visitWith(Func f) const;
|
void visitWith(Func f) const;
|
||||||
|
|
||||||
|
/// Return the number of leaves in the tree.
|
||||||
|
size_t nrLeaves() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Fold a binary function over the tree, returning accumulator.
|
* @brief Fold a binary function over the tree, returning accumulator.
|
||||||
*
|
*
|
||||||
|
|
|
||||||
|
|
@ -156,10 +156,7 @@ namespace gtsam {
|
||||||
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate()
|
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate()
|
||||||
const {
|
const {
|
||||||
// Get all possible assignments
|
// Get all possible assignments
|
||||||
std::vector<std::pair<Key, size_t>> pairs;
|
std::vector<std::pair<Key, size_t>> pairs = discreteKeys();
|
||||||
for (auto& key : keys()) {
|
|
||||||
pairs.emplace_back(key, cardinalities_.at(key));
|
|
||||||
}
|
|
||||||
// Reverse to make cartesian product output a more natural ordering.
|
// Reverse to make cartesian product output a more natural ordering.
|
||||||
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
||||||
const auto assignments = DiscreteValues::CartesianProduct(rpairs);
|
const auto assignments = DiscreteValues::CartesianProduct(rpairs);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue