use std::optional for specifying dead mode threshold

release/4.3a0
Varun Agrawal 2025-01-24 23:11:42 -05:00
parent d2b9eb5df6
commit 1764b58e8c
2 changed files with 9 additions and 9 deletions

View File

@ -47,8 +47,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
// TODO(Frank): This can be quite expensive *unless* the factors have already // TODO(Frank): This can be quite expensive *unless* the factors have already
// been pruned before. Another, possibly faster approach is branch and bound // been pruned before. Another, possibly faster approach is branch and bound
// search to find the K-best leaves and then create a single pruned conditional. // search to find the K-best leaves and then create a single pruned conditional.
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, bool removeDeadModes, HybridBayesNet HybridBayesNet::prune(
double deadModeThreshold) const { size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const {
// Collect all the discrete conditionals. Could be small if already pruned. // Collect all the discrete conditionals. Could be small if already pruned.
const DiscreteBayesNet marginal = discreteMarginal(); const DiscreteBayesNet marginal = discreteMarginal();
@ -66,13 +66,13 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, bool removeDeadModes,
pruned.prune(maxNrLeaves); pruned.prune(maxNrLeaves);
DiscreteValues deadModesValues; DiscreteValues deadModesValues;
if (removeDeadModes) { if (deadModeThreshold.has_value()) {
DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
for (auto dkey : pruned.discreteKeys()) { for (auto dkey : pruned.discreteKeys()) {
Vector probabilities = marginals.marginalProbabilities(dkey); Vector probabilities = marginals.marginalProbabilities(dkey);
int index = -1; int index = -1;
auto threshold = (probabilities.array() > deadModeThreshold); auto threshold = (probabilities.array() > *deadModeThreshold);
// If atleast 1 value is non-zero, then we can find the index // If atleast 1 value is non-zero, then we can find the index
// Else if all are zero, index would be set to 0 which is incorrect // Else if all are zero, index would be set to 0 which is incorrect
if (!threshold.isZero()) { if (!threshold.isZero()) {
@ -121,7 +121,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, bool removeDeadModes,
// Prune the hybrid Gaussian conditional! // Prune the hybrid Gaussian conditional!
auto prunedHybridGaussianConditional = hgc->prune(pruned); auto prunedHybridGaussianConditional = hgc->prune(pruned);
if (removeDeadModes) { if (deadModeThreshold.has_value()) {
KeyVector deadKeys, conditionalDiscreteKeys; KeyVector deadKeys, conditionalDiscreteKeys;
for (const auto &kv : deadModesValues) { for (const auto &kv : deadModesValues) {
deadKeys.push_back(kv.first); deadKeys.push_back(kv.first);

View File

@ -217,15 +217,15 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves. * @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
* *
* @param maxNrLeaves Continuous values at which to compute the error. * @param maxNrLeaves Continuous values at which to compute the error.
* @param removeDeadModes Flag to enable removal of modes which only have a
* single possible assignment.
* @param deadModeThreshold The threshold to check the mode marginals against. * @param deadModeThreshold The threshold to check the mode marginals against.
* If greater than this threshold, the mode gets assigned that value and is * If greater than this threshold, the mode gets assigned that value and is
* considered "dead" for hybrid elimination. * considered "dead" for hybrid elimination.
* The mode can then be removed since it only has a single possible
* assignment.
* @return A pruned HybridBayesNet * @return A pruned HybridBayesNet
*/ */
HybridBayesNet prune(size_t maxNrLeaves, bool removeDeadModes = false, HybridBayesNet prune(size_t maxNrLeaves,
double deadModeThreshold = 0.99) const; const std::optional<double> &deadModeThreshold) const;
/** /**
* @brief Error method using HybridValues which returns specific error for * @brief Error method using HybridValues which returns specific error for