Change threshold name
							parent
							
								
									3c10913c70
								
							
						
					
					
						commit
						9bae03a6fa
					
				|  | @ -75,7 +75,7 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { | |||
| // been pruned before. Another, possibly faster approach is branch and bound
 | ||||
| // search to find the K-best leaves and then create a single pruned conditional.
 | ||||
| DiscreteBayesNet DiscreteBayesNet::prune( | ||||
|     size_t maxNrLeaves, const std::optional<double>& deadModeThreshold, | ||||
|     size_t maxNrLeaves, const std::optional<double>& marginalThreshold, | ||||
|     DiscreteValues* fixedValues) const { | ||||
|   // Multiply into one big conditional. NOTE: possibly quite expensive.
 | ||||
|   DiscreteConditional joint; | ||||
|  | @ -89,13 +89,13 @@ DiscreteBayesNet DiscreteBayesNet::prune( | |||
|   DiscreteValues deadModesValues; | ||||
|   // If we have a dead mode threshold and discrete variables left after pruning,
 | ||||
|   // then we run dead mode removal.
 | ||||
|   if (deadModeThreshold.has_value() && pruned.keys().size() > 0) { | ||||
|   if (marginalThreshold.has_value() && pruned.keys().size() > 0) { | ||||
|     DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); | ||||
|     for (auto dkey : pruned.discreteKeys()) { | ||||
|       const Vector probabilities = marginals.marginalProbabilities(dkey); | ||||
| 
 | ||||
|       int index = -1; | ||||
|       auto threshold = (probabilities.array() > *deadModeThreshold); | ||||
|       auto threshold = (probabilities.array() > *marginalThreshold); | ||||
|       // If atleast 1 value is non-zero, then we can find the index
 | ||||
|       // Else if all are zero, index would be set to 0 which is incorrect
 | ||||
|       if (!threshold.isZero()) { | ||||
|  |  | |||
|  | @ -128,12 +128,12 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> { | |||
|      * @brief Prune the Bayes net | ||||
|      * | ||||
|      * @param maxNrLeaves The maximum number of leaves to keep. | ||||
|      * @param deadModeThreshold If given, threshold on marginals to prune variables. | ||||
|      * @param marginalThreshold If given, threshold on marginals to prune variables. | ||||
|      * @param fixedValues If given, return the fixed values removed. | ||||
|      * @return A new DiscreteBayesNet with pruned conditionals. | ||||
|      */ | ||||
|     DiscreteBayesNet prune(size_t maxNrLeaves, | ||||
|                            const std::optional<double>& deadModeThreshold = {}, | ||||
|                            const std::optional<double>& marginalThreshold = {}, | ||||
|                            DiscreteValues* fixedValues = nullptr) const; | ||||
| 
 | ||||
|     ///@}
 | ||||
|  |  | |||
|  | @ -43,7 +43,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { | |||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| HybridBayesNet HybridBayesNet::prune( | ||||
|     size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const { | ||||
|     size_t maxNrLeaves, const std::optional<double> &marginalThreshold, | ||||
|     DiscreteValues *fixedValues) const { | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|   gttic_(HybridPruning); | ||||
| #endif | ||||
|  | @ -52,14 +53,14 @@ HybridBayesNet HybridBayesNet::prune( | |||
| 
 | ||||
|   // Prune discrete Bayes net
 | ||||
|   DiscreteValues fixed; | ||||
|   auto prunedBN = marginal.prune(maxNrLeaves, deadModeThreshold, &fixed); | ||||
|   auto prunedBN = marginal.prune(maxNrLeaves, marginalThreshold, &fixed); | ||||
| 
 | ||||
|   // Multiply into one big conditional. NOTE: possibly quite expensive.
 | ||||
|   DiscreteConditional pruned; | ||||
|   for (auto &&conditional : prunedBN) pruned = pruned * (*conditional); | ||||
| 
 | ||||
|   // Set the fixed values if requested.
 | ||||
|   if (deadModeThreshold && fixedValues) { | ||||
|   if (marginalThreshold && fixedValues) { | ||||
|     *fixedValues = fixed; | ||||
|   } | ||||
| 
 | ||||
|  | @ -71,7 +72,7 @@ HybridBayesNet HybridBayesNet::prune( | |||
|     if (conditional->isDiscrete()) continue; | ||||
| 
 | ||||
|     // No-op if not a HybridGaussianConditional.
 | ||||
|     if (deadModeThreshold) conditional = conditional->restrict(fixed); | ||||
|     if (marginalThreshold) conditional = conditional->restrict(fixed); | ||||
| 
 | ||||
|     // Now decide on type what to do:
 | ||||
|     if (auto hgc = conditional->asHybrid()) { | ||||
|  |  | |||
|  | @ -217,16 +217,18 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> { | |||
|    * @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves. | ||||
|    * | ||||
|    * @param maxNrLeaves Continuous values at which to compute the error. | ||||
|    * @param deadModeThreshold The threshold to check the mode marginals against. | ||||
|    * If greater than this threshold, the mode gets assigned that value and is | ||||
|    * considered "dead" for hybrid elimination. | ||||
|    * The mode can then be removed since it only has a single possible | ||||
|    * assignment. | ||||
|    * @param marginalThreshold The threshold to check the mode marginals against. | ||||
|    * @param fixedValues The fixed values resulting from dead mode removal. | ||||
|    * | ||||
|    * @note If marginal greater than this threshold, the mode gets assigned that | ||||
|    * value and is considered "dead" for hybrid elimination. The mode can then be | ||||
|    * removed since it only has a single possible assignment. | ||||
|     | ||||
|    * @return A pruned HybridBayesNet | ||||
|    */ | ||||
|   HybridBayesNet prune( | ||||
|       size_t maxNrLeaves, | ||||
|       const std::optional<double> &deadModeThreshold = {}) const; | ||||
|   HybridBayesNet prune(size_t maxNrLeaves, | ||||
|                        const std::optional<double> &marginalThreshold = {}, | ||||
|                        DiscreteValues *fixedValues = nullptr) const; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Error method using HybridValues which returns specific error for | ||||
|  |  | |||
|  | @ -81,7 +81,7 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph, | |||
|   if (maxNrLeaves) { | ||||
|     // `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
 | ||||
|     // all the conditionals with the same keys in bayesNetFragment.
 | ||||
|     bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, deadModeThreshold_); | ||||
|     bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, marginalThreshold_); | ||||
|   } | ||||
| 
 | ||||
|   // Add the partial bayes net to the posterior bayes net.
 | ||||
|  |  | |||
|  | @ -30,18 +30,19 @@ class GTSAM_EXPORT HybridSmoother { | |||
|   HybridGaussianFactorGraph remainingFactorGraph_; | ||||
| 
 | ||||
|   /// The threshold above which we make a decision about a mode.
 | ||||
|   std::optional<double> deadModeThreshold_; | ||||
|   std::optional<double> marginalThreshold_; | ||||
|   DiscreteValues fixedValues_; | ||||
| 
 | ||||
|  public: | ||||
|   /**
 | ||||
|    * @brief Constructor | ||||
|    * | ||||
|    * @param removeDeadModes Flag indicating whether to remove dead modes. | ||||
|    * @param deadModeThreshold The threshold above which a mode gets assigned a | ||||
|    * @param marginalThreshold The threshold above which a mode gets assigned a | ||||
|    * value and is considered "dead". 0.99 is a good starting value. | ||||
|    */ | ||||
|   HybridSmoother(const std::optional<double> deadModeThreshold = {}) | ||||
|       : deadModeThreshold_(deadModeThreshold) {} | ||||
|   HybridSmoother(const std::optional<double> marginalThreshold = {}) | ||||
|       : marginalThreshold_(marginalThreshold) {} | ||||
| 
 | ||||
|   /**
 | ||||
|    * Given new factors, perform an incremental update. | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue