handle pruning in model selection

release/4.3a0
Varun Agrawal 2023-12-27 15:46:31 -05:00
parent 3a89653e91
commit 6f66d04f14
1 changed files with 7 additions and 6 deletions

View File

@ -319,6 +319,13 @@ HybridValues HybridBayesNet::optimize() const {
const std::pair<GaussianBayesNet, double> &gbnAndValue) {
// Compute the X* of each assignment
VectorValues mu = gbnAndValue.first.optimize();
// mu is empty if gbn had nullptrs
if (mu.size() == 0) {
return std::make_pair(gbnAndValue.first,
std::numeric_limits<double>::max());
}
// Compute the error for X* and the assignment
double error =
this->error(HybridValues(mu, DiscreteValues(assignment)));
@ -343,18 +350,12 @@ HybridValues HybridBayesNet::optimize() const {
AlgebraicDecisionTree<Key> model_selection_term =
(errorTree + log_norm_constants) * -1;
// std::cout << "model selection term" << std::endl;
// model_selection_term.print("", DefaultKeyFormatter);
double max_log = model_selection_term.max();
AlgebraicDecisionTree<Key> model_selection = DecisionTree<Key, double>(
model_selection_term,
[&max_log](const double &x) { return std::exp(x - max_log); });
model_selection = model_selection.normalize(model_selection.sum());
// std::cout << "normalized model selection" << std::endl;
// model_selection.print("", DefaultKeyFormatter);
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
discrete_fg.push_back(conditional->asDiscrete());