Added logProbability/evaluate in conditionals/Bayes net

release/4.3a0
Frank Dellaert 2023-01-10 13:12:17 -08:00
parent 69398d0e60
commit 8e29140ff7
6 changed files with 118 additions and 42 deletions

View File

@ -88,6 +88,22 @@ void BayesNet<CONDITIONAL>::saveGraph(const std::string& filename,
of.close(); of.close();
} }
/* ************************************************************************* */
template <class CONDITIONAL>
double BayesNet<CONDITIONAL>::logProbability(const HybridValues& x) const {
double sum = 0.;
for (const auto& gc : *this) {
if (gc) sum += gc->logProbability(x);
}
return sum;
}
/* ************************************************************************* */
template <class CONDITIONAL>
double BayesNet<CONDITIONAL>::evaluate(const HybridValues& x) const {
return exp(-logProbability(x));
}
/* ************************************************************************* */ /* ************************************************************************* */
} // namespace gtsam } // namespace gtsam

View File

@ -25,6 +25,8 @@
namespace gtsam { namespace gtsam {
class HybridValues;
/** /**
* A BayesNet is a tree of conditionals, stored in elimination order. * A BayesNet is a tree of conditionals, stored in elimination order.
* @ingroup inference * @ingroup inference
@ -68,7 +70,6 @@ class BayesNet : public FactorGraph<CONDITIONAL> {
const KeyFormatter& formatter = DefaultKeyFormatter) const override; const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/// @} /// @}
/// @name Graph Display /// @name Graph Display
/// @{ /// @{
@ -86,6 +87,13 @@ class BayesNet : public FactorGraph<CONDITIONAL> {
const KeyFormatter& keyFormatter = DefaultKeyFormatter, const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const; const DotWriter& writer = DotWriter()) const;
/// @}
/// @name HybridValues methods
/// @{
double logProbability(const HybridValues& x) const;
double evaluate(const HybridValues& c) const;
/// @} /// @}
}; };

View File

@ -18,30 +18,42 @@
// \callgraph // \callgraph
#pragma once #pragma once
#include <iostream>
#include <gtsam/inference/Conditional.h> #include <gtsam/inference/Conditional.h>
#include <cmath>
#include <iostream>
namespace gtsam { namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
template<class FACTOR, class DERIVEDFACTOR> template <class FACTOR, class DERIVEDCONDITIONAL>
void Conditional<FACTOR,DERIVEDFACTOR>::print(const std::string& s, const KeyFormatter& formatter) const { void Conditional<FACTOR, DERIVEDCONDITIONAL>::print(
const std::string& s, const KeyFormatter& formatter) const {
std::cout << s << " P("; std::cout << s << " P(";
for(Key key: frontals()) for (Key key : frontals()) std::cout << " " << formatter(key);
std::cout << " " << formatter(key); if (nrParents() > 0) std::cout << " |";
if (nrParents() > 0) for (Key parent : parents()) std::cout << " " << formatter(parent);
std::cout << " |";
for(Key parent: parents())
std::cout << " " << formatter(parent);
std::cout << ")" << std::endl; std::cout << ")" << std::endl;
} }
/* ************************************************************************* */ /* ************************************************************************* */
template<class FACTOR, class DERIVEDFACTOR> template <class FACTOR, class DERIVEDCONDITIONAL>
bool Conditional<FACTOR,DERIVEDFACTOR>::equals(const This& c, double tol) const bool Conditional<FACTOR, DERIVEDCONDITIONAL>::equals(const This& c,
{ double tol) const {
return nrFrontals_ == c.nrFrontals_; return nrFrontals_ == c.nrFrontals_;
} }
/* ************************************************************************* */
template <class FACTOR, class DERIVEDCONDITIONAL>
double Conditional<FACTOR, DERIVEDCONDITIONAL>::logProbability(
const HybridValues& c) const {
throw std::runtime_error("Conditional::logProbability is not implemented");
} }
/* ************************************************************************* */
template <class FACTOR, class DERIVEDCONDITIONAL>
double Conditional<FACTOR, DERIVEDCONDITIONAL>::evaluate(
const HybridValues& c) const {
return exp(static_cast<const DERIVEDCONDITIONAL*>(this)->logProbability(c));
}
} // namespace gtsam

View File

