diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index 7d929f12c..8c04cb91c 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -75,7 +75,7 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { // been pruned before. Another, possibly faster approach is branch and bound // search to find the K-best leaves and then create a single pruned conditional. DiscreteBayesNet DiscreteBayesNet::prune( - size_t maxNrLeaves, const std::optional& deadModeThreshold, + size_t maxNrLeaves, const std::optional& marginalThreshold, DiscreteValues* fixedValues) const { // Multiply into one big conditional. NOTE: possibly quite expensive. DiscreteConditional joint; @@ -89,13 +89,13 @@ DiscreteBayesNet DiscreteBayesNet::prune( DiscreteValues deadModesValues; // If we have a dead mode threshold and discrete variables left after pruning, // then we run dead mode removal. - if (deadModeThreshold.has_value() && pruned.keys().size() > 0) { + if (marginalThreshold.has_value() && pruned.keys().size() > 0) { DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); for (auto dkey : pruned.discreteKeys()) { const Vector probabilities = marginals.marginalProbabilities(dkey); int index = -1; - auto threshold = (probabilities.array() > *deadModeThreshold); + auto threshold = (probabilities.array() > *marginalThreshold); // If atleast 1 value is non-zero, then we can find the index // Else if all are zero, index would be set to 0 which is incorrect if (!threshold.isZero()) { diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index 01b452865..eea1739f6 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -128,12 +128,12 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet { * @brief Prune the Bayes net * * @param maxNrLeaves The maximum number of leaves to keep. - * @param deadModeThreshold If given, threshold on marginals to prune variables. + * @param marginalThreshold If given, threshold on marginals to prune variables. * @param fixedValues If given, return the fixed values removed. * @return A new DiscreteBayesNet with pruned conditionals. */ DiscreteBayesNet prune(size_t maxNrLeaves, - const std::optional& deadModeThreshold = {}, + const std::optional& marginalThreshold = {}, DiscreteValues* fixedValues = nullptr) const; ///@} diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index ea4c3e80b..5353fe2e0 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -43,7 +43,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { /* ************************************************************************* */ HybridBayesNet HybridBayesNet::prune( - size_t maxNrLeaves, const std::optional &deadModeThreshold) const { + size_t maxNrLeaves, const std::optional &marginalThreshold, + DiscreteValues *fixedValues) const { #if GTSAM_HYBRID_TIMING gttic_(HybridPruning); #endif @@ -52,14 +53,14 @@ HybridBayesNet HybridBayesNet::prune( // Prune discrete Bayes net DiscreteValues fixed; - auto prunedBN = marginal.prune(maxNrLeaves, deadModeThreshold, &fixed); + auto prunedBN = marginal.prune(maxNrLeaves, marginalThreshold, &fixed); // Multiply into one big conditional. NOTE: possibly quite expensive. DiscreteConditional pruned; for (auto &&conditional : prunedBN) pruned = pruned * (*conditional); // Set the fixed values if requested. - if (deadModeThreshold && fixedValues) { + if (marginalThreshold && fixedValues) { *fixedValues = fixed; } @@ -71,7 +72,7 @@ HybridBayesNet HybridBayesNet::prune( if (conditional->isDiscrete()) continue; // No-op if not a HybridGaussianConditional. - if (deadModeThreshold) conditional = conditional->restrict(fixed); + if (marginalThreshold) conditional = conditional->restrict(fixed); // Now decide on type what to do: if (auto hgc = conditional->asHybrid()) { diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index fb05e2407..0840cb381 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -217,16 +217,18 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves. * * @param maxNrLeaves Continuous values at which to compute the error. - * @param deadModeThreshold The threshold to check the mode marginals against. - * If greater than this threshold, the mode gets assigned that value and is - * considered "dead" for hybrid elimination. - * The mode can then be removed since it only has a single possible - * assignment. + * @param marginalThreshold The threshold to check the mode marginals against. + * @param fixedValues The fixed values resulting from dead mode removal. + * + * @note If marginal greater than this threshold, the mode gets assigned that + * value and is considered "dead" for hybrid elimination. The mode can then be + * removed since it only has a single possible assignment. + * @return A pruned HybridBayesNet */ - HybridBayesNet prune( - size_t maxNrLeaves, - const std::optional &deadModeThreshold = {}) const; + HybridBayesNet prune(size_t maxNrLeaves, + const std::optional &marginalThreshold = {}, + DiscreteValues *fixedValues = nullptr) const; /** * @brief Error method using HybridValues which returns specific error for diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index 45320896a..594c12825 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -81,7 +81,7 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph, if (maxNrLeaves) { // `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in // all the conditionals with the same keys in bayesNetFragment. - bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, deadModeThreshold_); + bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, marginalThreshold_); } // Add the partial bayes net to the posterior bayes net. diff --git a/gtsam/hybrid/HybridSmoother.h b/gtsam/hybrid/HybridSmoother.h index c3f022c62..2f7bfcebb 100644 --- a/gtsam/hybrid/HybridSmoother.h +++ b/gtsam/hybrid/HybridSmoother.h @@ -30,18 +30,19 @@ class GTSAM_EXPORT HybridSmoother { HybridGaussianFactorGraph remainingFactorGraph_; /// The threshold above which we make a decision about a mode. - std::optional deadModeThreshold_; + std::optional marginalThreshold_; + DiscreteValues fixedValues_; public: /** * @brief Constructor * * @param removeDeadModes Flag indicating whether to remove dead modes. - * @param deadModeThreshold The threshold above which a mode gets assigned a + * @param marginalThreshold The threshold above which a mode gets assigned a * value and is considered "dead". 0.99 is a good starting value. */ - HybridSmoother(const std::optional deadModeThreshold = {}) - : deadModeThreshold_(deadModeThreshold) {} + HybridSmoother(const std::optional marginalThreshold = {}) + : marginalThreshold_(marginalThreshold) {} /** * Given new factors, perform an incremental update.