normalize model selection term

release/4.3a0
Varun Agrawal 2023-12-15 15:21:25 -05:00
parent 7b56c96b43
commit e549a9b41f
1 changed files with 11 additions and 6 deletions

View File

@ -239,7 +239,7 @@ HybridValues HybridBayesNet::optimize() const {
hence L(X;M,Z)P(X|M) is the unnormalized probabilty of
the joint Gaussian distribution.
This can be computed by multiplying all the exponentiaed errors
This can be computed by multiplying all the exponentiated errors
of each of the conditionals, which we do below in hybrid case.
*/
if (conditional->isContinuous()) {
@ -288,7 +288,7 @@ HybridValues HybridBayesNet::optimize() const {
double sum = 0.0;
auto addConstant = [&gm, &sum](const double &error) {
double e = error + gm->logNormalizationConstant();
sum += e;
sum += std::abs(e);
return e;
};
error = error.apply(addConstant);
@ -302,12 +302,17 @@ HybridValues HybridBayesNet::optimize() const {
}
}
double min_log = error.min();
AlgebraicDecisionTree<Key> model_selection = DecisionTree<Key, double>(
error, [](const double &error) { return std::exp(-error); });
error, [&min_log](const double &x) { return std::exp(-(x - min_log)); });
model_selection = model_selection + exp(-min_log);
discrete_fg.push_back(DecisionTreeFactor(
DiscreteKeys(discreteKeySet.begin(), discreteKeySet.end()),
model_selection));
// Only add model_selection if we have discrete keys
if (discreteKeySet.size() > 0) {
discrete_fg.push_back(DecisionTreeFactor(
DiscreteKeys(discreteKeySet.begin(), discreteKeySet.end()),
model_selection));
}
// Solve for the MPE
DiscreteValues mpe = discrete_fg.optimize();