From 5d7089a5a978f41e8cbfb3288bda323814bb0cbd Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 21 Jan 2025 15:30:03 -0500 Subject: [PATCH 01/13] break big error unit test in HBN to two smaller ones --- gtsam/hybrid/tests/testHybridBayesNet.cpp | 61 ++++++++++++----------- 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 784d9c95f..327b5b3d0 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -343,11 +343,20 @@ TEST(HybridBayesNet, Optimize) { } /* ****************************************************************************/ -// Test Bayes net error -TEST(HybridBayesNet, Pruning) { - // Create switching network with three continuous variables and two discrete: - // ϕ(x0) ϕ(x0,x1,m0) ϕ(x1,x2,m1) ϕ(x0;z0) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0) ϕ(m0,m1) - Switching s(3); +namespace hbn_error { +// Create switching network with three continuous variables and two discrete: +// ϕ(x0) ϕ(x0,x1,m0) ϕ(x1,x2,m1) ϕ(x0;z0) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0) ϕ(m0,m1) +Switching s(3); + +// The true discrete assignment +const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}}; + +} // namespace hbn_error + +/* ****************************************************************************/ +// Test Bayes net error and log-probability +TEST(HybridBayesNet, Error) { + using namespace hbn_error; HybridBayesNet::shared_ptr posterior = s.linearizedFactorGraph().eliminateSequential(); @@ -366,7 +375,6 @@ TEST(HybridBayesNet, Pruning) { EXPECT(assert_equal(expected, discretePosterior, 1e-6)); // Verify logProbability computation and check specific logProbability value - const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}}; const HybridValues hybridValues{delta.continuous(), discrete_values}; double logProbability = 0; logProbability += posterior->at(0)->asHybrid()->logProbability(hybridValues); @@ -390,17 +398,32 @@ TEST(HybridBayesNet, Pruning) { // Check agreement with discrete posterior double density = exp(logProbability + negLogConstant) / normalizer; EXPECT_DOUBLES_EQUAL(density, discretePosterior(discrete_values), 1e-6); +} + +/* ****************************************************************************/ +// Test Bayes net error and log-probability after pruning +TEST(HybridBayesNet, ErrorAfterPruning) { + using namespace hbn_error; + + HybridBayesNet::shared_ptr posterior = + s.linearizedFactorGraph().eliminateSequential(); + EXPECT_LONGS_EQUAL(5, posterior->size()); + + // Optimize + HybridValues delta = posterior->optimize(); // Prune and get probabilities - auto prunedBayesNet = posterior->prune(2); - auto prunedTree = prunedBayesNet.discretePosterior(delta.continuous()); + HybridBayesNet prunedBayesNet = posterior->prune(2); + AlgebraicDecisionTree prunedTree = + prunedBayesNet.discretePosterior(delta.continuous()); - // Regression test on pruned logProbability tree + // Regression test on pruned probability tree std::vector pruned_leaves = {0.0, 0.50758422, 0.0, 0.49241578}; AlgebraicDecisionTree expected_pruned(s.modes, pruned_leaves); EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6)); - // Regression + // Regression to check specific logProbability value + const HybridValues hybridValues{delta.continuous(), discrete_values}; double pruned_logProbability = 0; pruned_logProbability += prunedBayesNet.at(0)->asDiscrete()->logProbability(hybridValues); @@ -423,24 +446,6 @@ TEST(HybridBayesNet, Pruning) { EXPECT_DOUBLES_EQUAL(pruned_density, prunedTree(discrete_values), 1e-9); } -/* ****************************************************************************/ -// Test Bayes net pruning -TEST(HybridBayesNet, Prune) { - Switching s(4); - - HybridBayesNet::shared_ptr posterior = - s.linearizedFactorGraph().eliminateSequential(); - EXPECT_LONGS_EQUAL(7, posterior->size()); - - HybridValues delta = posterior->optimize(); - - auto prunedBayesNet = posterior->prune(2); - HybridValues pruned_delta = prunedBayesNet.optimize(); - - EXPECT(assert_equal(delta.discrete(), pruned_delta.discrete())); - EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous())); -} - /* ****************************************************************************/ // Test Bayes net updateDiscreteConditionals TEST(HybridBayesNet, UpdateDiscreteConditionals) { From fff828f5998a9e3e83336cc35565a060745e5c41 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 21 Jan 2025 15:31:01 -0500 Subject: [PATCH 02/13] move unit test for pruning --- gtsam/hybrid/tests/testHybridBayesNet.cpp | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 327b5b3d0..6a108b941 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -400,6 +400,27 @@ TEST(HybridBayesNet, Error) { EXPECT_DOUBLES_EQUAL(density, discretePosterior(discrete_values), 1e-6); } +/* ****************************************************************************/ +// Test Bayes net pruning +TEST(HybridBayesNet, Prune) { + Switching s(3); + + HybridBayesNet::shared_ptr posterior = + s.linearizedFactorGraph().eliminateSequential(); + EXPECT_LONGS_EQUAL(7, posterior->size()); + + // Call Max-Product to get MAP + HybridValues delta = posterior->optimize(); + + // Prune the Bayes net + auto prunedBayesNet = posterior->prune(2); + + // Test if Max-Product gives the same result as unpruned version + HybridValues pruned_delta = prunedBayesNet.optimize(); + EXPECT(assert_equal(delta.discrete(), pruned_delta.discrete())); + EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous())); +} + /* ****************************************************************************/ // Test Bayes net error and log-probability after pruning TEST(HybridBayesNet, ErrorAfterPruning) { From 47f47fedc1d5305e9dceeb31750785deb3df6de7 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 21 Jan 2025 16:33:48 -0500 Subject: [PATCH 03/13] HybridSmoother tests --- gtsam/hybrid/tests/testHybridEstimation.cpp | 97 ----------- gtsam/hybrid/tests/testHybridSmoother.cpp | 177 ++++++++++++++++++++ 2 files changed, 177 insertions(+), 97 deletions(-) create mode 100644 gtsam/hybrid/tests/testHybridSmoother.cpp diff --git a/gtsam/hybrid/tests/testHybridEstimation.cpp b/gtsam/hybrid/tests/testHybridEstimation.cpp index ef2ae9c41..4b7dd9fd6 100644 --- a/gtsam/hybrid/tests/testHybridEstimation.cpp +++ b/gtsam/hybrid/tests/testHybridEstimation.cpp @@ -115,103 +115,6 @@ TEST(HybridEstimation, Full) { EXPECT(assert_equal(expected_continuous, result)); } -/****************************************************************************/ -// Test approximate inference with an additional pruning step. -TEST(HybridEstimation, IncrementalSmoother) { - using namespace estimation_fixture; - - size_t K = 15; - - // Switching example of robot moving in 1D - // with given measurements and equal mode priors. - HybridNonlinearFactorGraph graph; - Values initial; - Switching switching = InitializeEstimationProblem(K, 1.0, 0.1, measurements, - "1/1 1/1", graph, initial); - HybridSmoother smoother; - - HybridGaussianFactorGraph linearized; - - constexpr size_t maxNrLeaves = 3; - for (size_t k = 1; k < K; k++) { - if (k > 1) graph.push_back(switching.modeChain.at(k - 1)); // Mode chain - graph.push_back(switching.binaryFactors.at(k - 1)); // Motion Model - graph.push_back(switching.unaryFactors.at(k)); // Measurement - - initial.insert(X(k), switching.linearizationPoint.at(X(k))); - - linearized = *graph.linearize(initial); - Ordering ordering = smoother.getOrdering(linearized); - - smoother.update(linearized, maxNrLeaves, ordering); - graph.resize(0); - } - - HybridValues delta = smoother.hybridBayesNet().optimize(); - - Values result = initial.retract(delta.continuous()); - - DiscreteValues expected_discrete; - for (size_t k = 0; k < K - 1; k++) { - expected_discrete[M(k)] = discrete_seq[k]; - } - EXPECT(assert_equal(expected_discrete, delta.discrete())); - - Values expected_continuous; - for (size_t k = 0; k < K; k++) { - expected_continuous.insert(X(k), measurements[k]); - } - EXPECT(assert_equal(expected_continuous, result)); -} - -/****************************************************************************/ -// Test if pruned factor is set to correct error and no errors are thrown. -TEST(HybridEstimation, ValidPruningError) { - using namespace estimation_fixture; - - size_t K = 8; - - HybridNonlinearFactorGraph graph; - Values initial; - Switching switching = InitializeEstimationProblem(K, 1e-2, 1e-3, measurements, - "1/1 1/1", graph, initial); - HybridSmoother smoother; - - HybridGaussianFactorGraph linearized; - - constexpr size_t maxNrLeaves = 3; - for (size_t k = 1; k < K; k++) { - if (k > 1) graph.push_back(switching.modeChain.at(k - 1)); // Mode chain - graph.push_back(switching.binaryFactors.at(k - 1)); // Motion Model - graph.push_back(switching.unaryFactors.at(k)); // Measurement - - initial.insert(X(k), switching.linearizationPoint.at(X(k))); - - linearized = *graph.linearize(initial); - Ordering ordering = smoother.getOrdering(linearized); - - smoother.update(linearized, maxNrLeaves, ordering); - - graph.resize(0); - } - - HybridValues delta = smoother.hybridBayesNet().optimize(); - - Values result = initial.retract(delta.continuous()); - - DiscreteValues expected_discrete; - for (size_t k = 0; k < K - 1; k++) { - expected_discrete[M(k)] = discrete_seq[k]; - } - EXPECT(assert_equal(expected_discrete, delta.discrete())); - - Values expected_continuous; - for (size_t k = 0; k < K; k++) { - expected_continuous.insert(X(k), measurements[k]); - } - EXPECT(assert_equal(expected_continuous, result)); -} - /****************************************************************************/ // Test approximate inference with an additional pruning step. TEST(HybridEstimation, ISAM) { diff --git a/gtsam/hybrid/tests/testHybridSmoother.cpp b/gtsam/hybrid/tests/testHybridSmoother.cpp new file mode 100644 index 000000000..2a3be124e --- /dev/null +++ b/gtsam/hybrid/tests/testHybridSmoother.cpp @@ -0,0 +1,177 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file testHybridSmoother.cpp + * @brief Unit tests for HybridSmoother + * @author Varun Agrawal + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Include for test suite +#include + +#include + +#include "Switching.h" + +using namespace std; +using namespace gtsam; + +using symbol_shorthand::X; +using symbol_shorthand::Z; + +namespace estimation_fixture { +std::vector measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6, + 7, 8, 9, 9, 9, 10, 11, 11, 11, 11}; +// Ground truth discrete seq +std::vector discrete_seq = {1, 1, 0, 0, 0, 1, 1, 1, 1, 0, + 1, 1, 1, 0, 0, 1, 1, 0, 0, 0}; + +Switching InitializeEstimationProblem( + const size_t K, const double between_sigma, const double measurement_sigma, + const std::vector& measurements, + const std::string& transitionProbabilityTable, + HybridNonlinearFactorGraph* graph, Values* initial) { + Switching switching(K, between_sigma, measurement_sigma, measurements, + transitionProbabilityTable); + + // Add prior on M(0) + graph->push_back(switching.modeChain.at(0)); + + // Add the X(0) prior + graph->push_back(switching.unaryFactors.at(0)); + initial->insert(X(0), switching.linearizationPoint.at(X(0))); + + return switching; +} + +} // namespace estimation_fixture + +/****************************************************************************/ +// Test approximate inference with an additional pruning step. +TEST(HybridSmoother, IncrementalSmoother) { + using namespace estimation_fixture; + + size_t K = 5; + + // Switching example of robot moving in 1D + // with given measurements and equal mode priors. + HybridNonlinearFactorGraph graph; + Values initial; + Switching switching = InitializeEstimationProblem( + K, 1.0, 0.1, measurements, "1/1 1/1", &graph, &initial); + + HybridSmoother smoother; + constexpr size_t maxNrLeaves = 5; + + // Loop over timesteps from 1...K-1 + for (size_t k = 1; k < K; k++) { + if (k > 1) graph.push_back(switching.modeChain.at(k - 1)); // Mode chain + graph.push_back(switching.binaryFactors.at(k - 1)); // Motion Model + graph.push_back(switching.unaryFactors.at(k)); // Measurement + + initial.insert(X(k), switching.linearizationPoint.at(X(k))); + + HybridGaussianFactorGraph linearized = *graph.linearize(initial); + Ordering ordering = smoother.getOrdering(linearized); + + smoother.update(linearized, maxNrLeaves, ordering); + + // Clear all the factors from the graph + graph.resize(0); + } + + EXPECT_LONGS_EQUAL(11, + smoother.hybridBayesNet().at(0)->asDiscrete()->nrValues()); + + // Get the continuous delta update as well as + // the optimal discrete assignment. + HybridValues delta = smoother.hybridBayesNet().optimize(); + + // Check discrete assignment + DiscreteValues expected_discrete; + for (size_t k = 0; k < K - 1; k++) { + expected_discrete[M(k)] = discrete_seq[k]; + } + EXPECT(assert_equal(expected_discrete, delta.discrete())); + + // Update nonlinear solution and verify + Values result = initial.retract(delta.continuous()); + Values expected_continuous; + for (size_t k = 0; k < K; k++) { + expected_continuous.insert(X(k), measurements[k]); + } + EXPECT(assert_equal(expected_continuous, result)); +} + +/****************************************************************************/ +// Test if pruned Bayes net is set to correct error and no errors are thrown. +TEST(HybridSmoother, ValidPruningError) { + using namespace estimation_fixture; + + size_t K = 8; + + // Switching example of robot moving in 1D + // with given measurements and equal mode priors. + HybridNonlinearFactorGraph graph; + Values initial; + Switching switching = InitializeEstimationProblem( + K, 0.1, 0.1, measurements, "1/1 1/1", &graph, &initial); + HybridSmoother smoother; + + constexpr size_t maxNrLeaves = 3; + for (size_t k = 1; k < K; k++) { + if (k > 1) graph.push_back(switching.modeChain.at(k - 1)); // Mode chain + graph.push_back(switching.binaryFactors.at(k - 1)); // Motion Model + graph.push_back(switching.unaryFactors.at(k)); // Measurement + + initial.insert(X(k), switching.linearizationPoint.at(X(k))); + + HybridGaussianFactorGraph linearized = *graph.linearize(initial); + Ordering ordering = smoother.getOrdering(linearized); + + smoother.update(linearized, maxNrLeaves, ordering); + + // Clear all the factors from the graph + graph.resize(0); + } + + EXPECT_LONGS_EQUAL(14, + smoother.hybridBayesNet().at(0)->asDiscrete()->nrValues()); + + // Get the continuous delta update as well as + // the optimal discrete assignment. + HybridValues delta = smoother.hybridBayesNet().optimize(); + + auto errorTree = smoother.hybridBayesNet().errorTree(delta.continuous()); + EXPECT_DOUBLES_EQUAL(0.0, errorTree(delta.discrete()), 1e-8); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ From 80d8d88fdc388d7457fb52040606de08c2fa3700 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 21 Jan 2025 16:35:57 -0500 Subject: [PATCH 04/13] return DiscreteFactor from DiscreteFactorFromErrors --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index cf56b52ed..4943f91cb 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -247,7 +247,7 @@ continuousElimination(const HybridGaussianFactorGraph &factors, * @param errors DecisionTree of (unnormalized) errors. * @return TableFactor::shared_ptr */ -static TableFactor::shared_ptr DiscreteFactorFromErrors( +static DiscreteFactor::shared_ptr DiscreteFactorFromErrors( const DiscreteKeys &discreteKeys, const AlgebraicDecisionTree &errors) { double min_log = errors.min(); From a9a1764136cbc8f61612731e549e41548fa9f441 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 21 Jan 2025 16:36:55 -0500 Subject: [PATCH 05/13] fix assertion --- gtsam/hybrid/tests/testHybridSmoother.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/hybrid/tests/testHybridSmoother.cpp b/gtsam/hybrid/tests/testHybridSmoother.cpp index 2a3be124e..145f44d1e 100644 --- a/gtsam/hybrid/tests/testHybridSmoother.cpp +++ b/gtsam/hybrid/tests/testHybridSmoother.cpp @@ -166,7 +166,7 @@ TEST(HybridSmoother, ValidPruningError) { HybridValues delta = smoother.hybridBayesNet().optimize(); auto errorTree = smoother.hybridBayesNet().errorTree(delta.continuous()); - EXPECT_DOUBLES_EQUAL(0.0, errorTree(delta.discrete()), 1e-8); + EXPECT_DOUBLES_EQUAL(1e-8, errorTree(delta.discrete()), 1e-8); } /* ************************************************************************* */ From ff9a56c055b691c338a365395d35b6cde7890985 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 21 Jan 2025 20:22:41 -0500 Subject: [PATCH 06/13] new removeModes method --- gtsam/hybrid/HybridConditional.cpp | 37 ++++++++++++++++++++++++++++++ gtsam/hybrid/HybridConditional.h | 9 ++++++++ 2 files changed, 46 insertions(+) diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index 97ec1a1f8..77a0c781e 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -169,4 +169,41 @@ double HybridConditional::evaluate(const HybridValues &values) const { return std::exp(logProbability(values)); } +/* ************************************************************************ */ +void HybridConditional::removeModes(const DiscreteValues &given) { + if (this->isDiscrete()) { + auto d = this->asDiscrete(); + + AlgebraicDecisionTree tree(*d); + for (auto [key, value] : given) { + tree = tree.choose(key, value); + } + + // Get the leftover DiscreteKeys + DiscreteKeys dkeys; + for (DiscreteKey dkey : d->discreteKeys()) { + if (given.count(dkey.first) == 0) { + dkeys.emplace_back(dkey); + } + } + inner_ = std::make_shared(dkeys.size(), dkeys, tree); + + } else if (this->isHybrid()) { + auto d = this->asHybrid(); + HybridGaussianFactor::FactorValuePairs tree = d->factors(); + for (auto [key, value] : given) { + tree = tree.choose(key, value); + } + + // Get the leftover DiscreteKeys + DiscreteKeys dkeys; + for (DiscreteKey dkey : d->discreteKeys()) { + if (given.count(dkey.first) == 0) { + dkeys.emplace_back(dkey); + } + } + inner_ = std::make_shared(dkeys, tree); + } +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index 3cf5b80e5..a33ef7327 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -215,6 +215,15 @@ class GTSAM_EXPORT HybridConditional return true; } + /** + * @brief Remove the modes whose assignments are given to us. + * + * Imperative method so we can update nodes in the Bayes net or Bayes tree. + * + * @param given The discrete modes whose assignments we know. + */ + void removeModes(const DiscreteValues& given); + /// @} private: From 22bf9df39abd5b9f643ed630641d111562c21143 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 21 Jan 2025 20:44:03 -0500 Subject: [PATCH 07/13] remove dead modes in HybridBayesNet --- gtsam/hybrid/HybridBayesNet.cpp | 52 +++++++++++++++++++++-- gtsam/hybrid/HybridBayesNet.h | 3 +- gtsam/hybrid/tests/testHybridBayesNet.cpp | 31 +++++++++++++- 3 files changed, 81 insertions(+), 5 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 8668bedd6..a665f6f92 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -46,7 +47,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { // TODO(Frank): This can be quite expensive *unless* the factors have already // been pruned before. Another, possibly faster approach is branch and bound // search to find the K-best leaves and then create a single pruned conditional. -HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { +HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, + bool removeDeadModes) const { // Collect all the discrete conditionals. Could be small if already pruned. const DiscreteBayesNet marginal = discreteMarginal(); @@ -66,6 +68,30 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { // we can prune HybridGaussianConditionals. DiscreteConditional pruned = *result.back()->asDiscrete(); + DiscreteValues deadModesValues; + if (removeDeadModes) { + DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); + for (auto dkey : pruned.discreteKeys()) { + Vector probabilities = marginals.marginalProbabilities(dkey); + + int index = -1; + auto threshold = (probabilities.array() > 0.99); + // If atleast 1 value is non-zero, then we can find the index + // Else if all are zero, index would be set to 0 which is incorrect + if (!threshold.isZero()) { + threshold.maxCoeff(&index); + } + + if (index >= 0) { + deadModesValues.insert(std::make_pair(dkey.first, index)); + } + } + + // Remove the modes (imperative) + result.back()->removeModes(deadModesValues); + pruned = *result.back()->asDiscrete(); + } + /* To prune, we visitWith every leaf in the HybridGaussianConditional. * For each leaf, using the assignment we can check the discrete decision tree * for 0.0 probability, then just set the leaf to a nullptr. @@ -80,8 +106,28 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { // Prune the hybrid Gaussian conditional! auto prunedHybridGaussianConditional = hgc->prune(pruned); - // Type-erase and add to the pruned Bayes Net fragment. - result.push_back(prunedHybridGaussianConditional); + if (removeDeadModes) { + KeyVector deadKeys, conditionalDiscreteKeys; + for (const auto &kv : deadModesValues) { + deadKeys.push_back(kv.first); + } + for (auto dkey : prunedHybridGaussianConditional->discreteKeys()) { + conditionalDiscreteKeys.push_back(dkey.first); + } + // The discrete keys in the conditional are the same as the keys in the + // dead modes, then we just get the corresponding Gaussian conditional. + if (deadKeys == conditionalDiscreteKeys) { + result.push_back( + prunedHybridGaussianConditional->choose(deadModesValues)); + } else { + // Add as-is + result.push_back(prunedHybridGaussianConditional); + } + } else { + // Type-erase and add to the pruned Bayes Net fragment. + result.push_back(prunedHybridGaussianConditional); + } + } else if (auto gc = conditional->asGaussian()) { // Add the non-HybridGaussianConditional conditional result.push_back(gc); diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 3e07c71ce..f114a6fa0 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -209,9 +209,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves. * * @param maxNrLeaves Continuous values at which to compute the error. + * @param removeDeadModes * @return A pruned HybridBayesNet */ - HybridBayesNet prune(size_t maxNrLeaves) const; + HybridBayesNet prune(size_t maxNrLeaves, bool removeDeadModes = false) const; /** * @brief Error method using HybridValues which returns specific error for diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 6a108b941..d2f67aa47 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -407,7 +407,7 @@ TEST(HybridBayesNet, Prune) { HybridBayesNet::shared_ptr posterior = s.linearizedFactorGraph().eliminateSequential(); - EXPECT_LONGS_EQUAL(7, posterior->size()); + EXPECT_LONGS_EQUAL(5, posterior->size()); // Call Max-Product to get MAP HybridValues delta = posterior->optimize(); @@ -421,6 +421,35 @@ TEST(HybridBayesNet, Prune) { EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous())); } +/* ****************************************************************************/ +// Test Bayes net pruning and dead node removal +TEST(HybridBayesNet, RemoveDeadNodes) { + Switching s(3); + + HybridBayesNet::shared_ptr posterior = + s.linearizedFactorGraph().eliminateSequential(); + EXPECT_LONGS_EQUAL(5, posterior->size()); + + // Call Max-Product to get MAP + HybridValues delta = posterior->optimize(); + + // Prune the Bayes net + const bool pruneDeadVariables = true; + auto prunedBayesNet = posterior->prune(2, pruneDeadVariables); + + // Check that discrete joint only has M0 and not (M0, M1) + // since M0 is removed + KeyVector actual_keys = prunedBayesNet.at(0)->asDiscrete()->keys(); + EXPECT(KeyVector{M(0)} == actual_keys); + + // Check that hybrid conditionals that only depend on M1 are no longer hybrid + EXPECT(prunedBayesNet.at(0)->isDiscrete()); + EXPECT(prunedBayesNet.at(1)->isHybrid()); + // Only P(X2 | X1, M1) depends on M1, so it is Gaussian + EXPECT(prunedBayesNet.at(2)->isContinuous()); + EXPECT(prunedBayesNet.at(3)->isHybrid()); +} + /* ****************************************************************************/ // Test Bayes net error and log-probability after pruning TEST(HybridBayesNet, ErrorAfterPruning) { From abbbde980fae5189ef6ab3caf83b8f6565d417e4 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 21 Jan 2025 20:48:20 -0500 Subject: [PATCH 08/13] make removeDiscreteModes only apply to discrete conditionals --- gtsam/hybrid/HybridBayesNet.cpp | 2 +- gtsam/hybrid/HybridConditional.cpp | 18 +----------------- gtsam/hybrid/HybridConditional.h | 5 +++-- 3 files changed, 5 insertions(+), 20 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index a665f6f92..0cdc09336 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -88,7 +88,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, } // Remove the modes (imperative) - result.back()->removeModes(deadModesValues); + result.back()->removeDiscreteModes(deadModesValues); pruned = *result.back()->asDiscrete(); } diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index 77a0c781e..1e5a851e6 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -170,7 +170,7 @@ double HybridConditional::evaluate(const HybridValues &values) const { } /* ************************************************************************ */ -void HybridConditional::removeModes(const DiscreteValues &given) { +void HybridConditional::removeDiscreteModes(const DiscreteValues &given) { if (this->isDiscrete()) { auto d = this->asDiscrete(); @@ -187,22 +187,6 @@ void HybridConditional::removeModes(const DiscreteValues &given) { } } inner_ = std::make_shared(dkeys.size(), dkeys, tree); - - } else if (this->isHybrid()) { - auto d = this->asHybrid(); - HybridGaussianFactor::FactorValuePairs tree = d->factors(); - for (auto [key, value] : given) { - tree = tree.choose(key, value); - } - - // Get the leftover DiscreteKeys - DiscreteKeys dkeys; - for (DiscreteKey dkey : d->discreteKeys()) { - if (given.count(dkey.first) == 0) { - dkeys.emplace_back(dkey); - } - } - inner_ = std::make_shared(dkeys, tree); } } diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index a33ef7327..480a6481b 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -216,13 +216,14 @@ class GTSAM_EXPORT HybridConditional } /** - * @brief Remove the modes whose assignments are given to us. + * @brief Remove the discrete modes whose assignments are given to us. + * Only applies to discrete conditionals. * * Imperative method so we can update nodes in the Bayes net or Bayes tree. * * @param given The discrete modes whose assignments we know. */ - void removeModes(const DiscreteValues& given); + void removeDiscreteModes(const DiscreteValues& given); /// @} From 600a87bbbcf43cdcca7e7cbd0eb8c2c9ce50e5ff Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 21 Jan 2025 20:49:37 -0500 Subject: [PATCH 09/13] update docstring --- gtsam/hybrid/HybridBayesNet.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index f114a6fa0..7e374caff 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -209,7 +209,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves. * * @param maxNrLeaves Continuous values at which to compute the error. - * @param removeDeadModes + * @param removeDeadModes Flag to enable removal of modes which only have a + * single possible assignment. * @return A pruned HybridBayesNet */ HybridBayesNet prune(size_t maxNrLeaves, bool removeDeadModes = false) const; From 6ed18a94a3f08f95db9ffe352bb70d0d9336aa9b Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 22 Jan 2025 10:22:17 -0500 Subject: [PATCH 10/13] use emplace for deadModesValues --- gtsam/hybrid/HybridBayesNet.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 0cdc09336..268c865b0 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -83,7 +83,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, } if (index >= 0) { - deadModesValues.insert(std::make_pair(dkey.first, index)); + deadModesValues.emplace(dkey.first, index); } } From 3a58adbd8a6227485c592290ab63abd96cdc6abe Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 22 Jan 2025 10:22:36 -0500 Subject: [PATCH 11/13] update comments --- gtsam/hybrid/tests/testHybridBayesNet.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index d2f67aa47..56e93499b 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -442,10 +442,12 @@ TEST(HybridBayesNet, RemoveDeadNodes) { KeyVector actual_keys = prunedBayesNet.at(0)->asDiscrete()->keys(); EXPECT(KeyVector{M(0)} == actual_keys); - // Check that hybrid conditionals that only depend on M1 are no longer hybrid + // Check that hybrid conditionals that only depend on M1 + // are now Gaussian and not Hybrid EXPECT(prunedBayesNet.at(0)->isDiscrete()); EXPECT(prunedBayesNet.at(1)->isHybrid()); - // Only P(X2 | X1, M1) depends on M1, so it is Gaussian + // Only P(X2 | X1, M1) depends on M1, + // so it gets convert to a Gaussian P(X2 | X1) EXPECT(prunedBayesNet.at(2)->isContinuous()); EXPECT(prunedBayesNet.at(3)->isHybrid()); } From 7ecf978683a3ab974631307fff2129228f8fde3e Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 22 Jan 2025 11:09:26 -0500 Subject: [PATCH 12/13] move removeDiscreteModes to DiscreteConditional --- gtsam/discrete/DiscreteConditional.cpp | 34 ++++++++++++++++++++++++++ gtsam/discrete/DiscreteConditional.h | 10 ++++++++ gtsam/hybrid/HybridBayesNet.cpp | 2 +- gtsam/hybrid/HybridConditional.cpp | 21 ---------------- gtsam/hybrid/HybridConditional.h | 10 -------- 5 files changed, 45 insertions(+), 32 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index f5ad2b98a..4c9e06c70 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -499,6 +499,40 @@ void DiscreteConditional::prune(size_t maxNrAssignments) { this->root_ = pruned.root_; } +/* ************************************************************************ */ +void DiscreteConditional::removeDiscreteModes(const DiscreteValues& given) { + AlgebraicDecisionTree tree(*this); + for (auto [key, value] : given) { + tree = tree.choose(key, value); + } + + // Get the leftover DiscreteKey frontals + DiscreteKeys frontals; + for (Key key : this->frontals()) { + // Check if frontal key exists in given, if not add to new frontals + if (given.count(key) == 0) { + frontals.emplace_back(key, cardinalities_.at(key)); + } + } + // Get the leftover DiscreteKey parents + DiscreteKeys parents; + for (Key key : this->parents()) { + // Check if parent key exists in given, if not add to new parents + if (given.count(key) == 0) { + parents.emplace_back(key, cardinalities_.at(key)); + } + } + + DiscreteKeys allDkeys(frontals); + allDkeys.insert(allDkeys.end(), parents.begin(), parents.end()); + + // Update the conditional + this->keys_ = allDkeys.indices(); + this->cardinalities_ = allDkeys.cardinalities(); + this->root_ = tree.root_; + this->nrFrontals_ = frontals.size(); +} + /* ************************************************************************* */ double DiscreteConditional::negLogConstant() const { return 0.0; } diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 1bca0b09f..c22fcdf85 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -279,6 +279,16 @@ class GTSAM_EXPORT DiscreteConditional /// Prune the conditional virtual void prune(size_t maxNrAssignments); + /** + * @brief Remove the discrete modes whose assignments are given to us. + * Only applies to discrete conditionals. + * + * Imperative method so we can update nodes in the Bayes net or Bayes tree. + * + * @param given The discrete modes whose assignments we know. + */ + void removeDiscreteModes(const DiscreteValues& given); + /// @} protected: diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 268c865b0..f548efcbb 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -88,7 +88,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, } // Remove the modes (imperative) - result.back()->removeDiscreteModes(deadModesValues); + result.back()->asDiscrete()->removeDiscreteModes(deadModesValues); pruned = *result.back()->asDiscrete(); } diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index 1e5a851e6..97ec1a1f8 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -169,25 +169,4 @@ double HybridConditional::evaluate(const HybridValues &values) const { return std::exp(logProbability(values)); } -/* ************************************************************************ */ -void HybridConditional::removeDiscreteModes(const DiscreteValues &given) { - if (this->isDiscrete()) { - auto d = this->asDiscrete(); - - AlgebraicDecisionTree tree(*d); - for (auto [key, value] : given) { - tree = tree.choose(key, value); - } - - // Get the leftover DiscreteKeys - DiscreteKeys dkeys; - for (DiscreteKey dkey : d->discreteKeys()) { - if (given.count(dkey.first) == 0) { - dkeys.emplace_back(dkey); - } - } - inner_ = std::make_shared(dkeys.size(), dkeys, tree); - } -} - } // namespace gtsam diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index 480a6481b..3cf5b80e5 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -215,16 +215,6 @@ class GTSAM_EXPORT HybridConditional return true; } - /** - * @brief Remove the discrete modes whose assignments are given to us. - * Only applies to discrete conditionals. - * - * Imperative method so we can update nodes in the Bayes net or Bayes tree. - * - * @param given The discrete modes whose assignments we know. - */ - void removeDiscreteModes(const DiscreteValues& given); - /// @} private: From 49aa510d591e38fc998d1788d8e5e7b9d43aea7a Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 22 Jan 2025 11:15:48 -0500 Subject: [PATCH 13/13] replace for loop with std::for_each --- gtsam/discrete/DiscreteConditional.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 4c9e06c70..9a3b6cf46 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -508,20 +508,20 @@ void DiscreteConditional::removeDiscreteModes(const DiscreteValues& given) { // Get the leftover DiscreteKey frontals DiscreteKeys frontals; - for (Key key : this->frontals()) { + std::for_each(this->frontals().begin(), this->frontals().end(), [&](Key key) { // Check if frontal key exists in given, if not add to new frontals if (given.count(key) == 0) { - frontals.emplace_back(key, cardinalities_.at(key)); + frontals.emplace_back(key, this->cardinalities_.at(key)); } - } + }); // Get the leftover DiscreteKey parents DiscreteKeys parents; - for (Key key : this->parents()) { + std::for_each(this->parents().begin(), this->parents().end(), [&](Key key) { // Check if parent key exists in given, if not add to new parents if (given.count(key) == 0) { - parents.emplace_back(key, cardinalities_.at(key)); + parents.emplace_back(key, this->cardinalities_.at(key)); } - } + }); DiscreteKeys allDkeys(frontals); allDkeys.insert(allDkeys.end(), parents.begin(), parents.end());