common logProbability normalization function
parent
f636212fec
commit
f4830baa5e
|
|
@ -233,6 +233,26 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
|
|||
return {std::make_shared<HybridConditional>(result.first), result.second};
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
/**
|
||||
* @brief Exponential log-probabilities after performing
|
||||
* the necessary normalizations.
|
||||
*
|
||||
* @param logProbabilities DecisionTree of log-probabilities.
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
static AlgebraicDecisionTree<Key> exponentiateLogProbabilities(
|
||||
const AlgebraicDecisionTree<Key> &logProbabilities) {
|
||||
// Perform normalization
|
||||
double max_log = logProbabilities.max();
|
||||
AlgebraicDecisionTree<Key> probabilities = DecisionTree<Key, double>(
|
||||
logProbabilities,
|
||||
[&max_log](const double x) { return exp(x - max_log); });
|
||||
probabilities = probabilities.normalize(probabilities.sum());
|
||||
|
||||
return probabilities;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
||||
discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||
|
|
@ -245,14 +265,18 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
|||
} else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
||||
// Case where we have a GaussianMixtureFactor with no continuous keys.
|
||||
// In this case, compute discrete probabilities.
|
||||
auto probability =
|
||||
auto logProbability =
|
||||
[&](const GaussianFactor::shared_ptr &factor) -> double {
|
||||
if (!factor) return 0.0;
|
||||
return exp(-factor->error(VectorValues()));
|
||||
return -factor->error(VectorValues());
|
||||
};
|
||||
dfg.emplace_shared<DecisionTreeFactor>(
|
||||
gmf->discreteKeys(),
|
||||
DecisionTree<Key, double>(gmf->factors(), probability));
|
||||
AlgebraicDecisionTree<Key> logProbabilities =
|
||||
DecisionTree<Key, double>(gmf->factors(), logProbability);
|
||||
|
||||
AlgebraicDecisionTree<Key> probabilities =
|
||||
exponentiateLogProbabilities(logProbabilities);
|
||||
dfg.emplace_shared<DecisionTreeFactor>(gmf->discreteKeys(),
|
||||
probabilities);
|
||||
|
||||
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
||||
// Ignore orphaned clique.
|
||||
|
|
@ -315,13 +339,8 @@ static std::shared_ptr<Factor> createDiscreteFactor(
|
|||
|
||||
AlgebraicDecisionTree<Key> logProbabilities(
|
||||
DecisionTree<Key, double>(eliminationResults, logProbability));
|
||||
|
||||
// Perform normalization
|
||||
double max_log = logProbabilities.max();
|
||||
AlgebraicDecisionTree probabilities = DecisionTree<Key, double>(
|
||||
logProbabilities,
|
||||
[&max_log](const double x) { return exp(x - max_log); });
|
||||
probabilities = probabilities.normalize(probabilities.sum());
|
||||
AlgebraicDecisionTree<Key> probabilities =
|
||||
exponentiateLogProbabilities(logProbabilities);
|
||||
|
||||
return std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue