cleaner model selection computation

release/4.3a0
Varun Agrawal 2023-12-25 23:11:49 -05:00
parent 1e298be3b3
commit b4f07a0162
1 changed files with 5 additions and 5 deletions

View File

@ -342,9 +342,11 @@ HybridValues HybridBayesNet::optimize() const {
for (auto &&f : *this) { for (auto &&f : *this) {
if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) { if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
error += gm->error(HybridValues(mu, DiscreteValues(assignment))); error += gm->error(HybridValues(mu, DiscreteValues(assignment)));
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) { } else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
if (auto gm = hc->asMixture()) { if (auto gm = hc->asMixture()) {
error += gm->error(HybridValues(mu, DiscreteValues(assignment))); error += gm->error(HybridValues(mu, DiscreteValues(assignment)));
} else if (auto g = hc->asGaussian()) { } else if (auto g = hc->asGaussian()) {
error += g->error(mu); error += g->error(mu);
} }
@ -356,11 +358,9 @@ HybridValues HybridBayesNet::optimize() const {
AlgebraicDecisionTree<Key> errorTree = AlgebraicDecisionTree<Key> errorTree =
DecisionTree<Key, double>(labels, errors); DecisionTree<Key, double>(labels, errors);
// Compute model selection term // Compute model selection term (with help from ADT methods)
AlgebraicDecisionTree<Key> model_selection_term = errorTree.apply( AlgebraicDecisionTree<Key> model_selection_term =
[&log_norm_constants](const Assignment<Key> assignment, double err) { (errorTree + log_norm_constants) * -1;
return -(err + log_norm_constants(assignment));
});
// std::cout << "model selection term" << std::endl; // std::cout << "model selection term" << std::endl;
// model_selection_term.print("", DefaultKeyFormatter); // model_selection_term.print("", DefaultKeyFormatter);