diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index b5af0bf7f..e690af51c 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -35,7 +35,16 @@ GaussianMixture::GaussianMixture( : BaseFactor(CollectKeys(continuousFrontals, continuousParents), discreteParents), BaseConditional(continuousFrontals.size()), - conditionals_(conditionals) {} + conditionals_(conditionals) { + // Calculate logConstant_ as the maximum of the log constants of the + // conditionals, by visiting the decision tree: + logConstant_ = -std::numeric_limits::infinity(); + conditionals_.visit( + [this](const GaussianConditional::shared_ptr &conditional) { + this->logConstant_ = std::max(this->logConstant_, + conditional->logNormalizationConstant()); + }); +} /* *******************************************************************************/ const GaussianMixture::Conditionals &GaussianMixture::conditionals() const { @@ -203,8 +212,7 @@ boost::shared_ptr GaussianMixture::likelihood( const KeyVector continuousParentKeys = continuousParents(); const GaussianMixtureFactor::Factors likelihoods( conditionals_, [&](const GaussianConditional::shared_ptr &conditional) { - return GaussianMixtureFactor::sharedFactor{ - conditional->likelihood(given)}; + return conditional->likelihood(given); }); return boost::make_shared( continuousParentKeys, discreteParentKeys, likelihoods); @@ -307,11 +315,23 @@ AlgebraicDecisionTree GaussianMixture::logProbability( return DecisionTree(conditionals_, errorFunc); } +/* *******************************************************************************/ +AlgebraicDecisionTree GaussianMixture::error( + const VectorValues &continuousValues) const { + auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) { + return logConstant_ + conditional->error(continuousValues) - + conditional->logNormalizationConstant(); + }; + DecisionTree errorTree(conditionals_, errorFunc); + return errorTree; +} + /* *******************************************************************************/ double GaussianMixture::error(const HybridValues &values) const { // Directly index to get the conditional, no need to build the whole tree. auto conditional = conditionals_(values.discrete()); - return conditional->error(values.continuous()) - conditional->logNormalizationConstant(); + return logConstant_ + conditional->error(values.continuous()) - + conditional->logNormalizationConstant(); } /* *******************************************************************************/ diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index 1b4e9126e..64eda218e 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -63,7 +63,8 @@ class GTSAM_EXPORT GaussianMixture using Conditionals = DecisionTree; private: - Conditionals conditionals_; + Conditionals conditionals_; ///< a decision tree of Gaussian conditionals. + double logConstant_; ///< log of the normalization constant. /** * @brief Convert a DecisionTree of factors into a DT of Gaussian FGs. @@ -155,6 +156,10 @@ class GTSAM_EXPORT GaussianMixture /// Returns the continuous keys among the parents. KeyVector continuousParents() const; + /// The log normalization constant is max of the the individual + /// log-normalization constants. + double logNormalizationConstant() const override { return logConstant_; } + /// Return a discrete factor with possibly varying normalization constants. /// If there is no variation, return nullptr. boost::shared_ptr normalizationConstants() const; @@ -192,18 +197,29 @@ class GTSAM_EXPORT GaussianMixture * in Conditional.h, should not depend on x, y, or m, only on the parameters * of the density. Hence, we delegate to the underlying Gaussian * conditionals, indexed by m, which do satisfy: - * + * * log(probability_m(x;y)) = K_m - error_m(x;y) - * - * We resolve by having K == 0.0 and - * - * error(x;y,m) = error_m(x;y) - K_m + * + * We resolve by having K == max(K_m) and + * + * error(x;y,m) = error_m(x;y) + K - K_m + * + * which also makes error(x;y,m) >= 0 for all x,y,m. * * @param values Continuous values and discrete assignment. * @return double */ double error(const HybridValues &values) const override; + /** + * @brief Compute error of the GaussianMixture as a tree. + * + * @param continuousValues The continuous VectorValues. + * @return AlgebraicDecisionTree A decision tree on the discrete keys + * only, with the leaf values as the error for each assignment. + */ + AlgebraicDecisionTree error(const VectorValues &continuousValues) const; + /** * @brief Compute the logProbability of this Gaussian Mixture. *