diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index d343decc8..4c41e952b 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -265,16 +265,10 @@ GaussianBayesNetValTree HybridBayesNet::assembleTree() const { /* ************************************************************************* */ AlgebraicDecisionTree HybridBayesNet::modelSelection() const { /* - To perform model selection, we need: - q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma)) - - If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma)) - thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k)) - = exp(log(q) - log(k)) = exp(-error - log(k)) - = exp(-(error + log(k))), + To perform model selection, we need: q(mu; M, Z) = exp(-error) where error is computed at the corresponding MAP point, gbn.error(mu). - So we compute (error + log(k)) and exponentiate later + So we compute (-error) and exponentiate later */ GaussianBayesNetValTree bnTree = assembleTree(); @@ -301,13 +295,16 @@ AlgebraicDecisionTree HybridBayesNet::modelSelection() const { auto trees = unzip(bn_error); AlgebraicDecisionTree errorTree = trees.second; - // Only compute logNormalizationConstant - AlgebraicDecisionTree log_norm_constants = - computeLogNormConstants(bnTree); - // Compute model selection term (with help from ADT methods) - AlgebraicDecisionTree modelSelectionTerm = - computeModelSelectionTerm(errorTree, log_norm_constants); + AlgebraicDecisionTree modelSelectionTerm = errorTree * -1; + + // Exponentiate using our scheme + double max_log = modelSelectionTerm.max(); + modelSelectionTerm = DecisionTree( + modelSelectionTerm, + [&max_log](const double &x) { return std::exp(x - max_log); }); + modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum()); + return modelSelectionTerm; } @@ -531,20 +528,4 @@ AlgebraicDecisionTree computeLogNormConstants( return log_norm_constants; } -/* ************************************************************************* */ -AlgebraicDecisionTree computeModelSelectionTerm( - const AlgebraicDecisionTree &errorTree, - const AlgebraicDecisionTree &log_norm_constants) { - AlgebraicDecisionTree modelSelectionTerm = - (errorTree + log_norm_constants) * -1; - - double max_log = modelSelectionTerm.max(); - modelSelectionTerm = DecisionTree( - modelSelectionTerm, - [&max_log](const double &x) { return std::exp(x - max_log); }); - modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum()); - - return modelSelectionTerm; -} - } // namespace gtsam diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index d4fd3db71..c678c418c 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -128,22 +128,17 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { */ GaussianBayesNetValTree assembleTree() const; - /* - Compute L(M;Z), the likelihood of the discrete model M - given the measurements Z. - This is called the model selection term. - - To do so, we perform the integration of L(M;Z) ∝ L(X;M,Z)P(X|M). - - By Bayes' rule, P(X|M,Z) ∝ L(X;M,Z)P(X|M), - hence L(X;M,Z)P(X|M) is the unnormalized probabilty of - the joint Gaussian distribution. - - This can be computed by multiplying all the exponentiated errors - of each of the conditionals. - - Return a tree where each leaf value is L(M_i;Z). - */ + /** + * @brief Compute the model selection term q(μ_X; M, Z) + * given the error for each discrete assignment. + * + * The q(μ) terms are obtained as a result of elimination + * as part of the separator factor. + * + * Perform normalization to handle underflow issues. + * + * @return AlgebraicDecisionTree + */ AlgebraicDecisionTree modelSelection() const; /** @@ -301,18 +296,4 @@ GaussianBayesNetTree addGaussian(const GaussianBayesNetTree &gbnTree, AlgebraicDecisionTree computeLogNormConstants( const GaussianBayesNetValTree &bnTree); -/** - * @brief Compute the model selection term L(M; Z, X) given the error - * and log normalization constants. - * - * Perform normalization to handle underflow issues. - * - * @param errorTree - * @param log_norm_constants - * @return AlgebraicDecisionTree - */ -AlgebraicDecisionTree computeModelSelectionTerm( - const AlgebraicDecisionTree &errorTree, - const AlgebraicDecisionTree &log_norm_constants); - } // namespace gtsam diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index cdd30d398..394c88928 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -111,13 +111,15 @@ AlgebraicDecisionTree HybridBayesTree::modelSelection() const { auto trees = unzip(bn_error); AlgebraicDecisionTree errorTree = trees.second; - // Only compute logNormalizationConstant - AlgebraicDecisionTree log_norm_constants = - computeLogNormConstants(bnTree); - // Compute model selection term (with help from ADT methods) - AlgebraicDecisionTree modelSelectionTerm = - computeModelSelectionTerm(errorTree, log_norm_constants); + AlgebraicDecisionTree modelSelectionTerm = errorTree * -1; + + // Exponentiate using our scheme + double max_log = modelSelectionTerm.max(); + modelSelectionTerm = DecisionTree( + modelSelectionTerm, + [&max_log](const double& x) { return std::exp(x - max_log); }); + modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum()); return modelSelectionTerm; }