Compute log-normalization constant as the max of the individual normalization constants.
parent
191e6140b3
commit
57e59d1237
|
|
@ -35,7 +35,16 @@ GaussianMixture::GaussianMixture(
|
||||||
: BaseFactor(CollectKeys(continuousFrontals, continuousParents),
|
: BaseFactor(CollectKeys(continuousFrontals, continuousParents),
|
||||||
discreteParents),
|
discreteParents),
|
||||||
BaseConditional(continuousFrontals.size()),
|
BaseConditional(continuousFrontals.size()),
|
||||||
conditionals_(conditionals) {}
|
conditionals_(conditionals) {
|
||||||
|
// Calculate logConstant_ as the maximum of the log constants of the
|
||||||
|
// conditionals, by visiting the decision tree:
|
||||||
|
logConstant_ = -std::numeric_limits<double>::infinity();
|
||||||
|
conditionals_.visit(
|
||||||
|
[this](const GaussianConditional::shared_ptr &conditional) {
|
||||||
|
this->logConstant_ = std::max(this->logConstant_,
|
||||||
|
conditional->logNormalizationConstant());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
const GaussianMixture::Conditionals &GaussianMixture::conditionals() const {
|
const GaussianMixture::Conditionals &GaussianMixture::conditionals() const {
|
||||||
|
|
@ -203,8 +212,7 @@ boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
|
||||||
const KeyVector continuousParentKeys = continuousParents();
|
const KeyVector continuousParentKeys = continuousParents();
|
||||||
const GaussianMixtureFactor::Factors likelihoods(
|
const GaussianMixtureFactor::Factors likelihoods(
|
||||||
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
|
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
|
||||||
return GaussianMixtureFactor::sharedFactor{
|
return conditional->likelihood(given);
|
||||||
conditional->likelihood(given)};
|
|
||||||
});
|
});
|
||||||
return boost::make_shared<GaussianMixtureFactor>(
|
return boost::make_shared<GaussianMixtureFactor>(
|
||||||
continuousParentKeys, discreteParentKeys, likelihoods);
|
continuousParentKeys, discreteParentKeys, likelihoods);
|
||||||
|
|
@ -307,11 +315,23 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
|
||||||
return DecisionTree<Key, double>(conditionals_, errorFunc);
|
return DecisionTree<Key, double>(conditionals_, errorFunc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
AlgebraicDecisionTree<Key> GaussianMixture::error(
|
||||||
|
const VectorValues &continuousValues) const {
|
||||||
|
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
|
||||||
|
return logConstant_ + conditional->error(continuousValues) -
|
||||||
|
conditional->logNormalizationConstant();
|
||||||
|
};
|
||||||
|
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
|
||||||
|
return errorTree;
|
||||||
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
double GaussianMixture::error(const HybridValues &values) const {
|
double GaussianMixture::error(const HybridValues &values) const {
|
||||||
// Directly index to get the conditional, no need to build the whole tree.
|
// Directly index to get the conditional, no need to build the whole tree.
|
||||||
auto conditional = conditionals_(values.discrete());
|
auto conditional = conditionals_(values.discrete());
|
||||||
return conditional->error(values.continuous()) - conditional->logNormalizationConstant();
|
return logConstant_ + conditional->error(values.continuous()) -
|
||||||
|
conditional->logNormalizationConstant();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
|
|
||||||
|
|
@ -63,7 +63,8 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
|
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Conditionals conditionals_;
|
Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.
|
||||||
|
double logConstant_; ///< log of the normalization constant.
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
|
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
|
||||||
|
|
@ -155,6 +156,10 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
/// Returns the continuous keys among the parents.
|
/// Returns the continuous keys among the parents.
|
||||||
KeyVector continuousParents() const;
|
KeyVector continuousParents() const;
|
||||||
|
|
||||||
|
/// The log normalization constant is max of the the individual
|
||||||
|
/// log-normalization constants.
|
||||||
|
double logNormalizationConstant() const override { return logConstant_; }
|
||||||
|
|
||||||
/// Return a discrete factor with possibly varying normalization constants.
|
/// Return a discrete factor with possibly varying normalization constants.
|
||||||
/// If there is no variation, return nullptr.
|
/// If there is no variation, return nullptr.
|
||||||
boost::shared_ptr<DecisionTreeFactor> normalizationConstants() const;
|
boost::shared_ptr<DecisionTreeFactor> normalizationConstants() const;
|
||||||
|
|
@ -195,15 +200,26 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
*
|
*
|
||||||
* log(probability_m(x;y)) = K_m - error_m(x;y)
|
* log(probability_m(x;y)) = K_m - error_m(x;y)
|
||||||
*
|
*
|
||||||
* We resolve by having K == 0.0 and
|
* We resolve by having K == max(K_m) and
|
||||||
*
|
*
|
||||||
* error(x;y,m) = error_m(x;y) - K_m
|
* error(x;y,m) = error_m(x;y) + K - K_m
|
||||||
|
*
|
||||||
|
* which also makes error(x;y,m) >= 0 for all x,y,m.
|
||||||
*
|
*
|
||||||
* @param values Continuous values and discrete assignment.
|
* @param values Continuous values and discrete assignment.
|
||||||
* @return double
|
* @return double
|
||||||
*/
|
*/
|
||||||
double error(const HybridValues &values) const override;
|
double error(const HybridValues &values) const override;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute error of the GaussianMixture as a tree.
|
||||||
|
*
|
||||||
|
* @param continuousValues The continuous VectorValues.
|
||||||
|
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
|
||||||
|
* only, with the leaf values as the error for each assignment.
|
||||||
|
*/
|
||||||
|
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute the logProbability of this Gaussian Mixture.
|
* @brief Compute the logProbability of this Gaussian Mixture.
|
||||||
*
|
*
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue