diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index a3db16d04..e519cefe6 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -28,11 +28,86 @@ namespace gtsam { +/** + * @brief Helper function to correct the [A|b] matrices in the factor components + * with the normalizer values. + * This is done by storing the normalizer value in + * the `b` vector as an additional row. + * + * @param factors DecisionTree of GaussianFactor shared pointers. + * @param varyingNormalizers Flag indicating the normalizers are different for + * each component. + * @return GaussianMixtureFactor::Factors + */ +GaussianMixtureFactor::Factors correct( + const GaussianMixtureFactor::Factors &factors, bool varyingNormalizers) { + if (!varyingNormalizers) { + return factors; + } + + // First compute all the sqrt(|2 pi Sigma|) terms + auto computeNormalizers = [](const GaussianMixtureFactor::sharedFactor &gf) { + auto jf = std::dynamic_pointer_cast(gf); + // If we have, say, a Hessian factor, then no need to do anything + if (!jf) return 0.0; + + auto model = jf->get_model(); + // If there is no noise model, there is nothing to do. + if (!model) { + return 0.0; + } + // Since noise models are Gaussian, we can get the logDeterminant using the + // same trick as in GaussianConditional + double logDetR = + model->R().diagonal().unaryExpr([](double x) { return log(x); }).sum(); + double logDeterminantSigma = -2.0 * logDetR; + + size_t n = model->dim(); + constexpr double log2pi = 1.8378770664093454835606594728112; + return n * log2pi + logDeterminantSigma; + }; + + AlgebraicDecisionTree log_normalizers = + DecisionTree(factors, computeNormalizers); + + // Find the minimum value so we can "proselytize" to positive values. + // Done because we can't have sqrt of negative numbers. + double min_log_normalizer = log_normalizers.min(); + log_normalizers = log_normalizers.apply( + [&min_log_normalizer](double n) { return n - min_log_normalizer; }); + + // Finally, update the [A|b] matrices. + auto update = [&log_normalizers]( + const Assignment &assignment, + const GaussianMixtureFactor::sharedFactor &gf) { + auto jf = std::dynamic_pointer_cast(gf); + if (!jf) return gf; + // If there is no noise model, there is nothing to do. + if (!jf->get_model()) return gf; + // If the log_normalizer is 0, do nothing + if (log_normalizers(assignment) == 0.0) return gf; + + GaussianFactorGraph gfg; + gfg.push_back(jf); + + Vector c(1); + c << std::sqrt(log_normalizers(assignment)); + auto constantFactor = std::make_shared(c); + + gfg.push_back(constantFactor); + return std::dynamic_pointer_cast( + std::make_shared(gfg)); + }; + return factors.apply(update); +} + /* *******************************************************************************/ GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, - const Factors &factors) - : Base(continuousKeys, discreteKeys), factors_(factors) {} + const Factors &factors, + bool varyingNormalizers) + : Base(continuousKeys, discreteKeys), + factors_(correct(factors, varyingNormalizers)) {} /* *******************************************************************************/ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { @@ -54,7 +129,9 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { /* *******************************************************************************/ void GaussianMixtureFactor::print(const std::string &s, const KeyFormatter &formatter) const { - HybridFactor::print(s, formatter); + std::cout << (s.empty() ? "" : s + "\n"); + std::cout << "GaussianMixtureFactor" << std::endl; + HybridFactor::print("", formatter); std::cout << "{\n"; if (factors_.empty()) { std::cout << " empty" << std::endl; diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index 67d12ddb0..588501bbe 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -82,10 +82,13 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { * their cardinalities. * @param factors The decision tree of Gaussian factors stored as the mixture * density. + * @param varyingNormalizers Flag indicating factor components have varying + * normalizer values. */ GaussianMixtureFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, - const Factors &factors); + const Factors &factors, + bool varyingNormalizers = false); /** * @brief Construct a new GaussianMixtureFactor object using a vector of @@ -94,12 +97,16 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { * @param continuousKeys Vector of keys for continuous factors. * @param discreteKeys Vector of discrete keys. * @param factors Vector of gaussian factor shared pointers. + * @param varyingNormalizers Flag indicating factor components have varying + * normalizer values. */ GaussianMixtureFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, - const std::vector &factors) + const std::vector &factors, + bool varyingNormalizers = false) : GaussianMixtureFactor(continuousKeys, discreteKeys, - Factors(discreteKeys, factors)) {} + Factors(discreteKeys, factors), + varyingNormalizers) {} /// @} /// @name Testable @@ -107,9 +114,8 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { bool equals(const HybridFactor &lf, double tol = 1e-9) const override; - void print( - const std::string &s = "GaussianMixtureFactor\n", - const KeyFormatter &formatter = DefaultKeyFormatter) const override; + void print(const std::string &s = "", const KeyFormatter &formatter = + DefaultKeyFormatter) const override; /// @} /// @name Standard API