diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 1f50014d6..bda44bb9d 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -287,6 +287,10 @@ namespace gtsam { return branches_; } + std::vector& branches() { + return branches_; + } + /** add a branch: TODO merge into constructor */ void push_back(NodePtr&& node) { // allSame_ is restricted to leaf nodes in a decision tree @@ -555,6 +559,36 @@ namespace gtsam { root_ = compose(functions.begin(), functions.end(), label); } + /****************************************************************************/ + template + DecisionTree::DecisionTree(const Unary& op, + DecisionTree&& other) noexcept + : root_(std::move(other.root_)) { + // Apply the unary operation directly to each leaf in the tree + if (root_) { + // Define a helper function to traverse and apply the operation + struct ApplyUnary { + const Unary& op; + void operator()(typename DecisionTree::NodePtr& node) const { + if (auto leaf = std::dynamic_pointer_cast(node)) { + // Apply the unary operation to the leaf's constant value + leaf->constant_ = op(leaf->constant_); + } else if (auto choice = std::dynamic_pointer_cast(node)) { + // Recurse into the choice branches + for (NodePtr& branch : choice->branches()) { + (*this)(branch); + } + } + } + }; + + ApplyUnary applyUnary{op}; + applyUnary(root_); + } + // Reset the other tree's root to nullptr to avoid dangling references + other.root_ = nullptr; + } + /****************************************************************************/ template template @@ -695,7 +729,7 @@ namespace gtsam { typename DecisionTree::NodePtr DecisionTree::create( It begin, It end, ValueIt beginY, ValueIt endY) { auto node = build(begin, end, beginY, endY); - if (auto choice = std::dynamic_pointer_cast(node)) { + if (auto choice = std::dynamic_pointer_cast(node)) { return Choice::Unique(choice); } else { return node; @@ -711,7 +745,7 @@ namespace gtsam { // If leaf, apply unary conversion "op" and create a unique leaf. using LXLeaf = typename DecisionTree::Leaf; - if (auto leaf = std::dynamic_pointer_cast(f)) { + if (auto leaf = std::dynamic_pointer_cast(f)) { return NodePtr(new Leaf(Y_of_X(leaf->constant()))); } diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index aba5b88b7..486f798e9 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -85,7 +85,7 @@ namespace gtsam { /** ------------------------ Node base class --------------------------- */ struct Node { - using Ptr = std::shared_ptr; + using Ptr = std::shared_ptr; #ifdef DT_DEBUG_MEMORY static int nrNodes; @@ -228,6 +228,15 @@ namespace gtsam { DecisionTree(const L& label, const DecisionTree& f0, const DecisionTree& f1); + /** + * @brief Move constructor for DecisionTree. Very efficient as does not + * allocate anything, just changes in-place. But `other` is consumed. + * + * @param op The unary operation to apply to the moved DecisionTree. + * @param other The DecisionTree to move from, will be empty afterwards. + */ + DecisionTree(const Unary& op, DecisionTree&& other) noexcept; + /** * @brief Convert from a different value type. * diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 1382fc704..526001b51 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -108,6 +108,7 @@ struct DT : public DecisionTree { std::cout << s; Base::print("", keyFormatter, valueFormatter); } + /// Equality method customized to int node type bool equals(const Base& other, double tol = 1e-9) const { auto compare = [](const int& v, const int& w) { return v == w; }; @@ -302,6 +303,27 @@ TEST(DecisionTree, Split) { } +/* ************************************************************************** */ +// Test that we can create a tree by modifying an rvalue. +TEST(DecisionTree, Consume) { + // Create labels + string A("A"), B("B"); + + // Create a decision tree + DT original(A, DT(B, 1, 2), DT(B, 3, 4)); + + DT modified([](int i){return i*2;}, std::move(original)); + + // Check the first resulting tree + EXPECT_LONGS_EQUAL(2, modified(Assignment{{A, 0}, {B, 0}})); + EXPECT_LONGS_EQUAL(4, modified(Assignment{{A, 0}, {B, 1}})); + EXPECT_LONGS_EQUAL(6, modified(Assignment{{A, 1}, {B, 0}})); + EXPECT_LONGS_EQUAL(8, modified(Assignment{{A, 1}, {B, 1}})); + + // Check original was moved + EXPECT(original.root_ == nullptr); +} + /* ************************************************************************** */ // test Conversion of values bool bool_of_int(const int& y) { return y != 0; };