diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 9f55f3b63..1e5dc848b 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -196,6 +196,25 @@ namespace gtsam { return this->apply(g, &Ring::div); } + /// Compute sum of all values + double sum() const { + double sum = 0; + auto visitor = [&](int y) { sum += y; }; + this->visit(visitor); + return sum; + } + + /** + * @brief Helper method to perform normalization such that all leaves in the + * tree sum to 1 + * + * @param sum + * @return AlgebraicDecisionTree + */ + AlgebraicDecisionTree normalize(double sum) const { + return this->apply([&sum](const double& x) { return x / sum; }); + } + /** sum out variable */ AlgebraicDecisionTree sum(const L& label, size_t cardinality) const { return this->combine(label, cardinality, &Ring::add); diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index cd1157576..eea51d329 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -283,16 +283,17 @@ HybridValues HybridBayesNet::optimize() const { error = error + gm->error(continuousValues); // Add the logNormalization constant to the error - // Also compute the mean for normalization (for numerical stability) - double mean = 0.0; - auto addConstant = [&gm, &mean](const double &error) { + // Also compute the sum for discrete probability normalization + // (normalization trick for numerical stability) + double sum = 0.0; + auto addConstant = [&gm, &sum](const double &error) { double e = error + gm->logNormalizationConstant(); - mean += e; + sum += e; return e; }; error = error.apply(addConstant); - // Normalize by the mean - error = error.apply([&mean](double x) { return x / mean; }); + // Normalize by the sum + error = error.normalize(sum); // Include the discrete keys std::copy(gm->discreteKeys().begin(), gm->discreteKeys().end(),