update model selection code and docs to match the math
parent
f62805f8b3
commit
6e8e2579da
|
|
@ -265,16 +265,10 @@ GaussianBayesNetValTree HybridBayesNet::assembleTree() const {
|
|||
/* ************************************************************************* */
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection() const {
|
||||
/*
|
||||
To perform model selection, we need:
|
||||
q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma))
|
||||
|
||||
If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma))
|
||||
thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k))
|
||||
= exp(log(q) - log(k)) = exp(-error - log(k))
|
||||
= exp(-(error + log(k))),
|
||||
To perform model selection, we need: q(mu; M, Z) = exp(-error)
|
||||
where error is computed at the corresponding MAP point, gbn.error(mu).
|
||||
|
||||
So we compute (error + log(k)) and exponentiate later
|
||||
So we compute (-error) and exponentiate later
|
||||
*/
|
||||
|
||||
GaussianBayesNetValTree bnTree = assembleTree();
|
||||
|
|
@ -301,13 +295,16 @@ AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection() const {
|
|||
auto trees = unzip(bn_error);
|
||||
AlgebraicDecisionTree<Key> errorTree = trees.second;
|
||||
|
||||
// Only compute logNormalizationConstant
|
||||
AlgebraicDecisionTree<Key> log_norm_constants =
|
||||
computeLogNormConstants(bnTree);
|
||||
|
||||
// Compute model selection term (with help from ADT methods)
|
||||
AlgebraicDecisionTree<Key> modelSelectionTerm =
|
||||
computeModelSelectionTerm(errorTree, log_norm_constants);
|
||||
AlgebraicDecisionTree<Key> modelSelectionTerm = errorTree * -1;
|
||||
|
||||
// Exponentiate using our scheme
|
||||
double max_log = modelSelectionTerm.max();
|
||||
modelSelectionTerm = DecisionTree<Key, double>(
|
||||
modelSelectionTerm,
|
||||
[&max_log](const double &x) { return std::exp(x - max_log); });
|
||||
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
|
||||
|
||||
return modelSelectionTerm;
|
||||
}
|
||||
|
||||
|
|
@ -531,20 +528,4 @@ AlgebraicDecisionTree<Key> computeLogNormConstants(
|
|||
return log_norm_constants;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
AlgebraicDecisionTree<Key> computeModelSelectionTerm(
|
||||
const AlgebraicDecisionTree<Key> &errorTree,
|
||||
const AlgebraicDecisionTree<Key> &log_norm_constants) {
|
||||
AlgebraicDecisionTree<Key> modelSelectionTerm =
|
||||
(errorTree + log_norm_constants) * -1;
|
||||
|
||||
double max_log = modelSelectionTerm.max();
|
||||
modelSelectionTerm = DecisionTree<Key, double>(
|
||||
modelSelectionTerm,
|
||||
[&max_log](const double &x) { return std::exp(x - max_log); });
|
||||
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
|
||||
|
||||
return modelSelectionTerm;
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -128,21 +128,16 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
*/
|
||||
GaussianBayesNetValTree assembleTree() const;
|
||||
|
||||
/*
|
||||
Compute L(M;Z), the likelihood of the discrete model M
|
||||
given the measurements Z.
|
||||
This is called the model selection term.
|
||||
|
||||
To do so, we perform the integration of L(M;Z) ∝ L(X;M,Z)P(X|M).
|
||||
|
||||
By Bayes' rule, P(X|M,Z) ∝ L(X;M,Z)P(X|M),
|
||||
hence L(X;M,Z)P(X|M) is the unnormalized probabilty of
|
||||
the joint Gaussian distribution.
|
||||
|
||||
This can be computed by multiplying all the exponentiated errors
|
||||
of each of the conditionals.
|
||||
|
||||
Return a tree where each leaf value is L(M_i;Z).
|
||||
/**
|
||||
* @brief Compute the model selection term q(μ_X; M, Z)
|
||||
* given the error for each discrete assignment.
|
||||
*
|
||||
* The q(μ) terms are obtained as a result of elimination
|
||||
* as part of the separator factor.
|
||||
*
|
||||
* Perform normalization to handle underflow issues.
|
||||
*
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> modelSelection() const;
|
||||
|
||||
|
|
@ -301,18 +296,4 @@ GaussianBayesNetTree addGaussian(const GaussianBayesNetTree &gbnTree,
|
|||
AlgebraicDecisionTree<Key> computeLogNormConstants(
|
||||
const GaussianBayesNetValTree &bnTree);
|
||||
|
||||
/**
|
||||
* @brief Compute the model selection term L(M; Z, X) given the error
|
||||
* and log normalization constants.
|
||||
*
|
||||
* Perform normalization to handle underflow issues.
|
||||
*
|
||||
* @param errorTree
|
||||
* @param log_norm_constants
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> computeModelSelectionTerm(
|
||||
const AlgebraicDecisionTree<Key> &errorTree,
|
||||
const AlgebraicDecisionTree<Key> &log_norm_constants);
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -111,13 +111,15 @@ AlgebraicDecisionTree<Key> HybridBayesTree::modelSelection() const {
|
|||
auto trees = unzip(bn_error);
|
||||
AlgebraicDecisionTree<Key> errorTree = trees.second;
|
||||
|
||||
// Only compute logNormalizationConstant
|
||||
AlgebraicDecisionTree<Key> log_norm_constants =
|
||||
computeLogNormConstants(bnTree);
|
||||
|
||||
// Compute model selection term (with help from ADT methods)
|
||||
AlgebraicDecisionTree<Key> modelSelectionTerm =
|
||||
computeModelSelectionTerm(errorTree, log_norm_constants);
|
||||
AlgebraicDecisionTree<Key> modelSelectionTerm = errorTree * -1;
|
||||
|
||||
// Exponentiate using our scheme
|
||||
double max_log = modelSelectionTerm.max();
|
||||
modelSelectionTerm = DecisionTree<Key, double>(
|
||||
modelSelectionTerm,
|
||||
[&max_log](const double& x) { return std::exp(x - max_log); });
|
||||
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
|
||||
|
||||
return modelSelectionTerm;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue