New compilation unit that collects generic, templated inference methods that go between factor graphs and Bayes nets. These used to be in BayesNet-inl.h and FactorGraph-inl.h

release/4.3a0
Frank Dellaert 2009-11-12 04:52:40 +00:00
parent 4a7e05fffa
commit a38ebece1e
3 changed files with 193 additions and 0 deletions

87
cpp/inference-inl.h Normal file
View File

@ -0,0 +1,87 @@
/**
* @file inference-inl.h
* @brief inference template definitions
* @author Frank Dellaert
*/
#include "inference.h"
#include "FactorGraph-inl.h"
#include "BayesNet-inl.h"
using namespace std;
namespace gtsam {
/* ************************************************************************* */
/* eliminate one node from the factor graph */
/* ************************************************************************* */
template<class Factor,class Conditional>
boost::shared_ptr<Conditional> eliminateOne(FactorGraph<Factor>& graph, const string& key) {
// combine the factors of all nodes connected to the variable to be eliminated
// if no factors are connected to key, returns an empty factor
boost::shared_ptr<Factor> joint_factor = removeAndCombineFactors(graph,key);
// eliminate that joint factor
boost::shared_ptr<Factor> factor;
boost::shared_ptr<Conditional> conditional;
boost::tie(conditional, factor) = joint_factor->eliminate(key);
// add new factor on separator back into the graph
if (!factor->empty()) graph.push_back(factor);
// return the conditional Gaussian
return conditional;
}
/* ************************************************************************* */
// This doubly templated function is generic. There is a LinearFactorGraph
// version that returns a more specific GaussianBayesNet.
// Note, you will need to include this file to instantiate the function.
/* ************************************************************************* */
template<class Factor,class Conditional>
BayesNet<Conditional> eliminate(FactorGraph<Factor>& factorGraph, const Ordering& ordering)
{
BayesNet<Conditional> bayesNet; // empty
BOOST_FOREACH(string key, ordering) {
boost::shared_ptr<Conditional> cg = eliminateOne<Factor,Conditional>(factorGraph,key);
bayesNet.push_back(cg);
}
return bayesNet;
}
/* ************************************************************************* */
template<class Factor, class Conditional>
pair< BayesNet<Conditional>, FactorGraph<Factor> >
factor(const BayesNet<Conditional>& bn, const Ordering& keys) {
// Convert to factor graph
FactorGraph<Factor> factorGraph(bn);
// Get the keys of all variables and remove all keys we want the marginal for
Ordering ord = bn.ordering();
BOOST_FOREACH(string key, keys) ord.remove(key); // TODO: O(n*k), faster possible?
// eliminate partially,
BayesNet<Conditional> conditional = eliminate<Factor,Conditional>(factorGraph,ord);
// at this moment, the factor graph only encodes P(keys)
return make_pair(conditional,factorGraph);
}
/* ************************************************************************* */
template<class Factor, class Conditional>
FactorGraph<Factor> marginalize(const BayesNet<Conditional>& bn, const Ordering& keys) {
// factor P(X,Y) as P(X|Y)P(Y), where Y corresponds to keys
pair< BayesNet<Conditional>, FactorGraph<Factor> > factors =
gtsam::factor<Factor,Conditional>(bn,keys);
// throw away conditional, return marginal P(Y)
return factors.second;
}
/* ************************************************************************* */
} // namespace gtsam

52
cpp/inference.h Normal file
View File

@ -0,0 +1,52 @@
/**
* @file inference.h
* @brief Contains *generic* inference algorithms that convert between templated
* graphical models, i.e., factor graphs, Bayes nets, and Bayes trees
* @author Frank Dellaert
*/
#pragma once
#include "FactorGraph.h"
#include "BayesNet.h"
namespace gtsam {
class Ordering;
// ELIMINATE: FACTOR GRAPH -> BAYES NET
/**
* Eliminate a single node yielding a Conditional
* Eliminates the factors from the factor graph through findAndRemoveFactors
* and adds a new factor on the separator to the factor graph
*/
template<class Factor, class Conditional>
boost::shared_ptr<Conditional>
eliminateOne(FactorGraph<Factor>& factorGraph, const std::string& key);
/**
* eliminate factor graph using the given (not necessarily complete)
* ordering, yielding a chordal Bayes net and (partially eliminated) FG
*/
template<class Factor, class Conditional>
BayesNet<Conditional> eliminate(FactorGraph<Factor>& factorGraph, const Ordering& ordering);
// FACTOR/MARGINALIZE: BAYES NET -> FACTOR GRAPH
/**
* Factor P(X) as P(not keys|keys) P(keys)
* @return P(not keys|keys) as an incomplete BayesNet, and P(keys) as a factor graph
*/
template<class Factor, class Conditional>
std::pair< BayesNet<Conditional>, FactorGraph<Factor> >
factor(const BayesNet<Conditional>& bn, const Ordering& keys);
/**
* integrate out all except ordering, might be inefficient as the ordering
* will simply be the current ordering with the keys put in the back
*/
template<class Factor, class Conditional>
FactorGraph<Factor> marginalize(const BayesNet<Conditional>& bn, const Ordering& keys);
} /// namespace gtsam

54
cpp/testInference.cpp Normal file
View File

@ -0,0 +1,54 @@
/**
* @file testInference.cpp
* @brief Unit tests for functionality declared in inference.h
* @author Frank Dellaert
*/
#include <CppUnitLite/TestHarness.h>
#include "Ordering.h"
#include "smallExample.h"
#include "inference-inl.h"
using namespace std;
using namespace gtsam;
/* ************************************************************************* */
// The tests below test the *generic* inference algorithms. Some of these have
// specialized versions in the derived classes LinearFactorGraph etc...
/* ************************************************************************* */
/* ************************************************************************* */
TEST(LinearFactorGraph, createSmoother)
{
LinearFactorGraph fg2 = createSmoother(3);
LONGS_EQUAL(5,fg2.size());
// eliminate
Ordering ordering;
GaussianBayesNet bayesNet = fg2.eliminate(ordering);
bayesNet.print("bayesNet");
FactorGraph<LinearFactor> p_x3 = marginalize<LinearFactor,ConditionalGaussian>(bayesNet, Ordering("x3"));
FactorGraph<LinearFactor> p_x1 = marginalize<LinearFactor,ConditionalGaussian>(bayesNet, Ordering("x1"));
CHECK(assert_equal(p_x1,p_x3)); // should be the same because of symmetry
}
/* ************************************************************************* */
TEST( Inference, marginals )
{
// create and marginalize a small Bayes net on "x"
GaussianBayesNet cbn = createSmallGaussianBayesNet();
Ordering keys("x");
FactorGraph<LinearFactor> fg = marginalize<LinearFactor, ConditionalGaussian>(cbn,keys);
// turn into Bayes net to test easily
BayesNet<ConditionalGaussian> actual = eliminate<LinearFactor,ConditionalGaussian>(fg,keys);
// expected is just scalar Gaussian on x
GaussianBayesNet expected = scalarGaussian("x",4,sqrt(2));
CHECK(assert_equal(expected,actual));
}
/* ************************************************************************* */
int main() { TestResult tr; return TestRegistry::runAllTests(tr);}
/* ************************************************************************* */