diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index ccc52585e..7cd190cc2 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -33,6 +33,15 @@ bool DiscreteBayesNet::equals(const This& bn, double tol) const { return Base::equals(bn, tol); } +/* ************************************************************************* */ +double DiscreteBayesNet::logProbability(const DiscreteValues& values) const { + // evaluate all conditionals and add + double result = 0.0; + for (const DiscreteConditional::shared_ptr& conditional : *this) + result += conditional->logProbability(values); + return result; +} + /* ************************************************************************* */ double DiscreteBayesNet::evaluate(const DiscreteValues& values) const { // evaluate all conditionals and multiply diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index 8ded827ce..700f81d7e 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -103,6 +103,9 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet { return evaluate(values); } + //** log(evaluate(values)) for given DiscreteValues */ + double logProbability(const DiscreteValues & values) const; + /** * @brief do ancestral sampling * @@ -136,7 +139,15 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet { std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, const DiscreteFactor::Names& names = {}) const; - ///@} + /// @} + /// @name HybridValues methods. + /// @{ + + using Base::error; // Expose error(const HybridValues&) method.. + using Base::evaluate; // Expose evaluate(const HybridValues&) method.. + using Base::logProbability; // Expose logProbability(const HybridValues&) + + /// @} #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 /// @name Deprecated functionality diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index e665ea88b..3fbb64d50 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -222,6 +222,12 @@ class GTSAM_EXPORT DiscreteFactorGraph std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, const DiscreteFactor::Names& names = {}) const; + /// @} + /// @name HybridValues methods. + /// @{ + + using Base::error; // Expose error(const HybridValues&) method.. + /// @} }; // \ DiscreteFactorGraph diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 6303a5487..38c64b13b 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -25,7 +25,6 @@ #include - #include #include #include @@ -101,6 +100,11 @@ TEST(DiscreteBayesNet, Asia) { DiscreteConditional expected2(Bronchitis % "11/9"); EXPECT(assert_equal(expected2, *chordal->back())); + // Check evaluate and logProbability + auto result = chordal->optimize(); + EXPECT_DOUBLES_EQUAL(asia.logProbability(result), + std::log(asia.evaluate(result)), 1e-9); + // add evidence, we were in Asia and we have dyspnea fg.add(Asia, "0 1"); fg.add(Dyspnea, "0 1");