Refactoring in elimination

release/4.3a0
Frank Dellaert 2023-01-01 17:44:53 -05:00
parent 665cb29b3f
commit 2c7b3a2af0
1 changed files with 36 additions and 37 deletions

View File

@ -188,11 +188,28 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
boost::make_shared<HybridDiscreteFactor>(result.second)};
}
/* ************************************************************************ */
// If any GaussianFactorGraph in the decision tree contains a nullptr, convert
// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will
// otherwise create a GFG with a single (null) factor.
GaussianMixtureFactor::Sum removeEmpty(const GaussianMixtureFactor::Sum &sum) {
auto emptyGaussian = [](const GaussianMixtureFactor::GraphAndConstant
&graph_z) {
bool hasNull =
std::any_of(graph_z.graph.begin(), graph_z.graph.end(),
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
return hasNull
? GaussianMixtureFactor::GraphAndConstant{GaussianFactorGraph(),
0.0}
: graph_z;
};
return GaussianMixtureFactor::Sum(sum, emptyGaussian);
}
/* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
hybridElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys,
const KeySet &continuousSeparator,
const KeyVector &continuousSeparator,
const std::set<DiscreteKey> &discreteSeparatorSet) {
// NOTE: since we use the special JunctionTree,
// only possibility is continuous conditioned on discrete.
@ -203,27 +220,12 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// decision tree indexed by all discrete keys involved.
GaussianMixtureFactor::Sum sum = factors.SumFrontals();
// If a tree leaf contains nullptr, convert that leaf to an empty
// GaussianFactorGraph. Needed since the DecisionTree will otherwise create a
// GFG with a single (null) factor.
// TODO(dellaert): can SumFrontals not guarantee this?
auto emptyGaussian =
[](const GaussianMixtureFactor::GraphAndConstant &graph_z) {
bool hasNull = std::any_of(
graph_z.graph.begin(), graph_z.graph.end(),
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
return hasNull ? GaussianMixtureFactor::GraphAndConstant(
GaussianFactorGraph(), 0.0)
: graph_z;
};
sum = GaussianMixtureFactor::Sum(sum, emptyGaussian);
// TODO(dellaert): does SumFrontals not guarantee we do not need this?
sum = removeEmpty(sum);
using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>,
GaussianMixtureFactor::FactorAndConstant>;
KeyVector keysOfSeparator;
// This is the elimination method on the leaf nodes
auto eliminate = [&](const GaussianMixtureFactor::GraphAndConstant &graph_z)
-> EliminationPair {
@ -240,15 +242,12 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
boost::tie(conditional, newFactor) =
EliminatePreferCholesky(graph_z.graph, frontalKeys);
// TODO(dellaert): always the same, and we already computed this in caller?
keysOfSeparator = newFactor->keys();
#ifdef HYBRID_TIMING
gttoc_(hybrid_eliminate);
#endif
// Get the log of the log normalization constant inverse.
double logZ = -conditional->logNormalizationConstant() + graph_z.constant;
double logZ = graph_z.constant - conditional->logNormalizationConstant();
return {conditional, {newFactor, logZ}};
};
@ -261,35 +260,35 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
#endif
// Separate out decision tree into conditionals and remaining factors.
auto pair = unzip(eliminationResults);
GaussianMixture::Conditionals conditionals;
GaussianMixtureFactor::Factors newFactors;
std::tie(conditionals, newFactors) = unzip(eliminationResults);
// Create the GaussianMixture from the conditionals
auto conditional = boost::make_shared<GaussianMixture>(
frontalKeys, keysOfSeparator, discreteSeparator, pair.first);
auto gaussianMixture = boost::make_shared<GaussianMixture>(
frontalKeys, continuousSeparator, discreteSeparator, conditionals);
// If there are no more continuous parents, then we should create a
// DiscreteFactor here, with the error for each discrete choice.
const auto &separatorFactors = pair.second;
if (keysOfSeparator.empty()) {
if (continuousSeparator.empty()) {
auto factorProb =
[&](const GaussianMixtureFactor::FactorAndConstant &factor_z) {
// This is the probability q(μ) at the MLE point.
return factor_z.error(VectorValues());
};
const DecisionTree<Key, double> fdt(separatorFactors, factorProb);
const DecisionTree<Key, double> fdt(newFactors, factorProb);
const auto discreteFactor =
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
return {boost::make_shared<HybridConditional>(conditional),
return {boost::make_shared<HybridConditional>(gaussianMixture),
boost::make_shared<HybridDiscreteFactor>(discreteFactor)};
} else {
// Create a resulting GaussianMixtureFactor on the separator.
// Keys can be computed from the factors, so we should not need to pass them
// in.
const auto factor = boost::make_shared<GaussianMixtureFactor>(
KeyVector(continuousSeparator.begin(), continuousSeparator.end()),
discreteSeparator, separatorFactors);
return {boost::make_shared<HybridConditional>(conditional), factor};
// TODO(dellaert): Keys can be computed from the factors, so we should not
// need to pass them in.
return {boost::make_shared<HybridConditional>(gaussianMixture),
boost::make_shared<GaussianMixtureFactor>(
continuousSeparator, discreteSeparator, newFactors)};
}
}
@ -389,12 +388,12 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
// Fill in discrete discrete separator keys and continuous separator keys.
std::set<DiscreteKey> discreteSeparatorSet;
KeySet continuousSeparator;
KeyVector continuousSeparator;
for (auto &k : separatorKeys) {
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k));
} else {
continuousSeparator.insert(k);
continuousSeparator.push_back(k);
}
}