normalize the discrete factor at the continuous-discrete boundary
parent
598edfacce
commit
eef9765e4a
|
@ -279,21 +279,37 @@ GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
|
||||||
using Result = std::pair<std::shared_ptr<GaussianConditional>,
|
using Result = std::pair<std::shared_ptr<GaussianConditional>,
|
||||||
GaussianMixtureFactor::sharedFactor>;
|
GaussianMixtureFactor::sharedFactor>;
|
||||||
|
|
||||||
// Integrate the probability mass in the last continuous conditional using
|
/**
|
||||||
// the unnormalized probability q(μ;m) = exp(-error(μ;m)) at the mean.
|
* Compute the probability q(μ;m) = exp(-error(μ;m)) * sqrt(det(2π Σ_m)
|
||||||
// discrete_probability = exp(-error(μ;m)) * sqrt(det(2π Σ_m))
|
* from the residual error at the mean μ.
|
||||||
|
* The residual error contains no keys, and only
|
||||||
|
* depends on the discrete separator if present.
|
||||||
|
*/
|
||||||
static std::shared_ptr<Factor> createDiscreteFactor(
|
static std::shared_ptr<Factor> createDiscreteFactor(
|
||||||
const DecisionTree<Key, Result> &eliminationResults,
|
const DecisionTree<Key, Result> &eliminationResults,
|
||||||
const DiscreteKeys &discreteSeparator) {
|
const DiscreteKeys &discreteSeparator) {
|
||||||
auto probability = [&](const Result &pair) -> double {
|
auto logProbability = [&](const Result &pair) -> double {
|
||||||
const auto &[conditional, factor] = pair;
|
const auto &[conditional, factor] = pair;
|
||||||
static const VectorValues kEmpty;
|
static const VectorValues kEmpty;
|
||||||
// If the factor is not null, it has no keys, just contains the residual.
|
// If the factor is not null, it has no keys, just contains the residual.
|
||||||
if (!factor) return 1.0; // TODO(dellaert): not loving this.
|
if (!factor) return 1.0; // TODO(dellaert): not loving this.
|
||||||
return exp(-factor->error(kEmpty)) / conditional->normalizationConstant();
|
|
||||||
|
// Logspace version of:
|
||||||
|
// exp(-factor->error(kEmpty)) / conditional->normalizationConstant();
|
||||||
|
// We take negative of the logNormalizationConstant `log(1/k)`
|
||||||
|
// to get `log(k)`.
|
||||||
|
return -factor->error(kEmpty) + (-conditional->logNormalizationConstant());
|
||||||
};
|
};
|
||||||
|
|
||||||
DecisionTree<Key, double> probabilities(eliminationResults, probability);
|
AlgebraicDecisionTree<Key> logProbabilities(
|
||||||
|
DecisionTree<Key, double>(eliminationResults, logProbability));
|
||||||
|
|
||||||
|
// Perform normalization
|
||||||
|
double max_log = logProbabilities.max();
|
||||||
|
AlgebraicDecisionTree probabilities = DecisionTree<Key, double>(
|
||||||
|
logProbabilities,
|
||||||
|
[&max_log](const double x) { return exp(x - max_log); });
|
||||||
|
probabilities = probabilities.normalize(probabilities.sum());
|
||||||
|
|
||||||
return std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities);
|
return std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue