almost working

release/4.3a0
Varun Agrawal 2023-12-26 00:20:44 -05:00
parent b4f07a0162
commit 6f4343ca94
1 changed files with 20 additions and 25 deletions

View File

@ -329,34 +329,29 @@ HybridValues HybridBayesNet::optimize() const {
return gbn.logNormalizationConstant(); return gbn.logNormalizationConstant();
}); });
// Compute unnormalized error term // Compute errors as VectorValues
std::vector<DiscreteKey> labels; DecisionTree<Key, VectorValues> errorVectors = x_map.apply(
for (auto &&key : x_map.labels()) { [this](const Assignment<Key> &assignment, const VectorValues &mu) {
labels.push_back(std::make_pair(key, 2)); double error = 0.0;
} for (auto &&f : *this) {
if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
error += gm->error(HybridValues(mu, DiscreteValues(assignment)));
std::vector<double> errors; } else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
x_map.visitWith([this, &errors](const Assignment<Key> &assignment, if (auto gm = hc->asMixture()) {
const VectorValues &mu) { error += gm->error(HybridValues(mu, DiscreteValues(assignment)));
double error = 0.0;
for (auto &&f : *this) {
if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
error += gm->error(HybridValues(mu, DiscreteValues(assignment)));
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) { } else if (auto g = hc->asGaussian()) {
if (auto gm = hc->asMixture()) { error += g->error(mu);
error += gm->error(HybridValues(mu, DiscreteValues(assignment))); }
}
} else if (auto g = hc->asGaussian()) {
error += g->error(mu);
} }
} VectorValues e;
} e.insert(0, Vector1(error));
errors.push_back(error); return e;
}); });
AlgebraicDecisionTree<Key> errorTree = DecisionTree<Key, double>(
AlgebraicDecisionTree<Key> errorTree = errorVectors, [](const VectorValues &v) { return v[0](0); });
DecisionTree<Key, double>(labels, errors);
// Compute model selection term (with help from ADT methods) // Compute model selection term (with help from ADT methods)
AlgebraicDecisionTree<Key> model_selection_term = AlgebraicDecisionTree<Key> model_selection_term =