diff --git a/cpp/BayesTree-inl.h b/cpp/BayesTree-inl.h index e3e1447c1..cc074e2a7 100644 --- a/cpp/BayesTree-inl.h +++ b/cpp/BayesTree-inl.h @@ -6,7 +6,7 @@ #include #include "BayesTree.h" -#include "FactorGraph.h" +#include "FactorGraph-inl.h" namespace gtsam { @@ -120,6 +120,7 @@ namespace gtsam { /* ************************************************************************* */ template + template boost::shared_ptr BayesTree::marginal(const string& key) const { // find the clique to which key belongs @@ -128,16 +129,39 @@ namespace gtsam { "BayesTree::marginal('"+key+"'): key not found")); // find all cliques on the path to the root and turn into factor graph - // FactorGraph node_ptr node = it->second; - int i=0; + Ordering ordering; + FactorGraph graph; while (node!=NULL) { - //node->print("node"); + + // extend ordering + Ordering cliqueOrdering = node->ordering(); + ordering.splice (ordering.end(), cliqueOrdering); + + // extend factor graph + boost::shared_ptr > bayesNet = node; + FactorGraph cliqueGraph(*bayesNet); + typename FactorGraph::const_iterator factor=cliqueGraph.begin(); + for(; factor!=cliqueGraph.end(); factor++) + graph.push_back(*factor); + + // move up the tree node = node->parent_; } - boost::shared_ptr result(new Conditional); - return result; + //graph.print(); + ordering.reverse(); + //ordering.print(); + + // eliminate to get marginal + boost::shared_ptr > bayesNet; + typename boost::shared_ptr > chordalBayesNet = + graph.eliminate(bayesNet,ordering); + + //chordalBayesNet->print("chordalBayesNet"); + + boost::shared_ptr marginal = chordalBayesNet->back(); + return marginal; } /* ************************************************************************* */ diff --git a/cpp/BayesTree.h b/cpp/BayesTree.h index e9bd2c3b9..eeed835f0 100644 --- a/cpp/BayesTree.h +++ b/cpp/BayesTree.h @@ -94,6 +94,7 @@ namespace gtsam { boost::shared_ptr > root() const {return root_;} /** return marginal on any variable */ + template boost::shared_ptr marginal(const std::string& key) const; }; // BayesTree diff --git a/cpp/testBayesTree.cpp b/cpp/testBayesTree.cpp index 74520b5bc..0fe41898b 100644 --- a/cpp/testBayesTree.cpp +++ b/cpp/testBayesTree.cpp @@ -127,9 +127,14 @@ TEST( BayesTree, balanced_smoother_marginals ) //CHECK(assert_equal(expected_root,actual_root)); // Check marginal on x1 - ConditionalGaussian expected; - ConditionalGaussian::shared_ptr actual = bayesTree.marginal("x1"); - CHECK(assert_equal(expected,*actual)); + double data1[] = { 1.0, 0.0, + 0.0, 1.0}; + Matrix R1 = Matrix_(2,2, data1); + Vector d1(2); d1(0) = -0.615385; d1(1) = 0; + Vector tau1(2); tau1(0) = 1.61803; tau1(1) = 1.61803; + ConditionalGaussian expected("x1",d1, R1, tau1); + ConditionalGaussian::shared_ptr actual = bayesTree.marginal("x1"); + CHECK(assert_equal(expected,*actual,1e-4)); // JunctionTree is an undirected tree of cliques // JunctionTree marginals = bayesTree.marginals();