New function marginals. Also: combine entire Bayes nets with push_back and push_front. And finally: some convenience constructors in GaussianBayesNet.

release/4.3a0
Frank Dellaert 2009-11-08 22:50:26 +00:00
parent 58007a8167
commit 10e618f360
6 changed files with 128 additions and 9 deletions

View File

@ -13,6 +13,7 @@ using namespace boost::assign;
#include "Ordering.h"
#include "BayesNet.h"
#include "FactorGraph-inl.h"
using namespace std;
@ -34,6 +35,20 @@ namespace gtsam {
return equal(conditionals_.begin(),conditionals_.end(),cbn.conditionals_.begin(),equals_star<Conditional>(tol));
}
/* ************************************************************************* */
template<class Conditional>
void BayesNet<Conditional>::push_back(const BayesNet<Conditional> bn) {
BOOST_FOREACH(sharedConditional conditional,bn.conditionals_)
push_back(conditional);
}
/* ************************************************************************* */
template<class Conditional>
void BayesNet<Conditional>::push_front(const BayesNet<Conditional> bn) {
BOOST_FOREACH(sharedConditional conditional,bn.conditionals_)
push_front(conditional);
}
/* ************************************************************************* */
template<class Conditional>
Ordering BayesNet<Conditional>::ordering() const {
@ -53,6 +68,31 @@ namespace gtsam {
"BayesNet::operator['"+key+"']: not found"));
return *it;
}
/* ************************************************************************* */
template<class Factor, class Conditional>
BayesNet<Conditional> marginals(const BayesNet<Conditional>& bn, const Ordering& keys) {
// Convert to factor graph
FactorGraph<Factor> factorGraph(bn);
// Get the keys of all variables and remove all keys we want the marginal for
Ordering ord = bn.ordering();
BOOST_FOREACH(string key, keys) ord.remove(key); // TODO: O(n*k), faster possible?
// add marginal keys at end
BOOST_FOREACH(string key, keys) ord.push_back(key);
// eliminate to get joint
typename BayesNet<Conditional>::shared_ptr joint = _eliminate<Factor,Conditional>(factorGraph,ord);
// remove all integrands, P(K) = \int_I P(I|K) P(K)
size_t nrIntegrands = ord.size()-keys.size();
for(int i=0;i<nrIntegrands;i++) joint->pop_front();
// joint is now only on keys, return it
return *joint;
}
/* ************************************************************************* */
} // namespace gtsam

View File

@ -30,6 +30,8 @@ namespace gtsam {
public:
typedef typename boost::shared_ptr<BayesNet<Conditional> >shared_ptr;
/** We store shared pointers to Conditional densities */
typedef typename boost::shared_ptr<Conditional> sharedConditional;
typedef typename std::list<sharedConditional> Conditionals;
@ -64,6 +66,12 @@ namespace gtsam {
conditionals_.push_front(conditional);
}
// push_back an entire Bayes net */
void push_back(const BayesNet<Conditional> bn);
// push_front an entire Bayes net */
void push_front(const BayesNet<Conditional> bn);
/**
* pop_front: remove node at the bottom, used in marginalization
* For example P(ABC)=P(A|BC)P(B|C)P(C) becomes P(BC)=P(B|C)P(C)
@ -81,6 +89,7 @@ namespace gtsam {
/** SLOW O(n) random access to Conditional by key */
sharedConditional operator[](const std::string& key) const;
/** return last node in ordering */
inline sharedConditional back() { return conditionals_.back(); }
/** return iterators. FD: breaks encapsulation? */
@ -96,6 +105,15 @@ namespace gtsam {
void serialize(Archive & ar, const unsigned int version) {
ar & BOOST_SERIALIZATION_NVP(conditionals_);
}
};
}; // BayesNet
/** doubly templated functions */
/**
* integrate out all except ordering, might be inefficient as the ordering
* will simply be the current ordering with the keys put in the back
*/
template<class Factor, class Conditional>
BayesNet<Conditional> marginals(const BayesNet<Conditional>& bn, const Ordering& keys);
} /// namespace gtsam

View File

