update model selection code and docs to match the math

release/4.3a0
Varun Agrawal 2024-03-06 16:01:30 -05:00
parent f62805f8b3
commit 6e8e2579da
3 changed files with 30 additions and 66 deletions

View File

@ -265,16 +265,10 @@ GaussianBayesNetValTree HybridBayesNet::assembleTree() const {
/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection() const {
/*
To perform model selection, we need:
q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma))
If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma))
thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k))
= exp(log(q) - log(k)) = exp(-error - log(k))
= exp(-(error + log(k))),
To perform model selection, we need: q(mu; M, Z) = exp(-error)
where error is computed at the corresponding MAP point, gbn.error(mu).
So we compute (error + log(k)) and exponentiate later
So we compute (-error) and exponentiate later
*/
GaussianBayesNetValTree bnTree = assembleTree();
@ -301,13 +295,16 @@ AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection() const {
auto trees = unzip(bn_error);
AlgebraicDecisionTree<Key> errorTree = trees.second;
// Only compute logNormalizationConstant
AlgebraicDecisionTree<Key> log_norm_constants =
computeLogNormConstants(bnTree);
// Compute model selection term (with help from ADT methods)
AlgebraicDecisionTree<Key> modelSelectionTerm =
computeModelSelectionTerm(errorTree, log_norm_constants);
AlgebraicDecisionTree<Key> modelSelectionTerm = errorTree * -1;
// Exponentiate using our scheme
double max_log = modelSelectionTerm.max();
modelSelectionTerm = DecisionTree<Key, double>(
modelSelectionTerm,
[&max_log](const double &x) { return std::exp(x - max_log); });
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
return modelSelectionTerm;
}
@ -531,20 +528,4 @@ AlgebraicDecisionTree<Key> computeLogNormConstants(
return log_norm_constants;
}
/* ************************************************************************* */
AlgebraicDecisionTree<Key> computeModelSelectionTerm(
const AlgebraicDecisionTree<Key> &errorTree,
const AlgebraicDecisionTree<Key> &log_norm_constants) {
AlgebraicDecisionTree<Key> modelSelectionTerm =
(errorTree + log_norm_constants) * -1;
double max_log = modelSelectionTerm.max();
modelSelectionTerm = DecisionTree<Key, double>(
modelSelectionTerm,
[&max_log](const double &x) { return std::exp(x - max_log); });
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
return modelSelectionTerm;
}
} // namespace gtsam

View File

@ -128,21 +128,16 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/
GaussianBayesNetValTree assembleTree() const;
/*
Compute L(M;Z), the likelihood of the discrete model M
given the measurements Z.
This is called the model selection term.
To do so, we perform the integration of L(M;Z) L(X;M,Z)P(X|M).
By Bayes' rule, P(X|M,Z) L(X;M,Z)P(X|M),
hence L(X;M,Z)P(X|M) is the unnormalized probabilty of
the joint Gaussian distribution.
This can be computed by multiplying all the exponentiated errors
of each of the conditionals.
Return a tree where each leaf value is L(M_i;Z).
/**
* @brief Compute the model selection term q(μ_X; M, Z)
* given the error for each discrete assignment.
*
* The q(μ) terms are obtained as a result of elimination
* as part of the separator factor.
*
* Perform normalization to handle underflow issues.
*
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> modelSelection() const;
@ -301,18 +296,4 @@ GaussianBayesNetTree addGaussian(const GaussianBayesNetTree &gbnTree,
AlgebraicDecisionTree<Key> computeLogNormConstants(
const GaussianBayesNetValTree &bnTree);
/**
* @brief Compute the model selection term L(M; Z, X) given the error
* and log normalization constants.
*
* Perform normalization to handle underflow issues.
*
* @param errorTree
* @param log_norm_constants
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> computeModelSelectionTerm(
const AlgebraicDecisionTree<Key> &errorTree,
const AlgebraicDecisionTree<Key> &log_norm_constants);
} // namespace gtsam

View File

@ -111,13 +111,15 @@ AlgebraicDecisionTree<Key> HybridBayesTree::modelSelection() const {
auto trees = unzip(bn_error);
AlgebraicDecisionTree<Key> errorTree = trees.second;
// Only compute logNormalizationConstant
AlgebraicDecisionTree<Key> log_norm_constants =
computeLogNormConstants(bnTree);
// Compute model selection term (with help from ADT methods)
AlgebraicDecisionTree<Key> modelSelectionTerm =
computeModelSelectionTerm(errorTree, log_norm_constants);
AlgebraicDecisionTree<Key> modelSelectionTerm = errorTree * -1;
// Exponentiate using our scheme
double max_log = modelSelectionTerm.max();
modelSelectionTerm = DecisionTree<Key, double>(
modelSelectionTerm,
[&max_log](const double& x) { return std::exp(x - max_log); });
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
return modelSelectionTerm;
}