From 05ad198ca6097627469833d3a9080789aa9a55d5 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 29 Jan 2025 23:37:58 -0500 Subject: [PATCH] Use restrict inside prune --- gtsam/hybrid/HybridBayesNet.cpp | 80 ++++++++------------------------- 1 file changed, 19 insertions(+), 61 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 2efb8030e..a911a047a 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -96,71 +96,36 @@ HybridBayesNet HybridBayesNet::prune( // Remove the modes (imperative) pruned.removeDiscreteModes(deadModesValues); - GTSAM_PRINT(deadModesValues); - #if GTSAM_HYBRID_TIMING gttoc_(DeadModeRemoval); #endif } - /* To prune, we visitWith every leaf in the HybridGaussianConditional. - * For each leaf, using the assignment we can check the discrete decision tree - * for 0.0 probability, then just set the leaf to a nullptr. - * - * We can later check the HybridGaussianConditional for just nullptrs. - */ - - // Go through all the Gaussian conditionals in the Bayes Net and prune them as - // per pruned discrete joint. + // Go through all the Gaussian conditionals, restrict them according to + // deadModesValues, and then prune further. for (auto &&conditional : *this) { - if (auto hgc = conditional->asHybrid()) { + if (conditional->isDiscrete()) continue; + + // Restrict conditional using deadModesValues. + // No-op if not a HybridGaussianConditional or deadModesValues empty. + auto restricted = conditional->restrict(deadModesValues); + + // Now decide on type what to do: + if (auto hgc = restricted->asHybrid()) { // Prune the hybrid Gaussian conditional! auto prunedHybridGaussianConditional = hgc->prune(pruned); if (!prunedHybridGaussianConditional) { - GTSAM_PRINT(marginal); - GTSAM_PRINT(pruned); throw std::runtime_error( "A HybridGaussianConditional had all its conditionals pruned"); } - - if (deadModeThreshold.has_value()) { - const auto &discreteParents = - prunedHybridGaussianConditional->discreteKeys(); - DiscreteValues deadParentValues; - DiscreteKeys liveParents; - for (const auto &key : discreteParents) { - auto it = deadModesValues.find(key.first); - if (it != deadModesValues.end()) - deadParentValues[key.first] = it->second; - else - liveParents.emplace_back(key); - } - // If so then we just get the corresponding Gaussian conditional: - if (deadParentValues.size() == discreteParents.size()) { - // print on how many discreteParents we are choosing: - result.push_back( - prunedHybridGaussianConditional->choose(deadParentValues)); - } else if (liveParents.size() > 0) { - auto newTree = prunedHybridGaussianConditional->factors(); - for (auto &&[key, value] : deadModesValues) { - newTree = newTree.choose(key, value); - } - result.emplace_shared(liveParents, - newTree); - } else { - // Add as-is - result.push_back(prunedHybridGaussianConditional); - } - } else { - // Type-erase and add to the pruned Bayes Net fragment. - result.push_back(prunedHybridGaussianConditional); - } - - } else if (auto gc = conditional->asGaussian()) { + // Type-erase and add to the pruned Bayes Net fragment. + result.push_back(prunedHybridGaussianConditional); + } else if (auto gc = restricted->asGaussian()) { // Add the non-HybridGaussianConditional conditional result.push_back(gc); - } - // We ignore DiscreteConditional as they are already pruned and added. + } else + throw std::runtime_error( + "HybrdiBayesNet::prune: Unknown HybridConditional type."); } #if GTSAM_HYBRID_TIMING @@ -169,21 +134,14 @@ HybridBayesNet HybridBayesNet::prune( if (deadModeThreshold.has_value()) { /* - If the pruned discrete conditional has any keys left, - we add it to the HybridBayesNet. - If not, it means it is an orphan so we don't add this pruned joint, - and instead add only the marginals below. + If the pruned discrete conditional has any keys left, we add it to the + HybridBayesNet. If not, it means it is an orphan so we don't add this + pruned joint, and instead add only the marginals below. */ if (pruned.keys().size() > 0) { result.emplace_shared(pruned); } - // Add the marginals for future factors - // for (auto &&[key, _] : deadModesValues) { - // result.push_back( - // std::dynamic_pointer_cast(marginals(key))); - // } - } else { result.emplace_shared(pruned); }