diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 1ac782b88..cf22fe153 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -62,6 +62,18 @@ namespace gtsam { return error(values.discrete()); } + /* ************************************************************************ */ + DiscreteFactor::shared_ptr DecisionTreeFactor::multiply( + const DiscreteFactor::shared_ptr& f) const override { + DiscreteFactor::shared_ptr result; + if (auto tf = std::dynamic_pointer_cast(f)) { + result = std::make_shared((*tf) * TableFactor(*this)); + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + result = std::make_shared(this->operator*(*dtf)); + } + return result; + } + /* ************************************************************************ */ double DecisionTreeFactor::safe_div(const double& a, const double& b) { // The use for safe_div is when we divide the product factor by the sum diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 80ee10a7b..3e70c0df9 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -147,6 +148,10 @@ namespace gtsam { /// Calculate error for DiscreteValues `x`, is -log(probability). double error(const DiscreteValues& values) const override; + /// Multiply factors, DiscreteFactor::shared_ptr edition + virtual DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& f) const override; + /// multiply two factors DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { return apply(f, Ring::mul); diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index a1fde0f86..c18eaae2f 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -129,6 +129,16 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { /// DecisionTreeFactor virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; + /** + * @brief Multiply in a DiscreteFactor and return the result as + * DiscreteFactor, both via shared pointers. + * + * @param df DiscreteFactor shared_ptr + * @return DiscreteFactor::shared_ptr + */ + virtual DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& df) const = 0; + virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; /// @} diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index a59095d40..cfa56b43a 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -254,6 +254,18 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { return toDecisionTreeFactor() * f; } +/* ************************************************************************ */ +DiscreteFactor::shared_ptr TableFactor::multiply( + const DiscreteFactor::shared_ptr& f) const override { + DiscreteFactor::shared_ptr result; + if (auto tf = std::dynamic_pointer_cast(f)) { + result = std::make_shared(this->operator*(*tf)); + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + result = std::make_shared(this->operator*(TableFactor(*dtf))); + } + return result; +} + /* ************************************************************************ */ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DiscreteKeys dkeys = discreteKeys(); diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index a2fdb4d32..4b53d7e2b 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -178,6 +178,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { /// multiply with DecisionTreeFactor DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + /// Multiply factors, DiscreteFactor::shared_ptr edition + virtual DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& f) const override; + static double safe_div(const double& a, const double& b); /// divide by factor f (safely)