Merge branch 'discrete-multiply' into discrete-elimination-refactor

release/4.3a0
Varun Agrawal 2025-01-05 18:22:01 -05:00
commit 5e9c1300db
7 changed files with 47 additions and 5 deletions

View File

@ -70,6 +70,9 @@ namespace gtsam {
result = std::make_shared<TableFactor>((*tf) * TableFactor(*this)); result = std::make_shared<TableFactor>((*tf) * TableFactor(*this));
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) { } else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
result = std::make_shared<DecisionTreeFactor>(this->operator*(*dtf)); result = std::make_shared<DecisionTreeFactor>(this->operator*(*dtf));
} else {
// Simulate double dispatch in C++
result = std::make_shared<DecisionTreeFactor>(f->operator*(*this));
} }
return result; return result;
} }

View File

@ -65,11 +65,18 @@ namespace gtsam {
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor DiscreteFactorGraph::product() const { DecisionTreeFactor DiscreteFactorGraph::product() const {
DecisionTreeFactor result; DiscreteFactor::shared_ptr result;
for (const sharedFactor& factor : *this) { for (auto it = this->begin(); it != this->end(); ++it) {
if (factor) result = (*factor) * result; if (*it) {
if (result) {
result = result->multiply(*it);
} else {
// Assign to the first non-null factor
result = *it;
}
}
} }
return result; return result->toDecisionTreeFactor();
} }
/* ************************************************************************ */ /* ************************************************************************ */

View File

@ -256,12 +256,16 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
/* ************************************************************************ */ /* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::multiply( DiscreteFactor::shared_ptr TableFactor::multiply(
const DiscreteFactor::shared_ptr& f) const override { const DiscreteFactor::shared_ptr& f) const {
DiscreteFactor::shared_ptr result; DiscreteFactor::shared_ptr result;
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) { if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
result = std::make_shared<TableFactor>(this->operator*(*tf)); result = std::make_shared<TableFactor>(this->operator*(*tf));
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) { } else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
result = std::make_shared<TableFactor>(this->operator*(TableFactor(*dtf))); result = std::make_shared<TableFactor>(this->operator*(TableFactor(*dtf)));
} else {
// Simulate double dispatch in C++
result = std::make_shared<DecisionTreeFactor>(
f->operator*(this->toDecisionTreeFactor()));
} }
return result; return result;
} }

View File

@ -53,6 +53,13 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
/// Multiply into a decisiontree /// Multiply into a decisiontree
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
/// Multiply factors, DiscreteFactor::shared_ptr edition
DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const override {
return std::make_shared<DecisionTreeFactor>(
this->operator*(df->toDecisionTreeFactor()));
}
/// Compute error for each assignment and return as a tree /// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override { AlgebraicDecisionTree<Key> errorTree() const override {
throw std::runtime_error("AllDiff::error not implemented"); throw std::runtime_error("AllDiff::error not implemented");

View File

@ -69,6 +69,13 @@ class BinaryAllDiff : public Constraint {
return toDecisionTreeFactor() * f; return toDecisionTreeFactor() * f;
} }
/// Multiply factors, DiscreteFactor::shared_ptr edition
DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const override {
return std::make_shared<DecisionTreeFactor>(
this->operator*(df->toDecisionTreeFactor()));
}
/* /*
* Ensure Arc-consistency by checking every possible value of domain j. * Ensure Arc-consistency by checking every possible value of domain j.
* @param j domain to be checked * @param j domain to be checked

View File

@ -90,6 +90,13 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
/// Multiply into a decisiontree /// Multiply into a decisiontree
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
/// Multiply factors, DiscreteFactor::shared_ptr edition
DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const override {
return std::make_shared<DecisionTreeFactor>(
this->operator*(df->toDecisionTreeFactor()));
}
/* /*
* Ensure Arc-consistency by checking every possible value of domain j. * Ensure Arc-consistency by checking every possible value of domain j.
* @param j domain to be checked * @param j domain to be checked

View File

@ -63,6 +63,13 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
/// Multiply into a decisiontree /// Multiply into a decisiontree
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
/// Multiply factors, DiscreteFactor::shared_ptr edition
DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const override {
return std::make_shared<DecisionTreeFactor>(
this->operator*(df->toDecisionTreeFactor()));
}
/* /*
* Ensure Arc-consistency: just sets domain[j] to {value_}. * Ensure Arc-consistency: just sets domain[j] to {value_}.
* @param j domain to be checked * @param j domain to be checked