From 1764b58e8cf47ac0f5c3b8acc0c0da1b1e13f957 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 24 Jan 2025 23:11:42 -0500 Subject: [PATCH] use std::optional for specifying dead mode threshold --- gtsam/hybrid/HybridBayesNet.cpp | 10 +++++----- gtsam/hybrid/HybridBayesNet.h | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index e9583b845..68c248119 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -47,8 +47,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { // TODO(Frank): This can be quite expensive *unless* the factors have already // been pruned before. Another, possibly faster approach is branch and bound // search to find the K-best leaves and then create a single pruned conditional. -HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, bool removeDeadModes, - double deadModeThreshold) const { +HybridBayesNet HybridBayesNet::prune( + size_t maxNrLeaves, const std::optional &deadModeThreshold) const { // Collect all the discrete conditionals. Could be small if already pruned. const DiscreteBayesNet marginal = discreteMarginal(); @@ -66,13 +66,13 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, bool removeDeadModes, pruned.prune(maxNrLeaves); DiscreteValues deadModesValues; - if (removeDeadModes) { + if (deadModeThreshold.has_value()) { DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); for (auto dkey : pruned.discreteKeys()) { Vector probabilities = marginals.marginalProbabilities(dkey); int index = -1; - auto threshold = (probabilities.array() > deadModeThreshold); + auto threshold = (probabilities.array() > *deadModeThreshold); // If atleast 1 value is non-zero, then we can find the index // Else if all are zero, index would be set to 0 which is incorrect if (!threshold.isZero()) { @@ -121,7 +121,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, bool removeDeadModes, // Prune the hybrid Gaussian conditional! auto prunedHybridGaussianConditional = hgc->prune(pruned); - if (removeDeadModes) { + if (deadModeThreshold.has_value()) { KeyVector deadKeys, conditionalDiscreteKeys; for (const auto &kv : deadModesValues) { deadKeys.push_back(kv.first); diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 0546a7422..86fc9527a 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -217,15 +217,15 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves. * * @param maxNrLeaves Continuous values at which to compute the error. - * @param removeDeadModes Flag to enable removal of modes which only have a - * single possible assignment. * @param deadModeThreshold The threshold to check the mode marginals against. * If greater than this threshold, the mode gets assigned that value and is * considered "dead" for hybrid elimination. + * The mode can then be removed since it only has a single possible + * assignment. * @return A pruned HybridBayesNet */ - HybridBayesNet prune(size_t maxNrLeaves, bool removeDeadModes = false, - double deadModeThreshold = 0.99) const; + HybridBayesNet prune(size_t maxNrLeaves, + const std::optional &deadModeThreshold) const; /** * @brief Error method using HybridValues which returns specific error for