Move constructor

release/4.3a0
Frank Dellaert 2024-10-16 14:41:22 -07:00
parent a19fb28fab
commit 08167d08cc
3 changed files with 68 additions and 3 deletions

View File

@ -287,6 +287,10 @@ namespace gtsam {
return branches_;
}
std::vector<NodePtr>& branches() {
return branches_;
}
/** add a branch: TODO merge into constructor */
void push_back(NodePtr&& node) {
// allSame_ is restricted to leaf nodes in a decision tree
@ -555,6 +559,36 @@ namespace gtsam {
root_ = compose(functions.begin(), functions.end(), label);
}
/****************************************************************************/
template <typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const Unary& op,
DecisionTree&& other) noexcept
: root_(std::move(other.root_)) {
// Apply the unary operation directly to each leaf in the tree
if (root_) {
// Define a helper function to traverse and apply the operation
struct ApplyUnary {
const Unary& op;
void operator()(typename DecisionTree<L, Y>::NodePtr& node) const {
if (auto leaf = std::dynamic_pointer_cast<Leaf>(node)) {
// Apply the unary operation to the leaf's constant value
leaf->constant_ = op(leaf->constant_);
} else if (auto choice = std::dynamic_pointer_cast<Choice>(node)) {
// Recurse into the choice branches
for (NodePtr& branch : choice->branches()) {
(*this)(branch);
}
}
}
};
ApplyUnary applyUnary{op};
applyUnary(root_);
}
// Reset the other tree's root to nullptr to avoid dangling references
other.root_ = nullptr;
}
/****************************************************************************/
template <typename L, typename Y>
template <typename X, typename Func>
@ -695,7 +729,7 @@ namespace gtsam {
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
It begin, It end, ValueIt beginY, ValueIt endY) {
auto node = build(begin, end, beginY, endY);
if (auto choice = std::dynamic_pointer_cast<const Choice>(node)) {
if (auto choice = std::dynamic_pointer_cast<Choice>(node)) {
return Choice::Unique(choice);
} else {
return node;
@ -711,7 +745,7 @@ namespace gtsam {
// If leaf, apply unary conversion "op" and create a unique leaf.
using LXLeaf = typename DecisionTree<L, X>::Leaf;
if (auto leaf = std::dynamic_pointer_cast<const LXLeaf>(f)) {
if (auto leaf = std::dynamic_pointer_cast<LXLeaf>(f)) {
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
}

View File

@ -85,7 +85,7 @@ namespace gtsam {
/** ------------------------ Node base class --------------------------- */
struct Node {
using Ptr = std::shared_ptr<const Node>;
using Ptr = std::shared_ptr<Node>;
#ifdef DT_DEBUG_MEMORY
static int nrNodes;
@ -228,6 +228,15 @@ namespace gtsam {
DecisionTree(const L& label, const DecisionTree& f0,
const DecisionTree& f1);
/**
* @brief Move constructor for DecisionTree. Very efficient as does not
* allocate anything, just changes in-place. But `other` is consumed.
*
* @param op The unary operation to apply to the moved DecisionTree.
* @param other The DecisionTree to move from, will be empty afterwards.
*/
DecisionTree(const Unary& op, DecisionTree&& other) noexcept;
/**
* @brief Convert from a different value type.
*

View File

@ -108,6 +108,7 @@ struct DT : public DecisionTree<string, int> {
std::cout << s;
Base::print("", keyFormatter, valueFormatter);
}
/// Equality method customized to int node type
bool equals(const Base& other, double tol = 1e-9) const {
auto compare = [](const int& v, const int& w) { return v == w; };
@ -302,6 +303,27 @@ TEST(DecisionTree, Split) {
}
/* ************************************************************************** */
// Test that we can create a tree by modifying an rvalue.
TEST(DecisionTree, Consume) {
// Create labels
string A("A"), B("B");
// Create a decision tree
DT original(A, DT(B, 1, 2), DT(B, 3, 4));
DT modified([](int i){return i*2;}, std::move(original));
// Check the first resulting tree
EXPECT_LONGS_EQUAL(2, modified(Assignment<string>{{A, 0}, {B, 0}}));
EXPECT_LONGS_EQUAL(4, modified(Assignment<string>{{A, 0}, {B, 1}}));
EXPECT_LONGS_EQUAL(6, modified(Assignment<string>{{A, 1}, {B, 0}}));
EXPECT_LONGS_EQUAL(8, modified(Assignment<string>{{A, 1}, {B, 1}}));
// Check original was moved
EXPECT(original.root_ == nullptr);
}
/* ************************************************************************** */
// test Conversion of values
bool bool_of_int(const int& y) { return y != 0; };