diff --git a/cpp/BayesNet-inl.h b/cpp/BayesNet-inl.h index fa9116ea5..dde1a4cac 100644 --- a/cpp/BayesNet-inl.h +++ b/cpp/BayesNet-inl.h @@ -13,6 +13,7 @@ using namespace boost::assign; #include "Ordering.h" #include "BayesNet.h" +#include "FactorGraph-inl.h" using namespace std; @@ -34,6 +35,20 @@ namespace gtsam { return equal(conditionals_.begin(),conditionals_.end(),cbn.conditionals_.begin(),equals_star(tol)); } + /* ************************************************************************* */ + template + void BayesNet::push_back(const BayesNet bn) { + BOOST_FOREACH(sharedConditional conditional,bn.conditionals_) + push_back(conditional); + } + + /* ************************************************************************* */ + template + void BayesNet::push_front(const BayesNet bn) { + BOOST_FOREACH(sharedConditional conditional,bn.conditionals_) + push_front(conditional); + } + /* ************************************************************************* */ template Ordering BayesNet::ordering() const { @@ -53,6 +68,31 @@ namespace gtsam { "BayesNet::operator['"+key+"']: not found")); return *it; } + + /* ************************************************************************* */ + template + BayesNet marginals(const BayesNet& bn, const Ordering& keys) { + // Convert to factor graph + FactorGraph factorGraph(bn); + + // Get the keys of all variables and remove all keys we want the marginal for + Ordering ord = bn.ordering(); + BOOST_FOREACH(string key, keys) ord.remove(key); // TODO: O(n*k), faster possible? + + // add marginal keys at end + BOOST_FOREACH(string key, keys) ord.push_back(key); + + // eliminate to get joint + typename BayesNet::shared_ptr joint = _eliminate(factorGraph,ord); + + // remove all integrands, P(K) = \int_I P(I|K) P(K) + size_t nrIntegrands = ord.size()-keys.size(); + for(int i=0;ipop_front(); + + // joint is now only on keys, return it + return *joint; + } + /* ************************************************************************* */ } // namespace gtsam diff --git a/cpp/BayesNet.h b/cpp/BayesNet.h index ea965ac69..c88d0f038 100644 --- a/cpp/BayesNet.h +++ b/cpp/BayesNet.h @@ -30,6 +30,8 @@ namespace gtsam { public: + typedef typename boost::shared_ptr >shared_ptr; + /** We store shared pointers to Conditional densities */ typedef typename boost::shared_ptr sharedConditional; typedef typename std::list Conditionals; @@ -64,6 +66,12 @@ namespace gtsam { conditionals_.push_front(conditional); } + // push_back an entire Bayes net */ + void push_back(const BayesNet bn); + + // push_front an entire Bayes net */ + void push_front(const BayesNet bn); + /** * pop_front: remove node at the bottom, used in marginalization * For example P(ABC)=P(A|BC)P(B|C)P(C) becomes P(BC)=P(B|C)P(C) @@ -81,6 +89,7 @@ namespace gtsam { /** SLOW O(n) random access to Conditional by key */ sharedConditional operator[](const std::string& key) const; + /** return last node in ordering */ inline sharedConditional back() { return conditionals_.back(); } /** return iterators. FD: breaks encapsulation? */ @@ -96,6 +105,15 @@ namespace gtsam { void serialize(Archive & ar, const unsigned int version) { ar & BOOST_SERIALIZATION_NVP(conditionals_); } - }; + }; // BayesNet + + /** doubly templated functions */ + + /** + * integrate out all except ordering, might be inefficient as the ordering + * will simply be the current ordering with the keys put in the back + */ + template + BayesNet marginals(const BayesNet& bn, const Ordering& keys); } /// namespace gtsam diff --git a/cpp/GaussianBayesNet.cpp b/cpp/GaussianBayesNet.cpp index ffb27ff06..62522e171 100644 --- a/cpp/GaussianBayesNet.cpp +++ b/cpp/GaussianBayesNet.cpp @@ -22,6 +22,21 @@ template class BayesNet; #define FOREACH_PAIR( KEY, VAL, COL) BOOST_FOREACH (boost::tie(KEY,VAL),COL) #define REVERSE_FOREACH_PAIR( KEY, VAL, COL) BOOST_REVERSE_FOREACH (boost::tie(KEY,VAL),COL) +/* ************************************************************************* */ +GaussianBayesNet::GaussianBayesNet(const string& key, double mu, double sigma) { + ConditionalGaussian::shared_ptr + conditional(new ConditionalGaussian(key, Vector_(1,mu), eye(1), Vector_(1,sigma))); + push_back(conditional); +} + +/* ************************************************************************* */ +GaussianBayesNet::GaussianBayesNet(const string& key, const Vector& mu, double sigma) { + size_t n = mu.size(); + ConditionalGaussian::shared_ptr + conditional(new ConditionalGaussian(key, mu, eye(n), repeat(n,sigma))); + push_back(conditional); +} + /* ************************************************************************* */ boost::shared_ptr GaussianBayesNet::optimize() const { diff --git a/cpp/GaussianBayesNet.h b/cpp/GaussianBayesNet.h index 41e4a7684..39f29f81a 100644 --- a/cpp/GaussianBayesNet.h +++ b/cpp/GaussianBayesNet.h @@ -25,9 +25,11 @@ public: /** Construct an empty net */ GaussianBayesNet() {} - /** Copy Constructor */ -// GaussianBayesNet(const GaussianBayesNet& cbn_in) : -// keys_(cbn_in.keys_), nodes_(cbn_in.nodes_) {} + /** Create a scalar Gaussian */ + GaussianBayesNet(const std::string& key, double mu=0.0, double sigma=1.0); + + /** Create a simple Gaussian on a single multivariate variable */ + GaussianBayesNet(const std::string& key, const Vector& mu, double sigma=1.0); /** Destructor */ virtual ~GaussianBayesNet() {} diff --git a/cpp/testGaussianBayesNet.cpp b/cpp/testGaussianBayesNet.cpp index 314bbad13..07e47c506 100644 --- a/cpp/testGaussianBayesNet.cpp +++ b/cpp/testGaussianBayesNet.cpp @@ -4,7 +4,6 @@ * @author Frank Dellaert */ - // STL/C++ #include #include @@ -12,13 +11,18 @@ #include #include +#include // for operator += +using namespace boost::assign; + #ifdef HAVE_BOOST_SERIALIZATION #include #include #endif //HAVE_BOOST_SERIALIZATION #include "GaussianBayesNet.h" +#include "BayesNet-inl.h" #include "smallExample.h" +#include "Ordering.h" using namespace std; using namespace gtsam; @@ -34,11 +38,11 @@ TEST( GaussianBayesNet, constructor ) Matrix R22 = Matrix_(1,1,1.0); Vector d1(1), d2(1); d1(0) = 9; d2(0) = 5; - Vector tau(1); - tau(0) = 1.; + Vector sigmas(1); + sigmas(0) = 1.; // define nodes and specify in reverse topological sort (i.e. parents last) - ConditionalGaussian x("x",d1,R11,"y",S12, tau), y("y",d2,R22, tau); + ConditionalGaussian x("x",d1,R11,"y",S12, sigmas), y("y",d2,R22, sigmas); // check small example which uses constructor GaussianBayesNet cbn = createSmallGaussianBayesNet(); @@ -68,7 +72,6 @@ TEST( GaussianBayesNet, matrix ) /* ************************************************************************* */ TEST( GaussianBayesNet, optimize ) { - // optimize small Bayes Net GaussianBayesNet cbn = createSmallGaussianBayesNet(); boost::shared_ptr actual = cbn.optimize(); @@ -81,6 +84,19 @@ TEST( GaussianBayesNet, optimize ) CHECK(actual->equals(expected)); } +/* ************************************************************************* */ +TEST( GaussianBayesNet, marginals ) +{ + // create and marginalize a small Bayes net on "x" + GaussianBayesNet cbn = createSmallGaussianBayesNet(); + Ordering keys("x"); + BayesNet actual = marginals(cbn,keys); + + // expected is just scalar Gaussian on x + GaussianBayesNet expected("x",4,sqrt(2)); + CHECK(assert_equal((BayesNet)expected,actual)); +} + /* ************************************************************************* */ #ifdef HAVE_BOOST_SERIALIZATION TEST( GaussianBayesNet, serialize ) diff --git a/cpp/testSymbolicBayesNet.cpp b/cpp/testSymbolicBayesNet.cpp index 6f9bac8b1..c704ea485 100644 --- a/cpp/testSymbolicBayesNet.cpp +++ b/cpp/testSymbolicBayesNet.cpp @@ -66,6 +66,34 @@ TEST( SymbolicBayesNet, pop_front ) CHECK(assert_equal(expected,actual)); } +/* ************************************************************************* */ +TEST( SymbolicBayesNet, combine ) +{ + SymbolicConditional::shared_ptr + A(new SymbolicConditional("A","B","C")), + B(new SymbolicConditional("B","C")), + C(new SymbolicConditional("C")); + + // p(A|BC) + SymbolicBayesNet p_ABC; + p_ABC.push_back(A); + + // P(BC)=P(B|C)P(C) + SymbolicBayesNet p_BC; + p_BC.push_back(B); + p_BC.push_back(C); + + // P(ABC) = P(A|BC) P(BC) + p_ABC.push_back(p_BC); + + SymbolicBayesNet expected; + expected.push_back(A); + expected.push_back(B); + expected.push_back(C); + + CHECK(assert_equal(expected,p_ABC)); +} + /* ************************************************************************* */ int main() { TestResult tr;