diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 1b0472a80..bf2c4e1a9 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -93,7 +93,7 @@ namespace gtsam { /// print void print(const std::string& s, const LabelFormatter& labelFormatter, const ValueFormatter& valueFormatter) const override { - std::cout << s << " Leaf " << valueFormatter(constant_) << " | nrAssignments: " << nrAssignments_ << std::endl; + std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; } /** Write graphviz format to stream `os`. */ @@ -201,6 +201,7 @@ namespace gtsam { /// If all branches of a choice node f are the same, just return a branch. static NodePtr Unique(const ChoicePtr& f) { #ifndef GTSAM_DT_NO_PRUNING + // If all the branches are the same, we can merge them into one if (f->allSame_) { assert(f->branches().size() > 0); NodePtr f0 = f->branches_[0]; @@ -215,34 +216,30 @@ namespace gtsam { new Leaf(std::dynamic_pointer_cast(f0)->constant(), nrAssignments)); return newLeaf; + } else + // Else we recurse #endif - // { - // Choice choice_node; + { - // for (auto branch : f->branches()) { - // if (auto choice = std::dynamic_pointer_cast(branch)) { - // // `branch` is a Choice node so we apply Unique to it. - // choice_node.push_back(Unique(choice)); + // Make non-const copy + auto ff = std::make_shared(f->label(), f->nrChoices()); - // } else if (auto leaf = - // std::dynamic_pointer_cast(branch)) { - // choice_node.push_back(leaf); - // } - // } - // return std::make_shared(choice_node); - // } - return f; - } + // Iterate over all the branches + for (size_t i = 0; i < f->nrChoices(); i++) { + auto branch = f->branches_[i]; + if (auto leaf = std::dynamic_pointer_cast(branch)) { + // Leaf node, simply assign + ff->push_back(branch); - static NodePtr UpdateNrAssignments(const NodePtr& f) { - if (auto choice = std::dynamic_pointer_cast(f)) { - // `f` is a Choice node so we recurse. - return UpdateNrAssignments(f); + } else if (auto choice = + std::dynamic_pointer_cast(branch)) { + // Choice node, we recurse + ff->push_back(Unique(choice)); + } + } - } else if (auto leaf = std::dynamic_pointer_cast(f)) { - + return ff; } } @@ -308,7 +305,7 @@ namespace gtsam { void print(const std::string& s, const LabelFormatter& labelFormatter, const ValueFormatter& valueFormatter) const override { std::cout << s << " Choice("; - std::cout << labelFormatter(label_) << ") " << " | All Same: " << allSame_ << " | nrBranches: " << branches_.size() << std::endl; + std::cout << labelFormatter(label_) << ") " << std::endl; for (size_t i = 0; i < branches_.size(); i++) { branches_[i]->print(s + " " + std::to_string(i), labelFormatter, valueFormatter); } @@ -595,16 +592,16 @@ namespace gtsam { // find highest label among branches std::optional highestLabel; size_t nrChoices = 0; - // for (Iterator it = begin; it != end; it++) { - // if (it->root_->isLeaf()) - // continue; - // std::shared_ptr c = - // std::dynamic_pointer_cast(it->root_); - // if (!highestLabel || c->label() > *highestLabel) { - // highestLabel = c->label(); - // nrChoices = c->nrChoices(); - // } - // } + for (Iterator it = begin; it != end; it++) { + if (it->root_->isLeaf()) + continue; + std::shared_ptr c = + std::dynamic_pointer_cast(it->root_); + if (!highestLabel || c->label() > *highestLabel) { + highestLabel = c->label(); + nrChoices = c->nrChoices(); + } + } // if label is already in correct order, just put together a choice on label if (!nrChoices || !highestLabel || label > *highestLabel) { @@ -679,7 +676,7 @@ namespace gtsam { for (ValueIt y = beginY; y != endY; y++) { choice->push_back(NodePtr(new Leaf(*y))); } - return Choice::Unique(choice); + return choice; } // Recursive case: perform "Shannon expansion"