add deadModeThreshold argument to HybridBayesNet::prune

release/4.3a0
Varun Agrawal 2025-01-24 20:08:09 -05:00
parent 8725361fd2
commit 1b79e8800f
3 changed files with 14 additions and 4 deletions

View File

@ -47,8 +47,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
// TODO(Frank): This can be quite expensive *unless* the factors have already // TODO(Frank): This can be quite expensive *unless* the factors have already
// been pruned before. Another, possibly faster approach is branch and bound // been pruned before. Another, possibly faster approach is branch and bound
// search to find the K-best leaves and then create a single pruned conditional. // search to find the K-best leaves and then create a single pruned conditional.
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, bool removeDeadModes,
bool removeDeadModes) const { double deadModeThreshold) const {
// Collect all the discrete conditionals. Could be small if already pruned. // Collect all the discrete conditionals. Could be small if already pruned.
const DiscreteBayesNet marginal = discreteMarginal(); const DiscreteBayesNet marginal = discreteMarginal();

View File

@ -219,9 +219,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @param maxNrLeaves Continuous values at which to compute the error. * @param maxNrLeaves Continuous values at which to compute the error.
* @param removeDeadModes Flag to enable removal of modes which only have a * @param removeDeadModes Flag to enable removal of modes which only have a
* single possible assignment. * single possible assignment.
* @param deadModeThreshold The threshold to check the mode marginals against.
* If greater than this threshold, the mode gets assigned that value and is
* considered "dead" for hybrid elimination.
* @return A pruned HybridBayesNet * @return A pruned HybridBayesNet
*/ */
HybridBayesNet prune(size_t maxNrLeaves, bool removeDeadModes = false) const; HybridBayesNet prune(size_t maxNrLeaves, bool removeDeadModes = false,
double deadModeThreshold = 0.99) const;
/** /**
* @brief Error method using HybridValues which returns specific error for * @brief Error method using HybridValues which returns specific error for

View File

@ -74,6 +74,11 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
ordering = *given_ordering; ordering = *given_ordering;
} }
// graph.print("Original GRAPH");
// GTSAM_PRINT(updatedGraph);
// GTSAM_PRINT(hybridBayesNet_);
// GTSAM_PRINT(ordering);
// Eliminate. // Eliminate.
HybridBayesNet bayesNetFragment = *updatedGraph.eliminateSequential(ordering); HybridBayesNet bayesNetFragment = *updatedGraph.eliminateSequential(ordering);
@ -81,7 +86,8 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
if (maxNrLeaves) { if (maxNrLeaves) {
// `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in // `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
// all the conditionals with the same keys in bayesNetFragment. // all the conditionals with the same keys in bayesNetFragment.
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, removeDeadModes_); bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, removeDeadModes_,
deadModeThreshold_);
} }
// Add the partial bayes net to the posterior bayes net. // Add the partial bayes net to the posterior bayes net.