@ -22,6 +22,21 @@ template class BayesNet<ConditionalGaussian>;
#define FOREACH_PAIR( KEY, VAL, COL) BOOST_FOREACH (boost::tie(KEY,VAL),COL)
#define REVERSE_FOREACH_PAIR( KEY, VAL, COL) BOOST_REVERSE_FOREACH (boost::tie(KEY,VAL),COL)
/* ************************************************************************* */
GaussianBayesNet::GaussianBayesNet(const string& key, double mu, double sigma) {
ConditionalGaussian::shared_ptr
conditional(new ConditionalGaussian(key, Vector_(1,mu), eye(1), Vector_(1,sigma)));
push_back(conditional);
}
/* ************************************************************************* */
GaussianBayesNet::GaussianBayesNet(const string& key, const Vector& mu, double sigma) {
size_t n = mu.size();
ConditionalGaussian::shared_ptr
conditional(new ConditionalGaussian(key, mu, eye(n), repeat(n,sigma)));
push_back(conditional);
}
/* ************************************************************************* */
boost::shared_ptr<VectorConfig> GaussianBayesNet::optimize() const
{

View File

@ -25,9 +25,11 @@ public:
/** Construct an empty net */
GaussianBayesNet() {}
/** Copy Constructor */
// GaussianBayesNet(const GaussianBayesNet& cbn_in) :
// keys_(cbn_in.keys_), nodes_(cbn_in.nodes_) {}
/** Create a scalar Gaussian */
GaussianBayesNet(const std::string& key, double mu=0.0, double sigma=1.0);
/** Create a simple Gaussian on a single multivariate variable */
GaussianBayesNet(const std::string& key, const Vector& mu, double sigma=1.0);
/** Destructor */
virtual ~GaussianBayesNet() {}

View File

@ -4,7 +4,6 @@
* @author Frank Dellaert
*/
// STL/C++
#include <iostream>
#include <sstream>
@ -12,13 +11,18 @@
#include <boost/tuple/tuple.hpp>
#include <boost/foreach.hpp>
#include <boost/assign/std/list.hpp> // for operator +=
using namespace boost::assign;
#ifdef HAVE_BOOST_SERIALIZATION
#include <boost/archive/text_oarchive.hpp>
#include <boost/archive/text_iarchive.hpp>
#endif //HAVE_BOOST_SERIALIZATION
#include "GaussianBayesNet.h"
#include "BayesNet-inl.h"
#include "smallExample.h"
#include "Ordering.h"
using namespace std;
using namespace gtsam;
@ -34,11 +38,11 @@ TEST( GaussianBayesNet, constructor )
Matrix R22 = Matrix_(1,1,1.0);
Vector d1(1), d2(1);
d1(0) = 9; d2(0) = 5;
Vector tau(1);
tau(0) = 1.;
Vector sigmas(1);
sigmas(0) = 1.;
// define nodes and specify in reverse topological sort (i.e. parents last)
ConditionalGaussian x("x",d1,R11,"y",S12, tau), y("y",d2,R22, tau);
ConditionalGaussian x("x",d1,R11,"y",S12, sigmas), y("y",d2,R22, sigmas);
// check small example which uses constructor
GaussianBayesNet cbn = createSmallGaussianBayesNet();
@ -68,7 +72,6 @@ TEST( GaussianBayesNet, matrix )
/* ************************************************************************* */
TEST( GaussianBayesNet, optimize )
{
// optimize small Bayes Net
GaussianBayesNet cbn = createSmallGaussianBayesNet();
boost::shared_ptr<VectorConfig> actual = cbn.optimize();
@ -81,6 +84,19 @@ TEST( GaussianBayesNet, optimize )
CHECK(actual->equals(expected));
}
/* ************************************************************************* */
TEST( GaussianBayesNet, marginals )
{
// create and marginalize a small Bayes net on "x"
GaussianBayesNet cbn = createSmallGaussianBayesNet();
Ordering keys("x");
BayesNet<ConditionalGaussian> actual = marginals<LinearFactor>(cbn,keys);
// expected is just scalar Gaussian on x
GaussianBayesNet expected("x",4,sqrt(2));
CHECK(assert_equal((BayesNet<ConditionalGaussian>)expected,actual));
}
/* ************************************************************************* */
#ifdef HAVE_BOOST_SERIALIZATION
TEST( GaussianBayesNet, serialize )

View File

@ -66,6 +66,34 @@ TEST( SymbolicBayesNet, pop_front )
CHECK(assert_equal(expected,actual));
}
/* ************************************************************************* */
TEST( SymbolicBayesNet, combine )
{
SymbolicConditional::shared_ptr
A(new SymbolicConditional("A","B","C")),
B(new SymbolicConditional("B","C")),
C(new SymbolicConditional("C"));
// p(A|BC)
SymbolicBayesNet p_ABC;
p_ABC.push_back(A);
// P(BC)=P(B|C)P(C)
SymbolicBayesNet p_BC;
p_BC.push_back(B);
p_BC.push_back(C);
// P(ABC) = P(A|BC) P(BC)
p_ABC.push_back(p_BC);
SymbolicBayesNet expected;
expected.push_back(A);
expected.push_back(B);
expected.push_back(C);
CHECK(assert_equal(expected,p_ABC));
}
/* ************************************************************************* */
int main() {
TestResult tr;