From 4ebca711461eb3fd914b74b4532242aedf38c048 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 20:44:10 -0500 Subject: [PATCH] divide operator for DiscreteFactor::shared_ptr --- gtsam/discrete/DecisionTreeFactor.cpp | 13 +++++++++++++ gtsam/discrete/DecisionTreeFactor.h | 5 ++--- gtsam/discrete/DiscreteFactor.h | 4 ++++ gtsam/discrete/TableFactor.cpp | 14 ++++++++++++++ gtsam/discrete/TableFactor.h | 11 ++--------- gtsam_unstable/discrete/AllDiff.cpp | 6 ++++++ gtsam_unstable/discrete/AllDiff.h | 4 ++++ gtsam_unstable/discrete/BinaryAllDiff.h | 6 ++++++ gtsam_unstable/discrete/Domain.cpp | 6 ++++++ gtsam_unstable/discrete/Domain.h | 4 ++++ gtsam_unstable/discrete/SingleValue.cpp | 6 ++++++ gtsam_unstable/discrete/SingleValue.h | 4 ++++ 12 files changed, 71 insertions(+), 12 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 4b16dad8a..2f2c039a4 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -77,6 +77,19 @@ namespace gtsam { return result; } + /* ************************************************************************ */ + DiscreteFactor::shared_ptr DecisionTreeFactor::operator/( + const DiscreteFactor::shared_ptr& f) const { + if (auto tf = std::dynamic_pointer_cast(f)) { + return std::make_shared(tf->operator/(TableFactor(*this))); + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + return std::make_shared(this->operator/(*dtf)); + } else { + return std::make_shared( + this->operator/(this->toDecisionTreeFactor())); + } + } + /* ************************************************************************ */ double DecisionTreeFactor::safe_div(const double& a, const double& b) { // The use for safe_div is when we divide the product factor by the sum diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index d3cb55fa5..a5327bdd0 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -165,9 +165,8 @@ namespace gtsam { } /// divide by DiscreteFactor::shared_ptr f (safely) - DecisionTreeFactor operator/(const DiscreteFactor::shared_ptr& f) const { - return apply(*std::dynamic_pointer_cast(f), safe_div); - } + DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& f) const override; /// Convert into a decision tree DecisionTreeFactor toDecisionTreeFactor() const override { return *this; } diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index adc79bbd5..6cbc00d09 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -140,6 +140,10 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { virtual DiscreteFactor::shared_ptr multiply( const DiscreteFactor::shared_ptr& df) const = 0; + /// divide by DiscreteFactor::shared_ptr f (safely) + virtual DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& df) const = 0; + virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; /// Create new factor by summing all values with the same separator values diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 6516a4a98..b692e9ba2 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -270,6 +270,20 @@ DiscreteFactor::shared_ptr TableFactor::multiply( return result; } +/* ************************************************************************ */ +DiscreteFactor::shared_ptr TableFactor::operator/( + const DiscreteFactor::shared_ptr& f) const { + if (auto tf = std::dynamic_pointer_cast(f)) { + return std::make_shared(this->operator/(*tf)); + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + return std::make_shared( + this->operator/(TableFactor(f->discreteKeys(), *dtf))); + } else { + TableFactor divisor(f->toDecisionTreeFactor()); + return std::make_shared(this->operator/(divisor)); + } +} + /* ************************************************************************ */ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DiscreteKeys dkeys = discreteKeys(); diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index f7d0f5215..a2f74758f 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -191,15 +191,8 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { } /// divide by DiscreteFactor::shared_ptr f (safely) - TableFactor operator/(const DiscreteFactor::shared_ptr& f) const { - if (auto tf = std::dynamic_pointer_cast(f)) { - return apply(*tf, safe_div); - } else if (auto dtf = std::dynamic_pointer_cast(f)) { - return apply(TableFactor(f->discreteKeys(), *dtf), safe_div); - } else { - throw std::runtime_error("Unknown derived type for DiscreteFactor"); - } - } + DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& f) const override; /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override; diff --git a/gtsam_unstable/discrete/AllDiff.cpp b/gtsam_unstable/discrete/AllDiff.cpp index 585ca8103..01f50fa3d 100644 --- a/gtsam_unstable/discrete/AllDiff.cpp +++ b/gtsam_unstable/discrete/AllDiff.cpp @@ -56,6 +56,12 @@ DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const { return toDecisionTreeFactor() * f; } +/* ************************************************************************* */ +DiscreteFactor::shared_ptr AllDiff::operator/( + const DiscreteFactor::shared_ptr& df) const { + return this->toDecisionTreeFactor() / df; +} + /* ************************************************************************* */ bool AllDiff::ensureArcConsistency(Key j, Domains* domains) const { Domain& Dj = domains->at(j); diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index 267ddb9fd..7a7b1cecc 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -60,6 +60,10 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { this->operator*(df->toDecisionTreeFactor())); } + /// divide by DiscreteFactor::shared_ptr f (safely) + DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& df) const override; + /// Compute error for each assignment and return as a tree AlgebraicDecisionTree errorTree() const override { throw std::runtime_error("AllDiff::error not implemented"); diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index 3035d0620..fbff8a01c 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -76,6 +76,12 @@ class BinaryAllDiff : public Constraint { this->operator*(df->toDecisionTreeFactor())); } + /// divide by DiscreteFactor::shared_ptr f (safely) + DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& df) const override { + return this->toDecisionTreeFactor() / df; + } + /* * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked diff --git a/gtsam_unstable/discrete/Domain.cpp b/gtsam_unstable/discrete/Domain.cpp index 74f621dc7..cecb7cc1a 100644 --- a/gtsam_unstable/discrete/Domain.cpp +++ b/gtsam_unstable/discrete/Domain.cpp @@ -49,6 +49,12 @@ DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const { return toDecisionTreeFactor() * f; } +/* ************************************************************************* */ +DiscreteFactor::shared_ptr Domain::operator/( + const DiscreteFactor::shared_ptr& df) const { + return this->toDecisionTreeFactor() / df; +} + /* ************************************************************************* */ bool Domain::ensureArcConsistency(Key j, Domains* domains) const { if (j != key()) throw invalid_argument("Domain check on wrong domain"); diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 4c2d3f9dd..7362e9caf 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -97,6 +97,10 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { this->operator*(df->toDecisionTreeFactor())); } + /// divide by DiscreteFactor::shared_ptr f (safely) + DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& df) const override; + /* * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked diff --git a/gtsam_unstable/discrete/SingleValue.cpp b/gtsam_unstable/discrete/SingleValue.cpp index 220bc9c06..09a8314df 100644 --- a/gtsam_unstable/discrete/SingleValue.cpp +++ b/gtsam_unstable/discrete/SingleValue.cpp @@ -41,6 +41,12 @@ DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const { return toDecisionTreeFactor() * f; } +/* ************************************************************************* */ +DiscreteFactor::shared_ptr SingleValue::operator/( + const DiscreteFactor::shared_ptr& df) const { + return this->toDecisionTreeFactor() / df; +} + /* ************************************************************************* */ bool SingleValue::ensureArcConsistency(Key j, Domains* domains) const { if (j != keys_[0]) diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index b6c91f912..87c42fc80 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -70,6 +70,10 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { this->operator*(df->toDecisionTreeFactor())); } + /// divide by DiscreteFactor::shared_ptr f (safely) + DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& df) const override; + /* * Ensure Arc-consistency: just sets domain[j] to {value_}. * @param j domain to be checked