diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 7afbab0b0..8445c5332 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -158,26 +158,37 @@ namespace gtsam { return apply(f, safe_div); } + /// divide by factor f (pointer version) + DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& f) const override { + if (auto derived = std::dynamic_pointer_cast(f)) { + return std::make_shared(apply(*derived, safe_div)); + } else { + throw std::runtime_error( + "Cannot convert DiscreteFactor to Table Factor"); + } + } + /// Convert into a decision tree DecisionTreeFactor toDecisionTreeFactor() const override { return *this; } /// Create new factor by summing all values with the same separator values - shared_ptr sum(size_t nrFrontals) const { + DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { return combine(nrFrontals, ADT::Ring::add); } /// Create new factor by summing all values with the same separator values - shared_ptr sum(const Ordering& keys) const { + DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { return combine(keys, ADT::Ring::add); } /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(size_t nrFrontals) const { + DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { return combine(nrFrontals, ADT::Ring::max); } /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(const Ordering& keys) const { + DiscreteFactor::shared_ptr max(const Ordering& keys) const override { return combine(keys, ADT::Ring::max); } diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 4c486dca8..1ada7b7b2 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -22,6 +22,7 @@ #include #include #include +#include #include namespace gtsam { @@ -114,6 +115,22 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; + /// Create new factor by summing all values with the same separator values + virtual DiscreteFactor::shared_ptr sum(size_t nrFrontals) const = 0; + + /// Create new factor by summing all values with the same separator values + virtual DiscreteFactor::shared_ptr sum(const Ordering& keys) const = 0; + + /// Create new factor by maximizing over all values with the same separator. + virtual DiscreteFactor::shared_ptr max(size_t nrFrontals) const = 0; + + /// Create new factor by maximizing over all values with the same separator. + virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0; + + /// divide by factor f (safely) + virtual DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& f) const = 0; + /** * Get the number of non-zero values contained in this factor. * It could be much smaller than `prod_{key}(cardinality(key))`. diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 29cbd5e9b..e452b5be0 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -197,6 +197,16 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { return apply(f, safe_div); } + /// divide by factor f (pointer version) + DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& f) const override { + if (auto derived = std::dynamic_pointer_cast(f)) { + return std::make_shared(apply(*derived, safe_div)); + } else { + throw std::runtime_error("Cannot convert DiscreteFactor to Table Factor"); + } + } + /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override; @@ -205,22 +215,22 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { DiscreteKeys parent_keys) const; /// Create new factor by summing all values with the same separator values - shared_ptr sum(size_t nrFrontals) const { + DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { return combine(nrFrontals, Ring::add); } /// Create new factor by summing all values with the same separator values - shared_ptr sum(const Ordering& keys) const { + DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { return combine(keys, Ring::add); } /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(size_t nrFrontals) const { + DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { return combine(nrFrontals, Ring::max); } /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(const Ordering& keys) const { + DiscreteFactor::shared_ptr max(const Ordering& keys) const override { return combine(keys, Ring::max); }