Faster version of convertFrom when no label translation needed
parent
f98b9223e8
commit
6a5dd60d33
|
@ -557,9 +557,7 @@ namespace gtsam {
|
|||
template <typename X, typename Func>
|
||||
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
|
||||
Func Y_of_X) {
|
||||
// Define functor for identity mapping of node label.
|
||||
auto L_of_L = [](const L& label) { return label; };
|
||||
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
|
||||
root_ = convertFrom<X>(other.root_, Y_of_X);
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
|
@ -698,6 +696,36 @@ namespace gtsam {
|
|||
}
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
template <typename X>
|
||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
|
||||
const typename DecisionTree<L, X>::NodePtr& f,
|
||||
std::function<Y(const X&)> Y_of_X) {
|
||||
|
||||
// If leaf, apply unary conversion "op" and create a unique leaf.
|
||||
using LXLeaf = typename DecisionTree<L, X>::Leaf;
|
||||
if (auto leaf = std::dynamic_pointer_cast<const LXLeaf>(f)) {
|
||||
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
||||
}
|
||||
|
||||
// Check if Choice
|
||||
using LXChoice = typename DecisionTree<L, X>::Choice;
|
||||
auto choice = std::dynamic_pointer_cast<const LXChoice>(f);
|
||||
if (!choice) throw std::invalid_argument(
|
||||
"DecisionTree::convertFrom: Invalid NodePtr");
|
||||
|
||||
// Create a new Choice node with the same label
|
||||
auto newChoice = std::make_shared<Choice>(choice->label(), choice->nrChoices());
|
||||
|
||||
// Convert each branch recursively
|
||||
for (auto&& branch : choice->branches()) {
|
||||
newChoice->push_back(convertFrom<X>(branch, Y_of_X));
|
||||
}
|
||||
|
||||
return Choice::Unique(newChoice);
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
template <typename M, typename X>
|
||||
|
@ -745,8 +773,9 @@ namespace gtsam {
|
|||
*
|
||||
* NOTE: We differentiate between leaves and assignments. Concretely, a 3
|
||||
* binary variable tree will have 2^3=8 assignments, but based on pruning, it
|
||||
* can have <8 leaves. For example, if a tree has all assignment values as 1,
|
||||
* then pruning will cause the tree to have only 1 leaf yet 8 assignments.
|
||||
* can have less than 8 leaves. For example, if a tree has all assignment
|
||||
* values as 1, then pruning will cause the tree to have only 1 leaf yet 8
|
||||
* assignments.
|
||||
*/
|
||||
template <typename L, typename Y>
|
||||
struct Visit {
|
||||
|
|
|
@ -165,6 +165,19 @@ namespace gtsam {
|
|||
template <typename It, typename ValueIt>
|
||||
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
|
||||
|
||||
/**
|
||||
* @brief Convert from a DecisionTree<L, X> to DecisionTree<L, Y>.
|
||||
*
|
||||
* @tparam M The previous label type.
|
||||
* @tparam X The previous value type.
|
||||
* @param f The node pointer to the root of the previous DecisionTree.
|
||||
* @param Y_of_X Functor to convert from value type X to type Y.
|
||||
* @return NodePtr
|
||||
*/
|
||||
template <typename X>
|
||||
static NodePtr convertFrom(const typename DecisionTree<L, X>::NodePtr& f,
|
||||
std::function<Y(const X&)> Y_of_X);
|
||||
|
||||
/**
|
||||
* @brief Convert from a DecisionTree<M, X> to DecisionTree<L, Y>.
|
||||
*
|
||||
|
|
Loading…
Reference in New Issue