support for varying normalizers in GaussianMixtureFactor
							parent
							
								
									eef9765e4a
								
							
						
					
					
						commit
						ea104c4b83
					
				|  | @ -28,11 +28,86 @@ | |||
| 
 | ||||
| namespace gtsam { | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief Helper function to correct the [A|b] matrices in the factor components | ||||
|  * with the normalizer values. | ||||
|  * This is done by storing the normalizer value in | ||||
|  * the `b` vector as an additional row. | ||||
|  * | ||||
|  * @param factors DecisionTree of GaussianFactor shared pointers. | ||||
|  * @param varyingNormalizers Flag indicating the normalizers are different for | ||||
|  * each component. | ||||
|  * @return GaussianMixtureFactor::Factors | ||||
|  */ | ||||
| GaussianMixtureFactor::Factors correct( | ||||
|     const GaussianMixtureFactor::Factors &factors, bool varyingNormalizers) { | ||||
|   if (!varyingNormalizers) { | ||||
|     return factors; | ||||
|   } | ||||
| 
 | ||||
|   // First compute all the sqrt(|2 pi Sigma|) terms
 | ||||
|   auto computeNormalizers = [](const GaussianMixtureFactor::sharedFactor &gf) { | ||||
|     auto jf = std::dynamic_pointer_cast<JacobianFactor>(gf); | ||||
|     // If we have, say, a Hessian factor, then no need to do anything
 | ||||
|     if (!jf) return 0.0; | ||||
| 
 | ||||
|     auto model = jf->get_model(); | ||||
|     // If there is no noise model, there is nothing to do.
 | ||||
|     if (!model) { | ||||
|       return 0.0; | ||||
|     } | ||||
|     // Since noise models are Gaussian, we can get the logDeterminant using the
 | ||||
|     // same trick as in GaussianConditional
 | ||||
|     double logDetR = | ||||
|         model->R().diagonal().unaryExpr([](double x) { return log(x); }).sum(); | ||||
|     double logDeterminantSigma = -2.0 * logDetR; | ||||
| 
 | ||||
|     size_t n = model->dim(); | ||||
|     constexpr double log2pi = 1.8378770664093454835606594728112; | ||||
|     return n * log2pi + logDeterminantSigma; | ||||
|   }; | ||||
| 
 | ||||
|   AlgebraicDecisionTree<Key> log_normalizers = | ||||
|       DecisionTree<Key, double>(factors, computeNormalizers); | ||||
| 
 | ||||
|   // Find the minimum value so we can "proselytize" to positive values.
 | ||||
|   // Done because we can't have sqrt of negative numbers.
 | ||||
|   double min_log_normalizer = log_normalizers.min(); | ||||
|   log_normalizers = log_normalizers.apply( | ||||
|       [&min_log_normalizer](double n) { return n - min_log_normalizer; }); | ||||
| 
 | ||||
|   // Finally, update the [A|b] matrices.
 | ||||
|   auto update = [&log_normalizers]( | ||||
|                     const Assignment<Key> &assignment, | ||||
|                     const GaussianMixtureFactor::sharedFactor &gf) { | ||||
|     auto jf = std::dynamic_pointer_cast<JacobianFactor>(gf); | ||||
|     if (!jf) return gf; | ||||
|     // If there is no noise model, there is nothing to do.
 | ||||
|     if (!jf->get_model()) return gf; | ||||
|     // If the log_normalizer is 0, do nothing
 | ||||
|     if (log_normalizers(assignment) == 0.0) return gf; | ||||
| 
 | ||||
|     GaussianFactorGraph gfg; | ||||
|     gfg.push_back(jf); | ||||
| 
 | ||||
|     Vector c(1); | ||||
|     c << std::sqrt(log_normalizers(assignment)); | ||||
|     auto constantFactor = std::make_shared<JacobianFactor>(c); | ||||
| 
 | ||||
|     gfg.push_back(constantFactor); | ||||
|     return std::dynamic_pointer_cast<GaussianFactor>( | ||||
|         std::make_shared<JacobianFactor>(gfg)); | ||||
|   }; | ||||
|   return factors.apply(update); | ||||
| } | ||||
| 
 | ||||
| /* *******************************************************************************/ | ||||
| GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys, | ||||
|                                              const DiscreteKeys &discreteKeys, | ||||
|                                              const Factors &factors) | ||||
|     : Base(continuousKeys, discreteKeys), factors_(factors) {} | ||||
|                                              const Factors &factors, | ||||
|                                              bool varyingNormalizers) | ||||
|     : Base(continuousKeys, discreteKeys), | ||||
|       factors_(correct(factors, varyingNormalizers)) {} | ||||
| 
 | ||||
