From 4b2a22eaa564e4593cc1c28f12629f35e218598e Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 16 Jan 2024 15:01:31 -0500 Subject: [PATCH] model selection for HybridBayesTree --- gtsam/hybrid/HybridBayesTree.cpp | 105 +++++++++++++++++++++++++++++-- gtsam/hybrid/HybridBayesTree.h | 45 +++++++++++++ 2 files changed, 146 insertions(+), 4 deletions(-) diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index ae8fa0378..cdd30d398 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -38,19 +38,116 @@ bool HybridBayesTree::equals(const This& other, double tol) const { return Base::equals(other, tol); } +GaussianBayesNetTree& HybridBayesTree::addCliqueToTree( + const sharedClique& clique, GaussianBayesNetTree& result) const { + // Perform bottom-up inclusion + for (sharedClique child : clique->children) { + result = addCliqueToTree(child, result); + } + + auto f = clique->conditional(); + + if (auto hc = std::dynamic_pointer_cast(f)) { + if (auto gm = hc->asMixture()) { + result = gm->add(result); + } else if (auto g = hc->asGaussian()) { + result = addGaussian(result, g); + } else { + // Has to be discrete, which we don't add. + } + } + return result; +} + +/* ************************************************************************ */ +GaussianBayesNetValTree HybridBayesTree::assembleTree() const { + GaussianBayesNetTree result; + for (auto&& root : roots_) { + result = addCliqueToTree(root, result); + } + + GaussianBayesNetValTree resultTree(result, [](const GaussianBayesNet& gbn) { + return std::make_pair(gbn, 0.0); + }); + return resultTree; +} + +/* ************************************************************************* */ +AlgebraicDecisionTree HybridBayesTree::modelSelection() const { + /* + To perform model selection, we need: + q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma)) + + If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma)) + thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k)) + = exp(log(q) - log(k)) = exp(-error - log(k)) + = exp(-(error + log(k))), + where error is computed at the corresponding MAP point, gbt.error(mu). + + So we compute (error + log(k)) and exponentiate later + */ + + GaussianBayesNetValTree bnTree = assembleTree(); + + GaussianBayesNetValTree bn_error = bnTree.apply( + [this](const Assignment& assignment, + const std::pair& gbnAndValue) { + // Compute the X* of each assignment + VectorValues mu = gbnAndValue.first.optimize(); + + // mu is empty if gbn had nullptrs + if (mu.size() == 0) { + return std::make_pair(gbnAndValue.first, + std::numeric_limits::max()); + } + + // Compute the error for X* and the assignment + double error = + this->error(HybridValues(mu, DiscreteValues(assignment))); + + return std::make_pair(gbnAndValue.first, error); + }); + + auto trees = unzip(bn_error); + AlgebraicDecisionTree errorTree = trees.second; + + // Only compute logNormalizationConstant + AlgebraicDecisionTree log_norm_constants = + computeLogNormConstants(bnTree); + + // Compute model selection term (with help from ADT methods) + AlgebraicDecisionTree modelSelectionTerm = + computeModelSelectionTerm(errorTree, log_norm_constants); + + return modelSelectionTerm; +} + /* ************************************************************************* */ HybridValues HybridBayesTree::optimize() const { - DiscreteBayesNet dbn; + DiscreteFactorGraph discrete_fg; DiscreteValues mpe; + // Compute model selection term + AlgebraicDecisionTree modelSelectionTerm = modelSelection(); + auto root = roots_.at(0); // Access the clique and get the underlying hybrid conditional HybridConditional::shared_ptr root_conditional = root->conditional(); - // The root should be discrete only, we compute the MPE + // Get the set of all discrete keys involved in model selection + std::set discreteKeySet; + + // The root should be discrete only, we compute the MPE if (root_conditional->isDiscrete()) { - dbn.push_back(root_conditional->asDiscrete()); - mpe = DiscreteFactorGraph(dbn).optimize(); + discrete_fg.push_back(root_conditional->asDiscrete()); + + // Only add model_selection if we have discrete keys + if (discreteKeySet.size() > 0) { + discrete_fg.push_back(DecisionTreeFactor( + DiscreteKeys(discreteKeySet.begin(), discreteKeySet.end()), + modelSelectionTerm)); + } + mpe = discrete_fg.optimize(); } else { throw std::runtime_error( "HybridBayesTree root is not discrete-only. Please check elimination " diff --git a/gtsam/hybrid/HybridBayesTree.h b/gtsam/hybrid/HybridBayesTree.h index f91e16cbf..8327b7f31 100644 --- a/gtsam/hybrid/HybridBayesTree.h +++ b/gtsam/hybrid/HybridBayesTree.h @@ -84,6 +84,51 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree { */ GaussianBayesTree choose(const DiscreteValues& assignment) const; + /** Error for all conditionals. */ + double error(const HybridValues& values) const { + return HybridGaussianFactorGraph(*this).error(values); + } + + /** + * @brief Helper function to add a clique of hybrid conditionals to the passed + * in GaussianBayesNetTree. Operates recursively on the clique in a bottom-up + * fashion, adding the children first. + * + * @param clique The + * @param result + * @return GaussianBayesNetTree& + */ + GaussianBayesNetTree& addCliqueToTree(const sharedClique& clique, + GaussianBayesNetTree& result) const; + + /** + * @brief Assemble a DecisionTree of (GaussianBayesTree, double) leaves for + * each discrete assignment. + * The included double value is used to make + * constructing the model selection term cleaner and more efficient. + * + * @return GaussianBayesNetValTree + */ + GaussianBayesNetValTree assembleTree() const; + + /* + Compute L(M;Z), the likelihood of the discrete model M + given the measurements Z. + This is called the model selection term. + + To do so, we perform the integration of L(M;Z) ∝ L(X;M,Z)P(X|M). + + By Bayes' rule, P(X|M,Z) ∝ L(X;M,Z)P(X|M), + hence L(X;M,Z)P(X|M) is the unnormalized probabilty of + the joint Gaussian distribution. + + This can be computed by multiplying all the exponentiated errors + of each of the conditionals. + + Return a tree where each leaf value is L(M_i;Z). + */ + AlgebraicDecisionTree modelSelection() const; + /** * @brief Optimize the hybrid Bayes tree by computing the MPE for the current * set of discrete variables and using it to compute the best continuous