diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 8668bedd6..a665f6f92 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -46,7 +47,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { // TODO(Frank): This can be quite expensive *unless* the factors have already // been pruned before. Another, possibly faster approach is branch and bound // search to find the K-best leaves and then create a single pruned conditional. -HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { +HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, + bool removeDeadModes) const { // Collect all the discrete conditionals. Could be small if already pruned. const DiscreteBayesNet marginal = discreteMarginal(); @@ -66,6 +68,30 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { // we can prune HybridGaussianConditionals. DiscreteConditional pruned = *result.back()->asDiscrete(); + DiscreteValues deadModesValues; + if (removeDeadModes) { + DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); + for (auto dkey : pruned.discreteKeys()) { + Vector probabilities = marginals.marginalProbabilities(dkey); + + int index = -1; + auto threshold = (probabilities.array() > 0.99); + // If atleast 1 value is non-zero, then we can find the index + // Else if all are zero, index would be set to 0 which is incorrect + if (!threshold.isZero()) { + threshold.maxCoeff(&index); + } + + if (index >= 0) { + deadModesValues.insert(std::make_pair(dkey.first, index)); + } + } + + // Remove the modes (imperative) + result.back()->removeModes(deadModesValues); + pruned = *result.back()->asDiscrete(); + } + /* To prune, we visitWith every leaf in the HybridGaussianConditional. * For each leaf, using the assignment we can check the discrete decision tree * for 0.0 probability, then just set the leaf to a nullptr. @@ -80,8 +106,28 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { // Prune the hybrid Gaussian conditional! auto prunedHybridGaussianConditional = hgc->prune(pruned); - // Type-erase and add to the pruned Bayes Net fragment. - result.push_back(prunedHybridGaussianConditional); + if (removeDeadModes) { + KeyVector deadKeys, conditionalDiscreteKeys; + for (const auto &kv : deadModesValues) { + deadKeys.push_back(kv.first); + } + for (auto dkey : prunedHybridGaussianConditional->discreteKeys()) { + conditionalDiscreteKeys.push_back(dkey.first); + } + // The discrete keys in the conditional are the same as the keys in the + // dead modes, then we just get the corresponding Gaussian conditional. + if (deadKeys == conditionalDiscreteKeys) { + result.push_back( + prunedHybridGaussianConditional->choose(deadModesValues)); + } else { + // Add as-is + result.push_back(prunedHybridGaussianConditional); + } + } else { + // Type-erase and add to the pruned Bayes Net fragment. + result.push_back(prunedHybridGaussianConditional); + } + } else if (auto gc = conditional->asGaussian()) { // Add the non-HybridGaussianConditional conditional result.push_back(gc); diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 3e07c71ce..f114a6fa0 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -209,9 +209,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves. * * @param maxNrLeaves Continuous values at which to compute the error. + * @param removeDeadModes * @return A pruned HybridBayesNet */ - HybridBayesNet prune(size_t maxNrLeaves) const; + HybridBayesNet prune(size_t maxNrLeaves, bool removeDeadModes = false) const; /** * @brief Error method using HybridValues which returns specific error for diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 6a108b941..d2f67aa47 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -407,7 +407,7 @@ TEST(HybridBayesNet, Prune) { HybridBayesNet::shared_ptr posterior = s.linearizedFactorGraph().eliminateSequential(); - EXPECT_LONGS_EQUAL(7, posterior->size()); + EXPECT_LONGS_EQUAL(5, posterior->size()); // Call Max-Product to get MAP HybridValues delta = posterior->optimize(); @@ -421,6 +421,35 @@ TEST(HybridBayesNet, Prune) { EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous())); } +/* ****************************************************************************/ +// Test Bayes net pruning and dead node removal +TEST(HybridBayesNet, RemoveDeadNodes) { + Switching s(3); + + HybridBayesNet::shared_ptr posterior = + s.linearizedFactorGraph().eliminateSequential(); + EXPECT_LONGS_EQUAL(5, posterior->size()); + + // Call Max-Product to get MAP + HybridValues delta = posterior->optimize(); + + // Prune the Bayes net + const bool pruneDeadVariables = true; + auto prunedBayesNet = posterior->prune(2, pruneDeadVariables); + + // Check that discrete joint only has M0 and not (M0, M1) + // since M0 is removed + KeyVector actual_keys = prunedBayesNet.at(0)->asDiscrete()->keys(); + EXPECT(KeyVector{M(0)} == actual_keys); + + // Check that hybrid conditionals that only depend on M1 are no longer hybrid + EXPECT(prunedBayesNet.at(0)->isDiscrete()); + EXPECT(prunedBayesNet.at(1)->isHybrid()); + // Only P(X2 | X1, M1) depends on M1, so it is Gaussian + EXPECT(prunedBayesNet.at(2)->isContinuous()); + EXPECT(prunedBayesNet.at(3)->isHybrid()); +} + /* ****************************************************************************/ // Test Bayes net error and log-probability after pruning TEST(HybridBayesNet, ErrorAfterPruning) {