@ -24,13 +24,37 @@
namespace gtsam { namespace gtsam {
class HybridValues; // forward declaration.
/** /**
* Base class for conditional densities. This class iterators and * This is the base class for all conditional distributions/densities,
* access to the frontal and separator keys. * which are implemented as specialized factors. This class does not store any
* data other than its keys. Derived classes store data such as matrices and
* probability tables.
*
* The `evaluate` method is used to evaluate the factor, and together with
* `logProbability` is the main methods that need to be implemented in derived
* classes. These two methods relate to the `error` method in the factor by:
* probability(x) = k exp(-error(x))
* where k is a normalization constant making \int probability(x) == 1.0, and
* logProbability(x) = K - error(x)
* i.e., K = log(K).
*
* There are four broad classes of conditionals that derive from Conditional:
*
* - \b Gaussian conditionals, implemented in \class GaussianConditional, a
* Gaussian density over a set of continuous variables.
* - \b Discrete conditionals, implemented in \class DiscreteConditional, which
* represent a discrete conditional distribution over discrete variables.
* - \b Hybrid conditional densities, such as \class GaussianMixture, which is
* a density over continuous variables given discrete/continuous parents.
* - \b Symbolic factors, used to represent a graph structure, implemented in
* \class SymbolicConditional. Only used for symbolic elimination etc.
* *
* Derived classes *must* redefine the Factor and shared_ptr typedefs to refer * Derived classes *must* redefine the Factor and shared_ptr typedefs to refer
* to the associated factor type and shared_ptr type of the derived class. See * to the associated factor type and shared_ptr type of the derived class. See
* SymbolicConditional and GaussianConditional for examples. * SymbolicConditional and GaussianConditional for examples.
*
* \nosubgrouping * \nosubgrouping
*/ */
template<class FACTOR, class DERIVEDCONDITIONAL> template<class FACTOR, class DERIVEDCONDITIONAL>
@ -78,6 +102,8 @@ namespace gtsam {
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
virtual ~Conditional() {}
/** return the number of frontals */ /** return the number of frontals */
size_t nrFrontals() const { return nrFrontals_; } size_t nrFrontals() const { return nrFrontals_; }
@ -98,6 +124,27 @@ namespace gtsam {
/** return a view of the parent keys */ /** return a view of the parent keys */
Parents parents() const { return boost::make_iterator_range(beginParents(), endParents()); } Parents parents() const { return boost::make_iterator_range(beginParents(), endParents()); }
/**
* All conditional types need to implement a `logProbability` function, for which
* exp(logProbability(x)) = evaluate(x).
*/
virtual double logProbability(const HybridValues& c) const;
/**
* All conditional types need to implement an `evaluate` function, that yields
* a true probability. The default implementation just exponentiates logProbability.
*/
virtual double evaluate(const HybridValues& c) const;
/// Evaluate probability density, sugar.
double operator()(const HybridValues& x) const {
return evaluate(x);
}
/// @}
/// @name Advanced Interface
/// @{
/** Iterator pointing to first frontal key. */ /** Iterator pointing to first frontal key. */
typename FACTOR::const_iterator beginFrontals() const { return asFactor().begin(); } typename FACTOR::const_iterator beginFrontals() const { return asFactor().begin(); }
@ -110,10 +157,6 @@ namespace gtsam {
/** Iterator pointing past the last parent key. */ /** Iterator pointing past the last parent key. */
typename FACTOR::const_iterator endParents() const { return asFactor().end(); } typename FACTOR::const_iterator endParents() const { return asFactor().end(); }
/// @}
/// @name Advanced Interface
/// @{
/** Mutable version of nrFrontals */ /** Mutable version of nrFrontals */
size_t& nrFrontals() { return nrFrontals_; } size_t& nrFrontals() { return nrFrontals_; }

View File

@ -43,4 +43,10 @@ namespace gtsam {
return keys_ == other.keys_; return keys_ == other.keys_;
} }
/* ************************************************************************* */
double Factor::error(const HybridValues& c) const {
throw std::runtime_error("Factor::error is not implemented");
}
} }

View File

@ -44,10 +44,7 @@ namespace gtsam {
* *
* The `error` method is used to evaluate the factor, and is the only method * The `error` method is used to evaluate the factor, and is the only method
* that is required to be implemented in derived classes, although it has a * that is required to be implemented in derived classes, although it has a
* default implementation that throws an exception. The meaning of the error * default implementation that throws an exception.
* is slightly different for factors and conditionals: in the former it is the
* negative log-likelihood, and in the latter it is the negative log of the
* properly normalized conditional distribution or density.
* *
* There are five broad classes of factors that derive from Factor: * There are five broad classes of factors that derive from Factor:
* *
@ -55,15 +52,12 @@ namespace gtsam {
* represent a nonlinear likelihood function over a set of variables. * represent a nonlinear likelihood function over a set of variables.
* - \b Gaussian factors, such as \class JacobianFactor and \class HessianFactor, which * - \b Gaussian factors, such as \class JacobianFactor and \class HessianFactor, which
* represent a Gaussian likelihood over a set of variables. * represent a Gaussian likelihood over a set of variables.
* A \class GaussianConditional, which represent a Gaussian density over a set of * - \b Discrete factors, such as \class DiscreteFactor and \class DecisionTreeFactor, which
* variables conditioned on another set of variables.
* - \b Discrete factors, such as \class DiscreteFactor and \class DiscreteConditional, which
* represent a discrete distribution over a set of variables. * represent a discrete distribution over a set of variables.
* - \b Hybrid factors, such as \class HybridFactor, which represent a mixture of * - \b Hybrid factors, such as \class HybridFactor, which represent a mixture of
* Gaussian and discrete distributions over a set of variables. * Gaussian and discrete distributions over a set of variables.
* - \b Symbolic factors, used to represent a graph structure, such as * - \b Symbolic factors, used to represent a graph structure, such as
* \class SymbolicFactor and \class SymbolicConditional. They do not override the * \class SymbolicFactor, only used for symbolic elimination etc.
* `error` method, and are used only for symbolic elimination etc.
* *
* Note that derived classes must also redefine the `This` and `shared_ptr` * Note that derived classes must also redefine the `This` and `shared_ptr`
* typedefs. See JacobianFactor, etc. for examples. * typedefs. See JacobianFactor, etc. for examples.
@ -154,11 +148,8 @@ namespace gtsam {
/** /**
* All factor types need to implement an error function. * All factor types need to implement an error function.
* In factor graphs, this is the negative log-likelihood. * In factor graphs, this is the negative log-likelihood.
* In Bayes nets, it is the negative log density, i.e., properly normalized.
*/ */
virtual double error(const HybridValues& c) const { virtual double error(const HybridValues& c) const;
throw std::runtime_error("Factor::error is not implemented");
}
/** /**
* @return the number of variables involved in this factor * @return the number of variables involved in this factor