From d1d440ad3420efb6a35bef80f39699b6b075e810 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 7 Dec 2024 10:53:32 -0500 Subject: [PATCH 01/86] add nrValues method --- gtsam/discrete/DecisionTreeFactor.h | 6 ++++++ gtsam/discrete/DiscreteFactor.h | 6 ++++++ gtsam/discrete/TableFactor.h | 6 ++++++ gtsam_unstable/discrete/AllDiff.h | 3 +++ gtsam_unstable/discrete/BinaryAllDiff.h | 3 +++ gtsam_unstable/discrete/Domain.h | 2 +- gtsam_unstable/discrete/SingleValue.h | 3 +++ 7 files changed, 28 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index a8ab2644f..f417a38d7 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -255,6 +255,12 @@ namespace gtsam { */ DecisionTreeFactor prune(size_t maxNrAssignments) const; + /** + * Get the number of non-zero values contained in this factor. + * It could be much smaller than `prod_{key}(cardinality(key))`. + */ + uint64_t nrValues() const override { return nrLeaves(); } + /// @} /// @name Wrapper support /// @{ diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 19af5bd13..7d5047ec6 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -113,6 +113,12 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; + /** + * Get the number of non-zero values contained in this factor. + * It could be much smaller than `prod_{key}(cardinality(key))`. + */ + virtual uint64_t nrValues() const = 0; + /// @} /// @name Wrapper support /// @{ diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index f0ecd66a3..b988eebad 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -324,6 +324,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { */ TableFactor prune(size_t maxNrAssignments) const; + /** + * Get the number of non-zero values contained in this factor. + * It could be much smaller than `prod_{key}(cardinality(key))`. + */ + uint64_t nrValues() const override { return sparse_table_.nonZeros(); } + /// @} /// @name Wrapper support /// @{ diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index d7a63eae0..42a255bbf 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -72,6 +72,9 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { /// Partially apply known values, domain version Constraint::shared_ptr partiallyApply( const Domains&) const override; + + /// Get the number of non-zero values contained in this factor. + uint64_t nrValues() const override { return 1; }; }; } // namespace gtsam diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index 18b335092..22acfb092 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -96,6 +96,9 @@ class BinaryAllDiff : public Constraint { AlgebraicDecisionTree errorTree() const override { throw std::runtime_error("BinaryAllDiff::error not implemented"); } + + /// Get the number of non-zero values contained in this factor. + uint64_t nrValues() const override { return 1; }; }; } // namespace gtsam diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 7f7b717c2..ba3771eca 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -49,7 +49,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { /// Erase a value, non const :-( void erase(size_t value) { values_.erase(value); } - size_t nrValues() const { return values_.size(); } + uint64_t nrValues() const override { return values_.size(); } bool isSingleton() const { return nrValues() == 1; } diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index 3f7f22d6a..7f2eb2c2c 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -77,6 +77,9 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { /// Partially apply known values, domain version Constraint::shared_ptr partiallyApply( const Domains& domains) const override; + + /// Get the number of non-zero values contained in this factor. + uint64_t nrValues() const override { return 1; }; }; } // namespace gtsam From a68da21527760daff64161f3feffc6cc1d46d1b1 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 7 Dec 2024 11:02:30 -0500 Subject: [PATCH 02/86] operator* version which accepts DiscreteFactor --- gtsam/discrete/DecisionTreeFactor.cpp | 11 +++++++++++ gtsam/discrete/DecisionTreeFactor.h | 5 ++++- gtsam/discrete/DiscreteFactor.h | 7 ++++--- gtsam/discrete/TableFactor.cpp | 9 +++++++-- gtsam/discrete/TableFactor.h | 5 +++-- 5 files changed, 29 insertions(+), 8 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 9ec3b0ac5..e53f8cb90 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -82,6 +82,17 @@ namespace gtsam { ADT::print("", formatter); } + /* ************************************************************************ */ + DiscreteFactor::shared_ptr DecisionTreeFactor::operator*( + const DiscreteFactor::shared_ptr& f) const { + if (auto derived = std::dynamic_pointer_cast(f)) { + return std::make_shared(this->operator*(*derived)); + } else { + throw std::runtime_error( + "Cannot convert DiscreteFactor to DecisionTreeFactor"); + } + } + /* ************************************************************************ */ DecisionTreeFactor DecisionTreeFactor::apply(ADT::Unary op) const { // apply operand diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index f417a38d7..7afbab0b0 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -144,10 +144,13 @@ namespace gtsam { double error(const DiscreteValues& values) const override; /// multiply two factors - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const { return apply(f, ADT::Ring::mul); } + DiscreteFactor::shared_ptr operator*( + const DiscreteFactor::shared_ptr& f) const override; + static double safe_div(const double& a, const double& b); /// divide by factor f (safely) diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 7d5047ec6..4c486dca8 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -107,9 +107,10 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { /// Compute error for each assignment and return as a tree virtual AlgebraicDecisionTree errorTree() const; - /// Multiply in a DecisionTreeFactor and return the result as - /// DecisionTreeFactor - virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; + /// Multiply in a DiscreteFactor and return the result as + /// DiscreteFactor + virtual DiscreteFactor::shared_ptr operator*( + const DiscreteFactor::shared_ptr&) const = 0; virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index f4e023a4d..7cf520973 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -169,8 +169,13 @@ double TableFactor::error(const HybridValues& values) const { } /* ************************************************************************ */ -DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { - return toDecisionTreeFactor() * f; +DiscreteFactor::shared_ptr TableFactor::operator*( + const DiscreteFactor::shared_ptr& f) const { + if (auto derived = std::dynamic_pointer_cast(f)) { + return std::make_shared(this->operator*(*derived)); + } else { + throw std::runtime_error("Cannot convert DiscreteFactor to TableFactor"); + } } /* ************************************************************************ */ diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index b988eebad..29cbd5e9b 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -186,8 +186,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { return apply(f, Ring::mul); }; - /// multiply with DecisionTreeFactor - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + /// multiply with DiscreteFactor + DiscreteFactor::shared_ptr operator*( + const DiscreteFactor::shared_ptr& f) const override; static double safe_div(const double& a, const double& b); From a09b77ef407b7ac91b1604153d7b2ec08301b4b8 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 7 Dec 2024 11:07:26 -0500 Subject: [PATCH 03/86] return DiscreteFactor shared_ptr as leftover from elimination --- gtsam/discrete/DiscreteFactorGraph.cpp | 4 ++-- gtsam/discrete/DiscreteFactorGraph.h | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 4ededbb8b..f81c6085e 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -112,7 +112,7 @@ namespace gtsam { /* ************************************************************************ */ // Alternate eliminate function for MPE - std::pair // + std::pair EliminateForMPE(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { // PRODUCT: multiply all factors @@ -201,7 +201,7 @@ namespace gtsam { } /* ************************************************************************ */ - std::pair // + std::pair EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { // PRODUCT: multiply all factors diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index d0dc282b4..a5324811c 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -48,7 +48,7 @@ class DiscreteJunctionTree; * @ingroup discrete */ GTSAM_EXPORT -std::pair +std::pair EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys); @@ -61,7 +61,7 @@ EliminateDiscrete(const DiscreteFactorGraph& factors, * @ingroup discrete */ GTSAM_EXPORT -std::pair +std::pair EliminateForMPE(const DiscreteFactorGraph& factors, const Ordering& frontalKeys); @@ -133,6 +133,7 @@ class GTSAM_EXPORT DiscreteFactorGraph /// @} + //TODO(Varun): Make compatible with TableFactor /** Add a decision-tree factor */ template void add(Args&&... args) { @@ -146,7 +147,7 @@ class GTSAM_EXPORT DiscreteFactorGraph DiscreteKeys discreteKeys() const; /** return product of all factors as a single factor */ - DecisionTreeFactor product() const; + DiscreteFactor::shared_ptr product() const; /** * Evaluates the factor graph given values, returns the joint probability of From 27bbce150aa2dac82e57ff84a91c17a9873fab7b Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 7 Dec 2024 11:10:24 -0500 Subject: [PATCH 04/86] generalize DiscreteFactorGraph::product to DiscreteFactor --- gtsam/discrete/DiscreteFactorGraph.cpp | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index f81c6085e..b27438130 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -64,10 +64,16 @@ namespace gtsam { } /* ************************************************************************* */ - DecisionTreeFactor DiscreteFactorGraph::product() const { - DecisionTreeFactor result; - for(const sharedFactor& factor: *this) - if (factor) result = (*factor) * result; + DiscreteFactor::shared_ptr DiscreteFactorGraph::product() const { + sharedFactor result = this->at(0); + for (size_t i = 1; i < this->size(); ++i) { + const sharedFactor factor = this->at(i); + if (factor) { + // Predicated on the fact that all discrete factors are of a single type + // so there is no type-conversion happening which can be expensive. + result = result->operator*(factor); + } + } return result; } From 84e419456af55c0d194645c302934316146b3803 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 7 Dec 2024 11:15:06 -0500 Subject: [PATCH 05/86] make normalization code common --- gtsam/discrete/DiscreteFactorGraph.cpp | 50 ++++++++++++++++---------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index b27438130..29bd1f9ac 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -116,6 +116,28 @@ namespace gtsam { // } // } + /** + * @brief Helper method to normalize the product factor by + * the max value to prevent underflow + * + * @param product The product discrete factor. + * @return DiscreteFactor::shared_ptr + */ + static DiscreteFactor::shared_ptr Normalize( + const DiscreteFactor::shared_ptr& product) { + // Max over all the potentials by pretending all keys are frontal: + gttic(DiscreteFindMax); + auto normalization = product->max(product->size()); + gttoc(DiscreteFindMax); + + gttic(DiscreteNormalization); + // Normalize the product factor to prevent underflow. + auto normalized_product = product->operator/(normalization); + gttoc(DiscreteNormalization); + + return normalized_product; + } + /* ************************************************************************ */ // Alternate eliminate function for MPE std::pair @@ -123,27 +145,23 @@ namespace gtsam { const Ordering& frontalKeys) { // PRODUCT: multiply all factors gttic(product); - DecisionTreeFactor product; - for (auto&& factor : factors) product = (*factor) * product; + DiscreteFactor::shared_ptr product = factors.product(); gttoc(product); - // Max over all the potentials by pretending all keys are frontal: - auto normalization = product.max(product.size()); - - // Normalize the product factor to prevent underflow. - product = product / (*normalization); + // Normalize the product + product = Normalize(product); // max out frontals, this is the factor on the separator gttic(max); - DecisionTreeFactor::shared_ptr max = product.max(frontalKeys); + DecisionTreeFactor::shared_ptr max = product->max(frontalKeys); gttoc(max); // Ordering keys for the conditional so that frontalKeys are really in front DiscreteKeys orderedKeys; for (auto&& key : frontalKeys) - orderedKeys.emplace_back(key, product.cardinality(key)); + orderedKeys.emplace_back(key, product->cardinality(key)); for (auto&& key : max->keys()) - orderedKeys.emplace_back(key, product.cardinality(key)); + orderedKeys.emplace_back(key, product->cardinality(key)); // Make lookup with product gttic(lookup); @@ -212,19 +230,15 @@ namespace gtsam { const Ordering& frontalKeys) { // PRODUCT: multiply all factors gttic(product); - DecisionTreeFactor product; - for (auto&& factor : factors) product = (*factor) * product; + DiscreteFactor::shared_ptr product = factors.product(); gttoc(product); - // Max over all the potentials by pretending all keys are frontal: - auto normalization = product.max(product.size()); - - // Normalize the product factor to prevent underflow. - product = product / (*normalization); + // Normalize the product + product = Normalize(product); // sum out frontals, this is the factor on the separator gttic(sum); - DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys); + DecisionTreeFactor::shared_ptr sum = product->sum(frontalKeys); gttoc(sum); // Ordering keys for the conditional so that frontalKeys are really in front From 4dac37ce2b0000fd51f67d361c9fbcbd744c06ab Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 7 Dec 2024 11:16:49 -0500 Subject: [PATCH 06/86] make sum and max DiscreteFactor methods --- gtsam/discrete/DecisionTreeFactor.h | 19 +++++++++++++++---- gtsam/discrete/DiscreteFactor.h | 17 +++++++++++++++++ gtsam/discrete/TableFactor.h | 18 ++++++++++++++---- 3 files changed, 46 insertions(+), 8 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 7afbab0b0..8445c5332 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -158,26 +158,37 @@ namespace gtsam { return apply(f, safe_div); } + /// divide by factor f (pointer version) + DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& f) const override { + if (auto derived = std::dynamic_pointer_cast(f)) { + return std::make_shared(apply(*derived, safe_div)); + } else { + throw std::runtime_error( + "Cannot convert DiscreteFactor to Table Factor"); + } + } + /// Convert into a decision tree DecisionTreeFactor toDecisionTreeFactor() const override { return *this; } /// Create new factor by summing all values with the same separator values - shared_ptr sum(size_t nrFrontals) const { + DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { return combine(nrFrontals, ADT::Ring::add); } /// Create new factor by summing all values with the same separator values - shared_ptr sum(const Ordering& keys) const { + DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { return combine(keys, ADT::Ring::add); } /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(size_t nrFrontals) const { + DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { return combine(nrFrontals, ADT::Ring::max); } /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(const Ordering& keys) const { + DiscreteFactor::shared_ptr max(const Ordering& keys) const override { return combine(keys, ADT::Ring::max); } diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 4c486dca8..1ada7b7b2 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -22,6 +22,7 @@ #include #include #include +#include #include namespace gtsam { @@ -114,6 +115,22 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; + /// Create new factor by summing all values with the same separator values + virtual DiscreteFactor::shared_ptr sum(size_t nrFrontals) const = 0; + + /// Create new factor by summing all values with the same separator values + virtual DiscreteFactor::shared_ptr sum(const Ordering& keys) const = 0; + + /// Create new factor by maximizing over all values with the same separator. + virtual DiscreteFactor::shared_ptr max(size_t nrFrontals) const = 0; + + /// Create new factor by maximizing over all values with the same separator. + virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0; + + /// divide by factor f (safely) + virtual DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& f) const = 0; + /** * Get the number of non-zero values contained in this factor. * It could be much smaller than `prod_{key}(cardinality(key))`. diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 29cbd5e9b..e452b5be0 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -197,6 +197,16 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { return apply(f, safe_div); } + /// divide by factor f (pointer version) + DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& f) const override { + if (auto derived = std::dynamic_pointer_cast(f)) { + return std::make_shared(apply(*derived, safe_div)); + } else { + throw std::runtime_error("Cannot convert DiscreteFactor to Table Factor"); + } + } + /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override; @@ -205,22 +215,22 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { DiscreteKeys parent_keys) const; /// Create new factor by summing all values with the same separator values - shared_ptr sum(size_t nrFrontals) const { + DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { return combine(nrFrontals, Ring::add); } /// Create new factor by summing all values with the same separator values - shared_ptr sum(const Ordering& keys) const { + DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { return combine(keys, Ring::add); } /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(size_t nrFrontals) const { + DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { return combine(nrFrontals, Ring::max); } /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(const Ordering& keys) const { + DiscreteFactor::shared_ptr max(const Ordering& keys) const override { return combine(keys, Ring::max); } From 6c4546779a2dff7a03a1f1cc8978e3facbf3cd16 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 7 Dec 2024 18:20:50 -0500 Subject: [PATCH 07/86] add timing info --- gtsam/discrete/DiscreteFactorGraph.cpp | 24 +++++++++++++++--------- gtsam/discrete/TableFactor.cpp | 6 ++++++ 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 29bd1f9ac..904108afb 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -144,12 +145,15 @@ namespace gtsam { EliminateForMPE(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { // PRODUCT: multiply all factors - gttic(product); + gttic_(MPEProduct); DiscreteFactor::shared_ptr product = factors.product(); - gttoc(product); + gttoc_(MPEProduct); + + gttic_(Normalize); // Normalize the product product = Normalize(product); + gttoc_(Normalize); // max out frontals, this is the factor on the separator gttic(max); @@ -229,17 +233,19 @@ namespace gtsam { EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { // PRODUCT: multiply all factors - gttic(product); + gttic_(product); DiscreteFactor::shared_ptr product = factors.product(); - gttoc(product); + gttoc_(product); - // Normalize the product + gttic_(Normalize); + // Normalize the product product = Normalize(product); + gttoc_(Normalize); // sum out frontals, this is the factor on the separator - gttic(sum); + gttic_(sum); DecisionTreeFactor::shared_ptr sum = product->sum(frontalKeys); - gttoc(sum); + gttoc_(sum); // Ordering keys for the conditional so that frontalKeys are really in front Ordering orderedKeys; @@ -249,10 +255,10 @@ namespace gtsam { sum->keys().end()); // now divide product/sum to get conditional - gttic(divide); + gttic_(divide); auto conditional = std::make_shared(product, *sum, orderedKeys); - gttoc(divide); + gttoc_(divide); return {conditional, sum}; } diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 7cf520973..b867fa916 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -72,7 +72,9 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys, */ std::vector ComputeLeafOrdering(const DiscreteKeys& dkeys, const DecisionTreeFactor& dt) { + gttic_(ComputeLeafOrdering); std::vector probs = dt.probabilities(); + gttoc_(ComputeLeafOrdering); std::vector ordered; size_t n = dkeys[0].second; @@ -180,12 +182,16 @@ DiscreteFactor::shared_ptr TableFactor::operator*( /* ************************************************************************ */ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { + gttic_(toDecisionTreeFactor); DiscreteKeys dkeys = discreteKeys(); std::vector table; for (auto i = 0; i < sparse_table_.size(); i++) { table.push_back(sparse_table_.coeff(i)); } + gttoc_(toDecisionTreeFactor); + gttic_(toDecisionTreeFactor_Constructor); DecisionTreeFactor f(dkeys, table); + gttoc_(toDecisionTreeFactor_Constructor); return f; } From b0ad350a20a410aa2658839a7707a910d7129bc6 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 7 Dec 2024 18:22:30 -0500 Subject: [PATCH 08/86] add note about toDecisionTreeFactor --- gtsam/discrete/TableFactor.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index b867fa916..53131616d 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -182,14 +182,13 @@ DiscreteFactor::shared_ptr TableFactor::operator*( /* ************************************************************************ */ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { - gttic_(toDecisionTreeFactor); DiscreteKeys dkeys = discreteKeys(); std::vector table; for (auto i = 0; i < sparse_table_.size(); i++) { table.push_back(sparse_table_.coeff(i)); } - gttoc_(toDecisionTreeFactor); gttic_(toDecisionTreeFactor_Constructor); + // NOTE(Varun): This constructor is really expensive!! DecisionTreeFactor f(dkeys, table); gttoc_(toDecisionTreeFactor_Constructor); return f; From 306a3bae527af21cd121cff602bfa97fa3a5966f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 7 Dec 2024 18:25:27 -0500 Subject: [PATCH 09/86] kill toDecisionTreeFactor to force rethink --- gtsam/discrete/DecisionTreeFactor.h | 3 --- gtsam/discrete/DiscreteFactor.h | 2 -- gtsam/discrete/TableFactor.cpp | 14 -------------- gtsam/discrete/TableFactor.h | 3 --- 4 files changed, 22 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 8445c5332..642187ff1 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -169,9 +169,6 @@ namespace gtsam { } } - /// Convert into a decision tree - DecisionTreeFactor toDecisionTreeFactor() const override { return *this; } - /// Create new factor by summing all values with the same separator values DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { return combine(nrFrontals, ADT::Ring::add); diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 1ada7b7b2..29984d795 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -113,8 +113,6 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { virtual DiscreteFactor::shared_ptr operator*( const DiscreteFactor::shared_ptr&) const = 0; - virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; - /// Create new factor by summing all values with the same separator values virtual DiscreteFactor::shared_ptr sum(size_t nrFrontals) const = 0; diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 53131616d..a8adf0918 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -180,20 +180,6 @@ DiscreteFactor::shared_ptr TableFactor::operator*( } } -/* ************************************************************************ */ -DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { - DiscreteKeys dkeys = discreteKeys(); - std::vector table; - for (auto i = 0; i < sparse_table_.size(); i++) { - table.push_back(sparse_table_.coeff(i)); - } - gttic_(toDecisionTreeFactor_Constructor); - // NOTE(Varun): This constructor is really expensive!! - DecisionTreeFactor f(dkeys, table); - gttoc_(toDecisionTreeFactor_Constructor); - return f; -} - /* ************************************************************************ */ TableFactor TableFactor::choose(const DiscreteValues parent_assign, DiscreteKeys parent_keys) const { diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index e452b5be0..12266b2a5 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -207,9 +207,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { } } - /// Convert into a decisiontree - DecisionTreeFactor toDecisionTreeFactor() const override; - /// Create a TableFactor that is a subset of this TableFactor TableFactor choose(const DiscreteValues assignments, DiscreteKeys parent_keys) const; From 2cd2ab0a43ea57c33392371d7ec1d285b1f005c3 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 7 Dec 2024 18:25:40 -0500 Subject: [PATCH 10/86] DiscreteDistribution from TableFactor --- gtsam/discrete/DiscreteDistribution.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/DiscreteDistribution.h b/gtsam/discrete/DiscreteDistribution.h index 4b690da15..abe8f7933 100644 --- a/gtsam/discrete/DiscreteDistribution.h +++ b/gtsam/discrete/DiscreteDistribution.h @@ -40,10 +40,14 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional { /// Default constructor needed for serialization. DiscreteDistribution() {} - /// Constructor from factor. + /// Constructor from DecisionTreeFactor. explicit DiscreteDistribution(const DecisionTreeFactor& f) : Base(f.size(), f) {} + /// Constructor from TableFactor. + explicit DiscreteDistribution(const TableFactor& f) + : Base(f.size(), f) {} + /** * Construct from a Signature. * From 9f88a360dfd54f1e00361344f24d9d2ef085b50f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 7 Dec 2024 18:29:13 -0500 Subject: [PATCH 11/86] make evaluate use the Assignment base class --- gtsam/discrete/DecisionTreeFactor.h | 4 ++-- gtsam/discrete/DiscreteFactor.h | 3 +++ gtsam/discrete/TableFactor.cpp | 2 +- gtsam/discrete/TableFactor.h | 9 ++++++--- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 642187ff1..bf0e23b50 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -129,9 +129,9 @@ namespace gtsam { /// @name Standard Interface /// @{ - /// Calculate probability for given values `x`, + /// Calculate probability for given values, /// is just look up in AlgebraicDecisionTree. - double evaluate(const Assignment& values) const { + double evaluate(const Assignment& values) const override { return ADT::operator()(values); } diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 29984d795..23abf725e 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -129,6 +129,9 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { virtual DiscreteFactor::shared_ptr operator/( const DiscreteFactor::shared_ptr& f) const = 0; + /// Calculate probability for given values + virtual double evaluate(const Assignment& values) const = 0; + /** * Get the number of non-zero values contained in this factor. * It could be much smaller than `prod_{key}(cardinality(key))`. diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index a8adf0918..727c96ce4 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -135,7 +135,7 @@ bool TableFactor::equals(const DiscreteFactor& other, double tol) const { } /* ************************************************************************ */ -double TableFactor::operator()(const DiscreteValues& values) const { +double TableFactor::operator()(const Assignment& values) const { // a b c d => D * (C * (B * (a) + b) + c) + d uint64_t idx = 0, card = 1; for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) { diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 12266b2a5..ea222ca5c 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -169,14 +169,17 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { // /// @name Standard Interface // /// @{ - /// Calculate probability for given values `x`, + /// Calculate probability for given values, /// is just look up in TableFactor. - double evaluate(const DiscreteValues& values) const { + double evaluate(const Assignment& values) const override { return operator()(values); } /// Evaluate probability distribution, sugar. - double operator()(const DiscreteValues& values) const override; + double operator()(const Assignment& values) const; + double operator()(const DiscreteValues& values) const override { + return operator()(Assignment(values)); + } /// Calculate error for DiscreteValues `x`, is -log(probability). double error(const DiscreteValues& values) const override; From 2a3b5e62b785c2c9de4e414f5dcb1551569a297c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 7 Dec 2024 18:59:11 -0500 Subject: [PATCH 12/86] use Assignment for evaluate since it is the base class --- gtsam/discrete/DiscreteFactor.h | 2 +- gtsam/discrete/TableFactor.h | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 23abf725e..4470d97a7 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -94,7 +94,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { size_t cardinality(Key j) const { return cardinalities_.at(j); } /// Find value for given assignment of values to variables - virtual double operator()(const DiscreteValues&) const = 0; + virtual double operator()(const DiscreteValues& values) const = 0; /// Error is just -log(value) virtual double error(const DiscreteValues& values) const; diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index ea222ca5c..8fb04fcba 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -177,9 +177,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { /// Evaluate probability distribution, sugar. double operator()(const Assignment& values) const; - double operator()(const DiscreteValues& values) const override { - return operator()(Assignment(values)); - } /// Calculate error for DiscreteValues `x`, is -log(probability). double error(const DiscreteValues& values) const override; From fff8458d6bcc0ec54a8c7ebbf0145c3b8b91aafc Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 7 Dec 2024 19:03:26 -0500 Subject: [PATCH 13/86] remove TableFactor constructor in DiscreteDistribution --- gtsam/discrete/DiscreteDistribution.h | 4 ---- 1 file changed, 4 deletions(-) diff --git a/gtsam/discrete/DiscreteDistribution.h b/gtsam/discrete/DiscreteDistribution.h index abe8f7933..09ea50332 100644 --- a/gtsam/discrete/DiscreteDistribution.h +++ b/gtsam/discrete/DiscreteDistribution.h @@ -44,10 +44,6 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional { explicit DiscreteDistribution(const DecisionTreeFactor& f) : Base(f.size(), f) {} - /// Constructor from TableFactor. - explicit DiscreteDistribution(const TableFactor& f) - : Base(f.size(), f) {} - /** * Construct from a Signature. * From 295b965b6894574c11786dbef03edc9022311d82 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 7 Dec 2024 19:09:45 -0500 Subject: [PATCH 14/86] use Assignment since it is a base class --- gtsam/discrete/DecisionTreeFactor.h | 2 +- gtsam/discrete/DiscreteConditional.h | 2 +- gtsam/discrete/DiscreteFactor.h | 2 +- gtsam/discrete/TableFactor.h | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index bf0e23b50..688cd85a6 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -136,7 +136,7 @@ namespace gtsam { } /// Evaluate probability distribution, sugar. - double operator()(const DiscreteValues& values) const override { + double operator()(const Assignment& values) const override { return ADT::operator()(values); } diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index f59e29285..ce4fb96e5 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -169,7 +169,7 @@ class GTSAM_EXPORT DiscreteConditional } /// Evaluate, just look up in AlgebraicDecisionTree - double evaluate(const DiscreteValues& values) const { + double evaluate(const Assignment& values) const override { return ADT::operator()(values); } diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 4470d97a7..4c1d0afb1 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -94,7 +94,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { size_t cardinality(Key j) const { return cardinalities_.at(j); } /// Find value for given assignment of values to variables - virtual double operator()(const DiscreteValues& values) const = 0; + virtual double operator()(const Assignment& values) const = 0; /// Error is just -log(value) virtual double error(const DiscreteValues& values) const; diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 8fb04fcba..497c42dc2 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -176,7 +176,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { } /// Evaluate probability distribution, sugar. - double operator()(const Assignment& values) const; + double operator()(const Assignment& values) const override; /// Calculate error for DiscreteValues `x`, is -log(probability). double error(const DiscreteValues& values) const override; From 261038f93620185ede195605a021c30c5b71f4d8 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 7 Dec 2024 19:09:56 -0500 Subject: [PATCH 15/86] fix DiscreteConditional constructor --- gtsam/discrete/DiscreteConditional.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 5ab0c59ec..92086d143 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -43,7 +43,9 @@ template class GTSAM_EXPORT /* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const size_t nrFrontals, const DecisionTreeFactor& f) - : BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {} + : BaseFactor(f / (*std::dynamic_pointer_cast( + f.sum(nrFrontals)))), + BaseConditional(nrFrontals) {} /* ************************************************************************** */ DiscreteConditional::DiscreteConditional(size_t nrFrontals, From 20d6d09e06e705d8384d685226239c109748e7a2 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 7 Dec 2024 19:12:06 -0500 Subject: [PATCH 16/86] use DiscreteFactor everywhere in DiscreteFactorGraph.cpp --- gtsam/discrete/DiscreteFactorGraph.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 904108afb..a4f92e267 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -55,7 +55,7 @@ namespace gtsam { DiscreteKeys DiscreteFactorGraph::discreteKeys() const { DiscreteKeys result; for (auto&& factor : *this) { - if (auto p = std::dynamic_pointer_cast(factor)) { + if (auto p = std::dynamic_pointer_cast(factor)) { DiscreteKeys factor_keys = p->discreteKeys(); result.insert(result.end(), factor_keys.begin(), factor_keys.end()); } @@ -157,7 +157,7 @@ namespace gtsam { // max out frontals, this is the factor on the separator gttic(max); - DecisionTreeFactor::shared_ptr max = product->max(frontalKeys); + DiscreteFactor::shared_ptr max = product->max(frontalKeys); gttoc(max); // Ordering keys for the conditional so that frontalKeys are really in front @@ -244,7 +244,7 @@ namespace gtsam { // sum out frontals, this is the factor on the separator gttic_(sum); - DecisionTreeFactor::shared_ptr sum = product->sum(frontalKeys); + DiscreteFactor::shared_ptr sum = product->sum(frontalKeys); gttoc_(sum); // Ordering keys for the conditional so that frontalKeys are really in front @@ -257,8 +257,8 @@ namespace gtsam { // now divide product/sum to get conditional gttic_(divide); auto conditional = - std::make_shared(product, *sum, orderedKeys); - gttoc_(divide); + std::make_shared(product, sum, orderedKeys); + gttoc(divide); return {conditional, sum}; } From 32b6bc0a37da20e49cc8d390dc04b3ff9bb77548 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 7 Dec 2024 19:18:42 -0500 Subject: [PATCH 17/86] update DiscreteConditional --- gtsam/discrete/DiscreteConditional.cpp | 19 ++++++++++--------- gtsam/discrete/DiscreteConditional.h | 16 ++++++++-------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 92086d143..2f900afbe 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -37,8 +37,7 @@ using std::vector; namespace gtsam { // Instantiate base class -template class GTSAM_EXPORT - Conditional; +template class GTSAM_EXPORT Conditional; /* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const size_t nrFrontals, @@ -54,15 +53,17 @@ DiscreteConditional::DiscreteConditional(size_t nrFrontals, : BaseFactor(keys, potentials), BaseConditional(nrFrontals) {} /* ************************************************************************** */ -DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal) - : BaseFactor(joint / marginal), - BaseConditional(joint.size() - marginal.size()) {} +DiscreteConditional::DiscreteConditional( + const DiscreteFactor::shared_ptr& joint, + const DiscreteFactor::shared_ptr& marginal) + : BaseFactor(*std::dynamic_pointer_cast( + joint->operator/(marginal))), + BaseConditional(joint->size() - marginal->size()) {} /* ************************************************************************** */ -DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal, - const Ordering& orderedKeys) +DiscreteConditional::DiscreteConditional( + const DiscreteFactor::shared_ptr& joint, + const DiscreteFactor::shared_ptr& marginal, const Ordering& orderedKeys) : DiscreteConditional(joint, marginal) { keys_.clear(); keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end()); diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index ce4fb96e5..ec2c5c38d 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -110,16 +110,16 @@ class GTSAM_EXPORT DiscreteConditional * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) * Assumes but *does not check* that f(Y)=sum_X f(X,Y). */ - DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal); + DiscreteConditional(const DiscreteFactor::shared_ptr& joint, + const DiscreteFactor::shared_ptr& marginal); /** * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) * Assumes but *does not check* that f(Y)=sum_X f(X,Y). * Makes sure the keys are ordered as given. Does not check orderedKeys. */ - DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal, + DiscreteConditional(const DiscreteFactor::shared_ptr& joint, + const DiscreteFactor::shared_ptr& marginal, const Ordering& orderedKeys); /** @@ -173,8 +173,8 @@ class GTSAM_EXPORT DiscreteConditional return ADT::operator()(values); } - using DecisionTreeFactor::error; ///< DiscreteValues version - using DecisionTreeFactor::operator(); ///< DiscreteValues version + using DiscreteFactor::error; ///< DiscreteValues version + using DiscreteFactor::operator(); ///< DiscreteValues version /** * @brief restrict to given *parent* values. @@ -192,11 +192,11 @@ class GTSAM_EXPORT DiscreteConditional shared_ptr choose(const DiscreteValues& given) const; /** Convert to a likelihood factor by providing value before bar. */ - DecisionTreeFactor::shared_ptr likelihood( + DiscreteFactor::shared_ptr likelihood( const DiscreteValues& frontalValues) const; /** Single variable version of likelihood. */ - DecisionTreeFactor::shared_ptr likelihood(size_t frontal) const; + DiscreteFactor::shared_ptr likelihood(size_t frontal) const; /** * sample From 38563da342f233f81f356d76cc78d034db6b9d4f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 7 Dec 2024 19:24:04 -0500 Subject: [PATCH 18/86] Revert "kill toDecisionTreeFactor to force rethink" This reverts commit 306a3bae527af21cd121cff602bfa97fa3a5966f. --- gtsam/discrete/DecisionTreeFactor.h | 3 +++ gtsam/discrete/DiscreteFactor.h | 2 ++ gtsam/discrete/TableFactor.cpp | 14 ++++++++++++++ gtsam/discrete/TableFactor.h | 3 +++ 4 files changed, 22 insertions(+) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 688cd85a6..f8f2835e5 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -169,6 +169,9 @@ namespace gtsam { } } + /// Convert into a decision tree + DecisionTreeFactor toDecisionTreeFactor() const override { return *this; } + /// Create new factor by summing all values with the same separator values DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { return combine(nrFrontals, ADT::Ring::add); diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 4c1d0afb1..e2d32e828 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -113,6 +113,8 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { virtual DiscreteFactor::shared_ptr operator*( const DiscreteFactor::shared_ptr&) const = 0; + virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; + /// Create new factor by summing all values with the same separator values virtual DiscreteFactor::shared_ptr sum(size_t nrFrontals) const = 0; diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 727c96ce4..50d15ff5e 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -180,6 +180,20 @@ DiscreteFactor::shared_ptr TableFactor::operator*( } } +/* ************************************************************************ */ +DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { + DiscreteKeys dkeys = discreteKeys(); + std::vector table; + for (auto i = 0; i < sparse_table_.size(); i++) { + table.push_back(sparse_table_.coeff(i)); + } + gttic_(toDecisionTreeFactor_Constructor); + // NOTE(Varun): This constructor is really expensive!! + DecisionTreeFactor f(dkeys, table); + gttoc_(toDecisionTreeFactor_Constructor); + return f; +} + /* ************************************************************************ */ TableFactor TableFactor::choose(const DiscreteValues parent_assign, DiscreteKeys parent_keys) const { diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 497c42dc2..47a7c6bbb 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -207,6 +207,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { } } + /// Convert into a decisiontree + DecisionTreeFactor toDecisionTreeFactor() const override; + /// Create a TableFactor that is a subset of this TableFactor TableFactor choose(const DiscreteValues assignments, DiscreteKeys parent_keys) const; From 9633ad1fd8631a3abda75a25e5217e516119e458 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 7 Dec 2024 19:24:29 -0500 Subject: [PATCH 19/86] make DiscreteConditional::likelihood match the declaration --- gtsam/discrete/DiscreteConditional.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 2f900afbe..048f35e5b 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -201,7 +201,7 @@ DiscreteConditional::shared_ptr DiscreteConditional::choose( } /* ************************************************************************** */ -DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( +DiscreteFactor::shared_ptr DiscreteConditional::likelihood( const DiscreteValues& frontalValues) const { // Get the big decision tree with all the levels, and then go down the // branches based on the value of the frontal variables. @@ -226,7 +226,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( } /* ****************************************************************************/ -DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( +DiscreteFactor::shared_ptr DiscreteConditional::likelihood( size_t frontal) const { if (nrFrontals() != 1) throw std::invalid_argument( From 0b3477fc5ab428a6bbc65ca2b95e02d84b982593 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 7 Dec 2024 19:39:44 -0500 Subject: [PATCH 20/86] get different classes to play nicely --- gtsam/discrete/DiscreteConditional.h | 4 ++-- gtsam/discrete/DiscreteDistribution.h | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index ec2c5c38d..77003f232 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -173,8 +173,8 @@ class GTSAM_EXPORT DiscreteConditional return ADT::operator()(values); } - using DiscreteFactor::error; ///< DiscreteValues version - using DiscreteFactor::operator(); ///< DiscreteValues version + using DecisionTreeFactor::error; ///< DiscreteValues version + using DecisionTreeFactor::operator(); ///< DiscreteValues version /** * @brief restrict to given *parent* values. diff --git a/gtsam/discrete/DiscreteDistribution.h b/gtsam/discrete/DiscreteDistribution.h index 09ea50332..28e509f15 100644 --- a/gtsam/discrete/DiscreteDistribution.h +++ b/gtsam/discrete/DiscreteDistribution.h @@ -86,8 +86,7 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional { double operator()(size_t value) const; /// We also want to keep the Base version, taking DiscreteValues: - // TODO(dellaert): does not play well with wrapper! - // using Base::operator(); + using Base::operator(); /// Return entire probability mass function. std::vector pmf() const; From 1d7918841797e3e970587cbec1c60aeaafa21566 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 7 Dec 2024 19:40:20 -0500 Subject: [PATCH 21/86] compiles --- gtsam/discrete/DiscreteFactorGraph.cpp | 6 ++++-- gtsam/discrete/DiscreteLookupDAG.h | 7 +++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index a4f92e267..c2a16159b 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -170,8 +170,10 @@ namespace gtsam { // Make lookup with product gttic(lookup); size_t nrFrontals = frontalKeys.size(); - auto lookup = std::make_shared(nrFrontals, - orderedKeys, product); + //TODO(Varun): Should accept a DiscreteFactor::shared_ptr + auto lookup = std::make_shared( + nrFrontals, orderedKeys, + *std::dynamic_pointer_cast(product)); gttoc(lookup); return {std::dynamic_pointer_cast(lookup), max}; diff --git a/gtsam/discrete/DiscreteLookupDAG.h b/gtsam/discrete/DiscreteLookupDAG.h index f077a13d9..c811c4c49 100644 --- a/gtsam/discrete/DiscreteLookupDAG.h +++ b/gtsam/discrete/DiscreteLookupDAG.h @@ -18,6 +18,7 @@ #pragma once #include +#include #include #include @@ -54,6 +55,12 @@ class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional { const ADT& potentials) : DiscreteConditional(nFrontals, keys, potentials) {} + //TODO(Varun): Should accept a DiscreteFactor::shared_ptr + DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys, + const TableFactor& potentials) + : DiscreteConditional(nFrontals, keys, + potentials.toDecisionTreeFactor()) {} + /// GTSAM-style print void print( const std::string& s = "Discrete Lookup Table: ", From 77578512f80600d27f25cdfeb0c22526e64ce7b9 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 7 Dec 2024 21:45:22 -0500 Subject: [PATCH 22/86] timing --- gtsam/discrete/DiscreteFactorGraph.cpp | 11 +++++------ gtsam/discrete/DiscreteLookupDAG.h | 10 +++++++++- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index c2a16159b..fa9d9bdc7 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -22,7 +22,6 @@ #include #include #include -#include #include #include @@ -127,14 +126,14 @@ namespace gtsam { static DiscreteFactor::shared_ptr Normalize( const DiscreteFactor::shared_ptr& product) { // Max over all the potentials by pretending all keys are frontal: - gttic(DiscreteFindMax); + gttic_(DiscreteFindMax); auto normalization = product->max(product->size()); - gttoc(DiscreteFindMax); + gttoc_(DiscreteFindMax); - gttic(DiscreteNormalization); + gttic_(DiscreteNormalization); // Normalize the product factor to prevent underflow. auto normalized_product = product->operator/(normalization); - gttoc(DiscreteNormalization); + gttoc_(DiscreteNormalization); return normalized_product; } @@ -260,7 +259,7 @@ namespace gtsam { gttic_(divide); auto conditional = std::make_shared(product, sum, orderedKeys); - gttoc(divide); + gttoc_(divide); return {conditional, sum}; } diff --git a/gtsam/discrete/DiscreteLookupDAG.h b/gtsam/discrete/DiscreteLookupDAG.h index c811c4c49..21181d374 100644 --- a/gtsam/discrete/DiscreteLookupDAG.h +++ b/gtsam/discrete/DiscreteLookupDAG.h @@ -55,7 +55,15 @@ class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional { const ADT& potentials) : DiscreteConditional(nFrontals, keys, potentials) {} - //TODO(Varun): Should accept a DiscreteFactor::shared_ptr + /** + * @brief Construct a new Discrete Lookup Table object + * + * @param nFrontals number of frontal variables + * @param keys a sorted list of gtsam::Keys + * @param potentials Discrete potentials as a TableFactor. + * + * //TODO(Varun): Should accept a DiscreteFactor::shared_ptr + */ DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys, const TableFactor& potentials) : DiscreteConditional(nFrontals, keys, From 9844a555d4debcf3e0cb7c6047c7b81f8701d27b Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 8 Dec 2024 10:34:02 -0500 Subject: [PATCH 23/86] move evaluate and operator() next to each other --- gtsam/discrete/DiscreteFactor.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index e2d32e828..3151afe80 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -93,6 +93,9 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { size_t cardinality(Key j) const { return cardinalities_.at(j); } + /// Calculate probability for given values + virtual double evaluate(const Assignment& values) const = 0; + /// Find value for given assignment of values to variables virtual double operator()(const Assignment& values) const = 0; @@ -131,9 +134,6 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { virtual DiscreteFactor::shared_ptr operator/( const DiscreteFactor::shared_ptr& f) const = 0; - /// Calculate probability for given values - virtual double evaluate(const Assignment& values) const = 0; - /** * Get the number of non-zero values contained in this factor. * It could be much smaller than `prod_{key}(cardinality(key))`. From aa25ccfa6ecbaef12e549b3e0ac6a09d5c2d30de Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 8 Dec 2024 11:15:57 -0500 Subject: [PATCH 24/86] implement evaluate in DiscreteFactor --- gtsam/discrete/DecisionTreeFactor.h | 5 ----- gtsam/discrete/DiscreteConditional.cpp | 2 +- gtsam/discrete/DiscreteConditional.h | 5 ----- gtsam/discrete/DiscreteFactor.h | 14 ++++++++++++-- gtsam/discrete/TableFactor.h | 8 +------- 5 files changed, 14 insertions(+), 20 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index f8f2835e5..85491b909 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -131,11 +131,6 @@ namespace gtsam { /// Calculate probability for given values, /// is just look up in AlgebraicDecisionTree. - double evaluate(const Assignment& values) const override { - return ADT::operator()(values); - } - - /// Evaluate probability distribution, sugar. double operator()(const Assignment& values) const override { return ADT::operator()(values); } diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 048f35e5b..c90f7f9a0 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -476,7 +476,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter, /* ************************************************************************* */ double DiscreteConditional::evaluate(const HybridValues& x) const { - return this->evaluate(x.discrete()); + return this->operator()(x.discrete()); } /* ************************************************************************* */ diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 77003f232..29292f57e 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -168,11 +168,6 @@ class GTSAM_EXPORT DiscreteConditional static_cast(this)->print(s, formatter); } - /// Evaluate, just look up in AlgebraicDecisionTree - double evaluate(const Assignment& values) const override { - return ADT::operator()(values); - } - using DecisionTreeFactor::error; ///< DiscreteValues version using DecisionTreeFactor::operator(); ///< DiscreteValues version diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 3151afe80..5b4665d4d 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -93,8 +93,18 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { size_t cardinality(Key j) const { return cardinalities_.at(j); } - /// Calculate probability for given values - virtual double evaluate(const Assignment& values) const = 0; + /** + * @brief Calculate probability for given values. + * Calls specialized evaluation under the hood. + * + * Note: Uses Assignment as it is the base class of DiscreteValues. + * + * @param values Discrete assignment. + * @return double + */ + double evaluate(const Assignment& values) const { + return operator()(values); + } /// Find value for given assignment of values to variables virtual double operator()(const Assignment& values) const = 0; diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 47a7c6bbb..d8df12821 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -169,13 +169,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { // /// @name Standard Interface // /// @{ - /// Calculate probability for given values, - /// is just look up in TableFactor. - double evaluate(const Assignment& values) const override { - return operator()(values); - } - - /// Evaluate probability distribution, sugar. + /// Evaluate probability distribution, is just look up in TableFactor. double operator()(const Assignment& values) const override; /// Calculate error for DiscreteValues `x`, is -log(probability). From 90d7e21941d324df59ceda3487890d403a1af953 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 8 Dec 2024 11:19:35 -0500 Subject: [PATCH 25/86] change from DiscreteValues to Assignment --- gtsam_unstable/discrete/AllDiff.cpp | 2 +- gtsam_unstable/discrete/AllDiff.h | 2 +- gtsam_unstable/discrete/BinaryAllDiff.h | 2 +- gtsam_unstable/discrete/Domain.cpp | 2 +- gtsam_unstable/discrete/Domain.h | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/gtsam_unstable/discrete/AllDiff.cpp b/gtsam_unstable/discrete/AllDiff.cpp index 2bd9e6dfd..a450605b3 100644 --- a/gtsam_unstable/discrete/AllDiff.cpp +++ b/gtsam_unstable/discrete/AllDiff.cpp @@ -26,7 +26,7 @@ void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const { } /* ************************************************************************* */ -double AllDiff::operator()(const DiscreteValues& values) const { +double AllDiff::operator()(const Assignment& values) const { std::set taken; // record values taken by keys for (Key dkey : keys_) { size_t value = values.at(dkey); // get the value for that key diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index 42a255bbf..7f539f4c6 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -45,7 +45,7 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { } /// Calculate value = expensive ! - double operator()(const DiscreteValues& values) const override; + double operator()(const Assignment& values) const override; /// Convert into a decisiontree, can be *very* expensive ! DecisionTreeFactor toDecisionTreeFactor() const override; diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index 22acfb092..0e2fce109 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -47,7 +47,7 @@ class BinaryAllDiff : public Constraint { } /// Calculate value - double operator()(const DiscreteValues& values) const override { + double operator()(const Assignment& values) const override { return (double)(values.at(keys_[0]) != values.at(keys_[1])); } diff --git a/gtsam_unstable/discrete/Domain.cpp b/gtsam_unstable/discrete/Domain.cpp index bbbc87667..752228c18 100644 --- a/gtsam_unstable/discrete/Domain.cpp +++ b/gtsam_unstable/discrete/Domain.cpp @@ -30,7 +30,7 @@ string Domain::base1Str() const { } /* ************************************************************************* */ -double Domain::operator()(const DiscreteValues& values) const { +double Domain::operator()(const Assignment& values) const { return contains(values.at(key())); } diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index ba3771eca..cd11fc8d9 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -82,7 +82,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { bool contains(size_t value) const { return values_.count(value) > 0; } /// Calculate value - double operator()(const DiscreteValues& values) const override; + double operator()(const Assignment& values) const override; /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override; From 6665659e9de8a9f89dfda0c468692d93052d0f7a Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 8 Dec 2024 11:23:04 -0500 Subject: [PATCH 26/86] use BaseFactor instead of DecisionTreeFactor --- gtsam/discrete/DiscreteConditional.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 29292f57e..3dbadb3e4 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -168,8 +168,8 @@ class GTSAM_EXPORT DiscreteConditional static_cast(this)->print(s, formatter); } - using DecisionTreeFactor::error; ///< DiscreteValues version - using DecisionTreeFactor::operator(); ///< DiscreteValues version + using BaseFactor::error; ///< DiscreteValues version + using BaseFactor::operator(); ///< DiscreteValues version /** * @brief restrict to given *parent* values. From e6b65285217c84f49a07a641c187719c1bbfbf43 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 8 Dec 2024 12:36:23 -0500 Subject: [PATCH 27/86] common definitions of Unary, UnaryAssignment and Binary --- gtsam/discrete/DecisionTreeFactor.cpp | 12 ++++++------ gtsam/discrete/DecisionTreeFactor.h | 16 ++++++++++------ gtsam/discrete/DiscreteFactor.h | 5 +++++ 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index e53f8cb90..93a7921aa 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -94,7 +94,7 @@ namespace gtsam { } /* ************************************************************************ */ - DecisionTreeFactor DecisionTreeFactor::apply(ADT::Unary op) const { + DecisionTreeFactor DecisionTreeFactor::apply(Unary op) const { // apply operand ADT result = ADT::apply(op); // Make a new factor @@ -102,7 +102,7 @@ namespace gtsam { } /* ************************************************************************ */ - DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const { + DecisionTreeFactor DecisionTreeFactor::apply(UnaryAssignment op) const { // apply operand ADT result = ADT::apply(op); // Make a new factor @@ -111,7 +111,7 @@ namespace gtsam { /* ************************************************************************ */ DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f, - ADT::Binary op) const { + Binary op) const { map cs; // new cardinalities // make unique key-cardinality map for (Key j : keys()) cs[j] = cardinality(j); @@ -129,8 +129,8 @@ namespace gtsam { } /* ************************************************************************ */ - DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine( - size_t nrFrontals, ADT::Binary op) const { + DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals, + Binary op) const { if (nrFrontals > size()) { throw invalid_argument( "DecisionTreeFactor::combine: invalid number of frontal " @@ -157,7 +157,7 @@ namespace gtsam { /* ************************************************************************ */ DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine( - const Ordering& frontalKeys, ADT::Binary op) const { + const Ordering& frontalKeys, Binary op) const { if (frontalKeys.size() > size()) { throw invalid_argument( "DecisionTreeFactor::combine: invalid number of frontal " diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 983330f4a..f8679af59 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -51,6 +51,10 @@ namespace gtsam { typedef std::shared_ptr shared_ptr; typedef AlgebraicDecisionTree ADT; + using Base::Binary; + using Base::Unary; + using Base::UnaryAssignment; + /// @name Standard Constructors /// @{ @@ -140,7 +144,7 @@ namespace gtsam { double error(const DiscreteValues& values) const override; /// multiply two factors - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const { return apply(f, Ring::mul); } @@ -196,21 +200,21 @@ namespace gtsam { * Apply unary operator (*this) "op" f * @param op a unary operator that operates on AlgebraicDecisionTree */ - DecisionTreeFactor apply(ADT::Unary op) const; + DecisionTreeFactor apply(Unary op) const; /** * Apply unary operator (*this) "op" f * @param op a unary operator that operates on AlgebraicDecisionTree. Takes * both the assignment and the value. */ - DecisionTreeFactor apply(ADT::UnaryAssignment op) const; + DecisionTreeFactor apply(UnaryAssignment op) const; /** * Apply binary operator (*this) "op" f * @param f the second argument for op * @param op a binary operator that operates on AlgebraicDecisionTree */ - DecisionTreeFactor apply(const DecisionTreeFactor& f, ADT::Binary op) const; + DecisionTreeFactor apply(const DecisionTreeFactor& f, Binary op) const; /** * Combine frontal variables using binary operator "op" @@ -218,7 +222,7 @@ namespace gtsam { * @param op a binary operator that operates on AlgebraicDecisionTree * @return shared pointer to newly created DecisionTreeFactor */ - shared_ptr combine(size_t nrFrontals, ADT::Binary op) const; + shared_ptr combine(size_t nrFrontals, Binary op) const; /** * Combine frontal variables in an Ordering using binary operator "op" @@ -226,7 +230,7 @@ namespace gtsam { * @param op a binary operator that operates on AlgebraicDecisionTree * @return shared pointer to newly created DecisionTreeFactor */ - shared_ptr combine(const Ordering& keys, ADT::Binary op) const; + shared_ptr combine(const Ordering& keys, Binary op) const; /// Enumerate all values into a map from values to double. std::vector> enumerate() const; diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 5b4665d4d..6e9f69619 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -47,6 +47,11 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { using Values = DiscreteValues; ///< backwards compatibility + using Unary = std::function; + using UnaryAssignment = + std::function&, const double&)>; + using Binary = std::function; + protected: /// Map of Keys and their cardinalities. std::map cardinalities_; From f85284afb2f07ab82284628e45ef68949f15fa1a Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 8 Dec 2024 12:37:36 -0500 Subject: [PATCH 28/86] some cleanup based on previous commit --- gtsam/discrete/DecisionTreeFactor.h | 1 + gtsam/discrete/TableFactor.h | 5 ----- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index f8679af59..af0df47a2 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -51,6 +51,7 @@ namespace gtsam { typedef std::shared_ptr shared_ptr; typedef AlgebraicDecisionTree ADT; + // Needed since we have definitions in both DiscreteFactor and DecisionTree using Base::Binary; using Base::Unary; using Base::UnaryAssignment; diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 317596ba9..345cbc254 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -94,12 +94,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { typedef std::shared_ptr shared_ptr; typedef Eigen::SparseVector::InnerIterator SparseIt; typedef std::vector> AssignValList; - using Unary = std::function; - using UnaryAssignment = - std::function&, const double&)>; - using Binary = std::function; - public: /// @name Standard Constructors /// @{ From 5e86f7ee5122fc6cf4fac609fdf2a639c1e5079f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 8 Dec 2024 15:31:35 -0500 Subject: [PATCH 29/86] remove previously added code --- gtsam/discrete/DecisionTreeFactor.cpp | 11 ----------- gtsam/discrete/DecisionTreeFactor.h | 16 +--------------- gtsam/discrete/DiscreteConditional.cpp | 13 +++++++------ gtsam/discrete/DiscreteConditional.h | 8 ++++---- gtsam/discrete/DiscreteLookupDAG.cpp | 2 +- 5 files changed, 13 insertions(+), 37 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 93a7921aa..776d4bd90 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -82,17 +82,6 @@ namespace gtsam { ADT::print("", formatter); } - /* ************************************************************************ */ - DiscreteFactor::shared_ptr DecisionTreeFactor::operator*( - const DiscreteFactor::shared_ptr& f) const { - if (auto derived = std::dynamic_pointer_cast(f)) { - return std::make_shared(this->operator*(*derived)); - } else { - throw std::runtime_error( - "Cannot convert DiscreteFactor to DecisionTreeFactor"); - } - } - /* ************************************************************************ */ DecisionTreeFactor DecisionTreeFactor::apply(Unary op) const { // apply operand diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index af0df47a2..11793f984 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -145,13 +145,10 @@ namespace gtsam { double error(const DiscreteValues& values) const override; /// multiply two factors - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const { + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { return apply(f, Ring::mul); } - DiscreteFactor::shared_ptr operator*( - const DiscreteFactor::shared_ptr& f) const override; - static double safe_div(const double& a, const double& b); /// divide by factor f (safely) @@ -159,17 +156,6 @@ namespace gtsam { return apply(f, safe_div); } - /// divide by factor f (pointer version) - DiscreteFactor::shared_ptr operator/( - const DiscreteFactor::shared_ptr& f) const override { - if (auto derived = std::dynamic_pointer_cast(f)) { - return std::make_shared(apply(*derived, safe_div)); - } else { - throw std::runtime_error( - "Cannot convert DiscreteFactor to Table Factor"); - } - } - /// Convert into a decision tree DecisionTreeFactor toDecisionTreeFactor() const override { return *this; } diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 6dff84859..58acb21b0 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -38,7 +38,8 @@ using std::vector; namespace gtsam { // Instantiate base class -template class GTSAM_EXPORT Conditional; +template class GTSAM_EXPORT + Conditional; /* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const size_t nrFrontals, @@ -152,11 +153,11 @@ void DiscreteConditional::print(const string& s, /* ************************************************************************** */ bool DiscreteConditional::equals(const DiscreteFactor& other, double tol) const { - if (!dynamic_cast(&other)) { + if (!dynamic_cast(&other)) { return false; } else { - const DecisionTreeFactor& f(static_cast(other)); - return DecisionTreeFactor::equals(f, tol); + const BaseFactor& f(static_cast(other)); + return BaseFactor::equals(f, tol); } } @@ -377,7 +378,7 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter, ss << "*\n" << std::endl; if (nrParents() == 0) { // We have no parents, call factor method. - ss << DecisionTreeFactor::markdown(keyFormatter, names); + ss << BaseFactor::markdown(keyFormatter, names); return ss.str(); } @@ -429,7 +430,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter, ss << "

\n"; if (nrParents() == 0) { // We have no parents, call factor method. - ss << DecisionTreeFactor::html(keyFormatter, names); + ss << BaseFactor::html(keyFormatter, names); return ss.str(); } diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 3dbadb3e4..298a7b004 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -110,16 +110,16 @@ class GTSAM_EXPORT DiscreteConditional * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) * Assumes but *does not check* that f(Y)=sum_X f(X,Y). */ - DiscreteConditional(const DiscreteFactor::shared_ptr& joint, - const DiscreteFactor::shared_ptr& marginal); + DiscreteConditional(const DecisionTreeFactor& joint, + const DecisionTreeFactor& marginal); /** * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) * Assumes but *does not check* that f(Y)=sum_X f(X,Y). * Makes sure the keys are ordered as given. Does not check orderedKeys. */ - DiscreteConditional(const DiscreteFactor::shared_ptr& joint, - const DiscreteFactor::shared_ptr& marginal, + DiscreteConditional(const DecisionTreeFactor& joint, + const DecisionTreeFactor& marginal, const Ordering& orderedKeys); /** diff --git a/gtsam/discrete/DiscreteLookupDAG.cpp b/gtsam/discrete/DiscreteLookupDAG.cpp index 4d02e007b..ee381fe44 100644 --- a/gtsam/discrete/DiscreteLookupDAG.cpp +++ b/gtsam/discrete/DiscreteLookupDAG.cpp @@ -47,7 +47,7 @@ void DiscreteLookupTable::print(const std::string& s, } } cout << "):\n"; - ADT::print("", formatter); + BaseFactor::print("", formatter); cout << endl; } From 1c14a56f5d9351355107b181b87565fa0613856b Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 8 Dec 2024 15:58:07 -0500 Subject: [PATCH 30/86] revert changes to make code generic --- gtsam/discrete/DiscreteConditional.cpp | 16 +++++++--------- gtsam/discrete/DiscreteFactor.h | 10 +++------- gtsam/discrete/TableFactor.cpp | 9 ++------- gtsam/discrete/TableFactor.h | 14 ++------------ 4 files changed, 14 insertions(+), 35 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 58acb21b0..bd10e28b4 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -55,17 +55,15 @@ DiscreteConditional::DiscreteConditional(size_t nrFrontals, : BaseFactor(keys, potentials), BaseConditional(nrFrontals) {} /* ************************************************************************** */ -DiscreteConditional::DiscreteConditional( - const DiscreteFactor::shared_ptr& joint, - const DiscreteFactor::shared_ptr& marginal) - : BaseFactor(*std::dynamic_pointer_cast( - joint->operator/(marginal))), - BaseConditional(joint->size() - marginal->size()) {} +DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, + const DecisionTreeFactor& marginal) + : BaseFactor(joint / marginal), + BaseConditional(joint.size() - marginal.size()) {} /* ************************************************************************** */ -DiscreteConditional::DiscreteConditional( - const DiscreteFactor::shared_ptr& joint, - const DiscreteFactor::shared_ptr& marginal, const Ordering& orderedKeys) +DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, + const DecisionTreeFactor& marginal, + const Ordering& orderedKeys) : DiscreteConditional(joint, marginal) { keys_.clear(); keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end()); diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 6e9f69619..a6356a045 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -126,10 +126,9 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { /// Compute error for each assignment and return as a tree virtual AlgebraicDecisionTree errorTree() const; - /// Multiply in a DiscreteFactor and return the result as - /// DiscreteFactor - virtual DiscreteFactor::shared_ptr operator*( - const DiscreteFactor::shared_ptr&) const = 0; + /// Multiply in a DecisionTreeFactor and return the result as + /// DecisionTreeFactor + virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; @@ -145,9 +144,6 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { /// Create new factor by maximizing over all values with the same separator. virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0; - /// divide by factor f (safely) - virtual DiscreteFactor::shared_ptr operator/( - const DiscreteFactor::shared_ptr& f) const = 0; /** * Get the number of non-zero values contained in this factor. diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 50d15ff5e..a4947012e 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -171,13 +171,8 @@ double TableFactor::error(const HybridValues& values) const { } /* ************************************************************************ */ -DiscreteFactor::shared_ptr TableFactor::operator*( - const DiscreteFactor::shared_ptr& f) const { - if (auto derived = std::dynamic_pointer_cast(f)) { - return std::make_shared(this->operator*(*derived)); - } else { - throw std::runtime_error("Cannot convert DiscreteFactor to TableFactor"); - } +DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { + return toDecisionTreeFactor() * f; } /* ************************************************************************ */ diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 345cbc254..ba1d05fe9 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -161,9 +161,8 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { return apply(f, Ring::mul); }; - /// multiply with DiscreteFactor - DiscreteFactor::shared_ptr operator*( - const DiscreteFactor::shared_ptr& f) const override; + /// multiply with DecisionTreeFactor + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; static double safe_div(const double& a, const double& b); @@ -172,15 +171,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { return apply(f, safe_div); } - /// divide by factor f (pointer version) - DiscreteFactor::shared_ptr operator/( - const DiscreteFactor::shared_ptr& f) const override { - if (auto derived = std::dynamic_pointer_cast(f)) { - return std::make_shared(apply(*derived, safe_div)); - } else { - throw std::runtime_error("Cannot convert DiscreteFactor to Table Factor"); - } - } /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override; From b325150b3751b7f5c43b0268f76be8e60f10fe38 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 8 Dec 2024 16:18:42 -0500 Subject: [PATCH 31/86] revert DiscreteFactorGraph::product --- gtsam/discrete/DiscreteFactorGraph.cpp | 13 ++++--------- gtsam/discrete/DiscreteFactorGraph.h | 2 +- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index fa9d9bdc7..9e64b0f6d 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -64,15 +64,10 @@ namespace gtsam { } /* ************************************************************************* */ - DiscreteFactor::shared_ptr DiscreteFactorGraph::product() const { - sharedFactor result = this->at(0); - for (size_t i = 1; i < this->size(); ++i) { - const sharedFactor factor = this->at(i); - if (factor) { - // Predicated on the fact that all discrete factors are of a single type - // so there is no type-conversion happening which can be expensive. - result = result->operator*(factor); - } + DecisionTreeFactor DiscreteFactorGraph::product() const { + DecisionTreeFactor result; + for (const sharedFactor& factor : *this) { + if (factor) result = result * (*factor); } return result; } diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index a5324811c..43c48c2d0 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -147,7 +147,7 @@ class GTSAM_EXPORT DiscreteFactorGraph DiscreteKeys discreteKeys() const; /** return product of all factors as a single factor */ - DiscreteFactor::shared_ptr product() const; + DecisionTreeFactor product() const; /** * Evaluates the factor graph given values, returns the joint probability of From 0afc1984118d8bec1c2327f542cac8b22d11b96f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 8 Dec 2024 16:26:03 -0500 Subject: [PATCH 32/86] revert some DiscreteFactorGraph changes --- gtsam/discrete/DiscreteFactorGraph.cpp | 29 +++++++++++++------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 9e64b0f6d..04849985f 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -118,16 +118,17 @@ namespace gtsam { * @param product The product discrete factor. * @return DiscreteFactor::shared_ptr */ - static DiscreteFactor::shared_ptr Normalize( - const DiscreteFactor::shared_ptr& product) { + static DecisionTreeFactor Normalize(const DecisionTreeFactor& product) { // Max over all the potentials by pretending all keys are frontal: gttic_(DiscreteFindMax); - auto normalization = product->max(product->size()); + auto normalization = product.max(product.size()); gttoc_(DiscreteFindMax); gttic_(DiscreteNormalization); // Normalize the product factor to prevent underflow. - auto normalized_product = product->operator/(normalization); + auto normalized_product = + product / + (*std::dynamic_pointer_cast(normalization)); gttoc_(DiscreteNormalization); return normalized_product; @@ -140,7 +141,7 @@ namespace gtsam { const Ordering& frontalKeys) { // PRODUCT: multiply all factors gttic_(MPEProduct); - DiscreteFactor::shared_ptr product = factors.product(); + DecisionTreeFactor product = factors.product(); gttoc_(MPEProduct); gttic_(Normalize); @@ -151,23 +152,22 @@ namespace gtsam { // max out frontals, this is the factor on the separator gttic(max); - DiscreteFactor::shared_ptr max = product->max(frontalKeys); + DecisionTreeFactor::shared_ptr max = + std::dynamic_pointer_cast(product.max(frontalKeys)); gttoc(max); // Ordering keys for the conditional so that frontalKeys are really in front DiscreteKeys orderedKeys; for (auto&& key : frontalKeys) - orderedKeys.emplace_back(key, product->cardinality(key)); + orderedKeys.emplace_back(key, product.cardinality(key)); for (auto&& key : max->keys()) - orderedKeys.emplace_back(key, product->cardinality(key)); + orderedKeys.emplace_back(key, product.cardinality(key)); // Make lookup with product gttic(lookup); size_t nrFrontals = frontalKeys.size(); - //TODO(Varun): Should accept a DiscreteFactor::shared_ptr - auto lookup = std::make_shared( - nrFrontals, orderedKeys, - *std::dynamic_pointer_cast(product)); + auto lookup = + std::make_shared(nrFrontals, orderedKeys, product); gttoc(lookup); return {std::dynamic_pointer_cast(lookup), max}; @@ -230,7 +230,7 @@ namespace gtsam { const Ordering& frontalKeys) { // PRODUCT: multiply all factors gttic_(product); - DiscreteFactor::shared_ptr product = factors.product(); + DecisionTreeFactor product = factors.product(); gttoc_(product); gttic_(Normalize); @@ -240,7 +240,8 @@ namespace gtsam { // sum out frontals, this is the factor on the separator gttic_(sum); - DiscreteFactor::shared_ptr sum = product->sum(frontalKeys); + DecisionTreeFactor::shared_ptr sum = std::dynamic_pointer_cast( + product.sum(frontalKeys)); gttoc_(sum); // Ordering keys for the conditional so that frontalKeys are really in front From 975fe627d91fc3900e0553e4e043ddf60aef7e53 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 8 Dec 2024 16:58:19 -0500 Subject: [PATCH 33/86] add methods in gtsam_unstable --- gtsam_unstable/discrete/AllDiff.h | 16 ++++++++++++++++ gtsam_unstable/discrete/BinaryAllDiff.h | 16 ++++++++++++++++ gtsam_unstable/discrete/Domain.h | 16 ++++++++++++++++ gtsam_unstable/discrete/SingleValue.h | 18 +++++++++++++++++- 4 files changed, 65 insertions(+), 1 deletion(-) diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index 7f539f4c6..fb956146b 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -75,6 +75,22 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { /// Get the number of non-zero values contained in this factor. uint64_t nrValues() const override { return 1; }; + + DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { + throw std::runtime_error("Not implemented"); + } + + DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { + throw std::runtime_error("Not implemented"); + } + + DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { + throw std::runtime_error("Not implemented"); + } + + DiscreteFactor::shared_ptr max(const Ordering& keys) const override { + throw std::runtime_error("Not implemented"); + } }; } // namespace gtsam diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index 0e2fce109..fe04fd807 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -99,6 +99,22 @@ class BinaryAllDiff : public Constraint { /// Get the number of non-zero values contained in this factor. uint64_t nrValues() const override { return 1; }; + + DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { + throw std::runtime_error("Not implemented"); + } + + DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { + throw std::runtime_error("Not implemented"); + } + + DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { + throw std::runtime_error("Not implemented"); + } + + DiscreteFactor::shared_ptr max(const Ordering& keys) const override { + throw std::runtime_error("Not implemented"); + } }; } // namespace gtsam diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index cd11fc8d9..1fbd7b110 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -114,6 +114,22 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { /// Partially apply known values, domain version Constraint::shared_ptr partiallyApply(const Domains& domains) const override; + + DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { + throw std::runtime_error("Not implemented"); + } + + DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { + throw std::runtime_error("Not implemented"); + } + + DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { + throw std::runtime_error("Not implemented"); + } + + DiscreteFactor::shared_ptr max(const Ordering& keys) const override { + throw std::runtime_error("Not implemented"); + } }; } // namespace gtsam diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index 7f2eb2c2c..8dc7114ec 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -55,7 +55,7 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { } /// Calculate value - double operator()(const DiscreteValues& values) const override; + double operator()(const Assignment& values) const override; /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override; @@ -80,6 +80,22 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { /// Get the number of non-zero values contained in this factor. uint64_t nrValues() const override { return 1; }; + + DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { + throw std::runtime_error("Not implemented"); + } + + DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { + throw std::runtime_error("Not implemented"); + } + + DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { + throw std::runtime_error("Not implemented"); + } + + DiscreteFactor::shared_ptr max(const Ordering& keys) const override { + throw std::runtime_error("Not implemented"); + } }; } // namespace gtsam From fc2d33f437ab65b394ff563ff9f8872101487189 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 8 Dec 2024 17:00:04 -0500 Subject: [PATCH 34/86] add division with DiscreteFactor::shared_ptr for convenience --- gtsam/discrete/DecisionTreeFactor.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 11793f984..a5178b66f 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -156,6 +156,11 @@ namespace gtsam { return apply(f, safe_div); } + /// divide by DiscreteFactor::shared_ptr f (safely) + DecisionTreeFactor operator/(const DiscreteFactor::shared_ptr& f) const { + return apply(*std::dynamic_pointer_cast(f), safe_div); + } + /// Convert into a decision tree DecisionTreeFactor toDecisionTreeFactor() const override { return *this; } From 2c02efcae2e2b320834baeff157329c6908c332e Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 8 Dec 2024 17:02:47 -0500 Subject: [PATCH 35/86] fix tests --- gtsam/discrete/tests/testDecisionTreeFactor.cpp | 6 +++--- gtsam/discrete/tests/testDiscreteConditional.cpp | 4 ++-- gtsam/discrete/tests/testDiscreteFactorGraph.cpp | 6 +++--- gtsam/discrete/tests/testTableFactor.cpp | 6 +++--- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 756a0cebe..b4c5acc1b 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -111,15 +111,15 @@ TEST(DecisionTreeFactor, sum_max) { DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6"); DecisionTreeFactor expected(v1, "9 12"); - DecisionTreeFactor::shared_ptr actual = f1.sum(1); + auto actual = std::dynamic_pointer_cast(f1.sum(1)); CHECK(assert_equal(expected, *actual, 1e-5)); DecisionTreeFactor expected2(v1, "5 6"); - DecisionTreeFactor::shared_ptr actual2 = f1.max(1); + auto actual2 = std::dynamic_pointer_cast(f1.max(1)); CHECK(assert_equal(expected2, *actual2)); DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6"); - DecisionTreeFactor::shared_ptr actual22 = f2.sum(1); + auto actual22 = std::dynamic_pointer_cast(f2.sum(1)); } /* ************************************************************************* */ diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 2482a86a2..d17c76837 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -46,7 +46,7 @@ TEST(DiscreteConditional, constructors) { DecisionTreeFactor f2( X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); - DecisionTreeFactor expected2 = f2 / *f2.sum(1); + DecisionTreeFactor expected2 = f2 / f2.sum(1); EXPECT(assert_equal(expected2, static_cast(actual2))); std::vector probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75}; @@ -70,7 +70,7 @@ TEST(DiscreteConditional, constructors_alt_interface) { DecisionTreeFactor f2( X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); - DecisionTreeFactor expected2 = f2 / *f2.sum(1); + DecisionTreeFactor expected2 = f2 / f2.sum(1); EXPECT(assert_equal(expected2, static_cast(actual2))); } diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index 341eb63e3..e0d696d91 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -113,12 +113,12 @@ TEST(DiscreteFactorGraph, test) { const Ordering frontalKeys{0}; const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys); - DecisionTreeFactor newFactor = *newFactorPtr; + auto newFactor = *std::dynamic_pointer_cast(newFactorPtr); // Normalize newFactor by max for comparison with expected auto normalization = newFactor.max(newFactor.size()); - newFactor = newFactor / *normalization; + newFactor = newFactor / normalization; // Check Conditional CHECK(conditional); @@ -132,7 +132,7 @@ TEST(DiscreteFactorGraph, test) { // Normalize by max. normalization = expectedFactor.max(expectedFactor.size()); // Ensure normalization is correct. - expectedFactor = expectedFactor / *normalization; + expectedFactor = expectedFactor / normalization; EXPECT(assert_equal(expectedFactor, newFactor)); // Test using elimination tree diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp index 0f7d7a615..bd3dac514 100644 --- a/gtsam/discrete/tests/testTableFactor.cpp +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -242,15 +242,15 @@ TEST(TableFactor, sum_max) { TableFactor f1(v0 & v1, "1 2 3 4 5 6"); TableFactor expected(v1, "9 12"); - TableFactor::shared_ptr actual = f1.sum(1); + auto actual = std::dynamic_pointer_cast(f1.sum(1)); CHECK(assert_equal(expected, *actual, 1e-5)); TableFactor expected2(v1, "5 6"); - TableFactor::shared_ptr actual2 = f1.max(1); + auto actual2 = std::dynamic_pointer_cast(f1.max(1)); CHECK(assert_equal(expected2, *actual2)); TableFactor f2(v1 & v0, "1 2 3 4 5 6"); - TableFactor::shared_ptr actual22 = f2.sum(1); + auto actual22 = std::dynamic_pointer_cast(f2.sum(1)); } /* ************************************************************************* */ From 360598d3d596a49f0e0778d5b4c5b49e6f17340f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 8 Dec 2024 17:03:31 -0500 Subject: [PATCH 36/86] undo uncomment --- gtsam/discrete/DiscreteDistribution.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/DiscreteDistribution.h b/gtsam/discrete/DiscreteDistribution.h index 28e509f15..09ea50332 100644 --- a/gtsam/discrete/DiscreteDistribution.h +++ b/gtsam/discrete/DiscreteDistribution.h @@ -86,7 +86,8 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional { double operator()(size_t value) const; /// We also want to keep the Base version, taking DiscreteValues: - using Base::operator(); + // TODO(dellaert): does not play well with wrapper! + // using Base::operator(); /// Return entire probability mass function. std::vector pmf() const; From 853241c6d09da2a71375cfc507b95ba647e59d29 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 8 Dec 2024 17:07:40 -0500 Subject: [PATCH 37/86] add evaluate to DiscreteConditional --- gtsam/discrete/DiscreteConditional.h | 1 + 1 file changed, 1 insertion(+) diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 298a7b004..75eaf9154 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -169,6 +169,7 @@ class GTSAM_EXPORT DiscreteConditional } using BaseFactor::error; ///< DiscreteValues version + using BaseFactor::evaluate; // DiscreteValues version using BaseFactor::operator(); ///< DiscreteValues version /** From 199c0a0f24da5e040ec128b83ad1f9cb3e2ef8eb Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 8 Dec 2024 17:15:22 -0500 Subject: [PATCH 38/86] keep using DecisionTreeFactor for DiscreteConditional --- gtsam/discrete/DiscreteConditional.cpp | 4 ++-- gtsam/discrete/DiscreteConditional.h | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index bd10e28b4..399126d51 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -201,7 +201,7 @@ DiscreteConditional::shared_ptr DiscreteConditional::choose( } /* ************************************************************************** */ -DiscreteFactor::shared_ptr DiscreteConditional::likelihood( +DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( const DiscreteValues& frontalValues) const { // Get the big decision tree with all the levels, and then go down the // branches based on the value of the frontal variables. @@ -226,7 +226,7 @@ DiscreteFactor::shared_ptr DiscreteConditional::likelihood( } /* ****************************************************************************/ -DiscreteFactor::shared_ptr DiscreteConditional::likelihood( +DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( size_t frontal) const { if (nrFrontals() != 1) throw std::invalid_argument( diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 75eaf9154..96ecefd7a 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -188,11 +188,11 @@ class GTSAM_EXPORT DiscreteConditional shared_ptr choose(const DiscreteValues& given) const; /** Convert to a likelihood factor by providing value before bar. */ - DiscreteFactor::shared_ptr likelihood( + DecisionTreeFactor::shared_ptr likelihood( const DiscreteValues& frontalValues) const; /** Single variable version of likelihood. */ - DiscreteFactor::shared_ptr likelihood(size_t frontal) const; + DecisionTreeFactor::shared_ptr likelihood(size_t frontal) const; /** * sample From 214bf4ec1a2ce6329d930b03b8f8f279c6f71de9 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 8 Dec 2024 17:15:40 -0500 Subject: [PATCH 39/86] more fixes --- gtsam/discrete/DiscreteFactorGraph.cpp | 4 ++-- gtsam_unstable/discrete/SingleValue.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 04849985f..e31f94eae 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -67,7 +67,7 @@ namespace gtsam { DecisionTreeFactor DiscreteFactorGraph::product() const { DecisionTreeFactor result; for (const sharedFactor& factor : *this) { - if (factor) result = result * (*factor); + if (factor) result = (*factor) * result; } return result; } @@ -254,7 +254,7 @@ namespace gtsam { // now divide product/sum to get conditional gttic_(divide); auto conditional = - std::make_shared(product, sum, orderedKeys); + std::make_shared(product, *sum, orderedKeys); gttoc_(divide); return {conditional, sum}; diff --git a/gtsam_unstable/discrete/SingleValue.cpp b/gtsam_unstable/discrete/SingleValue.cpp index 6b78f38f5..9762aec0f 100644 --- a/gtsam_unstable/discrete/SingleValue.cpp +++ b/gtsam_unstable/discrete/SingleValue.cpp @@ -22,7 +22,7 @@ void SingleValue::print(const string& s, const KeyFormatter& formatter) const { } /* ************************************************************************* */ -double SingleValue::operator()(const DiscreteValues& values) const { +double SingleValue::operator()(const Assignment& values) const { return (double)(values.at(keys_[0]) == value_); } From e46cd5499351c67abd8c799d6ba317d37d21991e Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 9 Dec 2024 15:52:42 -0500 Subject: [PATCH 40/86] TableFactor cleanup --- gtsam/discrete/TableFactor.cpp | 2 -- gtsam/discrete/TableFactor.h | 1 - 2 files changed, 3 deletions(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index a4947012e..32049fde1 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -72,9 +72,7 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys, */ std::vector ComputeLeafOrdering(const DiscreteKeys& dkeys, const DecisionTreeFactor& dt) { - gttic_(ComputeLeafOrdering); std::vector probs = dt.probabilities(); - gttoc_(ComputeLeafOrdering); std::vector ordered; size_t n = dkeys[0].second; diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index ba1d05fe9..002c276ca 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -171,7 +171,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { return apply(f, safe_div); } - /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override; From 52c8034d41e8703ad149ee161c9d93ded4df5f0c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 9 Dec 2024 16:16:18 -0500 Subject: [PATCH 41/86] add division by DiscreteFactor in TableFactor --- gtsam/discrete/TableFactor.h | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 002c276ca..64e98c6a1 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -171,6 +171,15 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { return apply(f, safe_div); } + /// divide by DiscreteFactor::shared_ptr f (safely) + TableFactor operator/(const DiscreteFactor::shared_ptr& f) const { + if (auto tf = std::dynamic_pointer_cast(f)) { + return apply(*tf, safe_div); + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + return apply(TableFactor(f->discreteKeys(), *dtf), safe_div); + } + } + /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override; From e0e833c2fc71b5848c405d1d5fa20078d1df0cab Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 9 Dec 2024 16:23:55 -0500 Subject: [PATCH 42/86] cleanup --- gtsam/discrete/DiscreteLookupDAG.h | 2 -- gtsam/discrete/TableFactor.cpp | 2 -- 2 files changed, 4 deletions(-) diff --git a/gtsam/discrete/DiscreteLookupDAG.h b/gtsam/discrete/DiscreteLookupDAG.h index 21181d374..7f31a3f48 100644 --- a/gtsam/discrete/DiscreteLookupDAG.h +++ b/gtsam/discrete/DiscreteLookupDAG.h @@ -61,8 +61,6 @@ class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional { * @param nFrontals number of frontal variables * @param keys a sorted list of gtsam::Keys * @param potentials Discrete potentials as a TableFactor. - * - * //TODO(Varun): Should accept a DiscreteFactor::shared_ptr */ DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys, const TableFactor& potentials) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 32049fde1..32cba84ed 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -180,10 +180,8 @@ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { for (auto i = 0; i < sparse_table_.size(); i++) { table.push_back(sparse_table_.coeff(i)); } - gttic_(toDecisionTreeFactor_Constructor); // NOTE(Varun): This constructor is really expensive!! DecisionTreeFactor f(dkeys, table); - gttoc_(toDecisionTreeFactor_Constructor); return f; } From 84627c0c579805590e703f487fc1a072e6637da6 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 9 Dec 2024 16:30:46 -0500 Subject: [PATCH 43/86] fix error --- gtsam/discrete/DiscreteFactor.h | 1 - gtsam/discrete/TableFactor.h | 3 +++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index a6356a045..2fe80a54c 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -144,7 +144,6 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { /// Create new factor by maximizing over all values with the same separator. virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0; - /** * Get the number of non-zero values contained in this factor. * It could be much smaller than `prod_{key}(cardinality(key))`. diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 64e98c6a1..41b6287b8 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include #include @@ -177,6 +178,8 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { return apply(*tf, safe_div); } else if (auto dtf = std::dynamic_pointer_cast(f)) { return apply(TableFactor(f->discreteKeys(), *dtf), safe_div); + } else { + throw std::runtime_error("Unknown derived type for DiscreteFactor"); } } From 22d11d7af49ebdec102e5ec16597a00a175621ce Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 11 Dec 2024 04:00:56 -0500 Subject: [PATCH 44/86] don't print timing info by default --- gtsam/discrete/DiscreteFactorGraph.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 50072f547..ba05417aa 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -224,10 +224,10 @@ namespace gtsam { DecisionTreeFactor product = ProductAndNormalize(factors); // sum out frontals, this is the factor on the separator - gttic_(sum); + gttic(sum); DecisionTreeFactor::shared_ptr sum = std::dynamic_pointer_cast( product.sum(frontalKeys)); - gttoc_(sum); + gttoc(sum); // Ordering keys for the conditional so that frontalKeys are really in front Ordering orderedKeys; @@ -237,10 +237,10 @@ namespace gtsam { sum->keys().end()); // now divide product/sum to get conditional - gttic_(divide); + gttic(divide); auto conditional = std::make_shared(product, *sum, orderedKeys); - gttoc_(divide); + gttoc(divide); return {conditional, sum}; } From cbcfab4176b0577cc62e3cb77c3522cba90e2015 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 3 Jan 2025 13:09:25 -0500 Subject: [PATCH 45/86] serialize functions for Eigen::SparseVector --- gtsam/base/MatrixSerialization.h | 40 ++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/gtsam/base/MatrixSerialization.h b/gtsam/base/MatrixSerialization.h index 11b6a417a..43c97097d 100644 --- a/gtsam/base/MatrixSerialization.h +++ b/gtsam/base/MatrixSerialization.h @@ -24,6 +24,7 @@ #include +#include #include #include #include @@ -87,6 +88,45 @@ void serialize(Archive& ar, gtsam::Matrix& m, const unsigned int version) { split_free(ar, m, version); } +/******************************************************************************/ +/// Customized functions for serializing Eigen::SparseVector +template +void save(Archive& ar, const Eigen::SparseVector<_Scalar, _Options, _Index>& m, + const unsigned int /*version*/) { + _Index size = m.size(); + + std::vector> data; + for (typename Eigen::SparseVector<_Scalar, _Options, _Index>::InnerIterator + it(m); + it; ++it) + data.push_back({it.index(), it.value()}); + + ar << BOOST_SERIALIZATION_NVP(size); + ar << BOOST_SERIALIZATION_NVP(data); +} + +template +void load(Archive& ar, Eigen::SparseVector<_Scalar, _Options, _Index>& m, + const unsigned int /*version*/) { + _Index size; + ar >> BOOST_SERIALIZATION_NVP(size); + m.resize(size); + + std::vector> data; + ar >> BOOST_SERIALIZATION_NVP(data); + + for (auto&& d : data) { + m.coeffRef(d.first) = d.second; + } +} + +template +void serialize(Archive& ar, Eigen::SparseVector<_Scalar, _Options, _Index>& m, + const unsigned int version) { + split_free(ar, m, version); +} +/******************************************************************************/ + } // namespace serialization } // namespace boost #endif From 5c4194e7cd800c8cd232f92a63dbd8e10b935340 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 3 Jan 2025 13:13:51 -0500 Subject: [PATCH 46/86] add serialization code to TableFactor --- gtsam/discrete/TableFactor.h | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 5ddb4ab43..a2fdb4d32 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -31,6 +31,12 @@ #include #include +#if GTSAM_ENABLE_BOOST_SERIALIZATION +#include + +#include +#endif + namespace gtsam { class DiscreteConditional; @@ -342,6 +348,19 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { double error(const HybridValues& values) const override; /// @} + + private: +#if GTSAM_ENABLE_BOOST_SERIALIZATION + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + ar& BOOST_SERIALIZATION_NVP(sparse_table_); + ar& BOOST_SERIALIZATION_NVP(denominators_); + ar& BOOST_SERIALIZATION_NVP(sorted_dkeys_); + } +#endif }; // traits From ef2843c5b213e11463de8217e790a3e6e9cbbd25 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 3 Jan 2025 13:14:07 -0500 Subject: [PATCH 47/86] test for TableFactor serialization --- .../discrete/tests/testSerializationDiscrete.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/gtsam/discrete/tests/testSerializationDiscrete.cpp b/gtsam/discrete/tests/testSerializationDiscrete.cpp index df7df0b7e..9d15d0536 100644 --- a/gtsam/discrete/tests/testSerializationDiscrete.cpp +++ b/gtsam/discrete/tests/testSerializationDiscrete.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include using namespace std; @@ -32,6 +33,7 @@ BOOST_CLASS_EXPORT_GUID(Tree::Leaf, "gtsam_DecisionTreeStringInt_Leaf") BOOST_CLASS_EXPORT_GUID(Tree::Choice, "gtsam_DecisionTreeStringInt_Choice") BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor"); +BOOST_CLASS_EXPORT_GUID(TableFactor, "gtsam_TableFactor"); using ADT = AlgebraicDecisionTree; BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree"); @@ -79,6 +81,19 @@ TEST(DiscreteSerialization, DecisionTreeFactor) { EXPECT(equalsBinary(f)); } +/* ************************************************************************* */ +// Check serialization for TableFactor +TEST(DiscreteSerialization, TableFactor) { + using namespace serializationTestHelpers; + + DiscreteKey A(Symbol('x', 1), 3); + TableFactor tf(A % "1/2/2"); + + EXPECT(equalsObj(tf)); + EXPECT(equalsXML(tf)); + EXPECT(equalsBinary(tf)); +} + /* ************************************************************************* */ // Check serialization for DiscreteConditional & DiscreteDistribution TEST(DiscreteSerialization, DiscreteConditional) { From 834288f9748992b24bc4d4f4cffc77c7d8461d8c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 3 Jan 2025 13:18:24 -0500 Subject: [PATCH 48/86] additional Signature based constructor for DecisionTreeFactor --- gtsam/discrete/DecisionTreeFactor.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 80ee10a7b..24a699d42 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -18,6 +18,7 @@ #pragma once +#include #include #include #include @@ -116,6 +117,10 @@ namespace gtsam { DecisionTreeFactor(const DiscreteKey& key, const std::vector& row) : DecisionTreeFactor(DiscreteKeys{key}, row) {} + /// Construct from Signature + DecisionTreeFactor(const Signature& signature) + : DecisionTreeFactor(signature.discreteKeys(), signature.cpt()) {} + /** Construct from a DiscreteConditional type */ explicit DecisionTreeFactor(const DiscreteConditional& c); From e6567457b511a6ff993efcf2710c98f72c71bdad Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 3 Jan 2025 13:21:19 -0500 Subject: [PATCH 49/86] update tests --- .../discrete/tests/testDecisionTreeFactor.cpp | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 756a0cebe..1828db525 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -217,12 +217,6 @@ void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) { #endif } -/** Convert Signature into CPT */ -DecisionTreeFactor create(const Signature& signature) { - DecisionTreeFactor p(signature.discreteKeys(), signature.cpt()); - return p; -} - /* ************************************************************************* */ // test Asia Joint TEST(DecisionTreeFactor, joint) { @@ -230,14 +224,14 @@ TEST(DecisionTreeFactor, joint) { D(7, 2); gttic_(asiaCPTs); - DecisionTreeFactor pA = create(A % "99/1"); - DecisionTreeFactor pS = create(S % "50/50"); - DecisionTreeFactor pT = create(T | A = "99/1 95/5"); - DecisionTreeFactor pL = create(L | S = "99/1 90/10"); - DecisionTreeFactor pB = create(B | S = "70/30 40/60"); - DecisionTreeFactor pE = create((E | T, L) = "F T T T"); - DecisionTreeFactor pX = create(X | E = "95/5 2/98"); - DecisionTreeFactor pD = create((D | E, B) = "9/1 2/8 3/7 1/9"); + DecisionTreeFactor pA(A % "99/1"); + DecisionTreeFactor pS(S % "50/50"); + DecisionTreeFactor pT(T | A = "99/1 95/5"); + DecisionTreeFactor pL(L | S = "99/1 90/10"); + DecisionTreeFactor pB(B | S = "70/30 40/60"); + DecisionTreeFactor pE((E | T, L) = "F T T T"); + DecisionTreeFactor pX(X | E = "95/5 2/98"); + DecisionTreeFactor pD((D | E, B) = "9/1 2/8 3/7 1/9"); // Create joint gttic_(asiaJoint); From cb9cec30e39895b4745a4727aa896718dcccc467 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 3 Jan 2025 13:28:21 -0500 Subject: [PATCH 50/86] unit test exposing division bug --- gtsam/discrete/tests/testDecisionTreeFactor.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 1828db525..73420c860 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -69,6 +69,15 @@ TEST(DecisionTreeFactor, constructors) { EXPECT_DOUBLES_EQUAL(0.8, f4(x121), 1e-9); } +/* ************************************************************************* */ +TEST(DecisionTreeFactor, Divide) { + DiscreteKey A(0, 2), S(1, 2); + DecisionTreeFactor pA(A % "99/1"), pS(S % "50/50"); + DecisionTreeFactor joint = pA * pS; + DecisionTreeFactor s = joint / pA; + EXPECT(assert_equal(pS, s)); +} + /* ************************************************************************* */ TEST(DecisionTreeFactor, Error) { // Declare a bunch of keys From c6c451bee102f624cf25660ec6a820c9d5c1c49c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 3 Jan 2025 13:28:40 -0500 Subject: [PATCH 51/86] compute correct subset of keys for division --- gtsam/discrete/DecisionTreeFactor.h | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 24a699d42..0b94140da 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -161,7 +161,15 @@ namespace gtsam { /// divide by factor f (safely) DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { - return apply(f, safe_div); + KeyVector diff; + std::set_difference(this->keys().begin(), this->keys().end(), + f.keys().begin(), f.keys().end(), + std::back_inserter(diff)); + DiscreteKeys keys; + for (Key key : diff) { + keys.push_back({key, this->cardinality(key)}); + } + return DecisionTreeFactor(keys, apply(f, safe_div)); } /// Convert into a decision tree From 3718cb19edd21c718fe280c712a793031b6fd707 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 3 Jan 2025 13:35:13 -0500 Subject: [PATCH 52/86] use string based constructor --- gtsam/discrete/tests/testSerializationDiscrete.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/discrete/tests/testSerializationDiscrete.cpp b/gtsam/discrete/tests/testSerializationDiscrete.cpp index 9d15d0536..b118a00f6 100644 --- a/gtsam/discrete/tests/testSerializationDiscrete.cpp +++ b/gtsam/discrete/tests/testSerializationDiscrete.cpp @@ -87,7 +87,7 @@ TEST(DiscreteSerialization, TableFactor) { using namespace serializationTestHelpers; DiscreteKey A(Symbol('x', 1), 3); - TableFactor tf(A % "1/2/2"); + TableFactor tf(A, "1 2 2"); EXPECT(equalsObj(tf)); EXPECT(equalsXML(tf)); From ffe14d39aae1609c4d85b863a3ffd5ebca0089ac Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 3 Jan 2025 13:42:48 -0500 Subject: [PATCH 53/86] Revert "update tests" This reverts commit e6567457b511a6ff993efcf2710c98f72c71bdad. --- .../discrete/tests/testDecisionTreeFactor.cpp | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 73420c860..61ce9038d 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -226,6 +226,12 @@ void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) { #endif } +/** Convert Signature into CPT */ +DecisionTreeFactor create(const Signature& signature) { + DecisionTreeFactor p(signature.discreteKeys(), signature.cpt()); + return p; +} + /* ************************************************************************* */ // test Asia Joint TEST(DecisionTreeFactor, joint) { @@ -233,14 +239,14 @@ TEST(DecisionTreeFactor, joint) { D(7, 2); gttic_(asiaCPTs); - DecisionTreeFactor pA(A % "99/1"); - DecisionTreeFactor pS(S % "50/50"); - DecisionTreeFactor pT(T | A = "99/1 95/5"); - DecisionTreeFactor pL(L | S = "99/1 90/10"); - DecisionTreeFactor pB(B | S = "70/30 40/60"); - DecisionTreeFactor pE((E | T, L) = "F T T T"); - DecisionTreeFactor pX(X | E = "95/5 2/98"); - DecisionTreeFactor pD((D | E, B) = "9/1 2/8 3/7 1/9"); + DecisionTreeFactor pA = create(A % "99/1"); + DecisionTreeFactor pS = create(S % "50/50"); + DecisionTreeFactor pT = create(T | A = "99/1 95/5"); + DecisionTreeFactor pL = create(L | S = "99/1 90/10"); + DecisionTreeFactor pB = create(B | S = "70/30 40/60"); + DecisionTreeFactor pE = create((E | T, L) = "F T T T"); + DecisionTreeFactor pX = create(X | E = "95/5 2/98"); + DecisionTreeFactor pD = create((D | E, B) = "9/1 2/8 3/7 1/9"); // Create joint gttic_(asiaJoint); From be7be376a9ef95b12c08b83d72bb602cc93eb852 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 3 Jan 2025 13:42:56 -0500 Subject: [PATCH 54/86] Revert "additional Signature based constructor for DecisionTreeFactor" This reverts commit 834288f9748992b24bc4d4f4cffc77c7d8461d8c. --- gtsam/discrete/DecisionTreeFactor.h | 5 ----- 1 file changed, 5 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 0b94140da..804b956fe 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -18,7 +18,6 @@ #pragma once -#include #include #include #include @@ -117,10 +116,6 @@ namespace gtsam { DecisionTreeFactor(const DiscreteKey& key, const std::vector& row) : DecisionTreeFactor(DiscreteKeys{key}, row) {} - /// Construct from Signature - DecisionTreeFactor(const Signature& signature) - : DecisionTreeFactor(signature.discreteKeys(), signature.cpt()) {} - /** Construct from a DiscreteConditional type */ explicit DecisionTreeFactor(const DiscreteConditional& c); From 6b6283c1512467819918c90191aa8372e96a00dd Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 3 Jan 2025 15:21:49 -0500 Subject: [PATCH 55/86] fix factor construction --- gtsam/discrete/tests/testDecisionTreeFactor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 61ce9038d..dc18e0ab2 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -72,7 +72,7 @@ TEST(DecisionTreeFactor, constructors) { /* ************************************************************************* */ TEST(DecisionTreeFactor, Divide) { DiscreteKey A(0, 2), S(1, 2); - DecisionTreeFactor pA(A % "99/1"), pS(S % "50/50"); + DecisionTreeFactor pA = create(A % "99/1"), pS = create(S % "50/50"); DecisionTreeFactor joint = pA * pS; DecisionTreeFactor s = joint / pA; EXPECT(assert_equal(pS, s)); From a142556c52064b5adb2926e76e4eb7e238ed0cb5 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 08:45:11 -0500 Subject: [PATCH 56/86] move create to the top --- .../discrete/tests/testDecisionTreeFactor.cpp | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index dc18e0ab2..7210622d8 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -30,6 +30,12 @@ using namespace std; using namespace gtsam; +/** Convert Signature into CPT */ +DecisionTreeFactor create(const Signature& signature) { + DecisionTreeFactor p(signature.discreteKeys(), signature.cpt()); + return p; +} + /* ************************************************************************* */ TEST(DecisionTreeFactor, ConstructorsMatch) { // Declare two keys @@ -69,15 +75,6 @@ TEST(DecisionTreeFactor, constructors) { EXPECT_DOUBLES_EQUAL(0.8, f4(x121), 1e-9); } -/* ************************************************************************* */ -TEST(DecisionTreeFactor, Divide) { - DiscreteKey A(0, 2), S(1, 2); - DecisionTreeFactor pA = create(A % "99/1"), pS = create(S % "50/50"); - DecisionTreeFactor joint = pA * pS; - DecisionTreeFactor s = joint / pA; - EXPECT(assert_equal(pS, s)); -} - /* ************************************************************************* */ TEST(DecisionTreeFactor, Error) { // Declare a bunch of keys @@ -114,6 +111,15 @@ TEST(DecisionTreeFactor, multiplication) { CHECK(assert_equal(expected2, actual)); } +/* ************************************************************************* */ +TEST(DecisionTreeFactor, Divide) { + DiscreteKey A(0, 2), S(1, 2); + DecisionTreeFactor pA = create(A % "99/1"), pS = create(S % "50/50"); + DecisionTreeFactor joint = pA * pS; + DecisionTreeFactor s = joint / pA; + EXPECT(assert_equal(pS, s)); +} + /* ************************************************************************* */ TEST(DecisionTreeFactor, sum_max) { DiscreteKey v0(0, 3), v1(1, 2); @@ -226,12 +232,6 @@ void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) { #endif } -/** Convert Signature into CPT */ -DecisionTreeFactor create(const Signature& signature) { - DecisionTreeFactor p(signature.discreteKeys(), signature.cpt()); - return p; -} - /* ************************************************************************* */ // test Asia Joint TEST(DecisionTreeFactor, joint) { From 5fa04d7622548a8072ac621ff355a8caa622930d Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 09:08:57 -0500 Subject: [PATCH 57/86] small improvements --- gtsam/discrete/DiscreteConditional.cpp | 6 ++---- gtsam/discrete/DiscreteFactorGraph.cpp | 2 +- gtsam/discrete/tests/testDiscreteFactorGraph.cpp | 10 +++++----- gtsam/discrete/tests/testTableFactor.cpp | 14 ++++++++------ 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index be0f42bea..26f38e278 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -24,13 +24,13 @@ #include #include +#include #include #include #include #include #include #include -#include using namespace std; using std::pair; @@ -45,9 +45,7 @@ template class GTSAM_EXPORT /* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const size_t nrFrontals, const DecisionTreeFactor& f) - : BaseFactor(f / (*std::dynamic_pointer_cast( - f.sum(nrFrontals)))), - BaseConditional(nrFrontals) {} + : BaseFactor(f / f.sum(nrFrontals)), BaseConditional(nrFrontals) {} /* ************************************************************************** */ DiscreteConditional::DiscreteConditional(size_t nrFrontals, diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 8c950050b..678c70e2d 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -128,7 +128,7 @@ namespace gtsam { auto denominator = product.max(product.size()); // Normalize the product factor to prevent underflow. - product = product / (*denominator); + product = product / denominator; return product; } diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index cbcf5234e..4ee36f0ab 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -117,9 +117,9 @@ TEST(DiscreteFactorGraph, test) { *std::dynamic_pointer_cast(newFactorPtr); // Normalize newFactor by max for comparison with expected - auto normalizer = newFactor.max(newFactor.size()); + auto denominator = newFactor.max(newFactor.size()); - newFactor = newFactor / *normalizer; + newFactor = newFactor / denominator; // Check Conditional CHECK(conditional); @@ -131,9 +131,9 @@ TEST(DiscreteFactorGraph, test) { CHECK(&newFactor); DecisionTreeFactor expectedFactor(B & A, "10 6 6 10"); // Normalize by max. - normalizer = expectedFactor.max(expectedFactor.size()); - // Ensure normalizer is correct. - expectedFactor = expectedFactor / *normalizer; + denominator = expectedFactor.max(expectedFactor.size()); + // Ensure denominator is correct. + expectedFactor = expectedFactor / denominator; EXPECT(assert_equal(expectedFactor, newFactor)); // Test using elimination tree diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp index 147c3aea9..4f6ec2f39 100644 --- a/gtsam/discrete/tests/testTableFactor.cpp +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -194,15 +194,17 @@ TEST(TableFactor, Conversion) { TEST(TableFactor, Empty) { DiscreteKey X(1, 2); - TableFactor single = *TableFactor({X}, "1 1").sum(1); + auto single = TableFactor({X}, "1 1").sum(1); // Should not throw a segfault - EXPECT(assert_equal(*DecisionTreeFactor(X, "1 1").sum(1), - single.toDecisionTreeFactor())); + auto expected_single = DecisionTreeFactor(X, "1 1").sum(1); + EXPECT(assert_equal(expected_single->toDecisionTreeFactor(), + single->toDecisionTreeFactor())); - TableFactor empty = *TableFactor({X}, "0 0").sum(1); + auto empty = TableFactor({X}, "0 0").sum(1); // Should not throw a segfault - EXPECT(assert_equal(*DecisionTreeFactor(X, "0 0").sum(1), - empty.toDecisionTreeFactor())); + auto expected_empty = DecisionTreeFactor(X, "0 0").sum(1); + EXPECT(assert_equal(expected_empty->toDecisionTreeFactor(), + empty->toDecisionTreeFactor())); } /* ************************************************************************* */ From e309bf370bd195d5f4e2171cd451ca58588e9fb1 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 11:54:09 -0500 Subject: [PATCH 58/86] improve operator/ documentation and also showcase understanding in test --- gtsam/discrete/DecisionTreeFactor.h | 24 +++++++++---------- .../discrete/tests/testDecisionTreeFactor.cpp | 14 ++++++++++- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 804b956fe..a5b82f277 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -154,17 +154,17 @@ namespace gtsam { static double safe_div(const double& a, const double& b); - /// divide by factor f (safely) + /** + * @brief Divide by factor f (safely). + * Division of a factor \f$f(x, y)\f$ by another factor \f$g(y, z)\f$ + * results in a function which involves all keys + * \f$(\frac{f}{g})(x, y, z) = f(x, y) / g(y, z)\f$ + * + * @param f The DecisinTreeFactor to divide by. + * @return DecisionTreeFactor + */ DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { - KeyVector diff; - std::set_difference(this->keys().begin(), this->keys().end(), - f.keys().begin(), f.keys().end(), - std::back_inserter(diff)); - DiscreteKeys keys; - for (Key key : diff) { - keys.push_back({key, this->cardinality(key)}); - } - return DecisionTreeFactor(keys, apply(f, safe_div)); + return apply(f, safe_div); } /// Convert into a decision tree @@ -181,12 +181,12 @@ namespace gtsam { } /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(size_t nrFrontals) const { + DiscreteFactor::shared_ptr max(size_t nrFrontals) const { return combine(nrFrontals, Ring::max); } /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(const Ordering& keys) const { + DiscreteFactor::shared_ptr max(const Ordering& keys) const { return combine(keys, Ring::max); } diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 7210622d8..ba8714783 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -116,8 +116,20 @@ TEST(DecisionTreeFactor, Divide) { DiscreteKey A(0, 2), S(1, 2); DecisionTreeFactor pA = create(A % "99/1"), pS = create(S % "50/50"); DecisionTreeFactor joint = pA * pS; + DecisionTreeFactor s = joint / pA; - EXPECT(assert_equal(pS, s)); + + // Factors are not equal due to difference in keys + EXPECT(assert_inequal(pS, s)); + + // The underlying data should be the same + using ADT = AlgebraicDecisionTree; + EXPECT(assert_equal(ADT(pS), ADT(s))); + + KeySet keys(joint.keys()); + keys.insert(pA.keys().begin(), pA.keys().end()); + EXPECT(assert_inequal(KeySet(pS.keys()), keys)); + } /* ************************************************************************* */ From 268290dbf25195ce2cdb4ef1ed4099eed6338f78 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 13:22:36 -0500 Subject: [PATCH 59/86] multiply method for DiscreteFactor --- gtsam/discrete/DecisionTreeFactor.cpp | 12 ++++++++++++ gtsam/discrete/DecisionTreeFactor.h | 5 +++++ gtsam/discrete/DiscreteFactor.h | 10 ++++++++++ gtsam/discrete/TableFactor.cpp | 12 ++++++++++++ gtsam/discrete/TableFactor.h | 4 ++++ 5 files changed, 43 insertions(+) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 1ac782b88..cf22fe153 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -62,6 +62,18 @@ namespace gtsam { return error(values.discrete()); } + /* ************************************************************************ */ + DiscreteFactor::shared_ptr DecisionTreeFactor::multiply( + const DiscreteFactor::shared_ptr& f) const override { + DiscreteFactor::shared_ptr result; + if (auto tf = std::dynamic_pointer_cast(f)) { + result = std::make_shared((*tf) * TableFactor(*this)); + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + result = std::make_shared(this->operator*(*dtf)); + } + return result; + } + /* ************************************************************************ */ double DecisionTreeFactor::safe_div(const double& a, const double& b) { // The use for safe_div is when we divide the product factor by the sum diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 8b6d91be7..d3cb55fa5 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -147,6 +148,10 @@ namespace gtsam { /// Calculate error for DiscreteValues `x`, is -log(probability). double error(const DiscreteValues& values) const override; + /// Multiply factors, DiscreteFactor::shared_ptr edition + virtual DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& f) const override; + /// multiply two factors DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { return apply(f, Ring::mul); diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 63cd7844c..adc79bbd5 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -130,6 +130,16 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { /// DecisionTreeFactor virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; + /** + * @brief Multiply in a DiscreteFactor and return the result as + * DiscreteFactor, both via shared pointers. + * + * @param df DiscreteFactor shared_ptr + * @return DiscreteFactor::shared_ptr + */ + virtual DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& df) const = 0; + virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; /// Create new factor by summing all values with the same separator values diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index a59095d40..cfa56b43a 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -254,6 +254,18 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { return toDecisionTreeFactor() * f; } +/* ************************************************************************ */ +DiscreteFactor::shared_ptr TableFactor::multiply( + const DiscreteFactor::shared_ptr& f) const override { + DiscreteFactor::shared_ptr result; + if (auto tf = std::dynamic_pointer_cast(f)) { + result = std::make_shared(this->operator*(*tf)); + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + result = std::make_shared(this->operator*(TableFactor(*dtf))); + } + return result; +} + /* ************************************************************************ */ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DiscreteKeys dkeys = discreteKeys(); diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 16354026d..f7d0f5215 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -179,6 +179,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { /// multiply with DecisionTreeFactor DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + /// Multiply factors, DiscreteFactor::shared_ptr edition + virtual DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& f) const override; + static double safe_div(const double& a, const double& b); /// divide by factor f (safely) From 2f09e860e1109cf543a92c32d4fca14ce5d6c28e Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 13:36:06 -0500 Subject: [PATCH 60/86] remove override from definition --- gtsam/discrete/DecisionTreeFactor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index cf22fe153..c15fd4e2e 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -64,7 +64,7 @@ namespace gtsam { /* ************************************************************************ */ DiscreteFactor::shared_ptr DecisionTreeFactor::multiply( - const DiscreteFactor::shared_ptr& f) const override { + const DiscreteFactor::shared_ptr& f) const { DiscreteFactor::shared_ptr result; if (auto tf = std::dynamic_pointer_cast(f)) { result = std::make_shared((*tf) * TableFactor(*this)); From 5d865a8cc7908f02c206a3d9801af0fbaf8d1eaa Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 13:22:36 -0500 Subject: [PATCH 61/86] multiply method for DiscreteFactor --- gtsam/discrete/DecisionTreeFactor.cpp | 12 ++++++++++++ gtsam/discrete/DecisionTreeFactor.h | 5 +++++ gtsam/discrete/DiscreteFactor.h | 10 ++++++++++ gtsam/discrete/TableFactor.cpp | 12 ++++++++++++ gtsam/discrete/TableFactor.h | 4 ++++ 5 files changed, 43 insertions(+) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 1ac782b88..cf22fe153 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -62,6 +62,18 @@ namespace gtsam { return error(values.discrete()); } + /* ************************************************************************ */ + DiscreteFactor::shared_ptr DecisionTreeFactor::multiply( + const DiscreteFactor::shared_ptr& f) const override { + DiscreteFactor::shared_ptr result; + if (auto tf = std::dynamic_pointer_cast(f)) { + result = std::make_shared((*tf) * TableFactor(*this)); + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + result = std::make_shared(this->operator*(*dtf)); + } + return result; + } + /* ************************************************************************ */ double DecisionTreeFactor::safe_div(const double& a, const double& b) { // The use for safe_div is when we divide the product factor by the sum diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 80ee10a7b..3e70c0df9 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -147,6 +148,10 @@ namespace gtsam { /// Calculate error for DiscreteValues `x`, is -log(probability). double error(const DiscreteValues& values) const override; + /// Multiply factors, DiscreteFactor::shared_ptr edition + virtual DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& f) const override; + /// multiply two factors DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { return apply(f, Ring::mul); diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index a1fde0f86..c18eaae2f 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -129,6 +129,16 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { /// DecisionTreeFactor virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; + /** + * @brief Multiply in a DiscreteFactor and return the result as + * DiscreteFactor, both via shared pointers. + * + * @param df DiscreteFactor shared_ptr + * @return DiscreteFactor::shared_ptr + */ + virtual DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& df) const = 0; + virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; /// @} diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index a59095d40..cfa56b43a 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -254,6 +254,18 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { return toDecisionTreeFactor() * f; } +/* ************************************************************************ */ +DiscreteFactor::shared_ptr TableFactor::multiply( + const DiscreteFactor::shared_ptr& f) const override { + DiscreteFactor::shared_ptr result; + if (auto tf = std::dynamic_pointer_cast(f)) { + result = std::make_shared(this->operator*(*tf)); + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + result = std::make_shared(this->operator*(TableFactor(*dtf))); + } + return result; +} + /* ************************************************************************ */ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DiscreteKeys dkeys = discreteKeys(); diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index a2fdb4d32..4b53d7e2b 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -178,6 +178,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { /// multiply with DecisionTreeFactor DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + /// Multiply factors, DiscreteFactor::shared_ptr edition + virtual DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& f) const override; + static double safe_div(const double& a, const double& b); /// divide by factor f (safely) From 75a4e98715caca504bacf4cc92a768db3dd89303 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 13:36:06 -0500 Subject: [PATCH 62/86] remove override from definition --- gtsam/discrete/DecisionTreeFactor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index cf22fe153..c15fd4e2e 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -64,7 +64,7 @@ namespace gtsam { /* ************************************************************************ */ DiscreteFactor::shared_ptr DecisionTreeFactor::multiply( - const DiscreteFactor::shared_ptr& f) const override { + const DiscreteFactor::shared_ptr& f) const { DiscreteFactor::shared_ptr result; if (auto tf = std::dynamic_pointer_cast(f)) { result = std::make_shared((*tf) * TableFactor(*this)); From 700ad2bae326f3f89ecc78b56d55a74a64c3785a Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 13:56:54 -0500 Subject: [PATCH 63/86] remove override from TableFactor definition --- gtsam/discrete/TableFactor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index cfa56b43a..3ca8fecda 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -256,7 +256,7 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { /* ************************************************************************ */ DiscreteFactor::shared_ptr TableFactor::multiply( - const DiscreteFactor::shared_ptr& f) const override { + const DiscreteFactor::shared_ptr& f) const { DiscreteFactor::shared_ptr result; if (auto tf = std::dynamic_pointer_cast(f)) { result = std::make_shared(this->operator*(*tf)); From 260d448887447a9a9f3216a544ad34ddd1dfbd8c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 13:57:07 -0500 Subject: [PATCH 64/86] use new multiply method --- gtsam/discrete/DiscreteFactorGraph.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 227bb4da3..0444f47ae 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -65,11 +65,11 @@ namespace gtsam { /* ************************************************************************ */ DecisionTreeFactor DiscreteFactorGraph::product() const { - DecisionTreeFactor result; - for (const sharedFactor& factor : *this) { - if (factor) result = (*factor) * result; + DiscreteFactor::shared_ptr result = *this->begin(); + for (auto it = this->begin() + 1; it != this->end(); ++it) { + if (*it) result = result->multiply(*it); } - return result; + return result->toDecisionTreeFactor(); } /* ************************************************************************ */ From 453059bd61f4b08acdc43c37b9b098a5579d36a3 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 14:03:08 -0500 Subject: [PATCH 65/86] simplify to remove DiscreteProduct static function --- gtsam/discrete/DiscreteFactorGraph.cpp | 44 ++++++++++---------------- 1 file changed, 17 insertions(+), 27 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 0444f47ae..b3029111a 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -65,11 +65,23 @@ namespace gtsam { /* ************************************************************************ */ DecisionTreeFactor DiscreteFactorGraph::product() const { - DiscreteFactor::shared_ptr result = *this->begin(); + // PRODUCT: multiply all factors + gttic(product); + DiscreteFactor::shared_ptr product = *this->begin(); for (auto it = this->begin() + 1; it != this->end(); ++it) { - if (*it) result = result->multiply(*it); + if (*it) product = product->multiply(*it); } - return result->toDecisionTreeFactor(); + gttoc(product); + + DecisionTreeFactor = result->toDecisionTreeFactor(); + + // Max over all the potentials by pretending all keys are frontal: + auto denominator = product.max(product.size()); + + // Normalize the product factor to prevent underflow. + product = product / (*denominator); + + return product; } /* ************************************************************************ */ @@ -111,34 +123,12 @@ namespace gtsam { // } // } - /** - * @brief Multiply all the `factors`. - * - * @param factors The factors to multiply as a DiscreteFactorGraph. - * @return DecisionTreeFactor - */ - static DecisionTreeFactor DiscreteProduct( - const DiscreteFactorGraph& factors) { - // PRODUCT: multiply all factors - gttic(product); - DecisionTreeFactor product = factors.product(); - gttoc(product); - - // Max over all the potentials by pretending all keys are frontal: - auto denominator = product.max(product.size()); - - // Normalize the product factor to prevent underflow. - product = product / (*denominator); - - return product; - } - /* ************************************************************************ */ // Alternate eliminate function for MPE std::pair // EliminateForMPE(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DecisionTreeFactor product = DiscreteProduct(factors); + DecisionTreeFactor product = factors.product(); // max out frontals, this is the factor on the separator gttic(max); @@ -216,7 +206,7 @@ namespace gtsam { std::pair // EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DecisionTreeFactor product = DiscreteProduct(factors); + DecisionTreeFactor product = factors.product(); // sum out frontals, this is the factor on the separator gttic(sum); From a02baec0119ca9b670a8b5b64ebecc5db492bcbd Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 14:23:34 -0500 Subject: [PATCH 66/86] naive implementation of multiply for unstable --- gtsam_unstable/discrete/AllDiff.h | 7 +++++++ gtsam_unstable/discrete/BinaryAllDiff.h | 7 +++++++ gtsam_unstable/discrete/Domain.h | 7 +++++++ gtsam_unstable/discrete/SingleValue.h | 7 +++++++ 4 files changed, 28 insertions(+) diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index 1180abad4..cfbd76e7c 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -53,6 +53,13 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { /// Multiply into a decisiontree DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + /// Multiply factors, DiscreteFactor::shared_ptr edition + DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& df) const override { + return std::make_shared( + this->operator*(df->toDecisionTreeFactor())); + } + /// Compute error for each assignment and return as a tree AlgebraicDecisionTree errorTree() const override { throw std::runtime_error("AllDiff::error not implemented"); diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index e96bfdfde..a1a2bf0a6 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -69,6 +69,13 @@ class BinaryAllDiff : public Constraint { return toDecisionTreeFactor() * f; } + /// Multiply factors, DiscreteFactor::shared_ptr edition + DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& df) const override { + return std::make_shared( + this->operator*(df->toDecisionTreeFactor())); + } + /* * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 23a566d24..dea85934f 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -90,6 +90,13 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { /// Multiply into a decisiontree DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + /// Multiply factors, DiscreteFactor::shared_ptr edition + DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& df) const override { + return std::make_shared( + this->operator*(df->toDecisionTreeFactor())); + } + /* * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index 3df1209b8..8675c929b 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -63,6 +63,13 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { /// Multiply into a decisiontree DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + /// Multiply factors, DiscreteFactor::shared_ptr edition + DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& df) const override { + return std::make_shared( + this->operator*(df->toDecisionTreeFactor())); + } + /* * Ensure Arc-consistency: just sets domain[j] to {value_}. * @param j domain to be checked From 13bafb0a48bae3c5598b7db112ac2757bd02c431 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 14:28:57 -0500 Subject: [PATCH 67/86] Revert "simplify to remove DiscreteProduct static function" This reverts commit 453059bd61f4b08acdc43c37b9b098a5579d36a3. --- gtsam/discrete/DiscreteFactorGraph.cpp | 44 ++++++++++++++++---------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index b3029111a..0444f47ae 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -65,23 +65,11 @@ namespace gtsam { /* ************************************************************************ */ DecisionTreeFactor DiscreteFactorGraph::product() const { - // PRODUCT: multiply all factors - gttic(product); - DiscreteFactor::shared_ptr product = *this->begin(); + DiscreteFactor::shared_ptr result = *this->begin(); for (auto it = this->begin() + 1; it != this->end(); ++it) { - if (*it) product = product->multiply(*it); + if (*it) result = result->multiply(*it); } - gttoc(product); - - DecisionTreeFactor = result->toDecisionTreeFactor(); - - // Max over all the potentials by pretending all keys are frontal: - auto denominator = product.max(product.size()); - - // Normalize the product factor to prevent underflow. - product = product / (*denominator); - - return product; + return result->toDecisionTreeFactor(); } /* ************************************************************************ */ @@ -123,12 +111,34 @@ namespace gtsam { // } // } + /** + * @brief Multiply all the `factors`. + * + * @param factors The factors to multiply as a DiscreteFactorGraph. + * @return DecisionTreeFactor + */ + static DecisionTreeFactor DiscreteProduct( + const DiscreteFactorGraph& factors) { + // PRODUCT: multiply all factors + gttic(product); + DecisionTreeFactor product = factors.product(); + gttoc(product); + + // Max over all the potentials by pretending all keys are frontal: + auto denominator = product.max(product.size()); + + // Normalize the product factor to prevent underflow. + product = product / (*denominator); + + return product; + } + /* ************************************************************************ */ // Alternate eliminate function for MPE std::pair // EliminateForMPE(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DecisionTreeFactor product = factors.product(); + DecisionTreeFactor product = DiscreteProduct(factors); // max out frontals, this is the factor on the separator gttic(max); @@ -206,7 +216,7 @@ namespace gtsam { std::pair // EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DecisionTreeFactor product = factors.product(); + DecisionTreeFactor product = DiscreteProduct(factors); // sum out frontals, this is the factor on the separator gttic(sum); From a7fc6e3763d4dc5ea5019906be37e05613d3f9d9 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 15:08:58 -0500 Subject: [PATCH 68/86] convert everything to DecisionTreeFactor so we can use override operator* method --- gtsam_unstable/discrete/AllDiff.h | 4 ++-- gtsam_unstable/discrete/BinaryAllDiff.h | 4 ++-- gtsam_unstable/discrete/Domain.h | 4 ++-- gtsam_unstable/discrete/SingleValue.h | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index cfbd76e7c..032808dcd 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -56,8 +56,8 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { /// Multiply factors, DiscreteFactor::shared_ptr edition DiscreteFactor::shared_ptr multiply( const DiscreteFactor::shared_ptr& df) const override { - return std::make_shared( - this->operator*(df->toDecisionTreeFactor())); + return std::make_shared(this->toDecisionTreeFactor() * + df->toDecisionTreeFactor()); } /// Compute error for each assignment and return as a tree diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index a1a2bf0a6..0ebae4d77 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -72,8 +72,8 @@ class BinaryAllDiff : public Constraint { /// Multiply factors, DiscreteFactor::shared_ptr edition DiscreteFactor::shared_ptr multiply( const DiscreteFactor::shared_ptr& df) const override { - return std::make_shared( - this->operator*(df->toDecisionTreeFactor())); + return std::make_shared(this->toDecisionTreeFactor() * + df->toDecisionTreeFactor()); } /* diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index dea85934f..9a4a21847 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -93,8 +93,8 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { /// Multiply factors, DiscreteFactor::shared_ptr edition DiscreteFactor::shared_ptr multiply( const DiscreteFactor::shared_ptr& df) const override { - return std::make_shared( - this->operator*(df->toDecisionTreeFactor())); + return std::make_shared(this->toDecisionTreeFactor() * + df->toDecisionTreeFactor()); } /* diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index 8675c929b..ebe23f7e4 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -66,8 +66,8 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { /// Multiply factors, DiscreteFactor::shared_ptr edition DiscreteFactor::shared_ptr multiply( const DiscreteFactor::shared_ptr& df) const override { - return std::make_shared( - this->operator*(df->toDecisionTreeFactor())); + return std::make_shared(this->toDecisionTreeFactor() * + df->toDecisionTreeFactor()); } /* From 8390ffa2cbde062f7628a953bba6f43ba77cc7d1 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 15:19:16 -0500 Subject: [PATCH 69/86] revert previous commit --- gtsam_unstable/discrete/AllDiff.h | 4 ++-- gtsam_unstable/discrete/BinaryAllDiff.h | 4 ++-- gtsam_unstable/discrete/Domain.h | 4 ++-- gtsam_unstable/discrete/SingleValue.h | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index 032808dcd..cfbd76e7c 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -56,8 +56,8 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { /// Multiply factors, DiscreteFactor::shared_ptr edition DiscreteFactor::shared_ptr multiply( const DiscreteFactor::shared_ptr& df) const override { - return std::make_shared(this->toDecisionTreeFactor() * - df->toDecisionTreeFactor()); + return std::make_shared( + this->operator*(df->toDecisionTreeFactor())); } /// Compute error for each assignment and return as a tree diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index 0ebae4d77..a1a2bf0a6 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -72,8 +72,8 @@ class BinaryAllDiff : public Constraint { /// Multiply factors, DiscreteFactor::shared_ptr edition DiscreteFactor::shared_ptr multiply( const DiscreteFactor::shared_ptr& df) const override { - return std::make_shared(this->toDecisionTreeFactor() * - df->toDecisionTreeFactor()); + return std::make_shared( + this->operator*(df->toDecisionTreeFactor())); } /* diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 9a4a21847..dea85934f 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -93,8 +93,8 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { /// Multiply factors, DiscreteFactor::shared_ptr edition DiscreteFactor::shared_ptr multiply( const DiscreteFactor::shared_ptr& df) const override { - return std::make_shared(this->toDecisionTreeFactor() * - df->toDecisionTreeFactor()); + return std::make_shared( + this->operator*(df->toDecisionTreeFactor())); } /* diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index ebe23f7e4..8675c929b 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -66,8 +66,8 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { /// Multiply factors, DiscreteFactor::shared_ptr edition DiscreteFactor::shared_ptr multiply( const DiscreteFactor::shared_ptr& df) const override { - return std::make_shared(this->toDecisionTreeFactor() * - df->toDecisionTreeFactor()); + return std::make_shared( + this->operator*(df->toDecisionTreeFactor())); } /* From bc63cc8cb88c7edebf86a4136d7231ab51a59e47 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 15:19:57 -0500 Subject: [PATCH 70/86] use double dispatch for else case --- gtsam/discrete/DecisionTreeFactor.cpp | 3 +++ gtsam/discrete/TableFactor.cpp | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index c15fd4e2e..4b16dad8a 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -70,6 +70,9 @@ namespace gtsam { result = std::make_shared((*tf) * TableFactor(*this)); } else if (auto dtf = std::dynamic_pointer_cast(f)) { result = std::make_shared(this->operator*(*dtf)); + } else { + // Simulate double dispatch in C++ + result = std::make_shared(f->operator*(*this)); } return result; } diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 3ca8fecda..6516a4a98 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -262,6 +262,10 @@ DiscreteFactor::shared_ptr TableFactor::multiply( result = std::make_shared(this->operator*(*tf)); } else if (auto dtf = std::dynamic_pointer_cast(f)) { result = std::make_shared(this->operator*(TableFactor(*dtf))); + } else { + // Simulate double dispatch in C++ + result = std::make_shared( + f->operator*(this->toDecisionTreeFactor())); } return result; } From 713c49c9153f9f023621a6cd9a06997594ea52ab Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 15:20:09 -0500 Subject: [PATCH 71/86] more robust product --- gtsam/discrete/DiscreteFactorGraph.cpp | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 0444f47ae..a2b896286 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -65,9 +65,16 @@ namespace gtsam { /* ************************************************************************ */ DecisionTreeFactor DiscreteFactorGraph::product() const { - DiscreteFactor::shared_ptr result = *this->begin(); - for (auto it = this->begin() + 1; it != this->end(); ++it) { - if (*it) result = result->multiply(*it); + DiscreteFactor::shared_ptr result; + for (auto it = this->begin(); it != this->end(); ++it) { + if (*it) { + if (result) { + result = result->multiply(*it); + } else { + // Assign to the first non-null factor + result = *it; + } + } } return result->toDecisionTreeFactor(); } From b83aadb20487f69c6ab932245f8524c8ec92fdde Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 15:37:37 -0500 Subject: [PATCH 72/86] remove accidental type change --- gtsam/discrete/DecisionTreeFactor.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index a5b82f277..eb6d9eaa2 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -181,12 +181,12 @@ namespace gtsam { } /// Create new factor by maximizing over all values with the same separator. - DiscreteFactor::shared_ptr max(size_t nrFrontals) const { + shared_ptr max(size_t nrFrontals) const { return combine(nrFrontals, Ring::max); } /// Create new factor by maximizing over all values with the same separator. - DiscreteFactor::shared_ptr max(const Ordering& keys) const { + shared_ptr max(const Ordering& keys) const { return combine(keys, Ring::max); } From b5128b2c9fcf31d04e4760c3c4178d8449ad44c9 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 19:40:37 -0500 Subject: [PATCH 73/86] use DecisionTreeFactor version of sum and max where not available --- gtsam_unstable/discrete/AllDiff.h | 8 ++++---- gtsam_unstable/discrete/BinaryAllDiff.h | 8 ++++---- gtsam_unstable/discrete/Domain.h | 8 ++++---- gtsam_unstable/discrete/SingleValue.h | 8 ++++---- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index cf0e5e3cf..267ddb9fd 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -84,19 +84,19 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { uint64_t nrValues() const override { return 1; }; DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { - throw std::runtime_error("Not implemented"); + return toDecisionTreeFactor().sum(nrFrontals); } DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { - throw std::runtime_error("Not implemented"); + return toDecisionTreeFactor().sum(keys); } DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { - throw std::runtime_error("Not implemented"); + return toDecisionTreeFactor().max(nrFrontals); } DiscreteFactor::shared_ptr max(const Ordering& keys) const override { - throw std::runtime_error("Not implemented"); + return toDecisionTreeFactor().max(keys); } }; diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index c15ac8aec..3035d0620 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -108,19 +108,19 @@ class BinaryAllDiff : public Constraint { uint64_t nrValues() const override { return 1; }; DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { - throw std::runtime_error("Not implemented"); + return toDecisionTreeFactor().sum(nrFrontals); } DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { - throw std::runtime_error("Not implemented"); + return toDecisionTreeFactor().sum(keys); } DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { - throw std::runtime_error("Not implemented"); + return toDecisionTreeFactor().max(nrFrontals); } DiscreteFactor::shared_ptr max(const Ordering& keys) const override { - throw std::runtime_error("Not implemented"); + return toDecisionTreeFactor().max(keys); } }; diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index f64716028..4c2d3f9dd 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -123,19 +123,19 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { Constraint::shared_ptr partiallyApply(const Domains& domains) const override; DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { - throw std::runtime_error("Not implemented"); + return toDecisionTreeFactor().sum(nrFrontals); } DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { - throw std::runtime_error("Not implemented"); + return toDecisionTreeFactor().sum(keys); } DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { - throw std::runtime_error("Not implemented"); + return toDecisionTreeFactor().max(nrFrontals); } DiscreteFactor::shared_ptr max(const Ordering& keys) const override { - throw std::runtime_error("Not implemented"); + return toDecisionTreeFactor().max(keys); } }; diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index c5824a96a..b6c91f912 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -89,19 +89,19 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { uint64_t nrValues() const override { return 1; }; DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { - throw std::runtime_error("Not implemented"); + return toDecisionTreeFactor().sum(nrFrontals); } DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { - throw std::runtime_error("Not implemented"); + return toDecisionTreeFactor().sum(keys); } DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { - throw std::runtime_error("Not implemented"); + return toDecisionTreeFactor().max(nrFrontals); } DiscreteFactor::shared_ptr max(const Ordering& keys) const override { - throw std::runtime_error("Not implemented"); + return toDecisionTreeFactor().max(keys); } }; From 4ebca711461eb3fd914b74b4532242aedf38c048 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 20:44:10 -0500 Subject: [PATCH 74/86] divide operator for DiscreteFactor::shared_ptr --- gtsam/discrete/DecisionTreeFactor.cpp | 13 +++++++++++++ gtsam/discrete/DecisionTreeFactor.h | 5 ++--- gtsam/discrete/DiscreteFactor.h | 4 ++++ gtsam/discrete/TableFactor.cpp | 14 ++++++++++++++ gtsam/discrete/TableFactor.h | 11 ++--------- gtsam_unstable/discrete/AllDiff.cpp | 6 ++++++ gtsam_unstable/discrete/AllDiff.h | 4 ++++ gtsam_unstable/discrete/BinaryAllDiff.h | 6 ++++++ gtsam_unstable/discrete/Domain.cpp | 6 ++++++ gtsam_unstable/discrete/Domain.h | 4 ++++ gtsam_unstable/discrete/SingleValue.cpp | 6 ++++++ gtsam_unstable/discrete/SingleValue.h | 4 ++++ 12 files changed, 71 insertions(+), 12 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 4b16dad8a..2f2c039a4 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -77,6 +77,19 @@ namespace gtsam { return result; } + /* ************************************************************************ */ + DiscreteFactor::shared_ptr DecisionTreeFactor::operator/( + const DiscreteFactor::shared_ptr& f) const { + if (auto tf = std::dynamic_pointer_cast(f)) { + return std::make_shared(tf->operator/(TableFactor(*this))); + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + return std::make_shared(this->operator/(*dtf)); + } else { + return std::make_shared( + this->operator/(this->toDecisionTreeFactor())); + } + } + /* ************************************************************************ */ double DecisionTreeFactor::safe_div(const double& a, const double& b) { // The use for safe_div is when we divide the product factor by the sum diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index d3cb55fa5..a5327bdd0 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -165,9 +165,8 @@ namespace gtsam { } /// divide by DiscreteFactor::shared_ptr f (safely) - DecisionTreeFactor operator/(const DiscreteFactor::shared_ptr& f) const { - return apply(*std::dynamic_pointer_cast(f), safe_div); - } + DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& f) const override; /// Convert into a decision tree DecisionTreeFactor toDecisionTreeFactor() const override { return *this; } diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index adc79bbd5..6cbc00d09 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -140,6 +140,10 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { virtual DiscreteFactor::shared_ptr multiply( const DiscreteFactor::shared_ptr& df) const = 0; + /// divide by DiscreteFactor::shared_ptr f (safely) + virtual DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& df) const = 0; + virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; /// Create new factor by summing all values with the same separator values diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 6516a4a98..b692e9ba2 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -270,6 +270,20 @@ DiscreteFactor::shared_ptr TableFactor::multiply( return result; } +/* ************************************************************************ */ +DiscreteFactor::shared_ptr TableFactor::operator/( + const DiscreteFactor::shared_ptr& f) const { + if (auto tf = std::dynamic_pointer_cast(f)) { + return std::make_shared(this->operator/(*tf)); + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + return std::make_shared( + this->operator/(TableFactor(f->discreteKeys(), *dtf))); + } else { + TableFactor divisor(f->toDecisionTreeFactor()); + return std::make_shared(this->operator/(divisor)); + } +} + /* ************************************************************************ */ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DiscreteKeys dkeys = discreteKeys(); diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index f7d0f5215..a2f74758f 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -191,15 +191,8 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { } /// divide by DiscreteFactor::shared_ptr f (safely) - TableFactor operator/(const DiscreteFactor::shared_ptr& f) const { - if (auto tf = std::dynamic_pointer_cast(f)) { - return apply(*tf, safe_div); - } else if (auto dtf = std::dynamic_pointer_cast(f)) { - return apply(TableFactor(f->discreteKeys(), *dtf), safe_div); - } else { - throw std::runtime_error("Unknown derived type for DiscreteFactor"); - } - } + DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& f) const override; /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override; diff --git a/gtsam_unstable/discrete/AllDiff.cpp b/gtsam_unstable/discrete/AllDiff.cpp index 585ca8103..01f50fa3d 100644 --- a/gtsam_unstable/discrete/AllDiff.cpp +++ b/gtsam_unstable/discrete/AllDiff.cpp @@ -56,6 +56,12 @@ DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const { return toDecisionTreeFactor() * f; } +/* ************************************************************************* */ +DiscreteFactor::shared_ptr AllDiff::operator/( + const DiscreteFactor::shared_ptr& df) const { + return this->toDecisionTreeFactor() / df; +} + /* ************************************************************************* */ bool AllDiff::ensureArcConsistency(Key j, Domains* domains) const { Domain& Dj = domains->at(j); diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index 267ddb9fd..7a7b1cecc 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -60,6 +60,10 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { this->operator*(df->toDecisionTreeFactor())); } + /// divide by DiscreteFactor::shared_ptr f (safely) + DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& df) const override; + /// Compute error for each assignment and return as a tree AlgebraicDecisionTree errorTree() const override { throw std::runtime_error("AllDiff::error not implemented"); diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index 3035d0620..fbff8a01c 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -76,6 +76,12 @@ class BinaryAllDiff : public Constraint { this->operator*(df->toDecisionTreeFactor())); } + /// divide by DiscreteFactor::shared_ptr f (safely) + DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& df) const override { + return this->toDecisionTreeFactor() / df; + } + /* * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked diff --git a/gtsam_unstable/discrete/Domain.cpp b/gtsam_unstable/discrete/Domain.cpp index 74f621dc7..cecb7cc1a 100644 --- a/gtsam_unstable/discrete/Domain.cpp +++ b/gtsam_unstable/discrete/Domain.cpp @@ -49,6 +49,12 @@ DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const { return toDecisionTreeFactor() * f; } +/* ************************************************************************* */ +DiscreteFactor::shared_ptr Domain::operator/( + const DiscreteFactor::shared_ptr& df) const { + return this->toDecisionTreeFactor() / df; +} + /* ************************************************************************* */ bool Domain::ensureArcConsistency(Key j, Domains* domains) const { if (j != key()) throw invalid_argument("Domain check on wrong domain"); diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 4c2d3f9dd..7362e9caf 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -97,6 +97,10 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { this->operator*(df->toDecisionTreeFactor())); } + /// divide by DiscreteFactor::shared_ptr f (safely) + DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& df) const override; + /* * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked diff --git a/gtsam_unstable/discrete/SingleValue.cpp b/gtsam_unstable/discrete/SingleValue.cpp index 220bc9c06..09a8314df 100644 --- a/gtsam_unstable/discrete/SingleValue.cpp +++ b/gtsam_unstable/discrete/SingleValue.cpp @@ -41,6 +41,12 @@ DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const { return toDecisionTreeFactor() * f; } +/* ************************************************************************* */ +DiscreteFactor::shared_ptr SingleValue::operator/( + const DiscreteFactor::shared_ptr& df) const { + return this->toDecisionTreeFactor() / df; +} + /* ************************************************************************* */ bool SingleValue::ensureArcConsistency(Key j, Domains* domains) const { if (j != keys_[0]) diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index b6c91f912..87c42fc80 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -70,6 +70,10 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { this->operator*(df->toDecisionTreeFactor())); } + /// divide by DiscreteFactor::shared_ptr f (safely) + DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& df) const override; + /* * Ensure Arc-consistency: just sets domain[j] to {value_}. * @param j domain to be checked From fb1d52a18eec0e804276ef542edf4bdd34f69d3f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 20:49:24 -0500 Subject: [PATCH 75/86] fix constructor --- gtsam/discrete/DiscreteConditional.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 26f38e278..06a08eca5 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -45,7 +45,8 @@ template class GTSAM_EXPORT /* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const size_t nrFrontals, const DecisionTreeFactor& f) - : BaseFactor(f / f.sum(nrFrontals)), BaseConditional(nrFrontals) {} + : BaseFactor(f / f.sum(nrFrontals)->toDecisionTreeFactor()), + BaseConditional(nrFrontals) {} /* ************************************************************************** */ DiscreteConditional::DiscreteConditional(size_t nrFrontals, From e9822a70d2921fdbd433f55fa29d91ac9467f7f1 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 20:50:24 -0500 Subject: [PATCH 76/86] update DiscreteFactorGraph to use DiscreteFactor::shared_ptr for elimination --- gtsam/discrete/DiscreteFactorGraph.cpp | 37 +++++++++++++------------- gtsam/discrete/DiscreteFactorGraph.h | 2 +- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index e05cf9e33..eb3221819 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -64,7 +64,7 @@ namespace gtsam { } /* ************************************************************************ */ - DecisionTreeFactor DiscreteFactorGraph::product() const { + DiscreteFactor::shared_ptr DiscreteFactorGraph::product() const { DiscreteFactor::shared_ptr result; for (auto it = this->begin(); it != this->end(); ++it) { if (*it) { @@ -76,7 +76,7 @@ namespace gtsam { } } } - return result->toDecisionTreeFactor(); + return result; } /* ************************************************************************ */ @@ -122,20 +122,20 @@ namespace gtsam { * @brief Multiply all the `factors`. * * @param factors The factors to multiply as a DiscreteFactorGraph. - * @return DecisionTreeFactor + * @return DiscreteFactor::shared_ptr */ - static DecisionTreeFactor DiscreteProduct( + static DiscreteFactor::shared_ptr DiscreteProduct( const DiscreteFactorGraph& factors) { // PRODUCT: multiply all factors gttic(product); - DecisionTreeFactor product = factors.product(); + DiscreteFactor::shared_ptr product = factors.product(); gttoc(product); // Max over all the potentials by pretending all keys are frontal: - auto denominator = product.max(product.size()); + auto denominator = product->max(product->size()); // Normalize the product factor to prevent underflow. - product = product / denominator; + product = product->operator/(denominator); return product; } @@ -145,26 +145,25 @@ namespace gtsam { std::pair // EliminateForMPE(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DecisionTreeFactor product = DiscreteProduct(factors); + DiscreteFactor::shared_ptr product = DiscreteProduct(factors); // max out frontals, this is the factor on the separator gttic(max); - DecisionTreeFactor::shared_ptr max = - std::dynamic_pointer_cast(product.max(frontalKeys)); + DiscreteFactor::shared_ptr max = product->max(frontalKeys); gttoc(max); // Ordering keys for the conditional so that frontalKeys are really in front DiscreteKeys orderedKeys; for (auto&& key : frontalKeys) - orderedKeys.emplace_back(key, product.cardinality(key)); + orderedKeys.emplace_back(key, product->cardinality(key)); for (auto&& key : max->keys()) - orderedKeys.emplace_back(key, product.cardinality(key)); + orderedKeys.emplace_back(key, product->cardinality(key)); // Make lookup with product gttic(lookup); size_t nrFrontals = frontalKeys.size(); - auto lookup = - std::make_shared(nrFrontals, orderedKeys, product); + auto lookup = std::make_shared( + nrFrontals, orderedKeys, product->toDecisionTreeFactor()); gttoc(lookup); return {std::dynamic_pointer_cast(lookup), max}; @@ -224,12 +223,11 @@ namespace gtsam { std::pair // EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DecisionTreeFactor product = DiscreteProduct(factors); + DiscreteFactor::shared_ptr product = DiscreteProduct(factors); // sum out frontals, this is the factor on the separator gttic(sum); - DecisionTreeFactor::shared_ptr sum = std::dynamic_pointer_cast( - product.sum(frontalKeys)); + DiscreteFactor::shared_ptr sum = product->sum(frontalKeys); gttoc(sum); // Ordering keys for the conditional so that frontalKeys are really in front @@ -241,8 +239,9 @@ namespace gtsam { // now divide product/sum to get conditional gttic(divide); - auto conditional = - std::make_shared(product, *sum, orderedKeys); + auto conditional = std::make_shared( + product->toDecisionTreeFactor(), sum->toDecisionTreeFactor(), + orderedKeys); gttoc(divide); return {conditional, sum}; diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index b311cb78b..3d9e86cd1 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -148,7 +148,7 @@ class GTSAM_EXPORT DiscreteFactorGraph DiscreteKeys discreteKeys() const; /** return product of all factors as a single factor */ - DecisionTreeFactor product() const; + DiscreteFactor::shared_ptr product() const; /** * Evaluates the factor graph given values, returns the joint probability of From 2f8c8ddb75cb96b30b550bd03e0a659746938857 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 20:50:40 -0500 Subject: [PATCH 77/86] update tests --- gtsam/discrete/tests/testDiscreteConditional.cpp | 4 ++-- gtsam/discrete/tests/testDiscreteFactorGraph.cpp | 7 ++++--- gtsam_unstable/discrete/tests/testCSP.cpp | 2 +- gtsam_unstable/discrete/tests/testScheduler.cpp | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index d17c76837..b91e1bd8a 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -46,7 +46,7 @@ TEST(DiscreteConditional, constructors) { DecisionTreeFactor f2( X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); - DecisionTreeFactor expected2 = f2 / f2.sum(1); + DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor(); EXPECT(assert_equal(expected2, static_cast(actual2))); std::vector probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75}; @@ -70,7 +70,7 @@ TEST(DiscreteConditional, constructors_alt_interface) { DecisionTreeFactor f2( X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); - DecisionTreeFactor expected2 = f2 / f2.sum(1); + DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor(); EXPECT(assert_equal(expected2, static_cast(actual2))); } diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index 4ee36f0ab..0c1dd7a2a 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -94,7 +94,7 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) { EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9); // Check if graph product works - DecisionTreeFactor product = graph.product(); + DecisionTreeFactor product = graph.product()->toDecisionTreeFactor(); EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9); } @@ -117,7 +117,7 @@ TEST(DiscreteFactorGraph, test) { *std::dynamic_pointer_cast(newFactorPtr); // Normalize newFactor by max for comparison with expected - auto denominator = newFactor.max(newFactor.size()); + auto denominator = newFactor.max(newFactor.size())->toDecisionTreeFactor(); newFactor = newFactor / denominator; @@ -131,7 +131,8 @@ TEST(DiscreteFactorGraph, test) { CHECK(&newFactor); DecisionTreeFactor expectedFactor(B & A, "10 6 6 10"); // Normalize by max. - denominator = expectedFactor.max(expectedFactor.size()); + denominator = + expectedFactor.max(expectedFactor.size())->toDecisionTreeFactor(); // Ensure denominator is correct. expectedFactor = expectedFactor / denominator; EXPECT(assert_equal(expectedFactor, newFactor)); diff --git a/gtsam_unstable/discrete/tests/testCSP.cpp b/gtsam_unstable/discrete/tests/testCSP.cpp index 2b9a20ca6..6806bfe58 100644 --- a/gtsam_unstable/discrete/tests/testCSP.cpp +++ b/gtsam_unstable/discrete/tests/testCSP.cpp @@ -124,7 +124,7 @@ TEST(CSP, allInOne) { EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); // Just for fun, create the product and check it - DecisionTreeFactor product = csp.product(); + DecisionTreeFactor product = csp.product()->toDecisionTreeFactor(); // product.dot("product"); DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0"); EXPECT(assert_equal(expectedProduct, product)); diff --git a/gtsam_unstable/discrete/tests/testScheduler.cpp b/gtsam_unstable/discrete/tests/testScheduler.cpp index f868abb5e..5f9b7f287 100644 --- a/gtsam_unstable/discrete/tests/testScheduler.cpp +++ b/gtsam_unstable/discrete/tests/testScheduler.cpp @@ -113,7 +113,7 @@ TEST(schedulingExample, test) { EXPECT(assert_equal(expected, (DiscreteFactorGraph)s)); // Do brute force product and output that to file - DecisionTreeFactor product = s.product(); + DecisionTreeFactor product = s.product()->toDecisionTreeFactor(); // product.dot("scheduling", false); // Do exact inference From 2434e248e95d0efac75a18cddae2748ff8be4502 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 20:54:56 -0500 Subject: [PATCH 78/86] undo print change in DiscreteLookupTable --- gtsam/discrete/DiscreteLookupDAG.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/discrete/DiscreteLookupDAG.cpp b/gtsam/discrete/DiscreteLookupDAG.cpp index c1c301525..d1108e7b7 100644 --- a/gtsam/discrete/DiscreteLookupDAG.cpp +++ b/gtsam/discrete/DiscreteLookupDAG.cpp @@ -48,7 +48,7 @@ void DiscreteLookupTable::print(const std::string& s, } } cout << "):\n"; - BaseFactor::print("", formatter); + ADT::print("", formatter); cout << endl; } From 43f755d9d8d0ede5868ac18ae33e9350c43a2e78 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 6 Jan 2025 11:17:03 -0500 Subject: [PATCH 79/86] move multiply to Constraint.h --- gtsam_unstable/discrete/AllDiff.h | 7 ------- gtsam_unstable/discrete/BinaryAllDiff.h | 7 ------- gtsam_unstable/discrete/Constraint.h | 8 ++++++++ gtsam_unstable/discrete/Domain.h | 7 ------- gtsam_unstable/discrete/SingleValue.h | 7 ------- 5 files changed, 8 insertions(+), 28 deletions(-) diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index cfbd76e7c..1180abad4 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -53,13 +53,6 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { /// Multiply into a decisiontree DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; - /// Multiply factors, DiscreteFactor::shared_ptr edition - DiscreteFactor::shared_ptr multiply( - const DiscreteFactor::shared_ptr& df) const override { - return std::make_shared( - this->operator*(df->toDecisionTreeFactor())); - } - /// Compute error for each assignment and return as a tree AlgebraicDecisionTree errorTree() const override { throw std::runtime_error("AllDiff::error not implemented"); diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index a1a2bf0a6..e96bfdfde 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -69,13 +69,6 @@ class BinaryAllDiff : public Constraint { return toDecisionTreeFactor() * f; } - /// Multiply factors, DiscreteFactor::shared_ptr edition - DiscreteFactor::shared_ptr multiply( - const DiscreteFactor::shared_ptr& df) const override { - return std::make_shared( - this->operator*(df->toDecisionTreeFactor())); - } - /* * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked diff --git a/gtsam_unstable/discrete/Constraint.h b/gtsam_unstable/discrete/Constraint.h index 3526a282d..71ed7647a 100644 --- a/gtsam_unstable/discrete/Constraint.h +++ b/gtsam_unstable/discrete/Constraint.h @@ -78,6 +78,14 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor { /// Partially apply known values, domain version virtual shared_ptr partiallyApply(const Domains&) const = 0; + + /// Multiply factors, DiscreteFactor::shared_ptr edition + DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& df) const override { + return std::make_shared( + this->operator*(df->toDecisionTreeFactor())); + } + /// @} /// @name Wrapper support /// @{ diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index dea85934f..23a566d24 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -90,13 +90,6 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { /// Multiply into a decisiontree DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; - /// Multiply factors, DiscreteFactor::shared_ptr edition - DiscreteFactor::shared_ptr multiply( - const DiscreteFactor::shared_ptr& df) const override { - return std::make_shared( - this->operator*(df->toDecisionTreeFactor())); - } - /* * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index 8675c929b..3df1209b8 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -63,13 +63,6 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { /// Multiply into a decisiontree DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; - /// Multiply factors, DiscreteFactor::shared_ptr edition - DiscreteFactor::shared_ptr multiply( - const DiscreteFactor::shared_ptr& df) const override { - return std::make_shared( - this->operator*(df->toDecisionTreeFactor())); - } - /* * Ensure Arc-consistency: just sets domain[j] to {value_}. * @param j domain to be checked From 7561da4df2df5740a808d0e6ea4e957eb65cb4f2 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 6 Jan 2025 13:35:45 -0500 Subject: [PATCH 80/86] move operator/ to Constraint.h --- gtsam_unstable/discrete/AllDiff.cpp | 6 ------ gtsam_unstable/discrete/Constraint.h | 6 ++++++ gtsam_unstable/discrete/Domain.cpp | 6 ------ gtsam_unstable/discrete/SingleValue.cpp | 6 ------ 4 files changed, 6 insertions(+), 18 deletions(-) diff --git a/gtsam_unstable/discrete/AllDiff.cpp b/gtsam_unstable/discrete/AllDiff.cpp index 01f50fa3d..585ca8103 100644 --- a/gtsam_unstable/discrete/AllDiff.cpp +++ b/gtsam_unstable/discrete/AllDiff.cpp @@ -56,12 +56,6 @@ DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const { return toDecisionTreeFactor() * f; } -/* ************************************************************************* */ -DiscreteFactor::shared_ptr AllDiff::operator/( - const DiscreteFactor::shared_ptr& df) const { - return this->toDecisionTreeFactor() / df; -} - /* ************************************************************************* */ bool AllDiff::ensureArcConsistency(Key j, Domains* domains) const { Domain& Dj = domains->at(j); diff --git a/gtsam_unstable/discrete/Constraint.h b/gtsam_unstable/discrete/Constraint.h index 71ed7647a..2d98ab40b 100644 --- a/gtsam_unstable/discrete/Constraint.h +++ b/gtsam_unstable/discrete/Constraint.h @@ -86,6 +86,12 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor { this->operator*(df->toDecisionTreeFactor())); } + /// divide by DiscreteFactor::shared_ptr f (safely) + DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& df) const override { + return this->toDecisionTreeFactor() / df; + } + /// @} /// @name Wrapper support /// @{ diff --git a/gtsam_unstable/discrete/Domain.cpp b/gtsam_unstable/discrete/Domain.cpp index cecb7cc1a..74f621dc7 100644 --- a/gtsam_unstable/discrete/Domain.cpp +++ b/gtsam_unstable/discrete/Domain.cpp @@ -49,12 +49,6 @@ DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const { return toDecisionTreeFactor() * f; } -/* ************************************************************************* */ -DiscreteFactor::shared_ptr Domain::operator/( - const DiscreteFactor::shared_ptr& df) const { - return this->toDecisionTreeFactor() / df; -} - /* ************************************************************************* */ bool Domain::ensureArcConsistency(Key j, Domains* domains) const { if (j != key()) throw invalid_argument("Domain check on wrong domain"); diff --git a/gtsam_unstable/discrete/SingleValue.cpp b/gtsam_unstable/discrete/SingleValue.cpp index 09a8314df..220bc9c06 100644 --- a/gtsam_unstable/discrete/SingleValue.cpp +++ b/gtsam_unstable/discrete/SingleValue.cpp @@ -41,12 +41,6 @@ DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const { return toDecisionTreeFactor() * f; } -/* ************************************************************************* */ -DiscreteFactor::shared_ptr SingleValue::operator/( - const DiscreteFactor::shared_ptr& df) const { - return this->toDecisionTreeFactor() / df; -} - /* ************************************************************************* */ bool SingleValue::ensureArcConsistency(Key j, Domains* domains) const { if (j != keys_[0]) From ff5371fd4a42cb91fdf9d0a222dbef2d5036a294 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 6 Jan 2025 13:38:45 -0500 Subject: [PATCH 81/86] move sum, max and nrValues to Constraint class as well --- gtsam_unstable/discrete/AllDiff.h | 19 ------------------- gtsam_unstable/discrete/BinaryAllDiff.h | 19 ------------------- gtsam_unstable/discrete/Constraint.h | 22 +++++++++++++++++++++- gtsam_unstable/discrete/Domain.h | 16 ---------------- gtsam_unstable/discrete/SingleValue.h | 19 ------------------- 5 files changed, 21 insertions(+), 74 deletions(-) diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index 34e5c4700..1180abad4 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -72,25 +72,6 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { /// Partially apply known values, domain version Constraint::shared_ptr partiallyApply( const Domains&) const override; - - /// Get the number of non-zero values contained in this factor. - uint64_t nrValues() const override { return 1; }; - - DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { - return toDecisionTreeFactor().sum(nrFrontals); - } - - DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { - return toDecisionTreeFactor().sum(keys); - } - - DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { - return toDecisionTreeFactor().max(nrFrontals); - } - - DiscreteFactor::shared_ptr max(const Ordering& keys) const override { - return toDecisionTreeFactor().max(keys); - } }; } // namespace gtsam diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index 0cd51ec61..e96bfdfde 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -96,25 +96,6 @@ class BinaryAllDiff : public Constraint { AlgebraicDecisionTree errorTree() const override { throw std::runtime_error("BinaryAllDiff::error not implemented"); } - - /// Get the number of non-zero values contained in this factor. - uint64_t nrValues() const override { return 1; }; - - DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { - return toDecisionTreeFactor().sum(nrFrontals); - } - - DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { - return toDecisionTreeFactor().sum(keys); - } - - DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { - return toDecisionTreeFactor().max(nrFrontals); - } - - DiscreteFactor::shared_ptr max(const Ordering& keys) const override { - return toDecisionTreeFactor().max(keys); - } }; } // namespace gtsam diff --git a/gtsam_unstable/discrete/Constraint.h b/gtsam_unstable/discrete/Constraint.h index 2d98ab40b..328fabbea 100644 --- a/gtsam_unstable/discrete/Constraint.h +++ b/gtsam_unstable/discrete/Constraint.h @@ -68,7 +68,8 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor { /* * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked - * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @param (in/out) domains all domains, but only domains->at(j) will be + * checked. * @return true if domains->at(j) was changed, false otherwise. */ virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0; @@ -92,6 +93,25 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor { return this->toDecisionTreeFactor() / df; } + /// Get the number of non-zero values contained in this factor. + uint64_t nrValues() const override { return 1; }; + + DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { + return toDecisionTreeFactor().sum(nrFrontals); + } + + DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { + return toDecisionTreeFactor().sum(keys); + } + + DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { + return toDecisionTreeFactor().max(nrFrontals); + } + + DiscreteFactor::shared_ptr max(const Ordering& keys) const override { + return toDecisionTreeFactor().max(keys); + } + /// @} /// @name Wrapper support /// @{ diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 2372cf499..6ce846201 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -114,22 +114,6 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { /// Partially apply known values, domain version Constraint::shared_ptr partiallyApply(const Domains& domains) const override; - - DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { - return toDecisionTreeFactor().sum(nrFrontals); - } - - DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { - return toDecisionTreeFactor().sum(keys); - } - - DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { - return toDecisionTreeFactor().max(nrFrontals); - } - - DiscreteFactor::shared_ptr max(const Ordering& keys) const override { - return toDecisionTreeFactor().max(keys); - } }; } // namespace gtsam diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index 5d4c2dca1..3df1209b8 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -77,25 +77,6 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { /// Partially apply known values, domain version Constraint::shared_ptr partiallyApply( const Domains& domains) const override; - - /// Get the number of non-zero values contained in this factor. - uint64_t nrValues() const override { return 1; }; - - DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override { - return toDecisionTreeFactor().sum(nrFrontals); - } - - DiscreteFactor::shared_ptr sum(const Ordering& keys) const override { - return toDecisionTreeFactor().sum(keys); - } - - DiscreteFactor::shared_ptr max(size_t nrFrontals) const override { - return toDecisionTreeFactor().max(nrFrontals); - } - - DiscreteFactor::shared_ptr max(const Ordering& keys) const override { - return toDecisionTreeFactor().max(keys); - } }; } // namespace gtsam From ab90e0b0f3d68f3f144b5fc4d7dd87a7e6425901 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 6 Jan 2025 13:44:55 -0500 Subject: [PATCH 82/86] move include to .cpp --- gtsam/discrete/DecisionTreeFactor.cpp | 3 ++- gtsam/discrete/DecisionTreeFactor.h | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 4b16dad8a..e353fdebf 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -18,9 +18,10 @@ */ #include -#include #include #include +#include +#include #include diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 3e70c0df9..ff9bf0df9 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -22,7 +22,6 @@ #include #include #include -#include #include #include From f043ac43a7c559b9707c693fbb4dd265b0483838 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 6 Jan 2025 14:08:08 -0500 Subject: [PATCH 83/86] address review comments --- gtsam/discrete/DecisionTreeFactor.cpp | 9 +++++++++ gtsam/discrete/DecisionTreeFactor.h | 15 ++++++++++++++- gtsam/discrete/TableFactor.cpp | 10 ++++++++++ gtsam/discrete/TableFactor.h | 15 ++++++++++++++- 4 files changed, 47 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index e353fdebf..ef7979d0a 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -68,11 +68,20 @@ namespace gtsam { const DiscreteFactor::shared_ptr& f) const { DiscreteFactor::shared_ptr result; if (auto tf = std::dynamic_pointer_cast(f)) { + // If f is a TableFactor, we convert `this` to a TableFactor since this + // conversion is cheaper than converting `f` to a DecisionTreeFactor. We + // then return a TableFactor. result = std::make_shared((*tf) * TableFactor(*this)); + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + // If `f` is a DecisionTreeFactor, simply call operator*. result = std::make_shared(this->operator*(*dtf)); + } else { // Simulate double dispatch in C++ + // Useful for other classes which inherit from DiscreteFactor and have + // only `operator*(DecisionTreeFactor)` defined. Thus, other classes don't + // need to be updated. result = std::make_shared(f->operator*(*this)); } return result; diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index ff9bf0df9..907f29a45 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -147,7 +147,20 @@ namespace gtsam { /// Calculate error for DiscreteValues `x`, is -log(probability). double error(const DiscreteValues& values) const override; - /// Multiply factors, DiscreteFactor::shared_ptr edition + /** + * @brief Multiply factors, DiscreteFactor::shared_ptr edition. + * + * This method accepts `DiscreteFactor::shared_ptr` and uses dynamic + * dispatch and specializations to perform the most efficient + * multiplication. + * + * While converting a DecisionTreeFactor to a TableFactor is efficient, the + * reverse is not. Hence we specialize the code to return a TableFactor if + * `f` is a TableFactor, and DecisionTreeFactor otherwise. + * + * @param f The factor to multiply with. + * @return DiscreteFactor::shared_ptr + */ virtual DiscreteFactor::shared_ptr multiply( const DiscreteFactor::shared_ptr& f) const override; diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 6516a4a98..fe901aac1 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -259,11 +259,21 @@ DiscreteFactor::shared_ptr TableFactor::multiply( const DiscreteFactor::shared_ptr& f) const { DiscreteFactor::shared_ptr result; if (auto tf = std::dynamic_pointer_cast(f)) { + // If `f` is a TableFactor, we can simply call `operator*`. result = std::make_shared(this->operator*(*tf)); + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + // If `f` is a DecisionTreeFactor, we convert to a TableFactor which is + // cheaper than converting `this` to a DecisionTreeFactor. result = std::make_shared(this->operator*(TableFactor(*dtf))); + } else { // Simulate double dispatch in C++ + // Useful for other classes which inherit from DiscreteFactor and have + // only `operator*(DecisionTreeFactor)` defined. Thus, other classes don't + // need to be updated to know about TableFactor. + // Those classes can be specialized to use TableFactor + // if efficiency is a problem. result = std::make_shared( f->operator*(this->toDecisionTreeFactor())); } diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 4b53d7e2b..a2e89b302 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -178,7 +178,20 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { /// multiply with DecisionTreeFactor DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; - /// Multiply factors, DiscreteFactor::shared_ptr edition + /** + * @brief Multiply factors, DiscreteFactor::shared_ptr edition. + * + * This method accepts `DiscreteFactor::shared_ptr` and uses dynamic + * dispatch and specializations to perform the most efficient + * multiplication. + * + * While converting a DecisionTreeFactor to a TableFactor is efficient, the + * reverse is not. + * Hence we specialize the code to return a TableFactor always. + * + * @param f The factor to multiply with. + * @return DiscreteFactor::shared_ptr + */ virtual DiscreteFactor::shared_ptr multiply( const DiscreteFactor::shared_ptr& f) const override; From f932945652b6f60f266513904efd04fbb6eddef5 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 6 Jan 2025 18:30:34 -0500 Subject: [PATCH 84/86] check pointer casts --- gtsam/discrete/tests/testDecisionTreeFactor.cpp | 3 +++ gtsam/discrete/tests/testTableFactor.cpp | 3 +++ 2 files changed, 6 insertions(+) diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index b4c5acc1b..88045ce3d 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -112,14 +112,17 @@ TEST(DecisionTreeFactor, sum_max) { DecisionTreeFactor expected(v1, "9 12"); auto actual = std::dynamic_pointer_cast(f1.sum(1)); + CHECK(actual); CHECK(assert_equal(expected, *actual, 1e-5)); DecisionTreeFactor expected2(v1, "5 6"); auto actual2 = std::dynamic_pointer_cast(f1.max(1)); + CHECK(actual2); CHECK(assert_equal(expected2, *actual2)); DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6"); auto actual22 = std::dynamic_pointer_cast(f2.sum(1)); + CHECK(actual22); } /* ************************************************************************* */ diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp index 4f6ec2f39..76a0f2b5c 100644 --- a/gtsam/discrete/tests/testTableFactor.cpp +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -306,14 +306,17 @@ TEST(TableFactor, sum_max) { TableFactor expected(v1, "9 12"); auto actual = std::dynamic_pointer_cast(f1.sum(1)); + CHECK(actual); CHECK(assert_equal(expected, *actual, 1e-5)); TableFactor expected2(v1, "5 6"); auto actual2 = std::dynamic_pointer_cast(f1.max(1)); + CHECK(actual2); CHECK(assert_equal(expected2, *actual2)); TableFactor f2(v1 & v0, "1 2 3 4 5 6"); auto actual22 = std::dynamic_pointer_cast(f2.sum(1)); + CHECK(actual22); } /* ************************************************************************* */ From f8dedb503592f648c4d1d22a353dd201a760b697 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 6 Jan 2025 18:37:40 -0500 Subject: [PATCH 85/86] use DiscreteFactor for DiscreteConditional constructor --- gtsam/discrete/DiscreteConditional.cpp | 4 ++-- gtsam/discrete/DiscreteConditional.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 06a08eca5..19dcdc729 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -44,8 +44,8 @@ template class GTSAM_EXPORT /* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const size_t nrFrontals, - const DecisionTreeFactor& f) - : BaseFactor(f / f.sum(nrFrontals)->toDecisionTreeFactor()), + const DiscreteFactor& f) + : BaseFactor((f / f.sum(nrFrontals))->toDecisionTreeFactor()), BaseConditional(nrFrontals) {} /* ************************************************************************** */ diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 3ec9ae590..67f8a0056 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -54,7 +54,7 @@ class GTSAM_EXPORT DiscreteConditional DiscreteConditional() {} /// Construct from factor, taking the first `nFrontals` keys as frontals. - DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f); + DiscreteConditional(size_t nFrontals, const DiscreteFactor& f); /** * Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first From c754f9bfdcb6c213003a511924a5c4df69f5b91c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 6 Jan 2025 18:47:44 -0500 Subject: [PATCH 86/86] add comments --- gtsam/discrete/DecisionTreeFactor.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 2f2c039a4..6e25c6452 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -81,12 +81,18 @@ namespace gtsam { DiscreteFactor::shared_ptr DecisionTreeFactor::operator/( const DiscreteFactor::shared_ptr& f) const { if (auto tf = std::dynamic_pointer_cast(f)) { + // Check if `f` is a TableFactor. If yes, then + // convert `this` to a TableFactor which is cheaper. return std::make_shared(tf->operator/(TableFactor(*this))); + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + // If `f` is a DecisionTreeFactor, divide normally. return std::make_shared(this->operator/(*dtf)); + } else { + // Else, convert `f` to a DecisionTreeFactor so we can divide return std::make_shared( - this->operator/(this->toDecisionTreeFactor())); + this->operator/(f->toDecisionTreeFactor())); } }