diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index e353fdebf..ef7979d0a 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -68,11 +68,20 @@ namespace gtsam { const DiscreteFactor::shared_ptr& f) const { DiscreteFactor::shared_ptr result; if (auto tf = std::dynamic_pointer_cast(f)) { + // If f is a TableFactor, we convert `this` to a TableFactor since this + // conversion is cheaper than converting `f` to a DecisionTreeFactor. We + // then return a TableFactor. result = std::make_shared((*tf) * TableFactor(*this)); + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + // If `f` is a DecisionTreeFactor, simply call operator*. result = std::make_shared(this->operator*(*dtf)); + } else { // Simulate double dispatch in C++ + // Useful for other classes which inherit from DiscreteFactor and have + // only `operator*(DecisionTreeFactor)` defined. Thus, other classes don't + // need to be updated. result = std::make_shared(f->operator*(*this)); } return result; diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index ff9bf0df9..907f29a45 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -147,7 +147,20 @@ namespace gtsam { /// Calculate error for DiscreteValues `x`, is -log(probability). double error(const DiscreteValues& values) const override; - /// Multiply factors, DiscreteFactor::shared_ptr edition + /** + * @brief Multiply factors, DiscreteFactor::shared_ptr edition. + * + * This method accepts `DiscreteFactor::shared_ptr` and uses dynamic + * dispatch and specializations to perform the most efficient + * multiplication. + * + * While converting a DecisionTreeFactor to a TableFactor is efficient, the + * reverse is not. Hence we specialize the code to return a TableFactor if + * `f` is a TableFactor, and DecisionTreeFactor otherwise. + * + * @param f The factor to multiply with. + * @return DiscreteFactor::shared_ptr + */ virtual DiscreteFactor::shared_ptr multiply( const DiscreteFactor::shared_ptr& f) const override; diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 6516a4a98..fe901aac1 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -259,11 +259,21 @@ DiscreteFactor::shared_ptr TableFactor::multiply( const DiscreteFactor::shared_ptr& f) const { DiscreteFactor::shared_ptr result; if (auto tf = std::dynamic_pointer_cast(f)) { + // If `f` is a TableFactor, we can simply call `operator*`. result = std::make_shared(this->operator*(*tf)); + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + // If `f` is a DecisionTreeFactor, we convert to a TableFactor which is + // cheaper than converting `this` to a DecisionTreeFactor. result = std::make_shared(this->operator*(TableFactor(*dtf))); + } else { // Simulate double dispatch in C++ + // Useful for other classes which inherit from DiscreteFactor and have + // only `operator*(DecisionTreeFactor)` defined. Thus, other classes don't + // need to be updated to know about TableFactor. + // Those classes can be specialized to use TableFactor + // if efficiency is a problem. result = std::make_shared( f->operator*(this->toDecisionTreeFactor())); } diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 4b53d7e2b..a2e89b302 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -178,7 +178,20 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { /// multiply with DecisionTreeFactor DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; - /// Multiply factors, DiscreteFactor::shared_ptr edition + /** + * @brief Multiply factors, DiscreteFactor::shared_ptr edition. + * + * This method accepts `DiscreteFactor::shared_ptr` and uses dynamic + * dispatch and specializations to perform the most efficient + * multiplication. + * + * While converting a DecisionTreeFactor to a TableFactor is efficient, the + * reverse is not. + * Hence we specialize the code to return a TableFactor always. + * + * @param f The factor to multiply with. + * @return DiscreteFactor::shared_ptr + */ virtual DiscreteFactor::shared_ptr multiply( const DiscreteFactor::shared_ptr& f) const override;