Compute log-normalization constant as the max of the individual normalization constants.

release/4.3a0
Frank Dellaert 2023-01-16 15:33:14 -08:00
parent 191e6140b3
commit 57e59d1237
2 changed files with 46 additions and 10 deletions

View File

@ -35,7 +35,16 @@ GaussianMixture::GaussianMixture(
: BaseFactor(CollectKeys(continuousFrontals, continuousParents),
discreteParents),
BaseConditional(continuousFrontals.size()),
conditionals_(conditionals) {}
conditionals_(conditionals) {
// Calculate logConstant_ as the maximum of the log constants of the
// conditionals, by visiting the decision tree:
logConstant_ = -std::numeric_limits<double>::infinity();
conditionals_.visit(
[this](const GaussianConditional::shared_ptr &conditional) {
this->logConstant_ = std::max(this->logConstant_,
conditional->logNormalizationConstant());
});
}
/* *******************************************************************************/
const GaussianMixture::Conditionals &GaussianMixture::conditionals() const {
@ -203,8 +212,7 @@ boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
const KeyVector continuousParentKeys = continuousParents();
const GaussianMixtureFactor::Factors likelihoods(
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
return GaussianMixtureFactor::sharedFactor{
conditional->likelihood(given)};
return conditional->likelihood(given);
});
return boost::make_shared<GaussianMixtureFactor>(
continuousParentKeys, discreteParentKeys, likelihoods);
@ -307,11 +315,23 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
return DecisionTree<Key, double>(conditionals_, errorFunc);
}
/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::error(
const VectorValues &continuousValues) const {
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
return logConstant_ + conditional->error(continuousValues) -
conditional->logNormalizationConstant();
};
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
return errorTree;
}
/* *******************************************************************************/
double GaussianMixture::error(const HybridValues &values) const {
// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete());
return conditional->error(values.continuous()) - conditional->logNormalizationConstant();
return logConstant_ + conditional->error(values.continuous()) -
conditional->logNormalizationConstant();
}
/* *******************************************************************************/

View File

@ -63,7 +63,8 @@ class GTSAM_EXPORT GaussianMixture
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
private:
Conditionals conditionals_;
Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.
double logConstant_; ///< log of the normalization constant.
/**
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
@ -155,6 +156,10 @@ class GTSAM_EXPORT GaussianMixture
/// Returns the continuous keys among the parents.
KeyVector continuousParents() const;
/// The log normalization constant is max of the the individual
/// log-normalization constants.
double logNormalizationConstant() const override { return logConstant_; }
/// Return a discrete factor with possibly varying normalization constants.
/// If there is no variation, return nullptr.
boost::shared_ptr<DecisionTreeFactor> normalizationConstants() const;
@ -192,18 +197,29 @@ class GTSAM_EXPORT GaussianMixture
* in Conditional.h, should not depend on x, y, or m, only on the parameters
* of the density. Hence, we delegate to the underlying Gaussian
* conditionals, indexed by m, which do satisfy:
*
*
* log(probability_m(x;y)) = K_m - error_m(x;y)
*
* We resolve by having K == 0.0 and
*
* error(x;y,m) = error_m(x;y) - K_m
*
* We resolve by having K == max(K_m) and
*
* error(x;y,m) = error_m(x;y) + K - K_m
*
* which also makes error(x;y,m) >= 0 for all x,y,m.
*
* @param values Continuous values and discrete assignment.
* @return double
*/
double error(const HybridValues &values) const override;
/**
* @brief Compute error of the GaussianMixture as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
* only, with the leaf values as the error for each assignment.
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
/**
* @brief Compute the logProbability of this Gaussian Mixture.
*