update Unique to be recursive

release/4.3a0
Varun Agrawal 2023-06-08 09:36:08 -04:00
parent 73b563a9aa
commit 8a8f146517
1 changed files with 32 additions and 35 deletions

View File

@ -93,7 +93,7 @@ namespace gtsam {
/// print
void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override {
std::cout << s << " Leaf " << valueFormatter(constant_) << " | nrAssignments: " << nrAssignments_ << std::endl;
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
}
/** Write graphviz format to stream `os`. */
@ -201,6 +201,7 @@ namespace gtsam {
/// If all branches of a choice node f are the same, just return a branch.
static NodePtr Unique(const ChoicePtr& f) {
#ifndef GTSAM_DT_NO_PRUNING
// If all the branches are the same, we can merge them into one
if (f->allSame_) {
assert(f->branches().size() > 0);
NodePtr f0 = f->branches_[0];
@ -215,34 +216,30 @@ namespace gtsam {
new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant(),
nrAssignments));
return newLeaf;
} else
// Else we recurse
#endif
// {
// Choice choice_node;
{
// for (auto branch : f->branches()) {
// if (auto choice = std::dynamic_pointer_cast<const
// Choice>(branch)) {
// // `branch` is a Choice node so we apply Unique to it.
// choice_node.push_back(Unique(choice));
// Make non-const copy
auto ff = std::make_shared<Choice>(f->label(), f->nrChoices());
// } else if (auto leaf =
// std::dynamic_pointer_cast<const Leaf>(branch)) {
// choice_node.push_back(leaf);
// }
// }
// return std::make_shared<const Choice>(choice_node);
// }
return f;
}
// Iterate over all the branches
for (size_t i = 0; i < f->nrChoices(); i++) {
auto branch = f->branches_[i];
if (auto leaf = std::dynamic_pointer_cast<const Leaf>(branch)) {
// Leaf node, simply assign
ff->push_back(branch);
static NodePtr UpdateNrAssignments(const NodePtr& f) {
if (auto choice = std::dynamic_pointer_cast<const Choice>(f)) {
// `f` is a Choice node so we recurse.
return UpdateNrAssignments(f);
} else if (auto choice =
std::dynamic_pointer_cast<const Choice>(branch)) {
// Choice node, we recurse
ff->push_back(Unique(choice));
}
}
} else if (auto leaf = std::dynamic_pointer_cast<const Leaf>(f)) {
return ff;
}
}
@ -308,7 +305,7 @@ namespace gtsam {
void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override {
std::cout << s << " Choice(";
std::cout << labelFormatter(label_) << ") " << " | All Same: " << allSame_ << " | nrBranches: " << branches_.size() << std::endl;
std::cout << labelFormatter(label_) << ") " << std::endl;
for (size_t i = 0; i < branches_.size(); i++) {
branches_[i]->print(s + " " + std::to_string(i), labelFormatter, valueFormatter);
}
@ -595,16 +592,16 @@ namespace gtsam {
// find highest label among branches
std::optional<L> highestLabel;
size_t nrChoices = 0;
// for (Iterator it = begin; it != end; it++) {
// if (it->root_->isLeaf())
// continue;
// std::shared_ptr<const Choice> c =
// std::dynamic_pointer_cast<const Choice>(it->root_);
// if (!highestLabel || c->label() > *highestLabel) {
// highestLabel = c->label();
// nrChoices = c->nrChoices();
// }
// }
for (Iterator it = begin; it != end; it++) {
if (it->root_->isLeaf())
continue;
std::shared_ptr<const Choice> c =
std::dynamic_pointer_cast<const Choice>(it->root_);
if (!highestLabel || c->label() > *highestLabel) {
highestLabel = c->label();
nrChoices = c->nrChoices();
}
}
// if label is already in correct order, just put together a choice on label
if (!nrChoices || !highestLabel || label > *highestLabel) {
@ -679,7 +676,7 @@ namespace gtsam {
for (ValueIt y = beginY; y != endY; y++) {
choice->push_back(NodePtr(new Leaf(*y)));
}
return Choice::Unique(choice);
return choice;
}
// Recursive case: perform "Shannon expansion"