diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index 56265b0a4..7d929f12c 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -18,6 +18,7 @@ #include #include +#include #include namespace gtsam { @@ -68,6 +69,59 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { return result; } +/* ************************************************************************* */ +// The implementation is: build the entire joint into one factor and then prune. +// TODO(Frank): This can be quite expensive *unless* the factors have already +// been pruned before. Another, possibly faster approach is branch and bound +// search to find the K-best leaves and then create a single pruned conditional. +DiscreteBayesNet DiscreteBayesNet::prune( + size_t maxNrLeaves, const std::optional& deadModeThreshold, + DiscreteValues* fixedValues) const { + // Multiply into one big conditional. NOTE: possibly quite expensive. + DiscreteConditional joint; + for (const DiscreteConditional::shared_ptr& conditional : *this) + joint = joint * (*conditional); + + // Prune the joint. NOTE: imperative and, again, possibly quite expensive. + DiscreteConditional pruned = joint; + pruned.prune(maxNrLeaves); + + DiscreteValues deadModesValues; + // If we have a dead mode threshold and discrete variables left after pruning, + // then we run dead mode removal. + if (deadModeThreshold.has_value() && pruned.keys().size() > 0) { + DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); + for (auto dkey : pruned.discreteKeys()) { + const Vector probabilities = marginals.marginalProbabilities(dkey); + + int index = -1; + auto threshold = (probabilities.array() > *deadModeThreshold); + // If atleast 1 value is non-zero, then we can find the index + // Else if all are zero, index would be set to 0 which is incorrect + if (!threshold.isZero()) { + threshold.maxCoeff(&index); + } + + if (index >= 0) { + deadModesValues.emplace(dkey.first, index); + } + } + + // Remove the modes (imperative) + pruned.removeDiscreteModes(deadModesValues); + + // Set the fixed values if requested. + if (fixedValues) { + *fixedValues = deadModesValues; + } + } + + // Return the resulting DiscreteBayesNet. + DiscreteBayesNet result; + if (pruned.keys().size() > 0) result.push_back(pruned); + return result; +} + /* *********************************************************************** */ std::string DiscreteBayesNet::markdown( const KeyFormatter& keyFormatter, diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index 738b91aa5..01b452865 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -124,6 +124,18 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet { */ DiscreteValues sample(DiscreteValues given) const; + /** + * @brief Prune the Bayes net + * + * @param maxNrLeaves The maximum number of leaves to keep. + * @param deadModeThreshold If given, threshold on marginals to prune variables. + * @param fixedValues If given, return the fixed values removed. + * @return A new DiscreteBayesNet with pruned conditionals. + */ + DiscreteBayesNet prune(size_t maxNrLeaves, + const std::optional& deadModeThreshold = {}, + DiscreteValues* fixedValues = nullptr) const; + ///@} /// @name Wrapper support /// @{ diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index a911a047a..ea4c3e80b 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -19,7 +19,6 @@ #include #include #include -#include #include #include #include @@ -43,10 +42,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { } /* ************************************************************************* */ -// The implementation is: build the entire joint into one factor and then prune. -// TODO(Frank): This can be quite expensive *unless* the factors have already -// been pruned before. Another, possibly faster approach is branch and bound -// search to find the K-best leaves and then create a single pruned conditional. HybridBayesNet HybridBayesNet::prune( size_t maxNrLeaves, const std::optional &deadModeThreshold) const { #if GTSAM_HYBRID_TIMING @@ -55,63 +50,31 @@ HybridBayesNet HybridBayesNet::prune( // Collect all the discrete conditionals. Could be small if already pruned. const DiscreteBayesNet marginal = discreteMarginal(); + // Prune discrete Bayes net + DiscreteValues fixed; + auto prunedBN = marginal.prune(maxNrLeaves, deadModeThreshold, &fixed); + // Multiply into one big conditional. NOTE: possibly quite expensive. - DiscreteConditional joint; - for (auto &&conditional : marginal) { - joint = joint * (*conditional); + DiscreteConditional pruned; + for (auto &&conditional : prunedBN) pruned = pruned * (*conditional); + + // Set the fixed values if requested. + if (deadModeThreshold && fixedValues) { + *fixedValues = fixed; } - // Initialize the resulting HybridBayesNet. HybridBayesNet result; - // Prune the joint. NOTE: imperative and, again, possibly quite expensive. - DiscreteConditional pruned = joint; - pruned.prune(maxNrLeaves); - - DiscreteValues deadModesValues; - // If we have a dead mode threshold and discrete variables left after pruning, - // then we run dead mode removal. - if (deadModeThreshold.has_value() && pruned.keys().size() > 0) { -#if GTSAM_HYBRID_TIMING - gttic_(DeadModeRemoval); -#endif - - DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); - for (auto dkey : pruned.discreteKeys()) { - Vector probabilities = marginals.marginalProbabilities(dkey); - - int index = -1; - auto threshold = (probabilities.array() > *deadModeThreshold); - // If atleast 1 value is non-zero, then we can find the index - // Else if all are zero, index would be set to 0 which is incorrect - if (!threshold.isZero()) { - threshold.maxCoeff(&index); - } - - if (index >= 0) { - deadModesValues.emplace(dkey.first, index); - } - } - - // Remove the modes (imperative) - pruned.removeDiscreteModes(deadModesValues); - -#if GTSAM_HYBRID_TIMING - gttoc_(DeadModeRemoval); -#endif - } - // Go through all the Gaussian conditionals, restrict them according to - // deadModesValues, and then prune further. - for (auto &&conditional : *this) { + // fixed values, and then prune further. + for (std::shared_ptr conditional : *this) { if (conditional->isDiscrete()) continue; - // Restrict conditional using deadModesValues. - // No-op if not a HybridGaussianConditional or deadModesValues empty. - auto restricted = conditional->restrict(deadModesValues); + // No-op if not a HybridGaussianConditional. + if (deadModeThreshold) conditional = conditional->restrict(fixed); // Now decide on type what to do: - if (auto hgc = restricted->asHybrid()) { + if (auto hgc = conditional->asHybrid()) { // Prune the hybrid Gaussian conditional! auto prunedHybridGaussianConditional = hgc->prune(pruned); if (!prunedHybridGaussianConditional) { @@ -120,7 +83,7 @@ HybridBayesNet HybridBayesNet::prune( } // Type-erase and add to the pruned Bayes Net fragment. result.push_back(prunedHybridGaussianConditional); - } else if (auto gc = restricted->asGaussian()) { + } else if (auto gc = conditional->asGaussian()) { // Add the non-HybridGaussianConditional conditional result.push_back(gc); } else @@ -128,23 +91,9 @@ HybridBayesNet HybridBayesNet::prune( "HybrdiBayesNet::prune: Unknown HybridConditional type."); } -#if GTSAM_HYBRID_TIMING - gttoc_(HybridPruning); -#endif - - if (deadModeThreshold.has_value()) { - /* - If the pruned discrete conditional has any keys left, we add it to the - HybridBayesNet. If not, it means it is an orphan so we don't add this - pruned joint, and instead add only the marginals below. - */ - if (pruned.keys().size() > 0) { - result.emplace_shared(pruned); - } - - } else { - result.emplace_shared(pruned); - } + // Add the pruned discrete conditionals to the result. + for (const DiscreteConditional::shared_ptr &discrete : prunedBN) + result.push_back(discrete); return result; }