sum and normalize helper methods for the AlgebraicDecisionTree

release/4.3a0
Varun Agrawal 2023-12-12 06:26:57 -05:00
parent 50670da07c
commit af490e9ffc
2 changed files with 26 additions and 6 deletions

View File

@ -196,6 +196,25 @@ namespace gtsam {
return this->apply(g, &Ring::div);
}
/// Compute sum of all values
double sum() const {
double sum = 0;
auto visitor = [&](int y) { sum += y; };
this->visit(visitor);
return sum;
}
/**
* @brief Helper method to perform normalization such that all leaves in the
* tree sum to 1
*
* @param sum
* @return AlgebraicDecisionTree
*/
AlgebraicDecisionTree normalize(double sum) const {
return this->apply([&sum](const double& x) { return x / sum; });
}
/** sum out variable */
AlgebraicDecisionTree sum(const L& label, size_t cardinality) const {
return this->combine(label, cardinality, &Ring::add);

View File

@ -283,16 +283,17 @@ HybridValues HybridBayesNet::optimize() const {
error = error + gm->error(continuousValues);
// Add the logNormalization constant to the error
// Also compute the mean for normalization (for numerical stability)
double mean = 0.0;
auto addConstant = [&gm, &mean](const double &error) {
// Also compute the sum for discrete probability normalization
// (normalization trick for numerical stability)
double sum = 0.0;
auto addConstant = [&gm, &sum](const double &error) {
double e = error + gm->logNormalizationConstant();
mean += e;
sum += e;
return e;
};
error = error.apply(addConstant);
// Normalize by the mean
error = error.apply([&mean](double x) { return x / mean; });
// Normalize by the sum
error = error.normalize(sum);
// Include the discrete keys
std::copy(gm->discreteKeys().begin(), gm->discreteKeys().end(),