diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index ab14b2a72..84116ccd5 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -604,7 +604,7 @@ namespace gtsam { using MXChoice = typename DecisionTree::Choice; auto choice = boost::dynamic_pointer_cast(f); if (!choice) throw std::invalid_argument( - "DecisionTree::Convert: Invalid NodePtr"); + "DecisionTree::convertFrom: Invalid NodePtr"); // get new label const M oldLabel = choice->label(); @@ -634,6 +634,8 @@ namespace gtsam { using Choice = typename DecisionTree::Choice; auto choice = boost::dynamic_pointer_cast(node); + if (!choice) + throw std::invalid_argument("DecisionTree::Visit: Invalid NodePtr"); for (auto&& branch : choice->branches()) (*this)(branch); // recurse! } }; @@ -663,6 +665,8 @@ namespace gtsam { using Choice = typename DecisionTree::Choice; auto choice = boost::dynamic_pointer_cast(node); + if (!choice) + throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr"); for (size_t i = 0; i < choice->nrChoices(); i++) { choices[choice->label()] = i; // Set assignment for label to i (*this)(choice->branches()[i]); // recurse! diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 9692094e1..78f3a75b7 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -38,7 +38,7 @@ namespace gtsam { * Y = function range (any algebra), e.g., bool, int, double */ template - class GTSAM_EXPORT DecisionTree { + class DecisionTree { protected: /// Default method for comparison of two objects of type Y. @@ -340,4 +340,11 @@ namespace gtsam { return f.apply(g, op); } + /// unzip a DecisionTree if its leaves are `std::pair` + template + std::pair, DecisionTree > unzip(const DecisionTree > &input) { + return std::make_pair(DecisionTree(input, [](std::pair i) { return i.first; }), + DecisionTree(input, [](std::pair i) { return i.second; })); + } + } // namespace gtsam