Refactoring in elimination
parent
665cb29b3f
commit
2c7b3a2af0
|
|
@ -188,11 +188,28 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
|||
boost::make_shared<HybridDiscreteFactor>(result.second)};
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
// If any GaussianFactorGraph in the decision tree contains a nullptr, convert
|
||||
// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will
|
||||
// otherwise create a GFG with a single (null) factor.
|
||||
GaussianMixtureFactor::Sum removeEmpty(const GaussianMixtureFactor::Sum &sum) {
|
||||
auto emptyGaussian = [](const GaussianMixtureFactor::GraphAndConstant
|
||||
&graph_z) {
|
||||
bool hasNull =
|
||||
std::any_of(graph_z.graph.begin(), graph_z.graph.end(),
|
||||
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
|
||||
return hasNull
|
||||
? GaussianMixtureFactor::GraphAndConstant{GaussianFactorGraph(),
|
||||
0.0}
|
||||
: graph_z;
|
||||
};
|
||||
return GaussianMixtureFactor::Sum(sum, emptyGaussian);
|
||||
}
|
||||
/* ************************************************************************ */
|
||||
static std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
|
||||
hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||
const Ordering &frontalKeys,
|
||||
const KeySet &continuousSeparator,
|
||||
const KeyVector &continuousSeparator,
|
||||
const std::set<DiscreteKey> &discreteSeparatorSet) {
|
||||
// NOTE: since we use the special JunctionTree,
|
||||
// only possibility is continuous conditioned on discrete.
|
||||
|
|
@ -203,27 +220,12 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
// decision tree indexed by all discrete keys involved.
|
||||
GaussianMixtureFactor::Sum sum = factors.SumFrontals();
|
||||
|
||||
// If a tree leaf contains nullptr, convert that leaf to an empty
|
||||
// GaussianFactorGraph. Needed since the DecisionTree will otherwise create a
|
||||
// GFG with a single (null) factor.
|
||||
// TODO(dellaert): can SumFrontals not guarantee this?
|
||||
auto emptyGaussian =
|
||||
[](const GaussianMixtureFactor::GraphAndConstant &graph_z) {
|
||||
bool hasNull = std::any_of(
|
||||
graph_z.graph.begin(), graph_z.graph.end(),
|
||||
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
|
||||
|
||||
return hasNull ? GaussianMixtureFactor::GraphAndConstant(
|
||||
GaussianFactorGraph(), 0.0)
|
||||
: graph_z;
|
||||
};
|
||||
sum = GaussianMixtureFactor::Sum(sum, emptyGaussian);
|
||||
// TODO(dellaert): does SumFrontals not guarantee we do not need this?
|
||||
sum = removeEmpty(sum);
|
||||
|
||||
using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>,
|
||||
GaussianMixtureFactor::FactorAndConstant>;
|
||||
|
||||
KeyVector keysOfSeparator;
|
||||
|
||||
// This is the elimination method on the leaf nodes
|
||||
auto eliminate = [&](const GaussianMixtureFactor::GraphAndConstant &graph_z)
|
||||
-> EliminationPair {
|
||||
|
|
@ -240,15 +242,12 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
boost::tie(conditional, newFactor) =
|
||||
EliminatePreferCholesky(graph_z.graph, frontalKeys);
|
||||
|
||||
// TODO(dellaert): always the same, and we already computed this in caller?
|
||||
keysOfSeparator = newFactor->keys();
|
||||
|
||||
#ifdef HYBRID_TIMING
|
||||
gttoc_(hybrid_eliminate);
|
||||
#endif
|
||||
|
||||
// Get the log of the log normalization constant inverse.
|
||||
double logZ = -conditional->logNormalizationConstant() + graph_z.constant;
|
||||
double logZ = graph_z.constant - conditional->logNormalizationConstant();
|
||||
return {conditional, {newFactor, logZ}};
|
||||
};
|
||||
|
||||
|
|
@ -261,35 +260,35 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
#endif
|
||||
|
||||
// Separate out decision tree into conditionals and remaining factors.
|
||||
auto pair = unzip(eliminationResults);
|
||||
GaussianMixture::Conditionals conditionals;
|
||||
GaussianMixtureFactor::Factors newFactors;
|
||||
std::tie(conditionals, newFactors) = unzip(eliminationResults);
|
||||
|
||||
// Create the GaussianMixture from the conditionals
|
||||
auto conditional = boost::make_shared<GaussianMixture>(
|
||||
frontalKeys, keysOfSeparator, discreteSeparator, pair.first);
|
||||
auto gaussianMixture = boost::make_shared<GaussianMixture>(
|
||||
frontalKeys, continuousSeparator, discreteSeparator, conditionals);
|
||||
|
||||
// If there are no more continuous parents, then we should create a
|
||||
// DiscreteFactor here, with the error for each discrete choice.
|
||||
const auto &separatorFactors = pair.second;
|
||||
if (keysOfSeparator.empty()) {
|
||||
if (continuousSeparator.empty()) {
|
||||
auto factorProb =
|
||||
[&](const GaussianMixtureFactor::FactorAndConstant &factor_z) {
|
||||
// This is the probability q(μ) at the MLE point.
|
||||
return factor_z.error(VectorValues());
|
||||
};
|
||||
const DecisionTree<Key, double> fdt(separatorFactors, factorProb);
|
||||
const DecisionTree<Key, double> fdt(newFactors, factorProb);
|
||||
const auto discreteFactor =
|
||||
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
|
||||
|
||||
return {boost::make_shared<HybridConditional>(conditional),
|
||||
return {boost::make_shared<HybridConditional>(gaussianMixture),
|
||||
boost::make_shared<HybridDiscreteFactor>(discreteFactor)};
|
||||
} else {
|
||||
// Create a resulting GaussianMixtureFactor on the separator.
|
||||
// Keys can be computed from the factors, so we should not need to pass them
|
||||
// in.
|
||||
const auto factor = boost::make_shared<GaussianMixtureFactor>(
|
||||
KeyVector(continuousSeparator.begin(), continuousSeparator.end()),
|
||||
discreteSeparator, separatorFactors);
|
||||
return {boost::make_shared<HybridConditional>(conditional), factor};
|
||||
// TODO(dellaert): Keys can be computed from the factors, so we should not
|
||||
// need to pass them in.
|
||||
return {boost::make_shared<HybridConditional>(gaussianMixture),
|
||||
boost::make_shared<GaussianMixtureFactor>(
|
||||
continuousSeparator, discreteSeparator, newFactors)};
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -389,12 +388,12 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
|||
|
||||
// Fill in discrete discrete separator keys and continuous separator keys.
|
||||
std::set<DiscreteKey> discreteSeparatorSet;
|
||||
KeySet continuousSeparator;
|
||||
KeyVector continuousSeparator;
|
||||
for (auto &k : separatorKeys) {
|
||||
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
|
||||
discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k));
|
||||
} else {
|
||||
continuousSeparator.insert(k);
|
||||
continuousSeparator.push_back(k);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue