almost working

release/4.3a0
Varun Agrawal 2023-12-26 00:20:44 -05:00
parent b4f07a0162
commit 6f4343ca94
1 changed files with 20 additions and 25 deletions

View File

@ -329,15 +329,9 @@ HybridValues HybridBayesNet::optimize() const {
return gbn.logNormalizationConstant(); return gbn.logNormalizationConstant();
}); });
// Compute unnormalized error term // Compute errors as VectorValues
std::vector<DiscreteKey> labels; DecisionTree<Key, VectorValues> errorVectors = x_map.apply(
for (auto &&key : x_map.labels()) { [this](const Assignment<Key> &assignment, const VectorValues &mu) {
labels.push_back(std::make_pair(key, 2));
}
std::vector<double> errors;
x_map.visitWith([this, &errors](const Assignment<Key> &assignment,
const VectorValues &mu) {
double error = 0.0; double error = 0.0;
for (auto &&f : *this) { for (auto &&f : *this) {
if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) { if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
@ -352,11 +346,12 @@ HybridValues HybridBayesNet::optimize() const {
} }
} }
} }
errors.push_back(error); VectorValues e;
e.insert(0, Vector1(error));
return e;
}); });
AlgebraicDecisionTree<Key> errorTree = DecisionTree<Key, double>(
AlgebraicDecisionTree<Key> errorTree = errorVectors, [](const VectorValues &v) { return v[0](0); });
DecisionTree<Key, double>(labels, errors);
// Compute model selection term (with help from ADT methods) // Compute model selection term (with help from ADT methods)
AlgebraicDecisionTree<Key> model_selection_term = AlgebraicDecisionTree<Key> model_selection_term =