| /* *******************************************************************************/ | ||||
| bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { | ||||
|  | @ -54,7 +129,9 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { | |||
| /* *******************************************************************************/ | ||||
| void GaussianMixtureFactor::print(const std::string &s, | ||||
|                                   const KeyFormatter &formatter) const { | ||||
|   HybridFactor::print(s, formatter); | ||||
|   std::cout << (s.empty() ? "" : s + "\n"); | ||||
|   std::cout << "GaussianMixtureFactor" << std::endl; | ||||
|   HybridFactor::print("", formatter); | ||||
|   std::cout << "{\n"; | ||||
|   if (factors_.empty()) { | ||||
|     std::cout << "  empty" << std::endl; | ||||
|  |  | |||
|  | @ -82,10 +82,13 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { | |||
|    * their cardinalities. | ||||
|    * @param factors The decision tree of Gaussian factors stored as the mixture | ||||
|    * density. | ||||
|    * @param varyingNormalizers Flag indicating factor components have varying | ||||
|    * normalizer values. | ||||
|    */ | ||||
|   GaussianMixtureFactor(const KeyVector &continuousKeys, | ||||
|                         const DiscreteKeys &discreteKeys, | ||||
|                         const Factors &factors); | ||||
|                         const Factors &factors, | ||||
|                         bool varyingNormalizers = false); | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Construct a new GaussianMixtureFactor object using a vector of | ||||
|  | @ -94,12 +97,16 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { | |||
|    * @param continuousKeys Vector of keys for continuous factors. | ||||
|    * @param discreteKeys Vector of discrete keys. | ||||
|    * @param factors Vector of gaussian factor shared pointers. | ||||
|    * @param varyingNormalizers Flag indicating factor components have varying | ||||
|    * normalizer values. | ||||
|    */ | ||||
|   GaussianMixtureFactor(const KeyVector &continuousKeys, | ||||
|                         const DiscreteKeys &discreteKeys, | ||||
|                         const std::vector<sharedFactor> &factors) | ||||
|                         const std::vector<sharedFactor> &factors, | ||||
|                         bool varyingNormalizers = false) | ||||
|       : GaussianMixtureFactor(continuousKeys, discreteKeys, | ||||
|                               Factors(discreteKeys, factors)) {} | ||||
|                               Factors(discreteKeys, factors), | ||||
|                               varyingNormalizers) {} | ||||
| 
 | ||||
|   /// @}
 | ||||
|   /// @name Testable
 | ||||
|  | @ -107,9 +114,8 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { | |||
| 
 | ||||
|   bool equals(const HybridFactor &lf, double tol = 1e-9) const override; | ||||
| 
 | ||||
|   void print( | ||||
|       const std::string &s = "GaussianMixtureFactor\n", | ||||
|       const KeyFormatter &formatter = DefaultKeyFormatter) const override; | ||||
|   void print(const std::string &s = "", const KeyFormatter &formatter = | ||||
|                                             DefaultKeyFormatter) const override; | ||||
| 
 | ||||
|   /// @}
 | ||||
|   /// @name Standard API
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue