update Unique to be recursive
parent
73b563a9aa
commit
8a8f146517
|
@ -93,7 +93,7 @@ namespace gtsam {
|
||||||
/// print
|
/// print
|
||||||
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||||
const ValueFormatter& valueFormatter) const override {
|
const ValueFormatter& valueFormatter) const override {
|
||||||
std::cout << s << " Leaf " << valueFormatter(constant_) << " | nrAssignments: " << nrAssignments_ << std::endl;
|
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Write graphviz format to stream `os`. */
|
/** Write graphviz format to stream `os`. */
|
||||||
|
@ -201,6 +201,7 @@ namespace gtsam {
|
||||||
/// If all branches of a choice node f are the same, just return a branch.
|
/// If all branches of a choice node f are the same, just return a branch.
|
||||||
static NodePtr Unique(const ChoicePtr& f) {
|
static NodePtr Unique(const ChoicePtr& f) {
|
||||||
#ifndef GTSAM_DT_NO_PRUNING
|
#ifndef GTSAM_DT_NO_PRUNING
|
||||||
|
// If all the branches are the same, we can merge them into one
|
||||||
if (f->allSame_) {
|
if (f->allSame_) {
|
||||||
assert(f->branches().size() > 0);
|
assert(f->branches().size() > 0);
|
||||||
NodePtr f0 = f->branches_[0];
|
NodePtr f0 = f->branches_[0];
|
||||||
|
@ -215,34 +216,30 @@ namespace gtsam {
|
||||||
new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant(),
|
new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant(),
|
||||||
nrAssignments));
|
nrAssignments));
|
||||||
return newLeaf;
|
return newLeaf;
|
||||||
|
|
||||||
} else
|
} else
|
||||||
|
// Else we recurse
|
||||||
#endif
|
#endif
|
||||||
// {
|
{
|
||||||
// Choice choice_node;
|
|
||||||
|
|
||||||
// for (auto branch : f->branches()) {
|
// Make non-const copy
|
||||||
// if (auto choice = std::dynamic_pointer_cast<const
|
auto ff = std::make_shared<Choice>(f->label(), f->nrChoices());
|
||||||
// Choice>(branch)) {
|
|
||||||
// // `branch` is a Choice node so we apply Unique to it.
|
|
||||||
// choice_node.push_back(Unique(choice));
|
|
||||||
|
|
||||||
// } else if (auto leaf =
|
// Iterate over all the branches
|
||||||
// std::dynamic_pointer_cast<const Leaf>(branch)) {
|
for (size_t i = 0; i < f->nrChoices(); i++) {
|
||||||
// choice_node.push_back(leaf);
|
auto branch = f->branches_[i];
|
||||||
// }
|
if (auto leaf = std::dynamic_pointer_cast<const Leaf>(branch)) {
|
||||||
// }
|
// Leaf node, simply assign
|
||||||
// return std::make_shared<const Choice>(choice_node);
|
ff->push_back(branch);
|
||||||
// }
|
|
||||||
return f;
|
} else if (auto choice =
|
||||||
|
std::dynamic_pointer_cast<const Choice>(branch)) {
|
||||||
|
// Choice node, we recurse
|
||||||
|
ff->push_back(Unique(choice));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static NodePtr UpdateNrAssignments(const NodePtr& f) {
|
return ff;
|
||||||
if (auto choice = std::dynamic_pointer_cast<const Choice>(f)) {
|
|
||||||
// `f` is a Choice node so we recurse.
|
|
||||||
return UpdateNrAssignments(f);
|
|
||||||
|
|
||||||
} else if (auto leaf = std::dynamic_pointer_cast<const Leaf>(f)) {
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -308,7 +305,7 @@ namespace gtsam {
|
||||||
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||||
const ValueFormatter& valueFormatter) const override {
|
const ValueFormatter& valueFormatter) const override {
|
||||||
std::cout << s << " Choice(";
|
std::cout << s << " Choice(";
|
||||||
std::cout << labelFormatter(label_) << ") " << " | All Same: " << allSame_ << " | nrBranches: " << branches_.size() << std::endl;
|
std::cout << labelFormatter(label_) << ") " << std::endl;
|
||||||
for (size_t i = 0; i < branches_.size(); i++) {
|
for (size_t i = 0; i < branches_.size(); i++) {
|
||||||
branches_[i]->print(s + " " + std::to_string(i), labelFormatter, valueFormatter);
|
branches_[i]->print(s + " " + std::to_string(i), labelFormatter, valueFormatter);
|
||||||
}
|
}
|
||||||
|
@ -595,16 +592,16 @@ namespace gtsam {
|
||||||
// find highest label among branches
|
// find highest label among branches
|
||||||
std::optional<L> highestLabel;
|
std::optional<L> highestLabel;
|
||||||
size_t nrChoices = 0;
|
size_t nrChoices = 0;
|
||||||
// for (Iterator it = begin; it != end; it++) {
|
for (Iterator it = begin; it != end; it++) {
|
||||||
// if (it->root_->isLeaf())
|
if (it->root_->isLeaf())
|
||||||
// continue;
|
continue;
|
||||||
// std::shared_ptr<const Choice> c =
|
std::shared_ptr<const Choice> c =
|
||||||
// std::dynamic_pointer_cast<const Choice>(it->root_);
|
std::dynamic_pointer_cast<const Choice>(it->root_);
|
||||||
// if (!highestLabel || c->label() > *highestLabel) {
|
if (!highestLabel || c->label() > *highestLabel) {
|
||||||
// highestLabel = c->label();
|
highestLabel = c->label();
|
||||||
// nrChoices = c->nrChoices();
|
nrChoices = c->nrChoices();
|
||||||
// }
|
}
|
||||||
// }
|
}
|
||||||
|
|
||||||
// if label is already in correct order, just put together a choice on label
|
// if label is already in correct order, just put together a choice on label
|
||||||
if (!nrChoices || !highestLabel || label > *highestLabel) {
|
if (!nrChoices || !highestLabel || label > *highestLabel) {
|
||||||
|
@ -679,7 +676,7 @@ namespace gtsam {
|
||||||
for (ValueIt y = beginY; y != endY; y++) {
|
for (ValueIt y = beginY; y != endY; y++) {
|
||||||
choice->push_back(NodePtr(new Leaf(*y)));
|
choice->push_back(NodePtr(new Leaf(*y)));
|
||||||
}
|
}
|
||||||
return Choice::Unique(choice);
|
return choice;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recursive case: perform "Shannon expansion"
|
// Recursive case: perform "Shannon expansion"
|
||||||
|
|
Loading…
Reference in New Issue