common logProbability normalization function

release/4.3a0
Varun Agrawal 2024-08-26 17:27:44 -04:00
parent f636212fec
commit f4830baa5e
1 changed files with 31 additions and 12 deletions

View File

@ -233,6 +233,26 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
return {std::make_shared<HybridConditional>(result.first), result.second};
}
/* ************************************************************************ */
/**
* @brief Exponential log-probabilities after performing
* the necessary normalizations.
*
* @param logProbabilities DecisionTree of log-probabilities.
* @return AlgebraicDecisionTree<Key>
*/
static AlgebraicDecisionTree<Key> exponentiateLogProbabilities(
const AlgebraicDecisionTree<Key> &logProbabilities) {
// Perform normalization
double max_log = logProbabilities.max();
AlgebraicDecisionTree<Key> probabilities = DecisionTree<Key, double>(
logProbabilities,
[&max_log](const double x) { return exp(x - max_log); });
probabilities = probabilities.normalize(probabilities.sum());
return probabilities;
}
/* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
discreteElimination(const HybridGaussianFactorGraph &factors,
@ -245,14 +265,18 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
} else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
// Case where we have a GaussianMixtureFactor with no continuous keys.
// In this case, compute discrete probabilities.
auto probability =
auto logProbability =
[&](const GaussianFactor::shared_ptr &factor) -> double {
if (!factor) return 0.0;
return exp(-factor->error(VectorValues()));
return -factor->error(VectorValues());
};
dfg.emplace_shared<DecisionTreeFactor>(
gmf->discreteKeys(),
DecisionTree<Key, double>(gmf->factors(), probability));
AlgebraicDecisionTree<Key> logProbabilities =
DecisionTree<Key, double>(gmf->factors(), logProbability);
AlgebraicDecisionTree<Key> probabilities =
exponentiateLogProbabilities(logProbabilities);
dfg.emplace_shared<DecisionTreeFactor>(gmf->discreteKeys(),
probabilities);
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
// Ignore orphaned clique.
@ -315,13 +339,8 @@ static std::shared_ptr<Factor> createDiscreteFactor(
AlgebraicDecisionTree<Key> logProbabilities(
DecisionTree<Key, double>(eliminationResults, logProbability));
// Perform normalization
double max_log = logProbabilities.max();
AlgebraicDecisionTree probabilities = DecisionTree<Key, double>(
logProbabilities,
[&max_log](const double x) { return exp(x - max_log); });
probabilities = probabilities.normalize(probabilities.sum());
AlgebraicDecisionTree<Key> probabilities =
exponentiateLogProbabilities(logProbabilities);
return std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities);
}