New scalar operators

release/4.3a0
Frank Dellaert 2025-01-30 15:54:11 -05:00
parent 39e4610077
commit c7864d32b5
7 changed files with 71 additions and 16 deletions

View File

@ -164,6 +164,12 @@ namespace gtsam {
virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& f) const override;
/// multiply with a scalar
DiscreteFactor::shared_ptr operator*(double s) const override {
return std::make_shared<DecisionTreeFactor>(
apply([s](const double& a) { return Ring::mul(a, s); }));
}
/// multiply two factors
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
return apply(f, Ring::mul);
@ -201,6 +207,9 @@ namespace gtsam {
return combine(keys, Ring::add);
}
/// Find the maximum value in the factor.
double max() const override { return ADT::max(); };
/// Create new factor by maximizing over all values with the same separator.
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::max);

View File

@ -73,10 +73,7 @@ AlgebraicDecisionTree<Key> DiscreteFactor::errorTree() const {
/* ************************************************************************ */
DiscreteFactor::shared_ptr DiscreteFactor::scale() const {
// Max over all the potentials by pretending all keys are frontal:
shared_ptr denominator = this->max(this->size());
// Normalize the product factor to prevent underflow.
return this->operator/(denominator);
return this->operator*(1.0 / max());
}
} // namespace gtsam

View File

@ -126,6 +126,9 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
/// Compute error for each assignment and return as a tree
virtual AlgebraicDecisionTree<Key> errorTree() const;
/// Multiply with a scalar
virtual DiscreteFactor::shared_ptr operator*(double s) const = 0;
/// Multiply in a DecisionTreeFactor and return the result as
/// DecisionTreeFactor
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
@ -152,6 +155,9 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
/// Create new factor by summing all values with the same separator values
virtual DiscreteFactor::shared_ptr sum(const Ordering& keys) const = 0;
/// Find the maximum value in the factor.
virtual double max() const = 0;
/// Create new factor by maximizing over all values with the same separator.
virtual DiscreteFactor::shared_ptr max(size_t nrFrontals) const = 0;

View File

@ -110,6 +110,11 @@ DiscreteFactor::shared_ptr TableDistribution::max(const Ordering& keys) const {
return table_.max(keys);
}
/* ****************************************************************************/
DiscreteFactor::shared_ptr TableDistribution::operator*(double s) const {
return table_ * s;
}
/* ****************************************************************************/
DiscreteFactor::shared_ptr TableDistribution::operator/(
const DiscreteFactor::shared_ptr& f) const {

View File

@ -116,12 +116,19 @@ class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
/// Create new factor by summing all values with the same separator values
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override;
/// Find the maximum value in the factor.
double max() const override { return table_.max(); }
/// Create new factor by maximizing over all values with the same separator.
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override;
/// Create new factor by maximizing over all values with the same separator.
DiscreteFactor::shared_ptr max(const Ordering& keys) const override;
/// Multiply by scalar s
DiscreteFactor::shared_ptr operator*(double s) const override;
/// divide by DiscreteFactor::shared_ptr f (safely)
DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& f) const override;

View File

@ -389,6 +389,36 @@ void TableFactor::print(const string& s, const KeyFormatter& formatter) const {
cout << "number of nnzs: " << sparse_table_.nonZeros() << endl;
}
/* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::sum(size_t nrFrontals) const {
return combine(nrFrontals, Ring::add);
}
/* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::sum(const Ordering& keys) const {
return combine(keys, Ring::add);
}
/* ************************************************************************ */
double TableFactor::max() const {
double max_value = std::numeric_limits<double>::lowest();
for (Eigen::SparseVector<double>::InnerIterator it(sparse_table_); it; ++it) {
max_value = std::max(max_value, it.value());
}
return max_value;
}
/* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::max(size_t nrFrontals) const {
return combine(nrFrontals, Ring::max);
}
/* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::max(const Ordering& keys) const {
return combine(keys, Ring::max);
}
/* ************************************************************************ */
TableFactor TableFactor::apply(Unary op) const {
// Initialize new factor.

View File

@ -171,6 +171,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const override;
/// multiply with a scalar
DiscreteFactor::shared_ptr operator*(double s) const override {
return std::make_shared<TableFactor>(
apply([s](const double& a) { return Ring::mul(a, s); }));
}
/// multiply two TableFactors
TableFactor operator*(const TableFactor& f) const {
return apply(f, Ring::mul);
@ -215,24 +221,19 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
DiscreteKeys parent_keys) const;
/// Create new factor by summing all values with the same separator values
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::add);
}
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override;
/// Create new factor by summing all values with the same separator values
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
return combine(keys, Ring::add);
}
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override;
/// Find the maximum value in the factor.
double max() const override;
/// Create new factor by maximizing over all values with the same separator.
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::max);
}
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override;
/// Create new factor by maximizing over all values with the same separator.
DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
return combine(keys, Ring::max);
}
DiscreteFactor::shared_ptr max(const Ordering& keys) const override;
/// @}
/// @name Advanced Interface