clean up model selection
parent
d4e5a9be5d
commit
fef929f266
|
@ -284,15 +284,10 @@ GaussianBayesNetValTree HybridBayesNet::assembleTree() const {
|
||||||
AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection() const {
|
AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection() const {
|
||||||
/*
|
/*
|
||||||
To perform model selection, we need:
|
To perform model selection, we need:
|
||||||
q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma))
|
q(mu; M, Z) = exp(-error)
|
||||||
|
|
||||||
If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma))
|
|
||||||
thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k))
|
|
||||||
= exp(log(q) - log(k)) = exp(-error - log(k))
|
|
||||||
= exp(-(error + log(k))),
|
|
||||||
where error is computed at the corresponding MAP point, gbn.error(mu).
|
where error is computed at the corresponding MAP point, gbn.error(mu).
|
||||||
|
|
||||||
So we compute (error + log(k)) and exponentiate later
|
So we compute `error` and exponentiate after.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
GaussianBayesNetValTree bnTree = assembleTree();
|
GaussianBayesNetValTree bnTree = assembleTree();
|
||||||
|
@ -316,22 +311,10 @@ AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection() const {
|
||||||
return std::make_pair(gbnAndValue.first, error);
|
return std::make_pair(gbnAndValue.first, error);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Compute model selection term (with help from ADT methods)
|
||||||
auto trees = unzip(bn_error);
|
auto trees = unzip(bn_error);
|
||||||
AlgebraicDecisionTree<Key> errorTree = trees.second;
|
AlgebraicDecisionTree<Key> errorTree = trees.second;
|
||||||
|
AlgebraicDecisionTree<Key> modelSelectionTerm = errorTree * -1;
|
||||||
// Only compute logNormalizationConstant
|
|
||||||
AlgebraicDecisionTree<Key> log_norm_constants = DecisionTree<Key, double>(
|
|
||||||
bnTree, [](const std::pair<GaussianBayesNet, double> &gbnAndValue) {
|
|
||||||
GaussianBayesNet gbn = gbnAndValue.first;
|
|
||||||
if (gbn.size() == 0) {
|
|
||||||
return 0.0;
|
|
||||||
}
|
|
||||||
return gbn.logNormalizationConstant();
|
|
||||||
});
|
|
||||||
|
|
||||||
// Compute model selection term (with help from ADT methods)
|
|
||||||
AlgebraicDecisionTree<Key> modelSelectionTerm =
|
|
||||||
(errorTree + log_norm_constants) * -1;
|
|
||||||
|
|
||||||
double max_log = modelSelectionTerm.max();
|
double max_log = modelSelectionTerm.max();
|
||||||
modelSelectionTerm = DecisionTree<Key, double>(
|
modelSelectionTerm = DecisionTree<Key, double>(
|
||||||
|
|
Loading…
Reference in New Issue