add new build method to replace create, and let create call Unique

release/4.3a0
Varun Agrawal 2023-06-08 09:53:39 -04:00
parent dbd0a7d3ba
commit 68cb724970
2 changed files with 32 additions and 11 deletions

View File

@ -632,16 +632,16 @@ namespace gtsam {
}
/****************************************************************************/
// "create" is a bit of a complicated thing, but very useful.
// "build" is a bit of a complicated thing, but very useful.
// It takes a range of labels and a corresponding range of values,
// and creates a decision tree, as follows:
// and builds a decision tree, as follows:
// - if there is only one label, creates a choice node with values in leaves
// - otherwise, it evenly splits up the range of values and creates a tree for
// each sub-range, and assigns that tree to first label's choices
// Example:
// create([B A],[1 2 3 4]) would call
// create([A],[1 2])
// create([A],[3 4])
// build([B A],[1 2 3 4]) would call
// build([A],[1 2])
// build([A],[3 4])
// and produce
// B=0
// A=0: 1
@ -649,12 +649,12 @@ namespace gtsam {
// B=1
// A=0: 3
// A=1: 4
// Note, through the magic of "compose", create([A B],[1 2 3 4]) will produce
// Note, through the magic of "compose", build([A B],[1 2 3 4]) will produce
// exactly the same tree as above: the highest label is always the root.
// However, it will be *way* faster if labels are given highest to lowest.
template<typename L, typename Y>
template<typename It, typename ValueIt>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::build(
It begin, It end, ValueIt beginY, ValueIt endY) const {
// get crucial counts
size_t nrChoices = begin->second;
@ -684,12 +684,27 @@ namespace gtsam {
std::vector<DecisionTree> functions;
size_t split = size / nrChoices;
for (size_t i = 0; i < nrChoices; i++, beginY += split) {
NodePtr f = create<It, ValueIt>(labelC, end, beginY, beginY + split);
NodePtr f = build<It, ValueIt>(labelC, end, beginY, beginY + split);
functions.emplace_back(f);
}
return compose(functions.begin(), functions.end(), begin->first);
}
/****************************************************************************/
// Take a range of labels and a corresponding range of values,
// and creates a decision tree.
template<typename L, typename Y>
template<typename It, typename ValueIt>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
It begin, It end, ValueIt beginY, ValueIt endY) const {
auto node = build(begin, end, beginY, endY);
if (auto choice = std::dynamic_pointer_cast<const Choice>(node)) {
return Choice::Unique(choice);
} else {
return node;
}
}
/****************************************************************************/
template <typename L, typename Y>
template <typename M, typename X>

View File

@ -136,10 +136,16 @@ namespace gtsam {
NodePtr root_;
protected:
/** Internal recursive function to create from keys, cardinalities,
* and Y values
/** Internal recursive function to create from keys, cardinalities,
* and Y values
*/
template<typename It, typename ValueIt>
template <typename It, typename ValueIt>
NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY) const;
/** Internal helper function to create from keys, cardinalities,
* and Y values
*/
template <typename It, typename ValueIt>
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
/**