clean up model selection

release/4.3a0
Varun Agrawal 2024-08-20 12:53:53 -04:00
parent d4e5a9be5d
commit fef929f266
1 changed files with 4 additions and 21 deletions

View File

@ -284,15 +284,10 @@ GaussianBayesNetValTree HybridBayesNet::assembleTree() const {
AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection() const {
/*
To perform model selection, we need:
q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma))
If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma))
thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k))
= exp(log(q) - log(k)) = exp(-error - log(k))
= exp(-(error + log(k))),
q(mu; M, Z) = exp(-error)
where error is computed at the corresponding MAP point, gbn.error(mu).
So we compute (error + log(k)) and exponentiate later
So we compute `error` and exponentiate after.
*/
GaussianBayesNetValTree bnTree = assembleTree();
@ -316,22 +311,10 @@ AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection() const {
return std::make_pair(gbnAndValue.first, error);
});
// Compute model selection term (with help from ADT methods)
auto trees = unzip(bn_error);
AlgebraicDecisionTree<Key> errorTree = trees.second;
// Only compute logNormalizationConstant
AlgebraicDecisionTree<Key> log_norm_constants = DecisionTree<Key, double>(
bnTree, [](const std::pair<GaussianBayesNet, double> &gbnAndValue) {
GaussianBayesNet gbn = gbnAndValue.first;
if (gbn.size() == 0) {
return 0.0;
}
return gbn.logNormalizationConstant();
});
// Compute model selection term (with help from ADT methods)
AlgebraicDecisionTree<Key> modelSelectionTerm =
(errorTree + log_norm_constants) * -1;
AlgebraicDecisionTree<Key> modelSelectionTerm = errorTree * -1;
double max_log = modelSelectionTerm.max();
modelSelectionTerm = DecisionTree<Key, double>(