improved naming and documentation

release/4.3a0
Varun Agrawal 2024-01-07 15:49:33 -05:00
parent a80b5d4f5a
commit 0430fee377
2 changed files with 15 additions and 12 deletions

View File

@ -281,7 +281,7 @@ GaussianBayesNetValTree HybridBayesNet::assembleTree() const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::model_selection() const { AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection() const {
/* /*
To perform model selection, we need: To perform model selection, we need:
q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma)) q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma))
@ -330,16 +330,16 @@ AlgebraicDecisionTree<Key> HybridBayesNet::model_selection() const {
}); });
// Compute model selection term (with help from ADT methods) // Compute model selection term (with help from ADT methods)
AlgebraicDecisionTree<Key> model_selection_term = AlgebraicDecisionTree<Key> modelSelectionTerm =
(errorTree + log_norm_constants) * -1; (errorTree + log_norm_constants) * -1;
double max_log = model_selection_term.max(); double max_log = modelSelectionTerm.max();
AlgebraicDecisionTree<Key> model_selection = DecisionTree<Key, double>( modelSelectionTerm = DecisionTree<Key, double>(
model_selection_term, modelSelectionTerm,
[&max_log](const double &x) { return std::exp(x - max_log); }); [&max_log](const double &x) { return std::exp(x - max_log); });
model_selection = model_selection.normalize(model_selection.sum()); modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
return model_selection; return modelSelectionTerm;
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -348,7 +348,7 @@ HybridValues HybridBayesNet::optimize() const {
DiscreteFactorGraph discrete_fg; DiscreteFactorGraph discrete_fg;
// Compute model selection term // Compute model selection term
AlgebraicDecisionTree<Key> model_selection_term = model_selection(); AlgebraicDecisionTree<Key> modelSelectionTerm = modelSelection();
// Get the set of all discrete keys involved in model selection // Get the set of all discrete keys involved in model selection
std::set<DiscreteKey> discreteKeySet; std::set<DiscreteKey> discreteKeySet;
@ -376,7 +376,7 @@ HybridValues HybridBayesNet::optimize() const {
if (discreteKeySet.size() > 0) { if (discreteKeySet.size() > 0) {
discrete_fg.push_back(DecisionTreeFactor( discrete_fg.push_back(DecisionTreeFactor(
DiscreteKeys(discreteKeySet.begin(), discreteKeySet.end()), DiscreteKeys(discreteKeySet.begin(), discreteKeySet.end()),
model_selection_term)); modelSelectionTerm));
} }
// Solve for the MPE // Solve for the MPE

View File

@ -129,8 +129,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
GaussianBayesNetValTree assembleTree() const; GaussianBayesNetValTree assembleTree() const;
/* /*
Perform the integration of L(X;M,Z)P(X|M) Compute L(M;Z), the likelihood of the discrete model M
which is the model selection term. given the measurements Z.
This is called the model selection term.
To do so, we perform the integration of L(M;Z) L(X;M,Z)P(X|M).
By Bayes' rule, P(X|M,Z) L(X;M,Z)P(X|M), By Bayes' rule, P(X|M,Z) L(X;M,Z)P(X|M),
hence L(X;M,Z)P(X|M) is the unnormalized probabilty of hence L(X;M,Z)P(X|M) is the unnormalized probabilty of
@ -139,7 +142,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
This can be computed by multiplying all the exponentiated errors This can be computed by multiplying all the exponentiated errors
of each of the conditionals. of each of the conditionals.
*/ */
AlgebraicDecisionTree<Key> model_selection() const; AlgebraicDecisionTree<Key> modelSelection() const;
/** /**
* @brief Solve the HybridBayesNet by first computing the MPE of all the * @brief Solve the HybridBayesNet by first computing the MPE of all the