Added missing methods

release/4.3a0
Frank Dellaert 2023-01-10 23:13:27 -08:00
parent 426a49dc72
commit eb37d080d4
4 changed files with 32 additions and 2 deletions

View File

@ -33,6 +33,15 @@ bool DiscreteBayesNet::equals(const This& bn, double tol) const {
return Base::equals(bn, tol);
}
/* ************************************************************************* */
double DiscreteBayesNet::logProbability(const DiscreteValues& values) const {
// evaluate all conditionals and add
double result = 0.0;
for (const DiscreteConditional::shared_ptr& conditional : *this)
result += conditional->logProbability(values);
return result;
}
/* ************************************************************************* */
double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
// evaluate all conditionals and multiply

View File

@ -103,6 +103,9 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
return evaluate(values);
}
//** log(evaluate(values)) for given DiscreteValues */
double logProbability(const DiscreteValues & values) const;
/**
* @brief do ancestral sampling
*
@ -137,6 +140,14 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
const DiscreteFactor::Names& names = {}) const;
/// @}
/// @name HybridValues methods.
/// @{
using Base::error; // Expose error(const HybridValues&) method..
using Base::evaluate; // Expose evaluate(const HybridValues&) method..
using Base::logProbability; // Expose logProbability(const HybridValues&)
/// @}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality

View File

@ -222,6 +222,12 @@ class GTSAM_EXPORT DiscreteFactorGraph
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteFactor::Names& names = {}) const;
/// @}
/// @name HybridValues methods.
/// @{
using Base::error; // Expose error(const HybridValues&) method..
/// @}
}; // \ DiscreteFactorGraph

View File

@ -25,7 +25,6 @@
#include <CppUnitLite/TestHarness.h>
#include <iostream>
#include <string>
#include <vector>
@ -101,6 +100,11 @@ TEST(DiscreteBayesNet, Asia) {
DiscreteConditional expected2(Bronchitis % "11/9");
EXPECT(assert_equal(expected2, *chordal->back()));
// Check evaluate and logProbability
auto result = chordal->optimize();
EXPECT_DOUBLES_EQUAL(asia.logProbability(result),
std::log(asia.evaluate(result)), 1e-9);
// add evidence, we were in Asia and we have dyspnea
fg.add(Asia, "0 1");
fg.add(Dyspnea, "0 1");