diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 4da5a7c17..716c43b63 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -164,6 +164,12 @@ namespace gtsam { virtual DiscreteFactor::shared_ptr multiply( const DiscreteFactor::shared_ptr& f) const override; + /// multiply with a scalar + DiscreteFactor::shared_ptr operator*(double s) const override { + return std::make_shared( + apply([s](const double& a) { return Ring::mul(a, s); })); + } + /// multiply two factors DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { return apply(f, Ring::mul); @@ -201,6 +207,9 @@ namespace gtsam { return combine(keys, Ring::add); } + /// Find the maximum value in the factor. + double max() const override { return ADT::max(); }; + /// Create new factor by maximizing over all values with the same separator. DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { return combine(nrFrontals, Ring::max); diff --git a/gtsam/discrete/DiscreteFactor.cpp b/gtsam/discrete/DiscreteFactor.cpp index faae02af2..61b4c135c 100644 --- a/gtsam/discrete/DiscreteFactor.cpp +++ b/gtsam/discrete/DiscreteFactor.cpp @@ -73,10 +73,7 @@ AlgebraicDecisionTree DiscreteFactor::errorTree() const { /* ************************************************************************ */ DiscreteFactor::shared_ptr DiscreteFactor::scale() const { - // Max over all the potentials by pretending all keys are frontal: - shared_ptr denominator = this->max(this->size()); - // Normalize the product factor to prevent underflow. - return this->operator/(denominator); + return this->operator*(1.0 / max()); } } // namespace gtsam diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index fafb4dbf5..6fa074379 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -126,6 +126,9 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { /// Compute error for each assignment and return as a tree virtual AlgebraicDecisionTree errorTree() const; + /// Multiply with a scalar + virtual DiscreteFactor::shared_ptr operator*(double s) const = 0; + /// Multiply in a DecisionTreeFactor and return the result as /// DecisionTreeFactor virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; @@ -152,6 +155,9 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { /// Create new factor by summing all values with the same separator values virtual DiscreteFactor::shared_ptr sum(const Ordering& keys) const = 0; + /// Find the maximum value in the factor. + virtual double max() const = 0; + /// Create new factor by maximizing over all values with the same separator. virtual DiscreteFactor::shared_ptr max(size_t nrFrontals) const = 0; diff --git a/gtsam/discrete/TableDistribution.cpp b/gtsam/discrete/TableDistribution.cpp index e8696c5b1..ce0d92bff 100644 --- a/gtsam/discrete/TableDistribution.cpp +++ b/gtsam/discrete/TableDistribution.cpp @@ -110,6 +110,11 @@ DiscreteFactor::shared_ptr TableDistribution::max(const Ordering& keys) const { return table_.max(keys); } +/* ****************************************************************************/ +DiscreteFactor::shared_ptr TableDistribution::operator*(double s) const { + return table_ * s; +} + /* ****************************************************************************/ DiscreteFactor::shared_ptr TableDistribution::operator/( const DiscreteFactor::shared_ptr& f) const { diff --git a/gtsam/discrete/TableDistribution.h b/gtsam/discrete/TableDistribution.h index 72786a515..8e28bed5f 100644 --- a/gtsam/discrete/TableDistribution.h +++ b/gtsam/discrete/TableDistribution.h @@ -116,12 +116,19 @@ class GTSAM_EXPORT TableDistribution : public DiscreteConditional { /// Create new factor by summing all values with the same separator values DiscreteFactor::shared_ptr sum(const Ordering& keys) const override; + /// Find the maximum value in the factor. + double max() const override { return table_.max(); } + /// Create new factor by maximizing over all values with the same separator. DiscreteFactor::shared_ptr max(size_t nrFrontals) const override; /// Create new factor by maximizing over all values with the same separator. DiscreteFactor::shared_ptr max(const Ordering& keys) const override; + + /// Multiply by scalar s + DiscreteFactor::shared_ptr operator*(double s) const override; + /// divide by DiscreteFactor::shared_ptr f (safely) DiscreteFactor::shared_ptr operator/( const DiscreteFactor::shared_ptr& f) const override; diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index d1cedc9ef..25acae06e 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -389,6 +389,36 @@ void TableFactor::print(const string& s, const KeyFormatter& formatter) const { cout << "number of nnzs: " << sparse_table_.nonZeros() << endl; } +/* ************************************************************************ */ +DiscreteFactor::shared_ptr TableFactor::sum(size_t nrFrontals) const { +return combine(nrFrontals, Ring::add); +} + +/* ************************************************************************ */ +DiscreteFactor::shared_ptr TableFactor::sum(const Ordering& keys) const { +return combine(keys, Ring::add); +} + +/* ************************************************************************ */ +double TableFactor::max() const { + double max_value = std::numeric_limits::lowest(); + for (Eigen::SparseVector::InnerIterator it(sparse_table_); it; ++it) { + max_value = std::max(max_value, it.value()); + } + return max_value; +} + +/* ************************************************************************ */ +DiscreteFactor::shared_ptr TableFactor::max(size_t nrFrontals) const { + return combine(nrFrontals, Ring::max); +} + +/* ************************************************************************ */ +DiscreteFactor::shared_ptr TableFactor::max(const Ordering& keys) const { + return combine(keys, Ring::max); +} + + /* ************************************************************************ */ TableFactor TableFactor::apply(Unary op) const { // Initialize new factor. diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 1cb9eda8b..ce58d14bc 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -171,6 +171,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { /// Calculate error for DiscreteValues `x`, is -log(probability). double error(const DiscreteValues& values) const override; + /// multiply with a scalar + DiscreteFactor::shared_ptr operator*(double s) const override { + return std::make_shared( + apply([s](const double& a) { return Ring::mul(a, s); })); + } + /// multiply two TableFactors TableFactor operator*(const TableFactor& f) const { return apply(f, Ring::mul); @@ -215,24 +221,19 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { DiscreteKeys parent_keys) const; /// Create new factor by summing all values with the same separator values - DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { - return combine(nrFrontals, Ring::add); - } + DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override; /// Create new factor by summing all values with the same separator values - DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { - return combine(keys, Ring::add); - } + DiscreteFactor::shared_ptr sum(const Ordering& keys) const override; + + /// Find the maximum value in the factor. + double max() const override; /// Create new factor by maximizing over all values with the same separator. - DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { - return combine(nrFrontals, Ring::max); - } + DiscreteFactor::shared_ptr max(size_t nrFrontals) const override; /// Create new factor by maximizing over all values with the same separator. - DiscreteFactor::shared_ptr max(const Ordering& keys) const override { - return combine(keys, Ring::max); - } + DiscreteFactor::shared_ptr max(const Ordering& keys) const override; /// @} /// @name Advanced Interface