From 86173b66af106ede026023dd61f8e452c255eb21 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 8 Nov 2009 22:51:12 +0000 Subject: [PATCH] Clique marginal and dramatically simplified single variable marginal. --- cpp/BayesTree-inl.h | 99 +++++++++++++++++++++---------------------- cpp/BayesTree.h | 11 ++++- cpp/testBayesTree.cpp | 53 ++++++++++++++++------- 3 files changed, 94 insertions(+), 69 deletions(-) diff --git a/cpp/BayesTree-inl.h b/cpp/BayesTree-inl.h index 1943cef4d..ddf5caf8f 100644 --- a/cpp/BayesTree-inl.h +++ b/cpp/BayesTree-inl.h @@ -7,6 +7,7 @@ #include #include "BayesTree.h" #include "FactorGraph-inl.h" +#include "BayesNet-inl.h" namespace gtsam { @@ -19,6 +20,14 @@ namespace gtsam { this->push_back(conditional); } + /* ************************************************************************* */ + template + Ordering BayesTree::Clique::keys() const { + Ordering frontal_keys = this->ordering(), keys = separator_; + keys.splice(keys.begin(),frontal_keys); + return keys; + } + /* ************************************************************************* */ template void BayesTree::Clique::print(const string& s) const { @@ -41,30 +50,27 @@ namespace gtsam { child->printTree(indent+" "); } + /* ************************************************************************* */ + // The shortcut density is a conditional P(S|R) of the separator of this + // clique on the root. We can compute it recursively from the parent shortcut + // P(Sp|R) as \int P(Fp|Sp) P(Sp|R), where Fp are the frontal nodes in p + // TODO, why do we actually return a shared pointer, why does eliminate? /* ************************************************************************* */ template template typename BayesTree::sharedBayesNet BayesTree::Clique::shortcut(shared_ptr R) { - // The shortcut density is a conditional P(S|R) of the separator of this - // clique on the root. We can compute it recursively from the parent shortcut - // P(Sp|R) as \int P(Fp|Sp) P(Sp|R), where Fp are the frontal nodes in p - // A first base case is when this clique or its parent is the root, // in which case we return an empty Bayes net. - if (R.get()==this || parent_==R) { - sharedBayesNet empty(new BayesNet); - return empty; - } + if (R.get()==this || parent_==R) + return sharedBayesNet(new BayesNet); // The parent clique has a Conditional for each frontal node in Fp // so we can obtain P(Fp|Sp) in factor graph form FactorGraph p_Fp_Sp(*parent_); - //p_Fp_Sp.print("p_Fp_Sp"); // If not the base case, obtain the parent shortcut P(Sp|R) as factors FactorGraph p_Sp_R(*parent_->shortcut(R)); - //p_Sp_R.print("p_Sp_R"); // now combine P(Cp|R) = P(Fp|Sp) * P(Sp|R) FactorGraph p_Cp_R = combine(p_Fp_Sp, p_Sp_R); @@ -76,12 +82,7 @@ namespace gtsam { // Keys corresponding to the root should not be added to the ordering at all. // Get the key list Cp=Fp+Sp, which will form the basis for the integrands - Ordering integrands; - { - Ordering Fp = parent_->ordering(), Sp = parent_->separator_; - integrands.splice(integrands.end(),Fp); - integrands.splice(integrands.end(),Sp); - } + Ordering integrands = parent_->keys(); // Start ordering with the separator Ordering ordering = separator_; @@ -108,6 +109,29 @@ namespace gtsam { return p_S_R; } + /* ************************************************************************* */ + // P(C) = \int_R P(F|S) P(S|R) P(R) + // TODO: Maybe we should integrate given parent marginal P(Cp), + // \int(Cp\S) P(F|S)P(S|Cp)P(Cp) + // Because the root clique could be very big. + /* ************************************************************************* */ + template + template + BayesNet + BayesTree::Clique::marginal(shared_ptr R) { + // If we are the root, just return this root + if (R.get()==this) return *R; + + // Combine P(F|S), P(S|R), and P(R) + sharedBayesNet p_FSR = this->shortcut(R); + p_FSR->push_front(*this); + p_FSR->push_back(*R); + + // Find marginal on the keys we are interested in + BayesNet marginal = marginals(*p_FSR,keys()); + return marginal; + } + /* ************************************************************************* */ template BayesTree::BayesTree() { @@ -167,50 +191,23 @@ namespace gtsam { } /* ************************************************************************* */ - // Desired: recursive, memoizing version - // Once we know the clique, can we do all with Nodes ? - // Sure, as P(x) = \int P(C|root) - // The natural cache is P(C|root), memoized, of course, in the clique C - // When any marginal is asked for, we calculate P(C|root) = P(C|Pi)P(Pi|root) - // Super-naturally recursive !!!!! + // First finds clique marginal then marginalizes that /* ************************************************************************* */ template template - typename BayesTree::sharedConditional + BayesNet BayesTree::marginal(const string& key) const { - // get clique containing key, and remove all factors below key + // get clique containing key sharedClique clique = (*this)[key]; - Ordering ordering = clique->ordering(); - FactorGraph graph(*clique); - while(ordering.front()!=key) { - graph.findAndRemoveFactors(ordering.front()); - ordering.pop_front(); - } - // find all cliques on the path to the root and turn into factor graph - while (clique->parent_!=NULL) { - // move up the tree - clique = clique->parent_; + // calculate or retrieve its marginal + BayesNet cliqueMarginal = clique->marginal(root_); - // extend ordering - Ordering cliqueOrdering = clique->ordering(); - ordering.splice (ordering.end(), cliqueOrdering); + // Get the marginal on the single key + BayesNet marginal = marginals(cliqueMarginal,Ordering(key)); - // extend factor graph - FactorGraph cliqueGraph(*clique); - typename FactorGraph::const_iterator factor=cliqueGraph.begin(); - for(; factor!=cliqueGraph.end(); factor++) - graph.push_back(*factor); - } - - // TODO: can we prove reverse ordering is efficient? - ordering.reverse(); - - // eliminate to get marginal - sharedBayesNet chordalBayesNet = _eliminate(graph,ordering); - - return chordalBayesNet->back(); // the root is the marginal + return marginal; } /* ************************************************************************* */ diff --git a/cpp/BayesTree.h b/cpp/BayesTree.h index a1442d157..5e37b7c8c 100644 --- a/cpp/BayesTree.h +++ b/cpp/BayesTree.h @@ -46,6 +46,9 @@ namespace gtsam { //* Constructor */ Clique(const sharedConditional& conditional); + /** return keys in frontal:separator order */ + Ordering keys() const; + /** print this node */ void print(const std::string& s = "Bayes tree node") const; @@ -62,7 +65,11 @@ namespace gtsam { /** return the conditional P(S|Root) on the separator given the root */ template - sharedBayesNet shortcut(shared_ptr R); + sharedBayesNet shortcut(shared_ptr root); + + /** return the marginal P(C) of the clique */ + template + BayesNet marginal(shared_ptr root); }; typedef boost::shared_ptr sharedClique; @@ -130,7 +137,7 @@ namespace gtsam { /** return marginal on any variable */ template - sharedConditional marginal(const std::string& key) const; + BayesNet marginal(const std::string& key) const; }; // BayesTree } /// namespace gtsam diff --git a/cpp/testBayesTree.cpp b/cpp/testBayesTree.cpp index 7a3da89dc..d26e25fb9 100644 --- a/cpp/testBayesTree.cpp +++ b/cpp/testBayesTree.cpp @@ -176,26 +176,20 @@ TEST( BayesTree, balanced_smoother_marginals ) // Marginals - // Marginal will always be axis-parallel Gaussian on delta=(0,0) - Matrix R = eye(2); - // Check marginal on x1 - Vector sigma1 = repeat(2, 0.786153); - ConditionalGaussian expected1("x1", delta, R, sigma1); - ConditionalGaussian::shared_ptr actual1 = bayesTree.marginal("x1"); - CHECK(assert_equal(expected1,*actual1,1e-4)); + GaussianBayesNet expected1("x1", delta, 0.786153); + BayesNet actual1 = bayesTree.marginal("x1"); + CHECK(assert_equal((BayesNet)expected1,actual1,1e-4)); // Check marginal on x2 - Vector sigma2 = repeat(2, 0.687131); - ConditionalGaussian expected2("x2", delta, R, sigma2); - ConditionalGaussian::shared_ptr actual2 = bayesTree.marginal("x2"); - CHECK(assert_equal(expected2,*actual2,1e-4)); + GaussianBayesNet expected2("x2", delta, 0.687131); + BayesNet actual2 = bayesTree.marginal("x2"); + CHECK(assert_equal((BayesNet)expected2,actual2,1e-4)); // Check marginal on x3 - Vector sigma3 = repeat(2, 0.671512); - ConditionalGaussian expected3("x3", delta, R, sigma3); - ConditionalGaussian::shared_ptr actual3 = bayesTree.marginal("x3"); - CHECK(assert_equal(expected3,*actual3,1e-4)); + GaussianBayesNet expected3("x3", delta, 0.671512); + BayesNet actual3 = bayesTree.marginal("x3"); + CHECK(assert_equal((BayesNet)expected3,actual3,1e-4)); } /* ************************************************************************* */ @@ -212,10 +206,10 @@ TEST( BayesTree, balanced_smoother_shortcuts ) // Create the Bayes tree Gaussian bayesTree(*chordalBayesNet); + Gaussian::sharedClique R = bayesTree.root(); // Check the conditional P(Root|Root) BayesNet empty; - Gaussian::sharedClique R = bayesTree.root(); Gaussian::sharedBayesNet actual1 = R->shortcut(R); CHECK(assert_equal(empty,*actual1,1e-4)); @@ -232,6 +226,33 @@ TEST( BayesTree, balanced_smoother_shortcuts ) CHECK(assert_equal(expected3,*actual3,1e-4)); } +/* ************************************************************************* */ +TEST( BayesTree, balanced_smoother_clique_marginals ) +{ + // Create smoother with 7 nodes + LinearFactorGraph smoother = createSmoother(7); + Ordering ordering; + ordering += "x1","x3","x5","x7","x2","x6","x4"; + + // eliminate using a "nested dissection" ordering + GaussianBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering); + boost::shared_ptr actualSolution = chordalBayesNet->optimize(); + + // Create the Bayes tree + Gaussian bayesTree(*chordalBayesNet); + Gaussian::sharedClique R = bayesTree.root(); + + // Check the conditional P(C3|Root), which should be equal to P(x2|x4) + GaussianBayesNet expected3("x2",zero(2),0.687131); + Vector sigma3 = repeat(2, 0.707107); + Matrix A12 = (-0.5)*eye(2); + ConditionalGaussian::shared_ptr cg3(new ConditionalGaussian("x1", zero(2), eye(2), "x2", A12, sigma3)); + expected3.push_front(cg3); + Gaussian::sharedClique C3 = bayesTree["x1"]; + BayesNet actual3 = C3->marginal(R); + CHECK(assert_equal((BayesNet)expected3,actual3,1e-4)); +} + /* ************************************************************************* */ int main() { TestResult tr;