From d1d440ad3420efb6a35bef80f39699b6b075e810 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 7 Dec 2024 10:53:32 -0500 Subject: [PATCH] add nrValues method --- gtsam/discrete/DecisionTreeFactor.h | 6 ++++++ gtsam/discrete/DiscreteFactor.h | 6 ++++++ gtsam/discrete/TableFactor.h | 6 ++++++ gtsam_unstable/discrete/AllDiff.h | 3 +++ gtsam_unstable/discrete/BinaryAllDiff.h | 3 +++ gtsam_unstable/discrete/Domain.h | 2 +- gtsam_unstable/discrete/SingleValue.h | 3 +++ 7 files changed, 28 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index a8ab2644f..f417a38d7 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -255,6 +255,12 @@ namespace gtsam { */ DecisionTreeFactor prune(size_t maxNrAssignments) const; + /** + * Get the number of non-zero values contained in this factor. + * It could be much smaller than `prod_{key}(cardinality(key))`. + */ + uint64_t nrValues() const override { return nrLeaves(); } + /// @} /// @name Wrapper support /// @{ diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 19af5bd13..7d5047ec6 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -113,6 +113,12 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; + /** + * Get the number of non-zero values contained in this factor. + * It could be much smaller than `prod_{key}(cardinality(key))`. + */ + virtual uint64_t nrValues() const = 0; + /// @} /// @name Wrapper support /// @{ diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index f0ecd66a3..b988eebad 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -324,6 +324,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { */ TableFactor prune(size_t maxNrAssignments) const; + /** + * Get the number of non-zero values contained in this factor. + * It could be much smaller than `prod_{key}(cardinality(key))`. + */ + uint64_t nrValues() const override { return sparse_table_.nonZeros(); } + /// @} /// @name Wrapper support /// @{ diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index d7a63eae0..42a255bbf 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -72,6 +72,9 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { /// Partially apply known values, domain version Constraint::shared_ptr partiallyApply( const Domains&) const override; + + /// Get the number of non-zero values contained in this factor. + uint64_t nrValues() const override { return 1; }; }; } // namespace gtsam diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index 18b335092..22acfb092 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -96,6 +96,9 @@ class BinaryAllDiff : public Constraint { AlgebraicDecisionTree errorTree() const override { throw std::runtime_error("BinaryAllDiff::error not implemented"); } + + /// Get the number of non-zero values contained in this factor. + uint64_t nrValues() const override { return 1; }; }; } // namespace gtsam diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 7f7b717c2..ba3771eca 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -49,7 +49,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { /// Erase a value, non const :-( void erase(size_t value) { values_.erase(value); } - size_t nrValues() const { return values_.size(); } + uint64_t nrValues() const override { return values_.size(); } bool isSingleton() const { return nrValues() == 1; } diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index 3f7f22d6a..7f2eb2c2c 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -77,6 +77,9 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { /// Partially apply known values, domain version Constraint::shared_ptr partiallyApply( const Domains& domains) const override; + + /// Get the number of non-zero values contained in this factor. + uint64_t nrValues() const override { return 1; }; }; } // namespace gtsam