Fix pruning

release/4.3a0
Frank Dellaert 2025-01-29 17:54:29 -05:00
parent bb0c70b482
commit 98cdf1193f
1 changed files with 63 additions and 27 deletions

View File

@ -49,6 +49,9 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
// search to find the K-best leaves and then create a single pruned conditional. // search to find the K-best leaves and then create a single pruned conditional.
HybridBayesNet HybridBayesNet::prune( HybridBayesNet HybridBayesNet::prune(
size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const { size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const {
#if GTSAM_HYBRID_TIMING
gttic_(HybridPruning);
#endif
// Collect all the discrete conditionals. Could be small if already pruned. // Collect all the discrete conditionals. Could be small if already pruned.
const DiscreteBayesNet marginal = discreteMarginal(); const DiscreteBayesNet marginal = discreteMarginal();
@ -69,6 +72,10 @@ HybridBayesNet HybridBayesNet::prune(
// If we have a dead mode threshold and discrete variables left after pruning, // If we have a dead mode threshold and discrete variables left after pruning,
// then we run dead mode removal. // then we run dead mode removal.
if (deadModeThreshold.has_value() && pruned.keys().size() > 0) { if (deadModeThreshold.has_value() && pruned.keys().size() > 0) {
#if GTSAM_HYBRID_TIMING
gttic_(DeadModeRemoval);
#endif
DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
for (auto dkey : pruned.discreteKeys()) { for (auto dkey : pruned.discreteKeys()) {
Vector probabilities = marginals.marginalProbabilities(dkey); Vector probabilities = marginals.marginalProbabilities(dkey);
@ -89,24 +96,11 @@ HybridBayesNet HybridBayesNet::prune(
// Remove the modes (imperative) // Remove the modes (imperative)
pruned.removeDiscreteModes(deadModesValues); pruned.removeDiscreteModes(deadModesValues);
/* GTSAM_PRINT(deadModesValues);
If the pruned discrete conditional has any keys left,
we add it to the HybridBayesNet.
If not, it means it is an orphan so we don't add this pruned joint,
and instead add only the marginals below.
*/
if (pruned.keys().size() > 0) {
result.emplace_shared<DiscreteConditional>(pruned);
}
// Add the marginals for future factors #if GTSAM_HYBRID_TIMING
for (auto &&[key, _] : deadModesValues) { gttoc_(DeadModeRemoval);
result.push_back( #endif
std::dynamic_pointer_cast<DiscreteConditional>(marginals(key)));
}
} else {
result.emplace_shared<DiscreteConditional>(pruned);
} }
/* To prune, we visitWith every leaf in the HybridGaussianConditional. /* To prune, we visitWith every leaf in the HybridGaussianConditional.
@ -122,20 +116,37 @@ HybridBayesNet HybridBayesNet::prune(
if (auto hgc = conditional->asHybrid()) { if (auto hgc = conditional->asHybrid()) {
// Prune the hybrid Gaussian conditional! // Prune the hybrid Gaussian conditional!
auto prunedHybridGaussianConditional = hgc->prune(pruned); auto prunedHybridGaussianConditional = hgc->prune(pruned);
if (!prunedHybridGaussianConditional) {
GTSAM_PRINT(marginal);
GTSAM_PRINT(pruned);
throw std::runtime_error(
"A HybridGaussianConditional had all its conditionals pruned");
}
if (deadModeThreshold.has_value()) { if (deadModeThreshold.has_value()) {
KeyVector deadKeys, conditionalDiscreteKeys; const auto &discreteParents =
for (const auto &kv : deadModesValues) { prunedHybridGaussianConditional->discreteKeys();
deadKeys.push_back(kv.first); DiscreteValues deadParentValues;
DiscreteKeys liveParents;
for (const auto &key : discreteParents) {
auto it = deadModesValues.find(key.first);
if (it != deadModesValues.end())
deadParentValues[key.first] = it->second;
else
liveParents.emplace_back(key);
} }
for (auto dkey : prunedHybridGaussianConditional->discreteKeys()) { // If so then we just get the corresponding Gaussian conditional:
conditionalDiscreteKeys.push_back(dkey.first); if (deadParentValues.size() == discreteParents.size()) {
} // print on how many discreteParents we are choosing:
// The discrete keys in the conditional are the same as the keys in the
// dead modes, then we just get the corresponding Gaussian conditional.
if (deadKeys == conditionalDiscreteKeys) {
result.push_back( result.push_back(
prunedHybridGaussianConditional->choose(deadModesValues)); prunedHybridGaussianConditional->choose(deadParentValues));
} else if (liveParents.size() > 0) {
auto newTree = prunedHybridGaussianConditional->factors();
for (auto &&[key, value] : deadModesValues) {
newTree = newTree.choose(key, value);
}
result.emplace_shared<HybridGaussianConditional>(liveParents,
newTree);
} else { } else {
// Add as-is // Add as-is
result.push_back(prunedHybridGaussianConditional); result.push_back(prunedHybridGaussianConditional);
@ -152,6 +163,31 @@ HybridBayesNet HybridBayesNet::prune(
// We ignore DiscreteConditional as they are already pruned and added. // We ignore DiscreteConditional as they are already pruned and added.
} }
#if GTSAM_HYBRID_TIMING
gttoc_(HybridPruning);
#endif
if (deadModeThreshold.has_value()) {
/*
If the pruned discrete conditional has any keys left,
we add it to the HybridBayesNet.
If not, it means it is an orphan so we don't add this pruned joint,
and instead add only the marginals below.
*/
if (pruned.keys().size() > 0) {
result.emplace_shared<DiscreteConditional>(pruned);
}
// Add the marginals for future factors
// for (auto &&[key, _] : deadModesValues) {
// result.push_back(
// std::dynamic_pointer_cast<DiscreteConditional>(marginals(key)));
// }
} else {
result.emplace_shared<DiscreteConditional>(pruned);
}
return result; return result;
} }