diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 23a1c7787..89be86056 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -285,8 +285,6 @@ HybridValues HybridBayesNet::optimize() const { DiscreteFactorGraph discrete_fg; VectorValues continuousValues; - // Error values for each hybrid factor - AlgebraicDecisionTree error(0.0); std::set discreteKeySet; // this->print(); @@ -313,7 +311,8 @@ HybridValues HybridBayesNet::optimize() const { If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma)) thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k)) = exp(log(q) - log(k)) = exp(-error - log(k)) - = exp(-(error + log(k))) + = exp(-(error + log(k))), + where error is computed at the corresponding MAP point, gbn.error(mu). So we compute (error + log(k)) and exponentiate later */ @@ -325,29 +324,45 @@ HybridValues HybridBayesNet::optimize() const { AlgebraicDecisionTree log_norm_constants = DecisionTree(bnTree, [](const GaussianBayesNet &gbn) { if (gbn.size() == 0) { - return -std::numeric_limits::max(); + return 0.0; } - return -gbn.logNormalizationConstant(); + return gbn.logNormalizationConstant(); }); - // Compute unnormalized error term and compute model selection term - AlgebraicDecisionTree model_selection_term = log_norm_constants.apply( - [this, &x_map](const Assignment &assignment, double x) { - double error = 0.0; - for (auto &&f : *this) { - if (auto gm = dynamic_pointer_cast(f)) { - error += gm->error( - HybridValues(x_map(assignment), DiscreteValues(assignment))); - } else if (auto hc = dynamic_pointer_cast(f)) { - if (auto gm = hc->asMixture()) { - error += gm->error( - HybridValues(x_map(assignment), DiscreteValues(assignment))); - } else if (auto g = hc->asGaussian()) { - error += g->error(x_map(assignment)); - } - } + + // Compute unnormalized error term + std::vector labels; + for (auto &&key : x_map.labels()) { + labels.push_back(std::make_pair(key, 2)); + } + + std::vector errors; + x_map.visitWith([this, &errors](const Assignment &assignment, + const VectorValues &mu) { + double error = 0.0; + for (auto &&f : *this) { + if (auto gm = dynamic_pointer_cast(f)) { + error += gm->error(HybridValues(mu, DiscreteValues(assignment))); + } else if (auto hc = dynamic_pointer_cast(f)) { + if (auto gm = hc->asMixture()) { + error += gm->error(HybridValues(mu, DiscreteValues(assignment))); + } else if (auto g = hc->asGaussian()) { + error += g->error(mu); } - return -(error + x); + } + } + errors.push_back(error); + }); + + AlgebraicDecisionTree errorTree = + DecisionTree(labels, errors); + + // Compute model selection term + AlgebraicDecisionTree model_selection_term = errorTree.apply( + [&log_norm_constants](const Assignment assignment, double err) { + return -(err + log_norm_constants(assignment)); }); + + // std::cout << "model selection term" << std::endl; // model_selection_term.print("", DefaultKeyFormatter); double max_log = model_selection_term.max(); @@ -355,6 +370,7 @@ HybridValues HybridBayesNet::optimize() const { model_selection_term, [&max_log](const double &x) { return std::exp(x - max_log); }); model_selection = model_selection.normalize(model_selection.sum()); + // std::cout << "normalized model selection" << std::endl; // model_selection.print("", DefaultKeyFormatter); @@ -363,17 +379,11 @@ HybridValues HybridBayesNet::optimize() const { discrete_fg.push_back(conditional->asDiscrete()); } else { if (conditional->isContinuous()) { - // /* - // If we are here, it means there are no discrete variables in - // the Bayes net (due to strong elimination ordering). - // This is a continuous-only problem hence model selection doesn't matter. - // */ - // auto gc = conditional->asGaussian(); - // for (GaussianConditional::const_iterator frontal = gc->beginFrontals(); - // frontal != gc->endFrontals(); ++frontal) { - // continuousValues.insert_or_assign(*frontal, - // Vector::Zero(gc->getDim(frontal))); - // } + /* + If we are here, it means there are no discrete variables in + the Bayes net (due to strong elimination ordering). + This is a continuous-only problem hence model selection doesn't matter. + */ } else if (conditional->isHybrid()) { auto gm = conditional->asMixture();