sum and normalize helper methods for the AlgebraicDecisionTree

release/4.3a0
Varun Agrawal 2023-12-12 06:26:57 -05:00
parent 50670da07c
commit af490e9ffc
2 changed files with 26 additions and 6 deletions

View File

@ -196,6 +196,25 @@ namespace gtsam {
return this->apply(g, &Ring::div); return this->apply(g, &Ring::div);
} }
/// Compute sum of all values
double sum() const {
double sum = 0;
auto visitor = [&](int y) { sum += y; };
this->visit(visitor);
return sum;
}
/**
* @brief Helper method to perform normalization such that all leaves in the
* tree sum to 1
*
* @param sum
* @return AlgebraicDecisionTree
*/
AlgebraicDecisionTree normalize(double sum) const {
return this->apply([&sum](const double& x) { return x / sum; });
}
/** sum out variable */ /** sum out variable */
AlgebraicDecisionTree sum(const L& label, size_t cardinality) const { AlgebraicDecisionTree sum(const L& label, size_t cardinality) const {
return this->combine(label, cardinality, &Ring::add); return this->combine(label, cardinality, &Ring::add);

View File

@ -283,16 +283,17 @@ HybridValues HybridBayesNet::optimize() const {
error = error + gm->error(continuousValues); error = error + gm->error(continuousValues);
// Add the logNormalization constant to the error // Add the logNormalization constant to the error
// Also compute the mean for normalization (for numerical stability) // Also compute the sum for discrete probability normalization
double mean = 0.0; // (normalization trick for numerical stability)
auto addConstant = [&gm, &mean](const double &error) { double sum = 0.0;
auto addConstant = [&gm, &sum](const double &error) {
double e = error + gm->logNormalizationConstant(); double e = error + gm->logNormalizationConstant();
mean += e; sum += e;
return e; return e;
}; };
error = error.apply(addConstant); error = error.apply(addConstant);
// Normalize by the mean // Normalize by the sum
error = error.apply([&mean](double x) { return x / mean; }); error = error.normalize(sum);
// Include the discrete keys // Include the discrete keys
std::copy(gm->discreteKeys().begin(), gm->discreteKeys().end(), std::copy(gm->discreteKeys().begin(), gm->discreteKeys().end(),