diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 8be5efaa6..6f19574fc 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -557,9 +557,7 @@ namespace gtsam { template DecisionTree::DecisionTree(const DecisionTree& other, Func Y_of_X) { - // Define functor for identity mapping of node label. - auto L_of_L = [](const L& label) { return label; }; - root_ = convertFrom(other.root_, L_of_L, Y_of_X); + root_ = convertFrom(other.root_, Y_of_X); } /****************************************************************************/ @@ -698,6 +696,36 @@ namespace gtsam { } } + /****************************************************************************/ + template + template + typename DecisionTree::NodePtr DecisionTree::convertFrom( + const typename DecisionTree::NodePtr& f, + std::function Y_of_X) { + + // If leaf, apply unary conversion "op" and create a unique leaf. + using LXLeaf = typename DecisionTree::Leaf; + if (auto leaf = std::dynamic_pointer_cast(f)) { + return NodePtr(new Leaf(Y_of_X(leaf->constant()))); + } + + // Check if Choice + using LXChoice = typename DecisionTree::Choice; + auto choice = std::dynamic_pointer_cast(f); + if (!choice) throw std::invalid_argument( + "DecisionTree::convertFrom: Invalid NodePtr"); + + // Create a new Choice node with the same label + auto newChoice = std::make_shared(choice->label(), choice->nrChoices()); + + // Convert each branch recursively + for (auto&& branch : choice->branches()) { + newChoice->push_back(convertFrom(branch, Y_of_X)); + } + + return Choice::Unique(newChoice); + } + /****************************************************************************/ template template @@ -745,8 +773,9 @@ namespace gtsam { * * NOTE: We differentiate between leaves and assignments. Concretely, a 3 * binary variable tree will have 2^3=8 assignments, but based on pruning, it - * can have <8 leaves. For example, if a tree has all assignment values as 1, - * then pruning will cause the tree to have only 1 leaf yet 8 assignments. + * can have less than 8 leaves. For example, if a tree has all assignment + * values as 1, then pruning will cause the tree to have only 1 leaf yet 8 + * assignments. */ template struct Visit { diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 0d9db1fce..c1d7ea05f 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -165,6 +165,19 @@ namespace gtsam { template NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; + /** + * @brief Convert from a DecisionTree to DecisionTree. + * + * @tparam M The previous label type. + * @tparam X The previous value type. + * @param f The node pointer to the root of the previous DecisionTree. + * @param Y_of_X Functor to convert from value type X to type Y. + * @return NodePtr + */ + template + static NodePtr convertFrom(const typename DecisionTree::NodePtr& f, + std::function Y_of_X); + /** * @brief Convert from a DecisionTree to DecisionTree. *