common logProbability normalization function
parent
f636212fec
commit
f4830baa5e
|
|
@ -233,6 +233,26 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
|
||||||
return {std::make_shared<HybridConditional>(result.first), result.second};
|
return {std::make_shared<HybridConditional>(result.first), result.second};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
/**
|
||||||
|
* @brief Exponential log-probabilities after performing
|
||||||
|
* the necessary normalizations.
|
||||||
|
*
|
||||||
|
* @param logProbabilities DecisionTree of log-probabilities.
|
||||||
|
* @return AlgebraicDecisionTree<Key>
|
||||||
|
*/
|
||||||
|
static AlgebraicDecisionTree<Key> exponentiateLogProbabilities(
|
||||||
|
const AlgebraicDecisionTree<Key> &logProbabilities) {
|
||||||
|
// Perform normalization
|
||||||
|
double max_log = logProbabilities.max();
|
||||||
|
AlgebraicDecisionTree<Key> probabilities = DecisionTree<Key, double>(
|
||||||
|
logProbabilities,
|
||||||
|
[&max_log](const double x) { return exp(x - max_log); });
|
||||||
|
probabilities = probabilities.normalize(probabilities.sum());
|
||||||
|
|
||||||
|
return probabilities;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
||||||
discreteElimination(const HybridGaussianFactorGraph &factors,
|
discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
|
|
@ -245,14 +265,18 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
} else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
} else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
||||||
// Case where we have a GaussianMixtureFactor with no continuous keys.
|
// Case where we have a GaussianMixtureFactor with no continuous keys.
|
||||||
// In this case, compute discrete probabilities.
|
// In this case, compute discrete probabilities.
|
||||||
auto probability =
|
auto logProbability =
|
||||||
[&](const GaussianFactor::shared_ptr &factor) -> double {
|
[&](const GaussianFactor::shared_ptr &factor) -> double {
|
||||||
if (!factor) return 0.0;
|
if (!factor) return 0.0;
|
||||||
return exp(-factor->error(VectorValues()));
|
return -factor->error(VectorValues());
|
||||||
};
|
};
|
||||||
dfg.emplace_shared<DecisionTreeFactor>(
|
AlgebraicDecisionTree<Key> logProbabilities =
|
||||||
gmf->discreteKeys(),
|
DecisionTree<Key, double>(gmf->factors(), logProbability);
|
||||||
DecisionTree<Key, double>(gmf->factors(), probability));
|
|
||||||
|
AlgebraicDecisionTree<Key> probabilities =
|
||||||
|
exponentiateLogProbabilities(logProbabilities);
|
||||||
|
dfg.emplace_shared<DecisionTreeFactor>(gmf->discreteKeys(),
|
||||||
|
probabilities);
|
||||||
|
|
||||||
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
||||||
// Ignore orphaned clique.
|
// Ignore orphaned clique.
|
||||||
|
|
@ -315,13 +339,8 @@ static std::shared_ptr<Factor> createDiscreteFactor(
|
||||||
|
|
||||||
AlgebraicDecisionTree<Key> logProbabilities(
|
AlgebraicDecisionTree<Key> logProbabilities(
|
||||||
DecisionTree<Key, double>(eliminationResults, logProbability));
|
DecisionTree<Key, double>(eliminationResults, logProbability));
|
||||||
|
AlgebraicDecisionTree<Key> probabilities =
|
||||||
// Perform normalization
|
exponentiateLogProbabilities(logProbabilities);
|
||||||
double max_log = logProbabilities.max();
|
|
||||||
AlgebraicDecisionTree probabilities = DecisionTree<Key, double>(
|
|
||||||
logProbabilities,
|
|
||||||
[&max_log](const double x) { return exp(x - max_log); });
|
|
||||||
probabilities = probabilities.normalize(probabilities.sum());
|
|
||||||
|
|
||||||
return std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities);
|
return std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue