update model selection code and docs to match the math
parent
f62805f8b3
commit
6e8e2579da
|
|
@ -265,16 +265,10 @@ GaussianBayesNetValTree HybridBayesNet::assembleTree() const {
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection() const {
|
AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection() const {
|
||||||
/*
|
/*
|
||||||
To perform model selection, we need:
|
To perform model selection, we need: q(mu; M, Z) = exp(-error)
|
||||||
q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma))
|
|
||||||
|
|
||||||
If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma))
|
|
||||||
thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k))
|
|
||||||
= exp(log(q) - log(k)) = exp(-error - log(k))
|
|
||||||
= exp(-(error + log(k))),
|
|
||||||
where error is computed at the corresponding MAP point, gbn.error(mu).
|
where error is computed at the corresponding MAP point, gbn.error(mu).
|
||||||
|
|
||||||
So we compute (error + log(k)) and exponentiate later
|
So we compute (-error) and exponentiate later
|
||||||
*/
|
*/
|
||||||
|
|
||||||
GaussianBayesNetValTree bnTree = assembleTree();
|
GaussianBayesNetValTree bnTree = assembleTree();
|
||||||
|
|
@ -301,13 +295,16 @@ AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection() const {
|
||||||
auto trees = unzip(bn_error);
|
auto trees = unzip(bn_error);
|
||||||
AlgebraicDecisionTree<Key> errorTree = trees.second;
|
AlgebraicDecisionTree<Key> errorTree = trees.second;
|
||||||
|
|
||||||
// Only compute logNormalizationConstant
|
|
||||||
AlgebraicDecisionTree<Key> log_norm_constants =
|
|
||||||
computeLogNormConstants(bnTree);
|
|
||||||
|
|
||||||
// Compute model selection term (with help from ADT methods)
|
// Compute model selection term (with help from ADT methods)
|
||||||
AlgebraicDecisionTree<Key> modelSelectionTerm =
|
AlgebraicDecisionTree<Key> modelSelectionTerm = errorTree * -1;
|
||||||
computeModelSelectionTerm(errorTree, log_norm_constants);
|
|
||||||
|
// Exponentiate using our scheme
|
||||||
|
double max_log = modelSelectionTerm.max();
|
||||||
|
modelSelectionTerm = DecisionTree<Key, double>(
|
||||||
|
modelSelectionTerm,
|
||||||
|
[&max_log](const double &x) { return std::exp(x - max_log); });
|
||||||
|
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
|
||||||
|
|
||||||
return modelSelectionTerm;
|
return modelSelectionTerm;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -531,20 +528,4 @@ AlgebraicDecisionTree<Key> computeLogNormConstants(
|
||||||
return log_norm_constants;
|
return log_norm_constants;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
AlgebraicDecisionTree<Key> computeModelSelectionTerm(
|
|
||||||
const AlgebraicDecisionTree<Key> &errorTree,
|
|
||||||
const AlgebraicDecisionTree<Key> &log_norm_constants) {
|
|
||||||
AlgebraicDecisionTree<Key> modelSelectionTerm =
|
|
||||||
(errorTree + log_norm_constants) * -1;
|
|
||||||
|
|
||||||
double max_log = modelSelectionTerm.max();
|
|
||||||
modelSelectionTerm = DecisionTree<Key, double>(
|
|
||||||
modelSelectionTerm,
|
|
||||||
[&max_log](const double &x) { return std::exp(x - max_log); });
|
|
||||||
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
|
|
||||||
|
|
||||||
return modelSelectionTerm;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -128,22 +128,17 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
*/
|
*/
|
||||||
GaussianBayesNetValTree assembleTree() const;
|
GaussianBayesNetValTree assembleTree() const;
|
||||||
|
|
||||||
/*
|
/**
|
||||||
Compute L(M;Z), the likelihood of the discrete model M
|
* @brief Compute the model selection term q(μ_X; M, Z)
|
||||||
given the measurements Z.
|
* given the error for each discrete assignment.
|
||||||
This is called the model selection term.
|
*
|
||||||
|
* The q(μ) terms are obtained as a result of elimination
|
||||||
To do so, we perform the integration of L(M;Z) ∝ L(X;M,Z)P(X|M).
|
* as part of the separator factor.
|
||||||
|
*
|
||||||
By Bayes' rule, P(X|M,Z) ∝ L(X;M,Z)P(X|M),
|
* Perform normalization to handle underflow issues.
|
||||||
hence L(X;M,Z)P(X|M) is the unnormalized probabilty of
|
*
|
||||||
the joint Gaussian distribution.
|
* @return AlgebraicDecisionTree<Key>
|
||||||
|
*/
|
||||||
This can be computed by multiplying all the exponentiated errors
|
|
||||||
of each of the conditionals.
|
|
||||||
|
|
||||||
Return a tree where each leaf value is L(M_i;Z).
|
|
||||||
*/
|
|
||||||
AlgebraicDecisionTree<Key> modelSelection() const;
|
AlgebraicDecisionTree<Key> modelSelection() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -301,18 +296,4 @@ GaussianBayesNetTree addGaussian(const GaussianBayesNetTree &gbnTree,
|
||||||
AlgebraicDecisionTree<Key> computeLogNormConstants(
|
AlgebraicDecisionTree<Key> computeLogNormConstants(
|
||||||
const GaussianBayesNetValTree &bnTree);
|
const GaussianBayesNetValTree &bnTree);
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Compute the model selection term L(M; Z, X) given the error
|
|
||||||
* and log normalization constants.
|
|
||||||
*
|
|
||||||
* Perform normalization to handle underflow issues.
|
|
||||||
*
|
|
||||||
* @param errorTree
|
|
||||||
* @param log_norm_constants
|
|
||||||
* @return AlgebraicDecisionTree<Key>
|
|
||||||
*/
|
|
||||||
AlgebraicDecisionTree<Key> computeModelSelectionTerm(
|
|
||||||
const AlgebraicDecisionTree<Key> &errorTree,
|
|
||||||
const AlgebraicDecisionTree<Key> &log_norm_constants);
|
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -111,13 +111,15 @@ AlgebraicDecisionTree<Key> HybridBayesTree::modelSelection() const {
|
||||||
auto trees = unzip(bn_error);
|
auto trees = unzip(bn_error);
|
||||||
AlgebraicDecisionTree<Key> errorTree = trees.second;
|
AlgebraicDecisionTree<Key> errorTree = trees.second;
|
||||||
|
|
||||||
// Only compute logNormalizationConstant
|
|
||||||
AlgebraicDecisionTree<Key> log_norm_constants =
|
|
||||||
computeLogNormConstants(bnTree);
|
|
||||||
|
|
||||||
// Compute model selection term (with help from ADT methods)
|
// Compute model selection term (with help from ADT methods)
|
||||||
AlgebraicDecisionTree<Key> modelSelectionTerm =
|
AlgebraicDecisionTree<Key> modelSelectionTerm = errorTree * -1;
|
||||||
computeModelSelectionTerm(errorTree, log_norm_constants);
|
|
||||||
|
// Exponentiate using our scheme
|
||||||
|
double max_log = modelSelectionTerm.max();
|
||||||
|
modelSelectionTerm = DecisionTree<Key, double>(
|
||||||
|
modelSelectionTerm,
|
||||||
|
[&max_log](const double& x) { return std::exp(x - max_log); });
|
||||||
|
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
|
||||||
|
|
||||||
return modelSelectionTerm;
|
return modelSelectionTerm;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue