use std::optional for specifying dead mode threshold
							parent
							
								
									d2b9eb5df6
								
							
						
					
					
						commit
						1764b58e8c
					
				|  | @ -47,8 +47,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { | |||
| // TODO(Frank): This can be quite expensive *unless* the factors have already
 | ||||
| // been pruned before. Another, possibly faster approach is branch and bound
 | ||||
| // search to find the K-best leaves and then create a single pruned conditional.
 | ||||
| HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, bool removeDeadModes, | ||||
|                                      double deadModeThreshold) const { | ||||
| HybridBayesNet HybridBayesNet::prune( | ||||
|     size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const { | ||||
|   // Collect all the discrete conditionals. Could be small if already pruned.
 | ||||
|   const DiscreteBayesNet marginal = discreteMarginal(); | ||||
| 
 | ||||
|  | @ -66,13 +66,13 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, bool removeDeadModes, | |||
|   pruned.prune(maxNrLeaves); | ||||
| 
 | ||||
|   DiscreteValues deadModesValues; | ||||
|   if (removeDeadModes) { | ||||
|   if (deadModeThreshold.has_value()) { | ||||
|     DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); | ||||
|     for (auto dkey : pruned.discreteKeys()) { | ||||
|       Vector probabilities = marginals.marginalProbabilities(dkey); | ||||
| 
 | ||||
|       int index = -1; | ||||
|       auto threshold = (probabilities.array() > deadModeThreshold); | ||||
|       auto threshold = (probabilities.array() > *deadModeThreshold); | ||||
|       // If atleast 1 value is non-zero, then we can find the index
 | ||||
|       // Else if all are zero, index would be set to 0 which is incorrect
 | ||||
|       if (!threshold.isZero()) { | ||||
|  | @ -121,7 +121,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, bool removeDeadModes, | |||
|       // Prune the hybrid Gaussian conditional!
 | ||||
|       auto prunedHybridGaussianConditional = hgc->prune(pruned); | ||||
| 
 | ||||
|       if (removeDeadModes) { | ||||
|       if (deadModeThreshold.has_value()) { | ||||
|         KeyVector deadKeys, conditionalDiscreteKeys; | ||||
|         for (const auto &kv : deadModesValues) { | ||||
|           deadKeys.push_back(kv.first); | ||||
|  |  | |||
|  | @ -217,15 +217,15 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> { | |||
|    * @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves. | ||||
|    * | ||||
|    * @param maxNrLeaves Continuous values at which to compute the error. | ||||
|    * @param removeDeadModes Flag to enable removal of modes which only have a | ||||
|    * single possible assignment. | ||||
|    * @param deadModeThreshold The threshold to check the mode marginals against. | ||||
|    * If greater than this threshold, the mode gets assigned that value and is | ||||
|    * considered "dead" for hybrid elimination. | ||||
|    * The mode can then be removed since it only has a single possible | ||||
|    * assignment. | ||||
|    * @return A pruned HybridBayesNet | ||||
|    */ | ||||
|   HybridBayesNet prune(size_t maxNrLeaves, bool removeDeadModes = false, | ||||
|                        double deadModeThreshold = 0.99) const; | ||||
|   HybridBayesNet prune(size_t maxNrLeaves, | ||||
|                        const std::optional<double> &deadModeThreshold) const; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Error method using HybridValues which returns specific error for | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue