diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index e9870e6bf..23ee72215 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -281,7 +281,7 @@ GaussianBayesNetValTree HybridBayesNet::assembleTree() const { } /* ************************************************************************* */ -AlgebraicDecisionTree HybridBayesNet::model_selection() const { +AlgebraicDecisionTree HybridBayesNet::modelSelection() const { /* To perform model selection, we need: q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma)) @@ -330,16 +330,16 @@ AlgebraicDecisionTree HybridBayesNet::model_selection() const { }); // Compute model selection term (with help from ADT methods) - AlgebraicDecisionTree model_selection_term = + AlgebraicDecisionTree modelSelectionTerm = (errorTree + log_norm_constants) * -1; - double max_log = model_selection_term.max(); - AlgebraicDecisionTree model_selection = DecisionTree( - model_selection_term, + double max_log = modelSelectionTerm.max(); + modelSelectionTerm = DecisionTree( + modelSelectionTerm, [&max_log](const double &x) { return std::exp(x - max_log); }); - model_selection = model_selection.normalize(model_selection.sum()); + modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum()); - return model_selection; + return modelSelectionTerm; } /* ************************************************************************* */ @@ -348,7 +348,7 @@ HybridValues HybridBayesNet::optimize() const { DiscreteFactorGraph discrete_fg; // Compute model selection term - AlgebraicDecisionTree model_selection_term = model_selection(); + AlgebraicDecisionTree modelSelectionTerm = modelSelection(); // Get the set of all discrete keys involved in model selection std::set discreteKeySet; @@ -376,7 +376,7 @@ HybridValues HybridBayesNet::optimize() const { if (discreteKeySet.size() > 0) { discrete_fg.push_back(DecisionTreeFactor( DiscreteKeys(discreteKeySet.begin(), discreteKeySet.end()), - model_selection_term)); + modelSelectionTerm)); } // Solve for the MPE diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 757e60aea..9d16a4e14 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -129,8 +129,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { GaussianBayesNetValTree assembleTree() const; /* - Perform the integration of L(X;M,Z)P(X|M) - which is the model selection term. + Compute L(M;Z), the likelihood of the discrete model M + given the measurements Z. + This is called the model selection term. + + To do so, we perform the integration of L(M;Z) ∝ L(X;M,Z)P(X|M). By Bayes' rule, P(X|M,Z) ∝ L(X;M,Z)P(X|M), hence L(X;M,Z)P(X|M) is the unnormalized probabilty of @@ -139,7 +142,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { This can be computed by multiplying all the exponentiated errors of each of the conditionals. */ - AlgebraicDecisionTree model_selection() const; + AlgebraicDecisionTree modelSelection() const; /** * @brief Solve the HybridBayesNet by first computing the MPE of all the