normalize the discrete factor at the continuous-discrete boundary

release/4.3a0
Varun Agrawal 2024-08-20 16:29:02 -04:00
parent 598edfacce
commit eef9765e4a
1 changed files with 22 additions and 6 deletions

View File

@ -279,21 +279,37 @@ GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
using Result = std::pair<std::shared_ptr<GaussianConditional>,
GaussianMixtureFactor::sharedFactor>;
// Integrate the probability mass in the last continuous conditional using
// the unnormalized probability q(μ;m) = exp(-error(μ;m)) at the mean.
// discrete_probability = exp(-error(μ;m)) * sqrt(det(2π Σ_m))
/**
* Compute the probability q(μ;m) = exp(-error(μ;m)) * sqrt(det(2π Σ_m)
* from the residual error at the mean μ.
* The residual error contains no keys, and only
* depends on the discrete separator if present.
*/
static std::shared_ptr<Factor> createDiscreteFactor(
const DecisionTree<Key, Result> &eliminationResults,
const DiscreteKeys &discreteSeparator) {
auto probability = [&](const Result &pair) -> double {
auto logProbability = [&](const Result &pair) -> double {
const auto &[conditional, factor] = pair;
static const VectorValues kEmpty;
// If the factor is not null, it has no keys, just contains the residual.
if (!factor) return 1.0; // TODO(dellaert): not loving this.
return exp(-factor->error(kEmpty)) / conditional->normalizationConstant();
// Logspace version of:
// exp(-factor->error(kEmpty)) / conditional->normalizationConstant();
// We take negative of the logNormalizationConstant `log(1/k)`
// to get `log(k)`.
return -factor->error(kEmpty) + (-conditional->logNormalizationConstant());
};
DecisionTree<Key, double> probabilities(eliminationResults, probability);
AlgebraicDecisionTree<Key> logProbabilities(
DecisionTree<Key, double>(eliminationResults, logProbability));
// Perform normalization
double max_log = logProbabilities.max();
AlgebraicDecisionTree probabilities = DecisionTree<Key, double>(
logProbabilities,
[&max_log](const double x) { return exp(x - max_log); });
probabilities = probabilities.normalize(probabilities.sum());
return std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities);
}