Added missing methods

release/4.3a0
Frank Dellaert 2023-01-10 23:13:27 -08:00
parent 426a49dc72
commit eb37d080d4
4 changed files with 32 additions and 2 deletions

View File

@ -33,6 +33,15 @@ bool DiscreteBayesNet::equals(const This& bn, double tol) const {
return Base::equals(bn, tol); return Base::equals(bn, tol);
} }
/* ************************************************************************* */
double DiscreteBayesNet::logProbability(const DiscreteValues& values) const {
// evaluate all conditionals and add
double result = 0.0;
for (const DiscreteConditional::shared_ptr& conditional : *this)
result += conditional->logProbability(values);
return result;
}
/* ************************************************************************* */ /* ************************************************************************* */
double DiscreteBayesNet::evaluate(const DiscreteValues& values) const { double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
// evaluate all conditionals and multiply // evaluate all conditionals and multiply

View File

@ -103,6 +103,9 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
return evaluate(values); return evaluate(values);
} }
//** log(evaluate(values)) for given DiscreteValues */
double logProbability(const DiscreteValues & values) const;
/** /**
* @brief do ancestral sampling * @brief do ancestral sampling
* *
@ -136,7 +139,15 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteFactor::Names& names = {}) const; const DiscreteFactor::Names& names = {}) const;
///@} /// @}
/// @name HybridValues methods.
/// @{
using Base::error; // Expose error(const HybridValues&) method..
using Base::evaluate; // Expose evaluate(const HybridValues&) method..
using Base::logProbability; // Expose logProbability(const HybridValues&)
/// @}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality /// @name Deprecated functionality

View File

@ -222,6 +222,12 @@ class GTSAM_EXPORT DiscreteFactorGraph
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteFactor::Names& names = {}) const; const DiscreteFactor::Names& names = {}) const;
/// @}
/// @name HybridValues methods.
/// @{
using Base::error; // Expose error(const HybridValues&) method..
/// @} /// @}
}; // \ DiscreteFactorGraph }; // \ DiscreteFactorGraph

View File

@ -25,7 +25,6 @@
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <vector> #include <vector>
@ -101,6 +100,11 @@ TEST(DiscreteBayesNet, Asia) {
DiscreteConditional expected2(Bronchitis % "11/9"); DiscreteConditional expected2(Bronchitis % "11/9");
EXPECT(assert_equal(expected2, *chordal->back())); EXPECT(assert_equal(expected2, *chordal->back()));
// Check evaluate and logProbability
auto result = chordal->optimize();
EXPECT_DOUBLES_EQUAL(asia.logProbability(result),
std::log(asia.evaluate(result)), 1e-9);
// add evidence, we were in Asia and we have dyspnea // add evidence, we were in Asia and we have dyspnea
fg.add(Asia, "0 1"); fg.add(Asia, "0 1");
fg.add(Dyspnea, "0 1"); fg.add(Dyspnea, "0 1");