diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index efc19d9ee..89091c78b 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -74,8 +74,8 @@ namespace gtsam { /// equality up to tolerance bool equals(const Node& q, const CompareFunc& compare) const override { - const Leaf* other = dynamic_cast(&q); - if (!other) return false; + if (!q.isLeaf()) return false; + const Leaf* other = static_cast(&q); return compare(this->constant_, other->constant_); } @@ -202,36 +202,39 @@ namespace gtsam { * @param node The root node of the decision tree. * @return NodePtr */ + #ifdef GTSAM_DT_MERGING + static NodePtr Unique(const NodePtr& node) { - if (auto choice = std::dynamic_pointer_cast(node)) { - // Choice node, we recurse! - // Make non-const copy so we can update - auto f = std::make_shared(choice->label(), choice->nrChoices()); + if (node->isLeaf()) return node; // Leaf node, return as is - // Iterate over all the branches - for (size_t i = 0; i < choice->nrChoices(); i++) { - auto branch = choice->branches_[i]; - f->push_back(Unique(branch)); - } + auto choice = std::static_pointer_cast(node); + // Choice node, we recurse! + // Make non-const copy so we can update + auto f = std::make_shared(choice->label(), choice->nrChoices()); -#ifdef GTSAM_DT_MERGING - // If all the branches are the same, we can merge them into one - if (f->allSame_) { - assert(f->branches().size() > 0); - NodePtr f0 = f->branches_[0]; - - NodePtr newLeaf( - new Leaf(std::dynamic_pointer_cast(f0)->constant())); - return newLeaf; - } -#endif - return f; - } else { - // Leaf node, return as is - return node; + // Iterate over all the branches + for (const auto& branch : choice->branches_) { + f->push_back(Unique(branch)); } + + // If all the branches are the same, we can merge them into one + if (f->allSame_) { + assert(f->branches().size() > 0); + auto f0 = std::static_pointer_cast(f->branches_[0]); + return std::make_shared(f0->constant()); + } + + return f; } + #else + + static NodePtr Unique(const NodePtr& node) { + // No-op when GTSAM_DT_MERGING is not defined + return node; + } + + #endif bool isLeaf() const override { return false; } /// Constructor, given choice label and mandatory expected branch count. @@ -322,9 +325,9 @@ namespace gtsam { const NodePtr& branch = branches_[i]; // Check if zero - if (!showZero) { - const Leaf* leaf = dynamic_cast(branch.get()); - if (leaf && valueFormatter(leaf->constant()).compare("0")) continue; + if (!showZero && branch->isLeaf()) { + auto leaf = std::static_pointer_cast(branch); + if (valueFormatter(leaf->constant()).compare("0")) continue; } os << "\"" << this->id() << "\" -> \"" << branch->id() << "\""; @@ -346,8 +349,8 @@ namespace gtsam { /// equality bool equals(const Node& q, const CompareFunc& compare) const override { - const Choice* other = dynamic_cast(&q); - if (!other) return false; + if (q.isLeaf()) return false; + const Choice* other = static_cast(&q); if (this->label_ != other->label_) return false; if (branches_.size() != other->branches_.size()) return false; // we don't care about shared pointers being equal here @@ -570,11 +573,13 @@ namespace gtsam { struct ApplyUnary { const Unary& op; void operator()(typename DecisionTree::NodePtr& node) const { - if (auto leaf = std::dynamic_pointer_cast(node)) { + if (node->isLeaf()) { // Apply the unary operation to the leaf's constant value + auto leaf = std::static_pointer_cast(node); leaf->constant_ = op(leaf->constant_); - } else if (auto choice = std::dynamic_pointer_cast(node)) { + } else { // Recurse into the choice branches + auto choice = std::static_pointer_cast(node); for (NodePtr& branch : choice->branches()) { (*this)(branch); } @@ -622,8 +627,7 @@ namespace gtsam { for (Iterator it = begin; it != end; it++) { if (it->root_->isLeaf()) continue; - std::shared_ptr c = - std::dynamic_pointer_cast(it->root_); + auto c = std::static_pointer_cast(it->root_); if (!highestLabel || c->label() > *highestLabel) { highestLabel = c->label(); nrChoices = c->nrChoices(); @@ -729,11 +733,7 @@ namespace gtsam { typename DecisionTree::NodePtr DecisionTree::create( It begin, It end, ValueIt beginY, ValueIt endY) { auto node = build(begin, end, beginY, endY); - if (auto choice = std::dynamic_pointer_cast(node)) { - return Choice::Unique(choice); - } else { - return node; - } + return Choice::Unique(node); } /****************************************************************************/ @@ -742,18 +742,17 @@ namespace gtsam { typename DecisionTree::NodePtr DecisionTree::convertFrom( const typename DecisionTree::NodePtr& f, std::function Y_of_X) { + using LXLeaf = typename DecisionTree::Leaf; + using LXChoice = typename DecisionTree::Choice; // If leaf, apply unary conversion "op" and create a unique leaf. - using LXLeaf = typename DecisionTree::Leaf; - if (auto leaf = std::dynamic_pointer_cast(f)) { + if (f->isLeaf()) { + auto leaf = std::static_pointer_cast(f); return NodePtr(new Leaf(Y_of_X(leaf->constant()))); } - // Check if Choice - using LXChoice = typename DecisionTree::Choice; - auto choice = std::dynamic_pointer_cast(f); - if (!choice) throw std::invalid_argument( - "DecisionTree::convertFrom: Invalid NodePtr"); + // Now a Choice! + auto choice = std::static_pointer_cast(f); // Create a new Choice node with the same label auto newChoice = std::make_shared(choice->label(), choice->nrChoices()); @@ -773,18 +772,17 @@ namespace gtsam { const typename DecisionTree::NodePtr& f, std::function L_of_M, std::function Y_of_X) { using LY = DecisionTree; + using MXLeaf = typename DecisionTree::Leaf; + using MXChoice = typename DecisionTree::Choice; // If leaf, apply unary conversion "op" and create a unique leaf. - using MXLeaf = typename DecisionTree::Leaf; - if (auto leaf = std::dynamic_pointer_cast(f)) { + if (f->isLeaf()) { + auto leaf = std::static_pointer_cast(f); return NodePtr(new Leaf(Y_of_X(leaf->constant()))); } - // Check if Choice - using MXChoice = typename DecisionTree::Choice; - auto choice = std::dynamic_pointer_cast(f); - if (!choice) - throw std::invalid_argument("DecisionTree::convertFrom: Invalid NodePtr"); + // Now is Choice! + auto choice = std::static_pointer_cast(f); // get new label const M oldLabel = choice->label(); @@ -826,13 +824,14 @@ namespace gtsam { /// Do a depth-first visit on the tree rooted at node. void operator()(const typename DecisionTree::NodePtr& node) const { using Leaf = typename DecisionTree::Leaf; - if (auto leaf = std::dynamic_pointer_cast(node)) - return f(leaf->constant()); - using Choice = typename DecisionTree::Choice; - auto choice = std::dynamic_pointer_cast(node); - if (!choice) - throw std::invalid_argument("DecisionTree::Visit: Invalid NodePtr"); + + if (node->isLeaf()) { + auto leaf = std::static_pointer_cast(node); + return f(leaf->constant()); + } + + auto choice = std::static_pointer_cast(node); for (auto&& branch : choice->branches()) (*this)(branch); // recurse! } }; @@ -863,13 +862,14 @@ namespace gtsam { /// Do a depth-first visit on the tree rooted at node. void operator()(const typename DecisionTree::NodePtr& node) const { using Leaf = typename DecisionTree::Leaf; - if (auto leaf = std::dynamic_pointer_cast(node)) - return f(*leaf); - using Choice = typename DecisionTree::Choice; - auto choice = std::dynamic_pointer_cast(node); - if (!choice) - throw std::invalid_argument("DecisionTree::VisitLeaf: Invalid NodePtr"); + + if (node->isLeaf()) { + auto leaf = std::static_pointer_cast(node); + return f(*leaf); + } + + auto choice = std::static_pointer_cast(node); for (auto&& branch : choice->branches()) (*this)(branch); // recurse! } }; @@ -898,13 +898,16 @@ namespace gtsam { /// Do a depth-first visit on the tree rooted at node. void operator()(const typename DecisionTree::NodePtr& node) { using Leaf = typename DecisionTree::Leaf; - if (auto leaf = std::dynamic_pointer_cast(node)) - return f(assignment, leaf->constant()); - using Choice = typename DecisionTree::Choice; - auto choice = std::dynamic_pointer_cast(node); - if (!choice) - throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr"); + + if (node->isLeaf()) { + auto leaf = std::static_pointer_cast(node); + return f(assignment, leaf->constant()); + } + + + + auto choice = std::static_pointer_cast(node); for (size_t i = 0; i < choice->nrChoices(); i++) { assignment[choice->label()] = i; // Set assignment for label to i