diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 58acb21b0..bd10e28b4 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -55,17 +55,15 @@ DiscreteConditional::DiscreteConditional(size_t nrFrontals, : BaseFactor(keys, potentials), BaseConditional(nrFrontals) {} /* ************************************************************************** */ -DiscreteConditional::DiscreteConditional( - const DiscreteFactor::shared_ptr& joint, - const DiscreteFactor::shared_ptr& marginal) - : BaseFactor(*std::dynamic_pointer_cast( - joint->operator/(marginal))), - BaseConditional(joint->size() - marginal->size()) {} +DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, + const DecisionTreeFactor& marginal) + : BaseFactor(joint / marginal), + BaseConditional(joint.size() - marginal.size()) {} /* ************************************************************************** */ -DiscreteConditional::DiscreteConditional( - const DiscreteFactor::shared_ptr& joint, - const DiscreteFactor::shared_ptr& marginal, const Ordering& orderedKeys) +DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, + const DecisionTreeFactor& marginal, + const Ordering& orderedKeys) : DiscreteConditional(joint, marginal) { keys_.clear(); keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end()); diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 6e9f69619..a6356a045 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -126,10 +126,9 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { /// Compute error for each assignment and return as a tree virtual AlgebraicDecisionTree errorTree() const; - /// Multiply in a DiscreteFactor and return the result as - /// DiscreteFactor - virtual DiscreteFactor::shared_ptr operator*( - const DiscreteFactor::shared_ptr&) const = 0; + /// Multiply in a DecisionTreeFactor and return the result as + /// DecisionTreeFactor + virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; @@ -145,9 +144,6 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { /// Create new factor by maximizing over all values with the same separator. virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0; - /// divide by factor f (safely) - virtual DiscreteFactor::shared_ptr operator/( - const DiscreteFactor::shared_ptr& f) const = 0; /** * Get the number of non-zero values contained in this factor. diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 50d15ff5e..a4947012e 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -171,13 +171,8 @@ double TableFactor::error(const HybridValues& values) const { } /* ************************************************************************ */ -DiscreteFactor::shared_ptr TableFactor::operator*( - const DiscreteFactor::shared_ptr& f) const { - if (auto derived = std::dynamic_pointer_cast(f)) { - return std::make_shared(this->operator*(*derived)); - } else { - throw std::runtime_error("Cannot convert DiscreteFactor to TableFactor"); - } +DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { + return toDecisionTreeFactor() * f; } /* ************************************************************************ */ diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 345cbc254..ba1d05fe9 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -161,9 +161,8 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { return apply(f, Ring::mul); }; - /// multiply with DiscreteFactor - DiscreteFactor::shared_ptr operator*( - const DiscreteFactor::shared_ptr& f) const override; + /// multiply with DecisionTreeFactor + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; static double safe_div(const double& a, const double& b); @@ -172,15 +171,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { return apply(f, safe_div); } - /// divide by factor f (pointer version) - DiscreteFactor::shared_ptr operator/( - const DiscreteFactor::shared_ptr& f) const override { - if (auto derived = std::dynamic_pointer_cast(f)) { - return std::make_shared(apply(*derived, safe_div)); - } else { - throw std::runtime_error("Cannot convert DiscreteFactor to Table Factor"); - } - } /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override;