handle pruning in model selection
parent
3a89653e91
commit
6f66d04f14
|
|
@ -319,6 +319,13 @@ HybridValues HybridBayesNet::optimize() const {
|
||||||
const std::pair<GaussianBayesNet, double> &gbnAndValue) {
|
const std::pair<GaussianBayesNet, double> &gbnAndValue) {
|
||||||
// Compute the X* of each assignment
|
// Compute the X* of each assignment
|
||||||
VectorValues mu = gbnAndValue.first.optimize();
|
VectorValues mu = gbnAndValue.first.optimize();
|
||||||
|
|
||||||
|
// mu is empty if gbn had nullptrs
|
||||||
|
if (mu.size() == 0) {
|
||||||
|
return std::make_pair(gbnAndValue.first,
|
||||||
|
std::numeric_limits<double>::max());
|
||||||
|
}
|
||||||
|
|
||||||
// Compute the error for X* and the assignment
|
// Compute the error for X* and the assignment
|
||||||
double error =
|
double error =
|
||||||
this->error(HybridValues(mu, DiscreteValues(assignment)));
|
this->error(HybridValues(mu, DiscreteValues(assignment)));
|
||||||
|
|
@ -343,18 +350,12 @@ HybridValues HybridBayesNet::optimize() const {
|
||||||
AlgebraicDecisionTree<Key> model_selection_term =
|
AlgebraicDecisionTree<Key> model_selection_term =
|
||||||
(errorTree + log_norm_constants) * -1;
|
(errorTree + log_norm_constants) * -1;
|
||||||
|
|
||||||
// std::cout << "model selection term" << std::endl;
|
|
||||||
// model_selection_term.print("", DefaultKeyFormatter);
|
|
||||||
|
|
||||||
double max_log = model_selection_term.max();
|
double max_log = model_selection_term.max();
|
||||||
AlgebraicDecisionTree<Key> model_selection = DecisionTree<Key, double>(
|
AlgebraicDecisionTree<Key> model_selection = DecisionTree<Key, double>(
|
||||||
model_selection_term,
|
model_selection_term,
|
||||||
[&max_log](const double &x) { return std::exp(x - max_log); });
|
[&max_log](const double &x) { return std::exp(x - max_log); });
|
||||||
model_selection = model_selection.normalize(model_selection.sum());
|
model_selection = model_selection.normalize(model_selection.sum());
|
||||||
|
|
||||||
// std::cout << "normalized model selection" << std::endl;
|
|
||||||
// model_selection.print("", DefaultKeyFormatter);
|
|
||||||
|
|
||||||
for (auto &&conditional : *this) {
|
for (auto &&conditional : *this) {
|
||||||
if (conditional->isDiscrete()) {
|
if (conditional->isDiscrete()) {
|
||||||
discrete_fg.push_back(conditional->asDiscrete());
|
discrete_fg.push_back(conditional->asDiscrete());
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue