update Unique to be recursive

release/4.3a0
Varun Agrawal 2023-06-08 09:36:08 -04:00
parent 73b563a9aa
commit 8a8f146517
1 changed files with 32 additions and 35 deletions

View File

@ -93,7 +93,7 @@ namespace gtsam {
/// print /// print
void print(const std::string& s, const LabelFormatter& labelFormatter, void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override { const ValueFormatter& valueFormatter) const override {
std::cout << s << " Leaf " << valueFormatter(constant_) << " | nrAssignments: " << nrAssignments_ << std::endl; std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
} }
/** Write graphviz format to stream `os`. */ /** Write graphviz format to stream `os`. */
@ -201,6 +201,7 @@ namespace gtsam {
/// If all branches of a choice node f are the same, just return a branch. /// If all branches of a choice node f are the same, just return a branch.
static NodePtr Unique(const ChoicePtr& f) { static NodePtr Unique(const ChoicePtr& f) {
#ifndef GTSAM_DT_NO_PRUNING #ifndef GTSAM_DT_NO_PRUNING
// If all the branches are the same, we can merge them into one
if (f->allSame_) { if (f->allSame_) {
assert(f->branches().size() > 0); assert(f->branches().size() > 0);
NodePtr f0 = f->branches_[0]; NodePtr f0 = f->branches_[0];
@ -215,34 +216,30 @@ namespace gtsam {
new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant(), new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant(),
nrAssignments)); nrAssignments));
return newLeaf; return newLeaf;
} else } else
// Else we recurse
#endif #endif
// { {
// Choice choice_node;
// for (auto branch : f->branches()) { // Make non-const copy
// if (auto choice = std::dynamic_pointer_cast<const auto ff = std::make_shared<Choice>(f->label(), f->nrChoices());
// Choice>(branch)) {
// // `branch` is a Choice node so we apply Unique to it.
// choice_node.push_back(Unique(choice));
// } else if (auto leaf = // Iterate over all the branches
// std::dynamic_pointer_cast<const Leaf>(branch)) { for (size_t i = 0; i < f->nrChoices(); i++) {
// choice_node.push_back(leaf); auto branch = f->branches_[i];
// } if (auto leaf = std::dynamic_pointer_cast<const Leaf>(branch)) {
// } // Leaf node, simply assign
// return std::make_shared<const Choice>(choice_node); ff->push_back(branch);
// }
return f; } else if (auto choice =
std::dynamic_pointer_cast<const Choice>(branch)) {
// Choice node, we recurse
ff->push_back(Unique(choice));
}
} }
static NodePtr UpdateNrAssignments(const NodePtr& f) { return ff;
if (auto choice = std::dynamic_pointer_cast<const Choice>(f)) {
// `f` is a Choice node so we recurse.
return UpdateNrAssignments(f);
} else if (auto leaf = std::dynamic_pointer_cast<const Leaf>(f)) {
} }
} }
@ -308,7 +305,7 @@ namespace gtsam {
void print(const std::string& s, const LabelFormatter& labelFormatter, void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override { const ValueFormatter& valueFormatter) const override {
std::cout << s << " Choice("; std::cout << s << " Choice(";
std::cout << labelFormatter(label_) << ") " << " | All Same: " << allSame_ << " | nrBranches: " << branches_.size() << std::endl; std::cout << labelFormatter(label_) << ") " << std::endl;
for (size_t i = 0; i < branches_.size(); i++) { for (size_t i = 0; i < branches_.size(); i++) {
branches_[i]->print(s + " " + std::to_string(i), labelFormatter, valueFormatter); branches_[i]->print(s + " " + std::to_string(i), labelFormatter, valueFormatter);
} }
@ -595,16 +592,16 @@ namespace gtsam {
// find highest label among branches // find highest label among branches
std::optional<L> highestLabel; std::optional<L> highestLabel;
size_t nrChoices = 0; size_t nrChoices = 0;
// for (Iterator it = begin; it != end; it++) { for (Iterator it = begin; it != end; it++) {
// if (it->root_->isLeaf()) if (it->root_->isLeaf())
// continue; continue;
// std::shared_ptr<const Choice> c = std::shared_ptr<const Choice> c =
// std::dynamic_pointer_cast<const Choice>(it->root_); std::dynamic_pointer_cast<const Choice>(it->root_);
// if (!highestLabel || c->label() > *highestLabel) { if (!highestLabel || c->label() > *highestLabel) {
// highestLabel = c->label(); highestLabel = c->label();
// nrChoices = c->nrChoices(); nrChoices = c->nrChoices();
// } }
// } }
// if label is already in correct order, just put together a choice on label // if label is already in correct order, just put together a choice on label
if (!nrChoices || !highestLabel || label > *highestLabel) { if (!nrChoices || !highestLabel || label > *highestLabel) {
@ -679,7 +676,7 @@ namespace gtsam {
for (ValueIt y = beginY; y != endY; y++) { for (ValueIt y = beginY; y != endY; y++) {
choice->push_back(NodePtr(new Leaf(*y))); choice->push_back(NodePtr(new Leaf(*y)));
} }
return Choice::Unique(choice); return choice;
} }
// Recursive case: perform "Shannon expansion" // Recursive case: perform "Shannon expansion"