Added correction with the normalization constant in the second elimination path.
parent
c3ca31f2f3
commit
e444962aad
|
@ -220,12 +220,11 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
// FG has a nullptr as we're looping over the factors.
|
// FG has a nullptr as we're looping over the factors.
|
||||||
factorGraphTree = removeEmpty(factorGraphTree);
|
factorGraphTree = removeEmpty(factorGraphTree);
|
||||||
|
|
||||||
using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>,
|
using Result = std::pair<boost::shared_ptr<GaussianConditional>,
|
||||||
GaussianMixtureFactor::sharedFactor>;
|
GaussianMixtureFactor::sharedFactor>;
|
||||||
|
|
||||||
// This is the elimination method on the leaf nodes
|
// This is the elimination method on the leaf nodes
|
||||||
auto eliminateFunc =
|
auto eliminate = [&](const GaussianFactorGraph &graph) -> Result {
|
||||||
[&](const GaussianFactorGraph &graph) -> EliminationPair {
|
|
||||||
if (graph.empty()) {
|
if (graph.empty()) {
|
||||||
return {nullptr, nullptr};
|
return {nullptr, nullptr};
|
||||||
}
|
}
|
||||||
|
@ -234,21 +233,17 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
gttic_(hybrid_eliminate);
|
gttic_(hybrid_eliminate);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
boost::shared_ptr<GaussianConditional> conditional;
|
auto result = EliminatePreferCholesky(graph, frontalKeys);
|
||||||
boost::shared_ptr<GaussianFactor> newFactor;
|
|
||||||
boost::tie(conditional, newFactor) =
|
|
||||||
EliminatePreferCholesky(graph, frontalKeys);
|
|
||||||
|
|
||||||
#ifdef HYBRID_TIMING
|
#ifdef HYBRID_TIMING
|
||||||
gttoc_(hybrid_eliminate);
|
gttoc_(hybrid_eliminate);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
return {conditional, newFactor};
|
return result;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Perform elimination!
|
// Perform elimination!
|
||||||
DecisionTree<Key, EliminationPair> eliminationResults(factorGraphTree,
|
DecisionTree<Key, Result> eliminationResults(factorGraphTree, eliminate);
|
||||||
eliminateFunc);
|
|
||||||
|
|
||||||
#ifdef HYBRID_TIMING
|
#ifdef HYBRID_TIMING
|
||||||
tictoc_print_();
|
tictoc_print_();
|
||||||
|
@ -264,30 +259,46 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
auto gaussianMixture = boost::make_shared<GaussianMixture>(
|
auto gaussianMixture = boost::make_shared<GaussianMixture>(
|
||||||
frontalKeys, continuousSeparator, discreteSeparator, conditionals);
|
frontalKeys, continuousSeparator, discreteSeparator, conditionals);
|
||||||
|
|
||||||
// If there are no more continuous parents, then we should create a
|
|
||||||
// DiscreteFactor here, with the error for each discrete choice.
|
|
||||||
if (continuousSeparator.empty()) {
|
if (continuousSeparator.empty()) {
|
||||||
auto probPrime = [&](const EliminationPair &pair) {
|
// If there are no more continuous parents, then we create a
|
||||||
// This is the unnormalized probability q(μ;m) at the mean.
|
// DiscreteFactor here, with the error for each discrete choice.
|
||||||
// q(μ;m) = exp(-error(μ;m)) * sqrt(det(2π Σ_m))
|
|
||||||
// The factor has no keys, just contains the residual.
|
// Integrate the probability mass in the last continuous conditional using
|
||||||
|
// the unnormalized probability q(μ;m) = exp(-error(μ;m)) at the mean.
|
||||||
|
// discrete_probability = exp(-error(μ;m)) * sqrt(det(2π Σ_m))
|
||||||
|
auto probability = [&](const Result &pair) -> double {
|
||||||
static const VectorValues kEmpty;
|
static const VectorValues kEmpty;
|
||||||
return pair.second ? exp(-pair.second->error(kEmpty)) /
|
// If the factor is not null, it has no keys, just contains the residual.
|
||||||
pair.first->normalizationConstant()
|
const auto &factor = pair.second;
|
||||||
: 1.0;
|
if (!factor) return 1.0; // TODO(dellaert): not loving this.
|
||||||
|
return exp(-factor->error(kEmpty)) / pair.first->normalizationConstant();
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto discreteFactor = boost::make_shared<DecisionTreeFactor>(
|
DecisionTree<Key, double> probabilities(eliminationResults, probability);
|
||||||
discreteSeparator,
|
return {boost::make_shared<HybridConditional>(gaussianMixture),
|
||||||
DecisionTree<Key, double>(eliminationResults, probPrime));
|
boost::make_shared<DecisionTreeFactor>(discreteSeparator,
|
||||||
|
probabilities)};
|
||||||
|
} else {
|
||||||
|
// Otherwise, we create a resulting GaussianMixtureFactor on the separator,
|
||||||
|
// taking care to correct for conditional constant.
|
||||||
|
|
||||||
|
// Correct for the normalization constant used up by the conditional
|
||||||
|
auto correct = [&](const Result &pair) -> GaussianFactor::shared_ptr {
|
||||||
|
const auto &factor = pair.second;
|
||||||
|
if (!factor) return factor; // TODO(dellaert): not loving this.
|
||||||
|
auto hf = boost::dynamic_pointer_cast<HessianFactor>(factor);
|
||||||
|
if (!hf) throw std::runtime_error("Expected HessianFactor!");
|
||||||
|
hf->constantTerm() += 2.0 * pair.first->logNormalizationConstant();
|
||||||
|
return hf;
|
||||||
|
};
|
||||||
|
|
||||||
|
GaussianMixtureFactor::Factors correctedFactors(eliminationResults,
|
||||||
|
correct);
|
||||||
|
const auto mixtureFactor = boost::make_shared<GaussianMixtureFactor>(
|
||||||
|
continuousSeparator, discreteSeparator, newFactors);
|
||||||
|
|
||||||
return {boost::make_shared<HybridConditional>(gaussianMixture),
|
return {boost::make_shared<HybridConditional>(gaussianMixture),
|
||||||
discreteFactor};
|
mixtureFactor};
|
||||||
} else {
|
|
||||||
// Create a resulting GaussianMixtureFactor on the separator.
|
|
||||||
return {boost::make_shared<HybridConditional>(gaussianMixture),
|
|
||||||
boost::make_shared<GaussianMixtureFactor>(
|
|
||||||
continuousSeparator, discreteSeparator, newFactors)};
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue