diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index b6622980b..68c248119 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -47,8 +47,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { // TODO(Frank): This can be quite expensive *unless* the factors have already // been pruned before. Another, possibly faster approach is branch and bound // search to find the K-best leaves and then create a single pruned conditional. -HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, - bool removeDeadModes) const { +HybridBayesNet HybridBayesNet::prune( + size_t maxNrLeaves, const std::optional &deadModeThreshold) const { // Collect all the discrete conditionals. Could be small if already pruned. const DiscreteBayesNet marginal = discreteMarginal(); @@ -58,24 +58,21 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, joint = joint * (*conditional); } - // Create the result starting with the pruned joint. + // Initialize the resulting HybridBayesNet. HybridBayesNet result; - result.emplace_shared(joint); - // Prune the joint. NOTE: imperative and, again, possibly quite expensive. - result.back()->asDiscrete()->prune(maxNrLeaves); - // Get pruned discrete probabilities so - // we can prune HybridGaussianConditionals. - DiscreteConditional pruned = *result.back()->asDiscrete(); + // Prune the joint. NOTE: imperative and, again, possibly quite expensive. + DiscreteConditional pruned = joint; + pruned.prune(maxNrLeaves); DiscreteValues deadModesValues; - if (removeDeadModes) { + if (deadModeThreshold.has_value()) { DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); for (auto dkey : pruned.discreteKeys()) { Vector probabilities = marginals.marginalProbabilities(dkey); int index = -1; - auto threshold = (probabilities.array() > 0.99); + auto threshold = (probabilities.array() > *deadModeThreshold); // If atleast 1 value is non-zero, then we can find the index // Else if all are zero, index would be set to 0 which is incorrect if (!threshold.isZero()) { @@ -88,8 +85,26 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, } // Remove the modes (imperative) - result.back()->asDiscrete()->removeDiscreteModes(deadModesValues); - pruned = *result.back()->asDiscrete(); + pruned.removeDiscreteModes(deadModesValues); + + /* + If the pruned discrete conditional has any keys left, + we add it to the HybridBayesNet. + If not, it means it is an orphan so we don't add this pruned joint, + and instead add only the marginals below. + */ + if (pruned.keys().size() > 0) { + result.emplace_shared(pruned); + } + + // Add the marginals for future factors + for (auto &&[key, _] : deadModesValues) { + result.push_back( + std::dynamic_pointer_cast(marginals(key))); + } + + } else { + result.emplace_shared(pruned); } /* To prune, we visitWith every leaf in the HybridGaussianConditional. @@ -100,13 +115,13 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, */ // Go through all the Gaussian conditionals in the Bayes Net and prune them as - // per pruned Discrete joint. + // per pruned discrete joint. for (auto &&conditional : *this) { if (auto hgc = conditional->asHybrid()) { // Prune the hybrid Gaussian conditional! auto prunedHybridGaussianConditional = hgc->prune(pruned); - if (removeDeadModes) { + if (deadModeThreshold.has_value()) { KeyVector deadKeys, conditionalDiscreteKeys; for (const auto &kv : deadModesValues) { deadKeys.push_back(kv.first); @@ -186,6 +201,7 @@ DiscreteValues HybridBayesNet::mpe() const { } } } + return discrete_fg.optimize(); } diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 5d3270f4c..fb05e2407 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -217,11 +217,16 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves. * * @param maxNrLeaves Continuous values at which to compute the error. - * @param removeDeadModes Flag to enable removal of modes which only have a - * single possible assignment. + * @param deadModeThreshold The threshold to check the mode marginals against. + * If greater than this threshold, the mode gets assigned that value and is + * considered "dead" for hybrid elimination. + * The mode can then be removed since it only has a single possible + * assignment. * @return A pruned HybridBayesNet */ - HybridBayesNet prune(size_t maxNrLeaves, bool removeDeadModes = false) const; + HybridBayesNet prune( + size_t maxNrLeaves, + const std::optional &deadModeThreshold = {}) const; /** * @brief Error method using HybridValues which returns specific error for diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index ca3e27252..45320896a 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -24,17 +24,14 @@ namespace gtsam { /* ************************************************************************* */ -Ordering HybridSmoother::getOrdering( - const HybridGaussianFactorGraph &newFactors) { - HybridGaussianFactorGraph factors(hybridBayesNet()); - factors.push_back(newFactors); - +Ordering HybridSmoother::getOrdering(const HybridGaussianFactorGraph &factors, + const KeySet &newFactorKeys) { // Get all the discrete keys from the factors KeySet allDiscrete = factors.discreteKeySet(); // Create KeyVector with continuous keys followed by discrete keys. KeyVector newKeysDiscreteLast; - const KeySet newFactorKeys = newFactors.keys(); + // Insert continuous keys first. for (auto &k : newFactorKeys) { if (!allDiscrete.exists(k)) { @@ -56,29 +53,35 @@ Ordering HybridSmoother::getOrdering( } /* ************************************************************************* */ -void HybridSmoother::update(HybridGaussianFactorGraph graph, +void HybridSmoother::update(const HybridGaussianFactorGraph &graph, std::optional maxNrLeaves, const std::optional given_ordering) { + HybridGaussianFactorGraph updatedGraph; + // Add the necessary conditionals from the previous timestep(s). + std::tie(updatedGraph, hybridBayesNet_) = + addConditionals(graph, hybridBayesNet_); + Ordering ordering; // If no ordering provided, then we compute one if (!given_ordering.has_value()) { - ordering = this->getOrdering(graph); + // Get the keys from the new factors + const KeySet newFactorKeys = graph.keys(); + + // Since updatedGraph now has all the connected conditionals, + // we can get the correct ordering. + ordering = this->getOrdering(updatedGraph, newFactorKeys); } else { ordering = *given_ordering; } - // Add the necessary conditionals from the previous timestep(s). - std::tie(graph, hybridBayesNet_) = - addConditionals(graph, hybridBayesNet_, ordering); - // Eliminate. - HybridBayesNet bayesNetFragment = *graph.eliminateSequential(ordering); + HybridBayesNet bayesNetFragment = *updatedGraph.eliminateSequential(ordering); /// Prune if (maxNrLeaves) { // `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in // all the conditionals with the same keys in bayesNetFragment. - bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves); + bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, deadModeThreshold_); } // Add the partial bayes net to the posterior bayes net. @@ -88,10 +91,11 @@ void HybridSmoother::update(HybridGaussianFactorGraph graph, /* ************************************************************************* */ std::pair HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph, - const HybridBayesNet &originalHybridBayesNet, - const Ordering &ordering) const { + const HybridBayesNet &hybridBayesNet) const { HybridGaussianFactorGraph graph(originalGraph); - HybridBayesNet hybridBayesNet(originalHybridBayesNet); + HybridBayesNet updatedHybridBayesNet(hybridBayesNet); + + KeySet factorKeys = graph.keys(); // If hybridBayesNet is not empty, // it means we have conditionals to add to the factor graph. @@ -99,10 +103,6 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph, // We add all relevant hybrid conditionals on the last continuous variable // in the previous `hybridBayesNet` to the graph - // Conditionals to remove from the bayes net - // since the conditional will be updated. - std::vector conditionals_to_erase; - // New conditionals to add to the graph gtsam::HybridBayesNet newConditionals; @@ -112,25 +112,32 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph, auto conditional = hybridBayesNet.at(i); for (auto &key : conditional->frontals()) { - if (std::find(ordering.begin(), ordering.end(), key) != - ordering.end()) { + if (std::find(factorKeys.begin(), factorKeys.end(), key) != + factorKeys.end()) { newConditionals.push_back(conditional); - conditionals_to_erase.push_back(conditional); + + // Add the conditional parents to factorKeys + // so we add those conditionals too. + // NOTE: This assumes we have a structure where + // variables depend on those in the future. + for (auto &&parentKey : conditional->parents()) { + factorKeys.insert(parentKey); + } + + // Remove the conditional from the updated Bayes net + auto it = find(updatedHybridBayesNet.begin(), + updatedHybridBayesNet.end(), conditional); + updatedHybridBayesNet.erase(it); break; } } } - // Remove conditionals at the end so we don't affect the order in the - // original bayes net. - for (auto &&conditional : conditionals_to_erase) { - auto it = find(hybridBayesNet.begin(), hybridBayesNet.end(), conditional); - hybridBayesNet.erase(it); - } graph.push_back(newConditionals); } - return {graph, hybridBayesNet}; + + return {graph, updatedHybridBayesNet}; } /* ************************************************************************* */ diff --git a/gtsam/hybrid/HybridSmoother.h b/gtsam/hybrid/HybridSmoother.h index 66edf86d6..c3f022c62 100644 --- a/gtsam/hybrid/HybridSmoother.h +++ b/gtsam/hybrid/HybridSmoother.h @@ -29,7 +29,20 @@ class GTSAM_EXPORT HybridSmoother { HybridBayesNet hybridBayesNet_; HybridGaussianFactorGraph remainingFactorGraph_; + /// The threshold above which we make a decision about a mode. + std::optional deadModeThreshold_; + public: + /** + * @brief Constructor + * + * @param removeDeadModes Flag indicating whether to remove dead modes. + * @param deadModeThreshold The threshold above which a mode gets assigned a + * value and is considered "dead". 0.99 is a good starting value. + */ + HybridSmoother(const std::optional deadModeThreshold = {}) + : deadModeThreshold_(deadModeThreshold) {} + /** * Given new factors, perform an incremental update. * The relevant densities in the `hybridBayesNet` will be added to the input @@ -49,11 +62,24 @@ class GTSAM_EXPORT HybridSmoother { * @param given_ordering The (optional) ordering for elimination, only * continuous variables are allowed */ - void update(HybridGaussianFactorGraph graph, + void update(const HybridGaussianFactorGraph& graph, std::optional maxNrLeaves = {}, const std::optional given_ordering = {}); - Ordering getOrdering(const HybridGaussianFactorGraph& newFactors); + /** + * @brief Get an elimination ordering which eliminates continuous + * and then discrete. + * + * Expects `factors` to already have the necessary conditionals + * which were connected to the variables in the newly added factors. + * Those variables should be in `newFactorKeys`. + * + * @param factors All the new factors and connected conditionals. + * @param newFactorKeys The keys/variables in the newly added factors. + * @return Ordering + */ + Ordering getOrdering(const HybridGaussianFactorGraph& factors, + const KeySet& newFactorKeys); /** * @brief Add conditionals from previous timestep as part of liquefication. @@ -66,7 +92,7 @@ class GTSAM_EXPORT HybridSmoother { */ std::pair addConditionals( const HybridGaussianFactorGraph& graph, - const HybridBayesNet& hybridBayesNet, const Ordering& ordering) const; + const HybridBayesNet& hybridBayesNet) const; /** * @brief Get the hybrid Gaussian conditional from diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 56e93499b..86dcd48e4 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -434,7 +434,7 @@ TEST(HybridBayesNet, RemoveDeadNodes) { HybridValues delta = posterior->optimize(); // Prune the Bayes net - const bool pruneDeadVariables = true; + const double pruneDeadVariables = 0.99; auto prunedBayesNet = posterior->prune(2, pruneDeadVariables); // Check that discrete joint only has M0 and not (M0, M1) @@ -445,11 +445,12 @@ TEST(HybridBayesNet, RemoveDeadNodes) { // Check that hybrid conditionals that only depend on M1 // are now Gaussian and not Hybrid EXPECT(prunedBayesNet.at(0)->isDiscrete()); - EXPECT(prunedBayesNet.at(1)->isHybrid()); + EXPECT(prunedBayesNet.at(1)->isDiscrete()); + EXPECT(prunedBayesNet.at(2)->isHybrid()); // Only P(X2 | X1, M1) depends on M1, // so it gets convert to a Gaussian P(X2 | X1) - EXPECT(prunedBayesNet.at(2)->isContinuous()); - EXPECT(prunedBayesNet.at(3)->isHybrid()); + EXPECT(prunedBayesNet.at(3)->isContinuous()); + EXPECT(prunedBayesNet.at(4)->isHybrid()); } /* ****************************************************************************/ diff --git a/gtsam/hybrid/tests/testHybridSmoother.cpp b/gtsam/hybrid/tests/testHybridSmoother.cpp index 145f44d1e..3a0f376cc 100644 --- a/gtsam/hybrid/tests/testHybridSmoother.cpp +++ b/gtsam/hybrid/tests/testHybridSmoother.cpp @@ -95,16 +95,15 @@ TEST(HybridSmoother, IncrementalSmoother) { initial.insert(X(k), switching.linearizationPoint.at(X(k))); HybridGaussianFactorGraph linearized = *graph.linearize(initial); - Ordering ordering = smoother.getOrdering(linearized); - smoother.update(linearized, maxNrLeaves, ordering); + smoother.update(linearized, maxNrLeaves); // Clear all the factors from the graph graph.resize(0); } EXPECT_LONGS_EQUAL(11, - smoother.hybridBayesNet().at(0)->asDiscrete()->nrValues()); + smoother.hybridBayesNet().at(3)->asDiscrete()->nrValues()); // Get the continuous delta update as well as // the optimal discrete assignment. @@ -150,16 +149,15 @@ TEST(HybridSmoother, ValidPruningError) { initial.insert(X(k), switching.linearizationPoint.at(X(k))); HybridGaussianFactorGraph linearized = *graph.linearize(initial); - Ordering ordering = smoother.getOrdering(linearized); - smoother.update(linearized, maxNrLeaves, ordering); + smoother.update(linearized, maxNrLeaves); // Clear all the factors from the graph graph.resize(0); } EXPECT_LONGS_EQUAL(14, - smoother.hybridBayesNet().at(0)->asDiscrete()->nrValues()); + smoother.hybridBayesNet().at(6)->asDiscrete()->nrValues()); // Get the continuous delta update as well as // the optimal discrete assignment. @@ -169,6 +167,59 @@ TEST(HybridSmoother, ValidPruningError) { EXPECT_DOUBLES_EQUAL(1e-8, errorTree(delta.discrete()), 1e-8); } +/****************************************************************************/ +// Test if dead mode removal works. +TEST(HybridSmoother, DeadModeRemoval) { + using namespace estimation_fixture; + + size_t K = 8; + + // Switching example of robot moving in 1D + // with given measurements and equal mode priors. + HybridNonlinearFactorGraph graph; + Values initial; + Switching switching = InitializeEstimationProblem( + K, 0.1, 0.1, measurements, "1/1 1/1", &graph, &initial); + + // Smoother with dead mode removal enabled. + HybridSmoother smoother(true); + + constexpr size_t maxNrLeaves = 3; + for (size_t k = 1; k < K; k++) { + if (k > 1) graph.push_back(switching.modeChain.at(k - 1)); // Mode chain + graph.push_back(switching.binaryFactors.at(k - 1)); // Motion Model + graph.push_back(switching.unaryFactors.at(k)); // Measurement + + initial.insert(X(k), switching.linearizationPoint.at(X(k))); + + HybridGaussianFactorGraph linearized = *graph.linearize(initial); + + smoother.update(linearized, maxNrLeaves); + + // Clear all the factors from the graph + graph.resize(0); + } + + // Get the continuous delta update as well as + // the optimal discrete assignment. + HybridValues delta = smoother.hybridBayesNet().optimize(); + + // Check discrete assignment + DiscreteValues expected_discrete; + for (size_t k = 0; k < K - 1; k++) { + expected_discrete[M(k)] = discrete_seq[k]; + } + EXPECT(assert_equal(expected_discrete, delta.discrete())); + + // Update nonlinear solution and verify + Values result = initial.retract(delta.continuous()); + Values expected_continuous; + for (size_t k = 0; k < K; k++) { + expected_continuous.insert(X(k), measurements[k]); + } + EXPECT(assert_equal(expected_continuous, result)); +} + /* ************************************************************************* */ int main() { TestResult tr;