Faster version of convertFrom when no label translation needed

release/4.3a0
Frank Dellaert 2024-10-15 15:51:21 +09:00
parent f98b9223e8
commit 6a5dd60d33
2 changed files with 47 additions and 5 deletions

View File

@ -557,9 +557,7 @@ namespace gtsam {
template <typename X, typename Func> template <typename X, typename Func>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other, DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
Func Y_of_X) { Func Y_of_X) {
// Define functor for identity mapping of node label. root_ = convertFrom<X>(other.root_, Y_of_X);
auto L_of_L = [](const L& label) { return label; };
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
} }
/****************************************************************************/ /****************************************************************************/
@ -698,6 +696,36 @@ namespace gtsam {
} }
} }
/****************************************************************************/
template <typename L, typename Y>
template <typename X>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
const typename DecisionTree<L, X>::NodePtr& f,
std::function<Y(const X&)> Y_of_X) {
// If leaf, apply unary conversion "op" and create a unique leaf.
using LXLeaf = typename DecisionTree<L, X>::Leaf;
if (auto leaf = std::dynamic_pointer_cast<const LXLeaf>(f)) {
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
}
// Check if Choice
using LXChoice = typename DecisionTree<L, X>::Choice;
auto choice = std::dynamic_pointer_cast<const LXChoice>(f);
if (!choice) throw std::invalid_argument(
"DecisionTree::convertFrom: Invalid NodePtr");
// Create a new Choice node with the same label
auto newChoice = std::make_shared<Choice>(choice->label(), choice->nrChoices());
// Convert each branch recursively
for (auto&& branch : choice->branches()) {
newChoice->push_back(convertFrom<X>(branch, Y_of_X));
}
return Choice::Unique(newChoice);
}
/****************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
template <typename M, typename X> template <typename M, typename X>
@ -745,8 +773,9 @@ namespace gtsam {
* *
* NOTE: We differentiate between leaves and assignments. Concretely, a 3 * NOTE: We differentiate between leaves and assignments. Concretely, a 3
* binary variable tree will have 2^3=8 assignments, but based on pruning, it * binary variable tree will have 2^3=8 assignments, but based on pruning, it
* can have <8 leaves. For example, if a tree has all assignment values as 1, * can have less than 8 leaves. For example, if a tree has all assignment
* then pruning will cause the tree to have only 1 leaf yet 8 assignments. * values as 1, then pruning will cause the tree to have only 1 leaf yet 8
* assignments.
*/ */
template <typename L, typename Y> template <typename L, typename Y>
struct Visit { struct Visit {

View File

@ -165,6 +165,19 @@ namespace gtsam {
template <typename It, typename ValueIt> template <typename It, typename ValueIt>
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
/**
* @brief Convert from a DecisionTree<L, X> to DecisionTree<L, Y>.
*
* @tparam M The previous label type.
* @tparam X The previous value type.
* @param f The node pointer to the root of the previous DecisionTree.
* @param Y_of_X Functor to convert from value type X to type Y.
* @return NodePtr
*/
template <typename X>
static NodePtr convertFrom(const typename DecisionTree<L, X>::NodePtr& f,
std::function<Y(const X&)> Y_of_X);
/** /**
* @brief Convert from a DecisionTree<M, X> to DecisionTree<L, Y>. * @brief Convert from a DecisionTree<M, X> to DecisionTree<L, Y>.
* *