diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index f5ad2b98a..9a3b6cf46 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -499,6 +499,40 @@ void DiscreteConditional::prune(size_t maxNrAssignments) { this->root_ = pruned.root_; } +/* ************************************************************************ */ +void DiscreteConditional::removeDiscreteModes(const DiscreteValues& given) { + AlgebraicDecisionTree tree(*this); + for (auto [key, value] : given) { + tree = tree.choose(key, value); + } + + // Get the leftover DiscreteKey frontals + DiscreteKeys frontals; + std::for_each(this->frontals().begin(), this->frontals().end(), [&](Key key) { + // Check if frontal key exists in given, if not add to new frontals + if (given.count(key) == 0) { + frontals.emplace_back(key, this->cardinalities_.at(key)); + } + }); + // Get the leftover DiscreteKey parents + DiscreteKeys parents; + std::for_each(this->parents().begin(), this->parents().end(), [&](Key key) { + // Check if parent key exists in given, if not add to new parents + if (given.count(key) == 0) { + parents.emplace_back(key, this->cardinalities_.at(key)); + } + }); + + DiscreteKeys allDkeys(frontals); + allDkeys.insert(allDkeys.end(), parents.begin(), parents.end()); + + // Update the conditional + this->keys_ = allDkeys.indices(); + this->cardinalities_ = allDkeys.cardinalities(); + this->root_ = tree.root_; + this->nrFrontals_ = frontals.size(); +} + /* ************************************************************************* */ double DiscreteConditional::negLogConstant() const { return 0.0; } diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 1bca0b09f..c22fcdf85 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -279,6 +279,16 @@ class GTSAM_EXPORT DiscreteConditional /// Prune the conditional virtual void prune(size_t maxNrAssignments); + /** + * @brief Remove the discrete modes whose assignments are given to us. + * Only applies to discrete conditionals. + * + * Imperative method so we can update nodes in the Bayes net or Bayes tree. + * + * @param given The discrete modes whose assignments we know. + */ + void removeDiscreteModes(const DiscreteValues& given); + /// @} protected: diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index f83435df2..b6622980b 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -46,7 +47,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { // TODO(Frank): This can be quite expensive *unless* the factors have already // been pruned before. Another, possibly faster approach is branch and bound // search to find the K-best leaves and then create a single pruned conditional. -HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { +HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, + bool removeDeadModes) const { // Collect all the discrete conditionals. Could be small if already pruned. const DiscreteBayesNet marginal = discreteMarginal(); @@ -66,6 +68,30 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { // we can prune HybridGaussianConditionals. DiscreteConditional pruned = *result.back()->asDiscrete(); + DiscreteValues deadModesValues; + if (removeDeadModes) { + DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); + for (auto dkey : pruned.discreteKeys()) { + Vector probabilities = marginals.marginalProbabilities(dkey); + + int index = -1; + auto threshold = (probabilities.array() > 0.99); + // If atleast 1 value is non-zero, then we can find the index + // Else if all are zero, index would be set to 0 which is incorrect + if (!threshold.isZero()) { + threshold.maxCoeff(&index); + } + + if (index >= 0) { + deadModesValues.emplace(dkey.first, index); + } + } + + // Remove the modes (imperative) + result.back()->asDiscrete()->removeDiscreteModes(deadModesValues); + pruned = *result.back()->asDiscrete(); + } + /* To prune, we visitWith every leaf in the HybridGaussianConditional. * For each leaf, using the assignment we can check the discrete decision tree * for 0.0 probability, then just set the leaf to a nullptr. @@ -80,8 +106,28 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { // Prune the hybrid Gaussian conditional! auto prunedHybridGaussianConditional = hgc->prune(pruned); - // Type-erase and add to the pruned Bayes Net fragment. - result.push_back(prunedHybridGaussianConditional); + if (removeDeadModes) { + KeyVector deadKeys, conditionalDiscreteKeys; + for (const auto &kv : deadModesValues) { + deadKeys.push_back(kv.first); + } + for (auto dkey : prunedHybridGaussianConditional->discreteKeys()) { + conditionalDiscreteKeys.push_back(dkey.first); + } + // The discrete keys in the conditional are the same as the keys in the + // dead modes, then we just get the corresponding Gaussian conditional. + if (deadKeys == conditionalDiscreteKeys) { + result.push_back( + prunedHybridGaussianConditional->choose(deadModesValues)); + } else { + // Add as-is + result.push_back(prunedHybridGaussianConditional); + } + } else { + // Type-erase and add to the pruned Bayes Net fragment. + result.push_back(prunedHybridGaussianConditional); + } + } else if (auto gc = conditional->asGaussian()) { // Add the non-HybridGaussianConditional conditional result.push_back(gc); diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 90e3a6814..5d3270f4c 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -217,9 +217,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves. * * @param maxNrLeaves Continuous values at which to compute the error. + * @param removeDeadModes Flag to enable removal of modes which only have a + * single possible assignment. * @return A pruned HybridBayesNet */ - HybridBayesNet prune(size_t maxNrLeaves) const; + HybridBayesNet prune(size_t maxNrLeaves, bool removeDeadModes = false) const; /** * @brief Error method using HybridValues which returns specific error for diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index cf56b52ed..4943f91cb 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -247,7 +247,7 @@ continuousElimination(const HybridGaussianFactorGraph &factors, * @param errors DecisionTree of (unnormalized) errors. * @return TableFactor::shared_ptr */ -static TableFactor::shared_ptr DiscreteFactorFromErrors( +static DiscreteFactor::shared_ptr DiscreteFactorFromErrors( const DiscreteKeys &discreteKeys, const AlgebraicDecisionTree &errors) { double min_log = errors.min(); diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 784d9c95f..56e93499b 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -343,11 +343,20 @@ TEST(HybridBayesNet, Optimize) { } /* ****************************************************************************/ -// Test Bayes net error -TEST(HybridBayesNet, Pruning) { - // Create switching network with three continuous variables and two discrete: - // ϕ(x0) ϕ(x0,x1,m0) ϕ(x1,x2,m1) ϕ(x0;z0) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0) ϕ(m0,m1) - Switching s(3); +namespace hbn_error { +// Create switching network with three continuous variables and two discrete: +// ϕ(x0) ϕ(x0,x1,m0) ϕ(x1,x2,m1) ϕ(x0;z0) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0) ϕ(m0,m1) +Switching s(3); + +// The true discrete assignment +const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}}; + +} // namespace hbn_error + +/* ****************************************************************************/ +// Test Bayes net error and log-probability +TEST(HybridBayesNet, Error) { + using namespace hbn_error; HybridBayesNet::shared_ptr posterior = s.linearizedFactorGraph().eliminateSequential(); @@ -366,7 +375,6 @@ TEST(HybridBayesNet, Pruning) { EXPECT(assert_equal(expected, discretePosterior, 1e-6)); // Verify logProbability computation and check specific logProbability value - const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}}; const HybridValues hybridValues{delta.continuous(), discrete_values}; double logProbability = 0; logProbability += posterior->at(0)->asHybrid()->logProbability(hybridValues); @@ -390,17 +398,84 @@ TEST(HybridBayesNet, Pruning) { // Check agreement with discrete posterior double density = exp(logProbability + negLogConstant) / normalizer; EXPECT_DOUBLES_EQUAL(density, discretePosterior(discrete_values), 1e-6); +} + +/* ****************************************************************************/ +// Test Bayes net pruning +TEST(HybridBayesNet, Prune) { + Switching s(3); + + HybridBayesNet::shared_ptr posterior = + s.linearizedFactorGraph().eliminateSequential(); + EXPECT_LONGS_EQUAL(5, posterior->size()); + + // Call Max-Product to get MAP + HybridValues delta = posterior->optimize(); + + // Prune the Bayes net + auto prunedBayesNet = posterior->prune(2); + + // Test if Max-Product gives the same result as unpruned version + HybridValues pruned_delta = prunedBayesNet.optimize(); + EXPECT(assert_equal(delta.discrete(), pruned_delta.discrete())); + EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous())); +} + +/* ****************************************************************************/ +// Test Bayes net pruning and dead node removal +TEST(HybridBayesNet, RemoveDeadNodes) { + Switching s(3); + + HybridBayesNet::shared_ptr posterior = + s.linearizedFactorGraph().eliminateSequential(); + EXPECT_LONGS_EQUAL(5, posterior->size()); + + // Call Max-Product to get MAP + HybridValues delta = posterior->optimize(); + + // Prune the Bayes net + const bool pruneDeadVariables = true; + auto prunedBayesNet = posterior->prune(2, pruneDeadVariables); + + // Check that discrete joint only has M0 and not (M0, M1) + // since M0 is removed + KeyVector actual_keys = prunedBayesNet.at(0)->asDiscrete()->keys(); + EXPECT(KeyVector{M(0)} == actual_keys); + + // Check that hybrid conditionals that only depend on M1 + // are now Gaussian and not Hybrid + EXPECT(prunedBayesNet.at(0)->isDiscrete()); + EXPECT(prunedBayesNet.at(1)->isHybrid()); + // Only P(X2 | X1, M1) depends on M1, + // so it gets convert to a Gaussian P(X2 | X1) + EXPECT(prunedBayesNet.at(2)->isContinuous()); + EXPECT(prunedBayesNet.at(3)->isHybrid()); +} + +/* ****************************************************************************/ +// Test Bayes net error and log-probability after pruning +TEST(HybridBayesNet, ErrorAfterPruning) { + using namespace hbn_error; + + HybridBayesNet::shared_ptr posterior = + s.linearizedFactorGraph().eliminateSequential(); + EXPECT_LONGS_EQUAL(5, posterior->size()); + + // Optimize + HybridValues delta = posterior->optimize(); // Prune and get probabilities - auto prunedBayesNet = posterior->prune(2); - auto prunedTree = prunedBayesNet.discretePosterior(delta.continuous()); + HybridBayesNet prunedBayesNet = posterior->prune(2); + AlgebraicDecisionTree prunedTree = + prunedBayesNet.discretePosterior(delta.continuous()); - // Regression test on pruned logProbability tree + // Regression test on pruned probability tree std::vector pruned_leaves = {0.0, 0.50758422, 0.0, 0.49241578}; AlgebraicDecisionTree expected_pruned(s.modes, pruned_leaves); EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6)); - // Regression + // Regression to check specific logProbability value + const HybridValues hybridValues{delta.continuous(), discrete_values}; double pruned_logProbability = 0; pruned_logProbability += prunedBayesNet.at(0)->asDiscrete()->logProbability(hybridValues); @@ -423,24 +498,6 @@ TEST(HybridBayesNet, Pruning) { EXPECT_DOUBLES_EQUAL(pruned_density, prunedTree(discrete_values), 1e-9); } -/* ****************************************************************************/ -// Test Bayes net pruning -TEST(HybridBayesNet, Prune) { - Switching s(4); - - HybridBayesNet::shared_ptr posterior = - s.linearizedFactorGraph().eliminateSequential(); - EXPECT_LONGS_EQUAL(7, posterior->size()); - - HybridValues delta = posterior->optimize(); - - auto prunedBayesNet = posterior->prune(2); - HybridValues pruned_delta = prunedBayesNet.optimize(); - - EXPECT(assert_equal(delta.discrete(), pruned_delta.discrete())); - EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous())); -} - /* ****************************************************************************/ // Test Bayes net updateDiscreteConditionals TEST(HybridBayesNet, UpdateDiscreteConditionals) { diff --git a/gtsam/hybrid/tests/testHybridEstimation.cpp b/gtsam/hybrid/tests/testHybridEstimation.cpp index ef2ae9c41..4b7dd9fd6 100644 --- a/gtsam/hybrid/tests/testHybridEstimation.cpp +++ b/gtsam/hybrid/tests/testHybridEstimation.cpp @@ -115,103 +115,6 @@ TEST(HybridEstimation, Full) { EXPECT(assert_equal(expected_continuous, result)); } -/****************************************************************************/ -// Test approximate inference with an additional pruning step. -TEST(HybridEstimation, IncrementalSmoother) { - using namespace estimation_fixture; - - size_t K = 15; - - // Switching example of robot moving in 1D - // with given measurements and equal mode priors. - HybridNonlinearFactorGraph graph; - Values initial; - Switching switching = InitializeEstimationProblem(K, 1.0, 0.1, measurements, - "1/1 1/1", graph, initial); - HybridSmoother smoother; - - HybridGaussianFactorGraph linearized; - - constexpr size_t maxNrLeaves = 3; - for (size_t k = 1; k < K; k++) { - if (k > 1) graph.push_back(switching.modeChain.at(k - 1)); // Mode chain - graph.push_back(switching.binaryFactors.at(k - 1)); // Motion Model - graph.push_back(switching.unaryFactors.at(k)); // Measurement - - initial.insert(X(k), switching.linearizationPoint.at(X(k))); - - linearized = *graph.linearize(initial); - Ordering ordering = smoother.getOrdering(linearized); - - smoother.update(linearized, maxNrLeaves, ordering); - graph.resize(0); - } - - HybridValues delta = smoother.hybridBayesNet().optimize(); - - Values result = initial.retract(delta.continuous()); - - DiscreteValues expected_discrete; - for (size_t k = 0; k < K - 1; k++) { - expected_discrete[M(k)] = discrete_seq[k]; - } - EXPECT(assert_equal(expected_discrete, delta.discrete())); - - Values expected_continuous; - for (size_t k = 0; k < K; k++) { - expected_continuous.insert(X(k), measurements[k]); - } - EXPECT(assert_equal(expected_continuous, result)); -} - -/****************************************************************************/ -// Test if pruned factor is set to correct error and no errors are thrown. -TEST(HybridEstimation, ValidPruningError) { - using namespace estimation_fixture; - - size_t K = 8; - - HybridNonlinearFactorGraph graph; - Values initial; - Switching switching = InitializeEstimationProblem(K, 1e-2, 1e-3, measurements, - "1/1 1/1", graph, initial); - HybridSmoother smoother; - - HybridGaussianFactorGraph linearized; - - constexpr size_t maxNrLeaves = 3; - for (size_t k = 1; k < K; k++) { - if (k > 1) graph.push_back(switching.modeChain.at(k - 1)); // Mode chain - graph.push_back(switching.binaryFactors.at(k - 1)); // Motion Model - graph.push_back(switching.unaryFactors.at(k)); // Measurement - - initial.insert(X(k), switching.linearizationPoint.at(X(k))); - - linearized = *graph.linearize(initial); - Ordering ordering = smoother.getOrdering(linearized); - - smoother.update(linearized, maxNrLeaves, ordering); - - graph.resize(0); - } - - HybridValues delta = smoother.hybridBayesNet().optimize(); - - Values result = initial.retract(delta.continuous()); - - DiscreteValues expected_discrete; - for (size_t k = 0; k < K - 1; k++) { - expected_discrete[M(k)] = discrete_seq[k]; - } - EXPECT(assert_equal(expected_discrete, delta.discrete())); - - Values expected_continuous; - for (size_t k = 0; k < K; k++) { - expected_continuous.insert(X(k), measurements[k]); - } - EXPECT(assert_equal(expected_continuous, result)); -} - /****************************************************************************/ // Test approximate inference with an additional pruning step. TEST(HybridEstimation, ISAM) { diff --git a/gtsam/hybrid/tests/testHybridSmoother.cpp b/gtsam/hybrid/tests/testHybridSmoother.cpp new file mode 100644 index 000000000..145f44d1e --- /dev/null +++ b/gtsam/hybrid/tests/testHybridSmoother.cpp @@ -0,0 +1,177 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file testHybridSmoother.cpp + * @brief Unit tests for HybridSmoother + * @author Varun Agrawal + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Include for test suite +#include + +#include + +#include "Switching.h" + +using namespace std; +using namespace gtsam; + +using symbol_shorthand::X; +using symbol_shorthand::Z; + +namespace estimation_fixture { +std::vector measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6, + 7, 8, 9, 9, 9, 10, 11, 11, 11, 11}; +// Ground truth discrete seq +std::vector discrete_seq = {1, 1, 0, 0, 0, 1, 1, 1, 1, 0, + 1, 1, 1, 0, 0, 1, 1, 0, 0, 0}; + +Switching InitializeEstimationProblem( + const size_t K, const double between_sigma, const double measurement_sigma, + const std::vector& measurements, + const std::string& transitionProbabilityTable, + HybridNonlinearFactorGraph* graph, Values* initial) { + Switching switching(K, between_sigma, measurement_sigma, measurements, + transitionProbabilityTable); + + // Add prior on M(0) + graph->push_back(switching.modeChain.at(0)); + + // Add the X(0) prior + graph->push_back(switching.unaryFactors.at(0)); + initial->insert(X(0), switching.linearizationPoint.at(X(0))); + + return switching; +} + +} // namespace estimation_fixture + +/****************************************************************************/ +// Test approximate inference with an additional pruning step. +TEST(HybridSmoother, IncrementalSmoother) { + using namespace estimation_fixture; + + size_t K = 5; + + // Switching example of robot moving in 1D + // with given measurements and equal mode priors. + HybridNonlinearFactorGraph graph; + Values initial; + Switching switching = InitializeEstimationProblem( + K, 1.0, 0.1, measurements, "1/1 1/1", &graph, &initial); + + HybridSmoother smoother; + constexpr size_t maxNrLeaves = 5; + + // Loop over timesteps from 1...K-1 + for (size_t k = 1; k < K; k++) { + if (k > 1) graph.push_back(switching.modeChain.at(k - 1)); // Mode chain + graph.push_back(switching.binaryFactors.at(k - 1)); // Motion Model + graph.push_back(switching.unaryFactors.at(k)); // Measurement + + initial.insert(X(k), switching.linearizationPoint.at(X(k))); + + HybridGaussianFactorGraph linearized = *graph.linearize(initial); + Ordering ordering = smoother.getOrdering(linearized); + + smoother.update(linearized, maxNrLeaves, ordering); + + // Clear all the factors from the graph + graph.resize(0); + } + + EXPECT_LONGS_EQUAL(11, + smoother.hybridBayesNet().at(0)->asDiscrete()->nrValues()); + + // Get the continuous delta update as well as + // the optimal discrete assignment. + HybridValues delta = smoother.hybridBayesNet().optimize(); + + // Check discrete assignment + DiscreteValues expected_discrete; + for (size_t k = 0; k < K - 1; k++) { + expected_discrete[M(k)] = discrete_seq[k]; + } + EXPECT(assert_equal(expected_discrete, delta.discrete())); + + // Update nonlinear solution and verify + Values result = initial.retract(delta.continuous()); + Values expected_continuous; + for (size_t k = 0; k < K; k++) { + expected_continuous.insert(X(k), measurements[k]); + } + EXPECT(assert_equal(expected_continuous, result)); +} + +/****************************************************************************/ +// Test if pruned Bayes net is set to correct error and no errors are thrown. +TEST(HybridSmoother, ValidPruningError) { + using namespace estimation_fixture; + + size_t K = 8; + + // Switching example of robot moving in 1D + // with given measurements and equal mode priors. + HybridNonlinearFactorGraph graph; + Values initial; + Switching switching = InitializeEstimationProblem( + K, 0.1, 0.1, measurements, "1/1 1/1", &graph, &initial); + HybridSmoother smoother; + + constexpr size_t maxNrLeaves = 3; + for (size_t k = 1; k < K; k++) { + if (k > 1) graph.push_back(switching.modeChain.at(k - 1)); // Mode chain + graph.push_back(switching.binaryFactors.at(k - 1)); // Motion Model + graph.push_back(switching.unaryFactors.at(k)); // Measurement + + initial.insert(X(k), switching.linearizationPoint.at(X(k))); + + HybridGaussianFactorGraph linearized = *graph.linearize(initial); + Ordering ordering = smoother.getOrdering(linearized); + + smoother.update(linearized, maxNrLeaves, ordering); + + // Clear all the factors from the graph + graph.resize(0); + } + + EXPECT_LONGS_EQUAL(14, + smoother.hybridBayesNet().at(0)->asDiscrete()->nrValues()); + + // Get the continuous delta update as well as + // the optimal discrete assignment. + HybridValues delta = smoother.hybridBayesNet().optimize(); + + auto errorTree = smoother.hybridBayesNet().errorTree(delta.continuous()); + EXPECT_DOUBLES_EQUAL(1e-8, errorTree(delta.discrete()), 1e-8); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */