diff --git a/gtsam/inference/BayesNet-inst.h b/gtsam/inference/BayesNet-inst.h index f43b4025e..f06008c88 100644 --- a/gtsam/inference/BayesNet-inst.h +++ b/gtsam/inference/BayesNet-inst.h @@ -88,6 +88,22 @@ void BayesNet::saveGraph(const std::string& filename, of.close(); } +/* ************************************************************************* */ +template +double BayesNet::logProbability(const HybridValues& x) const { + double sum = 0.; + for (const auto& gc : *this) { + if (gc) sum += gc->logProbability(x); + } + return sum; +} + +/* ************************************************************************* */ +template +double BayesNet::evaluate(const HybridValues& x) const { + return exp(-logProbability(x)); +} + /* ************************************************************************* */ } // namespace gtsam diff --git a/gtsam/inference/BayesNet.h b/gtsam/inference/BayesNet.h index 4704d2873..4d266df46 100644 --- a/gtsam/inference/BayesNet.h +++ b/gtsam/inference/BayesNet.h @@ -25,6 +25,8 @@ namespace gtsam { +class HybridValues; + /** * A BayesNet is a tree of conditionals, stored in elimination order. * @ingroup inference @@ -68,7 +70,6 @@ class BayesNet : public FactorGraph { const KeyFormatter& formatter = DefaultKeyFormatter) const override; /// @} - /// @name Graph Display /// @{ @@ -86,6 +87,13 @@ class BayesNet : public FactorGraph { const KeyFormatter& keyFormatter = DefaultKeyFormatter, const DotWriter& writer = DotWriter()) const; + /// @} + /// @name HybridValues methods + /// @{ + + double logProbability(const HybridValues& x) const; + double evaluate(const HybridValues& c) const; + /// @} }; diff --git a/gtsam/inference/Conditional-inst.h b/gtsam/inference/Conditional-inst.h index 9879a582c..30433263c 100644 --- a/gtsam/inference/Conditional-inst.h +++ b/gtsam/inference/Conditional-inst.h @@ -18,30 +18,42 @@ // \callgraph #pragma once -#include - #include +#include +#include + namespace gtsam { - /* ************************************************************************* */ - template - void Conditional::print(const std::string& s, const KeyFormatter& formatter) const { - std::cout << s << " P("; - for(Key key: frontals()) - std::cout << " " << formatter(key); - if (nrParents() > 0) - std::cout << " |"; - for(Key parent: parents()) - std::cout << " " << formatter(parent); - std::cout << ")" << std::endl; - } - - /* ************************************************************************* */ - template - bool Conditional::equals(const This& c, double tol) const - { - return nrFrontals_ == c.nrFrontals_; - } - +/* ************************************************************************* */ +template +void Conditional::print( + const std::string& s, const KeyFormatter& formatter) const { + std::cout << s << " P("; + for (Key key : frontals()) std::cout << " " << formatter(key); + if (nrParents() > 0) std::cout << " |"; + for (Key parent : parents()) std::cout << " " << formatter(parent); + std::cout << ")" << std::endl; } + +/* ************************************************************************* */ +template +bool Conditional::equals(const This& c, + double tol) const { + return nrFrontals_ == c.nrFrontals_; +} + +/* ************************************************************************* */ +template +double Conditional::logProbability( + const HybridValues& c) const { + throw std::runtime_error("Conditional::logProbability is not implemented"); +} + +/* ************************************************************************* */ +template +double Conditional::evaluate( + const HybridValues& c) const { + return exp(static_cast(this)->logProbability(c)); +} +} // namespace gtsam diff --git a/gtsam/inference/Conditional.h b/gtsam/inference/Conditional.h index 7594da78d..9083c5c1a 100644 --- a/gtsam/inference/Conditional.h +++ b/gtsam/inference/Conditional.h @@ -24,13 +24,37 @@ namespace gtsam { + class HybridValues; // forward declaration. + /** - * Base class for conditional densities. This class iterators and - * access to the frontal and separator keys. + * This is the base class for all conditional distributions/densities, + * which are implemented as specialized factors. This class does not store any + * data other than its keys. Derived classes store data such as matrices and + * probability tables. + * + * The `evaluate` method is used to evaluate the factor, and together with + * `logProbability` is the main methods that need to be implemented in derived + * classes. These two methods relate to the `error` method in the factor by: + * probability(x) = k exp(-error(x)) + * where k is a normalization constant making \int probability(x) == 1.0, and + * logProbability(x) = K - error(x) + * i.e., K = log(K). + * + * There are four broad classes of conditionals that derive from Conditional: + * + * - \b Gaussian conditionals, implemented in \class GaussianConditional, a + * Gaussian density over a set of continuous variables. + * - \b Discrete conditionals, implemented in \class DiscreteConditional, which + * represent a discrete conditional distribution over discrete variables. + * - \b Hybrid conditional densities, such as \class GaussianMixture, which is + * a density over continuous variables given discrete/continuous parents. + * - \b Symbolic factors, used to represent a graph structure, implemented in + * \class SymbolicConditional. Only used for symbolic elimination etc. * * Derived classes *must* redefine the Factor and shared_ptr typedefs to refer * to the associated factor type and shared_ptr type of the derived class. See * SymbolicConditional and GaussianConditional for examples. + * * \nosubgrouping */ template @@ -78,6 +102,8 @@ namespace gtsam { /// @name Standard Interface /// @{ + virtual ~Conditional() {} + /** return the number of frontals */ size_t nrFrontals() const { return nrFrontals_; } @@ -98,6 +124,27 @@ namespace gtsam { /** return a view of the parent keys */ Parents parents() const { return boost::make_iterator_range(beginParents(), endParents()); } + /** + * All conditional types need to implement a `logProbability` function, for which + * exp(logProbability(x)) = evaluate(x). + */ + virtual double logProbability(const HybridValues& c) const; + + /** + * All conditional types need to implement an `evaluate` function, that yields + * a true probability. The default implementation just exponentiates logProbability. + */ + virtual double evaluate(const HybridValues& c) const; + + /// Evaluate probability density, sugar. + double operator()(const HybridValues& x) const { + return evaluate(x); + } + + /// @} + /// @name Advanced Interface + /// @{ + /** Iterator pointing to first frontal key. */ typename FACTOR::const_iterator beginFrontals() const { return asFactor().begin(); } @@ -110,10 +157,6 @@ namespace gtsam { /** Iterator pointing past the last parent key. */ typename FACTOR::const_iterator endParents() const { return asFactor().end(); } - /// @} - /// @name Advanced Interface - /// @{ - /** Mutable version of nrFrontals */ size_t& nrFrontals() { return nrFrontals_; } diff --git a/gtsam/inference/Factor.cpp b/gtsam/inference/Factor.cpp index 6fe96c777..2590d7b59 100644 --- a/gtsam/inference/Factor.cpp +++ b/gtsam/inference/Factor.cpp @@ -43,4 +43,10 @@ namespace gtsam { return keys_ == other.keys_; } + /* ************************************************************************* */ + double Factor::error(const HybridValues& c) const { + throw std::runtime_error("Factor::error is not implemented"); + } + + } diff --git a/gtsam/inference/Factor.h b/gtsam/inference/Factor.h index 2fa5e3f88..f59a5972d 100644 --- a/gtsam/inference/Factor.h +++ b/gtsam/inference/Factor.h @@ -44,10 +44,7 @@ namespace gtsam { * * The `error` method is used to evaluate the factor, and is the only method * that is required to be implemented in derived classes, although it has a - * default implementation that throws an exception. The meaning of the error - * is slightly different for factors and conditionals: in the former it is the - * negative log-likelihood, and in the latter it is the negative log of the - * properly normalized conditional distribution or density. + * default implementation that throws an exception. * * There are five broad classes of factors that derive from Factor: * @@ -55,15 +52,12 @@ namespace gtsam { * represent a nonlinear likelihood function over a set of variables. * - \b Gaussian factors, such as \class JacobianFactor and \class HessianFactor, which * represent a Gaussian likelihood over a set of variables. - * A \class GaussianConditional, which represent a Gaussian density over a set of - * variables conditioned on another set of variables. - * - \b Discrete factors, such as \class DiscreteFactor and \class DiscreteConditional, which + * - \b Discrete factors, such as \class DiscreteFactor and \class DecisionTreeFactor, which * represent a discrete distribution over a set of variables. * - \b Hybrid factors, such as \class HybridFactor, which represent a mixture of * Gaussian and discrete distributions over a set of variables. * - \b Symbolic factors, used to represent a graph structure, such as - * \class SymbolicFactor and \class SymbolicConditional. They do not override the - * `error` method, and are used only for symbolic elimination etc. + * \class SymbolicFactor, only used for symbolic elimination etc. * * Note that derived classes must also redefine the `This` and `shared_ptr` * typedefs. See JacobianFactor, etc. for examples. @@ -154,11 +148,8 @@ namespace gtsam { /** * All factor types need to implement an error function. * In factor graphs, this is the negative log-likelihood. - * In Bayes nets, it is the negative log density, i.e., properly normalized. */ - virtual double error(const HybridValues& c) const { - throw std::runtime_error("Factor::error is not implemented"); - } + virtual double error(const HybridValues& c) const; /** * @return the number of variables involved in this factor