From b4f07a01629f5ccb8d9413a82c9b9453627be28d Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 25 Dec 2023 23:11:49 -0500 Subject: [PATCH] cleaner model selection computation --- gtsam/hybrid/HybridBayesNet.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 89be86056..90951b074 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -342,9 +342,11 @@ HybridValues HybridBayesNet::optimize() const { for (auto &&f : *this) { if (auto gm = dynamic_pointer_cast(f)) { error += gm->error(HybridValues(mu, DiscreteValues(assignment))); + } else if (auto hc = dynamic_pointer_cast(f)) { if (auto gm = hc->asMixture()) { error += gm->error(HybridValues(mu, DiscreteValues(assignment))); + } else if (auto g = hc->asGaussian()) { error += g->error(mu); } @@ -356,11 +358,9 @@ HybridValues HybridBayesNet::optimize() const { AlgebraicDecisionTree errorTree = DecisionTree(labels, errors); - // Compute model selection term - AlgebraicDecisionTree model_selection_term = errorTree.apply( - [&log_norm_constants](const Assignment assignment, double err) { - return -(err + log_norm_constants(assignment)); - }); + // Compute model selection term (with help from ADT methods) + AlgebraicDecisionTree model_selection_term = + (errorTree + log_norm_constants) * -1; // std::cout << "model selection term" << std::endl; // model_selection_term.print("", DefaultKeyFormatter);