update model selection code and docs to match the math

release/4.3a0
Varun Agrawal 2024-03-06 16:01:30 -05:00
parent f62805f8b3
commit 6e8e2579da
3 changed files with 30 additions and 66 deletions

View File

@ -265,16 +265,10 @@ GaussianBayesNetValTree HybridBayesNet::assembleTree() const {
/* ************************************************************************* */ /* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection() const { AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection() const {
/* /*
To perform model selection, we need: To perform model selection, we need: q(mu; M, Z) = exp(-error)
q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma))
If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma))
thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k))
= exp(log(q) - log(k)) = exp(-error - log(k))
= exp(-(error + log(k))),
where error is computed at the corresponding MAP point, gbn.error(mu). where error is computed at the corresponding MAP point, gbn.error(mu).
So we compute (error + log(k)) and exponentiate later So we compute (-error) and exponentiate later
*/ */
GaussianBayesNetValTree bnTree = assembleTree(); GaussianBayesNetValTree bnTree = assembleTree();
@ -301,13 +295,16 @@ AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection() const {
auto trees = unzip(bn_error); auto trees = unzip(bn_error);
AlgebraicDecisionTree<Key> errorTree = trees.second; AlgebraicDecisionTree<Key> errorTree = trees.second;
// Only compute logNormalizationConstant
AlgebraicDecisionTree<Key> log_norm_constants =
computeLogNormConstants(bnTree);
// Compute model selection term (with help from ADT methods) // Compute model selection term (with help from ADT methods)
AlgebraicDecisionTree<Key> modelSelectionTerm = AlgebraicDecisionTree<Key> modelSelectionTerm = errorTree * -1;
computeModelSelectionTerm(errorTree, log_norm_constants);
// Exponentiate using our scheme
double max_log = modelSelectionTerm.max();
modelSelectionTerm = DecisionTree<Key, double>(
modelSelectionTerm,
[&max_log](const double &x) { return std::exp(x - max_log); });
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
return modelSelectionTerm; return modelSelectionTerm;
} }
@ -531,20 +528,4 @@ AlgebraicDecisionTree<Key> computeLogNormConstants(
return log_norm_constants; return log_norm_constants;
} }
/* ************************************************************************* */
AlgebraicDecisionTree<Key> computeModelSelectionTerm(
const AlgebraicDecisionTree<Key> &errorTree,
const AlgebraicDecisionTree<Key> &log_norm_constants) {
AlgebraicDecisionTree<Key> modelSelectionTerm =
(errorTree + log_norm_constants) * -1;
double max_log = modelSelectionTerm.max();
modelSelectionTerm = DecisionTree<Key, double>(
modelSelectionTerm,
[&max_log](const double &x) { return std::exp(x - max_log); });
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
return modelSelectionTerm;
}
} // namespace gtsam } // namespace gtsam

View File

@ -128,22 +128,17 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/ */
GaussianBayesNetValTree assembleTree() const; GaussianBayesNetValTree assembleTree() const;
/* /**
Compute L(M;Z), the likelihood of the discrete model M * @brief Compute the model selection term q(μ_X; M, Z)
given the measurements Z. * given the error for each discrete assignment.
This is called the model selection term. *
* The q(μ) terms are obtained as a result of elimination
To do so, we perform the integration of L(M;Z) L(X;M,Z)P(X|M). * as part of the separator factor.
*
By Bayes' rule, P(X|M,Z) L(X;M,Z)P(X|M), * Perform normalization to handle underflow issues.
hence L(X;M,Z)P(X|M) is the unnormalized probabilty of *
the joint Gaussian distribution. * @return AlgebraicDecisionTree<Key>
*/
This can be computed by multiplying all the exponentiated errors
of each of the conditionals.
Return a tree where each leaf value is L(M_i;Z).
*/
AlgebraicDecisionTree<Key> modelSelection() const; AlgebraicDecisionTree<Key> modelSelection() const;
/** /**
@ -301,18 +296,4 @@ GaussianBayesNetTree addGaussian(const GaussianBayesNetTree &gbnTree,
AlgebraicDecisionTree<Key> computeLogNormConstants( AlgebraicDecisionTree<Key> computeLogNormConstants(
const GaussianBayesNetValTree &bnTree); const GaussianBayesNetValTree &bnTree);
/**
* @brief Compute the model selection term L(M; Z, X) given the error
* and log normalization constants.
*
* Perform normalization to handle underflow issues.
*
* @param errorTree
* @param log_norm_constants
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> computeModelSelectionTerm(
const AlgebraicDecisionTree<Key> &errorTree,
const AlgebraicDecisionTree<Key> &log_norm_constants);
} // namespace gtsam } // namespace gtsam

View File

@ -111,13 +111,15 @@ AlgebraicDecisionTree<Key> HybridBayesTree::modelSelection() const {
auto trees = unzip(bn_error); auto trees = unzip(bn_error);
AlgebraicDecisionTree<Key> errorTree = trees.second; AlgebraicDecisionTree<Key> errorTree = trees.second;
// Only compute logNormalizationConstant
AlgebraicDecisionTree<Key> log_norm_constants =
computeLogNormConstants(bnTree);
// Compute model selection term (with help from ADT methods) // Compute model selection term (with help from ADT methods)
AlgebraicDecisionTree<Key> modelSelectionTerm = AlgebraicDecisionTree<Key> modelSelectionTerm = errorTree * -1;
computeModelSelectionTerm(errorTree, log_norm_constants);
// Exponentiate using our scheme
double max_log = modelSelectionTerm.max();
modelSelectionTerm = DecisionTree<Key, double>(
modelSelectionTerm,
[&max_log](const double& x) { return std::exp(x - max_log); });
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
return modelSelectionTerm; return modelSelectionTerm;
} }