diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index f57cda28d..e5748366c 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -197,6 +197,35 @@ AlgebraicDecisionTree HybridBayesNet::errorTree( return result; } +/* ************************************************************************* */ +double HybridBayesNet::negLogConstant() const { + double negLogNormConst = 0.0; + // Iterate over each conditional. + for (auto &&conditional : *this) { + negLogNormConst += conditional->negLogConstant(); + } + return negLogNormConst; +} + +/* ************************************************************************* */ +double HybridBayesNet::negLogConstant(const DiscreteValues &discrete) const { + double negLogNormConst = 0.0; + // Iterate over each conditional. + for (auto &&conditional : *this) { + if (auto gm = conditional->asHybrid()) { + negLogNormConst += gm->choose(discrete)->negLogConstant(); + } else if (auto gc = conditional->asGaussian()) { + negLogNormConst += gc->negLogConstant(); + } else if (auto dc = conditional->asDiscrete()) { + negLogNormConst += dc->choose(discrete)->negLogConstant(); + } else { + throw std::runtime_error( + "Unknown conditional type when computing negLogConstant"); + } + } + return negLogNormConst; +} + /* ************************************************************************* */ AlgebraicDecisionTree HybridBayesNet::discretePosterior( const VectorValues &continuousValues) const { diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index bba301be2..451f7f675 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -237,6 +237,23 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { using BayesNet::logProbability; // expose HybridValues version + /** + * @brief Get the negative log of the normalization constant corresponding + * to the joint density represented by this Bayes net. + * + * @return double + */ + double negLogConstant() const; + + /** + * @brief Get the negative log of the normalization constant + * corresponding to the joint Gaussian density represented by + * this Bayes net indexed by `discrete`. + * + * @return double + */ + double negLogConstant(const DiscreteValues &discrete) const; + /** * @brief Compute normalized posterior P(M|X=x) and return as a tree. *