Compute log-normalization constant as the max of the individual normalization constants.
parent
191e6140b3
commit
57e59d1237
|
|
@ -35,7 +35,16 @@ GaussianMixture::GaussianMixture(
|
|||
: BaseFactor(CollectKeys(continuousFrontals, continuousParents),
|
||||
discreteParents),
|
||||
BaseConditional(continuousFrontals.size()),
|
||||
conditionals_(conditionals) {}
|
||||
conditionals_(conditionals) {
|
||||
// Calculate logConstant_ as the maximum of the log constants of the
|
||||
// conditionals, by visiting the decision tree:
|
||||
logConstant_ = -std::numeric_limits<double>::infinity();
|
||||
conditionals_.visit(
|
||||
[this](const GaussianConditional::shared_ptr &conditional) {
|
||||
this->logConstant_ = std::max(this->logConstant_,
|
||||
conditional->logNormalizationConstant());
|
||||
});
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
const GaussianMixture::Conditionals &GaussianMixture::conditionals() const {
|
||||
|
|
@ -203,8 +212,7 @@ boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
|
|||
const KeyVector continuousParentKeys = continuousParents();
|
||||
const GaussianMixtureFactor::Factors likelihoods(
|
||||
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
|
||||
return GaussianMixtureFactor::sharedFactor{
|
||||
conditional->likelihood(given)};
|
||||
return conditional->likelihood(given);
|
||||
});
|
||||
return boost::make_shared<GaussianMixtureFactor>(
|
||||
continuousParentKeys, discreteParentKeys, likelihoods);
|
||||
|
|
@ -307,11 +315,23 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
|
|||
return DecisionTree<Key, double>(conditionals_, errorFunc);
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
AlgebraicDecisionTree<Key> GaussianMixture::error(
|
||||
const VectorValues &continuousValues) const {
|
||||
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
|
||||
return logConstant_ + conditional->error(continuousValues) -
|
||||
conditional->logNormalizationConstant();
|
||||
};
|
||||
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
|
||||
return errorTree;
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
double GaussianMixture::error(const HybridValues &values) const {
|
||||
// Directly index to get the conditional, no need to build the whole tree.
|
||||
auto conditional = conditionals_(values.discrete());
|
||||
return conditional->error(values.continuous()) - conditional->logNormalizationConstant();
|
||||
return logConstant_ + conditional->error(values.continuous()) -
|
||||
conditional->logNormalizationConstant();
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
|
|
|||
|
|
@ -63,7 +63,8 @@ class GTSAM_EXPORT GaussianMixture
|
|||
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
|
||||
|
||||
private:
|
||||
Conditionals conditionals_;
|
||||
Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.
|
||||
double logConstant_; ///< log of the normalization constant.
|
||||
|
||||
/**
|
||||
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
|
||||
|
|
@ -155,6 +156,10 @@ class GTSAM_EXPORT GaussianMixture
|
|||
/// Returns the continuous keys among the parents.
|
||||
KeyVector continuousParents() const;
|
||||
|
||||
/// The log normalization constant is max of the the individual
|
||||
/// log-normalization constants.
|
||||
double logNormalizationConstant() const override { return logConstant_; }
|
||||
|
||||
/// Return a discrete factor with possibly varying normalization constants.
|
||||
/// If there is no variation, return nullptr.
|
||||
boost::shared_ptr<DecisionTreeFactor> normalizationConstants() const;
|
||||
|
|
@ -192,18 +197,29 @@ class GTSAM_EXPORT GaussianMixture
|
|||
* in Conditional.h, should not depend on x, y, or m, only on the parameters
|
||||
* of the density. Hence, we delegate to the underlying Gaussian
|
||||
* conditionals, indexed by m, which do satisfy:
|
||||
*
|
||||
*
|
||||
* log(probability_m(x;y)) = K_m - error_m(x;y)
|
||||
*
|
||||
* We resolve by having K == 0.0 and
|
||||
*
|
||||
* error(x;y,m) = error_m(x;y) - K_m
|
||||
*
|
||||
* We resolve by having K == max(K_m) and
|
||||
*
|
||||
* error(x;y,m) = error_m(x;y) + K - K_m
|
||||
*
|
||||
* which also makes error(x;y,m) >= 0 for all x,y,m.
|
||||
*
|
||||
* @param values Continuous values and discrete assignment.
|
||||
* @return double
|
||||
*/
|
||||
double error(const HybridValues &values) const override;
|
||||
|
||||
/**
|
||||
* @brief Compute error of the GaussianMixture as a tree.
|
||||
*
|
||||
* @param continuousValues The continuous VectorValues.
|
||||
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
|
||||
* only, with the leaf values as the error for each assignment.
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Compute the logProbability of this Gaussian Mixture.
|
||||
*
|
||||
|
|
|
|||
Loading…
Reference in New Issue