From cd9ee0845765767e98c8f334e3a2cf4a4bed6a7e Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 22 Aug 2024 14:43:23 -0400 Subject: [PATCH] remove augment method in favor of conversion --- gtsam/hybrid/GaussianMixture.cpp | 35 ++++++++--------- gtsam/hybrid/GaussianMixtureFactor.cpp | 52 ++------------------------ gtsam/hybrid/GaussianMixtureFactor.h | 14 ++----- 3 files changed, 23 insertions(+), 78 deletions(-) diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 6c92f0252..0a0332af8 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -200,27 +200,24 @@ std::shared_ptr GaussianMixture::likelihood( const GaussianMixtureFactor::Factors likelihoods( conditionals_, [&](const GaussianConditional::shared_ptr &conditional) { const auto likelihood_m = conditional->likelihood(given); - return likelihood_m; + const double Cgm_Kgcm = + logConstant_ - conditional->logNormalizationConstant(); + if (Cgm_Kgcm == 0.0) { + return likelihood_m; + } else { + // Add a constant factor to the likelihood in case the noise models + // are not all equal. + GaussianFactorGraph gfg; + gfg.push_back(likelihood_m); + Vector c(1); + c << std::sqrt(2.0 * Cgm_Kgcm); + auto constantFactor = std::make_shared(c); + gfg.push_back(constantFactor); + return std::make_shared(gfg); + } }); - - // First compute all the sqrt(|2 pi Sigma|) terms - auto computeLogNormalizers = [](const GaussianFactor::shared_ptr &gf) { - auto jf = std::dynamic_pointer_cast(gf); - // If we have, say, a Hessian factor, then no need to do anything - if (!jf) return 0.0; - - auto model = jf->get_model(); - // If there is no noise model, there is nothing to do. - if (!model) { - return 0.0; - } - return ComputeLogNormalizer(model); - }; - - AlgebraicDecisionTree log_normalizers = - DecisionTree(likelihoods, computeLogNormalizers); return std::make_shared( - continuousParentKeys, discreteParentKeys, likelihoods, log_normalizers); + continuousParentKeys, discreteParentKeys, likelihoods); } /* ************************************************************************* */ diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index 0427eef7b..b4ee89cb0 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -28,55 +28,11 @@ namespace gtsam { -/** - * @brief Helper function to augment the [A|b] matrices in the factor components - * with the normalizer values. - * This is done by storing the normalizer value in - * the `b` vector as an additional row. - * - * @param factors DecisionTree of GaussianFactor shared pointers. - * @param logNormalizers Tree of log-normalizers corresponding to each - * Gaussian factor in factors. - * @return GaussianMixtureFactor::Factors - */ -GaussianMixtureFactor::Factors augment( - const GaussianMixtureFactor::Factors &factors, - const AlgebraicDecisionTree &logNormalizers) { - // Find the minimum value so we can "proselytize" to positive values. - // Done because we can't have sqrt of negative numbers. - double min_log_normalizer = logNormalizers.min(); - AlgebraicDecisionTree log_normalizers = logNormalizers.apply( - [&min_log_normalizer](double n) { return n - min_log_normalizer; }); - - // Finally, update the [A|b] matrices. - auto update = [&log_normalizers]( - const Assignment &assignment, - const GaussianMixtureFactor::sharedFactor &gf) { - auto jf = std::dynamic_pointer_cast(gf); - if (!jf) return gf; - // If the log_normalizer is 0, do nothing - if (log_normalizers(assignment) == 0.0) return gf; - - GaussianFactorGraph gfg; - gfg.push_back(jf); - - Vector c(1); - c << std::sqrt(log_normalizers(assignment)); - auto constantFactor = std::make_shared(c); - - gfg.push_back(constantFactor); - return std::dynamic_pointer_cast( - std::make_shared(gfg)); - }; - return factors.apply(update); -} - /* *******************************************************************************/ -GaussianMixtureFactor::GaussianMixtureFactor( - const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, - const Factors &factors, const AlgebraicDecisionTree &logNormalizers) - : Base(continuousKeys, discreteKeys), - factors_(augment(factors, logNormalizers)) {} +GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys, + const DiscreteKeys &discreteKeys, + const Factors &factors) + : Base(continuousKeys, discreteKeys), factors_(factors) {} /* *******************************************************************************/ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index 6680abb54..4570268f1 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -82,14 +82,10 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { * their cardinalities. * @param factors The decision tree of Gaussian factors stored as the mixture * density. - * @param logNormalizers Tree of log-normalizers corresponding to each - * Gaussian factor in factors. */ GaussianMixtureFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, - const Factors &factors, - const AlgebraicDecisionTree &logNormalizers = - AlgebraicDecisionTree(0.0)); + const Factors &factors); /** * @brief Construct a new GaussianMixtureFactor object using a vector of @@ -98,16 +94,12 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { * @param continuousKeys Vector of keys for continuous factors. * @param discreteKeys Vector of discrete keys. * @param factors Vector of gaussian factor shared pointers. - * @param logNormalizers Tree of log-normalizers corresponding to each - * Gaussian factor in factors. */ GaussianMixtureFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, - const std::vector &factors, - const AlgebraicDecisionTree &logNormalizers = - AlgebraicDecisionTree(0.0)) + const std::vector &factors) : GaussianMixtureFactor(continuousKeys, discreteKeys, - Factors(discreteKeys, factors), logNormalizers) {} + Factors(discreteKeys, factors)) {} /// @} /// @name Testable