diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 23ee72215..eeacc929b 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -227,24 +227,6 @@ GaussianBayesNet HybridBayesNet::choose( return gbn; } -/* ************************************************************************ */ -static GaussianBayesNetTree addGaussian( - const GaussianBayesNetTree &gfgTree, - const GaussianConditional::shared_ptr &factor) { - // If the decision tree is not initialized, then initialize it. - if (gfgTree.empty()) { - GaussianBayesNet result{factor}; - return GaussianBayesNetTree(result); - } else { - auto add = [&factor](const GaussianBayesNet &graph) { - auto result = graph; - result.push_back(factor); - return result; - }; - return gfgTree.apply(add); - } -} - /* ************************************************************************ */ GaussianBayesNetValTree HybridBayesNet::assembleTree() const { GaussianBayesNetTree result; @@ -320,25 +302,12 @@ AlgebraicDecisionTree HybridBayesNet::modelSelection() const { AlgebraicDecisionTree errorTree = trees.second; // Only compute logNormalizationConstant - AlgebraicDecisionTree log_norm_constants = DecisionTree( - bnTree, [](const std::pair &gbnAndValue) { - GaussianBayesNet gbn = gbnAndValue.first; - if (gbn.size() == 0) { - return 0.0; - } - return gbn.logNormalizationConstant(); - }); + AlgebraicDecisionTree log_norm_constants = + computeLogNormConstants(bnTree); // Compute model selection term (with help from ADT methods) AlgebraicDecisionTree modelSelectionTerm = - (errorTree + log_norm_constants) * -1; - - double max_log = modelSelectionTerm.max(); - modelSelectionTerm = DecisionTree( - modelSelectionTerm, - [&max_log](const double &x) { return std::exp(x - max_log); }); - modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum()); - + computeModelSelectionTerm(errorTree, log_norm_constants); return modelSelectionTerm; } @@ -530,4 +499,52 @@ HybridGaussianFactorGraph HybridBayesNet::toFactorGraph( return fg; } +/* ************************************************************************ */ +GaussianBayesNetTree addGaussian( + const GaussianBayesNetTree &gbnTree, + const GaussianConditional::shared_ptr &factor) { + // If the decision tree is not initialized, then initialize it. + if (gbnTree.empty()) { + GaussianBayesNet result{factor}; + return GaussianBayesNetTree(result); + } else { + auto add = [&factor](const GaussianBayesNet &graph) { + auto result = graph; + result.push_back(factor); + return result; + }; + return gbnTree.apply(add); + } +} + +/* ************************************************************************* */ +AlgebraicDecisionTree computeLogNormConstants( + const GaussianBayesNetValTree &bnTree) { + AlgebraicDecisionTree log_norm_constants = DecisionTree( + bnTree, [](const std::pair &gbnAndValue) { + GaussianBayesNet gbn = gbnAndValue.first; + if (gbn.size() == 0) { + return 0.0; + } + return gbn.logNormalizationConstant(); + }); + return log_norm_constants; +} + +/* ************************************************************************* */ +AlgebraicDecisionTree computeModelSelectionTerm( + const AlgebraicDecisionTree &errorTree, + const AlgebraicDecisionTree &log_norm_constants) { + AlgebraicDecisionTree modelSelectionTerm = + (errorTree + log_norm_constants) * -1; + + double max_log = modelSelectionTerm.max(); + modelSelectionTerm = DecisionTree( + modelSelectionTerm, + [&max_log](const double &x) { return std::exp(x - max_log); }); + modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum()); + + return modelSelectionTerm; +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index fe9010b1f..d4fd3db71 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -141,7 +141,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { This can be computed by multiplying all the exponentiated errors of each of the conditionals. - + Return a tree where each leaf value is L(M_i;Z). */ AlgebraicDecisionTree modelSelection() const; @@ -280,4 +280,39 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { template <> struct traits : public Testable {}; +/** + * @brief Add a Gaussian conditional to each node of the GaussianBayesNetTree + * + * @param gbnTree + * @param factor + * @return GaussianBayesNetTree + */ +GaussianBayesNetTree addGaussian(const GaussianBayesNetTree &gbnTree, + const GaussianConditional::shared_ptr &factor); + +/** + * @brief Compute the (logarithmic) normalization constant for each Bayes + * network in the tree. + * + * @param bnTree A tree of Bayes networks in each leaf. The tree encodes a + * discrete assignment yielding the Bayes net. + * @return AlgebraicDecisionTree + */ +AlgebraicDecisionTree computeLogNormConstants( + const GaussianBayesNetValTree &bnTree); + +/** + * @brief Compute the model selection term L(M; Z, X) given the error + * and log normalization constants. + * + * Perform normalization to handle underflow issues. + * + * @param errorTree + * @param log_norm_constants + * @return AlgebraicDecisionTree + */ +AlgebraicDecisionTree computeModelSelectionTerm( + const AlgebraicDecisionTree &errorTree, + const AlgebraicDecisionTree &log_norm_constants); + } // namespace gtsam