diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 0a0332af8..6c92f0252 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -200,24 +200,27 @@ std::shared_ptr GaussianMixture::likelihood( const GaussianMixtureFactor::Factors likelihoods( conditionals_, [&](const GaussianConditional::shared_ptr &conditional) { const auto likelihood_m = conditional->likelihood(given); - const double Cgm_Kgcm = - logConstant_ - conditional->logNormalizationConstant(); - if (Cgm_Kgcm == 0.0) { - return likelihood_m; - } else { - // Add a constant factor to the likelihood in case the noise models - // are not all equal. - GaussianFactorGraph gfg; - gfg.push_back(likelihood_m); - Vector c(1); - c << std::sqrt(2.0 * Cgm_Kgcm); - auto constantFactor = std::make_shared(c); - gfg.push_back(constantFactor); - return std::make_shared(gfg); - } + return likelihood_m; }); + + // First compute all the sqrt(|2 pi Sigma|) terms + auto computeLogNormalizers = [](const GaussianFactor::shared_ptr &gf) { + auto jf = std::dynamic_pointer_cast(gf); + // If we have, say, a Hessian factor, then no need to do anything + if (!jf) return 0.0; + + auto model = jf->get_model(); + // If there is no noise model, there is nothing to do. + if (!model) { + return 0.0; + } + return ComputeLogNormalizer(model); + }; + + AlgebraicDecisionTree log_normalizers = + DecisionTree(likelihoods, computeLogNormalizers); return std::make_shared( - continuousParentKeys, discreteParentKeys, likelihoods); + continuousParentKeys, discreteParentKeys, likelihoods, log_normalizers); } /* ************************************************************************* */ diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index 8471cef33..0427eef7b 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -28,11 +28,55 @@ namespace gtsam { +/** + * @brief Helper function to augment the [A|b] matrices in the factor components + * with the normalizer values. + * This is done by storing the normalizer value in + * the `b` vector as an additional row. + * + * @param factors DecisionTree of GaussianFactor shared pointers. + * @param logNormalizers Tree of log-normalizers corresponding to each + * Gaussian factor in factors. + * @return GaussianMixtureFactor::Factors + */ +GaussianMixtureFactor::Factors augment( + const GaussianMixtureFactor::Factors &factors, + const AlgebraicDecisionTree &logNormalizers) { + // Find the minimum value so we can "proselytize" to positive values. + // Done because we can't have sqrt of negative numbers. + double min_log_normalizer = logNormalizers.min(); + AlgebraicDecisionTree log_normalizers = logNormalizers.apply( + [&min_log_normalizer](double n) { return n - min_log_normalizer; }); + + // Finally, update the [A|b] matrices. + auto update = [&log_normalizers]( + const Assignment &assignment, + const GaussianMixtureFactor::sharedFactor &gf) { + auto jf = std::dynamic_pointer_cast(gf); + if (!jf) return gf; + // If the log_normalizer is 0, do nothing + if (log_normalizers(assignment) == 0.0) return gf; + + GaussianFactorGraph gfg; + gfg.push_back(jf); + + Vector c(1); + c << std::sqrt(log_normalizers(assignment)); + auto constantFactor = std::make_shared(c); + + gfg.push_back(constantFactor); + return std::dynamic_pointer_cast( + std::make_shared(gfg)); + }; + return factors.apply(update); +} + /* *******************************************************************************/ -GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys, - const DiscreteKeys &discreteKeys, - const Factors &factors) - : Base(continuousKeys, discreteKeys), factors_(factors) {} +GaussianMixtureFactor::GaussianMixtureFactor( + const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, + const Factors &factors, const AlgebraicDecisionTree &logNormalizers) + : Base(continuousKeys, discreteKeys), + factors_(augment(factors, logNormalizers)) {} /* *******************************************************************************/ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { @@ -120,4 +164,20 @@ double GaussianMixtureFactor::error(const HybridValues &values) const { return gf->error(values.continuous()); } +/* *******************************************************************************/ +double ComputeLogNormalizer( + const noiseModel::Gaussian::shared_ptr &noise_model) { + // Since noise models are Gaussian, we can get the logDeterminant using + // the same trick as in GaussianConditional + double logDetR = noise_model->R() + .diagonal() + .unaryExpr([](double x) { return log(x); }) + .sum(); + double logDeterminantSigma = -2.0 * logDetR; + + size_t n = noise_model->dim(); + constexpr double log2pi = 1.8378770664093454835606594728112; + return n * log2pi + logDeterminantSigma; +} + } // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index 4570268f1..6680abb54 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -82,10 +82,14 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { * their cardinalities. * @param factors The decision tree of Gaussian factors stored as the mixture * density. + * @param logNormalizers Tree of log-normalizers corresponding to each + * Gaussian factor in factors. */ GaussianMixtureFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, - const Factors &factors); + const Factors &factors, + const AlgebraicDecisionTree &logNormalizers = + AlgebraicDecisionTree(0.0)); /** * @brief Construct a new GaussianMixtureFactor object using a vector of @@ -94,12 +98,16 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { * @param continuousKeys Vector of keys for continuous factors. * @param discreteKeys Vector of discrete keys. * @param factors Vector of gaussian factor shared pointers. + * @param logNormalizers Tree of log-normalizers corresponding to each + * Gaussian factor in factors. */ GaussianMixtureFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, - const std::vector &factors) + const std::vector &factors, + const AlgebraicDecisionTree &logNormalizers = + AlgebraicDecisionTree(0.0)) : GaussianMixtureFactor(continuousKeys, discreteKeys, - Factors(discreteKeys, factors)) {} + Factors(discreteKeys, factors), logNormalizers) {} /// @} /// @name Testable