Change threshold name
							parent
							
								
									3c10913c70
								
							
						
					
					
						commit
						9bae03a6fa
					
				|  | @ -75,7 +75,7 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { | ||||||
| // been pruned before. Another, possibly faster approach is branch and bound
 | // been pruned before. Another, possibly faster approach is branch and bound
 | ||||||
| // search to find the K-best leaves and then create a single pruned conditional.
 | // search to find the K-best leaves and then create a single pruned conditional.
 | ||||||
| DiscreteBayesNet DiscreteBayesNet::prune( | DiscreteBayesNet DiscreteBayesNet::prune( | ||||||
|     size_t maxNrLeaves, const std::optional<double>& deadModeThreshold, |     size_t maxNrLeaves, const std::optional<double>& marginalThreshold, | ||||||
|     DiscreteValues* fixedValues) const { |     DiscreteValues* fixedValues) const { | ||||||
|   // Multiply into one big conditional. NOTE: possibly quite expensive.
 |   // Multiply into one big conditional. NOTE: possibly quite expensive.
 | ||||||
|   DiscreteConditional joint; |   DiscreteConditional joint; | ||||||
|  | @ -89,13 +89,13 @@ DiscreteBayesNet DiscreteBayesNet::prune( | ||||||
|   DiscreteValues deadModesValues; |   DiscreteValues deadModesValues; | ||||||
|   // If we have a dead mode threshold and discrete variables left after pruning,
 |   // If we have a dead mode threshold and discrete variables left after pruning,
 | ||||||
|   // then we run dead mode removal.
 |   // then we run dead mode removal.
 | ||||||
|   if (deadModeThreshold.has_value() && pruned.keys().size() > 0) { |   if (marginalThreshold.has_value() && pruned.keys().size() > 0) { | ||||||
|     DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); |     DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); | ||||||
|     for (auto dkey : pruned.discreteKeys()) { |     for (auto dkey : pruned.discreteKeys()) { | ||||||
|       const Vector probabilities = marginals.marginalProbabilities(dkey); |       const Vector probabilities = marginals.marginalProbabilities(dkey); | ||||||
| 
 | 
 | ||||||
|       int index = -1; |       int index = -1; | ||||||
|       auto threshold = (probabilities.array() > *deadModeThreshold); |       auto threshold = (probabilities.array() > *marginalThreshold); | ||||||
|       // If atleast 1 value is non-zero, then we can find the index
 |       // If atleast 1 value is non-zero, then we can find the index
 | ||||||
|       // Else if all are zero, index would be set to 0 which is incorrect
 |       // Else if all are zero, index would be set to 0 which is incorrect
 | ||||||
|       if (!threshold.isZero()) { |       if (!threshold.isZero()) { | ||||||
|  |  | ||||||
|  | @ -128,12 +128,12 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> { | ||||||
|      * @brief Prune the Bayes net |      * @brief Prune the Bayes net | ||||||
|      * |      * | ||||||
|      * @param maxNrLeaves The maximum number of leaves to keep. |      * @param maxNrLeaves The maximum number of leaves to keep. | ||||||
|      * @param deadModeThreshold If given, threshold on marginals to prune variables. |      * @param marginalThreshold If given, threshold on marginals to prune variables. | ||||||
|      * @param fixedValues If given, return the fixed values removed. |      * @param fixedValues If given, return the fixed values removed. | ||||||
|      * @return A new DiscreteBayesNet with pruned conditionals. |      * @return A new DiscreteBayesNet with pruned conditionals. | ||||||
|      */ |      */ | ||||||
|     DiscreteBayesNet prune(size_t maxNrLeaves, |     DiscreteBayesNet prune(size_t maxNrLeaves, | ||||||
|                            const std::optional<double>& deadModeThreshold = {}, |                            const std::optional<double>& marginalThreshold = {}, | ||||||
|                            DiscreteValues* fixedValues = nullptr) const; |                            DiscreteValues* fixedValues = nullptr) const; | ||||||
| 
 | 
 | ||||||
|     ///@}
 |     ///@}
 | ||||||
|  |  | ||||||
|  | @ -43,7 +43,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| HybridBayesNet HybridBayesNet::prune( | HybridBayesNet HybridBayesNet::prune( | ||||||
|     size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const { |     size_t maxNrLeaves, const std::optional<double> &marginalThreshold, | ||||||
|  |     DiscreteValues *fixedValues) const { | ||||||
| #if GTSAM_HYBRID_TIMING | #if GTSAM_HYBRID_TIMING | ||||||
|   gttic_(HybridPruning); |   gttic_(HybridPruning); | ||||||
| #endif | #endif | ||||||
|  | @ -52,14 +53,14 @@ HybridBayesNet HybridBayesNet::prune( | ||||||
| 
 | 
 | ||||||
|   // Prune discrete Bayes net
 |   // Prune discrete Bayes net
 | ||||||
|   DiscreteValues fixed; |   DiscreteValues fixed; | ||||||
|   auto prunedBN = marginal.prune(maxNrLeaves, deadModeThreshold, &fixed); |   auto prunedBN = marginal.prune(maxNrLeaves, marginalThreshold, &fixed); | ||||||
| 
 | 
 | ||||||
|   // Multiply into one big conditional. NOTE: possibly quite expensive.
 |   // Multiply into one big conditional. NOTE: possibly quite expensive.
 | ||||||
|   DiscreteConditional pruned; |   DiscreteConditional pruned; | ||||||
|   for (auto &&conditional : prunedBN) pruned = pruned * (*conditional); |   for (auto &&conditional : prunedBN) pruned = pruned * (*conditional); | ||||||
| 
 | 
 | ||||||
|   // Set the fixed values if requested.
 |   // Set the fixed values if requested.
 | ||||||
|   if (deadModeThreshold && fixedValues) { |   if (marginalThreshold && fixedValues) { | ||||||
|     *fixedValues = fixed; |     *fixedValues = fixed; | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  | @ -71,7 +72,7 @@ HybridBayesNet HybridBayesNet::prune( | ||||||
|     if (conditional->isDiscrete()) continue; |     if (conditional->isDiscrete()) continue; | ||||||
| 
 | 
 | ||||||
|     // No-op if not a HybridGaussianConditional.
 |     // No-op if not a HybridGaussianConditional.
 | ||||||
|     if (deadModeThreshold) conditional = conditional->restrict(fixed); |     if (marginalThreshold) conditional = conditional->restrict(fixed); | ||||||
| 
 | 
 | ||||||
|     // Now decide on type what to do:
 |     // Now decide on type what to do:
 | ||||||
|     if (auto hgc = conditional->asHybrid()) { |     if (auto hgc = conditional->asHybrid()) { | ||||||
|  |  | ||||||
|  | @ -217,16 +217,18 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> { | ||||||
|    * @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves. |    * @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves. | ||||||
|    * |    * | ||||||
|    * @param maxNrLeaves Continuous values at which to compute the error. |    * @param maxNrLeaves Continuous values at which to compute the error. | ||||||
|    * @param deadModeThreshold The threshold to check the mode marginals against. |    * @param marginalThreshold The threshold to check the mode marginals against. | ||||||
|    * If greater than this threshold, the mode gets assigned that value and is |    * @param fixedValues The fixed values resulting from dead mode removal. | ||||||
|    * considered "dead" for hybrid elimination. |    * | ||||||
|    * The mode can then be removed since it only has a single possible |    * @note If marginal greater than this threshold, the mode gets assigned that | ||||||
|    * assignment. |    * value and is considered "dead" for hybrid elimination. The mode can then be | ||||||
|  |    * removed since it only has a single possible assignment. | ||||||
|  |     | ||||||
|    * @return A pruned HybridBayesNet |    * @return A pruned HybridBayesNet | ||||||
|    */ |    */ | ||||||
|   HybridBayesNet prune( |   HybridBayesNet prune(size_t maxNrLeaves, | ||||||
|       size_t maxNrLeaves, |                        const std::optional<double> &marginalThreshold = {}, | ||||||
|       const std::optional<double> &deadModeThreshold = {}) const; |                        DiscreteValues *fixedValues = nullptr) const; | ||||||
| 
 | 
 | ||||||
|   /**
 |   /**
 | ||||||
|    * @brief Error method using HybridValues which returns specific error for |    * @brief Error method using HybridValues which returns specific error for | ||||||
|  |  | ||||||
|  | @ -81,7 +81,7 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph, | ||||||
|   if (maxNrLeaves) { |   if (maxNrLeaves) { | ||||||
|     // `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
 |     // `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
 | ||||||
|     // all the conditionals with the same keys in bayesNetFragment.
 |     // all the conditionals with the same keys in bayesNetFragment.
 | ||||||
|     bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, deadModeThreshold_); |     bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, marginalThreshold_); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   // Add the partial bayes net to the posterior bayes net.
 |   // Add the partial bayes net to the posterior bayes net.
 | ||||||
|  |  | ||||||
|  | @ -30,18 +30,19 @@ class GTSAM_EXPORT HybridSmoother { | ||||||
|   HybridGaussianFactorGraph remainingFactorGraph_; |   HybridGaussianFactorGraph remainingFactorGraph_; | ||||||
| 
 | 
 | ||||||
|   /// The threshold above which we make a decision about a mode.
 |   /// The threshold above which we make a decision about a mode.
 | ||||||
|   std::optional<double> deadModeThreshold_; |   std::optional<double> marginalThreshold_; | ||||||
|  |   DiscreteValues fixedValues_; | ||||||
| 
 | 
 | ||||||
|  public: |  public: | ||||||
|   /**
 |   /**
 | ||||||
|    * @brief Constructor |    * @brief Constructor | ||||||
|    * |    * | ||||||
|    * @param removeDeadModes Flag indicating whether to remove dead modes. |    * @param removeDeadModes Flag indicating whether to remove dead modes. | ||||||
|    * @param deadModeThreshold The threshold above which a mode gets assigned a |    * @param marginalThreshold The threshold above which a mode gets assigned a | ||||||
|    * value and is considered "dead". 0.99 is a good starting value. |    * value and is considered "dead". 0.99 is a good starting value. | ||||||
|    */ |    */ | ||||||
|   HybridSmoother(const std::optional<double> deadModeThreshold = {}) |   HybridSmoother(const std::optional<double> marginalThreshold = {}) | ||||||
|       : deadModeThreshold_(deadModeThreshold) {} |       : marginalThreshold_(marginalThreshold) {} | ||||||
| 
 | 
 | ||||||
|   /**
 |   /**
 | ||||||
|    * Given new factors, perform an incremental update. |    * Given new factors, perform an incremental update. | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue