handle pruning in model selection
parent
3a89653e91
commit
6f66d04f14
|
|
@ -319,6 +319,13 @@ HybridValues HybridBayesNet::optimize() const {
|
|||
const std::pair<GaussianBayesNet, double> &gbnAndValue) {
|
||||
// Compute the X* of each assignment
|
||||
VectorValues mu = gbnAndValue.first.optimize();
|
||||
|
||||
// mu is empty if gbn had nullptrs
|
||||
if (mu.size() == 0) {
|
||||
return std::make_pair(gbnAndValue.first,
|
||||
std::numeric_limits<double>::max());
|
||||
}
|
||||
|
||||
// Compute the error for X* and the assignment
|
||||
double error =
|
||||
this->error(HybridValues(mu, DiscreteValues(assignment)));
|
||||
|
|
@ -343,18 +350,12 @@ HybridValues HybridBayesNet::optimize() const {
|
|||
AlgebraicDecisionTree<Key> model_selection_term =
|
||||
(errorTree + log_norm_constants) * -1;
|
||||
|
||||
// std::cout << "model selection term" << std::endl;
|
||||
// model_selection_term.print("", DefaultKeyFormatter);
|
||||
|
||||
double max_log = model_selection_term.max();
|
||||
AlgebraicDecisionTree<Key> model_selection = DecisionTree<Key, double>(
|
||||
model_selection_term,
|
||||
[&max_log](const double &x) { return std::exp(x - max_log); });
|
||||
model_selection = model_selection.normalize(model_selection.sum());
|
||||
|
||||
// std::cout << "normalized model selection" << std::endl;
|
||||
// model_selection.print("", DefaultKeyFormatter);
|
||||
|
||||
for (auto &&conditional : *this) {
|
||||
if (conditional->isDiscrete()) {
|
||||
discrete_fg.push_back(conditional->asDiscrete());
|
||||
|
|
|
|||
Loading…
Reference in New Issue