diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 77393ab1b..68c248119 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -47,8 +47,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { // TODO(Frank): This can be quite expensive *unless* the factors have already // been pruned before. Another, possibly faster approach is branch and bound // search to find the K-best leaves and then create a single pruned conditional. -HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, - bool removeDeadModes) const { +HybridBayesNet HybridBayesNet::prune( + size_t maxNrLeaves, const std::optional &deadModeThreshold) const { // Collect all the discrete conditionals. Could be small if already pruned. const DiscreteBayesNet marginal = discreteMarginal(); @@ -58,25 +58,21 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, joint = joint * (*conditional); } - // Prune the joint. NOTE: imperative and, again, possibly quite expensive. - joint.prune(maxNrLeaves); - - // Create the result starting with the pruned joint. + // Initialize the resulting HybridBayesNet. HybridBayesNet result; - result.emplace_shared(joint); - // Get pruned discrete probabilities so - // we can prune HybridGaussianConditionals. - DiscreteConditional pruned = *result.back()->asDiscrete(); + // Prune the joint. NOTE: imperative and, again, possibly quite expensive. + DiscreteConditional pruned = joint; + pruned.prune(maxNrLeaves); DiscreteValues deadModesValues; - if (removeDeadModes) { + if (deadModeThreshold.has_value()) { DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); for (auto dkey : pruned.discreteKeys()) { Vector probabilities = marginals.marginalProbabilities(dkey); int index = -1; - auto threshold = (probabilities.array() > 0.99); + auto threshold = (probabilities.array() > *deadModeThreshold); // If atleast 1 value is non-zero, then we can find the index // Else if all are zero, index would be set to 0 which is incorrect if (!threshold.isZero()) { @@ -89,8 +85,26 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, } // Remove the modes (imperative) - result.back()->asDiscrete()->removeDiscreteModes(deadModesValues); - pruned = *result.back()->asDiscrete(); + pruned.removeDiscreteModes(deadModesValues); + + /* + If the pruned discrete conditional has any keys left, + we add it to the HybridBayesNet. + If not, it means it is an orphan so we don't add this pruned joint, + and instead add only the marginals below. + */ + if (pruned.keys().size() > 0) { + result.emplace_shared(pruned); + } + + // Add the marginals for future factors + for (auto &&[key, _] : deadModesValues) { + result.push_back( + std::dynamic_pointer_cast(marginals(key))); + } + + } else { + result.emplace_shared(pruned); } /* To prune, we visitWith every leaf in the HybridGaussianConditional. @@ -101,13 +115,13 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, */ // Go through all the Gaussian conditionals in the Bayes Net and prune them as - // per pruned Discrete joint. + // per pruned discrete joint. for (auto &&conditional : *this) { if (auto hgc = conditional->asHybrid()) { // Prune the hybrid Gaussian conditional! auto prunedHybridGaussianConditional = hgc->prune(pruned); - if (removeDeadModes) { + if (deadModeThreshold.has_value()) { KeyVector deadKeys, conditionalDiscreteKeys; for (const auto &kv : deadModesValues) { deadKeys.push_back(kv.first); diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 5d3270f4c..fb05e2407 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -217,11 +217,16 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves. * * @param maxNrLeaves Continuous values at which to compute the error. - * @param removeDeadModes Flag to enable removal of modes which only have a - * single possible assignment. + * @param deadModeThreshold The threshold to check the mode marginals against. + * If greater than this threshold, the mode gets assigned that value and is + * considered "dead" for hybrid elimination. + * The mode can then be removed since it only has a single possible + * assignment. * @return A pruned HybridBayesNet */ - HybridBayesNet prune(size_t maxNrLeaves, bool removeDeadModes = false) const; + HybridBayesNet prune( + size_t maxNrLeaves, + const std::optional &deadModeThreshold = {}) const; /** * @brief Error method using HybridValues which returns specific error for diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index 3f429198e..45320896a 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -81,7 +81,7 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph, if (maxNrLeaves) { // `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in // all the conditionals with the same keys in bayesNetFragment. - bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves); + bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, deadModeThreshold_); } // Add the partial bayes net to the posterior bayes net. @@ -116,6 +116,14 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph, factorKeys.end()) { newConditionals.push_back(conditional); + // Add the conditional parents to factorKeys + // so we add those conditionals too. + // NOTE: This assumes we have a structure where + // variables depend on those in the future. + for (auto &&parentKey : conditional->parents()) { + factorKeys.insert(parentKey); + } + // Remove the conditional from the updated Bayes net auto it = find(updatedHybridBayesNet.begin(), updatedHybridBayesNet.end(), conditional); diff --git a/gtsam/hybrid/HybridSmoother.h b/gtsam/hybrid/HybridSmoother.h index f941385bd..c3f022c62 100644 --- a/gtsam/hybrid/HybridSmoother.h +++ b/gtsam/hybrid/HybridSmoother.h @@ -29,7 +29,20 @@ class GTSAM_EXPORT HybridSmoother { HybridBayesNet hybridBayesNet_; HybridGaussianFactorGraph remainingFactorGraph_; + /// The threshold above which we make a decision about a mode. + std::optional deadModeThreshold_; + public: + /** + * @brief Constructor + * + * @param removeDeadModes Flag indicating whether to remove dead modes. + * @param deadModeThreshold The threshold above which a mode gets assigned a + * value and is considered "dead". 0.99 is a good starting value. + */ + HybridSmoother(const std::optional deadModeThreshold = {}) + : deadModeThreshold_(deadModeThreshold) {} + /** * Given new factors, perform an incremental update. * The relevant densities in the `hybridBayesNet` will be added to the input @@ -53,17 +66,6 @@ class GTSAM_EXPORT HybridSmoother { std::optional maxNrLeaves = {}, const std::optional given_ordering = {}); - /** - * @brief Get an elimination ordering which eliminates continuous and then - * discrete. - * - * Expects `newFactors` to already have the necessary conditionals connected - * to the - * - * @param factors - * @return Ordering - */ - /** * @brief Get an elimination ordering which eliminates continuous * and then discrete. diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 56e93499b..86dcd48e4 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -434,7 +434,7 @@ TEST(HybridBayesNet, RemoveDeadNodes) { HybridValues delta = posterior->optimize(); // Prune the Bayes net - const bool pruneDeadVariables = true; + const double pruneDeadVariables = 0.99; auto prunedBayesNet = posterior->prune(2, pruneDeadVariables); // Check that discrete joint only has M0 and not (M0, M1) @@ -445,11 +445,12 @@ TEST(HybridBayesNet, RemoveDeadNodes) { // Check that hybrid conditionals that only depend on M1 // are now Gaussian and not Hybrid EXPECT(prunedBayesNet.at(0)->isDiscrete()); - EXPECT(prunedBayesNet.at(1)->isHybrid()); + EXPECT(prunedBayesNet.at(1)->isDiscrete()); + EXPECT(prunedBayesNet.at(2)->isHybrid()); // Only P(X2 | X1, M1) depends on M1, // so it gets convert to a Gaussian P(X2 | X1) - EXPECT(prunedBayesNet.at(2)->isContinuous()); - EXPECT(prunedBayesNet.at(3)->isHybrid()); + EXPECT(prunedBayesNet.at(3)->isContinuous()); + EXPECT(prunedBayesNet.at(4)->isHybrid()); } /* ****************************************************************************/ diff --git a/gtsam/hybrid/tests/testHybridSmoother.cpp b/gtsam/hybrid/tests/testHybridSmoother.cpp index 5493ce53e..3a0f376cc 100644 --- a/gtsam/hybrid/tests/testHybridSmoother.cpp +++ b/gtsam/hybrid/tests/testHybridSmoother.cpp @@ -167,6 +167,59 @@ TEST(HybridSmoother, ValidPruningError) { EXPECT_DOUBLES_EQUAL(1e-8, errorTree(delta.discrete()), 1e-8); } +/****************************************************************************/ +// Test if dead mode removal works. +TEST(HybridSmoother, DeadModeRemoval) { + using namespace estimation_fixture; + + size_t K = 8; + + // Switching example of robot moving in 1D + // with given measurements and equal mode priors. + HybridNonlinearFactorGraph graph; + Values initial; + Switching switching = InitializeEstimationProblem( + K, 0.1, 0.1, measurements, "1/1 1/1", &graph, &initial); + + // Smoother with dead mode removal enabled. + HybridSmoother smoother(true); + + constexpr size_t maxNrLeaves = 3; + for (size_t k = 1; k < K; k++) { + if (k > 1) graph.push_back(switching.modeChain.at(k - 1)); // Mode chain + graph.push_back(switching.binaryFactors.at(k - 1)); // Motion Model + graph.push_back(switching.unaryFactors.at(k)); // Measurement + + initial.insert(X(k), switching.linearizationPoint.at(X(k))); + + HybridGaussianFactorGraph linearized = *graph.linearize(initial); + + smoother.update(linearized, maxNrLeaves); + + // Clear all the factors from the graph + graph.resize(0); + } + + // Get the continuous delta update as well as + // the optimal discrete assignment. + HybridValues delta = smoother.hybridBayesNet().optimize(); + + // Check discrete assignment + DiscreteValues expected_discrete; + for (size_t k = 0; k < K - 1; k++) { + expected_discrete[M(k)] = discrete_seq[k]; + } + EXPECT(assert_equal(expected_discrete, delta.discrete())); + + // Update nonlinear solution and verify + Values result = initial.retract(delta.continuous()); + Values expected_continuous; + for (size_t k = 0; k < K; k++) { + expected_continuous.insert(X(k), measurements[k]); + } + EXPECT(assert_equal(expected_continuous, result)); +} + /* ************************************************************************* */ int main() { TestResult tr;