normalize values in sparse_table so it forms a proper distribution

release/4.3a0
Varun Agrawal 2025-01-04 06:12:01 -05:00
parent d6bc1e11a6
commit 9a40be6f32
2 changed files with 14 additions and 4 deletions

View File

@ -37,6 +37,12 @@ using std::stringstream;
using std::vector;
namespace gtsam {
/// Normalize sparse_table
static Eigen::SparseVector<double> normalizeSparseTable(
const Eigen::SparseVector<double>& sparse_table) {
return sparse_table / sparse_table.sum();
}
/* ************************************************************************** */
TableDistribution::TableDistribution(const TableFactor& f)
: BaseConditional(f.keys().size(),
@ -47,19 +53,23 @@ TableDistribution::TableDistribution(const TableFactor& f)
TableDistribution::TableDistribution(
const DiscreteKeys& keys, const Eigen::SparseVector<double>& potentials)
: BaseConditional(keys.size(), keys, DecisionTreeFactor(keys, ADT())),
table_(TableFactor(keys, potentials)) {}
table_(TableFactor(keys, normalizeSparseTable(potentials))) {}
/* ************************************************************************** */
TableDistribution::TableDistribution(const DiscreteKeys& keys,
const std::vector<double>& potentials)
: BaseConditional(keys.size(), keys, DecisionTreeFactor(keys, ADT())),
table_(TableFactor(keys, potentials)) {}
table_(TableFactor(
keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) {
}
/* ************************************************************************** */
TableDistribution::TableDistribution(const DiscreteKeys& keys,
const std::string& potentials)
: BaseConditional(keys.size(), keys, DecisionTreeFactor(keys, ADT())),
table_(TableFactor(keys, potentials)) {}
table_(TableFactor(
keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) {
}
/* **************************************************************************
*/

View File

@ -86,6 +86,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
}
public:
/**
* Convert probability table given as doubles to SparseVector.
* Example: {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
@ -97,7 +98,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
static Eigen::SparseVector<double> Convert(const DiscreteKeys& keys,
const std::string& table);
public:
// typedefs needed to play nice with gtsam
typedef TableFactor This;
typedef DiscreteFactor Base; ///< Typedef to base class