Added logProbability/evaluate in conditionals/Bayes net
parent
69398d0e60
commit
8e29140ff7
|
|
@ -88,6 +88,22 @@ void BayesNet<CONDITIONAL>::saveGraph(const std::string& filename,
|
|||
of.close();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class CONDITIONAL>
|
||||
double BayesNet<CONDITIONAL>::logProbability(const HybridValues& x) const {
|
||||
double sum = 0.;
|
||||
for (const auto& gc : *this) {
|
||||
if (gc) sum += gc->logProbability(x);
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class CONDITIONAL>
|
||||
double BayesNet<CONDITIONAL>::evaluate(const HybridValues& x) const {
|
||||
return exp(-logProbability(x));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
class HybridValues;
|
||||
|
||||
/**
|
||||
* A BayesNet is a tree of conditionals, stored in elimination order.
|
||||
* @ingroup inference
|
||||
|
|
@ -68,7 +70,6 @@ class BayesNet : public FactorGraph<CONDITIONAL> {
|
|||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/// @}
|
||||
|
||||
/// @name Graph Display
|
||||
/// @{
|
||||
|
||||
|
|
@ -86,6 +87,13 @@ class BayesNet : public FactorGraph<CONDITIONAL> {
|
|||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const DotWriter& writer = DotWriter()) const;
|
||||
|
||||
/// @}
|
||||
/// @name HybridValues methods
|
||||
/// @{
|
||||
|
||||
double logProbability(const HybridValues& x) const;
|
||||
double evaluate(const HybridValues& c) const;
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -18,30 +18,42 @@
|
|||
// \callgraph
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include <gtsam/inference/Conditional.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class FACTOR, class DERIVEDFACTOR>
|
||||
void Conditional<FACTOR,DERIVEDFACTOR>::print(const std::string& s, const KeyFormatter& formatter) const {
|
||||
template <class FACTOR, class DERIVEDCONDITIONAL>
|
||||
void Conditional<FACTOR, DERIVEDCONDITIONAL>::print(
|
||||
const std::string& s, const KeyFormatter& formatter) const {
|
||||
std::cout << s << " P(";
|
||||
for(Key key: frontals())
|
||||
std::cout << " " << formatter(key);
|
||||
if (nrParents() > 0)
|
||||
std::cout << " |";
|
||||
for(Key parent: parents())
|
||||
std::cout << " " << formatter(parent);
|
||||
for (Key key : frontals()) std::cout << " " << formatter(key);
|
||||
if (nrParents() > 0) std::cout << " |";
|
||||
for (Key parent : parents()) std::cout << " " << formatter(parent);
|
||||
std::cout << ")" << std::endl;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class FACTOR, class DERIVEDFACTOR>
|
||||
bool Conditional<FACTOR,DERIVEDFACTOR>::equals(const This& c, double tol) const
|
||||
{
|
||||
template <class FACTOR, class DERIVEDCONDITIONAL>
|
||||
bool Conditional<FACTOR, DERIVEDCONDITIONAL>::equals(const This& c,
|
||||
double tol) const {
|
||||
return nrFrontals_ == c.nrFrontals_;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class FACTOR, class DERIVEDCONDITIONAL>
|
||||
double Conditional<FACTOR, DERIVEDCONDITIONAL>::logProbability(
|
||||
const HybridValues& c) const {
|
||||
throw std::runtime_error("Conditional::logProbability is not implemented");
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class FACTOR, class DERIVEDCONDITIONAL>
|
||||
double Conditional<FACTOR, DERIVEDCONDITIONAL>::evaluate(
|
||||
const HybridValues& c) const {
|
||||
return exp(static_cast<const DERIVEDCONDITIONAL*>(this)->logProbability(c));
|
||||
}
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -24,13 +24,37 @@
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
class HybridValues; // forward declaration.
|
||||
|
||||
/**
|
||||
* Base class for conditional densities. This class iterators and
|
||||
* access to the frontal and separator keys.
|
||||
* This is the base class for all conditional distributions/densities,
|
||||
* which are implemented as specialized factors. This class does not store any
|
||||
* data other than its keys. Derived classes store data such as matrices and
|
||||
* probability tables.
|
||||
*
|
||||
* The `evaluate` method is used to evaluate the factor, and together with
|
||||
* `logProbability` is the main methods that need to be implemented in derived
|
||||
* classes. These two methods relate to the `error` method in the factor by:
|
||||
* probability(x) = k exp(-error(x))
|
||||
* where k is a normalization constant making \int probability(x) == 1.0, and
|
||||
* logProbability(x) = K - error(x)
|
||||
* i.e., K = log(K).
|
||||
*
|
||||
* There are four broad classes of conditionals that derive from Conditional:
|
||||
*
|
||||
* - \b Gaussian conditionals, implemented in \class GaussianConditional, a
|
||||
* Gaussian density over a set of continuous variables.
|
||||
* - \b Discrete conditionals, implemented in \class DiscreteConditional, which
|
||||
* represent a discrete conditional distribution over discrete variables.
|
||||
* - \b Hybrid conditional densities, such as \class GaussianMixture, which is
|
||||
* a density over continuous variables given discrete/continuous parents.
|
||||
* - \b Symbolic factors, used to represent a graph structure, implemented in
|
||||
* \class SymbolicConditional. Only used for symbolic elimination etc.
|
||||
*
|
||||
* Derived classes *must* redefine the Factor and shared_ptr typedefs to refer
|
||||
* to the associated factor type and shared_ptr type of the derived class. See
|
||||
* SymbolicConditional and GaussianConditional for examples.
|
||||
*
|
||||
* \nosubgrouping
|
||||
*/
|
||||
template<class FACTOR, class DERIVEDCONDITIONAL>
|
||||
|
|
@ -78,6 +102,8 @@ namespace gtsam {
|
|||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
virtual ~Conditional() {}
|
||||
|
||||
/** return the number of frontals */
|
||||
size_t nrFrontals() const { return nrFrontals_; }
|
||||
|
||||
|
|
@ -98,6 +124,27 @@ namespace gtsam {
|
|||
/** return a view of the parent keys */
|
||||
Parents parents() const { return boost::make_iterator_range(beginParents(), endParents()); }
|
||||
|
||||
/**
|
||||
* All conditional types need to implement a `logProbability` function, for which
|
||||
* exp(logProbability(x)) = evaluate(x).
|
||||
*/
|
||||
virtual double logProbability(const HybridValues& c) const;
|
||||
|
||||
/**
|
||||
* All conditional types need to implement an `evaluate` function, that yields
|
||||
* a true probability. The default implementation just exponentiates logProbability.
|
||||
*/
|
||||
virtual double evaluate(const HybridValues& c) const;
|
||||
|
||||
/// Evaluate probability density, sugar.
|
||||
double operator()(const HybridValues& x) const {
|
||||
return evaluate(x);
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
||||
/** Iterator pointing to first frontal key. */
|
||||
typename FACTOR::const_iterator beginFrontals() const { return asFactor().begin(); }
|
||||
|
||||
|
|
@ -110,10 +157,6 @@ namespace gtsam {
|
|||
/** Iterator pointing past the last parent key. */
|
||||
typename FACTOR::const_iterator endParents() const { return asFactor().end(); }
|
||||
|
||||
/// @}
|
||||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
||||
/** Mutable version of nrFrontals */
|
||||
size_t& nrFrontals() { return nrFrontals_; }
|
||||
|
||||
|
|
|
|||
|
|
@ -43,4 +43,10 @@ namespace gtsam {
|
|||
return keys_ == other.keys_;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double Factor::error(const HybridValues& c) const {
|
||||
throw std::runtime_error("Factor::error is not implemented");
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,10 +44,7 @@ namespace gtsam {
|
|||
*
|
||||
* The `error` method is used to evaluate the factor, and is the only method
|
||||
* that is required to be implemented in derived classes, although it has a
|
||||
* default implementation that throws an exception. The meaning of the error
|
||||
* is slightly different for factors and conditionals: in the former it is the
|
||||
* negative log-likelihood, and in the latter it is the negative log of the
|
||||
* properly normalized conditional distribution or density.
|
||||
* default implementation that throws an exception.
|
||||
*
|
||||
* There are five broad classes of factors that derive from Factor:
|
||||
*
|
||||
|
|
@ -55,15 +52,12 @@ namespace gtsam {
|
|||
* represent a nonlinear likelihood function over a set of variables.
|
||||
* - \b Gaussian factors, such as \class JacobianFactor and \class HessianFactor, which
|
||||
* represent a Gaussian likelihood over a set of variables.
|
||||
* A \class GaussianConditional, which represent a Gaussian density over a set of
|
||||
* variables conditioned on another set of variables.
|
||||
* - \b Discrete factors, such as \class DiscreteFactor and \class DiscreteConditional, which
|
||||
* - \b Discrete factors, such as \class DiscreteFactor and \class DecisionTreeFactor, which
|
||||
* represent a discrete distribution over a set of variables.
|
||||
* - \b Hybrid factors, such as \class HybridFactor, which represent a mixture of
|
||||
* Gaussian and discrete distributions over a set of variables.
|
||||
* - \b Symbolic factors, used to represent a graph structure, such as
|
||||
* \class SymbolicFactor and \class SymbolicConditional. They do not override the
|
||||
* `error` method, and are used only for symbolic elimination etc.
|
||||
* \class SymbolicFactor, only used for symbolic elimination etc.
|
||||
*
|
||||
* Note that derived classes must also redefine the `This` and `shared_ptr`
|
||||
* typedefs. See JacobianFactor, etc. for examples.
|
||||
|
|
@ -154,11 +148,8 @@ namespace gtsam {
|
|||
/**
|
||||
* All factor types need to implement an error function.
|
||||
* In factor graphs, this is the negative log-likelihood.
|
||||
* In Bayes nets, it is the negative log density, i.e., properly normalized.
|
||||
*/
|
||||
virtual double error(const HybridValues& c) const {
|
||||
throw std::runtime_error("Factor::error is not implemented");
|
||||
}
|
||||
virtual double error(const HybridValues& c) const;
|
||||
|
||||
/**
|
||||
* @return the number of variables involved in this factor
|
||||
|
|
|
|||
Loading…
Reference in New Issue