diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 2b0b11e36..23a1c7787 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -26,6 +26,18 @@ static std::mt19937_64 kRandomNumberGenerator(42); namespace gtsam { +using std::dynamic_pointer_cast; + +/* ************************************************************************ */ +// Throw a runtime exception for method specified in string s, +// and conditional f: +static void throwRuntimeError(const std::string &s, + const std::shared_ptr &f) { + auto &fr = *f; + throw std::runtime_error(s + " not implemented for conditional type " + + demangle(typeid(fr).name()) + "."); +} + /* ************************************************************************* */ void HybridBayesNet::print(const std::string &s, const KeyFormatter &formatter) const { @@ -217,6 +229,56 @@ GaussianBayesNet HybridBayesNet::choose( return gbn; } +/* ************************************************************************ */ +static GaussianBayesNetTree addGaussian( + const GaussianBayesNetTree &gfgTree, + const GaussianConditional::shared_ptr &factor) { + // If the decision tree is not initialized, then initialize it. + if (gfgTree.empty()) { + GaussianBayesNet result{factor}; + return GaussianBayesNetTree(result); + } else { + auto add = [&factor](const GaussianBayesNet &graph) { + auto result = graph; + result.push_back(factor); + return result; + }; + return gfgTree.apply(add); + } +} + +/* ************************************************************************ */ +GaussianBayesNetTree HybridBayesNet::assembleTree() const { + GaussianBayesNetTree result; + + for (auto &f : factors_) { + // TODO(dellaert): just use a virtual method defined in HybridFactor. + if (auto gm = dynamic_pointer_cast(f)) { + result = gm->add(result); + } else if (auto hc = dynamic_pointer_cast(f)) { + if (auto gm = hc->asMixture()) { + result = gm->add(result); + } else if (auto g = hc->asGaussian()) { + result = addGaussian(result, g); + } else { + // Has to be discrete. + // TODO(dellaert): in C++20, we can use std::visit. + continue; + } + } else if (dynamic_pointer_cast(f)) { + // Don't do anything for discrete-only factors + // since we want to evaluate continuous values only. + continue; + } else { + // We need to handle the case where the object is actually an + // BayesTreeOrphanWrapper! + throwRuntimeError("HybridBayesNet::assembleTree", f); + } + } + + return result; +} + /* ************************************************************************* */ HybridValues HybridBayesNet::optimize() const { // Collect all the discrete factors to compute MPE @@ -227,74 +289,94 @@ HybridValues HybridBayesNet::optimize() const { AlgebraicDecisionTree error(0.0); std::set discreteKeySet; + // this->print(); + GaussianBayesNetTree bnTree = assembleTree(); + // bnTree.print("", DefaultKeyFormatter, [](const GaussianBayesNet &gbn) { + // gbn.print(); + // return ""; + // }); + /* + Perform the integration of L(X;M,Z)P(X|M) + which is the model selection term. + + By Bayes' rule, P(X|M,Z) ∝ L(X;M,Z)P(X|M), + hence L(X;M,Z)P(X|M) is the unnormalized probabilty of + the joint Gaussian distribution. + + This can be computed by multiplying all the exponentiated errors + of each of the conditionals, which we do below in hybrid case. + */ + /* + To perform model selection, we need: + q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma)) + + If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma)) + thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k)) + = exp(log(q) - log(k)) = exp(-error - log(k)) + = exp(-(error + log(k))) + + So we compute (error + log(k)) and exponentiate later + */ + // Compute the X* of each assignment and use that as the MAP. + DecisionTree x_map( + bnTree, [](const GaussianBayesNet &gbn) { return gbn.optimize(); }); + + // Only compute logNormalizationConstant for now + AlgebraicDecisionTree log_norm_constants = + DecisionTree(bnTree, [](const GaussianBayesNet &gbn) { + if (gbn.size() == 0) { + return -std::numeric_limits::max(); + } + return -gbn.logNormalizationConstant(); + }); + // Compute unnormalized error term and compute model selection term + AlgebraicDecisionTree model_selection_term = log_norm_constants.apply( + [this, &x_map](const Assignment &assignment, double x) { + double error = 0.0; + for (auto &&f : *this) { + if (auto gm = dynamic_pointer_cast(f)) { + error += gm->error( + HybridValues(x_map(assignment), DiscreteValues(assignment))); + } else if (auto hc = dynamic_pointer_cast(f)) { + if (auto gm = hc->asMixture()) { + error += gm->error( + HybridValues(x_map(assignment), DiscreteValues(assignment))); + } else if (auto g = hc->asGaussian()) { + error += g->error(x_map(assignment)); + } + } + } + return -(error + x); + }); + // model_selection_term.print("", DefaultKeyFormatter); + + double max_log = model_selection_term.max(); + AlgebraicDecisionTree model_selection = DecisionTree( + model_selection_term, + [&max_log](const double &x) { return std::exp(x - max_log); }); + model_selection = model_selection.normalize(model_selection.sum()); + // std::cout << "normalized model selection" << std::endl; + // model_selection.print("", DefaultKeyFormatter); + for (auto &&conditional : *this) { if (conditional->isDiscrete()) { discrete_fg.push_back(conditional->asDiscrete()); } else { - /* - Perform the integration of L(X;M,Z)P(X|M) - which is the model selection term. - - By Bayes' rule, P(X|M) ∝ L(X;M,Z)P(X|M), - hence L(X;M,Z)P(X|M) is the unnormalized probabilty of - the joint Gaussian distribution. - - This can be computed by multiplying all the exponentiated errors - of each of the conditionals, which we do below in hybrid case. - */ if (conditional->isContinuous()) { - /* - If we are here, it means there are no discrete variables in - the Bayes net (due to strong elimination ordering). - This is a continuous-only problem hence model selection doesn't matter. - */ - auto gc = conditional->asGaussian(); - for (GaussianConditional::const_iterator frontal = gc->beginFrontals(); - frontal != gc->endFrontals(); ++frontal) { - continuousValues.insert_or_assign(*frontal, - Vector::Zero(gc->getDim(frontal))); - } + // /* + // If we are here, it means there are no discrete variables in + // the Bayes net (due to strong elimination ordering). + // This is a continuous-only problem hence model selection doesn't matter. + // */ + // auto gc = conditional->asGaussian(); + // for (GaussianConditional::const_iterator frontal = gc->beginFrontals(); + // frontal != gc->endFrontals(); ++frontal) { + // continuousValues.insert_or_assign(*frontal, + // Vector::Zero(gc->getDim(frontal))); + // } } else if (conditional->isHybrid()) { auto gm = conditional->asMixture(); - gm->conditionals().apply( - [&continuousValues](const GaussianConditional::shared_ptr &gc) { - if (gc) { - for (GaussianConditional::const_iterator frontal = gc->begin(); - frontal != gc->end(); ++frontal) { - continuousValues.insert_or_assign( - *frontal, Vector::Zero(gc->getDim(frontal))); - } - } - return gc; - }); - - /* - To perform model selection, we need: - q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma)) - - If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma)) - thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k)) - = exp(log(q) - log(k)) = exp(-error - log(k)) - = exp(-(error + log(k))) - - So we compute (error + log(k)) and exponentiate later - */ - - // Add the error and the logNormalization constant to the error - auto err = gm->error(continuousValues) + gm->logNormalizationConstant(); - - // Also compute the sum for discrete probability normalization - // (normalization trick for numerical stability) - double sum = 0.0; - auto absSum = [&sum](const double &e) { - sum += std::abs(e); - return e; - }; - err.visit(absSum); - // Normalize by the sum to prevent overflow - error = error + err.normalize(sum); - // Include the discrete keys std::copy(gm->discreteKeys().begin(), gm->discreteKeys().end(), std::inserter(discreteKeySet, discreteKeySet.end())); @@ -302,12 +384,6 @@ HybridValues HybridBayesNet::optimize() const { } } - error = error * -1; - double max_log = error.max(); - AlgebraicDecisionTree model_selection = DecisionTree( - error, [&max_log](const double &x) { return std::exp(x - max_log); }); - model_selection = model_selection.normalize(model_selection.sum()); - // Only add model_selection if we have discrete keys if (discreteKeySet.size() > 0) { discrete_fg.push_back(DecisionTreeFactor(