From 7ca7e4549e198580f59cc17701b44505fa8e96ee Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 24 Jan 2025 19:57:47 -0500 Subject: [PATCH] improve dead mode removal by checking for empty discrete joints and adding the marginals for future factors --- gtsam/hybrid/HybridBayesNet.cpp | 34 ++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index b6622980b..d27b1026e 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -58,15 +58,12 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, joint = joint * (*conditional); } - // Create the result starting with the pruned joint. + // Initialize the resulting HybridBayesNet. HybridBayesNet result; - result.emplace_shared(joint); - // Prune the joint. NOTE: imperative and, again, possibly quite expensive. - result.back()->asDiscrete()->prune(maxNrLeaves); - // Get pruned discrete probabilities so - // we can prune HybridGaussianConditionals. - DiscreteConditional pruned = *result.back()->asDiscrete(); + // Prune the joint. NOTE: imperative and, again, possibly quite expensive. + DiscreteConditional pruned = joint; + joint.prune(maxNrLeaves); DiscreteValues deadModesValues; if (removeDeadModes) { @@ -88,8 +85,26 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, } // Remove the modes (imperative) - result.back()->asDiscrete()->removeDiscreteModes(deadModesValues); - pruned = *result.back()->asDiscrete(); + pruned.removeDiscreteModes(deadModesValues); + + /* + If the pruned discrete conditional has any keys left, + we add it to the HybridBayesNet. + If not, it means it is an orphan so we don't add this pruned joint, + and instead add only the marginals below. + */ + if (pruned.keys().size() > 0) { + result.emplace_shared(pruned); + } + + // Add the marginals for future factors + for (auto &&[key, _] : deadModesValues) { + result.push_back( + std::dynamic_pointer_cast(marginals(key))); + } + + } else { + result.emplace_shared(pruned); } /* To prune, we visitWith every leaf in the HybridGaussianConditional. @@ -186,6 +201,7 @@ DiscreteValues HybridBayesNet::mpe() const { } } } + return discrete_fg.optimize(); }