diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index be083236e..58de338a8 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -632,16 +632,16 @@ namespace gtsam { } /****************************************************************************/ - // "create" is a bit of a complicated thing, but very useful. + // "build" is a bit of a complicated thing, but very useful. // It takes a range of labels and a corresponding range of values, - // and creates a decision tree, as follows: + // and builds a decision tree, as follows: // - if there is only one label, creates a choice node with values in leaves // - otherwise, it evenly splits up the range of values and creates a tree for // each sub-range, and assigns that tree to first label's choices // Example: - // create([B A],[1 2 3 4]) would call - // create([A],[1 2]) - // create([A],[3 4]) + // build([B A],[1 2 3 4]) would call + // build([A],[1 2]) + // build([A],[3 4]) // and produce // B=0 // A=0: 1 @@ -649,12 +649,12 @@ namespace gtsam { // B=1 // A=0: 3 // A=1: 4 - // Note, through the magic of "compose", create([A B],[1 2 3 4]) will produce + // Note, through the magic of "compose", build([A B],[1 2 3 4]) will produce // exactly the same tree as above: the highest label is always the root. // However, it will be *way* faster if labels are given highest to lowest. template template - typename DecisionTree::NodePtr DecisionTree::create( + typename DecisionTree::NodePtr DecisionTree::build( It begin, It end, ValueIt beginY, ValueIt endY) const { // get crucial counts size_t nrChoices = begin->second; @@ -684,12 +684,27 @@ namespace gtsam { std::vector functions; size_t split = size / nrChoices; for (size_t i = 0; i < nrChoices; i++, beginY += split) { - NodePtr f = create(labelC, end, beginY, beginY + split); + NodePtr f = build(labelC, end, beginY, beginY + split); functions.emplace_back(f); } return compose(functions.begin(), functions.end(), begin->first); } + /****************************************************************************/ + // Take a range of labels and a corresponding range of values, + // and creates a decision tree. + template + template + typename DecisionTree::NodePtr DecisionTree::create( + It begin, It end, ValueIt beginY, ValueIt endY) const { + auto node = build(begin, end, beginY, endY); + if (auto choice = std::dynamic_pointer_cast(node)) { + return Choice::Unique(choice); + } else { + return node; + } + } + /****************************************************************************/ template template diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 9cff7aa47..a2b05070b 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -136,10 +136,16 @@ namespace gtsam { NodePtr root_; protected: - /** Internal recursive function to create from keys, cardinalities, - * and Y values + /** Internal recursive function to create from keys, cardinalities, + * and Y values */ - template + template + NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY) const; + + /** Internal helper function to create from keys, cardinalities, + * and Y values + */ + template NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; /**