From be5d30c57a6c27192173073d42c3744146ec2340 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 25 Jan 2025 15:10:30 -0500 Subject: [PATCH 01/18] rvalue constructor --- gtsam/linear/JacobianFactor-inl.h | 41 ++++++++++++++++++++++--------- gtsam/linear/JacobianFactor.h | 22 +++++++++++------ 2 files changed, 44 insertions(+), 19 deletions(-) diff --git a/gtsam/linear/JacobianFactor-inl.h b/gtsam/linear/JacobianFactor-inl.h index 6c4cb969a..c1ef25520 100644 --- a/gtsam/linear/JacobianFactor-inl.h +++ b/gtsam/linear/JacobianFactor-inl.h @@ -30,26 +30,43 @@ namespace gtsam { } /* ************************************************************************* */ - template - JacobianFactor::JacobianFactor( - const KEYS& keys, const VerticalBlockMatrix& augmentedMatrix, const SharedDiagonal& model) : - Base(keys), Ab_(augmentedMatrix) - { + template + JacobianFactor::JacobianFactor(const KEYS& keys, + const VerticalBlockMatrix& augmentedMatrix, + const SharedDiagonal& model) + : Base(keys), Ab_(augmentedMatrix) { + checkAndAssignModel(model, augmentedMatrix); + } + + /* ************************************************************************* */ + template + JacobianFactor::JacobianFactor(const KEYS& keys, + VerticalBlockMatrix&& augmentedMatrix, + const SharedDiagonal& model) + : Base(keys), Ab_(std::move(augmentedMatrix)) { + checkAndAssignModel(model, Ab_); + } + + /* ************************************************************************* */ + void JacobianFactor::checkAndAssignModel( + const SharedDiagonal& model, const VerticalBlockMatrix& augmentedMatrix) { // Check noise model dimension - if(model && (DenseIndex)model->dim() != augmentedMatrix.rows()) + if (model && (DenseIndex)model->dim() != augmentedMatrix.rows()) throw InvalidNoiseModel(augmentedMatrix.rows(), model->dim()); // Check number of variables - if((DenseIndex)Base::keys_.size() != augmentedMatrix.nBlocks() - 1) + if ((DenseIndex)Base::keys_.size() != augmentedMatrix.nBlocks() - 1) throw std::invalid_argument( - "Error in JacobianFactor constructor input. Number of provided keys plus\n" - "one for the RHS vector must equal the number of provided matrix blocks."); + "Error in JacobianFactor constructor input. Number of provided keys " + "plus one for the RHS vector must equal the number of provided " + "matrix blocks."); // Check RHS dimension - if(augmentedMatrix(augmentedMatrix.nBlocks() - 1).cols() != 1) + if (augmentedMatrix(augmentedMatrix.nBlocks() - 1).cols() != 1) throw std::invalid_argument( - "Error in JacobianFactor constructor input. The last provided matrix block\n" - "must be the RHS vector, but the last provided block had more than one column."); + "Error in JacobianFactor constructor input. The last provided " + "matrix block must be the RHS vector, but the last provided block " + "had more than one column."); // Take noise model model_ = model; diff --git a/gtsam/linear/JacobianFactor.h b/gtsam/linear/JacobianFactor.h index a9933374f..1e82eb051 100644 --- a/gtsam/linear/JacobianFactor.h +++ b/gtsam/linear/JacobianFactor.h @@ -145,13 +145,17 @@ namespace gtsam { template JacobianFactor(const TERMS& terms, const Vector& b, const SharedDiagonal& model = SharedDiagonal()); - /** Constructor with arbitrary number keys, and where the augmented matrix is given all together - * instead of in block terms. Note that only the active view of the provided augmented matrix - * is used, and that the matrix data is copied into a newly-allocated matrix in the constructed - * factor. */ - template - JacobianFactor( - const KEYS& keys, const VerticalBlockMatrix& augmentedMatrix, const SharedDiagonal& sigmas = SharedDiagonal()); + /** Constructor with arbitrary number keys, and where the augmented matrix + * is given all together instead of in block terms. + */ + template + JacobianFactor(const KEYS& keys, const VerticalBlockMatrix& augmentedMatrix, + const SharedDiagonal& sigmas = SharedDiagonal()); + + /** Construct with an rvalue VerticalBlockMatrix, to allow std::move. */ + template + JacobianFactor(const KEYS& keys, VerticalBlockMatrix&& augmentedMatrix, + const SharedDiagonal& model); /** * Build a dense joint factor from all the factors in a factor graph. If a VariableSlots @@ -398,6 +402,10 @@ namespace gtsam { template void fillTerms(const TERMS& terms, const Vector& b, const SharedDiagonal& noiseModel); + /// Common code between VerticalBlockMatrix constructors + void checkAndAssignModel(const SharedDiagonal& model, + const VerticalBlockMatrix& augmentedMatrix); + private: /** From c60d257e801dfa8437938fb18421df679599dc49 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 24 Jan 2025 15:02:18 -0500 Subject: [PATCH 02/18] Add prior and sanitize printing --- timing/timeBatch.cpp | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/timing/timeBatch.cpp b/timing/timeBatch.cpp index f59039fa7..ec7222e61 100644 --- a/timing/timeBatch.cpp +++ b/timing/timeBatch.cpp @@ -29,18 +29,16 @@ int main(int argc, char *argv[]) { cout << "Loading data..." << endl; string datasetFile = findExampleDataFile("w10000"); - std::pair data = - load2D(datasetFile); - - NonlinearFactorGraph graph = *data.first; - Values initial = *data.second; + auto [graph, initial] = load2D(datasetFile); + graph->addPrior(0, initial->at(0), noiseModel::Unit::Create(3)); cout << "Optimizing..." << endl; gttic_(Create_optimizer); - LevenbergMarquardtOptimizer optimizer(graph, initial); + LevenbergMarquardtOptimizer optimizer(*graph, *initial); gttoc_(Create_optimizer); tictoc_print_(); + tictoc_reset_(); double lastError = optimizer.error(); do { gttic_(Iterate_optimizer); @@ -53,19 +51,19 @@ int main(int argc, char *argv[]) { } while(!checkConvergence(optimizer.params().relativeErrorTol, optimizer.params().absoluteErrorTol, optimizer.params().errorTol, lastError, optimizer.error(), optimizer.params().verbosity)); + tictoc_reset_(); // Compute marginals - Marginals marginals(graph, optimizer.values()); - int i=0; - for(Key key: initial.keys()) { + gttic_(ConstructMarginals); + Marginals marginals(*graph, optimizer.values()); + gttoc_(ConstructMarginals); + for(Key key: initial->keys()) { gttic_(marginalInformation); Matrix info = marginals.marginalInformation(key); gttoc_(marginalInformation); tictoc_finishedIteration_(); - if(i % 1000 == 0) - tictoc_print_(); - ++i; } + tictoc_print_(); } catch(std::exception& e) { cout << e.what() << endl; From d3cd876cf957f96c90d9bc4c1a8b26b4dc0a49ca Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 25 Jan 2025 15:31:40 -0500 Subject: [PATCH 03/18] make check const method --- gtsam/linear/JacobianFactor-inl.h | 39 ++++++------------------------- gtsam/linear/JacobianFactor.cpp | 22 +++++++++++++++++ gtsam/linear/JacobianFactor.h | 6 ++--- 3 files changed, 32 insertions(+), 35 deletions(-) diff --git a/gtsam/linear/JacobianFactor-inl.h b/gtsam/linear/JacobianFactor-inl.h index c1ef25520..aee0b7563 100644 --- a/gtsam/linear/JacobianFactor-inl.h +++ b/gtsam/linear/JacobianFactor-inl.h @@ -23,9 +23,9 @@ namespace gtsam { /* ************************************************************************* */ - template - JacobianFactor::JacobianFactor(const TERMS&terms, const Vector &b, const SharedDiagonal& model) - { + template + JacobianFactor::JacobianFactor(const TERMS& terms, const Vector& b, + const SharedDiagonal& model) { fillTerms(terms, b, model); } @@ -34,8 +34,8 @@ namespace gtsam { JacobianFactor::JacobianFactor(const KEYS& keys, const VerticalBlockMatrix& augmentedMatrix, const SharedDiagonal& model) - : Base(keys), Ab_(augmentedMatrix) { - checkAndAssignModel(model, augmentedMatrix); + : Base(keys), Ab_(augmentedMatrix), model_(model) { + checkAb(model, augmentedMatrix); } /* ************************************************************************* */ @@ -43,33 +43,8 @@ namespace gtsam { JacobianFactor::JacobianFactor(const KEYS& keys, VerticalBlockMatrix&& augmentedMatrix, const SharedDiagonal& model) - : Base(keys), Ab_(std::move(augmentedMatrix)) { - checkAndAssignModel(model, Ab_); - } - - /* ************************************************************************* */ - void JacobianFactor::checkAndAssignModel( - const SharedDiagonal& model, const VerticalBlockMatrix& augmentedMatrix) { - // Check noise model dimension - if (model && (DenseIndex)model->dim() != augmentedMatrix.rows()) - throw InvalidNoiseModel(augmentedMatrix.rows(), model->dim()); - - // Check number of variables - if ((DenseIndex)Base::keys_.size() != augmentedMatrix.nBlocks() - 1) - throw std::invalid_argument( - "Error in JacobianFactor constructor input. Number of provided keys " - "plus one for the RHS vector must equal the number of provided " - "matrix blocks."); - - // Check RHS dimension - if (augmentedMatrix(augmentedMatrix.nBlocks() - 1).cols() != 1) - throw std::invalid_argument( - "Error in JacobianFactor constructor input. The last provided " - "matrix block must be the RHS vector, but the last provided block " - "had more than one column."); - - // Take noise model - model_ = model; + : Base(keys), Ab_(std::move(augmentedMatrix)), model_(model) { + checkAb(model, Ab_); } /* ************************************************************************* */ diff --git a/gtsam/linear/JacobianFactor.cpp b/gtsam/linear/JacobianFactor.cpp index 51d513e33..9052cc460 100644 --- a/gtsam/linear/JacobianFactor.cpp +++ b/gtsam/linear/JacobianFactor.cpp @@ -112,6 +112,28 @@ JacobianFactor::JacobianFactor(const HessianFactor& factor) } } + /* ************************************************************************* */ +void JacobianFactor::checkAb(const SharedDiagonal& model, + const VerticalBlockMatrix& augmentedMatrix) const { + // Check noise model dimension + if (model && (DenseIndex)model->dim() != augmentedMatrix.rows()) + throw InvalidNoiseModel(augmentedMatrix.rows(), model->dim()); + + // Check number of variables + if ((DenseIndex)Base::keys_.size() != augmentedMatrix.nBlocks() - 1) + throw std::invalid_argument( + "Error in JacobianFactor constructor input. Number of provided keys " + "plus one for the RHS vector must equal the number of provided " + "matrix blocks."); + + // Check RHS dimension + if (augmentedMatrix(augmentedMatrix.nBlocks() - 1).cols() != 1) + throw std::invalid_argument( + "Error in JacobianFactor constructor input. The last provided " + "matrix block must be the RHS vector, but the last provided block " + "had more than one column."); +} + /* ************************************************************************* */ // Helper functions for combine constructor namespace { diff --git a/gtsam/linear/JacobianFactor.h b/gtsam/linear/JacobianFactor.h index 1e82eb051..33f7183a6 100644 --- a/gtsam/linear/JacobianFactor.h +++ b/gtsam/linear/JacobianFactor.h @@ -403,10 +403,10 @@ namespace gtsam { void fillTerms(const TERMS& terms, const Vector& b, const SharedDiagonal& noiseModel); /// Common code between VerticalBlockMatrix constructors - void checkAndAssignModel(const SharedDiagonal& model, - const VerticalBlockMatrix& augmentedMatrix); + void checkAb(const SharedDiagonal& model, + const VerticalBlockMatrix& augmentedMatrix) const; - private: + private: /** * Helper function for public constructors: From 85b457f1e39e27979ec7a2e493cd8be168b8da4d Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 25 Jan 2025 15:31:55 -0500 Subject: [PATCH 04/18] Define and use move constructor --- gtsam/linear/GaussianConditional-inl.h | 6 +++++ gtsam/linear/GaussianConditional.h | 37 ++++++++++++++++++++------ gtsam/linear/HessianFactor.cpp | 2 +- 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/gtsam/linear/GaussianConditional-inl.h b/gtsam/linear/GaussianConditional-inl.h index fe5b1e0d7..2756b690d 100644 --- a/gtsam/linear/GaussianConditional-inl.h +++ b/gtsam/linear/GaussianConditional-inl.h @@ -33,4 +33,10 @@ namespace gtsam { const KEYS& keys, size_t nrFrontals, const VerticalBlockMatrix& augmentedMatrix, const SharedDiagonal& sigmas) : BaseFactor(keys, augmentedMatrix, sigmas), BaseConditional(nrFrontals) {} + /* ************************************************************************* */ + template + GaussianConditional::GaussianConditional( + const KEYS& keys, size_t nrFrontals, VerticalBlockMatrix&& augmentedMatrix, const SharedDiagonal& sigmas) : + BaseFactor(keys, std::move(augmentedMatrix), sigmas), BaseConditional(nrFrontals) {} + } // gtsam diff --git a/gtsam/linear/GaussianConditional.h b/gtsam/linear/GaussianConditional.h index 14b1ce87f..d71119d6a 100644 --- a/gtsam/linear/GaussianConditional.h +++ b/gtsam/linear/GaussianConditional.h @@ -75,14 +75,35 @@ namespace gtsam { size_t nrFrontals, const Vector& d, const SharedDiagonal& sigmas = SharedDiagonal()); - /** Constructor with arbitrary number keys, and where the augmented matrix is given all together - * instead of in block terms. Note that only the active view of the provided augmented matrix - * is used, and that the matrix data is copied into a newly-allocated matrix in the constructed - * factor. */ - template - GaussianConditional( - const KEYS& keys, size_t nrFrontals, const VerticalBlockMatrix& augmentedMatrix, - const SharedDiagonal& sigmas = SharedDiagonal()); + /** + * @brief Constructor with an arbitrary number of keys, where the augmented matrix + * is given all together instead of in block terms. + * + * @tparam KEYS Type of the keys container. + * @param keys Container of keys. + * @param nrFrontals Number of frontal variables. + * @param augmentedMatrix The augmented matrix containing the coefficients. + * @param sigmas Optional noise model (default is an empty SharedDiagonal). + */ + template + GaussianConditional(const KEYS& keys, size_t nrFrontals, + const VerticalBlockMatrix& augmentedMatrix, + const SharedDiagonal& sigmas = SharedDiagonal()); + + /** + * @brief Constructor with an arbitrary number of keys, where the augmented matrix + * is given all together instead of in block terms, using move semantics for efficiency. + * + * @tparam KEYS Type of the keys container. + * @param keys Container of keys. + * @param nrFrontals Number of frontal variables. + * @param augmentedMatrix The augmented matrix containing the coefficients (moved). + * @param sigmas Optional noise model (default is an empty SharedDiagonal). + */ + template + GaussianConditional(const KEYS& keys, size_t nrFrontals, + VerticalBlockMatrix&& augmentedMatrix, + const SharedDiagonal& sigmas = SharedDiagonal()); /// Construct from mean `mu` and standard deviation `sigma`. static GaussianConditional FromMeanAndStddev(Key key, const Vector& mu, diff --git a/gtsam/linear/HessianFactor.cpp b/gtsam/linear/HessianFactor.cpp index 1172dc281..701e79f41 100644 --- a/gtsam/linear/HessianFactor.cpp +++ b/gtsam/linear/HessianFactor.cpp @@ -470,7 +470,7 @@ std::shared_ptr HessianFactor::eliminateCholesky(const Orde // TODO(frank): pre-allocate GaussianConditional and write into it const VerticalBlockMatrix Ab = info_.split(nFrontals); - conditional = std::make_shared(keys_, nFrontals, Ab); + conditional = std::make_shared(keys_, nFrontals, std::move(Ab)); // Erase the eliminated keys in this factor keys_.erase(begin(), begin() + nFrontals); From d8b75f6bd0032456a3ee7d6a5a21646f65ef8e25 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 25 Jan 2025 15:50:51 -0500 Subject: [PATCH 05/18] Move constructors --- gtsam/base/VerticalBlockMatrix.h | 43 ++++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/gtsam/base/VerticalBlockMatrix.h b/gtsam/base/VerticalBlockMatrix.h index 0ff8a8ae2..741d0e4bf 100644 --- a/gtsam/base/VerticalBlockMatrix.h +++ b/gtsam/base/VerticalBlockMatrix.h @@ -57,13 +57,46 @@ namespace gtsam { DenseIndex blockStart_; ///< Changes apparent matrix view, see main class comment. public: - /** Construct an empty VerticalBlockMatrix */ - VerticalBlockMatrix() : - rowStart_(0), rowEnd_(0), blockStart_(0) - { + VerticalBlockMatrix() : rowStart_(0), rowEnd_(0), blockStart_(0) { variableColOffsets_.push_back(0); - assertInvariants(); + } + + // Destructor + ~VerticalBlockMatrix() = default; + + // Copy constructor (default) + VerticalBlockMatrix(const VerticalBlockMatrix& other) = default; + + // Copy assignment operator (default) + VerticalBlockMatrix& operator=(const VerticalBlockMatrix& other) = default; + + // Move constructor + VerticalBlockMatrix(VerticalBlockMatrix&& other) noexcept + : matrix_(std::move(other.matrix_)), + variableColOffsets_(std::move(other.variableColOffsets_)), + rowStart_(other.rowStart_), + rowEnd_(other.rowEnd_), + blockStart_(other.blockStart_) { + other.rowStart_ = 0; + other.rowEnd_ = 0; + other.blockStart_ = 0; + } + + // Move assignment operator + VerticalBlockMatrix& operator=(VerticalBlockMatrix&& other) noexcept { + if (this != &other) { + matrix_ = std::move(other.matrix_); + variableColOffsets_ = std::move(other.variableColOffsets_); + rowStart_ = other.rowStart_; + rowEnd_ = other.rowEnd_; + blockStart_ = other.blockStart_; + + other.rowStart_ = 0; + other.rowEnd_ = 0; + other.blockStart_ = 0; + } + return *this; } /** Construct from a container of the sizes of each vertical block. */ From 02c8f02a104a57f59c5793bb5c78f1470da249c2 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 28 Jan 2025 18:19:35 -0500 Subject: [PATCH 06/18] 5% speedup by using faster setAllZero --- gtsam/base/SymmetricBlockMatrix.h | 7 ++++++- gtsam/linear/HessianFactor.cpp | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/gtsam/base/SymmetricBlockMatrix.h b/gtsam/base/SymmetricBlockMatrix.h index 378e91144..f6d4fa9c2 100644 --- a/gtsam/base/SymmetricBlockMatrix.h +++ b/gtsam/base/SymmetricBlockMatrix.h @@ -256,11 +256,16 @@ namespace gtsam { full().triangularView() = xpr.template triangularView(); } - /// Set the entire active matrix zero. + /// Set the entire *active* matrix zero. void setZero() { full().triangularView().setZero(); } + /// Set entire matrix zero. + void setAllZero() { + matrix_.setZero(); + } + /// Negate the entire active matrix. void negate(); diff --git a/gtsam/linear/HessianFactor.cpp b/gtsam/linear/HessianFactor.cpp index 701e79f41..f2130f068 100644 --- a/gtsam/linear/HessianFactor.cpp +++ b/gtsam/linear/HessianFactor.cpp @@ -245,7 +245,7 @@ HessianFactor::HessianFactor(const GaussianFactorGraph& factors, // Form A' * A gttic(update); - info_.setZero(); + info_.setAllZero(); for(const auto& factor: factors) if (factor) factor->updateHessian(keys_, &info_); From 003730f8441180c5807cf1a9855ac50fc146e0af Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 28 Jan 2025 19:33:48 -0500 Subject: [PATCH 07/18] Pre-compute slots outside of loop. --- gtsam/linear/HessianFactor.cpp | 32 +++++++++++++++----------------- gtsam/linear/JacobianFactor.cpp | 11 +++++++---- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/gtsam/linear/HessianFactor.cpp b/gtsam/linear/HessianFactor.cpp index f2130f068..0a75cc336 100644 --- a/gtsam/linear/HessianFactor.cpp +++ b/gtsam/linear/HessianFactor.cpp @@ -348,28 +348,26 @@ double HessianFactor::error(const VectorValues& c) const { /* ************************************************************************* */ void HessianFactor::updateHessian(const KeyVector& infoKeys, SymmetricBlockMatrix* info) const { - gttic(updateHessian_HessianFactor); assert(info); - // Apply updates to the upper triangle - DenseIndex nrVariablesInThisFactor = size(), nrBlocksInInfo = info->nBlocks() - 1; + gttic(updateHessian_HessianFactor); + const DenseIndex nrVariablesInThisFactor = size(); + vector slots(nrVariablesInThisFactor + 1); + for (DenseIndex j = 0; j < nrVariablesInThisFactor; ++j) + slots[j] = Slot(infoKeys, keys_[j]); + slots[nrVariablesInThisFactor] = info->nBlocks() - 1; + + // Apply updates to the upper triangle // Loop over this factor's blocks with indices (i,j) // For every block (i,j), we determine the block (I,J) in info. for (DenseIndex j = 0; j <= nrVariablesInThisFactor; ++j) { - const bool rhs = (j == nrVariablesInThisFactor); - const DenseIndex J = rhs ? nrBlocksInInfo : Slot(infoKeys, keys_[j]); - slots[j] = J; - for (DenseIndex i = 0; i <= j; ++i) { - const DenseIndex I = slots[i]; // because i<=j, slots[i] is valid. - - if (i == j) { - assert(I == J); - info->updateDiagonalBlock(I, info_.diagonalBlock(i)); - } else { - assert(i < j); - assert(I != J); - info->updateOffDiagonalBlock(I, J, info_.aboveDiagonalBlock(i, j)); - } + const DenseIndex J = slots[j]; + info->updateDiagonalBlock(J, info_.diagonalBlock(j)); + for (DenseIndex i = 0; i < j; ++i) { + const DenseIndex I = slots[i]; + assert(i < j); + assert(I != J); + info->updateOffDiagonalBlock(I, J, info_.aboveDiagonalBlock(i, j)); } } } diff --git a/gtsam/linear/JacobianFactor.cpp b/gtsam/linear/JacobianFactor.cpp index 9052cc460..1802475eb 100644 --- a/gtsam/linear/JacobianFactor.cpp +++ b/gtsam/linear/JacobianFactor.cpp @@ -602,16 +602,19 @@ void JacobianFactor::updateHessian(const KeyVector& infoKeys, // Ab_ is the augmented Jacobian matrix A, and we perform I += A'*A below DenseIndex n = Ab_.nBlocks() - 1, N = info->nBlocks() - 1; + // Pre-calculate slots + vector slots(n + 1); + for (DenseIndex j = 0; j < n; ++j) slots[j] = Slot(infoKeys, keys_[j]); + slots[n] = N; + // Apply updates to the upper triangle // Loop over blocks of A, including RHS with j==n - vector slots(n+1); for (DenseIndex j = 0; j <= n; ++j) { Eigen::Block Ab_j = Ab_(j); - const DenseIndex J = (j == n) ? N : Slot(infoKeys, keys_[j]); - slots[j] = J; + const DenseIndex J = slots[j]; // Fill off-diagonal blocks with Ai'*Aj for (DenseIndex i = 0; i < j; ++i) { - const DenseIndex I = slots[i]; // because iupdateOffDiagonalBlock(I, J, Ab_(i).transpose() * Ab_j); } // Fill diagonal block with Aj'*Aj From 98cdf1193facd7715cd37de002fb50762a3d2e2a Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 29 Jan 2025 17:54:29 -0500 Subject: [PATCH 08/18] Fix pruning --- gtsam/hybrid/HybridBayesNet.cpp | 90 +++++++++++++++++++++++---------- 1 file changed, 63 insertions(+), 27 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index e64284a94..2efb8030e 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -49,6 +49,9 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { // search to find the K-best leaves and then create a single pruned conditional. HybridBayesNet HybridBayesNet::prune( size_t maxNrLeaves, const std::optional &deadModeThreshold) const { +#if GTSAM_HYBRID_TIMING + gttic_(HybridPruning); +#endif // Collect all the discrete conditionals. Could be small if already pruned. const DiscreteBayesNet marginal = discreteMarginal(); @@ -69,6 +72,10 @@ HybridBayesNet HybridBayesNet::prune( // If we have a dead mode threshold and discrete variables left after pruning, // then we run dead mode removal. if (deadModeThreshold.has_value() && pruned.keys().size() > 0) { +#if GTSAM_HYBRID_TIMING + gttic_(DeadModeRemoval); +#endif + DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); for (auto dkey : pruned.discreteKeys()) { Vector probabilities = marginals.marginalProbabilities(dkey); @@ -89,24 +96,11 @@ HybridBayesNet HybridBayesNet::prune( // Remove the modes (imperative) pruned.removeDiscreteModes(deadModesValues); - /* - If the pruned discrete conditional has any keys left, - we add it to the HybridBayesNet. - If not, it means it is an orphan so we don't add this pruned joint, - and instead add only the marginals below. - */ - if (pruned.keys().size() > 0) { - result.emplace_shared(pruned); - } + GTSAM_PRINT(deadModesValues); - // Add the marginals for future factors - for (auto &&[key, _] : deadModesValues) { - result.push_back( - std::dynamic_pointer_cast(marginals(key))); - } - - } else { - result.emplace_shared(pruned); +#if GTSAM_HYBRID_TIMING + gttoc_(DeadModeRemoval); +#endif } /* To prune, we visitWith every leaf in the HybridGaussianConditional. @@ -122,20 +116,37 @@ HybridBayesNet HybridBayesNet::prune( if (auto hgc = conditional->asHybrid()) { // Prune the hybrid Gaussian conditional! auto prunedHybridGaussianConditional = hgc->prune(pruned); + if (!prunedHybridGaussianConditional) { + GTSAM_PRINT(marginal); + GTSAM_PRINT(pruned); + throw std::runtime_error( + "A HybridGaussianConditional had all its conditionals pruned"); + } if (deadModeThreshold.has_value()) { - KeyVector deadKeys, conditionalDiscreteKeys; - for (const auto &kv : deadModesValues) { - deadKeys.push_back(kv.first); + const auto &discreteParents = + prunedHybridGaussianConditional->discreteKeys(); + DiscreteValues deadParentValues; + DiscreteKeys liveParents; + for (const auto &key : discreteParents) { + auto it = deadModesValues.find(key.first); + if (it != deadModesValues.end()) + deadParentValues[key.first] = it->second; + else + liveParents.emplace_back(key); } - for (auto dkey : prunedHybridGaussianConditional->discreteKeys()) { - conditionalDiscreteKeys.push_back(dkey.first); - } - // The discrete keys in the conditional are the same as the keys in the - // dead modes, then we just get the corresponding Gaussian conditional. - if (deadKeys == conditionalDiscreteKeys) { + // If so then we just get the corresponding Gaussian conditional: + if (deadParentValues.size() == discreteParents.size()) { + // print on how many discreteParents we are choosing: result.push_back( - prunedHybridGaussianConditional->choose(deadModesValues)); + prunedHybridGaussianConditional->choose(deadParentValues)); + } else if (liveParents.size() > 0) { + auto newTree = prunedHybridGaussianConditional->factors(); + for (auto &&[key, value] : deadModesValues) { + newTree = newTree.choose(key, value); + } + result.emplace_shared(liveParents, + newTree); } else { // Add as-is result.push_back(prunedHybridGaussianConditional); @@ -152,6 +163,31 @@ HybridBayesNet HybridBayesNet::prune( // We ignore DiscreteConditional as they are already pruned and added. } +#if GTSAM_HYBRID_TIMING + gttoc_(HybridPruning); +#endif + + if (deadModeThreshold.has_value()) { + /* + If the pruned discrete conditional has any keys left, + we add it to the HybridBayesNet. + If not, it means it is an orphan so we don't add this pruned joint, + and instead add only the marginals below. + */ + if (pruned.keys().size() > 0) { + result.emplace_shared(pruned); + } + + // Add the marginals for future factors + // for (auto &&[key, _] : deadModesValues) { + // result.push_back( + // std::dynamic_pointer_cast(marginals(key))); + // } + + } else { + result.emplace_shared(pruned); + } + return result; } From d17215c69b0837023e7069f7fd2cbfdc1359e3fb Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 29 Jan 2025 21:36:51 -0500 Subject: [PATCH 09/18] prototype choose --- .../tests/testHybridGaussianConditional.cpp | 95 +++++++++++++++---- 1 file changed, 78 insertions(+), 17 deletions(-) diff --git a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp index 8bb83cac4..7c7fc3ea5 100644 --- a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp @@ -238,22 +238,27 @@ TEST(HybridGaussianConditional, Likelihood2) { EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8); } +/* ************************************************************************* */ +namespace two_mode_measurement { +// Create a two key conditional: +const DiscreteKeys modes{{M(1), 2}, {M(2), 2}}; +const std::vector gcs = { + GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(1), 1), + GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(2), 2), + GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(3), 3), + GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(4), 4)}; +const HybridGaussianConditional::Conditionals conditionals(modes, gcs); +const auto hgc = + std::make_shared(modes, conditionals); +} // namespace two_mode_measurement + /* ************************************************************************* */ // Test pruning a HybridGaussianConditional with two discrete keys, based on a // DecisionTreeFactor with 3 keys: TEST(HybridGaussianConditional, Prune) { - // Create a two key conditional: - DiscreteKeys modes{{M(1), 2}, {M(2), 2}}; - std::vector gcs; - for (size_t i = 0; i < 4; i++) { - gcs.push_back( - GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(i + 1), i + 1)); - } - auto empty = std::make_shared(); - HybridGaussianConditional::Conditionals conditionals(modes, gcs); - HybridGaussianConditional hgc(modes, conditionals); + using two_mode_measurement::hgc; - DiscreteKeys keys = modes; + DiscreteKeys keys = two_mode_measurement::modes; keys.push_back({M(3), 2}); { for (size_t i = 0; i < 8; i++) { @@ -262,7 +267,7 @@ TEST(HybridGaussianConditional, Prune) { const DecisionTreeFactor decisionTreeFactor(keys, potentials); // Prune the HybridGaussianConditional const auto pruned = - hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); + hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor)); // Check that the pruned HybridGaussianConditional has 1 conditional EXPECT_LONGS_EQUAL(1, pruned->nrComponents()); } @@ -273,14 +278,14 @@ TEST(HybridGaussianConditional, Prune) { const DecisionTreeFactor decisionTreeFactor(keys, potentials); const auto pruned = - hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); + hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor)); // Check that the pruned HybridGaussianConditional has 2 conditionals EXPECT_LONGS_EQUAL(2, pruned->nrComponents()); // Check that the minimum negLogConstant is set correctly EXPECT_DOUBLES_EQUAL( - hgc.conditionals()({{M(1), 0}, {M(2), 1}})->negLogConstant(), + hgc->conditionals()({{M(1), 0}, {M(2), 1}})->negLogConstant(), pruned->negLogConstant(), 1e-9); } { @@ -289,18 +294,74 @@ TEST(HybridGaussianConditional, Prune) { const DecisionTreeFactor decisionTreeFactor(keys, potentials); const auto pruned = - hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); + hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor)); // Check that the pruned HybridGaussianConditional has 3 conditionals EXPECT_LONGS_EQUAL(3, pruned->nrComponents()); // Check that the minimum negLogConstant is correct - EXPECT_DOUBLES_EQUAL(hgc.negLogConstant(), pruned->negLogConstant(), 1e-9); + EXPECT_DOUBLES_EQUAL(hgc->negLogConstant(), pruned->negLogConstant(), 1e-9); } } -/* ************************************************************************* +/* ************************************************************************* */ + +#include + +/** + * Return a HybridConditional by choosing branches based on the given discrete + * values. If all discrete parents are specified, return a HybridConditional + * which is just a GaussianConditional. */ +HybridConditional::shared_ptr choose( + const HybridGaussianConditional::shared_ptr &self, + const DiscreteValues &discreteValues) { + const auto &discreteParents = self->discreteKeys(); + DiscreteValues deadParentValues; + DiscreteKeys liveParents; + for (const auto &key : discreteParents) { + auto it = discreteValues.find(key.first); + if (it != discreteValues.end()) + deadParentValues[key.first] = it->second; + else + liveParents.emplace_back(key); + } + // If so then we just get the corresponding Gaussian conditional: + if (deadParentValues.size() == discreteParents.size()) { + // print on how many discreteParents we are choosing: + return std::make_shared(self->choose(deadParentValues)); + } else if (liveParents.size() > 0) { + auto newTree = self->factors(); + for (auto &&[key, value] : discreteValues) { + newTree = newTree.choose(key, value); + } + return std::make_shared( + std::make_shared(liveParents, newTree)); + } else { + // Add as-is + return std::make_shared(self); + } +} + +/* ************************************************************************* */ +// Test the pruning and dead-mode removal. +TEST(HybridGaussianConditional, PrunePlus) { + using two_mode_measurement::hgc; // two discrete parents + + const HybridConditional::shared_ptr same = choose(hgc, {}); + EXPECT(same->isHybrid()); + EXPECT(same->asHybrid()->nrComponents() == 4); + + const HybridConditional::shared_ptr oneParent = choose(hgc, {{M(1), 0}}); + EXPECT(oneParent->isHybrid()); + EXPECT(oneParent->asHybrid()->nrComponents() == 2); + + const HybridConditional::shared_ptr gaussian = + choose(hgc, {{M(1), 0}, {M(2), 1}}); + EXPECT(gaussian->asGaussian()); +} + +/* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr); From 803dae75f31e9b9222a783602678457d303e1c82 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 29 Jan 2025 22:04:56 -0500 Subject: [PATCH 10/18] filter and missingKeys --- gtsam/discrete/DiscreteValues.h | 82 ++++++++++++++++----- gtsam/discrete/tests/testDiscreteValues.cpp | 18 +++++ 2 files changed, 83 insertions(+), 17 deletions(-) diff --git a/gtsam/discrete/DiscreteValues.h b/gtsam/discrete/DiscreteValues.h index df4ecdbff..7c73da681 100644 --- a/gtsam/discrete/DiscreteValues.h +++ b/gtsam/discrete/DiscreteValues.h @@ -68,28 +68,74 @@ class GTSAM_EXPORT DiscreteValues : public Assignment { friend std::ostream& operator<<(std::ostream& os, const DiscreteValues& x); // insert in base class; - std::pair insert( const value_type& value ){ + std::pair insert(const value_type& value) { return Base::insert(value); } /** - * Insert key-assignment pair. - * Throws an invalid_argument exception if - * any keys to be inserted are already used. */ + * @brief Insert key-assignment pair. + * + * @param assignment The key-assignment pair to insert. + * @return DiscreteValues& Reference to the updated DiscreteValues object. + * @throws std::invalid_argument if any keys to be inserted are already used. + */ DiscreteValues& insert(const std::pair& assignment); - /** Insert all values from \c values. Throws an invalid_argument exception if - * any keys to be inserted are already used. */ + /** + * @brief Insert all values from another DiscreteValues object. + * + * @param values The DiscreteValues object containing values to insert. + * @return DiscreteValues& Reference to the updated DiscreteValues object. + * @throws std::invalid_argument if any keys to be inserted are already used. + */ DiscreteValues& insert(const DiscreteValues& values); - /** For all key/value pairs in \c values, replace values with corresponding - * keys in this object with those in \c values. Throws std::out_of_range if - * any keys in \c values are not present in this object. */ + /** + * @brief Update values with corresponding keys from another DiscreteValues + * object. + * + * @param values The DiscreteValues object containing values to update. + * @return DiscreteValues& Reference to the updated DiscreteValues object. + * @throws std::out_of_range if any keys in values are not present in this + * object. + */ DiscreteValues& update(const DiscreteValues& values); /** - * @brief Return a vector of DiscreteValues, one for each possible - * combination of values. + * @brief Filter values by keys. + * + * @param keys The keys to filter by. + * @return DiscreteValues The filtered DiscreteValues object. + */ + DiscreteValues filter(const DiscreteKeys& keys) const { + DiscreteValues result; + for (const auto& [key, _] : keys) { + if (auto it = this->find(key); it != this->end()) + result[key] = it->second; + } + return result; + } + + /** + * @brief Return the keys that are not present in the DiscreteValues object. + * + * @param keys The keys to check for. + * @return DiscreteKeys Keys not present in the DiscreteValues object. + */ + DiscreteKeys missingKeys(const DiscreteKeys& keys) const { + DiscreteKeys result; + for (const auto& [key, cardinality] : keys) { + if (!this->contains(key)) result.emplace_back(key, cardinality); + } + return result; + } + + /** + * @brief Return a vector of DiscreteValues, one for each possible combination + * of values. + * + * @param keys The keys to generate the Cartesian product for. + * @return std::vector The vector of DiscreteValues. */ static std::vector CartesianProduct( const DiscreteKeys& keys) { @@ -135,14 +181,16 @@ inline std::vector cartesianProduct(const DiscreteKeys& keys) { } /// Free version of markdown. -std::string GTSAM_EXPORT markdown(const DiscreteValues& values, - const KeyFormatter& keyFormatter = DefaultKeyFormatter, - const DiscreteValues::Names& names = {}); +std::string GTSAM_EXPORT +markdown(const DiscreteValues& values, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteValues::Names& names = {}); /// Free version of html. -std::string GTSAM_EXPORT html(const DiscreteValues& values, - const KeyFormatter& keyFormatter = DefaultKeyFormatter, - const DiscreteValues::Names& names = {}); +std::string GTSAM_EXPORT +html(const DiscreteValues& values, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteValues::Names& names = {}); // traits template <> diff --git a/gtsam/discrete/tests/testDiscreteValues.cpp b/gtsam/discrete/tests/testDiscreteValues.cpp index 47989481c..4667bc5ea 100644 --- a/gtsam/discrete/tests/testDiscreteValues.cpp +++ b/gtsam/discrete/tests/testDiscreteValues.cpp @@ -40,6 +40,24 @@ TEST(DiscreteValues, Update) { DiscreteValues(kExample).update({{12, 2}}))); } +/* ************************************************************************* */ +// Test DiscreteValues::filter +TEST(DiscreteValues, Filter) { + DiscreteValues values = {{12, 1}, {5, 0}, {13, 2}}; + DiscreteKeys keys = {{12, 0}, {13, 0}, {99, 0}}; // 99 is missing in values + + EXPECT(assert_equal(DiscreteValues({{12, 1}, {13, 2}}), values.filter(keys))); +} + +/* ************************************************************************* */ +// Test DiscreteValues::missingKeys +TEST(DiscreteValues, MissingKeys) { + DiscreteValues values = {{12, 1}, {5, 0}}; + DiscreteKeys keys = {{12, 0}, {5, 0}, {99, 0}, {42, 0}}; // 99 and 42 are missing + + EXPECT(assert_equal(DiscreteKeys({{99, 0}, {42, 0}}), values.missingKeys(keys))); +} + /* ************************************************************************* */ // Check markdown representation with a value formatter. TEST(DiscreteValues, markdownWithValueFormatter) { From 8746b15a4a4188692956054d9b7bf4ad5d95269e Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 29 Jan 2025 22:08:12 -0500 Subject: [PATCH 11/18] Use two new methods --- .../tests/testHybridGaussianConditional.cpp | 47 +++++++++++-------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp index 7c7fc3ea5..ba6eaf4cd 100644 --- a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp @@ -308,6 +308,14 @@ TEST(HybridGaussianConditional, Prune) { #include +// Helper function to apply discrete values to the tree +auto choose(auto tree, const DiscreteValues &discreteValues) { + for (const auto &[key, value] : discreteValues) { + tree = tree.choose(key, value); + } + return tree; +} + /** * Return a HybridConditional by choosing branches based on the given discrete * values. If all discrete parents are specified, return a HybridConditional @@ -316,31 +324,27 @@ TEST(HybridGaussianConditional, Prune) { HybridConditional::shared_ptr choose( const HybridGaussianConditional::shared_ptr &self, const DiscreteValues &discreteValues) { - const auto &discreteParents = self->discreteKeys(); - DiscreteValues deadParentValues; - DiscreteKeys liveParents; - for (const auto &key : discreteParents) { - auto it = discreteValues.find(key.first); - if (it != discreteValues.end()) - deadParentValues[key.first] = it->second; - else - liveParents.emplace_back(key); + auto parentValues = discreteValues.filter(self->discreteKeys()); + auto unspecifiedParentKeys = discreteValues.missingKeys(self->discreteKeys()); + + // Case 1: Fully determined, return corresponding Gaussian conditional + if (parentValues.size() == self->discreteKeys().size()) { + return std::make_shared(self->choose(parentValues)); } - // If so then we just get the corresponding Gaussian conditional: - if (deadParentValues.size() == discreteParents.size()) { - // print on how many discreteParents we are choosing: - return std::make_shared(self->choose(deadParentValues)); - } else if (liveParents.size() > 0) { + + // Case 2: Some live parents remain, build a new tree + if (!unspecifiedParentKeys.empty()) { auto newTree = self->factors(); - for (auto &&[key, value] : discreteValues) { + for (const auto &[key, value] : parentValues) { newTree = newTree.choose(key, value); } return std::make_shared( - std::make_shared(liveParents, newTree)); - } else { - // Add as-is - return std::make_shared(self); + std::make_shared(unspecifiedParentKeys, + newTree)); } + + // Case 3: No changes needed, return original + return std::make_shared(self); } /* ************************************************************************* */ @@ -356,6 +360,11 @@ TEST(HybridGaussianConditional, PrunePlus) { EXPECT(oneParent->isHybrid()); EXPECT(oneParent->asHybrid()->nrComponents() == 2); + const HybridConditional::shared_ptr oneParent2 = + choose(hgc, {{M(7), 0}, {M(1), 0}}); + EXPECT(oneParent2->isHybrid()); + EXPECT(oneParent2->asHybrid()->nrComponents() == 2); + const HybridConditional::shared_ptr gaussian = choose(hgc, {{M(1), 0}, {M(2), 1}}); EXPECT(gaussian->asGaussian()); From 4d1a8e50572eb9da23ce729e3de592e3ee17db2c Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 29 Jan 2025 22:46:33 -0500 Subject: [PATCH 12/18] restrict method --- gtsam/hybrid/HybridConditional.cpp | 37 +++++++++++++ gtsam/hybrid/HybridConditional.h | 8 +++ .../tests/testHybridGaussianConditional.cpp | 53 ++++++------------- 3 files changed, 60 insertions(+), 38 deletions(-) diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index 97ec1a1f8..257eca314 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -169,4 +169,41 @@ double HybridConditional::evaluate(const HybridValues &values) const { return std::exp(logProbability(values)); } +/* ************************************************************************ */ +HybridConditional::shared_ptr HybridConditional::restrict( + const DiscreteValues &discreteValues) const { + if (auto gc = asGaussian()) { + return std::make_shared(gc); + } else if (auto dc = asDiscrete()) { + return std::make_shared(dc); + }; + + auto hgc = asHybrid(); + if (!hgc) + throw std::runtime_error( + "HybridConditional::restrict: conditional type not handled"); + + // Case 1: Fully determined, return corresponding Gaussian conditional + auto parentValues = discreteValues.filter(discreteKeys_); + if (parentValues.size() == discreteKeys_.size()) { + return std::make_shared(hgc->choose(parentValues)); + } + + // Case 2: Some live parents remain, build a new tree + auto unspecifiedParentKeys = discreteValues.missingKeys(discreteKeys_); + if (!unspecifiedParentKeys.empty()) { + auto newTree = hgc->factors(); + for (const auto &[key, value] : parentValues) { + newTree = newTree.choose(key, value); + } + return std::make_shared( + std::make_shared(unspecifiedParentKeys, + newTree)); + } + + // Case 3: No changes needed, return original + return std::make_shared(hgc); +} + +/* ************************************************************************ */ } // namespace gtsam diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index 3cf5b80e5..075fbe411 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -215,6 +215,14 @@ class GTSAM_EXPORT HybridConditional return true; } + /** + * Return a HybridConditional by choosing branches based on the given discrete + * values. If all discrete parents are specified, return a HybridConditional + * which is just a GaussianConditional. If this conditional is *not* a hybrid + * conditional, just return that. + */ + shared_ptr restrict(const DiscreteValues& discreteValues) const; + /// @} private: diff --git a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp index ba6eaf4cd..032be5a78 100644 --- a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp @@ -316,57 +316,34 @@ auto choose(auto tree, const DiscreteValues &discreteValues) { return tree; } -/** - * Return a HybridConditional by choosing branches based on the given discrete - * values. If all discrete parents are specified, return a HybridConditional - * which is just a GaussianConditional. +/* ************************************************************************* + * This test verifies the behavior of the restrict method in different + * scenarios: + * - When no restrictions are applied. + * - When one parent is restricted. + * - When two parents are restricted. + * - When the restriction results in a Gaussian conditional. */ -HybridConditional::shared_ptr choose( - const HybridGaussianConditional::shared_ptr &self, - const DiscreteValues &discreteValues) { - auto parentValues = discreteValues.filter(self->discreteKeys()); - auto unspecifiedParentKeys = discreteValues.missingKeys(self->discreteKeys()); +TEST(HybridGaussianConditional, Restrict) { + // Create a HybridConditional with two discrete parents P(z0|m0,m1) + const auto hc = + std::make_shared(two_mode_measurement::hgc); - // Case 1: Fully determined, return corresponding Gaussian conditional - if (parentValues.size() == self->discreteKeys().size()) { - return std::make_shared(self->choose(parentValues)); - } - - // Case 2: Some live parents remain, build a new tree - if (!unspecifiedParentKeys.empty()) { - auto newTree = self->factors(); - for (const auto &[key, value] : parentValues) { - newTree = newTree.choose(key, value); - } - return std::make_shared( - std::make_shared(unspecifiedParentKeys, - newTree)); - } - - // Case 3: No changes needed, return original - return std::make_shared(self); -} - -/* ************************************************************************* */ -// Test the pruning and dead-mode removal. -TEST(HybridGaussianConditional, PrunePlus) { - using two_mode_measurement::hgc; // two discrete parents - - const HybridConditional::shared_ptr same = choose(hgc, {}); + const HybridConditional::shared_ptr same = hc->restrict({}); EXPECT(same->isHybrid()); EXPECT(same->asHybrid()->nrComponents() == 4); - const HybridConditional::shared_ptr oneParent = choose(hgc, {{M(1), 0}}); + const HybridConditional::shared_ptr oneParent = hc->restrict({{M(1), 0}}); EXPECT(oneParent->isHybrid()); EXPECT(oneParent->asHybrid()->nrComponents() == 2); const HybridConditional::shared_ptr oneParent2 = - choose(hgc, {{M(7), 0}, {M(1), 0}}); + hc->restrict({{M(7), 0}, {M(1), 0}}); EXPECT(oneParent2->isHybrid()); EXPECT(oneParent2->asHybrid()->nrComponents() == 2); const HybridConditional::shared_ptr gaussian = - choose(hgc, {{M(1), 0}, {M(2), 1}}); + hc->restrict({{M(1), 0}, {M(2), 1}}); EXPECT(gaussian->asGaussian()); } From 05ad198ca6097627469833d3a9080789aa9a55d5 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 29 Jan 2025 23:37:58 -0500 Subject: [PATCH 13/18] Use restrict inside prune --- gtsam/hybrid/HybridBayesNet.cpp | 80 ++++++++------------------------- 1 file changed, 19 insertions(+), 61 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 2efb8030e..a911a047a 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -96,71 +96,36 @@ HybridBayesNet HybridBayesNet::prune( // Remove the modes (imperative) pruned.removeDiscreteModes(deadModesValues); - GTSAM_PRINT(deadModesValues); - #if GTSAM_HYBRID_TIMING gttoc_(DeadModeRemoval); #endif } - /* To prune, we visitWith every leaf in the HybridGaussianConditional. - * For each leaf, using the assignment we can check the discrete decision tree - * for 0.0 probability, then just set the leaf to a nullptr. - * - * We can later check the HybridGaussianConditional for just nullptrs. - */ - - // Go through all the Gaussian conditionals in the Bayes Net and prune them as - // per pruned discrete joint. + // Go through all the Gaussian conditionals, restrict them according to + // deadModesValues, and then prune further. for (auto &&conditional : *this) { - if (auto hgc = conditional->asHybrid()) { + if (conditional->isDiscrete()) continue; + + // Restrict conditional using deadModesValues. + // No-op if not a HybridGaussianConditional or deadModesValues empty. + auto restricted = conditional->restrict(deadModesValues); + + // Now decide on type what to do: + if (auto hgc = restricted->asHybrid()) { // Prune the hybrid Gaussian conditional! auto prunedHybridGaussianConditional = hgc->prune(pruned); if (!prunedHybridGaussianConditional) { - GTSAM_PRINT(marginal); - GTSAM_PRINT(pruned); throw std::runtime_error( "A HybridGaussianConditional had all its conditionals pruned"); } - - if (deadModeThreshold.has_value()) { - const auto &discreteParents = - prunedHybridGaussianConditional->discreteKeys(); - DiscreteValues deadParentValues; - DiscreteKeys liveParents; - for (const auto &key : discreteParents) { - auto it = deadModesValues.find(key.first); - if (it != deadModesValues.end()) - deadParentValues[key.first] = it->second; - else - liveParents.emplace_back(key); - } - // If so then we just get the corresponding Gaussian conditional: - if (deadParentValues.size() == discreteParents.size()) { - // print on how many discreteParents we are choosing: - result.push_back( - prunedHybridGaussianConditional->choose(deadParentValues)); - } else if (liveParents.size() > 0) { - auto newTree = prunedHybridGaussianConditional->factors(); - for (auto &&[key, value] : deadModesValues) { - newTree = newTree.choose(key, value); - } - result.emplace_shared(liveParents, - newTree); - } else { - // Add as-is - result.push_back(prunedHybridGaussianConditional); - } - } else { - // Type-erase and add to the pruned Bayes Net fragment. - result.push_back(prunedHybridGaussianConditional); - } - - } else if (auto gc = conditional->asGaussian()) { + // Type-erase and add to the pruned Bayes Net fragment. + result.push_back(prunedHybridGaussianConditional); + } else if (auto gc = restricted->asGaussian()) { // Add the non-HybridGaussianConditional conditional result.push_back(gc); - } - // We ignore DiscreteConditional as they are already pruned and added. + } else + throw std::runtime_error( + "HybrdiBayesNet::prune: Unknown HybridConditional type."); } #if GTSAM_HYBRID_TIMING @@ -169,21 +134,14 @@ HybridBayesNet HybridBayesNet::prune( if (deadModeThreshold.has_value()) { /* - If the pruned discrete conditional has any keys left, - we add it to the HybridBayesNet. - If not, it means it is an orphan so we don't add this pruned joint, - and instead add only the marginals below. + If the pruned discrete conditional has any keys left, we add it to the + HybridBayesNet. If not, it means it is an orphan so we don't add this + pruned joint, and instead add only the marginals below. */ if (pruned.keys().size() > 0) { result.emplace_shared(pruned); } - // Add the marginals for future factors - // for (auto &&[key, _] : deadModesValues) { - // result.push_back( - // std::dynamic_pointer_cast(marginals(key))); - // } - } else { result.emplace_shared(pruned); } From e6662b820658e397c85e3e8e3ca4e95d60084a0b Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 29 Jan 2025 23:38:06 -0500 Subject: [PATCH 14/18] Fix unit tests --- gtsam/hybrid/tests/testHybridBayesNet.cpp | 37 ++++++++++++----------- gtsam/hybrid/tests/testHybridSmoother.cpp | 4 +-- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 86dcd48e4..3ddad23ff 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -166,8 +166,8 @@ TEST(HybridBayesNet, Tiny) { // prune auto pruned = bayesNet.prune(1); - CHECK(pruned.at(1)->asHybrid()); - EXPECT_LONGS_EQUAL(1, pruned.at(1)->asHybrid()->nrComponents()); + CHECK(pruned.at(0)->asHybrid()); + EXPECT_LONGS_EQUAL(1, pruned.at(0)->asHybrid()->nrComponents()); EXPECT(!pruned.equals(bayesNet)); // error @@ -437,20 +437,21 @@ TEST(HybridBayesNet, RemoveDeadNodes) { const double pruneDeadVariables = 0.99; auto prunedBayesNet = posterior->prune(2, pruneDeadVariables); + // First conditional is still the same: P( x0 | x1 m0) + EXPECT(prunedBayesNet.at(0)->isHybrid()); + + // Check that hybrid conditional that only depend on M1 + // is now Gaussian and not Hybrid + EXPECT(prunedBayesNet.at(1)->isContinuous()); + + // Third conditional is still Hybrid: P( x1 | m0 m1) -> P( x1 | m0) + EXPECT(prunedBayesNet.at(0)->isHybrid()); + // Check that discrete joint only has M0 and not (M0, M1) // since M0 is removed - KeyVector actual_keys = prunedBayesNet.at(0)->asDiscrete()->keys(); - EXPECT(KeyVector{M(0)} == actual_keys); - - // Check that hybrid conditionals that only depend on M1 - // are now Gaussian and not Hybrid - EXPECT(prunedBayesNet.at(0)->isDiscrete()); - EXPECT(prunedBayesNet.at(1)->isDiscrete()); - EXPECT(prunedBayesNet.at(2)->isHybrid()); - // Only P(X2 | X1, M1) depends on M1, - // so it gets convert to a Gaussian P(X2 | X1) - EXPECT(prunedBayesNet.at(3)->isContinuous()); - EXPECT(prunedBayesNet.at(4)->isHybrid()); + auto joint = prunedBayesNet.at(3)->asDiscrete(); + EXPECT(joint); + EXPECT(joint->keys() == KeyVector{M(0)}); } /* ****************************************************************************/ @@ -479,13 +480,13 @@ TEST(HybridBayesNet, ErrorAfterPruning) { const HybridValues hybridValues{delta.continuous(), discrete_values}; double pruned_logProbability = 0; pruned_logProbability += - prunedBayesNet.at(0)->asDiscrete()->logProbability(hybridValues); + prunedBayesNet.at(0)->asHybrid()->logProbability(hybridValues); pruned_logProbability += prunedBayesNet.at(1)->asHybrid()->logProbability(hybridValues); pruned_logProbability += prunedBayesNet.at(2)->asHybrid()->logProbability(hybridValues); pruned_logProbability += - prunedBayesNet.at(3)->asHybrid()->logProbability(hybridValues); + prunedBayesNet.at(3)->asDiscrete()->logProbability(hybridValues); double pruned_negLogConstant = prunedBayesNet.negLogConstant(discrete_values); @@ -548,8 +549,8 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { }; // Get the pruned discrete conditionals as an AlgebraicDecisionTree - CHECK(pruned.at(0)->asDiscrete()); - auto pruned_discrete_conditionals = pruned.at(0)->asDiscrete(); + CHECK(pruned.at(4)->asDiscrete()); + auto pruned_discrete_conditionals = pruned.at(4)->asDiscrete(); auto discrete_conditional_tree = std::dynamic_pointer_cast( pruned_discrete_conditionals); diff --git a/gtsam/hybrid/tests/testHybridSmoother.cpp b/gtsam/hybrid/tests/testHybridSmoother.cpp index 3a0f376cc..97a302faf 100644 --- a/gtsam/hybrid/tests/testHybridSmoother.cpp +++ b/gtsam/hybrid/tests/testHybridSmoother.cpp @@ -103,7 +103,7 @@ TEST(HybridSmoother, IncrementalSmoother) { } EXPECT_LONGS_EQUAL(11, - smoother.hybridBayesNet().at(3)->asDiscrete()->nrValues()); + smoother.hybridBayesNet().at(5)->asDiscrete()->nrValues()); // Get the continuous delta update as well as // the optimal discrete assignment. @@ -157,7 +157,7 @@ TEST(HybridSmoother, ValidPruningError) { } EXPECT_LONGS_EQUAL(14, - smoother.hybridBayesNet().at(6)->asDiscrete()->nrValues()); + smoother.hybridBayesNet().at(8)->asDiscrete()->nrValues()); // Get the continuous delta update as well as // the optimal discrete assignment. From 57bcb9f4e6b54f3086355bce64906fca977da227 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 29 Jan 2025 23:50:43 -0500 Subject: [PATCH 15/18] Add contains --- gtsam/discrete/DiscreteValues.h | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/gtsam/discrete/DiscreteValues.h b/gtsam/discrete/DiscreteValues.h index 7c73da681..0644e0c16 100644 --- a/gtsam/discrete/DiscreteValues.h +++ b/gtsam/discrete/DiscreteValues.h @@ -101,6 +101,14 @@ class GTSAM_EXPORT DiscreteValues : public Assignment { */ DiscreteValues& update(const DiscreteValues& values); + /** + * @brief Check if the DiscreteValues contains the given key. + * + * @param key The key to check for. + * @return True if the key is present, false otherwise. + */ + bool contains(Key key) const { return this->find(key) != this->end(); } + /** * @brief Filter values by keys. * From 957c967d0c186b09e61a1bcac40b3661d37b8beb Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Thu, 30 Jan 2025 00:02:34 -0500 Subject: [PATCH 16/18] remove old remnant --- .../hybrid/tests/testHybridGaussianConditional.cpp | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp index 032be5a78..88a4fa485 100644 --- a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -304,18 +305,6 @@ TEST(HybridGaussianConditional, Prune) { } } -/* ************************************************************************* */ - -#include - -// Helper function to apply discrete values to the tree -auto choose(auto tree, const DiscreteValues &discreteValues) { - for (const auto &[key, value] : discreteValues) { - tree = tree.choose(key, value); - } - return tree; -} - /* ************************************************************************* * This test verifies the behavior of the restrict method in different * scenarios: From 3c10913c7042ee196f2147fa6347826523c62a41 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Thu, 30 Jan 2025 07:57:42 -0500 Subject: [PATCH 17/18] DiscreteNet::prune --- gtsam/discrete/DiscreteBayesNet.cpp | 54 +++++++++++++++++ gtsam/discrete/DiscreteBayesNet.h | 12 ++++ gtsam/hybrid/HybridBayesNet.cpp | 89 ++++++----------------------- 3 files changed, 85 insertions(+), 70 deletions(-) diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index 56265b0a4..7d929f12c 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -18,6 +18,7 @@ #include #include +#include #include namespace gtsam { @@ -68,6 +69,59 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { return result; } +/* ************************************************************************* */ +// The implementation is: build the entire joint into one factor and then prune. +// TODO(Frank): This can be quite expensive *unless* the factors have already +// been pruned before. Another, possibly faster approach is branch and bound +// search to find the K-best leaves and then create a single pruned conditional. +DiscreteBayesNet DiscreteBayesNet::prune( + size_t maxNrLeaves, const std::optional& deadModeThreshold, + DiscreteValues* fixedValues) const { + // Multiply into one big conditional. NOTE: possibly quite expensive. + DiscreteConditional joint; + for (const DiscreteConditional::shared_ptr& conditional : *this) + joint = joint * (*conditional); + + // Prune the joint. NOTE: imperative and, again, possibly quite expensive. + DiscreteConditional pruned = joint; + pruned.prune(maxNrLeaves); + + DiscreteValues deadModesValues; + // If we have a dead mode threshold and discrete variables left after pruning, + // then we run dead mode removal. + if (deadModeThreshold.has_value() && pruned.keys().size() > 0) { + DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); + for (auto dkey : pruned.discreteKeys()) { + const Vector probabilities = marginals.marginalProbabilities(dkey); + + int index = -1; + auto threshold = (probabilities.array() > *deadModeThreshold); + // If atleast 1 value is non-zero, then we can find the index + // Else if all are zero, index would be set to 0 which is incorrect + if (!threshold.isZero()) { + threshold.maxCoeff(&index); + } + + if (index >= 0) { + deadModesValues.emplace(dkey.first, index); + } + } + + // Remove the modes (imperative) + pruned.removeDiscreteModes(deadModesValues); + + // Set the fixed values if requested. + if (fixedValues) { + *fixedValues = deadModesValues; + } + } + + // Return the resulting DiscreteBayesNet. + DiscreteBayesNet result; + if (pruned.keys().size() > 0) result.push_back(pruned); + return result; +} + /* *********************************************************************** */ std::string DiscreteBayesNet::markdown( const KeyFormatter& keyFormatter, diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index 738b91aa5..01b452865 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -124,6 +124,18 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet { */ DiscreteValues sample(DiscreteValues given) const; + /** + * @brief Prune the Bayes net + * + * @param maxNrLeaves The maximum number of leaves to keep. + * @param deadModeThreshold If given, threshold on marginals to prune variables. + * @param fixedValues If given, return the fixed values removed. + * @return A new DiscreteBayesNet with pruned conditionals. + */ + DiscreteBayesNet prune(size_t maxNrLeaves, + const std::optional& deadModeThreshold = {}, + DiscreteValues* fixedValues = nullptr) const; + ///@} /// @name Wrapper support /// @{ diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index a911a047a..ea4c3e80b 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -19,7 +19,6 @@ #include #include #include -#include #include #include #include @@ -43,10 +42,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { } /* ************************************************************************* */ -// The implementation is: build the entire joint into one factor and then prune. -// TODO(Frank): This can be quite expensive *unless* the factors have already -// been pruned before. Another, possibly faster approach is branch and bound -// search to find the K-best leaves and then create a single pruned conditional. HybridBayesNet HybridBayesNet::prune( size_t maxNrLeaves, const std::optional &deadModeThreshold) const { #if GTSAM_HYBRID_TIMING @@ -55,63 +50,31 @@ HybridBayesNet HybridBayesNet::prune( // Collect all the discrete conditionals. Could be small if already pruned. const DiscreteBayesNet marginal = discreteMarginal(); + // Prune discrete Bayes net + DiscreteValues fixed; + auto prunedBN = marginal.prune(maxNrLeaves, deadModeThreshold, &fixed); + // Multiply into one big conditional. NOTE: possibly quite expensive. - DiscreteConditional joint; - for (auto &&conditional : marginal) { - joint = joint * (*conditional); + DiscreteConditional pruned; + for (auto &&conditional : prunedBN) pruned = pruned * (*conditional); + + // Set the fixed values if requested. + if (deadModeThreshold && fixedValues) { + *fixedValues = fixed; } - // Initialize the resulting HybridBayesNet. HybridBayesNet result; - // Prune the joint. NOTE: imperative and, again, possibly quite expensive. - DiscreteConditional pruned = joint; - pruned.prune(maxNrLeaves); - - DiscreteValues deadModesValues; - // If we have a dead mode threshold and discrete variables left after pruning, - // then we run dead mode removal. - if (deadModeThreshold.has_value() && pruned.keys().size() > 0) { -#if GTSAM_HYBRID_TIMING - gttic_(DeadModeRemoval); -#endif - - DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); - for (auto dkey : pruned.discreteKeys()) { - Vector probabilities = marginals.marginalProbabilities(dkey); - - int index = -1; - auto threshold = (probabilities.array() > *deadModeThreshold); - // If atleast 1 value is non-zero, then we can find the index - // Else if all are zero, index would be set to 0 which is incorrect - if (!threshold.isZero()) { - threshold.maxCoeff(&index); - } - - if (index >= 0) { - deadModesValues.emplace(dkey.first, index); - } - } - - // Remove the modes (imperative) - pruned.removeDiscreteModes(deadModesValues); - -#if GTSAM_HYBRID_TIMING - gttoc_(DeadModeRemoval); -#endif - } - // Go through all the Gaussian conditionals, restrict them according to - // deadModesValues, and then prune further. - for (auto &&conditional : *this) { + // fixed values, and then prune further. + for (std::shared_ptr conditional : *this) { if (conditional->isDiscrete()) continue; - // Restrict conditional using deadModesValues. - // No-op if not a HybridGaussianConditional or deadModesValues empty. - auto restricted = conditional->restrict(deadModesValues); + // No-op if not a HybridGaussianConditional. + if (deadModeThreshold) conditional = conditional->restrict(fixed); // Now decide on type what to do: - if (auto hgc = restricted->asHybrid()) { + if (auto hgc = conditional->asHybrid()) { // Prune the hybrid Gaussian conditional! auto prunedHybridGaussianConditional = hgc->prune(pruned); if (!prunedHybridGaussianConditional) { @@ -120,7 +83,7 @@ HybridBayesNet HybridBayesNet::prune( } // Type-erase and add to the pruned Bayes Net fragment. result.push_back(prunedHybridGaussianConditional); - } else if (auto gc = restricted->asGaussian()) { + } else if (auto gc = conditional->asGaussian()) { // Add the non-HybridGaussianConditional conditional result.push_back(gc); } else @@ -128,23 +91,9 @@ HybridBayesNet HybridBayesNet::prune( "HybrdiBayesNet::prune: Unknown HybridConditional type."); } -#if GTSAM_HYBRID_TIMING - gttoc_(HybridPruning); -#endif - - if (deadModeThreshold.has_value()) { - /* - If the pruned discrete conditional has any keys left, we add it to the - HybridBayesNet. If not, it means it is an orphan so we don't add this - pruned joint, and instead add only the marginals below. - */ - if (pruned.keys().size() > 0) { - result.emplace_shared(pruned); - } - - } else { - result.emplace_shared(pruned); - } + // Add the pruned discrete conditionals to the result. + for (const DiscreteConditional::shared_ptr &discrete : prunedBN) + result.push_back(discrete); return result; } From 9bae03a6fa89b931eff1e9c558ebe682ccd2b050 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Thu, 30 Jan 2025 08:57:04 -0500 Subject: [PATCH 18/18] Change threshold name --- gtsam/discrete/DiscreteBayesNet.cpp | 6 +++--- gtsam/discrete/DiscreteBayesNet.h | 4 ++-- gtsam/hybrid/HybridBayesNet.cpp | 9 +++++---- gtsam/hybrid/HybridBayesNet.h | 18 ++++++++++-------- gtsam/hybrid/HybridSmoother.cpp | 2 +- gtsam/hybrid/HybridSmoother.h | 9 +++++---- 6 files changed, 26 insertions(+), 22 deletions(-) diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index 7d929f12c..8c04cb91c 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -75,7 +75,7 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { // been pruned before. Another, possibly faster approach is branch and bound // search to find the K-best leaves and then create a single pruned conditional. DiscreteBayesNet DiscreteBayesNet::prune( - size_t maxNrLeaves, const std::optional& deadModeThreshold, + size_t maxNrLeaves, const std::optional& marginalThreshold, DiscreteValues* fixedValues) const { // Multiply into one big conditional. NOTE: possibly quite expensive. DiscreteConditional joint; @@ -89,13 +89,13 @@ DiscreteBayesNet DiscreteBayesNet::prune( DiscreteValues deadModesValues; // If we have a dead mode threshold and discrete variables left after pruning, // then we run dead mode removal. - if (deadModeThreshold.has_value() && pruned.keys().size() > 0) { + if (marginalThreshold.has_value() && pruned.keys().size() > 0) { DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); for (auto dkey : pruned.discreteKeys()) { const Vector probabilities = marginals.marginalProbabilities(dkey); int index = -1; - auto threshold = (probabilities.array() > *deadModeThreshold); + auto threshold = (probabilities.array() > *marginalThreshold); // If atleast 1 value is non-zero, then we can find the index // Else if all are zero, index would be set to 0 which is incorrect if (!threshold.isZero()) { diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index 01b452865..eea1739f6 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -128,12 +128,12 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet { * @brief Prune the Bayes net * * @param maxNrLeaves The maximum number of leaves to keep. - * @param deadModeThreshold If given, threshold on marginals to prune variables. + * @param marginalThreshold If given, threshold on marginals to prune variables. * @param fixedValues If given, return the fixed values removed. * @return A new DiscreteBayesNet with pruned conditionals. */ DiscreteBayesNet prune(size_t maxNrLeaves, - const std::optional& deadModeThreshold = {}, + const std::optional& marginalThreshold = {}, DiscreteValues* fixedValues = nullptr) const; ///@} diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index ea4c3e80b..5353fe2e0 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -43,7 +43,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { /* ************************************************************************* */ HybridBayesNet HybridBayesNet::prune( - size_t maxNrLeaves, const std::optional &deadModeThreshold) const { + size_t maxNrLeaves, const std::optional &marginalThreshold, + DiscreteValues *fixedValues) const { #if GTSAM_HYBRID_TIMING gttic_(HybridPruning); #endif @@ -52,14 +53,14 @@ HybridBayesNet HybridBayesNet::prune( // Prune discrete Bayes net DiscreteValues fixed; - auto prunedBN = marginal.prune(maxNrLeaves, deadModeThreshold, &fixed); + auto prunedBN = marginal.prune(maxNrLeaves, marginalThreshold, &fixed); // Multiply into one big conditional. NOTE: possibly quite expensive. DiscreteConditional pruned; for (auto &&conditional : prunedBN) pruned = pruned * (*conditional); // Set the fixed values if requested. - if (deadModeThreshold && fixedValues) { + if (marginalThreshold && fixedValues) { *fixedValues = fixed; } @@ -71,7 +72,7 @@ HybridBayesNet HybridBayesNet::prune( if (conditional->isDiscrete()) continue; // No-op if not a HybridGaussianConditional. - if (deadModeThreshold) conditional = conditional->restrict(fixed); + if (marginalThreshold) conditional = conditional->restrict(fixed); // Now decide on type what to do: if (auto hgc = conditional->asHybrid()) { diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index fb05e2407..0840cb381 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -217,16 +217,18 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves. * * @param maxNrLeaves Continuous values at which to compute the error. - * @param deadModeThreshold The threshold to check the mode marginals against. - * If greater than this threshold, the mode gets assigned that value and is - * considered "dead" for hybrid elimination. - * The mode can then be removed since it only has a single possible - * assignment. + * @param marginalThreshold The threshold to check the mode marginals against. + * @param fixedValues The fixed values resulting from dead mode removal. + * + * @note If marginal greater than this threshold, the mode gets assigned that + * value and is considered "dead" for hybrid elimination. The mode can then be + * removed since it only has a single possible assignment. + * @return A pruned HybridBayesNet */ - HybridBayesNet prune( - size_t maxNrLeaves, - const std::optional &deadModeThreshold = {}) const; + HybridBayesNet prune(size_t maxNrLeaves, + const std::optional &marginalThreshold = {}, + DiscreteValues *fixedValues = nullptr) const; /** * @brief Error method using HybridValues which returns specific error for diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index 45320896a..594c12825 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -81,7 +81,7 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph, if (maxNrLeaves) { // `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in // all the conditionals with the same keys in bayesNetFragment. - bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, deadModeThreshold_); + bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, marginalThreshold_); } // Add the partial bayes net to the posterior bayes net. diff --git a/gtsam/hybrid/HybridSmoother.h b/gtsam/hybrid/HybridSmoother.h index c3f022c62..2f7bfcebb 100644 --- a/gtsam/hybrid/HybridSmoother.h +++ b/gtsam/hybrid/HybridSmoother.h @@ -30,18 +30,19 @@ class GTSAM_EXPORT HybridSmoother { HybridGaussianFactorGraph remainingFactorGraph_; /// The threshold above which we make a decision about a mode. - std::optional deadModeThreshold_; + std::optional marginalThreshold_; + DiscreteValues fixedValues_; public: /** * @brief Constructor * * @param removeDeadModes Flag indicating whether to remove dead modes. - * @param deadModeThreshold The threshold above which a mode gets assigned a + * @param marginalThreshold The threshold above which a mode gets assigned a * value and is considered "dead". 0.99 is a good starting value. */ - HybridSmoother(const std::optional deadModeThreshold = {}) - : deadModeThreshold_(deadModeThreshold) {} + HybridSmoother(const std::optional marginalThreshold = {}) + : marginalThreshold_(marginalThreshold) {} /** * Given new factors, perform an incremental update.