handle HybridConditional and explicitly set Gaussian Factor Graphs to empty

release/4.3a0
Varun Agrawal 2022-08-21 12:49:13 -04:00
parent 07f0101db7
commit 29c19ee77b
1 changed files with 27 additions and 4 deletions

View File

@ -96,8 +96,15 @@ GaussianMixtureFactor::Sum sumFrontals(
}
} else if (f->isContinuous()) {
deferredFactors.push_back(
boost::dynamic_pointer_cast<HybridGaussianFactor>(f)->inner());
// Check if f is HybridConditional or HybridGaussianFactor.
if (auto hc = boost::dynamic_pointer_cast<HybridConditional>(f)) {
auto conditional =
boost::dynamic_pointer_cast<GaussianConditional>(hc->inner());
deferredFactors.push_back(conditional);
} else if (auto gf = boost::dynamic_pointer_cast<HybridGaussianFactor>(f)
->inner()) {
deferredFactors.push_back(gf);
}
} else if (f->isDiscrete()) {
// Don't do anything for discrete-only factors
@ -184,6 +191,19 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// sum out frontals, this is the factor on the separator
GaussianMixtureFactor::Sum sum = sumFrontals(factors);
// If a tree leaf contains nullptr,
// convert that leaf to an empty GaussianFactorGraph.
// Needed since the DecisionTree will otherwise create
// a GFG with a single (null) factor.
auto emptyGaussian = [](const GaussianFactorGraph &gfg) {
bool hasNull =
std::any_of(gfg.begin(), gfg.end(),
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
return hasNull ? GaussianFactorGraph() : gfg;
};
sum = GaussianMixtureFactor::Sum(sum, emptyGaussian);
using EliminationPair = GaussianFactorGraph::EliminationResult;
KeyVector keysOfEliminated; // Not the ordering
@ -195,7 +215,10 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
if (graph.empty()) {
return {nullptr, nullptr};
}
auto result = EliminatePreferCholesky(graph, frontalKeys);
std::pair<boost::shared_ptr<GaussianConditional>,
boost::shared_ptr<GaussianFactor>>
result = EliminatePreferCholesky(graph, frontalKeys);
if (keysOfEliminated.empty()) {
keysOfEliminated =
result.first->keys(); // Initialize the keysOfEliminated to be the
@ -235,7 +258,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
boost::make_shared<HybridDiscreteFactor>(discreteFactor)};
} else {
// Create a resulting DCGaussianMixture on the separator.
// Create a resulting GaussianMixtureFactor on the separator.
auto factor = boost::make_shared<GaussianMixtureFactor>(
KeyVector(continuousSeparator.begin(), continuousSeparator.end()),
discreteSeparator, separatorFactors);