From 5d865a8cc7908f02c206a3d9801af0fbaf8d1eaa Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 13:22:36 -0500 Subject: [PATCH 01/11] multiply method for DiscreteFactor --- gtsam/discrete/DecisionTreeFactor.cpp | 12 ++++++++++++ gtsam/discrete/DecisionTreeFactor.h | 5 +++++ gtsam/discrete/DiscreteFactor.h | 10 ++++++++++ gtsam/discrete/TableFactor.cpp | 12 ++++++++++++ gtsam/discrete/TableFactor.h | 4 ++++ 5 files changed, 43 insertions(+) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 1ac782b88..cf22fe153 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -62,6 +62,18 @@ namespace gtsam { return error(values.discrete()); } + /* ************************************************************************ */ + DiscreteFactor::shared_ptr DecisionTreeFactor::multiply( + const DiscreteFactor::shared_ptr& f) const override { + DiscreteFactor::shared_ptr result; + if (auto tf = std::dynamic_pointer_cast(f)) { + result = std::make_shared((*tf) * TableFactor(*this)); + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + result = std::make_shared(this->operator*(*dtf)); + } + return result; + } + /* ************************************************************************ */ double DecisionTreeFactor::safe_div(const double& a, const double& b) { // The use for safe_div is when we divide the product factor by the sum diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 80ee10a7b..3e70c0df9 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -147,6 +148,10 @@ namespace gtsam { /// Calculate error for DiscreteValues `x`, is -log(probability). double error(const DiscreteValues& values) const override; + /// Multiply factors, DiscreteFactor::shared_ptr edition + virtual DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& f) const override; + /// multiply two factors DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { return apply(f, Ring::mul); diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index a1fde0f86..c18eaae2f 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -129,6 +129,16 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { /// DecisionTreeFactor virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; + /** + * @brief Multiply in a DiscreteFactor and return the result as + * DiscreteFactor, both via shared pointers. + * + * @param df DiscreteFactor shared_ptr + * @return DiscreteFactor::shared_ptr + */ + virtual DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& df) const = 0; + virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; /// @} diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index a59095d40..cfa56b43a 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -254,6 +254,18 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { return toDecisionTreeFactor() * f; } +/* ************************************************************************ */ +DiscreteFactor::shared_ptr TableFactor::multiply( + const DiscreteFactor::shared_ptr& f) const override { + DiscreteFactor::shared_ptr result; + if (auto tf = std::dynamic_pointer_cast(f)) { + result = std::make_shared(this->operator*(*tf)); + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + result = std::make_shared(this->operator*(TableFactor(*dtf))); + } + return result; +} + /* ************************************************************************ */ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DiscreteKeys dkeys = discreteKeys(); diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index a2fdb4d32..4b53d7e2b 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -178,6 +178,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { /// multiply with DecisionTreeFactor DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + /// Multiply factors, DiscreteFactor::shared_ptr edition + virtual DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& f) const override; + static double safe_div(const double& a, const double& b); /// divide by factor f (safely) From 75a4e98715caca504bacf4cc92a768db3dd89303 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 13:36:06 -0500 Subject: [PATCH 02/11] remove override from definition --- gtsam/discrete/DecisionTreeFactor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index cf22fe153..c15fd4e2e 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -64,7 +64,7 @@ namespace gtsam { /* ************************************************************************ */ DiscreteFactor::shared_ptr DecisionTreeFactor::multiply( - const DiscreteFactor::shared_ptr& f) const override { + const DiscreteFactor::shared_ptr& f) const { DiscreteFactor::shared_ptr result; if (auto tf = std::dynamic_pointer_cast(f)) { result = std::make_shared((*tf) * TableFactor(*this)); From 700ad2bae326f3f89ecc78b56d55a74a64c3785a Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 13:56:54 -0500 Subject: [PATCH 03/11] remove override from TableFactor definition --- gtsam/discrete/TableFactor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index cfa56b43a..3ca8fecda 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -256,7 +256,7 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { /* ************************************************************************ */ DiscreteFactor::shared_ptr TableFactor::multiply( - const DiscreteFactor::shared_ptr& f) const override { + const DiscreteFactor::shared_ptr& f) const { DiscreteFactor::shared_ptr result; if (auto tf = std::dynamic_pointer_cast(f)) { result = std::make_shared(this->operator*(*tf)); From 260d448887447a9a9f3216a544ad34ddd1dfbd8c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 13:57:07 -0500 Subject: [PATCH 04/11] use new multiply method --- gtsam/discrete/DiscreteFactorGraph.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 227bb4da3..0444f47ae 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -65,11 +65,11 @@ namespace gtsam { /* ************************************************************************ */ DecisionTreeFactor DiscreteFactorGraph::product() const { - DecisionTreeFactor result; - for (const sharedFactor& factor : *this) { - if (factor) result = (*factor) * result; + DiscreteFactor::shared_ptr result = *this->begin(); + for (auto it = this->begin() + 1; it != this->end(); ++it) { + if (*it) result = result->multiply(*it); } - return result; + return result->toDecisionTreeFactor(); } /* ************************************************************************ */ From 453059bd61f4b08acdc43c37b9b098a5579d36a3 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 14:03:08 -0500 Subject: [PATCH 05/11] simplify to remove DiscreteProduct static function --- gtsam/discrete/DiscreteFactorGraph.cpp | 44 ++++++++++---------------- 1 file changed, 17 insertions(+), 27 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 0444f47ae..b3029111a 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -65,11 +65,23 @@ namespace gtsam { /* ************************************************************************ */ DecisionTreeFactor DiscreteFactorGraph::product() const { - DiscreteFactor::shared_ptr result = *this->begin(); + // PRODUCT: multiply all factors + gttic(product); + DiscreteFactor::shared_ptr product = *this->begin(); for (auto it = this->begin() + 1; it != this->end(); ++it) { - if (*it) result = result->multiply(*it); + if (*it) product = product->multiply(*it); } - return result->toDecisionTreeFactor(); + gttoc(product); + + DecisionTreeFactor = result->toDecisionTreeFactor(); + + // Max over all the potentials by pretending all keys are frontal: + auto denominator = product.max(product.size()); + + // Normalize the product factor to prevent underflow. + product = product / (*denominator); + + return product; } /* ************************************************************************ */ @@ -111,34 +123,12 @@ namespace gtsam { // } // } - /** - * @brief Multiply all the `factors`. - * - * @param factors The factors to multiply as a DiscreteFactorGraph. - * @return DecisionTreeFactor - */ - static DecisionTreeFactor DiscreteProduct( - const DiscreteFactorGraph& factors) { - // PRODUCT: multiply all factors - gttic(product); - DecisionTreeFactor product = factors.product(); - gttoc(product); - - // Max over all the potentials by pretending all keys are frontal: - auto denominator = product.max(product.size()); - - // Normalize the product factor to prevent underflow. - product = product / (*denominator); - - return product; - } - /* ************************************************************************ */ // Alternate eliminate function for MPE std::pair // EliminateForMPE(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DecisionTreeFactor product = DiscreteProduct(factors); + DecisionTreeFactor product = factors.product(); // max out frontals, this is the factor on the separator gttic(max); @@ -216,7 +206,7 @@ namespace gtsam { std::pair // EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DecisionTreeFactor product = DiscreteProduct(factors); + DecisionTreeFactor product = factors.product(); // sum out frontals, this is the factor on the separator gttic(sum); From a02baec0119ca9b670a8b5b64ebecc5db492bcbd Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 14:23:34 -0500 Subject: [PATCH 06/11] naive implementation of multiply for unstable --- gtsam_unstable/discrete/AllDiff.h | 7 +++++++ gtsam_unstable/discrete/BinaryAllDiff.h | 7 +++++++ gtsam_unstable/discrete/Domain.h | 7 +++++++ gtsam_unstable/discrete/SingleValue.h | 7 +++++++ 4 files changed, 28 insertions(+) diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index 1180abad4..cfbd76e7c 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -53,6 +53,13 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { /// Multiply into a decisiontree DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + /// Multiply factors, DiscreteFactor::shared_ptr edition + DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& df) const override { + return std::make_shared( + this->operator*(df->toDecisionTreeFactor())); + } + /// Compute error for each assignment and return as a tree AlgebraicDecisionTree errorTree() const override { throw std::runtime_error("AllDiff::error not implemented"); diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index e96bfdfde..a1a2bf0a6 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -69,6 +69,13 @@ class BinaryAllDiff : public Constraint { return toDecisionTreeFactor() * f; } + /// Multiply factors, DiscreteFactor::shared_ptr edition + DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& df) const override { + return std::make_shared( + this->operator*(df->toDecisionTreeFactor())); + } + /* * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 23a566d24..dea85934f 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -90,6 +90,13 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { /// Multiply into a decisiontree DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + /// Multiply factors, DiscreteFactor::shared_ptr edition + DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& df) const override { + return std::make_shared( + this->operator*(df->toDecisionTreeFactor())); + } + /* * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index 3df1209b8..8675c929b 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -63,6 +63,13 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { /// Multiply into a decisiontree DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + /// Multiply factors, DiscreteFactor::shared_ptr edition + DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& df) const override { + return std::make_shared( + this->operator*(df->toDecisionTreeFactor())); + } + /* * Ensure Arc-consistency: just sets domain[j] to {value_}. * @param j domain to be checked From 13bafb0a48bae3c5598b7db112ac2757bd02c431 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 14:28:57 -0500 Subject: [PATCH 07/11] Revert "simplify to remove DiscreteProduct static function" This reverts commit 453059bd61f4b08acdc43c37b9b098a5579d36a3. --- gtsam/discrete/DiscreteFactorGraph.cpp | 44 ++++++++++++++++---------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index b3029111a..0444f47ae 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -65,23 +65,11 @@ namespace gtsam { /* ************************************************************************ */ DecisionTreeFactor DiscreteFactorGraph::product() const { - // PRODUCT: multiply all factors - gttic(product); - DiscreteFactor::shared_ptr product = *this->begin(); + DiscreteFactor::shared_ptr result = *this->begin(); for (auto it = this->begin() + 1; it != this->end(); ++it) { - if (*it) product = product->multiply(*it); + if (*it) result = result->multiply(*it); } - gttoc(product); - - DecisionTreeFactor = result->toDecisionTreeFactor(); - - // Max over all the potentials by pretending all keys are frontal: - auto denominator = product.max(product.size()); - - // Normalize the product factor to prevent underflow. - product = product / (*denominator); - - return product; + return result->toDecisionTreeFactor(); } /* ************************************************************************ */ @@ -123,12 +111,34 @@ namespace gtsam { // } // } + /** + * @brief Multiply all the `factors`. + * + * @param factors The factors to multiply as a DiscreteFactorGraph. + * @return DecisionTreeFactor + */ + static DecisionTreeFactor DiscreteProduct( + const DiscreteFactorGraph& factors) { + // PRODUCT: multiply all factors + gttic(product); + DecisionTreeFactor product = factors.product(); + gttoc(product); + + // Max over all the potentials by pretending all keys are frontal: + auto denominator = product.max(product.size()); + + // Normalize the product factor to prevent underflow. + product = product / (*denominator); + + return product; + } + /* ************************************************************************ */ // Alternate eliminate function for MPE std::pair // EliminateForMPE(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DecisionTreeFactor product = factors.product(); + DecisionTreeFactor product = DiscreteProduct(factors); // max out frontals, this is the factor on the separator gttic(max); @@ -206,7 +216,7 @@ namespace gtsam { std::pair // EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DecisionTreeFactor product = factors.product(); + DecisionTreeFactor product = DiscreteProduct(factors); // sum out frontals, this is the factor on the separator gttic(sum); From a7fc6e3763d4dc5ea5019906be37e05613d3f9d9 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 15:08:58 -0500 Subject: [PATCH 08/11] convert everything to DecisionTreeFactor so we can use override operator* method --- gtsam_unstable/discrete/AllDiff.h | 4 ++-- gtsam_unstable/discrete/BinaryAllDiff.h | 4 ++-- gtsam_unstable/discrete/Domain.h | 4 ++-- gtsam_unstable/discrete/SingleValue.h | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index cfbd76e7c..032808dcd 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -56,8 +56,8 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { /// Multiply factors, DiscreteFactor::shared_ptr edition DiscreteFactor::shared_ptr multiply( const DiscreteFactor::shared_ptr& df) const override { - return std::make_shared( - this->operator*(df->toDecisionTreeFactor())); + return std::make_shared(this->toDecisionTreeFactor() * + df->toDecisionTreeFactor()); } /// Compute error for each assignment and return as a tree diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index a1a2bf0a6..0ebae4d77 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -72,8 +72,8 @@ class BinaryAllDiff : public Constraint { /// Multiply factors, DiscreteFactor::shared_ptr edition DiscreteFactor::shared_ptr multiply( const DiscreteFactor::shared_ptr& df) const override { - return std::make_shared( - this->operator*(df->toDecisionTreeFactor())); + return std::make_shared(this->toDecisionTreeFactor() * + df->toDecisionTreeFactor()); } /* diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index dea85934f..9a4a21847 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -93,8 +93,8 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { /// Multiply factors, DiscreteFactor::shared_ptr edition DiscreteFactor::shared_ptr multiply( const DiscreteFactor::shared_ptr& df) const override { - return std::make_shared( - this->operator*(df->toDecisionTreeFactor())); + return std::make_shared(this->toDecisionTreeFactor() * + df->toDecisionTreeFactor()); } /* diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index 8675c929b..ebe23f7e4 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -66,8 +66,8 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { /// Multiply factors, DiscreteFactor::shared_ptr edition DiscreteFactor::shared_ptr multiply( const DiscreteFactor::shared_ptr& df) const override { - return std::make_shared( - this->operator*(df->toDecisionTreeFactor())); + return std::make_shared(this->toDecisionTreeFactor() * + df->toDecisionTreeFactor()); } /* From 8390ffa2cbde062f7628a953bba6f43ba77cc7d1 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 15:19:16 -0500 Subject: [PATCH 09/11] revert previous commit --- gtsam_unstable/discrete/AllDiff.h | 4 ++-- gtsam_unstable/discrete/BinaryAllDiff.h | 4 ++-- gtsam_unstable/discrete/Domain.h | 4 ++-- gtsam_unstable/discrete/SingleValue.h | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index 032808dcd..cfbd76e7c 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -56,8 +56,8 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { /// Multiply factors, DiscreteFactor::shared_ptr edition DiscreteFactor::shared_ptr multiply( const DiscreteFactor::shared_ptr& df) const override { - return std::make_shared(this->toDecisionTreeFactor() * - df->toDecisionTreeFactor()); + return std::make_shared( + this->operator*(df->toDecisionTreeFactor())); } /// Compute error for each assignment and return as a tree diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index 0ebae4d77..a1a2bf0a6 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -72,8 +72,8 @@ class BinaryAllDiff : public Constraint { /// Multiply factors, DiscreteFactor::shared_ptr edition DiscreteFactor::shared_ptr multiply( const DiscreteFactor::shared_ptr& df) const override { - return std::make_shared(this->toDecisionTreeFactor() * - df->toDecisionTreeFactor()); + return std::make_shared( + this->operator*(df->toDecisionTreeFactor())); } /* diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 9a4a21847..dea85934f 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -93,8 +93,8 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { /// Multiply factors, DiscreteFactor::shared_ptr edition DiscreteFactor::shared_ptr multiply( const DiscreteFactor::shared_ptr& df) const override { - return std::make_shared(this->toDecisionTreeFactor() * - df->toDecisionTreeFactor()); + return std::make_shared( + this->operator*(df->toDecisionTreeFactor())); } /* diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index ebe23f7e4..8675c929b 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -66,8 +66,8 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { /// Multiply factors, DiscreteFactor::shared_ptr edition DiscreteFactor::shared_ptr multiply( const DiscreteFactor::shared_ptr& df) const override { - return std::make_shared(this->toDecisionTreeFactor() * - df->toDecisionTreeFactor()); + return std::make_shared( + this->operator*(df->toDecisionTreeFactor())); } /* From bc63cc8cb88c7edebf86a4136d7231ab51a59e47 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 15:19:57 -0500 Subject: [PATCH 10/11] use double dispatch for else case --- gtsam/discrete/DecisionTreeFactor.cpp | 3 +++ gtsam/discrete/TableFactor.cpp | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index c15fd4e2e..4b16dad8a 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -70,6 +70,9 @@ namespace gtsam { result = std::make_shared((*tf) * TableFactor(*this)); } else if (auto dtf = std::dynamic_pointer_cast(f)) { result = std::make_shared(this->operator*(*dtf)); + } else { + // Simulate double dispatch in C++ + result = std::make_shared(f->operator*(*this)); } return result; } diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 3ca8fecda..6516a4a98 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -262,6 +262,10 @@ DiscreteFactor::shared_ptr TableFactor::multiply( result = std::make_shared(this->operator*(*tf)); } else if (auto dtf = std::dynamic_pointer_cast(f)) { result = std::make_shared(this->operator*(TableFactor(*dtf))); + } else { + // Simulate double dispatch in C++ + result = std::make_shared( + f->operator*(this->toDecisionTreeFactor())); } return result; } From 713c49c9153f9f023621a6cd9a06997594ea52ab Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 15:20:09 -0500 Subject: [PATCH 11/11] more robust product --- gtsam/discrete/DiscreteFactorGraph.cpp | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 0444f47ae..a2b896286 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -65,9 +65,16 @@ namespace gtsam { /* ************************************************************************ */ DecisionTreeFactor DiscreteFactorGraph::product() const { - DiscreteFactor::shared_ptr result = *this->begin(); - for (auto it = this->begin() + 1; it != this->end(); ++it) { - if (*it) result = result->multiply(*it); + DiscreteFactor::shared_ptr result; + for (auto it = this->begin(); it != this->end(); ++it) { + if (*it) { + if (result) { + result = result->multiply(*it); + } else { + // Assign to the first non-null factor + result = *it; + } + } } return result->toDecisionTreeFactor(); }