diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 90951b074..0ff4e342b 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -329,34 +329,29 @@ HybridValues HybridBayesNet::optimize() const { return gbn.logNormalizationConstant(); }); - // Compute unnormalized error term - std::vector labels; - for (auto &&key : x_map.labels()) { - labels.push_back(std::make_pair(key, 2)); - } + // Compute errors as VectorValues + DecisionTree errorVectors = x_map.apply( + [this](const Assignment &assignment, const VectorValues &mu) { + double error = 0.0; + for (auto &&f : *this) { + if (auto gm = dynamic_pointer_cast(f)) { + error += gm->error(HybridValues(mu, DiscreteValues(assignment))); - std::vector errors; - x_map.visitWith([this, &errors](const Assignment &assignment, - const VectorValues &mu) { - double error = 0.0; - for (auto &&f : *this) { - if (auto gm = dynamic_pointer_cast(f)) { - error += gm->error(HybridValues(mu, DiscreteValues(assignment))); + } else if (auto hc = dynamic_pointer_cast(f)) { + if (auto gm = hc->asMixture()) { + error += gm->error(HybridValues(mu, DiscreteValues(assignment))); - } else if (auto hc = dynamic_pointer_cast(f)) { - if (auto gm = hc->asMixture()) { - error += gm->error(HybridValues(mu, DiscreteValues(assignment))); - - } else if (auto g = hc->asGaussian()) { - error += g->error(mu); + } else if (auto g = hc->asGaussian()) { + error += g->error(mu); + } + } } - } - } - errors.push_back(error); - }); - - AlgebraicDecisionTree errorTree = - DecisionTree(labels, errors); + VectorValues e; + e.insert(0, Vector1(error)); + return e; + }); + AlgebraicDecisionTree errorTree = DecisionTree( + errorVectors, [](const VectorValues &v) { return v[0](0); }); // Compute model selection term (with help from ADT methods) AlgebraicDecisionTree model_selection_term =