Faster version of convertFrom when no label translation needed
parent
f98b9223e8
commit
6a5dd60d33
|
@ -557,9 +557,7 @@ namespace gtsam {
|
||||||
template <typename X, typename Func>
|
template <typename X, typename Func>
|
||||||
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
|
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
|
||||||
Func Y_of_X) {
|
Func Y_of_X) {
|
||||||
// Define functor for identity mapping of node label.
|
root_ = convertFrom<X>(other.root_, Y_of_X);
|
||||||
auto L_of_L = [](const L& label) { return label; };
|
|
||||||
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
|
@ -698,6 +696,36 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
|
template <typename L, typename Y>
|
||||||
|
template <typename X>
|
||||||
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
|
||||||
|
const typename DecisionTree<L, X>::NodePtr& f,
|
||||||
|
std::function<Y(const X&)> Y_of_X) {
|
||||||
|
|
||||||
|
// If leaf, apply unary conversion "op" and create a unique leaf.
|
||||||
|
using LXLeaf = typename DecisionTree<L, X>::Leaf;
|
||||||
|
if (auto leaf = std::dynamic_pointer_cast<const LXLeaf>(f)) {
|
||||||
|
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if Choice
|
||||||
|
using LXChoice = typename DecisionTree<L, X>::Choice;
|
||||||
|
auto choice = std::dynamic_pointer_cast<const LXChoice>(f);
|
||||||
|
if (!choice) throw std::invalid_argument(
|
||||||
|
"DecisionTree::convertFrom: Invalid NodePtr");
|
||||||
|
|
||||||
|
// Create a new Choice node with the same label
|
||||||
|
auto newChoice = std::make_shared<Choice>(choice->label(), choice->nrChoices());
|
||||||
|
|
||||||
|
// Convert each branch recursively
|
||||||
|
for (auto&& branch : choice->branches()) {
|
||||||
|
newChoice->push_back(convertFrom<X>(branch, Y_of_X));
|
||||||
|
}
|
||||||
|
|
||||||
|
return Choice::Unique(newChoice);
|
||||||
|
}
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template <typename M, typename X>
|
template <typename M, typename X>
|
||||||
|
@ -745,8 +773,9 @@ namespace gtsam {
|
||||||
*
|
*
|
||||||
* NOTE: We differentiate between leaves and assignments. Concretely, a 3
|
* NOTE: We differentiate between leaves and assignments. Concretely, a 3
|
||||||
* binary variable tree will have 2^3=8 assignments, but based on pruning, it
|
* binary variable tree will have 2^3=8 assignments, but based on pruning, it
|
||||||
* can have <8 leaves. For example, if a tree has all assignment values as 1,
|
* can have less than 8 leaves. For example, if a tree has all assignment
|
||||||
* then pruning will cause the tree to have only 1 leaf yet 8 assignments.
|
* values as 1, then pruning will cause the tree to have only 1 leaf yet 8
|
||||||
|
* assignments.
|
||||||
*/
|
*/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
struct Visit {
|
struct Visit {
|
||||||
|
|
|
@ -165,6 +165,19 @@ namespace gtsam {
|
||||||
template <typename It, typename ValueIt>
|
template <typename It, typename ValueIt>
|
||||||
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
|
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Convert from a DecisionTree<L, X> to DecisionTree<L, Y>.
|
||||||
|
*
|
||||||
|
* @tparam M The previous label type.
|
||||||
|
* @tparam X The previous value type.
|
||||||
|
* @param f The node pointer to the root of the previous DecisionTree.
|
||||||
|
* @param Y_of_X Functor to convert from value type X to type Y.
|
||||||
|
* @return NodePtr
|
||||||
|
*/
|
||||||
|
template <typename X>
|
||||||
|
static NodePtr convertFrom(const typename DecisionTree<L, X>::NodePtr& f,
|
||||||
|
std::function<Y(const X&)> Y_of_X);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Convert from a DecisionTree<M, X> to DecisionTree<L, Y>.
|
* @brief Convert from a DecisionTree<M, X> to DecisionTree<L, Y>.
|
||||||
*
|
*
|
||||||
|
|
Loading…
Reference in New Issue