Compute log-normalization constant as the max of the individual normalization constants.

release/4.3a0
Frank Dellaert 2023-01-16 15:33:14 -08:00
parent 191e6140b3
commit 57e59d1237
2 changed files with 46 additions and 10 deletions

View File

@ -35,7 +35,16 @@ GaussianMixture::GaussianMixture(
: BaseFactor(CollectKeys(continuousFrontals, continuousParents), : BaseFactor(CollectKeys(continuousFrontals, continuousParents),
discreteParents), discreteParents),
BaseConditional(continuousFrontals.size()), BaseConditional(continuousFrontals.size()),
conditionals_(conditionals) {} conditionals_(conditionals) {
// Calculate logConstant_ as the maximum of the log constants of the
// conditionals, by visiting the decision tree:
logConstant_ = -std::numeric_limits<double>::infinity();
conditionals_.visit(
[this](const GaussianConditional::shared_ptr &conditional) {
this->logConstant_ = std::max(this->logConstant_,
conditional->logNormalizationConstant());
});
}
/* *******************************************************************************/ /* *******************************************************************************/
const GaussianMixture::Conditionals &GaussianMixture::conditionals() const { const GaussianMixture::Conditionals &GaussianMixture::conditionals() const {
@ -203,8 +212,7 @@ boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
const KeyVector continuousParentKeys = continuousParents(); const KeyVector continuousParentKeys = continuousParents();
const GaussianMixtureFactor::Factors likelihoods( const GaussianMixtureFactor::Factors likelihoods(
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) { conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
return GaussianMixtureFactor::sharedFactor{ return conditional->likelihood(given);
conditional->likelihood(given)};
}); });
return boost::make_shared<GaussianMixtureFactor>( return boost::make_shared<GaussianMixtureFactor>(
continuousParentKeys, discreteParentKeys, likelihoods); continuousParentKeys, discreteParentKeys, likelihoods);
@ -307,11 +315,23 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
return DecisionTree<Key, double>(conditionals_, errorFunc); return DecisionTree<Key, double>(conditionals_, errorFunc);
} }
/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::error(
const VectorValues &continuousValues) const {
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
return logConstant_ + conditional->error(continuousValues) -
conditional->logNormalizationConstant();
};
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
return errorTree;
}
/* *******************************************************************************/ /* *******************************************************************************/
double GaussianMixture::error(const HybridValues &values) const { double GaussianMixture::error(const HybridValues &values) const {
// Directly index to get the conditional, no need to build the whole tree. // Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete()); auto conditional = conditionals_(values.discrete());
return conditional->error(values.continuous()) - conditional->logNormalizationConstant(); return logConstant_ + conditional->error(values.continuous()) -
conditional->logNormalizationConstant();
} }
/* *******************************************************************************/ /* *******************************************************************************/

View File

@ -63,7 +63,8 @@ class GTSAM_EXPORT GaussianMixture
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>; using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
private: private:
Conditionals conditionals_; Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.
double logConstant_; ///< log of the normalization constant.
/** /**
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs. * @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
@ -155,6 +156,10 @@ class GTSAM_EXPORT GaussianMixture
/// Returns the continuous keys among the parents. /// Returns the continuous keys among the parents.
KeyVector continuousParents() const; KeyVector continuousParents() const;
/// The log normalization constant is max of the the individual
/// log-normalization constants.
double logNormalizationConstant() const override { return logConstant_; }
/// Return a discrete factor with possibly varying normalization constants. /// Return a discrete factor with possibly varying normalization constants.
/// If there is no variation, return nullptr. /// If there is no variation, return nullptr.
boost::shared_ptr<DecisionTreeFactor> normalizationConstants() const; boost::shared_ptr<DecisionTreeFactor> normalizationConstants() const;
@ -195,15 +200,26 @@ class GTSAM_EXPORT GaussianMixture
* *
* log(probability_m(x;y)) = K_m - error_m(x;y) * log(probability_m(x;y)) = K_m - error_m(x;y)
* *
* We resolve by having K == 0.0 and * We resolve by having K == max(K_m) and
* *
* error(x;y,m) = error_m(x;y) - K_m * error(x;y,m) = error_m(x;y) + K - K_m
*
* which also makes error(x;y,m) >= 0 for all x,y,m.
* *
* @param values Continuous values and discrete assignment. * @param values Continuous values and discrete assignment.
* @return double * @return double
*/ */
double error(const HybridValues &values) const override; double error(const HybridValues &values) const override;
/**
* @brief Compute error of the GaussianMixture as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
* only, with the leaf values as the error for each assignment.
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
/** /**
* @brief Compute the logProbability of this Gaussian Mixture. * @brief Compute the logProbability of this Gaussian Mixture.
* *