From 8a61c49bb3ca447af3d7bac86c1694a73c0b398e Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 3 Jan 2024 16:32:21 -0500 Subject: [PATCH] add model_selection method to HybridBayesNet --- gtsam/hybrid/HybridBayesNet.cpp | 50 ++++++++++++++++----------------- gtsam/hybrid/HybridBayesNet.h | 13 +++++++++ 2 files changed, 37 insertions(+), 26 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 0352d7962..81f4badea 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -283,35 +283,20 @@ GaussianBayesNetValTree HybridBayesNet::assembleTree() const { } /* ************************************************************************* */ -HybridValues HybridBayesNet::optimize() const { - // Collect all the discrete factors to compute MPE - DiscreteFactorGraph discrete_fg; - +AlgebraicDecisionTree HybridBayesNet::model_selection() const { /* - Perform the integration of L(X;M,Z)P(X|M) - which is the model selection term. + To perform model selection, we need: + q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma)) - By Bayes' rule, P(X|M,Z) ∝ L(X;M,Z)P(X|M), - hence L(X;M,Z)P(X|M) is the unnormalized probabilty of - the joint Gaussian distribution. + If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma)) + thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k)) + = exp(log(q) - log(k)) = exp(-error - log(k)) + = exp(-(error + log(k))), + where error is computed at the corresponding MAP point, gbn.error(mu). - This can be computed by multiplying all the exponentiated errors - of each of the conditionals, which we do below in hybrid case. - */ - /* - To perform model selection, we need: - q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma)) + So we compute (error + log(k)) and exponentiate later + */ - If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma)) - thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k)) - = exp(log(q) - log(k)) = exp(-error - log(k)) - = exp(-(error + log(k))), - where error is computed at the corresponding MAP point, gbn.error(mu). - - So we compute (error + log(k)) and exponentiate later - */ - - std::set discreteKeySet; GaussianBayesNetValTree bnTree = assembleTree(); GaussianBayesNetValTree bn_error = bnTree.apply( @@ -356,6 +341,19 @@ HybridValues HybridBayesNet::optimize() const { [&max_log](const double &x) { return std::exp(x - max_log); }); model_selection = model_selection.normalize(model_selection.sum()); + return model_selection; +} + +/* ************************************************************************* */ +HybridValues HybridBayesNet::optimize() const { + // Collect all the discrete factors to compute MPE + DiscreteFactorGraph discrete_fg; + + // Compute model selection term + AlgebraicDecisionTree model_selection_term = model_selection(); + + // Get the set of all discrete keys involved in model selection + std::set discreteKeySet; for (auto &&conditional : *this) { if (conditional->isDiscrete()) { discrete_fg.push_back(conditional->asDiscrete()); @@ -380,7 +378,7 @@ HybridValues HybridBayesNet::optimize() const { if (discreteKeySet.size() > 0) { discrete_fg.push_back(DecisionTreeFactor( DiscreteKeys(discreteKeySet.begin(), discreteKeySet.end()), - model_selection)); + model_selection_term)); } // Solve for the MPE diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index f0cdbaaf9..8acdd5b1b 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -120,6 +120,19 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { GaussianBayesNetValTree assembleTree() const; + /* + Perform the integration of L(X;M,Z)P(X|M) + which is the model selection term. + + By Bayes' rule, P(X|M,Z) ∝ L(X;M,Z)P(X|M), + hence L(X;M,Z)P(X|M) is the unnormalized probabilty of + the joint Gaussian distribution. + + This can be computed by multiplying all the exponentiated errors + of each of the conditionals. + */ + AlgebraicDecisionTree model_selection() const; + /** * @brief Solve the HybridBayesNet by first computing the MPE of all the * discrete variables and then optimizing the continuous variables based on