almost working

release/4.3a0
Varun Agrawal 2023-12-26 00:20:44 -05:00
parent b4f07a0162
commit 6f4343ca94
1 changed files with 20 additions and 25 deletions

View File

@ -329,34 +329,29 @@ HybridValues HybridBayesNet::optimize() const {
return gbn.logNormalizationConstant();
});
// Compute unnormalized error term
std::vector<DiscreteKey> labels;
for (auto &&key : x_map.labels()) {
labels.push_back(std::make_pair(key, 2));
}
// Compute errors as VectorValues
DecisionTree<Key, VectorValues> errorVectors = x_map.apply(
[this](const Assignment<Key> &assignment, const VectorValues &mu) {
double error = 0.0;
for (auto &&f : *this) {
if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
error += gm->error(HybridValues(mu, DiscreteValues(assignment)));
std::vector<double> errors;
x_map.visitWith([this, &errors](const Assignment<Key> &assignment,
const VectorValues &mu) {
double error = 0.0;
for (auto &&f : *this) {
if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
error += gm->error(HybridValues(mu, DiscreteValues(assignment)));
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
if (auto gm = hc->asMixture()) {
error += gm->error(HybridValues(mu, DiscreteValues(assignment)));
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
if (auto gm = hc->asMixture()) {
error += gm->error(HybridValues(mu, DiscreteValues(assignment)));
} else if (auto g = hc->asGaussian()) {
error += g->error(mu);
} else if (auto g = hc->asGaussian()) {
error += g->error(mu);
}
}
}
}
}
errors.push_back(error);
});
AlgebraicDecisionTree<Key> errorTree =
DecisionTree<Key, double>(labels, errors);
VectorValues e;
e.insert(0, Vector1(error));
return e;
});
AlgebraicDecisionTree<Key> errorTree = DecisionTree<Key, double>(
errorVectors, [](const VectorValues &v) { return v[0](0); });
// Compute model selection term (with help from ADT methods)
AlgebraicDecisionTree<Key> model_selection_term =