Merge pull request #1698 from borglab/frank/cleaner_eliminate
commit
6b098c70d5
|
@ -96,7 +96,6 @@ static GaussianFactorGraphTree addGaussian(
|
||||||
// TODO(dellaert): it's probably more efficient to first collect the discrete
|
// TODO(dellaert): it's probably more efficient to first collect the discrete
|
||||||
// keys, and then loop over all assignments to populate a vector.
|
// keys, and then loop over all assignments to populate a vector.
|
||||||
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
||||||
|
|
||||||
GaussianFactorGraphTree result;
|
GaussianFactorGraphTree result;
|
||||||
|
|
||||||
for (auto &f : factors_) {
|
for (auto &f : factors_) {
|
||||||
|
@ -198,6 +197,51 @@ GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
using Result = std::pair<std::shared_ptr<GaussianConditional>,
|
||||||
|
GaussianMixtureFactor::sharedFactor>;
|
||||||
|
|
||||||
|
// Integrate the probability mass in the last continuous conditional using
|
||||||
|
// the unnormalized probability q(μ;m) = exp(-error(μ;m)) at the mean.
|
||||||
|
// discrete_probability = exp(-error(μ;m)) * sqrt(det(2π Σ_m))
|
||||||
|
static std::shared_ptr<Factor> createDiscreteFactor(
|
||||||
|
const DecisionTree<Key, Result> &eliminationResults,
|
||||||
|
const DiscreteKeys &discreteSeparator) {
|
||||||
|
auto probability = [&](const Result &pair) -> double {
|
||||||
|
const auto &[conditional, factor] = pair;
|
||||||
|
static const VectorValues kEmpty;
|
||||||
|
// If the factor is not null, it has no keys, just contains the residual.
|
||||||
|
if (!factor) return 1.0; // TODO(dellaert): not loving this.
|
||||||
|
return exp(-factor->error(kEmpty)) / conditional->normalizationConstant();
|
||||||
|
};
|
||||||
|
|
||||||
|
DecisionTree<Key, double> probabilities(eliminationResults, probability);
|
||||||
|
|
||||||
|
return std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create GaussianMixtureFactor on the separator, taking care to correct
|
||||||
|
// for conditional constants.
|
||||||
|
static std::shared_ptr<Factor> createGaussianMixtureFactor(
|
||||||
|
const DecisionTree<Key, Result> &eliminationResults,
|
||||||
|
const KeyVector &continuousSeparator,
|
||||||
|
const DiscreteKeys &discreteSeparator) {
|
||||||
|
// Correct for the normalization constant used up by the conditional
|
||||||
|
auto correct = [&](const Result &pair) -> GaussianFactor::shared_ptr {
|
||||||
|
const auto &[conditional, factor] = pair;
|
||||||
|
if (factor) {
|
||||||
|
auto hf = std::dynamic_pointer_cast<HessianFactor>(factor);
|
||||||
|
if (!hf) throw std::runtime_error("Expected HessianFactor!");
|
||||||
|
hf->constantTerm() += 2.0 * conditional->logNormalizationConstant();
|
||||||
|
}
|
||||||
|
return factor;
|
||||||
|
};
|
||||||
|
DecisionTree<Key, GaussianFactor::shared_ptr> newFactors(eliminationResults,
|
||||||
|
correct);
|
||||||
|
|
||||||
|
return std::make_shared<GaussianMixtureFactor>(continuousSeparator,
|
||||||
|
discreteSeparator, newFactors);
|
||||||
|
}
|
||||||
|
|
||||||
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
||||||
hybridElimination(const HybridGaussianFactorGraph &factors,
|
hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
const Ordering &frontalKeys,
|
const Ordering &frontalKeys,
|
||||||
|
@ -217,9 +261,6 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
// FG has a nullptr as we're looping over the factors.
|
// FG has a nullptr as we're looping over the factors.
|
||||||
factorGraphTree = removeEmpty(factorGraphTree);
|
factorGraphTree = removeEmpty(factorGraphTree);
|
||||||
|
|
||||||
using Result = std::pair<std::shared_ptr<GaussianConditional>,
|
|
||||||
GaussianMixtureFactor::sharedFactor>;
|
|
||||||
|
|
||||||
// This is the elimination method on the leaf nodes
|
// This is the elimination method on the leaf nodes
|
||||||
auto eliminate = [&](const GaussianFactorGraph &graph) -> Result {
|
auto eliminate = [&](const GaussianFactorGraph &graph) -> Result {
|
||||||
if (graph.empty()) {
|
if (graph.empty()) {
|
||||||
|
@ -234,53 +275,22 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
// Perform elimination!
|
// Perform elimination!
|
||||||
DecisionTree<Key, Result> eliminationResults(factorGraphTree, eliminate);
|
DecisionTree<Key, Result> eliminationResults(factorGraphTree, eliminate);
|
||||||
|
|
||||||
// Separate out decision tree into conditionals and remaining factors.
|
// If there are no more continuous parents we create a DiscreteFactor with the
|
||||||
const auto [conditionals, newFactors] = unzip(eliminationResults);
|
// error for each discrete choice. Otherwise, create a GaussianMixtureFactor
|
||||||
|
// on the separator, taking care to correct for conditional constants.
|
||||||
|
auto newFactor =
|
||||||
|
continuousSeparator.empty()
|
||||||
|
? createDiscreteFactor(eliminationResults, discreteSeparator)
|
||||||
|
: createGaussianMixtureFactor(eliminationResults, continuousSeparator,
|
||||||
|
discreteSeparator);
|
||||||
|
|
||||||
// Create the GaussianMixture from the conditionals
|
// Create the GaussianMixture from the conditionals
|
||||||
|
GaussianMixture::Conditionals conditionals(
|
||||||
|
eliminationResults, [](const Result &pair) { return pair.first; });
|
||||||
auto gaussianMixture = std::make_shared<GaussianMixture>(
|
auto gaussianMixture = std::make_shared<GaussianMixture>(
|
||||||
frontalKeys, continuousSeparator, discreteSeparator, conditionals);
|
frontalKeys, continuousSeparator, discreteSeparator, conditionals);
|
||||||
|
|
||||||
if (continuousSeparator.empty()) {
|
return {std::make_shared<HybridConditional>(gaussianMixture), newFactor};
|
||||||
// If there are no more continuous parents, then we create a
|
|
||||||
// DiscreteFactor here, with the error for each discrete choice.
|
|
||||||
|
|
||||||
// Integrate the probability mass in the last continuous conditional using
|
|
||||||
// the unnormalized probability q(μ;m) = exp(-error(μ;m)) at the mean.
|
|
||||||
// discrete_probability = exp(-error(μ;m)) * sqrt(det(2π Σ_m))
|
|
||||||
auto probability = [&](const Result &pair) -> double {
|
|
||||||
static const VectorValues kEmpty;
|
|
||||||
// If the factor is not null, it has no keys, just contains the residual.
|
|
||||||
const auto &factor = pair.second;
|
|
||||||
if (!factor) return 1.0; // TODO(dellaert): not loving this.
|
|
||||||
return exp(-factor->error(kEmpty)) / pair.first->normalizationConstant();
|
|
||||||
};
|
|
||||||
|
|
||||||
DecisionTree<Key, double> probabilities(eliminationResults, probability);
|
|
||||||
|
|
||||||
return {
|
|
||||||
std::make_shared<HybridConditional>(gaussianMixture),
|
|
||||||
std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities)};
|
|
||||||
} else {
|
|
||||||
// Otherwise, we create a resulting GaussianMixtureFactor on the separator,
|
|
||||||
// taking care to correct for conditional constant.
|
|
||||||
|
|
||||||
// Correct for the normalization constant used up by the conditional
|
|
||||||
auto correct = [&](const Result &pair) {
|
|
||||||
const auto &factor = pair.second;
|
|
||||||
if (!factor) return;
|
|
||||||
auto hf = std::dynamic_pointer_cast<HessianFactor>(factor);
|
|
||||||
if (!hf) throw std::runtime_error("Expected HessianFactor!");
|
|
||||||
hf->constantTerm() += 2.0 * pair.first->logNormalizationConstant();
|
|
||||||
};
|
|
||||||
eliminationResults.visit(correct);
|
|
||||||
|
|
||||||
const auto mixtureFactor = std::make_shared<GaussianMixtureFactor>(
|
|
||||||
continuousSeparator, discreteSeparator, newFactors);
|
|
||||||
|
|
||||||
return {std::make_shared<HybridConditional>(gaussianMixture),
|
|
||||||
mixtureFactor};
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************
|
/* ************************************************************************
|
||||||
|
|
Loading…
Reference in New Issue