From d1d440ad3420efb6a35bef80f39699b6b075e810 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sat, 7 Dec 2024 10:53:32 -0500
Subject: [PATCH 01/86] add nrValues method
---
gtsam/discrete/DecisionTreeFactor.h | 6 ++++++
gtsam/discrete/DiscreteFactor.h | 6 ++++++
gtsam/discrete/TableFactor.h | 6 ++++++
gtsam_unstable/discrete/AllDiff.h | 3 +++
gtsam_unstable/discrete/BinaryAllDiff.h | 3 +++
gtsam_unstable/discrete/Domain.h | 2 +-
gtsam_unstable/discrete/SingleValue.h | 3 +++
7 files changed, 28 insertions(+), 1 deletion(-)
diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h
index a8ab2644f..f417a38d7 100644
--- a/gtsam/discrete/DecisionTreeFactor.h
+++ b/gtsam/discrete/DecisionTreeFactor.h
@@ -255,6 +255,12 @@ namespace gtsam {
*/
DecisionTreeFactor prune(size_t maxNrAssignments) const;
+ /**
+ * Get the number of non-zero values contained in this factor.
+ * It could be much smaller than `prod_{key}(cardinality(key))`.
+ */
+ uint64_t nrValues() const override { return nrLeaves(); }
+
/// @}
/// @name Wrapper support
/// @{
diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h
index 19af5bd13..7d5047ec6 100644
--- a/gtsam/discrete/DiscreteFactor.h
+++ b/gtsam/discrete/DiscreteFactor.h
@@ -113,6 +113,12 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
+ /**
+ * Get the number of non-zero values contained in this factor.
+ * It could be much smaller than `prod_{key}(cardinality(key))`.
+ */
+ virtual uint64_t nrValues() const = 0;
+
/// @}
/// @name Wrapper support
/// @{
diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h
index f0ecd66a3..b988eebad 100644
--- a/gtsam/discrete/TableFactor.h
+++ b/gtsam/discrete/TableFactor.h
@@ -324,6 +324,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
*/
TableFactor prune(size_t maxNrAssignments) const;
+ /**
+ * Get the number of non-zero values contained in this factor.
+ * It could be much smaller than `prod_{key}(cardinality(key))`.
+ */
+ uint64_t nrValues() const override { return sparse_table_.nonZeros(); }
+
/// @}
/// @name Wrapper support
/// @{
diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h
index d7a63eae0..42a255bbf 100644
--- a/gtsam_unstable/discrete/AllDiff.h
+++ b/gtsam_unstable/discrete/AllDiff.h
@@ -72,6 +72,9 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
/// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply(
const Domains&) const override;
+
+ /// Get the number of non-zero values contained in this factor.
+ uint64_t nrValues() const override { return 1; };
};
} // namespace gtsam
diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h
index 18b335092..22acfb092 100644
--- a/gtsam_unstable/discrete/BinaryAllDiff.h
+++ b/gtsam_unstable/discrete/BinaryAllDiff.h
@@ -96,6 +96,9 @@ class BinaryAllDiff : public Constraint {
AlgebraicDecisionTree errorTree() const override {
throw std::runtime_error("BinaryAllDiff::error not implemented");
}
+
+ /// Get the number of non-zero values contained in this factor.
+ uint64_t nrValues() const override { return 1; };
};
} // namespace gtsam
diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h
index 7f7b717c2..ba3771eca 100644
--- a/gtsam_unstable/discrete/Domain.h
+++ b/gtsam_unstable/discrete/Domain.h
@@ -49,7 +49,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
/// Erase a value, non const :-(
void erase(size_t value) { values_.erase(value); }
- size_t nrValues() const { return values_.size(); }
+ uint64_t nrValues() const override { return values_.size(); }
bool isSingleton() const { return nrValues() == 1; }
diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h
index 3f7f22d6a..7f2eb2c2c 100644
--- a/gtsam_unstable/discrete/SingleValue.h
+++ b/gtsam_unstable/discrete/SingleValue.h
@@ -77,6 +77,9 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
/// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply(
const Domains& domains) const override;
+
+ /// Get the number of non-zero values contained in this factor.
+ uint64_t nrValues() const override { return 1; };
};
} // namespace gtsam
From a68da21527760daff64161f3feffc6cc1d46d1b1 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sat, 7 Dec 2024 11:02:30 -0500
Subject: [PATCH 02/86] operator* version which accepts DiscreteFactor
---
gtsam/discrete/DecisionTreeFactor.cpp | 11 +++++++++++
gtsam/discrete/DecisionTreeFactor.h | 5 ++++-
gtsam/discrete/DiscreteFactor.h | 7 ++++---
gtsam/discrete/TableFactor.cpp | 9 +++++++--
gtsam/discrete/TableFactor.h | 5 +++--
5 files changed, 29 insertions(+), 8 deletions(-)
diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp
index 9ec3b0ac5..e53f8cb90 100644
--- a/gtsam/discrete/DecisionTreeFactor.cpp
+++ b/gtsam/discrete/DecisionTreeFactor.cpp
@@ -82,6 +82,17 @@ namespace gtsam {
ADT::print("", formatter);
}
+ /* ************************************************************************ */
+ DiscreteFactor::shared_ptr DecisionTreeFactor::operator*(
+ const DiscreteFactor::shared_ptr& f) const {
+ if (auto derived = std::dynamic_pointer_cast(f)) {
+ return std::make_shared(this->operator*(*derived));
+ } else {
+ throw std::runtime_error(
+ "Cannot convert DiscreteFactor to DecisionTreeFactor");
+ }
+ }
+
/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(ADT::Unary op) const {
// apply operand
diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h
index f417a38d7..7afbab0b0 100644
--- a/gtsam/discrete/DecisionTreeFactor.h
+++ b/gtsam/discrete/DecisionTreeFactor.h
@@ -144,10 +144,13 @@ namespace gtsam {
double error(const DiscreteValues& values) const override;
/// multiply two factors
- DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
+ DecisionTreeFactor operator*(const DecisionTreeFactor& f) const {
return apply(f, ADT::Ring::mul);
}
+ DiscreteFactor::shared_ptr operator*(
+ const DiscreteFactor::shared_ptr& f) const override;
+
static double safe_div(const double& a, const double& b);
/// divide by factor f (safely)
diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h
index 7d5047ec6..4c486dca8 100644
--- a/gtsam/discrete/DiscreteFactor.h
+++ b/gtsam/discrete/DiscreteFactor.h
@@ -107,9 +107,10 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
/// Compute error for each assignment and return as a tree
virtual AlgebraicDecisionTree errorTree() const;
- /// Multiply in a DecisionTreeFactor and return the result as
- /// DecisionTreeFactor
- virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
+ /// Multiply in a DiscreteFactor and return the result as
+ /// DiscreteFactor
+ virtual DiscreteFactor::shared_ptr operator*(
+ const DiscreteFactor::shared_ptr&) const = 0;
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp
index f4e023a4d..7cf520973 100644
--- a/gtsam/discrete/TableFactor.cpp
+++ b/gtsam/discrete/TableFactor.cpp
@@ -169,8 +169,13 @@ double TableFactor::error(const HybridValues& values) const {
}
/* ************************************************************************ */
-DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
- return toDecisionTreeFactor() * f;
+DiscreteFactor::shared_ptr TableFactor::operator*(
+ const DiscreteFactor::shared_ptr& f) const {
+ if (auto derived = std::dynamic_pointer_cast(f)) {
+ return std::make_shared(this->operator*(*derived));
+ } else {
+ throw std::runtime_error("Cannot convert DiscreteFactor to TableFactor");
+ }
}
/* ************************************************************************ */
diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h
index b988eebad..29cbd5e9b 100644
--- a/gtsam/discrete/TableFactor.h
+++ b/gtsam/discrete/TableFactor.h
@@ -186,8 +186,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
return apply(f, Ring::mul);
};
- /// multiply with DecisionTreeFactor
- DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
+ /// multiply with DiscreteFactor
+ DiscreteFactor::shared_ptr operator*(
+ const DiscreteFactor::shared_ptr& f) const override;
static double safe_div(const double& a, const double& b);
From a09b77ef407b7ac91b1604153d7b2ec08301b4b8 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sat, 7 Dec 2024 11:07:26 -0500
Subject: [PATCH 03/86] return DiscreteFactor shared_ptr as leftover from
elimination
---
gtsam/discrete/DiscreteFactorGraph.cpp | 4 ++--
gtsam/discrete/DiscreteFactorGraph.h | 7 ++++---
2 files changed, 6 insertions(+), 5 deletions(-)
diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp
index 4ededbb8b..f81c6085e 100644
--- a/gtsam/discrete/DiscreteFactorGraph.cpp
+++ b/gtsam/discrete/DiscreteFactorGraph.cpp
@@ -112,7 +112,7 @@ namespace gtsam {
/* ************************************************************************ */
// Alternate eliminate function for MPE
- std::pair //
+ std::pair
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
@@ -201,7 +201,7 @@ namespace gtsam {
}
/* ************************************************************************ */
- std::pair //
+ std::pair
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h
index d0dc282b4..a5324811c 100644
--- a/gtsam/discrete/DiscreteFactorGraph.h
+++ b/gtsam/discrete/DiscreteFactorGraph.h
@@ -48,7 +48,7 @@ class DiscreteJunctionTree;
* @ingroup discrete
*/
GTSAM_EXPORT
-std::pair
+std::pair
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys);
@@ -61,7 +61,7 @@ EliminateDiscrete(const DiscreteFactorGraph& factors,
* @ingroup discrete
*/
GTSAM_EXPORT
-std::pair
+std::pair
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys);
@@ -133,6 +133,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
/// @}
+ //TODO(Varun): Make compatible with TableFactor
/** Add a decision-tree factor */
template
void add(Args&&... args) {
@@ -146,7 +147,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
DiscreteKeys discreteKeys() const;
/** return product of all factors as a single factor */
- DecisionTreeFactor product() const;
+ DiscreteFactor::shared_ptr product() const;
/**
* Evaluates the factor graph given values, returns the joint probability of
From 27bbce150aa2dac82e57ff84a91c17a9873fab7b Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sat, 7 Dec 2024 11:10:24 -0500
Subject: [PATCH 04/86] generalize DiscreteFactorGraph::product to
DiscreteFactor
---
gtsam/discrete/DiscreteFactorGraph.cpp | 14 ++++++++++----
1 file changed, 10 insertions(+), 4 deletions(-)
diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp
index f81c6085e..b27438130 100644
--- a/gtsam/discrete/DiscreteFactorGraph.cpp
+++ b/gtsam/discrete/DiscreteFactorGraph.cpp
@@ -64,10 +64,16 @@ namespace gtsam {
}
/* ************************************************************************* */
- DecisionTreeFactor DiscreteFactorGraph::product() const {
- DecisionTreeFactor result;
- for(const sharedFactor& factor: *this)
- if (factor) result = (*factor) * result;
+ DiscreteFactor::shared_ptr DiscreteFactorGraph::product() const {
+ sharedFactor result = this->at(0);
+ for (size_t i = 1; i < this->size(); ++i) {
+ const sharedFactor factor = this->at(i);
+ if (factor) {
+ // Predicated on the fact that all discrete factors are of a single type
+ // so there is no type-conversion happening which can be expensive.
+ result = result->operator*(factor);
+ }
+ }
return result;
}
From 84e419456af55c0d194645c302934316146b3803 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sat, 7 Dec 2024 11:15:06 -0500
Subject: [PATCH 05/86] make normalization code common
---
gtsam/discrete/DiscreteFactorGraph.cpp | 50 ++++++++++++++++----------
1 file changed, 32 insertions(+), 18 deletions(-)
diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp
index b27438130..29bd1f9ac 100644
--- a/gtsam/discrete/DiscreteFactorGraph.cpp
+++ b/gtsam/discrete/DiscreteFactorGraph.cpp
@@ -116,6 +116,28 @@ namespace gtsam {
// }
// }
+ /**
+ * @brief Helper method to normalize the product factor by
+ * the max value to prevent underflow
+ *
+ * @param product The product discrete factor.
+ * @return DiscreteFactor::shared_ptr
+ */
+ static DiscreteFactor::shared_ptr Normalize(
+ const DiscreteFactor::shared_ptr& product) {
+ // Max over all the potentials by pretending all keys are frontal:
+ gttic(DiscreteFindMax);
+ auto normalization = product->max(product->size());
+ gttoc(DiscreteFindMax);
+
+ gttic(DiscreteNormalization);
+ // Normalize the product factor to prevent underflow.
+ auto normalized_product = product->operator/(normalization);
+ gttoc(DiscreteNormalization);
+
+ return normalized_product;
+ }
+
/* ************************************************************************ */
// Alternate eliminate function for MPE
std::pair
@@ -123,27 +145,23 @@ namespace gtsam {
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
gttic(product);
- DecisionTreeFactor product;
- for (auto&& factor : factors) product = (*factor) * product;
+ DiscreteFactor::shared_ptr product = factors.product();
gttoc(product);
- // Max over all the potentials by pretending all keys are frontal:
- auto normalization = product.max(product.size());
-
- // Normalize the product factor to prevent underflow.
- product = product / (*normalization);
+ // Normalize the product
+ product = Normalize(product);
// max out frontals, this is the factor on the separator
gttic(max);
- DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
+ DecisionTreeFactor::shared_ptr max = product->max(frontalKeys);
gttoc(max);
// Ordering keys for the conditional so that frontalKeys are really in front
DiscreteKeys orderedKeys;
for (auto&& key : frontalKeys)
- orderedKeys.emplace_back(key, product.cardinality(key));
+ orderedKeys.emplace_back(key, product->cardinality(key));
for (auto&& key : max->keys())
- orderedKeys.emplace_back(key, product.cardinality(key));
+ orderedKeys.emplace_back(key, product->cardinality(key));
// Make lookup with product
gttic(lookup);
@@ -212,19 +230,15 @@ namespace gtsam {
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
gttic(product);
- DecisionTreeFactor product;
- for (auto&& factor : factors) product = (*factor) * product;
+ DiscreteFactor::shared_ptr product = factors.product();
gttoc(product);
- // Max over all the potentials by pretending all keys are frontal:
- auto normalization = product.max(product.size());
-
- // Normalize the product factor to prevent underflow.
- product = product / (*normalization);
+ // Normalize the product
+ product = Normalize(product);
// sum out frontals, this is the factor on the separator
gttic(sum);
- DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
+ DecisionTreeFactor::shared_ptr sum = product->sum(frontalKeys);
gttoc(sum);
// Ordering keys for the conditional so that frontalKeys are really in front
From 4dac37ce2b0000fd51f67d361c9fbcbd744c06ab Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sat, 7 Dec 2024 11:16:49 -0500
Subject: [PATCH 06/86] make sum and max DiscreteFactor methods
---
gtsam/discrete/DecisionTreeFactor.h | 19 +++++++++++++++----
gtsam/discrete/DiscreteFactor.h | 17 +++++++++++++++++
gtsam/discrete/TableFactor.h | 18 ++++++++++++++----
3 files changed, 46 insertions(+), 8 deletions(-)
diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h
index 7afbab0b0..8445c5332 100644
--- a/gtsam/discrete/DecisionTreeFactor.h
+++ b/gtsam/discrete/DecisionTreeFactor.h
@@ -158,26 +158,37 @@ namespace gtsam {
return apply(f, safe_div);
}
+ /// divide by factor f (pointer version)
+ DiscreteFactor::shared_ptr operator/(
+ const DiscreteFactor::shared_ptr& f) const override {
+ if (auto derived = std::dynamic_pointer_cast(f)) {
+ return std::make_shared(apply(*derived, safe_div));
+ } else {
+ throw std::runtime_error(
+ "Cannot convert DiscreteFactor to Table Factor");
+ }
+ }
+
/// Convert into a decision tree
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
/// Create new factor by summing all values with the same separator values
- shared_ptr sum(size_t nrFrontals) const {
+ DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
return combine(nrFrontals, ADT::Ring::add);
}
/// Create new factor by summing all values with the same separator values
- shared_ptr sum(const Ordering& keys) const {
+ DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
return combine(keys, ADT::Ring::add);
}
/// Create new factor by maximizing over all values with the same separator.
- shared_ptr max(size_t nrFrontals) const {
+ DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
return combine(nrFrontals, ADT::Ring::max);
}
/// Create new factor by maximizing over all values with the same separator.
- shared_ptr max(const Ordering& keys) const {
+ DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
return combine(keys, ADT::Ring::max);
}
diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h
index 4c486dca8..1ada7b7b2 100644
--- a/gtsam/discrete/DiscreteFactor.h
+++ b/gtsam/discrete/DiscreteFactor.h
@@ -22,6 +22,7 @@
#include
#include
#include
+#include
#include
namespace gtsam {
@@ -114,6 +115,22 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
+ /// Create new factor by summing all values with the same separator values
+ virtual DiscreteFactor::shared_ptr sum(size_t nrFrontals) const = 0;
+
+ /// Create new factor by summing all values with the same separator values
+ virtual DiscreteFactor::shared_ptr sum(const Ordering& keys) const = 0;
+
+ /// Create new factor by maximizing over all values with the same separator.
+ virtual DiscreteFactor::shared_ptr max(size_t nrFrontals) const = 0;
+
+ /// Create new factor by maximizing over all values with the same separator.
+ virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0;
+
+ /// divide by factor f (safely)
+ virtual DiscreteFactor::shared_ptr operator/(
+ const DiscreteFactor::shared_ptr& f) const = 0;
+
/**
* Get the number of non-zero values contained in this factor.
* It could be much smaller than `prod_{key}(cardinality(key))`.
diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h
index 29cbd5e9b..e452b5be0 100644
--- a/gtsam/discrete/TableFactor.h
+++ b/gtsam/discrete/TableFactor.h
@@ -197,6 +197,16 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
return apply(f, safe_div);
}
+ /// divide by factor f (pointer version)
+ DiscreteFactor::shared_ptr operator/(
+ const DiscreteFactor::shared_ptr& f) const override {
+ if (auto derived = std::dynamic_pointer_cast(f)) {
+ return std::make_shared(apply(*derived, safe_div));
+ } else {
+ throw std::runtime_error("Cannot convert DiscreteFactor to Table Factor");
+ }
+ }
+
/// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override;
@@ -205,22 +215,22 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
DiscreteKeys parent_keys) const;
/// Create new factor by summing all values with the same separator values
- shared_ptr sum(size_t nrFrontals) const {
+ DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::add);
}
/// Create new factor by summing all values with the same separator values
- shared_ptr sum(const Ordering& keys) const {
+ DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
return combine(keys, Ring::add);
}
/// Create new factor by maximizing over all values with the same separator.
- shared_ptr max(size_t nrFrontals) const {
+ DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::max);
}
/// Create new factor by maximizing over all values with the same separator.
- shared_ptr max(const Ordering& keys) const {
+ DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
return combine(keys, Ring::max);
}
From 6c4546779a2dff7a03a1f1cc8978e3facbf3cd16 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sat, 7 Dec 2024 18:20:50 -0500
Subject: [PATCH 07/86] add timing info
---
gtsam/discrete/DiscreteFactorGraph.cpp | 24 +++++++++++++++---------
gtsam/discrete/TableFactor.cpp | 6 ++++++
2 files changed, 21 insertions(+), 9 deletions(-)
diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp
index 29bd1f9ac..904108afb 100644
--- a/gtsam/discrete/DiscreteFactorGraph.cpp
+++ b/gtsam/discrete/DiscreteFactorGraph.cpp
@@ -22,6 +22,7 @@
#include
#include
#include
+#include
#include
#include
@@ -144,12 +145,15 @@ namespace gtsam {
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
- gttic(product);
+ gttic_(MPEProduct);
DiscreteFactor::shared_ptr product = factors.product();
- gttoc(product);
+ gttoc_(MPEProduct);
+
+ gttic_(Normalize);
// Normalize the product
product = Normalize(product);
+ gttoc_(Normalize);
// max out frontals, this is the factor on the separator
gttic(max);
@@ -229,17 +233,19 @@ namespace gtsam {
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
- gttic(product);
+ gttic_(product);
DiscreteFactor::shared_ptr product = factors.product();
- gttoc(product);
+ gttoc_(product);
- // Normalize the product
+ gttic_(Normalize);
+ // Normalize the product
product = Normalize(product);
+ gttoc_(Normalize);
// sum out frontals, this is the factor on the separator
- gttic(sum);
+ gttic_(sum);
DecisionTreeFactor::shared_ptr sum = product->sum(frontalKeys);
- gttoc(sum);
+ gttoc_(sum);
// Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys;
@@ -249,10 +255,10 @@ namespace gtsam {
sum->keys().end());
// now divide product/sum to get conditional
- gttic(divide);
+ gttic_(divide);
auto conditional =
std::make_shared(product, *sum, orderedKeys);
- gttoc(divide);
+ gttoc_(divide);
return {conditional, sum};
}
diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp
index 7cf520973..b867fa916 100644
--- a/gtsam/discrete/TableFactor.cpp
+++ b/gtsam/discrete/TableFactor.cpp
@@ -72,7 +72,9 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
*/
std::vector ComputeLeafOrdering(const DiscreteKeys& dkeys,
const DecisionTreeFactor& dt) {
+ gttic_(ComputeLeafOrdering);
std::vector probs = dt.probabilities();
+ gttoc_(ComputeLeafOrdering);
std::vector ordered;
size_t n = dkeys[0].second;
@@ -180,12 +182,16 @@ DiscreteFactor::shared_ptr TableFactor::operator*(
/* ************************************************************************ */
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
+ gttic_(toDecisionTreeFactor);
DiscreteKeys dkeys = discreteKeys();
std::vector table;
for (auto i = 0; i < sparse_table_.size(); i++) {
table.push_back(sparse_table_.coeff(i));
}
+ gttoc_(toDecisionTreeFactor);
+ gttic_(toDecisionTreeFactor_Constructor);
DecisionTreeFactor f(dkeys, table);
+ gttoc_(toDecisionTreeFactor_Constructor);
return f;
}
From b0ad350a20a410aa2658839a7707a910d7129bc6 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sat, 7 Dec 2024 18:22:30 -0500
Subject: [PATCH 08/86] add note about toDecisionTreeFactor
---
gtsam/discrete/TableFactor.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp
index b867fa916..53131616d 100644
--- a/gtsam/discrete/TableFactor.cpp
+++ b/gtsam/discrete/TableFactor.cpp
@@ -182,14 +182,13 @@ DiscreteFactor::shared_ptr TableFactor::operator*(
/* ************************************************************************ */
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
- gttic_(toDecisionTreeFactor);
DiscreteKeys dkeys = discreteKeys();
std::vector table;
for (auto i = 0; i < sparse_table_.size(); i++) {
table.push_back(sparse_table_.coeff(i));
}
- gttoc_(toDecisionTreeFactor);
gttic_(toDecisionTreeFactor_Constructor);
+ // NOTE(Varun): This constructor is really expensive!!
DecisionTreeFactor f(dkeys, table);
gttoc_(toDecisionTreeFactor_Constructor);
return f;
From 306a3bae527af21cd121cff602bfa97fa3a5966f Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sat, 7 Dec 2024 18:25:27 -0500
Subject: [PATCH 09/86] kill toDecisionTreeFactor to force rethink
---
gtsam/discrete/DecisionTreeFactor.h | 3 ---
gtsam/discrete/DiscreteFactor.h | 2 --
gtsam/discrete/TableFactor.cpp | 14 --------------
gtsam/discrete/TableFactor.h | 3 ---
4 files changed, 22 deletions(-)
diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h
index 8445c5332..642187ff1 100644
--- a/gtsam/discrete/DecisionTreeFactor.h
+++ b/gtsam/discrete/DecisionTreeFactor.h
@@ -169,9 +169,6 @@ namespace gtsam {
}
}
- /// Convert into a decision tree
- DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
-
/// Create new factor by summing all values with the same separator values
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
return combine(nrFrontals, ADT::Ring::add);
diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h
index 1ada7b7b2..29984d795 100644
--- a/gtsam/discrete/DiscreteFactor.h
+++ b/gtsam/discrete/DiscreteFactor.h
@@ -113,8 +113,6 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
virtual DiscreteFactor::shared_ptr operator*(
const DiscreteFactor::shared_ptr&) const = 0;
- virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
-
/// Create new factor by summing all values with the same separator values
virtual DiscreteFactor::shared_ptr sum(size_t nrFrontals) const = 0;
diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp
index 53131616d..a8adf0918 100644
--- a/gtsam/discrete/TableFactor.cpp
+++ b/gtsam/discrete/TableFactor.cpp
@@ -180,20 +180,6 @@ DiscreteFactor::shared_ptr TableFactor::operator*(
}
}
-/* ************************************************************************ */
-DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
- DiscreteKeys dkeys = discreteKeys();
- std::vector table;
- for (auto i = 0; i < sparse_table_.size(); i++) {
- table.push_back(sparse_table_.coeff(i));
- }
- gttic_(toDecisionTreeFactor_Constructor);
- // NOTE(Varun): This constructor is really expensive!!
- DecisionTreeFactor f(dkeys, table);
- gttoc_(toDecisionTreeFactor_Constructor);
- return f;
-}
-
/* ************************************************************************ */
TableFactor TableFactor::choose(const DiscreteValues parent_assign,
DiscreteKeys parent_keys) const {
diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h
index e452b5be0..12266b2a5 100644
--- a/gtsam/discrete/TableFactor.h
+++ b/gtsam/discrete/TableFactor.h
@@ -207,9 +207,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
}
}
- /// Convert into a decisiontree
- DecisionTreeFactor toDecisionTreeFactor() const override;
-
/// Create a TableFactor that is a subset of this TableFactor
TableFactor choose(const DiscreteValues assignments,
DiscreteKeys parent_keys) const;
From 2cd2ab0a43ea57c33392371d7ec1d285b1f005c3 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sat, 7 Dec 2024 18:25:40 -0500
Subject: [PATCH 10/86] DiscreteDistribution from TableFactor
---
gtsam/discrete/DiscreteDistribution.h | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)
diff --git a/gtsam/discrete/DiscreteDistribution.h b/gtsam/discrete/DiscreteDistribution.h
index 4b690da15..abe8f7933 100644
--- a/gtsam/discrete/DiscreteDistribution.h
+++ b/gtsam/discrete/DiscreteDistribution.h
@@ -40,10 +40,14 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
/// Default constructor needed for serialization.
DiscreteDistribution() {}
- /// Constructor from factor.
+ /// Constructor from DecisionTreeFactor.
explicit DiscreteDistribution(const DecisionTreeFactor& f)
: Base(f.size(), f) {}
+ /// Constructor from TableFactor.
+ explicit DiscreteDistribution(const TableFactor& f)
+ : Base(f.size(), f) {}
+
/**
* Construct from a Signature.
*
From 9f88a360dfd54f1e00361344f24d9d2ef085b50f Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sat, 7 Dec 2024 18:29:13 -0500
Subject: [PATCH 11/86] make evaluate use the Assignment base class
---
gtsam/discrete/DecisionTreeFactor.h | 4 ++--
gtsam/discrete/DiscreteFactor.h | 3 +++
gtsam/discrete/TableFactor.cpp | 2 +-
gtsam/discrete/TableFactor.h | 9 ++++++---
4 files changed, 12 insertions(+), 6 deletions(-)
diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h
index 642187ff1..bf0e23b50 100644
--- a/gtsam/discrete/DecisionTreeFactor.h
+++ b/gtsam/discrete/DecisionTreeFactor.h
@@ -129,9 +129,9 @@ namespace gtsam {
/// @name Standard Interface
/// @{
- /// Calculate probability for given values `x`,
+ /// Calculate probability for given values,
/// is just look up in AlgebraicDecisionTree.
- double evaluate(const Assignment& values) const {
+ double evaluate(const Assignment& values) const override {
return ADT::operator()(values);
}
diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h
index 29984d795..23abf725e 100644
--- a/gtsam/discrete/DiscreteFactor.h
+++ b/gtsam/discrete/DiscreteFactor.h
@@ -129,6 +129,9 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
virtual DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& f) const = 0;
+ /// Calculate probability for given values
+ virtual double evaluate(const Assignment& values) const = 0;
+
/**
* Get the number of non-zero values contained in this factor.
* It could be much smaller than `prod_{key}(cardinality(key))`.
diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp
index a8adf0918..727c96ce4 100644
--- a/gtsam/discrete/TableFactor.cpp
+++ b/gtsam/discrete/TableFactor.cpp
@@ -135,7 +135,7 @@ bool TableFactor::equals(const DiscreteFactor& other, double tol) const {
}
/* ************************************************************************ */
-double TableFactor::operator()(const DiscreteValues& values) const {
+double TableFactor::operator()(const Assignment& values) const {
// a b c d => D * (C * (B * (a) + b) + c) + d
uint64_t idx = 0, card = 1;
for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) {
diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h
index 12266b2a5..ea222ca5c 100644
--- a/gtsam/discrete/TableFactor.h
+++ b/gtsam/discrete/TableFactor.h
@@ -169,14 +169,17 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
// /// @name Standard Interface
// /// @{
- /// Calculate probability for given values `x`,
+ /// Calculate probability for given values,
/// is just look up in TableFactor.
- double evaluate(const DiscreteValues& values) const {
+ double evaluate(const Assignment& values) const override {
return operator()(values);
}
/// Evaluate probability distribution, sugar.
- double operator()(const DiscreteValues& values) const override;
+ double operator()(const Assignment& values) const;
+ double operator()(const DiscreteValues& values) const override {
+ return operator()(Assignment(values));
+ }
/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const override;
From 2a3b5e62b785c2c9de4e414f5dcb1551569a297c Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sat, 7 Dec 2024 18:59:11 -0500
Subject: [PATCH 12/86] use Assignment for evaluate since it is the base
class
---
gtsam/discrete/DiscreteFactor.h | 2 +-
gtsam/discrete/TableFactor.h | 3 ---
2 files changed, 1 insertion(+), 4 deletions(-)
diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h
index 23abf725e..4470d97a7 100644
--- a/gtsam/discrete/DiscreteFactor.h
+++ b/gtsam/discrete/DiscreteFactor.h
@@ -94,7 +94,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
size_t cardinality(Key j) const { return cardinalities_.at(j); }
/// Find value for given assignment of values to variables
- virtual double operator()(const DiscreteValues&) const = 0;
+ virtual double operator()(const DiscreteValues& values) const = 0;
/// Error is just -log(value)
virtual double error(const DiscreteValues& values) const;
diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h
index ea222ca5c..8fb04fcba 100644
--- a/gtsam/discrete/TableFactor.h
+++ b/gtsam/discrete/TableFactor.h
@@ -177,9 +177,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/// Evaluate probability distribution, sugar.
double operator()(const Assignment& values) const;
- double operator()(const DiscreteValues& values) const override {
- return operator()(Assignment(values));
- }
/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const override;
From fff8458d6bcc0ec54a8c7ebbf0145c3b8b91aafc Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sat, 7 Dec 2024 19:03:26 -0500
Subject: [PATCH 13/86] remove TableFactor constructor in DiscreteDistribution
---
gtsam/discrete/DiscreteDistribution.h | 4 ----
1 file changed, 4 deletions(-)
diff --git a/gtsam/discrete/DiscreteDistribution.h b/gtsam/discrete/DiscreteDistribution.h
index abe8f7933..09ea50332 100644
--- a/gtsam/discrete/DiscreteDistribution.h
+++ b/gtsam/discrete/DiscreteDistribution.h
@@ -44,10 +44,6 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
explicit DiscreteDistribution(const DecisionTreeFactor& f)
: Base(f.size(), f) {}
- /// Constructor from TableFactor.
- explicit DiscreteDistribution(const TableFactor& f)
- : Base(f.size(), f) {}
-
/**
* Construct from a Signature.
*
From 295b965b6894574c11786dbef03edc9022311d82 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sat, 7 Dec 2024 19:09:45 -0500
Subject: [PATCH 14/86] use Assignment since it is a base class
---
gtsam/discrete/DecisionTreeFactor.h | 2 +-
gtsam/discrete/DiscreteConditional.h | 2 +-
gtsam/discrete/DiscreteFactor.h | 2 +-
gtsam/discrete/TableFactor.h | 2 +-
4 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h
index bf0e23b50..688cd85a6 100644
--- a/gtsam/discrete/DecisionTreeFactor.h
+++ b/gtsam/discrete/DecisionTreeFactor.h
@@ -136,7 +136,7 @@ namespace gtsam {
}
/// Evaluate probability distribution, sugar.
- double operator()(const DiscreteValues& values) const override {
+ double operator()(const Assignment& values) const override {
return ADT::operator()(values);
}
diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h
index f59e29285..ce4fb96e5 100644
--- a/gtsam/discrete/DiscreteConditional.h
+++ b/gtsam/discrete/DiscreteConditional.h
@@ -169,7 +169,7 @@ class GTSAM_EXPORT DiscreteConditional
}
/// Evaluate, just look up in AlgebraicDecisionTree
- double evaluate(const DiscreteValues& values) const {
+ double evaluate(const Assignment& values) const override {
return ADT::operator()(values);
}
diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h
index 4470d97a7..4c1d0afb1 100644
--- a/gtsam/discrete/DiscreteFactor.h
+++ b/gtsam/discrete/DiscreteFactor.h
@@ -94,7 +94,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
size_t cardinality(Key j) const { return cardinalities_.at(j); }
/// Find value for given assignment of values to variables
- virtual double operator()(const DiscreteValues& values) const = 0;
+ virtual double operator()(const Assignment& values) const = 0;
/// Error is just -log(value)
virtual double error(const DiscreteValues& values) const;
diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h
index 8fb04fcba..497c42dc2 100644
--- a/gtsam/discrete/TableFactor.h
+++ b/gtsam/discrete/TableFactor.h
@@ -176,7 +176,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
}
/// Evaluate probability distribution, sugar.
- double operator()(const Assignment& values) const;
+ double operator()(const Assignment& values) const override;
/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const override;
From 261038f93620185ede195605a021c30c5b71f4d8 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sat, 7 Dec 2024 19:09:56 -0500
Subject: [PATCH 15/86] fix DiscreteConditional constructor
---
gtsam/discrete/DiscreteConditional.cpp | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp
index 5ab0c59ec..92086d143 100644
--- a/gtsam/discrete/DiscreteConditional.cpp
+++ b/gtsam/discrete/DiscreteConditional.cpp
@@ -43,7 +43,9 @@ template class GTSAM_EXPORT
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
const DecisionTreeFactor& f)
- : BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}
+ : BaseFactor(f / (*std::dynamic_pointer_cast(
+ f.sum(nrFrontals)))),
+ BaseConditional(nrFrontals) {}
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
From 20d6d09e06e705d8384d685226239c109748e7a2 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sat, 7 Dec 2024 19:12:06 -0500
Subject: [PATCH 16/86] use DiscreteFactor everywhere in
DiscreteFactorGraph.cpp
---
gtsam/discrete/DiscreteFactorGraph.cpp | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp
index 904108afb..a4f92e267 100644
--- a/gtsam/discrete/DiscreteFactorGraph.cpp
+++ b/gtsam/discrete/DiscreteFactorGraph.cpp
@@ -55,7 +55,7 @@ namespace gtsam {
DiscreteKeys DiscreteFactorGraph::discreteKeys() const {
DiscreteKeys result;
for (auto&& factor : *this) {
- if (auto p = std::dynamic_pointer_cast(factor)) {
+ if (auto p = std::dynamic_pointer_cast(factor)) {
DiscreteKeys factor_keys = p->discreteKeys();
result.insert(result.end(), factor_keys.begin(), factor_keys.end());
}
@@ -157,7 +157,7 @@ namespace gtsam {
// max out frontals, this is the factor on the separator
gttic(max);
- DecisionTreeFactor::shared_ptr max = product->max(frontalKeys);
+ DiscreteFactor::shared_ptr max = product->max(frontalKeys);
gttoc(max);
// Ordering keys for the conditional so that frontalKeys are really in front
@@ -244,7 +244,7 @@ namespace gtsam {
// sum out frontals, this is the factor on the separator
gttic_(sum);
- DecisionTreeFactor::shared_ptr sum = product->sum(frontalKeys);
+ DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
gttoc_(sum);
// Ordering keys for the conditional so that frontalKeys are really in front
@@ -257,8 +257,8 @@ namespace gtsam {
// now divide product/sum to get conditional
gttic_(divide);
auto conditional =
- std::make_shared(product, *sum, orderedKeys);
- gttoc_(divide);
+ std::make_shared(product, sum, orderedKeys);
+ gttoc(divide);
return {conditional, sum};
}
From 32b6bc0a37da20e49cc8d390dc04b3ff9bb77548 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sat, 7 Dec 2024 19:18:42 -0500
Subject: [PATCH 17/86] update DiscreteConditional
---
gtsam/discrete/DiscreteConditional.cpp | 19 ++++++++++---------
gtsam/discrete/DiscreteConditional.h | 16 ++++++++--------
2 files changed, 18 insertions(+), 17 deletions(-)
diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp
index 92086d143..2f900afbe 100644
--- a/gtsam/discrete/DiscreteConditional.cpp
+++ b/gtsam/discrete/DiscreteConditional.cpp
@@ -37,8 +37,7 @@ using std::vector;
namespace gtsam {
// Instantiate base class
-template class GTSAM_EXPORT
- Conditional;
+template class GTSAM_EXPORT Conditional;
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
@@ -54,15 +53,17 @@ DiscreteConditional::DiscreteConditional(size_t nrFrontals,
: BaseFactor(keys, potentials), BaseConditional(nrFrontals) {}
/* ************************************************************************** */
-DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
- const DecisionTreeFactor& marginal)
- : BaseFactor(joint / marginal),
- BaseConditional(joint.size() - marginal.size()) {}
+DiscreteConditional::DiscreteConditional(
+ const DiscreteFactor::shared_ptr& joint,
+ const DiscreteFactor::shared_ptr& marginal)
+ : BaseFactor(*std::dynamic_pointer_cast(
+ joint->operator/(marginal))),
+ BaseConditional(joint->size() - marginal->size()) {}
/* ************************************************************************** */
-DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
- const DecisionTreeFactor& marginal,
- const Ordering& orderedKeys)
+DiscreteConditional::DiscreteConditional(
+ const DiscreteFactor::shared_ptr& joint,
+ const DiscreteFactor::shared_ptr& marginal, const Ordering& orderedKeys)
: DiscreteConditional(joint, marginal) {
keys_.clear();
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h
index ce4fb96e5..ec2c5c38d 100644
--- a/gtsam/discrete/DiscreteConditional.h
+++ b/gtsam/discrete/DiscreteConditional.h
@@ -110,16 +110,16 @@ class GTSAM_EXPORT DiscreteConditional
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
*/
- DiscreteConditional(const DecisionTreeFactor& joint,
- const DecisionTreeFactor& marginal);
+ DiscreteConditional(const DiscreteFactor::shared_ptr& joint,
+ const DiscreteFactor::shared_ptr& marginal);
/**
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
* Makes sure the keys are ordered as given. Does not check orderedKeys.
*/
- DiscreteConditional(const DecisionTreeFactor& joint,
- const DecisionTreeFactor& marginal,
+ DiscreteConditional(const DiscreteFactor::shared_ptr& joint,
+ const DiscreteFactor::shared_ptr& marginal,
const Ordering& orderedKeys);
/**
@@ -173,8 +173,8 @@ class GTSAM_EXPORT DiscreteConditional
return ADT::operator()(values);
}
- using DecisionTreeFactor::error; ///< DiscreteValues version
- using DecisionTreeFactor::operator(); ///< DiscreteValues version
+ using DiscreteFactor::error; ///< DiscreteValues version
+ using DiscreteFactor::operator(); ///< DiscreteValues version
/**
* @brief restrict to given *parent* values.
@@ -192,11 +192,11 @@ class GTSAM_EXPORT DiscreteConditional
shared_ptr choose(const DiscreteValues& given) const;
/** Convert to a likelihood factor by providing value before bar. */
- DecisionTreeFactor::shared_ptr likelihood(
+ DiscreteFactor::shared_ptr likelihood(
const DiscreteValues& frontalValues) const;
/** Single variable version of likelihood. */
- DecisionTreeFactor::shared_ptr likelihood(size_t frontal) const;
+ DiscreteFactor::shared_ptr likelihood(size_t frontal) const;
/**
* sample
From 38563da342f233f81f356d76cc78d034db6b9d4f Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sat, 7 Dec 2024 19:24:04 -0500
Subject: [PATCH 18/86] Revert "kill toDecisionTreeFactor to force rethink"
This reverts commit 306a3bae527af21cd121cff602bfa97fa3a5966f.
---
gtsam/discrete/DecisionTreeFactor.h | 3 +++
gtsam/discrete/DiscreteFactor.h | 2 ++
gtsam/discrete/TableFactor.cpp | 14 ++++++++++++++
gtsam/discrete/TableFactor.h | 3 +++
4 files changed, 22 insertions(+)
diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h
index 688cd85a6..f8f2835e5 100644
--- a/gtsam/discrete/DecisionTreeFactor.h
+++ b/gtsam/discrete/DecisionTreeFactor.h
@@ -169,6 +169,9 @@ namespace gtsam {
}
}
+ /// Convert into a decision tree
+ DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
+
/// Create new factor by summing all values with the same separator values
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
return combine(nrFrontals, ADT::Ring::add);
diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h
index 4c1d0afb1..e2d32e828 100644
--- a/gtsam/discrete/DiscreteFactor.h
+++ b/gtsam/discrete/DiscreteFactor.h
@@ -113,6 +113,8 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
virtual DiscreteFactor::shared_ptr operator*(
const DiscreteFactor::shared_ptr&) const = 0;
+ virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
+
/// Create new factor by summing all values with the same separator values
virtual DiscreteFactor::shared_ptr sum(size_t nrFrontals) const = 0;
diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp
index 727c96ce4..50d15ff5e 100644
--- a/gtsam/discrete/TableFactor.cpp
+++ b/gtsam/discrete/TableFactor.cpp
@@ -180,6 +180,20 @@ DiscreteFactor::shared_ptr TableFactor::operator*(
}
}
+/* ************************************************************************ */
+DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
+ DiscreteKeys dkeys = discreteKeys();
+ std::vector table;
+ for (auto i = 0; i < sparse_table_.size(); i++) {
+ table.push_back(sparse_table_.coeff(i));
+ }
+ gttic_(toDecisionTreeFactor_Constructor);
+ // NOTE(Varun): This constructor is really expensive!!
+ DecisionTreeFactor f(dkeys, table);
+ gttoc_(toDecisionTreeFactor_Constructor);
+ return f;
+}
+
/* ************************************************************************ */
TableFactor TableFactor::choose(const DiscreteValues parent_assign,
DiscreteKeys parent_keys) const {
diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h
index 497c42dc2..47a7c6bbb 100644
--- a/gtsam/discrete/TableFactor.h
+++ b/gtsam/discrete/TableFactor.h
@@ -207,6 +207,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
}
}
+ /// Convert into a decisiontree
+ DecisionTreeFactor toDecisionTreeFactor() const override;
+
/// Create a TableFactor that is a subset of this TableFactor
TableFactor choose(const DiscreteValues assignments,
DiscreteKeys parent_keys) const;
From 9633ad1fd8631a3abda75a25e5217e516119e458 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sat, 7 Dec 2024 19:24:29 -0500
Subject: [PATCH 19/86] make DiscreteConditional::likelihood match the
declaration
---
gtsam/discrete/DiscreteConditional.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp
index 2f900afbe..048f35e5b 100644
--- a/gtsam/discrete/DiscreteConditional.cpp
+++ b/gtsam/discrete/DiscreteConditional.cpp
@@ -201,7 +201,7 @@ DiscreteConditional::shared_ptr DiscreteConditional::choose(
}
/* ************************************************************************** */
-DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
+DiscreteFactor::shared_ptr DiscreteConditional::likelihood(
const DiscreteValues& frontalValues) const {
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the frontal variables.
@@ -226,7 +226,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
}
/* ****************************************************************************/
-DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
+DiscreteFactor::shared_ptr DiscreteConditional::likelihood(
size_t frontal) const {
if (nrFrontals() != 1)
throw std::invalid_argument(
From 0b3477fc5ab428a6bbc65ca2b95e02d84b982593 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sat, 7 Dec 2024 19:39:44 -0500
Subject: [PATCH 20/86] get different classes to play nicely
---
gtsam/discrete/DiscreteConditional.h | 4 ++--
gtsam/discrete/DiscreteDistribution.h | 3 +--
2 files changed, 3 insertions(+), 4 deletions(-)
diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h
index ec2c5c38d..77003f232 100644
--- a/gtsam/discrete/DiscreteConditional.h
+++ b/gtsam/discrete/DiscreteConditional.h
@@ -173,8 +173,8 @@ class GTSAM_EXPORT DiscreteConditional
return ADT::operator()(values);
}
- using DiscreteFactor::error; ///< DiscreteValues version
- using DiscreteFactor::operator(); ///< DiscreteValues version
+ using DecisionTreeFactor::error; ///< DiscreteValues version
+ using DecisionTreeFactor::operator(); ///< DiscreteValues version
/**
* @brief restrict to given *parent* values.
diff --git a/gtsam/discrete/DiscreteDistribution.h b/gtsam/discrete/DiscreteDistribution.h
index 09ea50332..28e509f15 100644
--- a/gtsam/discrete/DiscreteDistribution.h
+++ b/gtsam/discrete/DiscreteDistribution.h
@@ -86,8 +86,7 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
double operator()(size_t value) const;
/// We also want to keep the Base version, taking DiscreteValues:
- // TODO(dellaert): does not play well with wrapper!
- // using Base::operator();
+ using Base::operator();
/// Return entire probability mass function.
std::vector pmf() const;
From 1d7918841797e3e970587cbec1c60aeaafa21566 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sat, 7 Dec 2024 19:40:20 -0500
Subject: [PATCH 21/86] compiles
---
gtsam/discrete/DiscreteFactorGraph.cpp | 6 ++++--
gtsam/discrete/DiscreteLookupDAG.h | 7 +++++++
2 files changed, 11 insertions(+), 2 deletions(-)
diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp
index a4f92e267..c2a16159b 100644
--- a/gtsam/discrete/DiscreteFactorGraph.cpp
+++ b/gtsam/discrete/DiscreteFactorGraph.cpp
@@ -170,8 +170,10 @@ namespace gtsam {
// Make lookup with product
gttic(lookup);
size_t nrFrontals = frontalKeys.size();
- auto lookup = std::make_shared(nrFrontals,
- orderedKeys, product);
+ //TODO(Varun): Should accept a DiscreteFactor::shared_ptr
+ auto lookup = std::make_shared(
+ nrFrontals, orderedKeys,
+ *std::dynamic_pointer_cast(product));
gttoc(lookup);
return {std::dynamic_pointer_cast(lookup), max};
diff --git a/gtsam/discrete/DiscreteLookupDAG.h b/gtsam/discrete/DiscreteLookupDAG.h
index f077a13d9..c811c4c49 100644
--- a/gtsam/discrete/DiscreteLookupDAG.h
+++ b/gtsam/discrete/DiscreteLookupDAG.h
@@ -18,6 +18,7 @@
#pragma once
#include
+#include
#include
#include
@@ -54,6 +55,12 @@ class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional {
const ADT& potentials)
: DiscreteConditional(nFrontals, keys, potentials) {}
+ //TODO(Varun): Should accept a DiscreteFactor::shared_ptr
+ DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
+ const TableFactor& potentials)
+ : DiscreteConditional(nFrontals, keys,
+ potentials.toDecisionTreeFactor()) {}
+
/// GTSAM-style print
void print(
const std::string& s = "Discrete Lookup Table: ",
From 77578512f80600d27f25cdfeb0c22526e64ce7b9 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sat, 7 Dec 2024 21:45:22 -0500
Subject: [PATCH 22/86] timing
---
gtsam/discrete/DiscreteFactorGraph.cpp | 11 +++++------
gtsam/discrete/DiscreteLookupDAG.h | 10 +++++++++-
2 files changed, 14 insertions(+), 7 deletions(-)
diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp
index c2a16159b..fa9d9bdc7 100644
--- a/gtsam/discrete/DiscreteFactorGraph.cpp
+++ b/gtsam/discrete/DiscreteFactorGraph.cpp
@@ -22,7 +22,6 @@
#include
#include
#include
-#include
#include
#include
@@ -127,14 +126,14 @@ namespace gtsam {
static DiscreteFactor::shared_ptr Normalize(
const DiscreteFactor::shared_ptr& product) {
// Max over all the potentials by pretending all keys are frontal:
- gttic(DiscreteFindMax);
+ gttic_(DiscreteFindMax);
auto normalization = product->max(product->size());
- gttoc(DiscreteFindMax);
+ gttoc_(DiscreteFindMax);
- gttic(DiscreteNormalization);
+ gttic_(DiscreteNormalization);
// Normalize the product factor to prevent underflow.
auto normalized_product = product->operator/(normalization);
- gttoc(DiscreteNormalization);
+ gttoc_(DiscreteNormalization);
return normalized_product;
}
@@ -260,7 +259,7 @@ namespace gtsam {
gttic_(divide);
auto conditional =
std::make_shared(product, sum, orderedKeys);
- gttoc(divide);
+ gttoc_(divide);
return {conditional, sum};
}
diff --git a/gtsam/discrete/DiscreteLookupDAG.h b/gtsam/discrete/DiscreteLookupDAG.h
index c811c4c49..21181d374 100644
--- a/gtsam/discrete/DiscreteLookupDAG.h
+++ b/gtsam/discrete/DiscreteLookupDAG.h
@@ -55,7 +55,15 @@ class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional {
const ADT& potentials)
: DiscreteConditional(nFrontals, keys, potentials) {}
- //TODO(Varun): Should accept a DiscreteFactor::shared_ptr
+ /**
+ * @brief Construct a new Discrete Lookup Table object
+ *
+ * @param nFrontals number of frontal variables
+ * @param keys a sorted list of gtsam::Keys
+ * @param potentials Discrete potentials as a TableFactor.
+ *
+ * //TODO(Varun): Should accept a DiscreteFactor::shared_ptr
+ */
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
const TableFactor& potentials)
: DiscreteConditional(nFrontals, keys,
From 9844a555d4debcf3e0cb7c6047c7b81f8701d27b Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 8 Dec 2024 10:34:02 -0500
Subject: [PATCH 23/86] move evaluate and operator() next to each other
---
gtsam/discrete/DiscreteFactor.h | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h
index e2d32e828..3151afe80 100644
--- a/gtsam/discrete/DiscreteFactor.h
+++ b/gtsam/discrete/DiscreteFactor.h
@@ -93,6 +93,9 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
size_t cardinality(Key j) const { return cardinalities_.at(j); }
+ /// Calculate probability for given values
+ virtual double evaluate(const Assignment& values) const = 0;
+
/// Find value for given assignment of values to variables
virtual double operator()(const Assignment& values) const = 0;
@@ -131,9 +134,6 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
virtual DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& f) const = 0;
- /// Calculate probability for given values
- virtual double evaluate(const Assignment& values) const = 0;
-
/**
* Get the number of non-zero values contained in this factor.
* It could be much smaller than `prod_{key}(cardinality(key))`.
From aa25ccfa6ecbaef12e549b3e0ac6a09d5c2d30de Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 8 Dec 2024 11:15:57 -0500
Subject: [PATCH 24/86] implement evaluate in DiscreteFactor
---
gtsam/discrete/DecisionTreeFactor.h | 5 -----
gtsam/discrete/DiscreteConditional.cpp | 2 +-
gtsam/discrete/DiscreteConditional.h | 5 -----
gtsam/discrete/DiscreteFactor.h | 14 ++++++++++++--
gtsam/discrete/TableFactor.h | 8 +-------
5 files changed, 14 insertions(+), 20 deletions(-)
diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h
index f8f2835e5..85491b909 100644
--- a/gtsam/discrete/DecisionTreeFactor.h
+++ b/gtsam/discrete/DecisionTreeFactor.h
@@ -131,11 +131,6 @@ namespace gtsam {
/// Calculate probability for given values,
/// is just look up in AlgebraicDecisionTree.
- double evaluate(const Assignment& values) const override {
- return ADT::operator()(values);
- }
-
- /// Evaluate probability distribution, sugar.
double operator()(const Assignment& values) const override {
return ADT::operator()(values);
}
diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp
index 048f35e5b..c90f7f9a0 100644
--- a/gtsam/discrete/DiscreteConditional.cpp
+++ b/gtsam/discrete/DiscreteConditional.cpp
@@ -476,7 +476,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
/* ************************************************************************* */
double DiscreteConditional::evaluate(const HybridValues& x) const {
- return this->evaluate(x.discrete());
+ return this->operator()(x.discrete());
}
/* ************************************************************************* */
diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h
index 77003f232..29292f57e 100644
--- a/gtsam/discrete/DiscreteConditional.h
+++ b/gtsam/discrete/DiscreteConditional.h
@@ -168,11 +168,6 @@ class GTSAM_EXPORT DiscreteConditional
static_cast(this)->print(s, formatter);
}
- /// Evaluate, just look up in AlgebraicDecisionTree
- double evaluate(const Assignment& values) const override {
- return ADT::operator()(values);
- }
-
using DecisionTreeFactor::error; ///< DiscreteValues version
using DecisionTreeFactor::operator(); ///< DiscreteValues version
diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h
index 3151afe80..5b4665d4d 100644
--- a/gtsam/discrete/DiscreteFactor.h
+++ b/gtsam/discrete/DiscreteFactor.h
@@ -93,8 +93,18 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
size_t cardinality(Key j) const { return cardinalities_.at(j); }
- /// Calculate probability for given values
- virtual double evaluate(const Assignment& values) const = 0;
+ /**
+ * @brief Calculate probability for given values.
+ * Calls specialized evaluation under the hood.
+ *
+ * Note: Uses Assignment as it is the base class of DiscreteValues.
+ *
+ * @param values Discrete assignment.
+ * @return double
+ */
+ double evaluate(const Assignment& values) const {
+ return operator()(values);
+ }
/// Find value for given assignment of values to variables
virtual double operator()(const Assignment& values) const = 0;
diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h
index 47a7c6bbb..d8df12821 100644
--- a/gtsam/discrete/TableFactor.h
+++ b/gtsam/discrete/TableFactor.h
@@ -169,13 +169,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
// /// @name Standard Interface
// /// @{
- /// Calculate probability for given values,
- /// is just look up in TableFactor.
- double evaluate(const Assignment& values) const override {
- return operator()(values);
- }
-
- /// Evaluate probability distribution, sugar.
+ /// Evaluate probability distribution, is just look up in TableFactor.
double operator()(const Assignment& values) const override;
/// Calculate error for DiscreteValues `x`, is -log(probability).
From 90d7e21941d324df59ceda3487890d403a1af953 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 8 Dec 2024 11:19:35 -0500
Subject: [PATCH 25/86] change from DiscreteValues to Assignment
---
gtsam_unstable/discrete/AllDiff.cpp | 2 +-
gtsam_unstable/discrete/AllDiff.h | 2 +-
gtsam_unstable/discrete/BinaryAllDiff.h | 2 +-
gtsam_unstable/discrete/Domain.cpp | 2 +-
gtsam_unstable/discrete/Domain.h | 2 +-
5 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/gtsam_unstable/discrete/AllDiff.cpp b/gtsam_unstable/discrete/AllDiff.cpp
index 2bd9e6dfd..a450605b3 100644
--- a/gtsam_unstable/discrete/AllDiff.cpp
+++ b/gtsam_unstable/discrete/AllDiff.cpp
@@ -26,7 +26,7 @@ void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const {
}
/* ************************************************************************* */
-double AllDiff::operator()(const DiscreteValues& values) const {
+double AllDiff::operator()(const Assignment& values) const {
std::set taken; // record values taken by keys
for (Key dkey : keys_) {
size_t value = values.at(dkey); // get the value for that key
diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h
index 42a255bbf..7f539f4c6 100644
--- a/gtsam_unstable/discrete/AllDiff.h
+++ b/gtsam_unstable/discrete/AllDiff.h
@@ -45,7 +45,7 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
}
/// Calculate value = expensive !
- double operator()(const DiscreteValues& values) const override;
+ double operator()(const Assignment& values) const override;
/// Convert into a decisiontree, can be *very* expensive !
DecisionTreeFactor toDecisionTreeFactor() const override;
diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h
index 22acfb092..0e2fce109 100644
--- a/gtsam_unstable/discrete/BinaryAllDiff.h
+++ b/gtsam_unstable/discrete/BinaryAllDiff.h
@@ -47,7 +47,7 @@ class BinaryAllDiff : public Constraint {
}
/// Calculate value
- double operator()(const DiscreteValues& values) const override {
+ double operator()(const Assignment& values) const override {
return (double)(values.at(keys_[0]) != values.at(keys_[1]));
}
diff --git a/gtsam_unstable/discrete/Domain.cpp b/gtsam_unstable/discrete/Domain.cpp
index bbbc87667..752228c18 100644
--- a/gtsam_unstable/discrete/Domain.cpp
+++ b/gtsam_unstable/discrete/Domain.cpp
@@ -30,7 +30,7 @@ string Domain::base1Str() const {
}
/* ************************************************************************* */
-double Domain::operator()(const DiscreteValues& values) const {
+double Domain::operator()(const Assignment& values) const {
return contains(values.at(key()));
}
diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h
index ba3771eca..cd11fc8d9 100644
--- a/gtsam_unstable/discrete/Domain.h
+++ b/gtsam_unstable/discrete/Domain.h
@@ -82,7 +82,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
bool contains(size_t value) const { return values_.count(value) > 0; }
/// Calculate value
- double operator()(const DiscreteValues& values) const override;
+ double operator()(const Assignment& values) const override;
/// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override;
From 6665659e9de8a9f89dfda0c468692d93052d0f7a Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 8 Dec 2024 11:23:04 -0500
Subject: [PATCH 26/86] use BaseFactor instead of DecisionTreeFactor
---
gtsam/discrete/DiscreteConditional.h | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h
index 29292f57e..3dbadb3e4 100644
--- a/gtsam/discrete/DiscreteConditional.h
+++ b/gtsam/discrete/DiscreteConditional.h
@@ -168,8 +168,8 @@ class GTSAM_EXPORT DiscreteConditional
static_cast(this)->print(s, formatter);
}
- using DecisionTreeFactor::error; ///< DiscreteValues version
- using DecisionTreeFactor::operator(); ///< DiscreteValues version
+ using BaseFactor::error; ///< DiscreteValues version
+ using BaseFactor::operator(); ///< DiscreteValues version
/**
* @brief restrict to given *parent* values.
From e6b65285217c84f49a07a641c187719c1bbfbf43 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 8 Dec 2024 12:36:23 -0500
Subject: [PATCH 27/86] common definitions of Unary, UnaryAssignment and Binary
---
gtsam/discrete/DecisionTreeFactor.cpp | 12 ++++++------
gtsam/discrete/DecisionTreeFactor.h | 16 ++++++++++------
gtsam/discrete/DiscreteFactor.h | 5 +++++
3 files changed, 21 insertions(+), 12 deletions(-)
diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp
index e53f8cb90..93a7921aa 100644
--- a/gtsam/discrete/DecisionTreeFactor.cpp
+++ b/gtsam/discrete/DecisionTreeFactor.cpp
@@ -94,7 +94,7 @@ namespace gtsam {
}
/* ************************************************************************ */
- DecisionTreeFactor DecisionTreeFactor::apply(ADT::Unary op) const {
+ DecisionTreeFactor DecisionTreeFactor::apply(Unary op) const {
// apply operand
ADT result = ADT::apply(op);
// Make a new factor
@@ -102,7 +102,7 @@ namespace gtsam {
}
/* ************************************************************************ */
- DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const {
+ DecisionTreeFactor DecisionTreeFactor::apply(UnaryAssignment op) const {
// apply operand
ADT result = ADT::apply(op);
// Make a new factor
@@ -111,7 +111,7 @@ namespace gtsam {
/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
- ADT::Binary op) const {
+ Binary op) const {
map cs; // new cardinalities
// make unique key-cardinality map
for (Key j : keys()) cs[j] = cardinality(j);
@@ -129,8 +129,8 @@ namespace gtsam {
}
/* ************************************************************************ */
- DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
- size_t nrFrontals, ADT::Binary op) const {
+ DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals,
+ Binary op) const {
if (nrFrontals > size()) {
throw invalid_argument(
"DecisionTreeFactor::combine: invalid number of frontal "
@@ -157,7 +157,7 @@ namespace gtsam {
/* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
- const Ordering& frontalKeys, ADT::Binary op) const {
+ const Ordering& frontalKeys, Binary op) const {
if (frontalKeys.size() > size()) {
throw invalid_argument(
"DecisionTreeFactor::combine: invalid number of frontal "
diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h
index 983330f4a..f8679af59 100644
--- a/gtsam/discrete/DecisionTreeFactor.h
+++ b/gtsam/discrete/DecisionTreeFactor.h
@@ -51,6 +51,10 @@ namespace gtsam {
typedef std::shared_ptr shared_ptr;
typedef AlgebraicDecisionTree ADT;
+ using Base::Binary;
+ using Base::Unary;
+ using Base::UnaryAssignment;
+
/// @name Standard Constructors
/// @{
@@ -140,7 +144,7 @@ namespace gtsam {
double error(const DiscreteValues& values) const override;
/// multiply two factors
- DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
+ DecisionTreeFactor operator*(const DecisionTreeFactor& f) const {
return apply(f, Ring::mul);
}
@@ -196,21 +200,21 @@ namespace gtsam {
* Apply unary operator (*this) "op" f
* @param op a unary operator that operates on AlgebraicDecisionTree
*/
- DecisionTreeFactor apply(ADT::Unary op) const;
+ DecisionTreeFactor apply(Unary op) const;
/**
* Apply unary operator (*this) "op" f
* @param op a unary operator that operates on AlgebraicDecisionTree. Takes
* both the assignment and the value.
*/
- DecisionTreeFactor apply(ADT::UnaryAssignment op) const;
+ DecisionTreeFactor apply(UnaryAssignment op) const;
/**
* Apply binary operator (*this) "op" f
* @param f the second argument for op
* @param op a binary operator that operates on AlgebraicDecisionTree
*/
- DecisionTreeFactor apply(const DecisionTreeFactor& f, ADT::Binary op) const;
+ DecisionTreeFactor apply(const DecisionTreeFactor& f, Binary op) const;
/**
* Combine frontal variables using binary operator "op"
@@ -218,7 +222,7 @@ namespace gtsam {
* @param op a binary operator that operates on AlgebraicDecisionTree
* @return shared pointer to newly created DecisionTreeFactor
*/
- shared_ptr combine(size_t nrFrontals, ADT::Binary op) const;
+ shared_ptr combine(size_t nrFrontals, Binary op) const;
/**
* Combine frontal variables in an Ordering using binary operator "op"
@@ -226,7 +230,7 @@ namespace gtsam {
* @param op a binary operator that operates on AlgebraicDecisionTree
* @return shared pointer to newly created DecisionTreeFactor
*/
- shared_ptr combine(const Ordering& keys, ADT::Binary op) const;
+ shared_ptr combine(const Ordering& keys, Binary op) const;
/// Enumerate all values into a map from values to double.
std::vector> enumerate() const;
diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h
index 5b4665d4d..6e9f69619 100644
--- a/gtsam/discrete/DiscreteFactor.h
+++ b/gtsam/discrete/DiscreteFactor.h
@@ -47,6 +47,11 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
using Values = DiscreteValues; ///< backwards compatibility
+ using Unary = std::function;
+ using UnaryAssignment =
+ std::function&, const double&)>;
+ using Binary = std::function;
+
protected:
/// Map of Keys and their cardinalities.
std::map cardinalities_;
From f85284afb2f07ab82284628e45ef68949f15fa1a Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 8 Dec 2024 12:37:36 -0500
Subject: [PATCH 28/86] some cleanup based on previous commit
---
gtsam/discrete/DecisionTreeFactor.h | 1 +
gtsam/discrete/TableFactor.h | 5 -----
2 files changed, 1 insertion(+), 5 deletions(-)
diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h
index f8679af59..af0df47a2 100644
--- a/gtsam/discrete/DecisionTreeFactor.h
+++ b/gtsam/discrete/DecisionTreeFactor.h
@@ -51,6 +51,7 @@ namespace gtsam {
typedef std::shared_ptr shared_ptr;
typedef AlgebraicDecisionTree ADT;
+ // Needed since we have definitions in both DiscreteFactor and DecisionTree
using Base::Binary;
using Base::Unary;
using Base::UnaryAssignment;
diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h
index 317596ba9..345cbc254 100644
--- a/gtsam/discrete/TableFactor.h
+++ b/gtsam/discrete/TableFactor.h
@@ -94,12 +94,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
typedef std::shared_ptr shared_ptr;
typedef Eigen::SparseVector::InnerIterator SparseIt;
typedef std::vector> AssignValList;
- using Unary = std::function;
- using UnaryAssignment =
- std::function&, const double&)>;
- using Binary = std::function;
- public:
/// @name Standard Constructors
/// @{
From 5e86f7ee5122fc6cf4fac609fdf2a639c1e5079f Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 8 Dec 2024 15:31:35 -0500
Subject: [PATCH 29/86] remove previously added code
---
gtsam/discrete/DecisionTreeFactor.cpp | 11 -----------
gtsam/discrete/DecisionTreeFactor.h | 16 +---------------
gtsam/discrete/DiscreteConditional.cpp | 13 +++++++------
gtsam/discrete/DiscreteConditional.h | 8 ++++----
gtsam/discrete/DiscreteLookupDAG.cpp | 2 +-
5 files changed, 13 insertions(+), 37 deletions(-)
diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp
index 93a7921aa..776d4bd90 100644
--- a/gtsam/discrete/DecisionTreeFactor.cpp
+++ b/gtsam/discrete/DecisionTreeFactor.cpp
@@ -82,17 +82,6 @@ namespace gtsam {
ADT::print("", formatter);
}
- /* ************************************************************************ */
- DiscreteFactor::shared_ptr DecisionTreeFactor::operator*(
- const DiscreteFactor::shared_ptr& f) const {
- if (auto derived = std::dynamic_pointer_cast(f)) {
- return std::make_shared(this->operator*(*derived));
- } else {
- throw std::runtime_error(
- "Cannot convert DiscreteFactor to DecisionTreeFactor");
- }
- }
-
/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(Unary op) const {
// apply operand
diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h
index af0df47a2..11793f984 100644
--- a/gtsam/discrete/DecisionTreeFactor.h
+++ b/gtsam/discrete/DecisionTreeFactor.h
@@ -145,13 +145,10 @@ namespace gtsam {
double error(const DiscreteValues& values) const override;
/// multiply two factors
- DecisionTreeFactor operator*(const DecisionTreeFactor& f) const {
+ DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
return apply(f, Ring::mul);
}
- DiscreteFactor::shared_ptr operator*(
- const DiscreteFactor::shared_ptr& f) const override;
-
static double safe_div(const double& a, const double& b);
/// divide by factor f (safely)
@@ -159,17 +156,6 @@ namespace gtsam {
return apply(f, safe_div);
}
- /// divide by factor f (pointer version)
- DiscreteFactor::shared_ptr operator/(
- const DiscreteFactor::shared_ptr& f) const override {
- if (auto derived = std::dynamic_pointer_cast(f)) {
- return std::make_shared(apply(*derived, safe_div));
- } else {
- throw std::runtime_error(
- "Cannot convert DiscreteFactor to Table Factor");
- }
- }
-
/// Convert into a decision tree
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp
index 6dff84859..58acb21b0 100644
--- a/gtsam/discrete/DiscreteConditional.cpp
+++ b/gtsam/discrete/DiscreteConditional.cpp
@@ -38,7 +38,8 @@ using std::vector;
namespace gtsam {
// Instantiate base class
-template class GTSAM_EXPORT Conditional;
+template class GTSAM_EXPORT
+ Conditional;
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
@@ -152,11 +153,11 @@ void DiscreteConditional::print(const string& s,
/* ************************************************************************** */
bool DiscreteConditional::equals(const DiscreteFactor& other,
double tol) const {
- if (!dynamic_cast(&other)) {
+ if (!dynamic_cast(&other)) {
return false;
} else {
- const DecisionTreeFactor& f(static_cast(other));
- return DecisionTreeFactor::equals(f, tol);
+ const BaseFactor& f(static_cast(other));
+ return BaseFactor::equals(f, tol);
}
}
@@ -377,7 +378,7 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
ss << "*\n" << std::endl;
if (nrParents() == 0) {
// We have no parents, call factor method.
- ss << DecisionTreeFactor::markdown(keyFormatter, names);
+ ss << BaseFactor::markdown(keyFormatter, names);
return ss.str();
}
@@ -429,7 +430,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
ss << "
\n";
if (nrParents() == 0) {
// We have no parents, call factor method.
- ss << DecisionTreeFactor::html(keyFormatter, names);
+ ss << BaseFactor::html(keyFormatter, names);
return ss.str();
}
diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h
index 3dbadb3e4..298a7b004 100644
--- a/gtsam/discrete/DiscreteConditional.h
+++ b/gtsam/discrete/DiscreteConditional.h
@@ -110,16 +110,16 @@ class GTSAM_EXPORT DiscreteConditional
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
*/
- DiscreteConditional(const DiscreteFactor::shared_ptr& joint,
- const DiscreteFactor::shared_ptr& marginal);
+ DiscreteConditional(const DecisionTreeFactor& joint,
+ const DecisionTreeFactor& marginal);
/**
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
* Makes sure the keys are ordered as given. Does not check orderedKeys.
*/
- DiscreteConditional(const DiscreteFactor::shared_ptr& joint,
- const DiscreteFactor::shared_ptr& marginal,
+ DiscreteConditional(const DecisionTreeFactor& joint,
+ const DecisionTreeFactor& marginal,
const Ordering& orderedKeys);
/**
diff --git a/gtsam/discrete/DiscreteLookupDAG.cpp b/gtsam/discrete/DiscreteLookupDAG.cpp
index 4d02e007b..ee381fe44 100644
--- a/gtsam/discrete/DiscreteLookupDAG.cpp
+++ b/gtsam/discrete/DiscreteLookupDAG.cpp
@@ -47,7 +47,7 @@ void DiscreteLookupTable::print(const std::string& s,
}
}
cout << "):\n";
- ADT::print("", formatter);
+ BaseFactor::print("", formatter);
cout << endl;
}
From 1c14a56f5d9351355107b181b87565fa0613856b Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 8 Dec 2024 15:58:07 -0500
Subject: [PATCH 30/86] revert changes to make code generic
---
gtsam/discrete/DiscreteConditional.cpp | 16 +++++++---------
gtsam/discrete/DiscreteFactor.h | 10 +++-------
gtsam/discrete/TableFactor.cpp | 9 ++-------
gtsam/discrete/TableFactor.h | 14 ++------------
4 files changed, 14 insertions(+), 35 deletions(-)
diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp
index 58acb21b0..bd10e28b4 100644
--- a/gtsam/discrete/DiscreteConditional.cpp
+++ b/gtsam/discrete/DiscreteConditional.cpp
@@ -55,17 +55,15 @@ DiscreteConditional::DiscreteConditional(size_t nrFrontals,
: BaseFactor(keys, potentials), BaseConditional(nrFrontals) {}
/* ************************************************************************** */
-DiscreteConditional::DiscreteConditional(
- const DiscreteFactor::shared_ptr& joint,
- const DiscreteFactor::shared_ptr& marginal)
- : BaseFactor(*std::dynamic_pointer_cast(
- joint->operator/(marginal))),
- BaseConditional(joint->size() - marginal->size()) {}
+DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
+ const DecisionTreeFactor& marginal)
+ : BaseFactor(joint / marginal),
+ BaseConditional(joint.size() - marginal.size()) {}
/* ************************************************************************** */
-DiscreteConditional::DiscreteConditional(
- const DiscreteFactor::shared_ptr& joint,
- const DiscreteFactor::shared_ptr& marginal, const Ordering& orderedKeys)
+DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
+ const DecisionTreeFactor& marginal,
+ const Ordering& orderedKeys)
: DiscreteConditional(joint, marginal) {
keys_.clear();
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h
index 6e9f69619..a6356a045 100644
--- a/gtsam/discrete/DiscreteFactor.h
+++ b/gtsam/discrete/DiscreteFactor.h
@@ -126,10 +126,9 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
/// Compute error for each assignment and return as a tree
virtual AlgebraicDecisionTree errorTree() const;
- /// Multiply in a DiscreteFactor and return the result as
- /// DiscreteFactor
- virtual DiscreteFactor::shared_ptr operator*(
- const DiscreteFactor::shared_ptr&) const = 0;
+ /// Multiply in a DecisionTreeFactor and return the result as
+ /// DecisionTreeFactor
+ virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
@@ -145,9 +144,6 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
/// Create new factor by maximizing over all values with the same separator.
virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0;
- /// divide by factor f (safely)
- virtual DiscreteFactor::shared_ptr operator/(
- const DiscreteFactor::shared_ptr& f) const = 0;
/**
* Get the number of non-zero values contained in this factor.
diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp
index 50d15ff5e..a4947012e 100644
--- a/gtsam/discrete/TableFactor.cpp
+++ b/gtsam/discrete/TableFactor.cpp
@@ -171,13 +171,8 @@ double TableFactor::error(const HybridValues& values) const {
}
/* ************************************************************************ */
-DiscreteFactor::shared_ptr TableFactor::operator*(
- const DiscreteFactor::shared_ptr& f) const {
- if (auto derived = std::dynamic_pointer_cast(f)) {
- return std::make_shared(this->operator*(*derived));
- } else {
- throw std::runtime_error("Cannot convert DiscreteFactor to TableFactor");
- }
+DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
+ return toDecisionTreeFactor() * f;
}
/* ************************************************************************ */
diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h
index 345cbc254..ba1d05fe9 100644
--- a/gtsam/discrete/TableFactor.h
+++ b/gtsam/discrete/TableFactor.h
@@ -161,9 +161,8 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
return apply(f, Ring::mul);
};
- /// multiply with DiscreteFactor
- DiscreteFactor::shared_ptr operator*(
- const DiscreteFactor::shared_ptr& f) const override;
+ /// multiply with DecisionTreeFactor
+ DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
static double safe_div(const double& a, const double& b);
@@ -172,15 +171,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
return apply(f, safe_div);
}
- /// divide by factor f (pointer version)
- DiscreteFactor::shared_ptr operator/(
- const DiscreteFactor::shared_ptr& f) const override {
- if (auto derived = std::dynamic_pointer_cast(f)) {
- return std::make_shared(apply(*derived, safe_div));
- } else {
- throw std::runtime_error("Cannot convert DiscreteFactor to Table Factor");
- }
- }
/// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override;
From b325150b3751b7f5c43b0268f76be8e60f10fe38 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 8 Dec 2024 16:18:42 -0500
Subject: [PATCH 31/86] revert DiscreteFactorGraph::product
---
gtsam/discrete/DiscreteFactorGraph.cpp | 13 ++++---------
gtsam/discrete/DiscreteFactorGraph.h | 2 +-
2 files changed, 5 insertions(+), 10 deletions(-)
diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp
index fa9d9bdc7..9e64b0f6d 100644
--- a/gtsam/discrete/DiscreteFactorGraph.cpp
+++ b/gtsam/discrete/DiscreteFactorGraph.cpp
@@ -64,15 +64,10 @@ namespace gtsam {
}
/* ************************************************************************* */
- DiscreteFactor::shared_ptr DiscreteFactorGraph::product() const {
- sharedFactor result = this->at(0);
- for (size_t i = 1; i < this->size(); ++i) {
- const sharedFactor factor = this->at(i);
- if (factor) {
- // Predicated on the fact that all discrete factors are of a single type
- // so there is no type-conversion happening which can be expensive.
- result = result->operator*(factor);
- }
+ DecisionTreeFactor DiscreteFactorGraph::product() const {
+ DecisionTreeFactor result;
+ for (const sharedFactor& factor : *this) {
+ if (factor) result = result * (*factor);
}
return result;
}
diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h
index a5324811c..43c48c2d0 100644
--- a/gtsam/discrete/DiscreteFactorGraph.h
+++ b/gtsam/discrete/DiscreteFactorGraph.h
@@ -147,7 +147,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
DiscreteKeys discreteKeys() const;
/** return product of all factors as a single factor */
- DiscreteFactor::shared_ptr product() const;
+ DecisionTreeFactor product() const;
/**
* Evaluates the factor graph given values, returns the joint probability of
From 0afc1984118d8bec1c2327f542cac8b22d11b96f Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 8 Dec 2024 16:26:03 -0500
Subject: [PATCH 32/86] revert some DiscreteFactorGraph changes
---
gtsam/discrete/DiscreteFactorGraph.cpp | 29 +++++++++++++-------------
1 file changed, 15 insertions(+), 14 deletions(-)
diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp
index 9e64b0f6d..04849985f 100644
--- a/gtsam/discrete/DiscreteFactorGraph.cpp
+++ b/gtsam/discrete/DiscreteFactorGraph.cpp
@@ -118,16 +118,17 @@ namespace gtsam {
* @param product The product discrete factor.
* @return DiscreteFactor::shared_ptr
*/
- static DiscreteFactor::shared_ptr Normalize(
- const DiscreteFactor::shared_ptr& product) {
+ static DecisionTreeFactor Normalize(const DecisionTreeFactor& product) {
// Max over all the potentials by pretending all keys are frontal:
gttic_(DiscreteFindMax);
- auto normalization = product->max(product->size());
+ auto normalization = product.max(product.size());
gttoc_(DiscreteFindMax);
gttic_(DiscreteNormalization);
// Normalize the product factor to prevent underflow.
- auto normalized_product = product->operator/(normalization);
+ auto normalized_product =
+ product /
+ (*std::dynamic_pointer_cast(normalization));
gttoc_(DiscreteNormalization);
return normalized_product;
@@ -140,7 +141,7 @@ namespace gtsam {
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
gttic_(MPEProduct);
- DiscreteFactor::shared_ptr product = factors.product();
+ DecisionTreeFactor product = factors.product();
gttoc_(MPEProduct);
gttic_(Normalize);
@@ -151,23 +152,22 @@ namespace gtsam {
// max out frontals, this is the factor on the separator
gttic(max);
- DiscreteFactor::shared_ptr max = product->max(frontalKeys);
+ DecisionTreeFactor::shared_ptr max =
+ std::dynamic_pointer_cast(product.max(frontalKeys));
gttoc(max);
// Ordering keys for the conditional so that frontalKeys are really in front
DiscreteKeys orderedKeys;
for (auto&& key : frontalKeys)
- orderedKeys.emplace_back(key, product->cardinality(key));
+ orderedKeys.emplace_back(key, product.cardinality(key));
for (auto&& key : max->keys())
- orderedKeys.emplace_back(key, product->cardinality(key));
+ orderedKeys.emplace_back(key, product.cardinality(key));
// Make lookup with product
gttic(lookup);
size_t nrFrontals = frontalKeys.size();
- //TODO(Varun): Should accept a DiscreteFactor::shared_ptr
- auto lookup = std::make_shared(
- nrFrontals, orderedKeys,
- *std::dynamic_pointer_cast(product));
+ auto lookup =
+ std::make_shared(nrFrontals, orderedKeys, product);
gttoc(lookup);
return {std::dynamic_pointer_cast(lookup), max};
@@ -230,7 +230,7 @@ namespace gtsam {
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
gttic_(product);
- DiscreteFactor::shared_ptr product = factors.product();
+ DecisionTreeFactor product = factors.product();
gttoc_(product);
gttic_(Normalize);
@@ -240,7 +240,8 @@ namespace gtsam {
// sum out frontals, this is the factor on the separator
gttic_(sum);
- DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
+ DecisionTreeFactor::shared_ptr sum = std::dynamic_pointer_cast(
+ product.sum(frontalKeys));
gttoc_(sum);
// Ordering keys for the conditional so that frontalKeys are really in front
From 975fe627d91fc3900e0553e4e043ddf60aef7e53 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 8 Dec 2024 16:58:19 -0500
Subject: [PATCH 33/86] add methods in gtsam_unstable
---
gtsam_unstable/discrete/AllDiff.h | 16 ++++++++++++++++
gtsam_unstable/discrete/BinaryAllDiff.h | 16 ++++++++++++++++
gtsam_unstable/discrete/Domain.h | 16 ++++++++++++++++
gtsam_unstable/discrete/SingleValue.h | 18 +++++++++++++++++-
4 files changed, 65 insertions(+), 1 deletion(-)
diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h
index 7f539f4c6..fb956146b 100644
--- a/gtsam_unstable/discrete/AllDiff.h
+++ b/gtsam_unstable/discrete/AllDiff.h
@@ -75,6 +75,22 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
/// Get the number of non-zero values contained in this factor.
uint64_t nrValues() const override { return 1; };
+
+ DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
+ throw std::runtime_error("Not implemented");
+ }
+
+ DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
+ throw std::runtime_error("Not implemented");
+ }
+
+ DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
+ throw std::runtime_error("Not implemented");
+ }
+
+ DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
+ throw std::runtime_error("Not implemented");
+ }
};
} // namespace gtsam
diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h
index 0e2fce109..fe04fd807 100644
--- a/gtsam_unstable/discrete/BinaryAllDiff.h
+++ b/gtsam_unstable/discrete/BinaryAllDiff.h
@@ -99,6 +99,22 @@ class BinaryAllDiff : public Constraint {
/// Get the number of non-zero values contained in this factor.
uint64_t nrValues() const override { return 1; };
+
+ DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
+ throw std::runtime_error("Not implemented");
+ }
+
+ DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
+ throw std::runtime_error("Not implemented");
+ }
+
+ DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
+ throw std::runtime_error("Not implemented");
+ }
+
+ DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
+ throw std::runtime_error("Not implemented");
+ }
};
} // namespace gtsam
diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h
index cd11fc8d9..1fbd7b110 100644
--- a/gtsam_unstable/discrete/Domain.h
+++ b/gtsam_unstable/discrete/Domain.h
@@ -114,6 +114,22 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
/// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply(const Domains& domains) const override;
+
+ DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
+ throw std::runtime_error("Not implemented");
+ }
+
+ DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
+ throw std::runtime_error("Not implemented");
+ }
+
+ DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
+ throw std::runtime_error("Not implemented");
+ }
+
+ DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
+ throw std::runtime_error("Not implemented");
+ }
};
} // namespace gtsam
diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h
index 7f2eb2c2c..8dc7114ec 100644
--- a/gtsam_unstable/discrete/SingleValue.h
+++ b/gtsam_unstable/discrete/SingleValue.h
@@ -55,7 +55,7 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
}
/// Calculate value
- double operator()(const DiscreteValues& values) const override;
+ double operator()(const Assignment& values) const override;
/// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override;
@@ -80,6 +80,22 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
/// Get the number of non-zero values contained in this factor.
uint64_t nrValues() const override { return 1; };
+
+ DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
+ throw std::runtime_error("Not implemented");
+ }
+
+ DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
+ throw std::runtime_error("Not implemented");
+ }
+
+ DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
+ throw std::runtime_error("Not implemented");
+ }
+
+ DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
+ throw std::runtime_error("Not implemented");
+ }
};
} // namespace gtsam
From fc2d33f437ab65b394ff563ff9f8872101487189 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 8 Dec 2024 17:00:04 -0500
Subject: [PATCH 34/86] add division with DiscreteFactor::shared_ptr for
convenience
---
gtsam/discrete/DecisionTreeFactor.h | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h
index 11793f984..a5178b66f 100644
--- a/gtsam/discrete/DecisionTreeFactor.h
+++ b/gtsam/discrete/DecisionTreeFactor.h
@@ -156,6 +156,11 @@ namespace gtsam {
return apply(f, safe_div);
}
+ /// divide by DiscreteFactor::shared_ptr f (safely)
+ DecisionTreeFactor operator/(const DiscreteFactor::shared_ptr& f) const {
+ return apply(*std::dynamic_pointer_cast(f), safe_div);
+ }
+
/// Convert into a decision tree
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
From 2c02efcae2e2b320834baeff157329c6908c332e Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 8 Dec 2024 17:02:47 -0500
Subject: [PATCH 35/86] fix tests
---
gtsam/discrete/tests/testDecisionTreeFactor.cpp | 6 +++---
gtsam/discrete/tests/testDiscreteConditional.cpp | 4 ++--
gtsam/discrete/tests/testDiscreteFactorGraph.cpp | 6 +++---
gtsam/discrete/tests/testTableFactor.cpp | 6 +++---
4 files changed, 11 insertions(+), 11 deletions(-)
diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp
index 756a0cebe..b4c5acc1b 100644
--- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp
+++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp
@@ -111,15 +111,15 @@ TEST(DecisionTreeFactor, sum_max) {
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
DecisionTreeFactor expected(v1, "9 12");
- DecisionTreeFactor::shared_ptr actual = f1.sum(1);
+ auto actual = std::dynamic_pointer_cast(f1.sum(1));
CHECK(assert_equal(expected, *actual, 1e-5));
DecisionTreeFactor expected2(v1, "5 6");
- DecisionTreeFactor::shared_ptr actual2 = f1.max(1);
+ auto actual2 = std::dynamic_pointer_cast(f1.max(1));
CHECK(assert_equal(expected2, *actual2));
DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
- DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
+ auto actual22 = std::dynamic_pointer_cast(f2.sum(1));
}
/* ************************************************************************* */
diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp
index 2482a86a2..d17c76837 100644
--- a/gtsam/discrete/tests/testDiscreteConditional.cpp
+++ b/gtsam/discrete/tests/testDiscreteConditional.cpp
@@ -46,7 +46,7 @@ TEST(DiscreteConditional, constructors) {
DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2);
- DecisionTreeFactor expected2 = f2 / *f2.sum(1);
+ DecisionTreeFactor expected2 = f2 / f2.sum(1);
EXPECT(assert_equal(expected2, static_cast(actual2)));
std::vector probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75};
@@ -70,7 +70,7 @@ TEST(DiscreteConditional, constructors_alt_interface) {
DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2);
- DecisionTreeFactor expected2 = f2 / *f2.sum(1);
+ DecisionTreeFactor expected2 = f2 / f2.sum(1);
EXPECT(assert_equal(expected2, static_cast(actual2)));
}
diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp
index 341eb63e3..e0d696d91 100644
--- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp
+++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp
@@ -113,12 +113,12 @@ TEST(DiscreteFactorGraph, test) {
const Ordering frontalKeys{0};
const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys);
- DecisionTreeFactor newFactor = *newFactorPtr;
+ auto newFactor = *std::dynamic_pointer_cast(newFactorPtr);
// Normalize newFactor by max for comparison with expected
auto normalization = newFactor.max(newFactor.size());
- newFactor = newFactor / *normalization;
+ newFactor = newFactor / normalization;
// Check Conditional
CHECK(conditional);
@@ -132,7 +132,7 @@ TEST(DiscreteFactorGraph, test) {
// Normalize by max.
normalization = expectedFactor.max(expectedFactor.size());
// Ensure normalization is correct.
- expectedFactor = expectedFactor / *normalization;
+ expectedFactor = expectedFactor / normalization;
EXPECT(assert_equal(expectedFactor, newFactor));
// Test using elimination tree
diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp
index 0f7d7a615..bd3dac514 100644
--- a/gtsam/discrete/tests/testTableFactor.cpp
+++ b/gtsam/discrete/tests/testTableFactor.cpp
@@ -242,15 +242,15 @@ TEST(TableFactor, sum_max) {
TableFactor f1(v0 & v1, "1 2 3 4 5 6");
TableFactor expected(v1, "9 12");
- TableFactor::shared_ptr actual = f1.sum(1);
+ auto actual = std::dynamic_pointer_cast(f1.sum(1));
CHECK(assert_equal(expected, *actual, 1e-5));
TableFactor expected2(v1, "5 6");
- TableFactor::shared_ptr actual2 = f1.max(1);
+ auto actual2 = std::dynamic_pointer_cast(f1.max(1));
CHECK(assert_equal(expected2, *actual2));
TableFactor f2(v1 & v0, "1 2 3 4 5 6");
- TableFactor::shared_ptr actual22 = f2.sum(1);
+ auto actual22 = std::dynamic_pointer_cast(f2.sum(1));
}
/* ************************************************************************* */
From 360598d3d596a49f0e0778d5b4c5b49e6f17340f Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 8 Dec 2024 17:03:31 -0500
Subject: [PATCH 36/86] undo uncomment
---
gtsam/discrete/DiscreteDistribution.h | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/gtsam/discrete/DiscreteDistribution.h b/gtsam/discrete/DiscreteDistribution.h
index 28e509f15..09ea50332 100644
--- a/gtsam/discrete/DiscreteDistribution.h
+++ b/gtsam/discrete/DiscreteDistribution.h
@@ -86,7 +86,8 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
double operator()(size_t value) const;
/// We also want to keep the Base version, taking DiscreteValues:
- using Base::operator();
+ // TODO(dellaert): does not play well with wrapper!
+ // using Base::operator();
/// Return entire probability mass function.
std::vector pmf() const;
From 853241c6d09da2a71375cfc507b95ba647e59d29 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 8 Dec 2024 17:07:40 -0500
Subject: [PATCH 37/86] add evaluate to DiscreteConditional
---
gtsam/discrete/DiscreteConditional.h | 1 +
1 file changed, 1 insertion(+)
diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h
index 298a7b004..75eaf9154 100644
--- a/gtsam/discrete/DiscreteConditional.h
+++ b/gtsam/discrete/DiscreteConditional.h
@@ -169,6 +169,7 @@ class GTSAM_EXPORT DiscreteConditional
}
using BaseFactor::error; ///< DiscreteValues version
+ using BaseFactor::evaluate; // DiscreteValues version
using BaseFactor::operator(); ///< DiscreteValues version
/**
From 199c0a0f24da5e040ec128b83ad1f9cb3e2ef8eb Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 8 Dec 2024 17:15:22 -0500
Subject: [PATCH 38/86] keep using DecisionTreeFactor for DiscreteConditional
---
gtsam/discrete/DiscreteConditional.cpp | 4 ++--
gtsam/discrete/DiscreteConditional.h | 4 ++--
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp
index bd10e28b4..399126d51 100644
--- a/gtsam/discrete/DiscreteConditional.cpp
+++ b/gtsam/discrete/DiscreteConditional.cpp
@@ -201,7 +201,7 @@ DiscreteConditional::shared_ptr DiscreteConditional::choose(
}
/* ************************************************************************** */
-DiscreteFactor::shared_ptr DiscreteConditional::likelihood(
+DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
const DiscreteValues& frontalValues) const {
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the frontal variables.
@@ -226,7 +226,7 @@ DiscreteFactor::shared_ptr DiscreteConditional::likelihood(
}
/* ****************************************************************************/
-DiscreteFactor::shared_ptr DiscreteConditional::likelihood(
+DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
size_t frontal) const {
if (nrFrontals() != 1)
throw std::invalid_argument(
diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h
index 75eaf9154..96ecefd7a 100644
--- a/gtsam/discrete/DiscreteConditional.h
+++ b/gtsam/discrete/DiscreteConditional.h
@@ -188,11 +188,11 @@ class GTSAM_EXPORT DiscreteConditional
shared_ptr choose(const DiscreteValues& given) const;
/** Convert to a likelihood factor by providing value before bar. */
- DiscreteFactor::shared_ptr likelihood(
+ DecisionTreeFactor::shared_ptr likelihood(
const DiscreteValues& frontalValues) const;
/** Single variable version of likelihood. */
- DiscreteFactor::shared_ptr likelihood(size_t frontal) const;
+ DecisionTreeFactor::shared_ptr likelihood(size_t frontal) const;
/**
* sample
From 214bf4ec1a2ce6329d930b03b8f8f279c6f71de9 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 8 Dec 2024 17:15:40 -0500
Subject: [PATCH 39/86] more fixes
---
gtsam/discrete/DiscreteFactorGraph.cpp | 4 ++--
gtsam_unstable/discrete/SingleValue.cpp | 2 +-
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp
index 04849985f..e31f94eae 100644
--- a/gtsam/discrete/DiscreteFactorGraph.cpp
+++ b/gtsam/discrete/DiscreteFactorGraph.cpp
@@ -67,7 +67,7 @@ namespace gtsam {
DecisionTreeFactor DiscreteFactorGraph::product() const {
DecisionTreeFactor result;
for (const sharedFactor& factor : *this) {
- if (factor) result = result * (*factor);
+ if (factor) result = (*factor) * result;
}
return result;
}
@@ -254,7 +254,7 @@ namespace gtsam {
// now divide product/sum to get conditional
gttic_(divide);
auto conditional =
- std::make_shared(product, sum, orderedKeys);
+ std::make_shared(product, *sum, orderedKeys);
gttoc_(divide);
return {conditional, sum};
diff --git a/gtsam_unstable/discrete/SingleValue.cpp b/gtsam_unstable/discrete/SingleValue.cpp
index 6b78f38f5..9762aec0f 100644
--- a/gtsam_unstable/discrete/SingleValue.cpp
+++ b/gtsam_unstable/discrete/SingleValue.cpp
@@ -22,7 +22,7 @@ void SingleValue::print(const string& s, const KeyFormatter& formatter) const {
}
/* ************************************************************************* */
-double SingleValue::operator()(const DiscreteValues& values) const {
+double SingleValue::operator()(const Assignment& values) const {
return (double)(values.at(keys_[0]) == value_);
}
From e46cd5499351c67abd8c799d6ba317d37d21991e Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Mon, 9 Dec 2024 15:52:42 -0500
Subject: [PATCH 40/86] TableFactor cleanup
---
gtsam/discrete/TableFactor.cpp | 2 --
gtsam/discrete/TableFactor.h | 1 -
2 files changed, 3 deletions(-)
diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp
index a4947012e..32049fde1 100644
--- a/gtsam/discrete/TableFactor.cpp
+++ b/gtsam/discrete/TableFactor.cpp
@@ -72,9 +72,7 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
*/
std::vector ComputeLeafOrdering(const DiscreteKeys& dkeys,
const DecisionTreeFactor& dt) {
- gttic_(ComputeLeafOrdering);
std::vector probs = dt.probabilities();
- gttoc_(ComputeLeafOrdering);
std::vector ordered;
size_t n = dkeys[0].second;
diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h
index ba1d05fe9..002c276ca 100644
--- a/gtsam/discrete/TableFactor.h
+++ b/gtsam/discrete/TableFactor.h
@@ -171,7 +171,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
return apply(f, safe_div);
}
-
/// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override;
From 52c8034d41e8703ad149ee161c9d93ded4df5f0c Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Mon, 9 Dec 2024 16:16:18 -0500
Subject: [PATCH 41/86] add division by DiscreteFactor in TableFactor
---
gtsam/discrete/TableFactor.h | 9 +++++++++
1 file changed, 9 insertions(+)
diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h
index 002c276ca..64e98c6a1 100644
--- a/gtsam/discrete/TableFactor.h
+++ b/gtsam/discrete/TableFactor.h
@@ -171,6 +171,15 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
return apply(f, safe_div);
}
+ /// divide by DiscreteFactor::shared_ptr f (safely)
+ TableFactor operator/(const DiscreteFactor::shared_ptr& f) const {
+ if (auto tf = std::dynamic_pointer_cast(f)) {
+ return apply(*tf, safe_div);
+ } else if (auto dtf = std::dynamic_pointer_cast(f)) {
+ return apply(TableFactor(f->discreteKeys(), *dtf), safe_div);
+ }
+ }
+
/// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override;
From e0e833c2fc71b5848c405d1d5fa20078d1df0cab Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Mon, 9 Dec 2024 16:23:55 -0500
Subject: [PATCH 42/86] cleanup
---
gtsam/discrete/DiscreteLookupDAG.h | 2 --
gtsam/discrete/TableFactor.cpp | 2 --
2 files changed, 4 deletions(-)
diff --git a/gtsam/discrete/DiscreteLookupDAG.h b/gtsam/discrete/DiscreteLookupDAG.h
index 21181d374..7f31a3f48 100644
--- a/gtsam/discrete/DiscreteLookupDAG.h
+++ b/gtsam/discrete/DiscreteLookupDAG.h
@@ -61,8 +61,6 @@ class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional {
* @param nFrontals number of frontal variables
* @param keys a sorted list of gtsam::Keys
* @param potentials Discrete potentials as a TableFactor.
- *
- * //TODO(Varun): Should accept a DiscreteFactor::shared_ptr
*/
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
const TableFactor& potentials)
diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp
index 32049fde1..32cba84ed 100644
--- a/gtsam/discrete/TableFactor.cpp
+++ b/gtsam/discrete/TableFactor.cpp
@@ -180,10 +180,8 @@ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
for (auto i = 0; i < sparse_table_.size(); i++) {
table.push_back(sparse_table_.coeff(i));
}
- gttic_(toDecisionTreeFactor_Constructor);
// NOTE(Varun): This constructor is really expensive!!
DecisionTreeFactor f(dkeys, table);
- gttoc_(toDecisionTreeFactor_Constructor);
return f;
}
From 84627c0c579805590e703f487fc1a072e6637da6 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Mon, 9 Dec 2024 16:30:46 -0500
Subject: [PATCH 43/86] fix error
---
gtsam/discrete/DiscreteFactor.h | 1 -
gtsam/discrete/TableFactor.h | 3 +++
2 files changed, 3 insertions(+), 1 deletion(-)
diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h
index a6356a045..2fe80a54c 100644
--- a/gtsam/discrete/DiscreteFactor.h
+++ b/gtsam/discrete/DiscreteFactor.h
@@ -144,7 +144,6 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
/// Create new factor by maximizing over all values with the same separator.
virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0;
-
/**
* Get the number of non-zero values contained in this factor.
* It could be much smaller than `prod_{key}(cardinality(key))`.
diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h
index 64e98c6a1..41b6287b8 100644
--- a/gtsam/discrete/TableFactor.h
+++ b/gtsam/discrete/TableFactor.h
@@ -17,6 +17,7 @@
#pragma once
+#include
#include
#include
#include
@@ -177,6 +178,8 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
return apply(*tf, safe_div);
} else if (auto dtf = std::dynamic_pointer_cast(f)) {
return apply(TableFactor(f->discreteKeys(), *dtf), safe_div);
+ } else {
+ throw std::runtime_error("Unknown derived type for DiscreteFactor");
}
}
From 22d11d7af49ebdec102e5ec16597a00a175621ce Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Wed, 11 Dec 2024 04:00:56 -0500
Subject: [PATCH 44/86] don't print timing info by default
---
gtsam/discrete/DiscreteFactorGraph.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp
index 50072f547..ba05417aa 100644
--- a/gtsam/discrete/DiscreteFactorGraph.cpp
+++ b/gtsam/discrete/DiscreteFactorGraph.cpp
@@ -224,10 +224,10 @@ namespace gtsam {
DecisionTreeFactor product = ProductAndNormalize(factors);
// sum out frontals, this is the factor on the separator
- gttic_(sum);
+ gttic(sum);
DecisionTreeFactor::shared_ptr sum = std::dynamic_pointer_cast(
product.sum(frontalKeys));
- gttoc_(sum);
+ gttoc(sum);
// Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys;
@@ -237,10 +237,10 @@ namespace gtsam {
sum->keys().end());
// now divide product/sum to get conditional
- gttic_(divide);
+ gttic(divide);
auto conditional =
std::make_shared(product, *sum, orderedKeys);
- gttoc_(divide);
+ gttoc(divide);
return {conditional, sum};
}
From cbcfab4176b0577cc62e3cb77c3522cba90e2015 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Fri, 3 Jan 2025 13:09:25 -0500
Subject: [PATCH 45/86] serialize functions for Eigen::SparseVector
---
gtsam/base/MatrixSerialization.h | 40 ++++++++++++++++++++++++++++++++
1 file changed, 40 insertions(+)
diff --git a/gtsam/base/MatrixSerialization.h b/gtsam/base/MatrixSerialization.h
index 11b6a417a..43c97097d 100644
--- a/gtsam/base/MatrixSerialization.h
+++ b/gtsam/base/MatrixSerialization.h
@@ -24,6 +24,7 @@
#include
+#include
#include
#include
#include
@@ -87,6 +88,45 @@ void serialize(Archive& ar, gtsam::Matrix& m, const unsigned int version) {
split_free(ar, m, version);
}
+/******************************************************************************/
+/// Customized functions for serializing Eigen::SparseVector
+template
+void save(Archive& ar, const Eigen::SparseVector<_Scalar, _Options, _Index>& m,
+ const unsigned int /*version*/) {
+ _Index size = m.size();
+
+ std::vector> data;
+ for (typename Eigen::SparseVector<_Scalar, _Options, _Index>::InnerIterator
+ it(m);
+ it; ++it)
+ data.push_back({it.index(), it.value()});
+
+ ar << BOOST_SERIALIZATION_NVP(size);
+ ar << BOOST_SERIALIZATION_NVP(data);
+}
+
+template
+void load(Archive& ar, Eigen::SparseVector<_Scalar, _Options, _Index>& m,
+ const unsigned int /*version*/) {
+ _Index size;
+ ar >> BOOST_SERIALIZATION_NVP(size);
+ m.resize(size);
+
+ std::vector> data;
+ ar >> BOOST_SERIALIZATION_NVP(data);
+
+ for (auto&& d : data) {
+ m.coeffRef(d.first) = d.second;
+ }
+}
+
+template
+void serialize(Archive& ar, Eigen::SparseVector<_Scalar, _Options, _Index>& m,
+ const unsigned int version) {
+ split_free(ar, m, version);
+}
+/******************************************************************************/
+
} // namespace serialization
} // namespace boost
#endif
From 5c4194e7cd800c8cd232f92a63dbd8e10b935340 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Fri, 3 Jan 2025 13:13:51 -0500
Subject: [PATCH 46/86] add serialization code to TableFactor
---
gtsam/discrete/TableFactor.h | 19 +++++++++++++++++++
1 file changed, 19 insertions(+)
diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h
index 5ddb4ab43..a2fdb4d32 100644
--- a/gtsam/discrete/TableFactor.h
+++ b/gtsam/discrete/TableFactor.h
@@ -31,6 +31,12 @@
#include
#include
+#if GTSAM_ENABLE_BOOST_SERIALIZATION
+#include
+
+#include
+#endif
+
namespace gtsam {
class DiscreteConditional;
@@ -342,6 +348,19 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
double error(const HybridValues& values) const override;
/// @}
+
+ private:
+#if GTSAM_ENABLE_BOOST_SERIALIZATION
+ /** Serialization function */
+ friend class boost::serialization::access;
+ template
+ void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
+ ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
+ ar& BOOST_SERIALIZATION_NVP(sparse_table_);
+ ar& BOOST_SERIALIZATION_NVP(denominators_);
+ ar& BOOST_SERIALIZATION_NVP(sorted_dkeys_);
+ }
+#endif
};
// traits
From ef2843c5b213e11463de8217e790a3e6e9cbbd25 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Fri, 3 Jan 2025 13:14:07 -0500
Subject: [PATCH 47/86] test for TableFactor serialization
---
.../discrete/tests/testSerializationDiscrete.cpp | 15 +++++++++++++++
1 file changed, 15 insertions(+)
diff --git a/gtsam/discrete/tests/testSerializationDiscrete.cpp b/gtsam/discrete/tests/testSerializationDiscrete.cpp
index df7df0b7e..9d15d0536 100644
--- a/gtsam/discrete/tests/testSerializationDiscrete.cpp
+++ b/gtsam/discrete/tests/testSerializationDiscrete.cpp
@@ -20,6 +20,7 @@
#include
#include
#include
+#include
#include
using namespace std;
@@ -32,6 +33,7 @@ BOOST_CLASS_EXPORT_GUID(Tree::Leaf, "gtsam_DecisionTreeStringInt_Leaf")
BOOST_CLASS_EXPORT_GUID(Tree::Choice, "gtsam_DecisionTreeStringInt_Choice")
BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");
+BOOST_CLASS_EXPORT_GUID(TableFactor, "gtsam_TableFactor");
using ADT = AlgebraicDecisionTree;
BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree");
@@ -79,6 +81,19 @@ TEST(DiscreteSerialization, DecisionTreeFactor) {
EXPECT(equalsBinary(f));
}
+/* ************************************************************************* */
+// Check serialization for TableFactor
+TEST(DiscreteSerialization, TableFactor) {
+ using namespace serializationTestHelpers;
+
+ DiscreteKey A(Symbol('x', 1), 3);
+ TableFactor tf(A % "1/2/2");
+
+ EXPECT(equalsObj(tf));
+ EXPECT(equalsXML(tf));
+ EXPECT(equalsBinary(tf));
+}
+
/* ************************************************************************* */
// Check serialization for DiscreteConditional & DiscreteDistribution
TEST(DiscreteSerialization, DiscreteConditional) {
From 834288f9748992b24bc4d4f4cffc77c7d8461d8c Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Fri, 3 Jan 2025 13:18:24 -0500
Subject: [PATCH 48/86] additional Signature based constructor for
DecisionTreeFactor
---
gtsam/discrete/DecisionTreeFactor.h | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h
index 80ee10a7b..24a699d42 100644
--- a/gtsam/discrete/DecisionTreeFactor.h
+++ b/gtsam/discrete/DecisionTreeFactor.h
@@ -18,6 +18,7 @@
#pragma once
+#include
#include
#include
#include
@@ -116,6 +117,10 @@ namespace gtsam {
DecisionTreeFactor(const DiscreteKey& key, const std::vector& row)
: DecisionTreeFactor(DiscreteKeys{key}, row) {}
+ /// Construct from Signature
+ DecisionTreeFactor(const Signature& signature)
+ : DecisionTreeFactor(signature.discreteKeys(), signature.cpt()) {}
+
/** Construct from a DiscreteConditional type */
explicit DecisionTreeFactor(const DiscreteConditional& c);
From e6567457b511a6ff993efcf2710c98f72c71bdad Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Fri, 3 Jan 2025 13:21:19 -0500
Subject: [PATCH 49/86] update tests
---
.../discrete/tests/testDecisionTreeFactor.cpp | 22 +++++++------------
1 file changed, 8 insertions(+), 14 deletions(-)
diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp
index 756a0cebe..1828db525 100644
--- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp
+++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp
@@ -217,12 +217,6 @@ void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) {
#endif
}
-/** Convert Signature into CPT */
-DecisionTreeFactor create(const Signature& signature) {
- DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
- return p;
-}
-
/* ************************************************************************* */
// test Asia Joint
TEST(DecisionTreeFactor, joint) {
@@ -230,14 +224,14 @@ TEST(DecisionTreeFactor, joint) {
D(7, 2);
gttic_(asiaCPTs);
- DecisionTreeFactor pA = create(A % "99/1");
- DecisionTreeFactor pS = create(S % "50/50");
- DecisionTreeFactor pT = create(T | A = "99/1 95/5");
- DecisionTreeFactor pL = create(L | S = "99/1 90/10");
- DecisionTreeFactor pB = create(B | S = "70/30 40/60");
- DecisionTreeFactor pE = create((E | T, L) = "F T T T");
- DecisionTreeFactor pX = create(X | E = "95/5 2/98");
- DecisionTreeFactor pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
+ DecisionTreeFactor pA(A % "99/1");
+ DecisionTreeFactor pS(S % "50/50");
+ DecisionTreeFactor pT(T | A = "99/1 95/5");
+ DecisionTreeFactor pL(L | S = "99/1 90/10");
+ DecisionTreeFactor pB(B | S = "70/30 40/60");
+ DecisionTreeFactor pE((E | T, L) = "F T T T");
+ DecisionTreeFactor pX(X | E = "95/5 2/98");
+ DecisionTreeFactor pD((D | E, B) = "9/1 2/8 3/7 1/9");
// Create joint
gttic_(asiaJoint);
From cb9cec30e39895b4745a4727aa896718dcccc467 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Fri, 3 Jan 2025 13:28:21 -0500
Subject: [PATCH 50/86] unit test exposing division bug
---
gtsam/discrete/tests/testDecisionTreeFactor.cpp | 9 +++++++++
1 file changed, 9 insertions(+)
diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp
index 1828db525..73420c860 100644
--- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp
+++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp
@@ -69,6 +69,15 @@ TEST(DecisionTreeFactor, constructors) {
EXPECT_DOUBLES_EQUAL(0.8, f4(x121), 1e-9);
}
+/* ************************************************************************* */
+TEST(DecisionTreeFactor, Divide) {
+ DiscreteKey A(0, 2), S(1, 2);
+ DecisionTreeFactor pA(A % "99/1"), pS(S % "50/50");
+ DecisionTreeFactor joint = pA * pS;
+ DecisionTreeFactor s = joint / pA;
+ EXPECT(assert_equal(pS, s));
+}
+
/* ************************************************************************* */
TEST(DecisionTreeFactor, Error) {
// Declare a bunch of keys
From c6c451bee102f624cf25660ec6a820c9d5c1c49c Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Fri, 3 Jan 2025 13:28:40 -0500
Subject: [PATCH 51/86] compute correct subset of keys for division
---
gtsam/discrete/DecisionTreeFactor.h | 10 +++++++++-
1 file changed, 9 insertions(+), 1 deletion(-)
diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h
index 24a699d42..0b94140da 100644
--- a/gtsam/discrete/DecisionTreeFactor.h
+++ b/gtsam/discrete/DecisionTreeFactor.h
@@ -161,7 +161,15 @@ namespace gtsam {
/// divide by factor f (safely)
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
- return apply(f, safe_div);
+ KeyVector diff;
+ std::set_difference(this->keys().begin(), this->keys().end(),
+ f.keys().begin(), f.keys().end(),
+ std::back_inserter(diff));
+ DiscreteKeys keys;
+ for (Key key : diff) {
+ keys.push_back({key, this->cardinality(key)});
+ }
+ return DecisionTreeFactor(keys, apply(f, safe_div));
}
/// Convert into a decision tree
From 3718cb19edd21c718fe280c712a793031b6fd707 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Fri, 3 Jan 2025 13:35:13 -0500
Subject: [PATCH 52/86] use string based constructor
---
gtsam/discrete/tests/testSerializationDiscrete.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/gtsam/discrete/tests/testSerializationDiscrete.cpp b/gtsam/discrete/tests/testSerializationDiscrete.cpp
index 9d15d0536..b118a00f6 100644
--- a/gtsam/discrete/tests/testSerializationDiscrete.cpp
+++ b/gtsam/discrete/tests/testSerializationDiscrete.cpp
@@ -87,7 +87,7 @@ TEST(DiscreteSerialization, TableFactor) {
using namespace serializationTestHelpers;
DiscreteKey A(Symbol('x', 1), 3);
- TableFactor tf(A % "1/2/2");
+ TableFactor tf(A, "1 2 2");
EXPECT(equalsObj(tf));
EXPECT(equalsXML(tf));
From ffe14d39aae1609c4d85b863a3ffd5ebca0089ac Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Fri, 3 Jan 2025 13:42:48 -0500
Subject: [PATCH 53/86] Revert "update tests"
This reverts commit e6567457b511a6ff993efcf2710c98f72c71bdad.
---
.../discrete/tests/testDecisionTreeFactor.cpp | 22 ++++++++++++-------
1 file changed, 14 insertions(+), 8 deletions(-)
diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp
index 73420c860..61ce9038d 100644
--- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp
+++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp
@@ -226,6 +226,12 @@ void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) {
#endif
}
+/** Convert Signature into CPT */
+DecisionTreeFactor create(const Signature& signature) {
+ DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
+ return p;
+}
+
/* ************************************************************************* */
// test Asia Joint
TEST(DecisionTreeFactor, joint) {
@@ -233,14 +239,14 @@ TEST(DecisionTreeFactor, joint) {
D(7, 2);
gttic_(asiaCPTs);
- DecisionTreeFactor pA(A % "99/1");
- DecisionTreeFactor pS(S % "50/50");
- DecisionTreeFactor pT(T | A = "99/1 95/5");
- DecisionTreeFactor pL(L | S = "99/1 90/10");
- DecisionTreeFactor pB(B | S = "70/30 40/60");
- DecisionTreeFactor pE((E | T, L) = "F T T T");
- DecisionTreeFactor pX(X | E = "95/5 2/98");
- DecisionTreeFactor pD((D | E, B) = "9/1 2/8 3/7 1/9");
+ DecisionTreeFactor pA = create(A % "99/1");
+ DecisionTreeFactor pS = create(S % "50/50");
+ DecisionTreeFactor pT = create(T | A = "99/1 95/5");
+ DecisionTreeFactor pL = create(L | S = "99/1 90/10");
+ DecisionTreeFactor pB = create(B | S = "70/30 40/60");
+ DecisionTreeFactor pE = create((E | T, L) = "F T T T");
+ DecisionTreeFactor pX = create(X | E = "95/5 2/98");
+ DecisionTreeFactor pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
// Create joint
gttic_(asiaJoint);
From be7be376a9ef95b12c08b83d72bb602cc93eb852 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Fri, 3 Jan 2025 13:42:56 -0500
Subject: [PATCH 54/86] Revert "additional Signature based constructor for
DecisionTreeFactor"
This reverts commit 834288f9748992b24bc4d4f4cffc77c7d8461d8c.
---
gtsam/discrete/DecisionTreeFactor.h | 5 -----
1 file changed, 5 deletions(-)
diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h
index 0b94140da..804b956fe 100644
--- a/gtsam/discrete/DecisionTreeFactor.h
+++ b/gtsam/discrete/DecisionTreeFactor.h
@@ -18,7 +18,6 @@
#pragma once
-#include
#include
#include
#include
@@ -117,10 +116,6 @@ namespace gtsam {
DecisionTreeFactor(const DiscreteKey& key, const std::vector& row)
: DecisionTreeFactor(DiscreteKeys{key}, row) {}
- /// Construct from Signature
- DecisionTreeFactor(const Signature& signature)
- : DecisionTreeFactor(signature.discreteKeys(), signature.cpt()) {}
-
/** Construct from a DiscreteConditional type */
explicit DecisionTreeFactor(const DiscreteConditional& c);
From 6b6283c1512467819918c90191aa8372e96a00dd Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Fri, 3 Jan 2025 15:21:49 -0500
Subject: [PATCH 55/86] fix factor construction
---
gtsam/discrete/tests/testDecisionTreeFactor.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp
index 61ce9038d..dc18e0ab2 100644
--- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp
+++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp
@@ -72,7 +72,7 @@ TEST(DecisionTreeFactor, constructors) {
/* ************************************************************************* */
TEST(DecisionTreeFactor, Divide) {
DiscreteKey A(0, 2), S(1, 2);
- DecisionTreeFactor pA(A % "99/1"), pS(S % "50/50");
+ DecisionTreeFactor pA = create(A % "99/1"), pS = create(S % "50/50");
DecisionTreeFactor joint = pA * pS;
DecisionTreeFactor s = joint / pA;
EXPECT(assert_equal(pS, s));
From a142556c52064b5adb2926e76e4eb7e238ed0cb5 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 5 Jan 2025 08:45:11 -0500
Subject: [PATCH 56/86] move create to the top
---
.../discrete/tests/testDecisionTreeFactor.cpp | 30 +++++++++----------
1 file changed, 15 insertions(+), 15 deletions(-)
diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp
index dc18e0ab2..7210622d8 100644
--- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp
+++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp
@@ -30,6 +30,12 @@
using namespace std;
using namespace gtsam;
+/** Convert Signature into CPT */
+DecisionTreeFactor create(const Signature& signature) {
+ DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
+ return p;
+}
+
/* ************************************************************************* */
TEST(DecisionTreeFactor, ConstructorsMatch) {
// Declare two keys
@@ -69,15 +75,6 @@ TEST(DecisionTreeFactor, constructors) {
EXPECT_DOUBLES_EQUAL(0.8, f4(x121), 1e-9);
}
-/* ************************************************************************* */
-TEST(DecisionTreeFactor, Divide) {
- DiscreteKey A(0, 2), S(1, 2);
- DecisionTreeFactor pA = create(A % "99/1"), pS = create(S % "50/50");
- DecisionTreeFactor joint = pA * pS;
- DecisionTreeFactor s = joint / pA;
- EXPECT(assert_equal(pS, s));
-}
-
/* ************************************************************************* */
TEST(DecisionTreeFactor, Error) {
// Declare a bunch of keys
@@ -114,6 +111,15 @@ TEST(DecisionTreeFactor, multiplication) {
CHECK(assert_equal(expected2, actual));
}
+/* ************************************************************************* */
+TEST(DecisionTreeFactor, Divide) {
+ DiscreteKey A(0, 2), S(1, 2);
+ DecisionTreeFactor pA = create(A % "99/1"), pS = create(S % "50/50");
+ DecisionTreeFactor joint = pA * pS;
+ DecisionTreeFactor s = joint / pA;
+ EXPECT(assert_equal(pS, s));
+}
+
/* ************************************************************************* */
TEST(DecisionTreeFactor, sum_max) {
DiscreteKey v0(0, 3), v1(1, 2);
@@ -226,12 +232,6 @@ void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) {
#endif
}
-/** Convert Signature into CPT */
-DecisionTreeFactor create(const Signature& signature) {
- DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
- return p;
-}
-
/* ************************************************************************* */
// test Asia Joint
TEST(DecisionTreeFactor, joint) {
From 5fa04d7622548a8072ac621ff355a8caa622930d Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 5 Jan 2025 09:08:57 -0500
Subject: [PATCH 57/86] small improvements
---
gtsam/discrete/DiscreteConditional.cpp | 6 ++----
gtsam/discrete/DiscreteFactorGraph.cpp | 2 +-
gtsam/discrete/tests/testDiscreteFactorGraph.cpp | 10 +++++-----
gtsam/discrete/tests/testTableFactor.cpp | 14 ++++++++------
4 files changed, 16 insertions(+), 16 deletions(-)
diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp
index be0f42bea..26f38e278 100644
--- a/gtsam/discrete/DiscreteConditional.cpp
+++ b/gtsam/discrete/DiscreteConditional.cpp
@@ -24,13 +24,13 @@
#include
#include
+#include
#include
#include
#include
#include
#include
#include
-#include
using namespace std;
using std::pair;
@@ -45,9 +45,7 @@ template class GTSAM_EXPORT
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
const DecisionTreeFactor& f)
- : BaseFactor(f / (*std::dynamic_pointer_cast(
- f.sum(nrFrontals)))),
- BaseConditional(nrFrontals) {}
+ : BaseFactor(f / f.sum(nrFrontals)), BaseConditional(nrFrontals) {}
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp
index 8c950050b..678c70e2d 100644
--- a/gtsam/discrete/DiscreteFactorGraph.cpp
+++ b/gtsam/discrete/DiscreteFactorGraph.cpp
@@ -128,7 +128,7 @@ namespace gtsam {
auto denominator = product.max(product.size());
// Normalize the product factor to prevent underflow.
- product = product / (*denominator);
+ product = product / denominator;
return product;
}
diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp
index cbcf5234e..4ee36f0ab 100644
--- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp
+++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp
@@ -117,9 +117,9 @@ TEST(DiscreteFactorGraph, test) {
*std::dynamic_pointer_cast(newFactorPtr);
// Normalize newFactor by max for comparison with expected
- auto normalizer = newFactor.max(newFactor.size());
+ auto denominator = newFactor.max(newFactor.size());
- newFactor = newFactor / *normalizer;
+ newFactor = newFactor / denominator;
// Check Conditional
CHECK(conditional);
@@ -131,9 +131,9 @@ TEST(DiscreteFactorGraph, test) {
CHECK(&newFactor);
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
// Normalize by max.
- normalizer = expectedFactor.max(expectedFactor.size());
- // Ensure normalizer is correct.
- expectedFactor = expectedFactor / *normalizer;
+ denominator = expectedFactor.max(expectedFactor.size());
+ // Ensure denominator is correct.
+ expectedFactor = expectedFactor / denominator;
EXPECT(assert_equal(expectedFactor, newFactor));
// Test using elimination tree
diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp
index 147c3aea9..4f6ec2f39 100644
--- a/gtsam/discrete/tests/testTableFactor.cpp
+++ b/gtsam/discrete/tests/testTableFactor.cpp
@@ -194,15 +194,17 @@ TEST(TableFactor, Conversion) {
TEST(TableFactor, Empty) {
DiscreteKey X(1, 2);
- TableFactor single = *TableFactor({X}, "1 1").sum(1);
+ auto single = TableFactor({X}, "1 1").sum(1);
// Should not throw a segfault
- EXPECT(assert_equal(*DecisionTreeFactor(X, "1 1").sum(1),
- single.toDecisionTreeFactor()));
+ auto expected_single = DecisionTreeFactor(X, "1 1").sum(1);
+ EXPECT(assert_equal(expected_single->toDecisionTreeFactor(),
+ single->toDecisionTreeFactor()));
- TableFactor empty = *TableFactor({X}, "0 0").sum(1);
+ auto empty = TableFactor({X}, "0 0").sum(1);
// Should not throw a segfault
- EXPECT(assert_equal(*DecisionTreeFactor(X, "0 0").sum(1),
- empty.toDecisionTreeFactor()));
+ auto expected_empty = DecisionTreeFactor(X, "0 0").sum(1);
+ EXPECT(assert_equal(expected_empty->toDecisionTreeFactor(),
+ empty->toDecisionTreeFactor()));
}
/* ************************************************************************* */
From e309bf370bd195d5f4e2171cd451ca58588e9fb1 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 5 Jan 2025 11:54:09 -0500
Subject: [PATCH 58/86] improve operator/ documentation and also showcase
understanding in test
---
gtsam/discrete/DecisionTreeFactor.h | 24 +++++++++----------
.../discrete/tests/testDecisionTreeFactor.cpp | 14 ++++++++++-
2 files changed, 25 insertions(+), 13 deletions(-)
diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h
index 804b956fe..a5b82f277 100644
--- a/gtsam/discrete/DecisionTreeFactor.h
+++ b/gtsam/discrete/DecisionTreeFactor.h
@@ -154,17 +154,17 @@ namespace gtsam {
static double safe_div(const double& a, const double& b);
- /// divide by factor f (safely)
+ /**
+ * @brief Divide by factor f (safely).
+ * Division of a factor \f$f(x, y)\f$ by another factor \f$g(y, z)\f$
+ * results in a function which involves all keys
+ * \f$(\frac{f}{g})(x, y, z) = f(x, y) / g(y, z)\f$
+ *
+ * @param f The DecisinTreeFactor to divide by.
+ * @return DecisionTreeFactor
+ */
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
- KeyVector diff;
- std::set_difference(this->keys().begin(), this->keys().end(),
- f.keys().begin(), f.keys().end(),
- std::back_inserter(diff));
- DiscreteKeys keys;
- for (Key key : diff) {
- keys.push_back({key, this->cardinality(key)});
- }
- return DecisionTreeFactor(keys, apply(f, safe_div));
+ return apply(f, safe_div);
}
/// Convert into a decision tree
@@ -181,12 +181,12 @@ namespace gtsam {
}
/// Create new factor by maximizing over all values with the same separator.
- shared_ptr max(size_t nrFrontals) const {
+ DiscreteFactor::shared_ptr max(size_t nrFrontals) const {
return combine(nrFrontals, Ring::max);
}
/// Create new factor by maximizing over all values with the same separator.
- shared_ptr max(const Ordering& keys) const {
+ DiscreteFactor::shared_ptr max(const Ordering& keys) const {
return combine(keys, Ring::max);
}
diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp
index 7210622d8..ba8714783 100644
--- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp
+++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp
@@ -116,8 +116,20 @@ TEST(DecisionTreeFactor, Divide) {
DiscreteKey A(0, 2), S(1, 2);
DecisionTreeFactor pA = create(A % "99/1"), pS = create(S % "50/50");
DecisionTreeFactor joint = pA * pS;
+
DecisionTreeFactor s = joint / pA;
- EXPECT(assert_equal(pS, s));
+
+ // Factors are not equal due to difference in keys
+ EXPECT(assert_inequal(pS, s));
+
+ // The underlying data should be the same
+ using ADT = AlgebraicDecisionTree;
+ EXPECT(assert_equal(ADT(pS), ADT(s)));
+
+ KeySet keys(joint.keys());
+ keys.insert(pA.keys().begin(), pA.keys().end());
+ EXPECT(assert_inequal(KeySet(pS.keys()), keys));
+
}
/* ************************************************************************* */
From 268290dbf25195ce2cdb4ef1ed4099eed6338f78 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 5 Jan 2025 13:22:36 -0500
Subject: [PATCH 59/86] multiply method for DiscreteFactor
---
gtsam/discrete/DecisionTreeFactor.cpp | 12 ++++++++++++
gtsam/discrete/DecisionTreeFactor.h | 5 +++++
gtsam/discrete/DiscreteFactor.h | 10 ++++++++++
gtsam/discrete/TableFactor.cpp | 12 ++++++++++++
gtsam/discrete/TableFactor.h | 4 ++++
5 files changed, 43 insertions(+)
diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp
index 1ac782b88..cf22fe153 100644
--- a/gtsam/discrete/DecisionTreeFactor.cpp
+++ b/gtsam/discrete/DecisionTreeFactor.cpp
@@ -62,6 +62,18 @@ namespace gtsam {
return error(values.discrete());
}
+ /* ************************************************************************ */
+ DiscreteFactor::shared_ptr DecisionTreeFactor::multiply(
+ const DiscreteFactor::shared_ptr& f) const override {
+ DiscreteFactor::shared_ptr result;
+ if (auto tf = std::dynamic_pointer_cast(f)) {
+ result = std::make_shared((*tf) * TableFactor(*this));
+ } else if (auto dtf = std::dynamic_pointer_cast(f)) {
+ result = std::make_shared(this->operator*(*dtf));
+ }
+ return result;
+ }
+
/* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum
diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h
index 8b6d91be7..d3cb55fa5 100644
--- a/gtsam/discrete/DecisionTreeFactor.h
+++ b/gtsam/discrete/DecisionTreeFactor.h
@@ -22,6 +22,7 @@
#include
#include
#include
+#include
#include
#include
@@ -147,6 +148,10 @@ namespace gtsam {
/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const override;
+ /// Multiply factors, DiscreteFactor::shared_ptr edition
+ virtual DiscreteFactor::shared_ptr multiply(
+ const DiscreteFactor::shared_ptr& f) const override;
+
/// multiply two factors
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
return apply(f, Ring::mul);
diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h
index 63cd7844c..adc79bbd5 100644
--- a/gtsam/discrete/DiscreteFactor.h
+++ b/gtsam/discrete/DiscreteFactor.h
@@ -130,6 +130,16 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
/// DecisionTreeFactor
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
+ /**
+ * @brief Multiply in a DiscreteFactor and return the result as
+ * DiscreteFactor, both via shared pointers.
+ *
+ * @param df DiscreteFactor shared_ptr
+ * @return DiscreteFactor::shared_ptr
+ */
+ virtual DiscreteFactor::shared_ptr multiply(
+ const DiscreteFactor::shared_ptr& df) const = 0;
+
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
/// Create new factor by summing all values with the same separator values
diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp
index a59095d40..cfa56b43a 100644
--- a/gtsam/discrete/TableFactor.cpp
+++ b/gtsam/discrete/TableFactor.cpp
@@ -254,6 +254,18 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f;
}
+/* ************************************************************************ */
+DiscreteFactor::shared_ptr TableFactor::multiply(
+ const DiscreteFactor::shared_ptr& f) const override {
+ DiscreteFactor::shared_ptr result;
+ if (auto tf = std::dynamic_pointer_cast(f)) {
+ result = std::make_shared(this->operator*(*tf));
+ } else if (auto dtf = std::dynamic_pointer_cast(f)) {
+ result = std::make_shared(this->operator*(TableFactor(*dtf)));
+ }
+ return result;
+}
+
/* ************************************************************************ */
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys();
diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h
index 16354026d..f7d0f5215 100644
--- a/gtsam/discrete/TableFactor.h
+++ b/gtsam/discrete/TableFactor.h
@@ -179,6 +179,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/// multiply with DecisionTreeFactor
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
+ /// Multiply factors, DiscreteFactor::shared_ptr edition
+ virtual DiscreteFactor::shared_ptr multiply(
+ const DiscreteFactor::shared_ptr& f) const override;
+
static double safe_div(const double& a, const double& b);
/// divide by factor f (safely)
From 2f09e860e1109cf543a92c32d4fca14ce5d6c28e Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 5 Jan 2025 13:36:06 -0500
Subject: [PATCH 60/86] remove override from definition
---
gtsam/discrete/DecisionTreeFactor.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp
index cf22fe153..c15fd4e2e 100644
--- a/gtsam/discrete/DecisionTreeFactor.cpp
+++ b/gtsam/discrete/DecisionTreeFactor.cpp
@@ -64,7 +64,7 @@ namespace gtsam {
/* ************************************************************************ */
DiscreteFactor::shared_ptr DecisionTreeFactor::multiply(
- const DiscreteFactor::shared_ptr& f) const override {
+ const DiscreteFactor::shared_ptr& f) const {
DiscreteFactor::shared_ptr result;
if (auto tf = std::dynamic_pointer_cast(f)) {
result = std::make_shared((*tf) * TableFactor(*this));
From 5d865a8cc7908f02c206a3d9801af0fbaf8d1eaa Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 5 Jan 2025 13:22:36 -0500
Subject: [PATCH 61/86] multiply method for DiscreteFactor
---
gtsam/discrete/DecisionTreeFactor.cpp | 12 ++++++++++++
gtsam/discrete/DecisionTreeFactor.h | 5 +++++
gtsam/discrete/DiscreteFactor.h | 10 ++++++++++
gtsam/discrete/TableFactor.cpp | 12 ++++++++++++
gtsam/discrete/TableFactor.h | 4 ++++
5 files changed, 43 insertions(+)
diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp
index 1ac782b88..cf22fe153 100644
--- a/gtsam/discrete/DecisionTreeFactor.cpp
+++ b/gtsam/discrete/DecisionTreeFactor.cpp
@@ -62,6 +62,18 @@ namespace gtsam {
return error(values.discrete());
}
+ /* ************************************************************************ */
+ DiscreteFactor::shared_ptr DecisionTreeFactor::multiply(
+ const DiscreteFactor::shared_ptr& f) const override {
+ DiscreteFactor::shared_ptr result;
+ if (auto tf = std::dynamic_pointer_cast(f)) {
+ result = std::make_shared((*tf) * TableFactor(*this));
+ } else if (auto dtf = std::dynamic_pointer_cast(f)) {
+ result = std::make_shared(this->operator*(*dtf));
+ }
+ return result;
+ }
+
/* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum
diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h
index 80ee10a7b..3e70c0df9 100644
--- a/gtsam/discrete/DecisionTreeFactor.h
+++ b/gtsam/discrete/DecisionTreeFactor.h
@@ -22,6 +22,7 @@
#include
#include
#include
+#include
#include
#include
@@ -147,6 +148,10 @@ namespace gtsam {
/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const override;
+ /// Multiply factors, DiscreteFactor::shared_ptr edition
+ virtual DiscreteFactor::shared_ptr multiply(
+ const DiscreteFactor::shared_ptr& f) const override;
+
/// multiply two factors
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
return apply(f, Ring::mul);
diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h
index a1fde0f86..c18eaae2f 100644
--- a/gtsam/discrete/DiscreteFactor.h
+++ b/gtsam/discrete/DiscreteFactor.h
@@ -129,6 +129,16 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
/// DecisionTreeFactor
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
+ /**
+ * @brief Multiply in a DiscreteFactor and return the result as
+ * DiscreteFactor, both via shared pointers.
+ *
+ * @param df DiscreteFactor shared_ptr
+ * @return DiscreteFactor::shared_ptr
+ */
+ virtual DiscreteFactor::shared_ptr multiply(
+ const DiscreteFactor::shared_ptr& df) const = 0;
+
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
/// @}
diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp
index a59095d40..cfa56b43a 100644
--- a/gtsam/discrete/TableFactor.cpp
+++ b/gtsam/discrete/TableFactor.cpp
@@ -254,6 +254,18 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f;
}
+/* ************************************************************************ */
+DiscreteFactor::shared_ptr TableFactor::multiply(
+ const DiscreteFactor::shared_ptr& f) const override {
+ DiscreteFactor::shared_ptr result;
+ if (auto tf = std::dynamic_pointer_cast(f)) {
+ result = std::make_shared(this->operator*(*tf));
+ } else if (auto dtf = std::dynamic_pointer_cast(f)) {
+ result = std::make_shared(this->operator*(TableFactor(*dtf)));
+ }
+ return result;
+}
+
/* ************************************************************************ */
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys();
diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h
index a2fdb4d32..4b53d7e2b 100644
--- a/gtsam/discrete/TableFactor.h
+++ b/gtsam/discrete/TableFactor.h
@@ -178,6 +178,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/// multiply with DecisionTreeFactor
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
+ /// Multiply factors, DiscreteFactor::shared_ptr edition
+ virtual DiscreteFactor::shared_ptr multiply(
+ const DiscreteFactor::shared_ptr& f) const override;
+
static double safe_div(const double& a, const double& b);
/// divide by factor f (safely)
From 75a4e98715caca504bacf4cc92a768db3dd89303 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 5 Jan 2025 13:36:06 -0500
Subject: [PATCH 62/86] remove override from definition
---
gtsam/discrete/DecisionTreeFactor.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp
index cf22fe153..c15fd4e2e 100644
--- a/gtsam/discrete/DecisionTreeFactor.cpp
+++ b/gtsam/discrete/DecisionTreeFactor.cpp
@@ -64,7 +64,7 @@ namespace gtsam {
/* ************************************************************************ */
DiscreteFactor::shared_ptr DecisionTreeFactor::multiply(
- const DiscreteFactor::shared_ptr& f) const override {
+ const DiscreteFactor::shared_ptr& f) const {
DiscreteFactor::shared_ptr result;
if (auto tf = std::dynamic_pointer_cast(f)) {
result = std::make_shared((*tf) * TableFactor(*this));
From 700ad2bae326f3f89ecc78b56d55a74a64c3785a Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 5 Jan 2025 13:56:54 -0500
Subject: [PATCH 63/86] remove override from TableFactor definition
---
gtsam/discrete/TableFactor.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp
index cfa56b43a..3ca8fecda 100644
--- a/gtsam/discrete/TableFactor.cpp
+++ b/gtsam/discrete/TableFactor.cpp
@@ -256,7 +256,7 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
/* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::multiply(
- const DiscreteFactor::shared_ptr& f) const override {
+ const DiscreteFactor::shared_ptr& f) const {
DiscreteFactor::shared_ptr result;
if (auto tf = std::dynamic_pointer_cast(f)) {
result = std::make_shared(this->operator*(*tf));
From 260d448887447a9a9f3216a544ad34ddd1dfbd8c Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 5 Jan 2025 13:57:07 -0500
Subject: [PATCH 64/86] use new multiply method
---
gtsam/discrete/DiscreteFactorGraph.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp
index 227bb4da3..0444f47ae 100644
--- a/gtsam/discrete/DiscreteFactorGraph.cpp
+++ b/gtsam/discrete/DiscreteFactorGraph.cpp
@@ -65,11 +65,11 @@ namespace gtsam {
/* ************************************************************************ */
DecisionTreeFactor DiscreteFactorGraph::product() const {
- DecisionTreeFactor result;
- for (const sharedFactor& factor : *this) {
- if (factor) result = (*factor) * result;
+ DiscreteFactor::shared_ptr result = *this->begin();
+ for (auto it = this->begin() + 1; it != this->end(); ++it) {
+ if (*it) result = result->multiply(*it);
}
- return result;
+ return result->toDecisionTreeFactor();
}
/* ************************************************************************ */
From 453059bd61f4b08acdc43c37b9b098a5579d36a3 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 5 Jan 2025 14:03:08 -0500
Subject: [PATCH 65/86] simplify to remove DiscreteProduct static function
---
gtsam/discrete/DiscreteFactorGraph.cpp | 44 ++++++++++----------------
1 file changed, 17 insertions(+), 27 deletions(-)
diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp
index 0444f47ae..b3029111a 100644
--- a/gtsam/discrete/DiscreteFactorGraph.cpp
+++ b/gtsam/discrete/DiscreteFactorGraph.cpp
@@ -65,11 +65,23 @@ namespace gtsam {
/* ************************************************************************ */
DecisionTreeFactor DiscreteFactorGraph::product() const {
- DiscreteFactor::shared_ptr result = *this->begin();
+ // PRODUCT: multiply all factors
+ gttic(product);
+ DiscreteFactor::shared_ptr product = *this->begin();
for (auto it = this->begin() + 1; it != this->end(); ++it) {
- if (*it) result = result->multiply(*it);
+ if (*it) product = product->multiply(*it);
}
- return result->toDecisionTreeFactor();
+ gttoc(product);
+
+ DecisionTreeFactor = result->toDecisionTreeFactor();
+
+ // Max over all the potentials by pretending all keys are frontal:
+ auto denominator = product.max(product.size());
+
+ // Normalize the product factor to prevent underflow.
+ product = product / (*denominator);
+
+ return product;
}
/* ************************************************************************ */
@@ -111,34 +123,12 @@ namespace gtsam {
// }
// }
- /**
- * @brief Multiply all the `factors`.
- *
- * @param factors The factors to multiply as a DiscreteFactorGraph.
- * @return DecisionTreeFactor
- */
- static DecisionTreeFactor DiscreteProduct(
- const DiscreteFactorGraph& factors) {
- // PRODUCT: multiply all factors
- gttic(product);
- DecisionTreeFactor product = factors.product();
- gttoc(product);
-
- // Max over all the potentials by pretending all keys are frontal:
- auto denominator = product.max(product.size());
-
- // Normalize the product factor to prevent underflow.
- product = product / (*denominator);
-
- return product;
- }
-
/* ************************************************************************ */
// Alternate eliminate function for MPE
std::pair //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
- DecisionTreeFactor product = DiscreteProduct(factors);
+ DecisionTreeFactor product = factors.product();
// max out frontals, this is the factor on the separator
gttic(max);
@@ -216,7 +206,7 @@ namespace gtsam {
std::pair //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
- DecisionTreeFactor product = DiscreteProduct(factors);
+ DecisionTreeFactor product = factors.product();
// sum out frontals, this is the factor on the separator
gttic(sum);
From a02baec0119ca9b670a8b5b64ebecc5db492bcbd Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 5 Jan 2025 14:23:34 -0500
Subject: [PATCH 66/86] naive implementation of multiply for unstable
---
gtsam_unstable/discrete/AllDiff.h | 7 +++++++
gtsam_unstable/discrete/BinaryAllDiff.h | 7 +++++++
gtsam_unstable/discrete/Domain.h | 7 +++++++
gtsam_unstable/discrete/SingleValue.h | 7 +++++++
4 files changed, 28 insertions(+)
diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h
index 1180abad4..cfbd76e7c 100644
--- a/gtsam_unstable/discrete/AllDiff.h
+++ b/gtsam_unstable/discrete/AllDiff.h
@@ -53,6 +53,13 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
/// Multiply into a decisiontree
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
+ /// Multiply factors, DiscreteFactor::shared_ptr edition
+ DiscreteFactor::shared_ptr multiply(
+ const DiscreteFactor::shared_ptr& df) const override {
+ return std::make_shared(
+ this->operator*(df->toDecisionTreeFactor()));
+ }
+
/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree errorTree() const override {
throw std::runtime_error("AllDiff::error not implemented");
diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h
index e96bfdfde..a1a2bf0a6 100644
--- a/gtsam_unstable/discrete/BinaryAllDiff.h
+++ b/gtsam_unstable/discrete/BinaryAllDiff.h
@@ -69,6 +69,13 @@ class BinaryAllDiff : public Constraint {
return toDecisionTreeFactor() * f;
}
+ /// Multiply factors, DiscreteFactor::shared_ptr edition
+ DiscreteFactor::shared_ptr multiply(
+ const DiscreteFactor::shared_ptr& df) const override {
+ return std::make_shared(
+ this->operator*(df->toDecisionTreeFactor()));
+ }
+
/*
* Ensure Arc-consistency by checking every possible value of domain j.
* @param j domain to be checked
diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h
index 23a566d24..dea85934f 100644
--- a/gtsam_unstable/discrete/Domain.h
+++ b/gtsam_unstable/discrete/Domain.h
@@ -90,6 +90,13 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
/// Multiply into a decisiontree
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
+ /// Multiply factors, DiscreteFactor::shared_ptr edition
+ DiscreteFactor::shared_ptr multiply(
+ const DiscreteFactor::shared_ptr& df) const override {
+ return std::make_shared(
+ this->operator*(df->toDecisionTreeFactor()));
+ }
+
/*
* Ensure Arc-consistency by checking every possible value of domain j.
* @param j domain to be checked
diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h
index 3df1209b8..8675c929b 100644
--- a/gtsam_unstable/discrete/SingleValue.h
+++ b/gtsam_unstable/discrete/SingleValue.h
@@ -63,6 +63,13 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
/// Multiply into a decisiontree
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
+ /// Multiply factors, DiscreteFactor::shared_ptr edition
+ DiscreteFactor::shared_ptr multiply(
+ const DiscreteFactor::shared_ptr& df) const override {
+ return std::make_shared(
+ this->operator*(df->toDecisionTreeFactor()));
+ }
+
/*
* Ensure Arc-consistency: just sets domain[j] to {value_}.
* @param j domain to be checked
From 13bafb0a48bae3c5598b7db112ac2757bd02c431 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 5 Jan 2025 14:28:57 -0500
Subject: [PATCH 67/86] Revert "simplify to remove DiscreteProduct static
function"
This reverts commit 453059bd61f4b08acdc43c37b9b098a5579d36a3.
---
gtsam/discrete/DiscreteFactorGraph.cpp | 44 ++++++++++++++++----------
1 file changed, 27 insertions(+), 17 deletions(-)
diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp
index b3029111a..0444f47ae 100644
--- a/gtsam/discrete/DiscreteFactorGraph.cpp
+++ b/gtsam/discrete/DiscreteFactorGraph.cpp
@@ -65,23 +65,11 @@ namespace gtsam {
/* ************************************************************************ */
DecisionTreeFactor DiscreteFactorGraph::product() const {
- // PRODUCT: multiply all factors
- gttic(product);
- DiscreteFactor::shared_ptr product = *this->begin();
+ DiscreteFactor::shared_ptr result = *this->begin();
for (auto it = this->begin() + 1; it != this->end(); ++it) {
- if (*it) product = product->multiply(*it);
+ if (*it) result = result->multiply(*it);
}
- gttoc(product);
-
- DecisionTreeFactor = result->toDecisionTreeFactor();
-
- // Max over all the potentials by pretending all keys are frontal:
- auto denominator = product.max(product.size());
-
- // Normalize the product factor to prevent underflow.
- product = product / (*denominator);
-
- return product;
+ return result->toDecisionTreeFactor();
}
/* ************************************************************************ */
@@ -123,12 +111,34 @@ namespace gtsam {
// }
// }
+ /**
+ * @brief Multiply all the `factors`.
+ *
+ * @param factors The factors to multiply as a DiscreteFactorGraph.
+ * @return DecisionTreeFactor
+ */
+ static DecisionTreeFactor DiscreteProduct(
+ const DiscreteFactorGraph& factors) {
+ // PRODUCT: multiply all factors
+ gttic(product);
+ DecisionTreeFactor product = factors.product();
+ gttoc(product);
+
+ // Max over all the potentials by pretending all keys are frontal:
+ auto denominator = product.max(product.size());
+
+ // Normalize the product factor to prevent underflow.
+ product = product / (*denominator);
+
+ return product;
+ }
+
/* ************************************************************************ */
// Alternate eliminate function for MPE
std::pair //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
- DecisionTreeFactor product = factors.product();
+ DecisionTreeFactor product = DiscreteProduct(factors);
// max out frontals, this is the factor on the separator
gttic(max);
@@ -206,7 +216,7 @@ namespace gtsam {
std::pair //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
- DecisionTreeFactor product = factors.product();
+ DecisionTreeFactor product = DiscreteProduct(factors);
// sum out frontals, this is the factor on the separator
gttic(sum);
From a7fc6e3763d4dc5ea5019906be37e05613d3f9d9 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 5 Jan 2025 15:08:58 -0500
Subject: [PATCH 68/86] convert everything to DecisionTreeFactor so we can use
override operator* method
---
gtsam_unstable/discrete/AllDiff.h | 4 ++--
gtsam_unstable/discrete/BinaryAllDiff.h | 4 ++--
gtsam_unstable/discrete/Domain.h | 4 ++--
gtsam_unstable/discrete/SingleValue.h | 4 ++--
4 files changed, 8 insertions(+), 8 deletions(-)
diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h
index cfbd76e7c..032808dcd 100644
--- a/gtsam_unstable/discrete/AllDiff.h
+++ b/gtsam_unstable/discrete/AllDiff.h
@@ -56,8 +56,8 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
/// Multiply factors, DiscreteFactor::shared_ptr edition
DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const override {
- return std::make_shared(
- this->operator*(df->toDecisionTreeFactor()));
+ return std::make_shared(this->toDecisionTreeFactor() *
+ df->toDecisionTreeFactor());
}
/// Compute error for each assignment and return as a tree
diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h
index a1a2bf0a6..0ebae4d77 100644
--- a/gtsam_unstable/discrete/BinaryAllDiff.h
+++ b/gtsam_unstable/discrete/BinaryAllDiff.h
@@ -72,8 +72,8 @@ class BinaryAllDiff : public Constraint {
/// Multiply factors, DiscreteFactor::shared_ptr edition
DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const override {
- return std::make_shared(
- this->operator*(df->toDecisionTreeFactor()));
+ return std::make_shared(this->toDecisionTreeFactor() *
+ df->toDecisionTreeFactor());
}
/*
diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h
index dea85934f..9a4a21847 100644
--- a/gtsam_unstable/discrete/Domain.h
+++ b/gtsam_unstable/discrete/Domain.h
@@ -93,8 +93,8 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
/// Multiply factors, DiscreteFactor::shared_ptr edition
DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const override {
- return std::make_shared(
- this->operator*(df->toDecisionTreeFactor()));
+ return std::make_shared(this->toDecisionTreeFactor() *
+ df->toDecisionTreeFactor());
}
/*
diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h
index 8675c929b..ebe23f7e4 100644
--- a/gtsam_unstable/discrete/SingleValue.h
+++ b/gtsam_unstable/discrete/SingleValue.h
@@ -66,8 +66,8 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
/// Multiply factors, DiscreteFactor::shared_ptr edition
DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const override {
- return std::make_shared(
- this->operator*(df->toDecisionTreeFactor()));
+ return std::make_shared(this->toDecisionTreeFactor() *
+ df->toDecisionTreeFactor());
}
/*
From 8390ffa2cbde062f7628a953bba6f43ba77cc7d1 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 5 Jan 2025 15:19:16 -0500
Subject: [PATCH 69/86] revert previous commit
---
gtsam_unstable/discrete/AllDiff.h | 4 ++--
gtsam_unstable/discrete/BinaryAllDiff.h | 4 ++--
gtsam_unstable/discrete/Domain.h | 4 ++--
gtsam_unstable/discrete/SingleValue.h | 4 ++--
4 files changed, 8 insertions(+), 8 deletions(-)
diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h
index 032808dcd..cfbd76e7c 100644
--- a/gtsam_unstable/discrete/AllDiff.h
+++ b/gtsam_unstable/discrete/AllDiff.h
@@ -56,8 +56,8 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
/// Multiply factors, DiscreteFactor::shared_ptr edition
DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const override {
- return std::make_shared(this->toDecisionTreeFactor() *
- df->toDecisionTreeFactor());
+ return std::make_shared(
+ this->operator*(df->toDecisionTreeFactor()));
}
/// Compute error for each assignment and return as a tree
diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h
index 0ebae4d77..a1a2bf0a6 100644
--- a/gtsam_unstable/discrete/BinaryAllDiff.h
+++ b/gtsam_unstable/discrete/BinaryAllDiff.h
@@ -72,8 +72,8 @@ class BinaryAllDiff : public Constraint {
/// Multiply factors, DiscreteFactor::shared_ptr edition
DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const override {
- return std::make_shared(this->toDecisionTreeFactor() *
- df->toDecisionTreeFactor());
+ return std::make_shared(
+ this->operator*(df->toDecisionTreeFactor()));
}
/*
diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h
index 9a4a21847..dea85934f 100644
--- a/gtsam_unstable/discrete/Domain.h
+++ b/gtsam_unstable/discrete/Domain.h
@@ -93,8 +93,8 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
/// Multiply factors, DiscreteFactor::shared_ptr edition
DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const override {
- return std::make_shared(this->toDecisionTreeFactor() *
- df->toDecisionTreeFactor());
+ return std::make_shared(
+ this->operator*(df->toDecisionTreeFactor()));
}
/*
diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h
index ebe23f7e4..8675c929b 100644
--- a/gtsam_unstable/discrete/SingleValue.h
+++ b/gtsam_unstable/discrete/SingleValue.h
@@ -66,8 +66,8 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
/// Multiply factors, DiscreteFactor::shared_ptr edition
DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const override {
- return std::make_shared(this->toDecisionTreeFactor() *
- df->toDecisionTreeFactor());
+ return std::make_shared(
+ this->operator*(df->toDecisionTreeFactor()));
}
/*
From bc63cc8cb88c7edebf86a4136d7231ab51a59e47 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 5 Jan 2025 15:19:57 -0500
Subject: [PATCH 70/86] use double dispatch for else case
---
gtsam/discrete/DecisionTreeFactor.cpp | 3 +++
gtsam/discrete/TableFactor.cpp | 4 ++++
2 files changed, 7 insertions(+)
diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp
index c15fd4e2e..4b16dad8a 100644
--- a/gtsam/discrete/DecisionTreeFactor.cpp
+++ b/gtsam/discrete/DecisionTreeFactor.cpp
@@ -70,6 +70,9 @@ namespace gtsam {
result = std::make_shared((*tf) * TableFactor(*this));
} else if (auto dtf = std::dynamic_pointer_cast(f)) {
result = std::make_shared(this->operator*(*dtf));
+ } else {
+ // Simulate double dispatch in C++
+ result = std::make_shared(f->operator*(*this));
}
return result;
}
diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp
index 3ca8fecda..6516a4a98 100644
--- a/gtsam/discrete/TableFactor.cpp
+++ b/gtsam/discrete/TableFactor.cpp
@@ -262,6 +262,10 @@ DiscreteFactor::shared_ptr TableFactor::multiply(
result = std::make_shared(this->operator*(*tf));
} else if (auto dtf = std::dynamic_pointer_cast(f)) {
result = std::make_shared(this->operator*(TableFactor(*dtf)));
+ } else {
+ // Simulate double dispatch in C++
+ result = std::make_shared(
+ f->operator*(this->toDecisionTreeFactor()));
}
return result;
}
From 713c49c9153f9f023621a6cd9a06997594ea52ab Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 5 Jan 2025 15:20:09 -0500
Subject: [PATCH 71/86] more robust product
---
gtsam/discrete/DiscreteFactorGraph.cpp | 13 ++++++++++---
1 file changed, 10 insertions(+), 3 deletions(-)
diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp
index 0444f47ae..a2b896286 100644
--- a/gtsam/discrete/DiscreteFactorGraph.cpp
+++ b/gtsam/discrete/DiscreteFactorGraph.cpp
@@ -65,9 +65,16 @@ namespace gtsam {
/* ************************************************************************ */
DecisionTreeFactor DiscreteFactorGraph::product() const {
- DiscreteFactor::shared_ptr result = *this->begin();
- for (auto it = this->begin() + 1; it != this->end(); ++it) {
- if (*it) result = result->multiply(*it);
+ DiscreteFactor::shared_ptr result;
+ for (auto it = this->begin(); it != this->end(); ++it) {
+ if (*it) {
+ if (result) {
+ result = result->multiply(*it);
+ } else {
+ // Assign to the first non-null factor
+ result = *it;
+ }
+ }
}
return result->toDecisionTreeFactor();
}
From b83aadb20487f69c6ab932245f8524c8ec92fdde Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 5 Jan 2025 15:37:37 -0500
Subject: [PATCH 72/86] remove accidental type change
---
gtsam/discrete/DecisionTreeFactor.h | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h
index a5b82f277..eb6d9eaa2 100644
--- a/gtsam/discrete/DecisionTreeFactor.h
+++ b/gtsam/discrete/DecisionTreeFactor.h
@@ -181,12 +181,12 @@ namespace gtsam {
}
/// Create new factor by maximizing over all values with the same separator.
- DiscreteFactor::shared_ptr max(size_t nrFrontals) const {
+ shared_ptr max(size_t nrFrontals) const {
return combine(nrFrontals, Ring::max);
}
/// Create new factor by maximizing over all values with the same separator.
- DiscreteFactor::shared_ptr max(const Ordering& keys) const {
+ shared_ptr max(const Ordering& keys) const {
return combine(keys, Ring::max);
}
From b5128b2c9fcf31d04e4760c3c4178d8449ad44c9 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 5 Jan 2025 19:40:37 -0500
Subject: [PATCH 73/86] use DecisionTreeFactor version of sum and max where not
available
---
gtsam_unstable/discrete/AllDiff.h | 8 ++++----
gtsam_unstable/discrete/BinaryAllDiff.h | 8 ++++----
gtsam_unstable/discrete/Domain.h | 8 ++++----
gtsam_unstable/discrete/SingleValue.h | 8 ++++----
4 files changed, 16 insertions(+), 16 deletions(-)
diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h
index cf0e5e3cf..267ddb9fd 100644
--- a/gtsam_unstable/discrete/AllDiff.h
+++ b/gtsam_unstable/discrete/AllDiff.h
@@ -84,19 +84,19 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
uint64_t nrValues() const override { return 1; };
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
- throw std::runtime_error("Not implemented");
+ return toDecisionTreeFactor().sum(nrFrontals);
}
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
- throw std::runtime_error("Not implemented");
+ return toDecisionTreeFactor().sum(keys);
}
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
- throw std::runtime_error("Not implemented");
+ return toDecisionTreeFactor().max(nrFrontals);
}
DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
- throw std::runtime_error("Not implemented");
+ return toDecisionTreeFactor().max(keys);
}
};
diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h
index c15ac8aec..3035d0620 100644
--- a/gtsam_unstable/discrete/BinaryAllDiff.h
+++ b/gtsam_unstable/discrete/BinaryAllDiff.h
@@ -108,19 +108,19 @@ class BinaryAllDiff : public Constraint {
uint64_t nrValues() const override { return 1; };
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
- throw std::runtime_error("Not implemented");
+ return toDecisionTreeFactor().sum(nrFrontals);
}
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
- throw std::runtime_error("Not implemented");
+ return toDecisionTreeFactor().sum(keys);
}
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
- throw std::runtime_error("Not implemented");
+ return toDecisionTreeFactor().max(nrFrontals);
}
DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
- throw std::runtime_error("Not implemented");
+ return toDecisionTreeFactor().max(keys);
}
};
diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h
index f64716028..4c2d3f9dd 100644
--- a/gtsam_unstable/discrete/Domain.h
+++ b/gtsam_unstable/discrete/Domain.h
@@ -123,19 +123,19 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
Constraint::shared_ptr partiallyApply(const Domains& domains) const override;
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
- throw std::runtime_error("Not implemented");
+ return toDecisionTreeFactor().sum(nrFrontals);
}
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
- throw std::runtime_error("Not implemented");
+ return toDecisionTreeFactor().sum(keys);
}
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
- throw std::runtime_error("Not implemented");
+ return toDecisionTreeFactor().max(nrFrontals);
}
DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
- throw std::runtime_error("Not implemented");
+ return toDecisionTreeFactor().max(keys);
}
};
diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h
index c5824a96a..b6c91f912 100644
--- a/gtsam_unstable/discrete/SingleValue.h
+++ b/gtsam_unstable/discrete/SingleValue.h
@@ -89,19 +89,19 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
uint64_t nrValues() const override { return 1; };
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
- throw std::runtime_error("Not implemented");
+ return toDecisionTreeFactor().sum(nrFrontals);
}
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
- throw std::runtime_error("Not implemented");
+ return toDecisionTreeFactor().sum(keys);
}
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
- throw std::runtime_error("Not implemented");
+ return toDecisionTreeFactor().max(nrFrontals);
}
DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
- throw std::runtime_error("Not implemented");
+ return toDecisionTreeFactor().max(keys);
}
};
From 4ebca711461eb3fd914b74b4532242aedf38c048 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 5 Jan 2025 20:44:10 -0500
Subject: [PATCH 74/86] divide operator for DiscreteFactor::shared_ptr
---
gtsam/discrete/DecisionTreeFactor.cpp | 13 +++++++++++++
gtsam/discrete/DecisionTreeFactor.h | 5 ++---
gtsam/discrete/DiscreteFactor.h | 4 ++++
gtsam/discrete/TableFactor.cpp | 14 ++++++++++++++
gtsam/discrete/TableFactor.h | 11 ++---------
gtsam_unstable/discrete/AllDiff.cpp | 6 ++++++
gtsam_unstable/discrete/AllDiff.h | 4 ++++
gtsam_unstable/discrete/BinaryAllDiff.h | 6 ++++++
gtsam_unstable/discrete/Domain.cpp | 6 ++++++
gtsam_unstable/discrete/Domain.h | 4 ++++
gtsam_unstable/discrete/SingleValue.cpp | 6 ++++++
gtsam_unstable/discrete/SingleValue.h | 4 ++++
12 files changed, 71 insertions(+), 12 deletions(-)
diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp
index 4b16dad8a..2f2c039a4 100644
--- a/gtsam/discrete/DecisionTreeFactor.cpp
+++ b/gtsam/discrete/DecisionTreeFactor.cpp
@@ -77,6 +77,19 @@ namespace gtsam {
return result;
}
+ /* ************************************************************************ */
+ DiscreteFactor::shared_ptr DecisionTreeFactor::operator/(
+ const DiscreteFactor::shared_ptr& f) const {
+ if (auto tf = std::dynamic_pointer_cast(f)) {
+ return std::make_shared(tf->operator/(TableFactor(*this)));
+ } else if (auto dtf = std::dynamic_pointer_cast(f)) {
+ return std::make_shared(this->operator/(*dtf));
+ } else {
+ return std::make_shared(
+ this->operator/(this->toDecisionTreeFactor()));
+ }
+ }
+
/* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum
diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h
index d3cb55fa5..a5327bdd0 100644
--- a/gtsam/discrete/DecisionTreeFactor.h
+++ b/gtsam/discrete/DecisionTreeFactor.h
@@ -165,9 +165,8 @@ namespace gtsam {
}
/// divide by DiscreteFactor::shared_ptr f (safely)
- DecisionTreeFactor operator/(const DiscreteFactor::shared_ptr& f) const {
- return apply(*std::dynamic_pointer_cast(f), safe_div);
- }
+ DiscreteFactor::shared_ptr operator/(
+ const DiscreteFactor::shared_ptr& f) const override;
/// Convert into a decision tree
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h
index adc79bbd5..6cbc00d09 100644
--- a/gtsam/discrete/DiscreteFactor.h
+++ b/gtsam/discrete/DiscreteFactor.h
@@ -140,6 +140,10 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const = 0;
+ /// divide by DiscreteFactor::shared_ptr f (safely)
+ virtual DiscreteFactor::shared_ptr operator/(
+ const DiscreteFactor::shared_ptr& df) const = 0;
+
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
/// Create new factor by summing all values with the same separator values
diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp
index 6516a4a98..b692e9ba2 100644
--- a/gtsam/discrete/TableFactor.cpp
+++ b/gtsam/discrete/TableFactor.cpp
@@ -270,6 +270,20 @@ DiscreteFactor::shared_ptr TableFactor::multiply(
return result;
}
+/* ************************************************************************ */
+DiscreteFactor::shared_ptr TableFactor::operator/(
+ const DiscreteFactor::shared_ptr& f) const {
+ if (auto tf = std::dynamic_pointer_cast(f)) {
+ return std::make_shared(this->operator/(*tf));
+ } else if (auto dtf = std::dynamic_pointer_cast(f)) {
+ return std::make_shared(
+ this->operator/(TableFactor(f->discreteKeys(), *dtf)));
+ } else {
+ TableFactor divisor(f->toDecisionTreeFactor());
+ return std::make_shared(this->operator/(divisor));
+ }
+}
+
/* ************************************************************************ */
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys();
diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h
index f7d0f5215..a2f74758f 100644
--- a/gtsam/discrete/TableFactor.h
+++ b/gtsam/discrete/TableFactor.h
@@ -191,15 +191,8 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
}
/// divide by DiscreteFactor::shared_ptr f (safely)
- TableFactor operator/(const DiscreteFactor::shared_ptr& f) const {
- if (auto tf = std::dynamic_pointer_cast(f)) {
- return apply(*tf, safe_div);
- } else if (auto dtf = std::dynamic_pointer_cast(f)) {
- return apply(TableFactor(f->discreteKeys(), *dtf), safe_div);
- } else {
- throw std::runtime_error("Unknown derived type for DiscreteFactor");
- }
- }
+ DiscreteFactor::shared_ptr operator/(
+ const DiscreteFactor::shared_ptr& f) const override;
/// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override;
diff --git a/gtsam_unstable/discrete/AllDiff.cpp b/gtsam_unstable/discrete/AllDiff.cpp
index 585ca8103..01f50fa3d 100644
--- a/gtsam_unstable/discrete/AllDiff.cpp
+++ b/gtsam_unstable/discrete/AllDiff.cpp
@@ -56,6 +56,12 @@ DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f;
}
+/* ************************************************************************* */
+DiscreteFactor::shared_ptr AllDiff::operator/(
+ const DiscreteFactor::shared_ptr& df) const {
+ return this->toDecisionTreeFactor() / df;
+}
+
/* ************************************************************************* */
bool AllDiff::ensureArcConsistency(Key j, Domains* domains) const {
Domain& Dj = domains->at(j);
diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h
index 267ddb9fd..7a7b1cecc 100644
--- a/gtsam_unstable/discrete/AllDiff.h
+++ b/gtsam_unstable/discrete/AllDiff.h
@@ -60,6 +60,10 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
this->operator*(df->toDecisionTreeFactor()));
}
+ /// divide by DiscreteFactor::shared_ptr f (safely)
+ DiscreteFactor::shared_ptr operator/(
+ const DiscreteFactor::shared_ptr& df) const override;
+
/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree errorTree() const override {
throw std::runtime_error("AllDiff::error not implemented");
diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h
index 3035d0620..fbff8a01c 100644
--- a/gtsam_unstable/discrete/BinaryAllDiff.h
+++ b/gtsam_unstable/discrete/BinaryAllDiff.h
@@ -76,6 +76,12 @@ class BinaryAllDiff : public Constraint {
this->operator*(df->toDecisionTreeFactor()));
}
+ /// divide by DiscreteFactor::shared_ptr f (safely)
+ DiscreteFactor::shared_ptr operator/(
+ const DiscreteFactor::shared_ptr& df) const override {
+ return this->toDecisionTreeFactor() / df;
+ }
+
/*
* Ensure Arc-consistency by checking every possible value of domain j.
* @param j domain to be checked
diff --git a/gtsam_unstable/discrete/Domain.cpp b/gtsam_unstable/discrete/Domain.cpp
index 74f621dc7..cecb7cc1a 100644
--- a/gtsam_unstable/discrete/Domain.cpp
+++ b/gtsam_unstable/discrete/Domain.cpp
@@ -49,6 +49,12 @@ DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f;
}
+/* ************************************************************************* */
+DiscreteFactor::shared_ptr Domain::operator/(
+ const DiscreteFactor::shared_ptr& df) const {
+ return this->toDecisionTreeFactor() / df;
+}
+
/* ************************************************************************* */
bool Domain::ensureArcConsistency(Key j, Domains* domains) const {
if (j != key()) throw invalid_argument("Domain check on wrong domain");
diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h
index 4c2d3f9dd..7362e9caf 100644
--- a/gtsam_unstable/discrete/Domain.h
+++ b/gtsam_unstable/discrete/Domain.h
@@ -97,6 +97,10 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
this->operator*(df->toDecisionTreeFactor()));
}
+ /// divide by DiscreteFactor::shared_ptr f (safely)
+ DiscreteFactor::shared_ptr operator/(
+ const DiscreteFactor::shared_ptr& df) const override;
+
/*
* Ensure Arc-consistency by checking every possible value of domain j.
* @param j domain to be checked
diff --git a/gtsam_unstable/discrete/SingleValue.cpp b/gtsam_unstable/discrete/SingleValue.cpp
index 220bc9c06..09a8314df 100644
--- a/gtsam_unstable/discrete/SingleValue.cpp
+++ b/gtsam_unstable/discrete/SingleValue.cpp
@@ -41,6 +41,12 @@ DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f;
}
+/* ************************************************************************* */
+DiscreteFactor::shared_ptr SingleValue::operator/(
+ const DiscreteFactor::shared_ptr& df) const {
+ return this->toDecisionTreeFactor() / df;
+}
+
/* ************************************************************************* */
bool SingleValue::ensureArcConsistency(Key j, Domains* domains) const {
if (j != keys_[0])
diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h
index b6c91f912..87c42fc80 100644
--- a/gtsam_unstable/discrete/SingleValue.h
+++ b/gtsam_unstable/discrete/SingleValue.h
@@ -70,6 +70,10 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
this->operator*(df->toDecisionTreeFactor()));
}
+ /// divide by DiscreteFactor::shared_ptr f (safely)
+ DiscreteFactor::shared_ptr operator/(
+ const DiscreteFactor::shared_ptr& df) const override;
+
/*
* Ensure Arc-consistency: just sets domain[j] to {value_}.
* @param j domain to be checked
From fb1d52a18eec0e804276ef542edf4bdd34f69d3f Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 5 Jan 2025 20:49:24 -0500
Subject: [PATCH 75/86] fix constructor
---
gtsam/discrete/DiscreteConditional.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp
index 26f38e278..06a08eca5 100644
--- a/gtsam/discrete/DiscreteConditional.cpp
+++ b/gtsam/discrete/DiscreteConditional.cpp
@@ -45,7 +45,8 @@ template class GTSAM_EXPORT
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
const DecisionTreeFactor& f)
- : BaseFactor(f / f.sum(nrFrontals)), BaseConditional(nrFrontals) {}
+ : BaseFactor(f / f.sum(nrFrontals)->toDecisionTreeFactor()),
+ BaseConditional(nrFrontals) {}
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
From e9822a70d2921fdbd433f55fa29d91ac9467f7f1 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 5 Jan 2025 20:50:24 -0500
Subject: [PATCH 76/86] update DiscreteFactorGraph to use
DiscreteFactor::shared_ptr for elimination
---
gtsam/discrete/DiscreteFactorGraph.cpp | 37 +++++++++++++-------------
gtsam/discrete/DiscreteFactorGraph.h | 2 +-
2 files changed, 19 insertions(+), 20 deletions(-)
diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp
index e05cf9e33..eb3221819 100644
--- a/gtsam/discrete/DiscreteFactorGraph.cpp
+++ b/gtsam/discrete/DiscreteFactorGraph.cpp
@@ -64,7 +64,7 @@ namespace gtsam {
}
/* ************************************************************************ */
- DecisionTreeFactor DiscreteFactorGraph::product() const {
+ DiscreteFactor::shared_ptr DiscreteFactorGraph::product() const {
DiscreteFactor::shared_ptr result;
for (auto it = this->begin(); it != this->end(); ++it) {
if (*it) {
@@ -76,7 +76,7 @@ namespace gtsam {
}
}
}
- return result->toDecisionTreeFactor();
+ return result;
}
/* ************************************************************************ */
@@ -122,20 +122,20 @@ namespace gtsam {
* @brief Multiply all the `factors`.
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
- * @return DecisionTreeFactor
+ * @return DiscreteFactor::shared_ptr
*/
- static DecisionTreeFactor DiscreteProduct(
+ static DiscreteFactor::shared_ptr DiscreteProduct(
const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors
gttic(product);
- DecisionTreeFactor product = factors.product();
+ DiscreteFactor::shared_ptr product = factors.product();
gttoc(product);
// Max over all the potentials by pretending all keys are frontal:
- auto denominator = product.max(product.size());
+ auto denominator = product->max(product->size());
// Normalize the product factor to prevent underflow.
- product = product / denominator;
+ product = product->operator/(denominator);
return product;
}
@@ -145,26 +145,25 @@ namespace gtsam {
std::pair //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
- DecisionTreeFactor product = DiscreteProduct(factors);
+ DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
// max out frontals, this is the factor on the separator
gttic(max);
- DecisionTreeFactor::shared_ptr max =
- std::dynamic_pointer_cast(product.max(frontalKeys));
+ DiscreteFactor::shared_ptr max = product->max(frontalKeys);
gttoc(max);
// Ordering keys for the conditional so that frontalKeys are really in front
DiscreteKeys orderedKeys;
for (auto&& key : frontalKeys)
- orderedKeys.emplace_back(key, product.cardinality(key));
+ orderedKeys.emplace_back(key, product->cardinality(key));
for (auto&& key : max->keys())
- orderedKeys.emplace_back(key, product.cardinality(key));
+ orderedKeys.emplace_back(key, product->cardinality(key));
// Make lookup with product
gttic(lookup);
size_t nrFrontals = frontalKeys.size();
- auto lookup =
- std::make_shared(nrFrontals, orderedKeys, product);
+ auto lookup = std::make_shared(
+ nrFrontals, orderedKeys, product->toDecisionTreeFactor());
gttoc(lookup);
return {std::dynamic_pointer_cast(lookup), max};
@@ -224,12 +223,11 @@ namespace gtsam {
std::pair //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
- DecisionTreeFactor product = DiscreteProduct(factors);
+ DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
// sum out frontals, this is the factor on the separator
gttic(sum);
- DecisionTreeFactor::shared_ptr sum = std::dynamic_pointer_cast(
- product.sum(frontalKeys));
+ DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
gttoc(sum);
// Ordering keys for the conditional so that frontalKeys are really in front
@@ -241,8 +239,9 @@ namespace gtsam {
// now divide product/sum to get conditional
gttic(divide);
- auto conditional =
- std::make_shared(product, *sum, orderedKeys);
+ auto conditional = std::make_shared(
+ product->toDecisionTreeFactor(), sum->toDecisionTreeFactor(),
+ orderedKeys);
gttoc(divide);
return {conditional, sum};
diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h
index b311cb78b..3d9e86cd1 100644
--- a/gtsam/discrete/DiscreteFactorGraph.h
+++ b/gtsam/discrete/DiscreteFactorGraph.h
@@ -148,7 +148,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
DiscreteKeys discreteKeys() const;
/** return product of all factors as a single factor */
- DecisionTreeFactor product() const;
+ DiscreteFactor::shared_ptr product() const;
/**
* Evaluates the factor graph given values, returns the joint probability of
From 2f8c8ddb75cb96b30b550bd03e0a659746938857 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 5 Jan 2025 20:50:40 -0500
Subject: [PATCH 77/86] update tests
---
gtsam/discrete/tests/testDiscreteConditional.cpp | 4 ++--
gtsam/discrete/tests/testDiscreteFactorGraph.cpp | 7 ++++---
gtsam_unstable/discrete/tests/testCSP.cpp | 2 +-
gtsam_unstable/discrete/tests/testScheduler.cpp | 2 +-
4 files changed, 8 insertions(+), 7 deletions(-)
diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp
index d17c76837..b91e1bd8a 100644
--- a/gtsam/discrete/tests/testDiscreteConditional.cpp
+++ b/gtsam/discrete/tests/testDiscreteConditional.cpp
@@ -46,7 +46,7 @@ TEST(DiscreteConditional, constructors) {
DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2);
- DecisionTreeFactor expected2 = f2 / f2.sum(1);
+ DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor();
EXPECT(assert_equal(expected2, static_cast(actual2)));
std::vector probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75};
@@ -70,7 +70,7 @@ TEST(DiscreteConditional, constructors_alt_interface) {
DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2);
- DecisionTreeFactor expected2 = f2 / f2.sum(1);
+ DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor();
EXPECT(assert_equal(expected2, static_cast(actual2)));
}
diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp
index 4ee36f0ab..0c1dd7a2a 100644
--- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp
+++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp
@@ -94,7 +94,7 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9);
// Check if graph product works
- DecisionTreeFactor product = graph.product();
+ DecisionTreeFactor product = graph.product()->toDecisionTreeFactor();
EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9);
}
@@ -117,7 +117,7 @@ TEST(DiscreteFactorGraph, test) {
*std::dynamic_pointer_cast(newFactorPtr);
// Normalize newFactor by max for comparison with expected
- auto denominator = newFactor.max(newFactor.size());
+ auto denominator = newFactor.max(newFactor.size())->toDecisionTreeFactor();
newFactor = newFactor / denominator;
@@ -131,7 +131,8 @@ TEST(DiscreteFactorGraph, test) {
CHECK(&newFactor);
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
// Normalize by max.
- denominator = expectedFactor.max(expectedFactor.size());
+ denominator =
+ expectedFactor.max(expectedFactor.size())->toDecisionTreeFactor();
// Ensure denominator is correct.
expectedFactor = expectedFactor / denominator;
EXPECT(assert_equal(expectedFactor, newFactor));
diff --git a/gtsam_unstable/discrete/tests/testCSP.cpp b/gtsam_unstable/discrete/tests/testCSP.cpp
index 2b9a20ca6..6806bfe58 100644
--- a/gtsam_unstable/discrete/tests/testCSP.cpp
+++ b/gtsam_unstable/discrete/tests/testCSP.cpp
@@ -124,7 +124,7 @@ TEST(CSP, allInOne) {
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
// Just for fun, create the product and check it
- DecisionTreeFactor product = csp.product();
+ DecisionTreeFactor product = csp.product()->toDecisionTreeFactor();
// product.dot("product");
DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0");
EXPECT(assert_equal(expectedProduct, product));
diff --git a/gtsam_unstable/discrete/tests/testScheduler.cpp b/gtsam_unstable/discrete/tests/testScheduler.cpp
index f868abb5e..5f9b7f287 100644
--- a/gtsam_unstable/discrete/tests/testScheduler.cpp
+++ b/gtsam_unstable/discrete/tests/testScheduler.cpp
@@ -113,7 +113,7 @@ TEST(schedulingExample, test) {
EXPECT(assert_equal(expected, (DiscreteFactorGraph)s));
// Do brute force product and output that to file
- DecisionTreeFactor product = s.product();
+ DecisionTreeFactor product = s.product()->toDecisionTreeFactor();
// product.dot("scheduling", false);
// Do exact inference
From 2434e248e95d0efac75a18cddae2748ff8be4502 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Sun, 5 Jan 2025 20:54:56 -0500
Subject: [PATCH 78/86] undo print change in DiscreteLookupTable
---
gtsam/discrete/DiscreteLookupDAG.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/gtsam/discrete/DiscreteLookupDAG.cpp b/gtsam/discrete/DiscreteLookupDAG.cpp
index c1c301525..d1108e7b7 100644
--- a/gtsam/discrete/DiscreteLookupDAG.cpp
+++ b/gtsam/discrete/DiscreteLookupDAG.cpp
@@ -48,7 +48,7 @@ void DiscreteLookupTable::print(const std::string& s,
}
}
cout << "):\n";
- BaseFactor::print("", formatter);
+ ADT::print("", formatter);
cout << endl;
}
From 43f755d9d8d0ede5868ac18ae33e9350c43a2e78 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Mon, 6 Jan 2025 11:17:03 -0500
Subject: [PATCH 79/86] move multiply to Constraint.h
---
gtsam_unstable/discrete/AllDiff.h | 7 -------
gtsam_unstable/discrete/BinaryAllDiff.h | 7 -------
gtsam_unstable/discrete/Constraint.h | 8 ++++++++
gtsam_unstable/discrete/Domain.h | 7 -------
gtsam_unstable/discrete/SingleValue.h | 7 -------
5 files changed, 8 insertions(+), 28 deletions(-)
diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h
index cfbd76e7c..1180abad4 100644
--- a/gtsam_unstable/discrete/AllDiff.h
+++ b/gtsam_unstable/discrete/AllDiff.h
@@ -53,13 +53,6 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
/// Multiply into a decisiontree
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
- /// Multiply factors, DiscreteFactor::shared_ptr edition
- DiscreteFactor::shared_ptr multiply(
- const DiscreteFactor::shared_ptr& df) const override {
- return std::make_shared(
- this->operator*(df->toDecisionTreeFactor()));
- }
-
/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree errorTree() const override {
throw std::runtime_error("AllDiff::error not implemented");
diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h
index a1a2bf0a6..e96bfdfde 100644
--- a/gtsam_unstable/discrete/BinaryAllDiff.h
+++ b/gtsam_unstable/discrete/BinaryAllDiff.h
@@ -69,13 +69,6 @@ class BinaryAllDiff : public Constraint {
return toDecisionTreeFactor() * f;
}
- /// Multiply factors, DiscreteFactor::shared_ptr edition
- DiscreteFactor::shared_ptr multiply(
- const DiscreteFactor::shared_ptr& df) const override {
- return std::make_shared(
- this->operator*(df->toDecisionTreeFactor()));
- }
-
/*
* Ensure Arc-consistency by checking every possible value of domain j.
* @param j domain to be checked
diff --git a/gtsam_unstable/discrete/Constraint.h b/gtsam_unstable/discrete/Constraint.h
index 3526a282d..71ed7647a 100644
--- a/gtsam_unstable/discrete/Constraint.h
+++ b/gtsam_unstable/discrete/Constraint.h
@@ -78,6 +78,14 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor {
/// Partially apply known values, domain version
virtual shared_ptr partiallyApply(const Domains&) const = 0;
+
+ /// Multiply factors, DiscreteFactor::shared_ptr edition
+ DiscreteFactor::shared_ptr multiply(
+ const DiscreteFactor::shared_ptr& df) const override {
+ return std::make_shared(
+ this->operator*(df->toDecisionTreeFactor()));
+ }
+
/// @}
/// @name Wrapper support
/// @{
diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h
index dea85934f..23a566d24 100644
--- a/gtsam_unstable/discrete/Domain.h
+++ b/gtsam_unstable/discrete/Domain.h
@@ -90,13 +90,6 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
/// Multiply into a decisiontree
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
- /// Multiply factors, DiscreteFactor::shared_ptr edition
- DiscreteFactor::shared_ptr multiply(
- const DiscreteFactor::shared_ptr& df) const override {
- return std::make_shared(
- this->operator*(df->toDecisionTreeFactor()));
- }
-
/*
* Ensure Arc-consistency by checking every possible value of domain j.
* @param j domain to be checked
diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h
index 8675c929b..3df1209b8 100644
--- a/gtsam_unstable/discrete/SingleValue.h
+++ b/gtsam_unstable/discrete/SingleValue.h
@@ -63,13 +63,6 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
/// Multiply into a decisiontree
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
- /// Multiply factors, DiscreteFactor::shared_ptr edition
- DiscreteFactor::shared_ptr multiply(
- const DiscreteFactor::shared_ptr& df) const override {
- return std::make_shared(
- this->operator*(df->toDecisionTreeFactor()));
- }
-
/*
* Ensure Arc-consistency: just sets domain[j] to {value_}.
* @param j domain to be checked
From 7561da4df2df5740a808d0e6ea4e957eb65cb4f2 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Mon, 6 Jan 2025 13:35:45 -0500
Subject: [PATCH 80/86] move operator/ to Constraint.h
---
gtsam_unstable/discrete/AllDiff.cpp | 6 ------
gtsam_unstable/discrete/Constraint.h | 6 ++++++
gtsam_unstable/discrete/Domain.cpp | 6 ------
gtsam_unstable/discrete/SingleValue.cpp | 6 ------
4 files changed, 6 insertions(+), 18 deletions(-)
diff --git a/gtsam_unstable/discrete/AllDiff.cpp b/gtsam_unstable/discrete/AllDiff.cpp
index 01f50fa3d..585ca8103 100644
--- a/gtsam_unstable/discrete/AllDiff.cpp
+++ b/gtsam_unstable/discrete/AllDiff.cpp
@@ -56,12 +56,6 @@ DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f;
}
-/* ************************************************************************* */
-DiscreteFactor::shared_ptr AllDiff::operator/(
- const DiscreteFactor::shared_ptr& df) const {
- return this->toDecisionTreeFactor() / df;
-}
-
/* ************************************************************************* */
bool AllDiff::ensureArcConsistency(Key j, Domains* domains) const {
Domain& Dj = domains->at(j);
diff --git a/gtsam_unstable/discrete/Constraint.h b/gtsam_unstable/discrete/Constraint.h
index 71ed7647a..2d98ab40b 100644
--- a/gtsam_unstable/discrete/Constraint.h
+++ b/gtsam_unstable/discrete/Constraint.h
@@ -86,6 +86,12 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor {
this->operator*(df->toDecisionTreeFactor()));
}
+ /// divide by DiscreteFactor::shared_ptr f (safely)
+ DiscreteFactor::shared_ptr operator/(
+ const DiscreteFactor::shared_ptr& df) const override {
+ return this->toDecisionTreeFactor() / df;
+ }
+
/// @}
/// @name Wrapper support
/// @{
diff --git a/gtsam_unstable/discrete/Domain.cpp b/gtsam_unstable/discrete/Domain.cpp
index cecb7cc1a..74f621dc7 100644
--- a/gtsam_unstable/discrete/Domain.cpp
+++ b/gtsam_unstable/discrete/Domain.cpp
@@ -49,12 +49,6 @@ DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f;
}
-/* ************************************************************************* */
-DiscreteFactor::shared_ptr Domain::operator/(
- const DiscreteFactor::shared_ptr& df) const {
- return this->toDecisionTreeFactor() / df;
-}
-
/* ************************************************************************* */
bool Domain::ensureArcConsistency(Key j, Domains* domains) const {
if (j != key()) throw invalid_argument("Domain check on wrong domain");
diff --git a/gtsam_unstable/discrete/SingleValue.cpp b/gtsam_unstable/discrete/SingleValue.cpp
index 09a8314df..220bc9c06 100644
--- a/gtsam_unstable/discrete/SingleValue.cpp
+++ b/gtsam_unstable/discrete/SingleValue.cpp
@@ -41,12 +41,6 @@ DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f;
}
-/* ************************************************************************* */
-DiscreteFactor::shared_ptr SingleValue::operator/(
- const DiscreteFactor::shared_ptr& df) const {
- return this->toDecisionTreeFactor() / df;
-}
-
/* ************************************************************************* */
bool SingleValue::ensureArcConsistency(Key j, Domains* domains) const {
if (j != keys_[0])
From ff5371fd4a42cb91fdf9d0a222dbef2d5036a294 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Mon, 6 Jan 2025 13:38:45 -0500
Subject: [PATCH 81/86] move sum, max and nrValues to Constraint class as well
---
gtsam_unstable/discrete/AllDiff.h | 19 -------------------
gtsam_unstable/discrete/BinaryAllDiff.h | 19 -------------------
gtsam_unstable/discrete/Constraint.h | 22 +++++++++++++++++++++-
gtsam_unstable/discrete/Domain.h | 16 ----------------
gtsam_unstable/discrete/SingleValue.h | 19 -------------------
5 files changed, 21 insertions(+), 74 deletions(-)
diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h
index 34e5c4700..1180abad4 100644
--- a/gtsam_unstable/discrete/AllDiff.h
+++ b/gtsam_unstable/discrete/AllDiff.h
@@ -72,25 +72,6 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
/// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply(
const Domains&) const override;
-
- /// Get the number of non-zero values contained in this factor.
- uint64_t nrValues() const override { return 1; };
-
- DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
- return toDecisionTreeFactor().sum(nrFrontals);
- }
-
- DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
- return toDecisionTreeFactor().sum(keys);
- }
-
- DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
- return toDecisionTreeFactor().max(nrFrontals);
- }
-
- DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
- return toDecisionTreeFactor().max(keys);
- }
};
} // namespace gtsam
diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h
index 0cd51ec61..e96bfdfde 100644
--- a/gtsam_unstable/discrete/BinaryAllDiff.h
+++ b/gtsam_unstable/discrete/BinaryAllDiff.h
@@ -96,25 +96,6 @@ class BinaryAllDiff : public Constraint {
AlgebraicDecisionTree errorTree() const override {
throw std::runtime_error("BinaryAllDiff::error not implemented");
}
-
- /// Get the number of non-zero values contained in this factor.
- uint64_t nrValues() const override { return 1; };
-
- DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
- return toDecisionTreeFactor().sum(nrFrontals);
- }
-
- DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
- return toDecisionTreeFactor().sum(keys);
- }
-
- DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
- return toDecisionTreeFactor().max(nrFrontals);
- }
-
- DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
- return toDecisionTreeFactor().max(keys);
- }
};
} // namespace gtsam
diff --git a/gtsam_unstable/discrete/Constraint.h b/gtsam_unstable/discrete/Constraint.h
index 2d98ab40b..328fabbea 100644
--- a/gtsam_unstable/discrete/Constraint.h
+++ b/gtsam_unstable/discrete/Constraint.h
@@ -68,7 +68,8 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor {
/*
* Ensure Arc-consistency by checking every possible value of domain j.
* @param j domain to be checked
- * @param (in/out) domains all domains, but only domains->at(j) will be checked.
+ * @param (in/out) domains all domains, but only domains->at(j) will be
+ * checked.
* @return true if domains->at(j) was changed, false otherwise.
*/
virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0;
@@ -92,6 +93,25 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor {
return this->toDecisionTreeFactor() / df;
}
+ /// Get the number of non-zero values contained in this factor.
+ uint64_t nrValues() const override { return 1; };
+
+ DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
+ return toDecisionTreeFactor().sum(nrFrontals);
+ }
+
+ DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
+ return toDecisionTreeFactor().sum(keys);
+ }
+
+ DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
+ return toDecisionTreeFactor().max(nrFrontals);
+ }
+
+ DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
+ return toDecisionTreeFactor().max(keys);
+ }
+
/// @}
/// @name Wrapper support
/// @{
diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h
index 2372cf499..6ce846201 100644
--- a/gtsam_unstable/discrete/Domain.h
+++ b/gtsam_unstable/discrete/Domain.h
@@ -114,22 +114,6 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
/// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply(const Domains& domains) const override;
-
- DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
- return toDecisionTreeFactor().sum(nrFrontals);
- }
-
- DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
- return toDecisionTreeFactor().sum(keys);
- }
-
- DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
- return toDecisionTreeFactor().max(nrFrontals);
- }
-
- DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
- return toDecisionTreeFactor().max(keys);
- }
};
} // namespace gtsam
diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h
index 5d4c2dca1..3df1209b8 100644
--- a/gtsam_unstable/discrete/SingleValue.h
+++ b/gtsam_unstable/discrete/SingleValue.h
@@ -77,25 +77,6 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
/// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply(
const Domains& domains) const override;
-
- /// Get the number of non-zero values contained in this factor.
- uint64_t nrValues() const override { return 1; };
-
- DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
- return toDecisionTreeFactor().sum(nrFrontals);
- }
-
- DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
- return toDecisionTreeFactor().sum(keys);
- }
-
- DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
- return toDecisionTreeFactor().max(nrFrontals);
- }
-
- DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
- return toDecisionTreeFactor().max(keys);
- }
};
} // namespace gtsam
From ab90e0b0f3d68f3f144b5fc4d7dd87a7e6425901 Mon Sep 17 00:00:00 2001
From: Varun Agrawal
Date: Mon, 6 Jan 2025 13:44:55 -0500
Subject: [PATCH 82/86] move include to .cpp
---
gtsam/discrete/DecisionTreeFactor.cpp | 3 ++-
gtsam/discrete/DecisionTreeFactor.h | 1 -
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp
index 4b16dad8a..e353fdebf 100644
--- a/gtsam/discrete/DecisionTreeFactor.cpp
+++ b/gtsam/discrete/DecisionTreeFactor.cpp
@@ -18,9 +18,10 @@
*/
#include
-#include
#include
#include
+#include
+#include
#include
diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h
index 3e70c0df9..ff9bf0df9 100644
--- a/gtsam/discrete/DecisionTreeFactor.h
+++ b/gtsam/discrete/DecisionTreeFactor.h
@@ -22,7 +22,6 @@
#include
#include