diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index c26d25420..099ccb528 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -457,6 +457,14 @@ namespace gtsam { root_ = convert(other.root_, map, op); } + /*********************************************************************************/ + template + template + DecisionTree::DecisionTree(const DecisionTree& other, + std::function op) { + root_ = convert(other.root_, op); + } + /*********************************************************************************/ // Called by two constructors above. // Takes a label and a corresponding range of decision trees, and creates a new @@ -602,6 +610,38 @@ namespace gtsam { return LY::compose(functions.begin(), functions.end(), newLabel); } + /*********************************************************************************/ + template + template + typename DecisionTree::NodePtr DecisionTree::convert( + const typename DecisionTree::NodePtr& f, + std::function op) { + + typedef DecisionTree LX; + typedef typename LX::Leaf LXLeaf; + typedef typename LX::Choice LXChoice; + typedef typename LX::NodePtr LXNodePtr; + typedef DecisionTree LY; + + // ugliness below because apparently we can't have templated virtual functions + // If leaf, apply unary conversion "op" and create a unique leaf + const LXLeaf* leaf = dynamic_cast (f.get()); + if (leaf) return NodePtr(new Leaf(op(leaf->constant()))); + + // Check if Choice + boost::shared_ptr choice = boost::dynamic_pointer_cast (f); + if (!choice) throw std::invalid_argument( + "DecisionTree::Convert: Invalid NodePtr"); + + // put together via Shannon expansion otherwise not sorted. + std::vector functions; + for(const LXNodePtr& branch: choice->branches()) { + LY converted(convert(branch, op)); + functions += converted; + } + return LY::compose(functions.begin(), functions.end(), choice->label()); + } + /*********************************************************************************/ template bool DecisionTree::equals(const DecisionTree& other, double tol) const { diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 68ddfa06b..3b91def63 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -119,7 +119,12 @@ namespace gtsam { convert(const typename DecisionTree::NodePtr& f, const std::map& map, std::function op); - public: + /** Convert only node to a different type */ + template + NodePtr convert(const typename DecisionTree::NodePtr& f, + const std::function op); + + public: /// @name Standard Constructors /// @{ @@ -155,6 +160,11 @@ namespace gtsam { DecisionTree(const DecisionTree& other, const std::map& map, std::function op); + /** Convert only nodes from a different type */ + template + DecisionTree(const DecisionTree& other, + std::function op); + /// @} /// @name Testable /// @{ @@ -231,12 +241,19 @@ namespace gtsam { /** free versions of apply */ + //TODO(Varun) where are these templates Y, L and not L, Y? template DecisionTree apply(const DecisionTree& f, const typename DecisionTree::Unary& op) { return f.apply(op); } + template + DecisionTree apply(const DecisionTree& f, + const std::function& op) { + return f.apply(op); + } + template DecisionTree apply(const DecisionTree& f, const DecisionTree& g,