dynamic -> static_pointer_cast
parent
fce62a61b5
commit
3ca88f2a46
|
|
@ -74,8 +74,8 @@ namespace gtsam {
|
||||||
|
|
||||||
/// equality up to tolerance
|
/// equality up to tolerance
|
||||||
bool equals(const Node& q, const CompareFunc& compare) const override {
|
bool equals(const Node& q, const CompareFunc& compare) const override {
|
||||||
const Leaf* other = dynamic_cast<const Leaf*>(&q);
|
if (!q.isLeaf()) return false;
|
||||||
if (!other) return false;
|
const Leaf* other = static_cast<const Leaf*>(&q);
|
||||||
return compare(this->constant_, other->constant_);
|
return compare(this->constant_, other->constant_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -202,36 +202,39 @@ namespace gtsam {
|
||||||
* @param node The root node of the decision tree.
|
* @param node The root node of the decision tree.
|
||||||
* @return NodePtr
|
* @return NodePtr
|
||||||
*/
|
*/
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
|
|
||||||
static NodePtr Unique(const NodePtr& node) {
|
static NodePtr Unique(const NodePtr& node) {
|
||||||
if (auto choice = std::dynamic_pointer_cast<const Choice>(node)) {
|
if (node->isLeaf()) return node; // Leaf node, return as is
|
||||||
|
|
||||||
|
auto choice = std::static_pointer_cast<const Choice>(node);
|
||||||
// Choice node, we recurse!
|
// Choice node, we recurse!
|
||||||
// Make non-const copy so we can update
|
// Make non-const copy so we can update
|
||||||
auto f = std::make_shared<Choice>(choice->label(), choice->nrChoices());
|
auto f = std::make_shared<Choice>(choice->label(), choice->nrChoices());
|
||||||
|
|
||||||
// Iterate over all the branches
|
// Iterate over all the branches
|
||||||
for (size_t i = 0; i < choice->nrChoices(); i++) {
|
for (const auto& branch : choice->branches_) {
|
||||||
auto branch = choice->branches_[i];
|
|
||||||
f->push_back(Unique(branch));
|
f->push_back(Unique(branch));
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef GTSAM_DT_MERGING
|
|
||||||
// If all the branches are the same, we can merge them into one
|
// If all the branches are the same, we can merge them into one
|
||||||
if (f->allSame_) {
|
if (f->allSame_) {
|
||||||
assert(f->branches().size() > 0);
|
assert(f->branches().size() > 0);
|
||||||
NodePtr f0 = f->branches_[0];
|
auto f0 = std::static_pointer_cast<const Leaf>(f->branches_[0]);
|
||||||
|
return std::make_shared<Leaf>(f0->constant());
|
||||||
NodePtr newLeaf(
|
|
||||||
new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant()));
|
|
||||||
return newLeaf;
|
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
return f;
|
return f;
|
||||||
} else {
|
}
|
||||||
// Leaf node, return as is
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
static NodePtr Unique(const NodePtr& node) {
|
||||||
|
// No-op when GTSAM_DT_MERGING is not defined
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
#endif
|
||||||
bool isLeaf() const override { return false; }
|
bool isLeaf() const override { return false; }
|
||||||
|
|
||||||
/// Constructor, given choice label and mandatory expected branch count.
|
/// Constructor, given choice label and mandatory expected branch count.
|
||||||
|
|
@ -322,9 +325,9 @@ namespace gtsam {
|
||||||
const NodePtr& branch = branches_[i];
|
const NodePtr& branch = branches_[i];
|
||||||
|
|
||||||
// Check if zero
|
// Check if zero
|
||||||
if (!showZero) {
|
if (!showZero && branch->isLeaf()) {
|
||||||
const Leaf* leaf = dynamic_cast<const Leaf*>(branch.get());
|
auto leaf = std::static_pointer_cast<const Leaf>(branch);
|
||||||
if (leaf && valueFormatter(leaf->constant()).compare("0")) continue;
|
if (valueFormatter(leaf->constant()).compare("0")) continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
os << "\"" << this->id() << "\" -> \"" << branch->id() << "\"";
|
os << "\"" << this->id() << "\" -> \"" << branch->id() << "\"";
|
||||||
|
|
@ -346,8 +349,8 @@ namespace gtsam {
|
||||||
|
|
||||||
/// equality
|
/// equality
|
||||||
bool equals(const Node& q, const CompareFunc& compare) const override {
|
bool equals(const Node& q, const CompareFunc& compare) const override {
|
||||||
const Choice* other = dynamic_cast<const Choice*>(&q);
|
if (q.isLeaf()) return false;
|
||||||
if (!other) return false;
|
const Choice* other = static_cast<const Choice*>(&q);
|
||||||
if (this->label_ != other->label_) return false;
|
if (this->label_ != other->label_) return false;
|
||||||
if (branches_.size() != other->branches_.size()) return false;
|
if (branches_.size() != other->branches_.size()) return false;
|
||||||
// we don't care about shared pointers being equal here
|
// we don't care about shared pointers being equal here
|
||||||
|
|
@ -570,11 +573,13 @@ namespace gtsam {
|
||||||
struct ApplyUnary {
|
struct ApplyUnary {
|
||||||
const Unary& op;
|
const Unary& op;
|
||||||
void operator()(typename DecisionTree<L, Y>::NodePtr& node) const {
|
void operator()(typename DecisionTree<L, Y>::NodePtr& node) const {
|
||||||
if (auto leaf = std::dynamic_pointer_cast<Leaf>(node)) {
|
if (node->isLeaf()) {
|
||||||
// Apply the unary operation to the leaf's constant value
|
// Apply the unary operation to the leaf's constant value
|
||||||
|
auto leaf = std::static_pointer_cast<Leaf>(node);
|
||||||
leaf->constant_ = op(leaf->constant_);
|
leaf->constant_ = op(leaf->constant_);
|
||||||
} else if (auto choice = std::dynamic_pointer_cast<Choice>(node)) {
|
} else {
|
||||||
// Recurse into the choice branches
|
// Recurse into the choice branches
|
||||||
|
auto choice = std::static_pointer_cast<Choice>(node);
|
||||||
for (NodePtr& branch : choice->branches()) {
|
for (NodePtr& branch : choice->branches()) {
|
||||||
(*this)(branch);
|
(*this)(branch);
|
||||||
}
|
}
|
||||||
|
|
@ -622,8 +627,7 @@ namespace gtsam {
|
||||||
for (Iterator it = begin; it != end; it++) {
|
for (Iterator it = begin; it != end; it++) {
|
||||||
if (it->root_->isLeaf())
|
if (it->root_->isLeaf())
|
||||||
continue;
|
continue;
|
||||||
std::shared_ptr<const Choice> c =
|
auto c = std::static_pointer_cast<const Choice>(it->root_);
|
||||||
std::dynamic_pointer_cast<const Choice>(it->root_);
|
|
||||||
if (!highestLabel || c->label() > *highestLabel) {
|
if (!highestLabel || c->label() > *highestLabel) {
|
||||||
highestLabel = c->label();
|
highestLabel = c->label();
|
||||||
nrChoices = c->nrChoices();
|
nrChoices = c->nrChoices();
|
||||||
|
|
@ -729,11 +733,7 @@ namespace gtsam {
|
||||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
|
||||||
It begin, It end, ValueIt beginY, ValueIt endY) {
|
It begin, It end, ValueIt beginY, ValueIt endY) {
|
||||||
auto node = build(begin, end, beginY, endY);
|
auto node = build(begin, end, beginY, endY);
|
||||||
if (auto choice = std::dynamic_pointer_cast<Choice>(node)) {
|
return Choice::Unique(node);
|
||||||
return Choice::Unique(choice);
|
|
||||||
} else {
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
|
|
@ -742,18 +742,17 @@ namespace gtsam {
|
||||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
|
||||||
const typename DecisionTree<L, X>::NodePtr& f,
|
const typename DecisionTree<L, X>::NodePtr& f,
|
||||||
std::function<Y(const X&)> Y_of_X) {
|
std::function<Y(const X&)> Y_of_X) {
|
||||||
|
using LXLeaf = typename DecisionTree<L, X>::Leaf;
|
||||||
|
using LXChoice = typename DecisionTree<L, X>::Choice;
|
||||||
|
|
||||||
// If leaf, apply unary conversion "op" and create a unique leaf.
|
// If leaf, apply unary conversion "op" and create a unique leaf.
|
||||||
using LXLeaf = typename DecisionTree<L, X>::Leaf;
|
if (f->isLeaf()) {
|
||||||
if (auto leaf = std::dynamic_pointer_cast<LXLeaf>(f)) {
|
auto leaf = std::static_pointer_cast<LXLeaf>(f);
|
||||||
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if Choice
|
// Now a Choice!
|
||||||
using LXChoice = typename DecisionTree<L, X>::Choice;
|
auto choice = std::static_pointer_cast<const LXChoice>(f);
|
||||||
auto choice = std::dynamic_pointer_cast<const LXChoice>(f);
|
|
||||||
if (!choice) throw std::invalid_argument(
|
|
||||||
"DecisionTree::convertFrom: Invalid NodePtr");
|
|
||||||
|
|
||||||
// Create a new Choice node with the same label
|
// Create a new Choice node with the same label
|
||||||
auto newChoice = std::make_shared<Choice>(choice->label(), choice->nrChoices());
|
auto newChoice = std::make_shared<Choice>(choice->label(), choice->nrChoices());
|
||||||
|
|
@ -773,18 +772,17 @@ namespace gtsam {
|
||||||
const typename DecisionTree<M, X>::NodePtr& f,
|
const typename DecisionTree<M, X>::NodePtr& f,
|
||||||
std::function<L(const M&)> L_of_M, std::function<Y(const X&)> Y_of_X) {
|
std::function<L(const M&)> L_of_M, std::function<Y(const X&)> Y_of_X) {
|
||||||
using LY = DecisionTree<L, Y>;
|
using LY = DecisionTree<L, Y>;
|
||||||
|
using MXLeaf = typename DecisionTree<M, X>::Leaf;
|
||||||
|
using MXChoice = typename DecisionTree<M, X>::Choice;
|
||||||
|
|
||||||
// If leaf, apply unary conversion "op" and create a unique leaf.
|
// If leaf, apply unary conversion "op" and create a unique leaf.
|
||||||
using MXLeaf = typename DecisionTree<M, X>::Leaf;
|
if (f->isLeaf()) {
|
||||||
if (auto leaf = std::dynamic_pointer_cast<const MXLeaf>(f)) {
|
auto leaf = std::static_pointer_cast<const MXLeaf>(f);
|
||||||
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if Choice
|
// Now is Choice!
|
||||||
using MXChoice = typename DecisionTree<M, X>::Choice;
|
auto choice = std::static_pointer_cast<const MXChoice>(f);
|
||||||
auto choice = std::dynamic_pointer_cast<const MXChoice>(f);
|
|
||||||
if (!choice)
|
|
||||||
throw std::invalid_argument("DecisionTree::convertFrom: Invalid NodePtr");
|
|
||||||
|
|
||||||
// get new label
|
// get new label
|
||||||
const M oldLabel = choice->label();
|
const M oldLabel = choice->label();
|
||||||
|
|
@ -826,13 +824,14 @@ namespace gtsam {
|
||||||
/// Do a depth-first visit on the tree rooted at node.
|
/// Do a depth-first visit on the tree rooted at node.
|
||||||
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
|
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
|
||||||
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
||||||
if (auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
|
|
||||||
return f(leaf->constant());
|
|
||||||
|
|
||||||
using Choice = typename DecisionTree<L, Y>::Choice;
|
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||||
auto choice = std::dynamic_pointer_cast<const Choice>(node);
|
|
||||||
if (!choice)
|
if (node->isLeaf()) {
|
||||||
throw std::invalid_argument("DecisionTree::Visit: Invalid NodePtr");
|
auto leaf = std::static_pointer_cast<const Leaf>(node);
|
||||||
|
return f(leaf->constant());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto choice = std::static_pointer_cast<const Choice>(node);
|
||||||
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
|
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -863,13 +862,14 @@ namespace gtsam {
|
||||||
/// Do a depth-first visit on the tree rooted at node.
|
/// Do a depth-first visit on the tree rooted at node.
|
||||||
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
|
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
|
||||||
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
||||||
if (auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
|
|
||||||
return f(*leaf);
|
|
||||||
|
|
||||||
using Choice = typename DecisionTree<L, Y>::Choice;
|
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||||
auto choice = std::dynamic_pointer_cast<const Choice>(node);
|
|
||||||
if (!choice)
|
if (node->isLeaf()) {
|
||||||
throw std::invalid_argument("DecisionTree::VisitLeaf: Invalid NodePtr");
|
auto leaf = std::static_pointer_cast<const Leaf>(node);
|
||||||
|
return f(*leaf);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto choice = std::static_pointer_cast<const Choice>(node);
|
||||||
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
|
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -898,13 +898,16 @@ namespace gtsam {
|
||||||
/// Do a depth-first visit on the tree rooted at node.
|
/// Do a depth-first visit on the tree rooted at node.
|
||||||
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
|
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
|
||||||
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
||||||
if (auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
|
|
||||||
return f(assignment, leaf->constant());
|
|
||||||
|
|
||||||
using Choice = typename DecisionTree<L, Y>::Choice;
|
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||||
auto choice = std::dynamic_pointer_cast<const Choice>(node);
|
|
||||||
if (!choice)
|
if (node->isLeaf()) {
|
||||||
throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
|
auto leaf = std::static_pointer_cast<const Leaf>(node);
|
||||||
|
return f(assignment, leaf->constant());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
auto choice = std::static_pointer_cast<const Choice>(node);
|
||||||
for (size_t i = 0; i < choice->nrChoices(); i++) {
|
for (size_t i = 0; i < choice->nrChoices(); i++) {
|
||||||
assignment[choice->label()] = i; // Set assignment for label to i
|
assignment[choice->label()] = i; // Set assignment for label to i
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue