common logProbability normalization function

release/4.3a0
Varun Agrawal 2024-08-26 17:27:44 -04:00
parent f636212fec
commit f4830baa5e
1 changed files with 31 additions and 12 deletions

View File

@ -233,6 +233,26 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
return {std::make_shared<HybridConditional>(result.first), result.second}; return {std::make_shared<HybridConditional>(result.first), result.second};
} }
/* ************************************************************************ */
/**
* @brief Exponential log-probabilities after performing
* the necessary normalizations.
*
* @param logProbabilities DecisionTree of log-probabilities.
* @return AlgebraicDecisionTree<Key>
*/
static AlgebraicDecisionTree<Key> exponentiateLogProbabilities(
const AlgebraicDecisionTree<Key> &logProbabilities) {
// Perform normalization
double max_log = logProbabilities.max();
AlgebraicDecisionTree<Key> probabilities = DecisionTree<Key, double>(
logProbabilities,
[&max_log](const double x) { return exp(x - max_log); });
probabilities = probabilities.normalize(probabilities.sum());
return probabilities;
}
/* ************************************************************************ */ /* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
discreteElimination(const HybridGaussianFactorGraph &factors, discreteElimination(const HybridGaussianFactorGraph &factors,
@ -245,14 +265,18 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
} else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) { } else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
// Case where we have a GaussianMixtureFactor with no continuous keys. // Case where we have a GaussianMixtureFactor with no continuous keys.
// In this case, compute discrete probabilities. // In this case, compute discrete probabilities.
auto probability = auto logProbability =
[&](const GaussianFactor::shared_ptr &factor) -> double { [&](const GaussianFactor::shared_ptr &factor) -> double {
if (!factor) return 0.0; if (!factor) return 0.0;
return exp(-factor->error(VectorValues())); return -factor->error(VectorValues());
}; };
dfg.emplace_shared<DecisionTreeFactor>( AlgebraicDecisionTree<Key> logProbabilities =
gmf->discreteKeys(), DecisionTree<Key, double>(gmf->factors(), logProbability);
DecisionTree<Key, double>(gmf->factors(), probability));
AlgebraicDecisionTree<Key> probabilities =
exponentiateLogProbabilities(logProbabilities);
dfg.emplace_shared<DecisionTreeFactor>(gmf->discreteKeys(),
probabilities);
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) { } else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
// Ignore orphaned clique. // Ignore orphaned clique.
@ -315,13 +339,8 @@ static std::shared_ptr<Factor> createDiscreteFactor(
AlgebraicDecisionTree<Key> logProbabilities( AlgebraicDecisionTree<Key> logProbabilities(
DecisionTree<Key, double>(eliminationResults, logProbability)); DecisionTree<Key, double>(eliminationResults, logProbability));
AlgebraicDecisionTree<Key> probabilities =
// Perform normalization exponentiateLogProbabilities(logProbabilities);
double max_log = logProbabilities.max();
AlgebraicDecisionTree probabilities = DecisionTree<Key, double>(
logProbabilities,
[&max_log](const double x) { return exp(x - max_log); });
probabilities = probabilities.normalize(probabilities.sum());
return std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities); return std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities);
} }