diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 89be86056..90951b074 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -342,9 +342,11 @@ HybridValues HybridBayesNet::optimize() const { for (auto &&f : *this) { if (auto gm = dynamic_pointer_cast(f)) { error += gm->error(HybridValues(mu, DiscreteValues(assignment))); + } else if (auto hc = dynamic_pointer_cast(f)) { if (auto gm = hc->asMixture()) { error += gm->error(HybridValues(mu, DiscreteValues(assignment))); + } else if (auto g = hc->asGaussian()) { error += g->error(mu); } @@ -356,11 +358,9 @@ HybridValues HybridBayesNet::optimize() const { AlgebraicDecisionTree errorTree = DecisionTree(labels, errors); - // Compute model selection term - AlgebraicDecisionTree model_selection_term = errorTree.apply( - [&log_norm_constants](const Assignment assignment, double err) { - return -(err + log_norm_constants(assignment)); - }); + // Compute model selection term (with help from ADT methods) + AlgebraicDecisionTree model_selection_term = + (errorTree + log_norm_constants) * -1; // std::cout << "model selection term" << std::endl; // model_selection_term.print("", DefaultKeyFormatter);