diff --git a/cpp/BayesTree-inl.h b/cpp/BayesTree-inl.h index ddf5caf8f..796020a52 100644 --- a/cpp/BayesTree-inl.h +++ b/cpp/BayesTree-inl.h @@ -5,6 +5,9 @@ */ #include +#include // for operator += +using namespace boost::assign; + #include "BayesTree.h" #include "FactorGraph-inl.h" #include "BayesNet-inl.h" @@ -128,8 +131,32 @@ namespace gtsam { p_FSR->push_back(*R); // Find marginal on the keys we are interested in - BayesNet marginal = marginals(*p_FSR,keys()); - return marginal; + return marginals(*p_FSR,keys()); + } + + /* ************************************************************************* */ + // P(C1,C2) = \int_R P(F1|S1) P(S1|R) P(F2|S1) P(S2|R) P(R) + /* ************************************************************************* */ + template + template + BayesNet + BayesTree::Clique::joint(shared_ptr C2, shared_ptr R) { + // For now, assume neither is the root + + // Combine P(F1|S1), P(S1|R), P(F2|S2), P(S2|R), and P(R) + sharedBayesNet p_FSR = this->shortcut(R); + p_FSR->push_front(*this); + p_FSR->push_front(*C2->shortcut(R)); + p_FSR->push_front(*C2); + p_FSR->push_back(*R); + + // Find the keys of both C1 and C2 + Ordering keys12 = keys(); + BOOST_FOREACH(string key,C2->keys()) keys12.push_back(key); + keys12.unique(); + + // Calculate the marginal + return marginals(*p_FSR,keys12); } /* ************************************************************************* */ @@ -205,9 +232,27 @@ namespace gtsam { BayesNet cliqueMarginal = clique->marginal(root_); // Get the marginal on the single key - BayesNet marginal = marginals(cliqueMarginal,Ordering(key)); + return marginals(cliqueMarginal,Ordering(key)); + } - return marginal; + /* ************************************************************************* */ + // Find two cliques, their joint, then marginalizes + /* ************************************************************************* */ + template + template + BayesNet + BayesTree::joint(const std::string& key1, const std::string& key2) const { + + // get clique C1 and C2 + sharedClique C1 = (*this)[key1], C2 = (*this)[key2]; + + // calculate joint + BayesNet p_C1C2 = C1->joint(C2,root_); + + // Get the marginal on the two keys + Ordering ordering; + ordering += key1, key2; + return marginals(p_C1C2,ordering); } /* ************************************************************************* */ diff --git a/cpp/BayesTree.h b/cpp/BayesTree.h index 5e37b7c8c..370a38779 100644 --- a/cpp/BayesTree.h +++ b/cpp/BayesTree.h @@ -70,6 +70,10 @@ namespace gtsam { /** return the marginal P(C) of the clique */ template BayesNet marginal(shared_ptr root); + + /** return the joint P(C1,C2), where C1==this. TODO: not a method? */ + template + BayesNet joint(shared_ptr C2, shared_ptr root); }; typedef boost::shared_ptr sharedClique; @@ -80,7 +84,7 @@ namespace gtsam { typedef std::map Nodes; Nodes nodes_; - /** Roor clique */ + /** Root clique */ sharedClique root_; /** add a clique */ @@ -138,6 +142,11 @@ namespace gtsam { /** return marginal on any variable */ template BayesNet marginal(const std::string& key) const; + + /** return joint on two variables */ + template + BayesNet joint(const std::string& key1, const std::string& key2) const; + }; // BayesTree } /// namespace gtsam diff --git a/cpp/testBayesTree.cpp b/cpp/testBayesTree.cpp index d26e25fb9..e6da23ea9 100644 --- a/cpp/testBayesTree.cpp +++ b/cpp/testBayesTree.cpp @@ -160,9 +160,6 @@ TEST( BayesTree, balanced_smoother_marginals ) // eliminate using a "nested dissection" ordering GaussianBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering); -// SymbolicBayesNet symbolic(*chordalBayesNet); -// symbolic.print("chordalBayesNet"); - VectorConfig expectedSolution; Vector delta = zero(2); BOOST_FOREACH(string key, ordering) @@ -174,8 +171,6 @@ TEST( BayesTree, balanced_smoother_marginals ) Gaussian bayesTree(*chordalBayesNet); LONGS_EQUAL(7,bayesTree.size()); - // Marginals - // Check marginal on x1 GaussianBayesNet expected1("x1", delta, 0.786153); BayesNet actual1 = bayesTree.marginal("x1"); @@ -200,16 +195,13 @@ TEST( BayesTree, balanced_smoother_shortcuts ) Ordering ordering; ordering += "x1","x3","x5","x7","x2","x6","x4"; - // eliminate using a "nested dissection" ordering - GaussianBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering); - boost::shared_ptr actualSolution = chordalBayesNet->optimize(); - // Create the Bayes tree + GaussianBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering); Gaussian bayesTree(*chordalBayesNet); - Gaussian::sharedClique R = bayesTree.root(); // Check the conditional P(Root|Root) BayesNet empty; + Gaussian::sharedClique R = bayesTree.root(); Gaussian::sharedBayesNet actual1 = R->shortcut(R); CHECK(assert_equal(empty,*actual1,1e-4)); @@ -234,23 +226,50 @@ TEST( BayesTree, balanced_smoother_clique_marginals ) Ordering ordering; ordering += "x1","x3","x5","x7","x2","x6","x4"; - // eliminate using a "nested dissection" ordering + // Create the Bayes tree GaussianBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering); - boost::shared_ptr actualSolution = chordalBayesNet->optimize(); + Gaussian bayesTree(*chordalBayesNet); + + // Check the clique marginal P(C3) + GaussianBayesNet expected("x2",zero(2),0.687131); + Vector sigma = repeat(2, 0.707107); + Matrix A12 = (-0.5)*eye(2); + ConditionalGaussian::shared_ptr cg(new ConditionalGaussian("x1", zero(2), eye(2), "x2", A12, sigma)); + expected.push_front(cg); + Gaussian::sharedClique R = bayesTree.root(), C3 = bayesTree["x1"]; + BayesNet actual = C3->marginal(R); + CHECK(assert_equal((BayesNet)expected,actual,1e-4)); +} + +/* ************************************************************************* */ +TEST( BayesTree, balanced_smoother_joint ) +{ + // Create smoother with 7 nodes + LinearFactorGraph smoother = createSmoother(7); + Ordering ordering; + ordering += "x1","x3","x5","x7","x2","x6","x4"; // Create the Bayes tree + GaussianBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering); Gaussian bayesTree(*chordalBayesNet); - Gaussian::sharedClique R = bayesTree.root(); - // Check the conditional P(C3|Root), which should be equal to P(x2|x4) - GaussianBayesNet expected3("x2",zero(2),0.687131); - Vector sigma3 = repeat(2, 0.707107); - Matrix A12 = (-0.5)*eye(2); - ConditionalGaussian::shared_ptr cg3(new ConditionalGaussian("x1", zero(2), eye(2), "x2", A12, sigma3)); - expected3.push_front(cg3); - Gaussian::sharedClique C3 = bayesTree["x1"]; - BayesNet actual3 = C3->marginal(R); - CHECK(assert_equal((BayesNet)expected3,actual3,1e-4)); + // Conditional density elements reused by both tests + Vector sigma = repeat(2, 0.786146); + Matrix A = (-0.00429185)*eye(2); + + // Check the joint density P(x1,x7) factored as P(x1|x7)P(x7) + GaussianBayesNet expected1("x7", zero(2), 0.786153); + ConditionalGaussian::shared_ptr cg1(new ConditionalGaussian("x1", zero(2), eye(2), "x7", A, sigma)); + expected1.push_front(cg1); + BayesNet actual1 = bayesTree.joint("x1","x7"); + CHECK(assert_equal((BayesNet)expected1,actual1,1e-4)); + + // Check the joint density P(x7,x1) factored as P(x7|x1)P(x1) + GaussianBayesNet expected2("x1", zero(2), 0.786153); + ConditionalGaussian::shared_ptr cg2(new ConditionalGaussian("x7", zero(2), eye(2), "x1", A, sigma)); + expected2.push_front(cg2); + BayesNet actual2 = bayesTree.joint("x7","x1"); + CHECK(assert_equal((BayesNet)expected2,actual2,1e-4)); } /* ************************************************************************* */