improve dead mode removal by checking for empty discrete joints and adding the marginals for future factors

release/4.3a0
Varun Agrawal 2025-01-24 19:57:47 -05:00
parent 938ae06031
commit 7ca7e4549e
1 changed files with 25 additions and 9 deletions

View File

@ -58,15 +58,12 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
joint = joint * (*conditional);
}
// Create the result starting with the pruned joint.
// Initialize the resulting HybridBayesNet.
HybridBayesNet result;
result.emplace_shared<DiscreteConditional>(joint);
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
result.back()->asDiscrete()->prune(maxNrLeaves);
// Get pruned discrete probabilities so
// we can prune HybridGaussianConditionals.
DiscreteConditional pruned = *result.back()->asDiscrete();
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
DiscreteConditional pruned = joint;
joint.prune(maxNrLeaves);
DiscreteValues deadModesValues;
if (removeDeadModes) {
@ -88,8 +85,26 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
}
// Remove the modes (imperative)
result.back()->asDiscrete()->removeDiscreteModes(deadModesValues);
pruned = *result.back()->asDiscrete();
pruned.removeDiscreteModes(deadModesValues);
/*
If the pruned discrete conditional has any keys left,
we add it to the HybridBayesNet.
If not, it means it is an orphan so we don't add this pruned joint,
and instead add only the marginals below.
*/
if (pruned.keys().size() > 0) {
result.emplace_shared<DiscreteConditional>(pruned);
}
// Add the marginals for future factors
for (auto &&[key, _] : deadModesValues) {
result.push_back(
std::dynamic_pointer_cast<DiscreteConditional>(marginals(key)));
}
} else {
result.emplace_shared<DiscreteConditional>(pruned);
}
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
@ -186,6 +201,7 @@ DiscreteValues HybridBayesNet::mpe() const {
}
}
}
return discrete_fg.optimize();
}