From a38ebece1e1279a5ea9fb30e25d51fd9479d2b4b Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Thu, 12 Nov 2009 04:52:40 +0000 Subject: [PATCH] New compilation unit that collects generic, templated inference methods that go between factor graphs and Bayes nets. These used to be in BayesNet-inl.h and FactorGraph-inl.h --- cpp/inference-inl.h | 87 +++++++++++++++++++++++++++++++++++++++++++ cpp/inference.h | 52 ++++++++++++++++++++++++++ cpp/testInference.cpp | 54 +++++++++++++++++++++++++++ 3 files changed, 193 insertions(+) create mode 100644 cpp/inference-inl.h create mode 100644 cpp/inference.h create mode 100644 cpp/testInference.cpp diff --git a/cpp/inference-inl.h b/cpp/inference-inl.h new file mode 100644 index 000000000..5c8f90934 --- /dev/null +++ b/cpp/inference-inl.h @@ -0,0 +1,87 @@ +/** + * @file inference-inl.h + * @brief inference template definitions + * @author Frank Dellaert + */ + +#include "inference.h" +#include "FactorGraph-inl.h" +#include "BayesNet-inl.h" + +using namespace std; + +namespace gtsam { + + /* ************************************************************************* */ + /* eliminate one node from the factor graph */ + /* ************************************************************************* */ + template + boost::shared_ptr eliminateOne(FactorGraph& graph, const string& key) { + + // combine the factors of all nodes connected to the variable to be eliminated + // if no factors are connected to key, returns an empty factor + boost::shared_ptr joint_factor = removeAndCombineFactors(graph,key); + + // eliminate that joint factor + boost::shared_ptr factor; + boost::shared_ptr conditional; + boost::tie(conditional, factor) = joint_factor->eliminate(key); + + // add new factor on separator back into the graph + if (!factor->empty()) graph.push_back(factor); + + // return the conditional Gaussian + return conditional; + } + + /* ************************************************************************* */ + // This doubly templated function is generic. There is a LinearFactorGraph + // version that returns a more specific GaussianBayesNet. + // Note, you will need to include this file to instantiate the function. + /* ************************************************************************* */ + template + BayesNet eliminate(FactorGraph& factorGraph, const Ordering& ordering) + { + BayesNet bayesNet; // empty + + BOOST_FOREACH(string key, ordering) { + boost::shared_ptr cg = eliminateOne(factorGraph,key); + bayesNet.push_back(cg); + } + + return bayesNet; + } + + /* ************************************************************************* */ + template + pair< BayesNet, FactorGraph > + factor(const BayesNet& bn, const Ordering& keys) { + // Convert to factor graph + FactorGraph factorGraph(bn); + + // Get the keys of all variables and remove all keys we want the marginal for + Ordering ord = bn.ordering(); + BOOST_FOREACH(string key, keys) ord.remove(key); // TODO: O(n*k), faster possible? + + // eliminate partially, + BayesNet conditional = eliminate(factorGraph,ord); + + // at this moment, the factor graph only encodes P(keys) + return make_pair(conditional,factorGraph); + } + + /* ************************************************************************* */ + template + FactorGraph marginalize(const BayesNet& bn, const Ordering& keys) { + + // factor P(X,Y) as P(X|Y)P(Y), where Y corresponds to keys + pair< BayesNet, FactorGraph > factors = + gtsam::factor(bn,keys); + + // throw away conditional, return marginal P(Y) + return factors.second; + } + + /* ************************************************************************* */ + +} // namespace gtsam diff --git a/cpp/inference.h b/cpp/inference.h new file mode 100644 index 000000000..e3321ba47 --- /dev/null +++ b/cpp/inference.h @@ -0,0 +1,52 @@ +/** + * @file inference.h + * @brief Contains *generic* inference algorithms that convert between templated + * graphical models, i.e., factor graphs, Bayes nets, and Bayes trees + * @author Frank Dellaert + */ + +#pragma once + +#include "FactorGraph.h" +#include "BayesNet.h" + +namespace gtsam { + + class Ordering; + + // ELIMINATE: FACTOR GRAPH -> BAYES NET + + /** + * Eliminate a single node yielding a Conditional + * Eliminates the factors from the factor graph through findAndRemoveFactors + * and adds a new factor on the separator to the factor graph + */ + template + boost::shared_ptr + eliminateOne(FactorGraph& factorGraph, const std::string& key); + + /** + * eliminate factor graph using the given (not necessarily complete) + * ordering, yielding a chordal Bayes net and (partially eliminated) FG + */ + template + BayesNet eliminate(FactorGraph& factorGraph, const Ordering& ordering); + + // FACTOR/MARGINALIZE: BAYES NET -> FACTOR GRAPH + + /** + * Factor P(X) as P(not keys|keys) P(keys) + * @return P(not keys|keys) as an incomplete BayesNet, and P(keys) as a factor graph + */ + template + std::pair< BayesNet, FactorGraph > + factor(const BayesNet& bn, const Ordering& keys); + + /** + * integrate out all except ordering, might be inefficient as the ordering + * will simply be the current ordering with the keys put in the back + */ + template + FactorGraph marginalize(const BayesNet& bn, const Ordering& keys); + +} /// namespace gtsam diff --git a/cpp/testInference.cpp b/cpp/testInference.cpp new file mode 100644 index 000000000..1aa88b4af --- /dev/null +++ b/cpp/testInference.cpp @@ -0,0 +1,54 @@ +/** + * @file testInference.cpp + * @brief Unit tests for functionality declared in inference.h + * @author Frank Dellaert + */ + +#include + +#include "Ordering.h" +#include "smallExample.h" +#include "inference-inl.h" + +using namespace std; +using namespace gtsam; + +/* ************************************************************************* */ +// The tests below test the *generic* inference algorithms. Some of these have +// specialized versions in the derived classes LinearFactorGraph etc... +/* ************************************************************************* */ + +/* ************************************************************************* */ +TEST(LinearFactorGraph, createSmoother) +{ + LinearFactorGraph fg2 = createSmoother(3); + LONGS_EQUAL(5,fg2.size()); + + // eliminate + Ordering ordering; + GaussianBayesNet bayesNet = fg2.eliminate(ordering); + bayesNet.print("bayesNet"); + FactorGraph p_x3 = marginalize(bayesNet, Ordering("x3")); + FactorGraph p_x1 = marginalize(bayesNet, Ordering("x1")); + CHECK(assert_equal(p_x1,p_x3)); // should be the same because of symmetry +} + +/* ************************************************************************* */ +TEST( Inference, marginals ) +{ + // create and marginalize a small Bayes net on "x" + GaussianBayesNet cbn = createSmallGaussianBayesNet(); + Ordering keys("x"); + FactorGraph fg = marginalize(cbn,keys); + + // turn into Bayes net to test easily + BayesNet actual = eliminate(fg,keys); + + // expected is just scalar Gaussian on x + GaussianBayesNet expected = scalarGaussian("x",4,sqrt(2)); + CHECK(assert_equal(expected,actual)); +} + +/* ************************************************************************* */ +int main() { TestResult tr; return TestRegistry::runAllTests(tr);} +/* ************************************************************************* */