sum and normalize helper methods for the AlgebraicDecisionTree
parent
50670da07c
commit
af490e9ffc
|
@ -196,6 +196,25 @@ namespace gtsam {
|
|||
return this->apply(g, &Ring::div);
|
||||
}
|
||||
|
||||
/// Compute sum of all values
|
||||
double sum() const {
|
||||
double sum = 0;
|
||||
auto visitor = [&](int y) { sum += y; };
|
||||
this->visit(visitor);
|
||||
return sum;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Helper method to perform normalization such that all leaves in the
|
||||
* tree sum to 1
|
||||
*
|
||||
* @param sum
|
||||
* @return AlgebraicDecisionTree
|
||||
*/
|
||||
AlgebraicDecisionTree normalize(double sum) const {
|
||||
return this->apply([&sum](const double& x) { return x / sum; });
|
||||
}
|
||||
|
||||
/** sum out variable */
|
||||
AlgebraicDecisionTree sum(const L& label, size_t cardinality) const {
|
||||
return this->combine(label, cardinality, &Ring::add);
|
||||
|
|
|
@ -283,16 +283,17 @@ HybridValues HybridBayesNet::optimize() const {
|
|||
error = error + gm->error(continuousValues);
|
||||
|
||||
// Add the logNormalization constant to the error
|
||||
// Also compute the mean for normalization (for numerical stability)
|
||||
double mean = 0.0;
|
||||
auto addConstant = [&gm, &mean](const double &error) {
|
||||
// Also compute the sum for discrete probability normalization
|
||||
// (normalization trick for numerical stability)
|
||||
double sum = 0.0;
|
||||
auto addConstant = [&gm, &sum](const double &error) {
|
||||
double e = error + gm->logNormalizationConstant();
|
||||
mean += e;
|
||||
sum += e;
|
||||
return e;
|
||||
};
|
||||
error = error.apply(addConstant);
|
||||
// Normalize by the mean
|
||||
error = error.apply([&mean](double x) { return x / mean; });
|
||||
// Normalize by the sum
|
||||
error = error.normalize(sum);
|
||||
|
||||
// Include the discrete keys
|
||||
std::copy(gm->discreteKeys().begin(), gm->discreteKeys().end(),
|
||||
|
|
Loading…
Reference in New Issue