sum and normalize helper methods for the AlgebraicDecisionTree
parent
50670da07c
commit
af490e9ffc
|
@ -196,6 +196,25 @@ namespace gtsam {
|
||||||
return this->apply(g, &Ring::div);
|
return this->apply(g, &Ring::div);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Compute sum of all values
|
||||||
|
double sum() const {
|
||||||
|
double sum = 0;
|
||||||
|
auto visitor = [&](int y) { sum += y; };
|
||||||
|
this->visit(visitor);
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Helper method to perform normalization such that all leaves in the
|
||||||
|
* tree sum to 1
|
||||||
|
*
|
||||||
|
* @param sum
|
||||||
|
* @return AlgebraicDecisionTree
|
||||||
|
*/
|
||||||
|
AlgebraicDecisionTree normalize(double sum) const {
|
||||||
|
return this->apply([&sum](const double& x) { return x / sum; });
|
||||||
|
}
|
||||||
|
|
||||||
/** sum out variable */
|
/** sum out variable */
|
||||||
AlgebraicDecisionTree sum(const L& label, size_t cardinality) const {
|
AlgebraicDecisionTree sum(const L& label, size_t cardinality) const {
|
||||||
return this->combine(label, cardinality, &Ring::add);
|
return this->combine(label, cardinality, &Ring::add);
|
||||||
|
|
|
@ -283,16 +283,17 @@ HybridValues HybridBayesNet::optimize() const {
|
||||||
error = error + gm->error(continuousValues);
|
error = error + gm->error(continuousValues);
|
||||||
|
|
||||||
// Add the logNormalization constant to the error
|
// Add the logNormalization constant to the error
|
||||||
// Also compute the mean for normalization (for numerical stability)
|
// Also compute the sum for discrete probability normalization
|
||||||
double mean = 0.0;
|
// (normalization trick for numerical stability)
|
||||||
auto addConstant = [&gm, &mean](const double &error) {
|
double sum = 0.0;
|
||||||
|
auto addConstant = [&gm, &sum](const double &error) {
|
||||||
double e = error + gm->logNormalizationConstant();
|
double e = error + gm->logNormalizationConstant();
|
||||||
mean += e;
|
sum += e;
|
||||||
return e;
|
return e;
|
||||||
};
|
};
|
||||||
error = error.apply(addConstant);
|
error = error.apply(addConstant);
|
||||||
// Normalize by the mean
|
// Normalize by the sum
|
||||||
error = error.apply([&mean](double x) { return x / mean; });
|
error = error.normalize(sum);
|
||||||
|
|
||||||
// Include the discrete keys
|
// Include the discrete keys
|
||||||
std::copy(gm->discreteKeys().begin(), gm->discreteKeys().end(),
|
std::copy(gm->discreteKeys().begin(), gm->discreteKeys().end(),
|
||||||
|
|
Loading…
Reference in New Issue