From 6f66d04f1425219f51b284c4b4094a15dc2a5791 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 27 Dec 2023 15:46:31 -0500 Subject: [PATCH] handle pruning in model selection --- gtsam/hybrid/HybridBayesNet.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 027bd75d4..0352d7962 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -319,6 +319,13 @@ HybridValues HybridBayesNet::optimize() const { const std::pair &gbnAndValue) { // Compute the X* of each assignment VectorValues mu = gbnAndValue.first.optimize(); + + // mu is empty if gbn had nullptrs + if (mu.size() == 0) { + return std::make_pair(gbnAndValue.first, + std::numeric_limits::max()); + } + // Compute the error for X* and the assignment double error = this->error(HybridValues(mu, DiscreteValues(assignment))); @@ -343,18 +350,12 @@ HybridValues HybridBayesNet::optimize() const { AlgebraicDecisionTree model_selection_term = (errorTree + log_norm_constants) * -1; - // std::cout << "model selection term" << std::endl; - // model_selection_term.print("", DefaultKeyFormatter); - double max_log = model_selection_term.max(); AlgebraicDecisionTree model_selection = DecisionTree( model_selection_term, [&max_log](const double &x) { return std::exp(x - max_log); }); model_selection = model_selection.normalize(model_selection.sum()); - // std::cout << "normalized model selection" << std::endl; - // model_selection.print("", DefaultKeyFormatter); - for (auto &&conditional : *this) { if (conditional->isDiscrete()) { discrete_fg.push_back(conditional->asDiscrete());