diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index c15fd4e2e..4b16dad8a 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -70,6 +70,9 @@ namespace gtsam { result = std::make_shared((*tf) * TableFactor(*this)); } else if (auto dtf = std::dynamic_pointer_cast(f)) { result = std::make_shared(this->operator*(*dtf)); + } else { + // Simulate double dispatch in C++ + result = std::make_shared(f->operator*(*this)); } return result; } diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 678c70e2d..e05cf9e33 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -65,11 +65,18 @@ namespace gtsam { /* ************************************************************************ */ DecisionTreeFactor DiscreteFactorGraph::product() const { - DecisionTreeFactor result; - for (const sharedFactor& factor : *this) { - if (factor) result = (*factor) * result; + DiscreteFactor::shared_ptr result; + for (auto it = this->begin(); it != this->end(); ++it) { + if (*it) { + if (result) { + result = result->multiply(*it); + } else { + // Assign to the first non-null factor + result = *it; + } + } } - return result; + return result->toDecisionTreeFactor(); } /* ************************************************************************ */ diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index cfa56b43a..6516a4a98 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -256,12 +256,16 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { /* ************************************************************************ */ DiscreteFactor::shared_ptr TableFactor::multiply( - const DiscreteFactor::shared_ptr& f) const override { + const DiscreteFactor::shared_ptr& f) const { DiscreteFactor::shared_ptr result; if (auto tf = std::dynamic_pointer_cast(f)) { result = std::make_shared(this->operator*(*tf)); } else if (auto dtf = std::dynamic_pointer_cast(f)) { result = std::make_shared(this->operator*(TableFactor(*dtf))); + } else { + // Simulate double dispatch in C++ + result = std::make_shared( + f->operator*(this->toDecisionTreeFactor())); } return result; } diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index 07c6ff4af..cf0e5e3cf 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -53,6 +53,13 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { /// Multiply into a decisiontree DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + /// Multiply factors, DiscreteFactor::shared_ptr edition + DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& df) const override { + return std::make_shared( + this->operator*(df->toDecisionTreeFactor())); + } + /// Compute error for each assignment and return as a tree AlgebraicDecisionTree errorTree() const override { throw std::runtime_error("AllDiff::error not implemented"); diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index ac099209a..c15ac8aec 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -69,6 +69,13 @@ class BinaryAllDiff : public Constraint { return toDecisionTreeFactor() * f; } + /// Multiply factors, DiscreteFactor::shared_ptr edition + DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& df) const override { + return std::make_shared( + this->operator*(df->toDecisionTreeFactor())); + } + /* * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 786bc874b..f64716028 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -90,6 +90,13 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { /// Multiply into a decisiontree DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + /// Multiply factors, DiscreteFactor::shared_ptr edition + DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& df) const override { + return std::make_shared( + this->operator*(df->toDecisionTreeFactor())); + } + /* * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index ab2563f58..c5824a96a 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -63,6 +63,13 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { /// Multiply into a decisiontree DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + /// Multiply factors, DiscreteFactor::shared_ptr edition + DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& df) const override { + return std::make_shared( + this->operator*(df->toDecisionTreeFactor())); + } + /* * Ensure Arc-consistency: just sets domain[j] to {value_}. * @param j domain to be checked