From 9a40be6f32ccdcafe50774795b5f8df84bc36d52 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 4 Jan 2025 06:12:01 -0500 Subject: [PATCH] normalize values in sparse_table so it forms a proper distribution --- gtsam/discrete/TableDistribution.cpp | 16 +++++++++++++--- gtsam/discrete/TableFactor.h | 2 +- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/TableDistribution.cpp b/gtsam/discrete/TableDistribution.cpp index 6669cea4a..e62d3ecec 100644 --- a/gtsam/discrete/TableDistribution.cpp +++ b/gtsam/discrete/TableDistribution.cpp @@ -37,6 +37,12 @@ using std::stringstream; using std::vector; namespace gtsam { +/// Normalize sparse_table +static Eigen::SparseVector normalizeSparseTable( + const Eigen::SparseVector& sparse_table) { + return sparse_table / sparse_table.sum(); +} + /* ************************************************************************** */ TableDistribution::TableDistribution(const TableFactor& f) : BaseConditional(f.keys().size(), @@ -47,19 +53,23 @@ TableDistribution::TableDistribution(const TableFactor& f) TableDistribution::TableDistribution( const DiscreteKeys& keys, const Eigen::SparseVector& potentials) : BaseConditional(keys.size(), keys, DecisionTreeFactor(keys, ADT())), - table_(TableFactor(keys, potentials)) {} + table_(TableFactor(keys, normalizeSparseTable(potentials))) {} /* ************************************************************************** */ TableDistribution::TableDistribution(const DiscreteKeys& keys, const std::vector& potentials) : BaseConditional(keys.size(), keys, DecisionTreeFactor(keys, ADT())), - table_(TableFactor(keys, potentials)) {} + table_(TableFactor( + keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) { +} /* ************************************************************************** */ TableDistribution::TableDistribution(const DiscreteKeys& keys, const std::string& potentials) : BaseConditional(keys.size(), keys, DecisionTreeFactor(keys, ADT())), - table_(TableFactor(keys, potentials)) {} + table_(TableFactor( + keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) { +} /* ************************************************************************** */ diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index a2fdb4d32..72778d711 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -86,6 +86,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { return DiscreteKey(keys_[i], cardinalities_.at(keys_[i])); } + public: /** * Convert probability table given as doubles to SparseVector. * Example: {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5} @@ -97,7 +98,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { static Eigen::SparseVector Convert(const DiscreteKeys& keys, const std::string& table); - public: // typedefs needed to play nice with gtsam typedef TableFactor This; typedef DiscreteFactor Base; ///< Typedef to base class