diff --git a/cpp/BayesTree-inl.h b/cpp/BayesTree-inl.h index 0442b82de..003de7989 100644 --- a/cpp/BayesTree-inl.h +++ b/cpp/BayesTree-inl.h @@ -6,6 +6,7 @@ #include #include "BayesTree.h" +#include "FactorGraph.h" namespace gtsam { @@ -27,7 +28,7 @@ namespace gtsam { if (!separator_.empty()) { cout << " :"; BOOST_FOREACH(string key, separator_) - cout << " " << key; + cout << " " << key; } cout << endl; } @@ -48,10 +49,10 @@ namespace gtsam { /* ************************************************************************* */ // TODO: traversal is O(n*log(n)) but could be O(n) with better bayesNet template - BayesTree::BayesTree(const BayesNet& bayesNet, bool verbose) { + BayesTree::BayesTree(const BayesNet& bayesNet) { typename BayesNet::const_reverse_iterator rit; for ( rit=bayesNet.rbegin(); rit != bayesNet.rend(); ++rit ) - insert(*rit,verbose); + insert(*rit); } /* ************************************************************************* */ @@ -73,22 +74,29 @@ namespace gtsam { /* ************************************************************************* */ template - void BayesTree::insert(const boost::shared_ptr& conditional, bool verbose) { + void BayesTree::addClique + (const boost::shared_ptr& conditional, node_ptr parent_clique) + { + node_ptr new_clique(new Node(conditional)); + nodeMap_.insert(make_pair(conditional->key(), nodes_.size())); + nodes_.push_back(new_clique); + if (parent_clique==NULL) return; + new_clique->parent_ = parent_clique; + parent_clique->children_.push_back(new_clique); + } - string key = conditional->key(); - if (verbose) cout << "Inserting " << key << "| "; - - // get parents + /* ************************************************************************* */ + template + void BayesTree::insert + (const boost::shared_ptr& conditional) + { + // get key and parents + string key = conditional->key(); list parents = conditional->parents(); - if (verbose) BOOST_FOREACH(string p, parents) cout << p << " "; - if (verbose) cout << endl; // if no parents, start a new root clique if (parents.empty()) { - if (verbose) cout << "Creating root clique" << endl; - node_ptr root(new Node(conditional)); - nodes_.push_back(root); - nodeMap_.insert(make_pair(key, 0)); + addClique(conditional); return; } @@ -96,26 +104,35 @@ namespace gtsam { string parent = parents.front(); NodeMap::const_iterator it = nodeMap_.find(parent); if (it == nodeMap_.end()) throw(invalid_argument( - "BayesTree::insert('"+key+"'): parent '" + parent + "' was not yet inserted")); - int index = it->second; - node_ptr parent_clique = nodes_[index]; - if (verbose) cout << "Parent clique " << index << " of size " << parent_clique->size() << endl; + "BayesTree::insert('"+key+"'): parent '" + parent + "' not yet inserted")); + int parent_index = it->second; + node_ptr parent_clique = nodes_[parent_index]; // if the parents and parent clique have the same size, add to parent clique if (parent_clique->size() == parents.size()) { - if (verbose) cout << "Adding to clique " << index << endl; - nodeMap_.insert(make_pair(key, index)); + nodeMap_.insert(make_pair(key, parent_index)); parent_clique->push_front(conditional); return; } // otherwise, start a new clique and add it to the tree - if (verbose) cout << "Starting new clique" << endl; - node_ptr new_clique(new Node(conditional)); - new_clique->parent_ = parent_clique; - parent_clique->children_.push_back(new_clique); - nodeMap_.insert(make_pair(key, nodes_.size())); - nodes_.push_back(new_clique); + addClique(conditional,parent_clique); + } + + /* ************************************************************************* */ + template + boost::shared_ptr BayesTree::marginal(const string& key) const { + + // find the clique to which key belongs + NodeMap::const_iterator it = nodeMap_.find(key); + if (it == nodeMap_.end()) throw(invalid_argument( + "BayesTree::marginal('"+key+"'): key not found")); + + // find all cliques on the path to the root + // FactorGraph + + boost::shared_ptr result(new Conditional); + return result; } /* ************************************************************************* */ diff --git a/cpp/BayesTree.h b/cpp/BayesTree.h index 7d4e9ec22..1ab94fe71 100644 --- a/cpp/BayesTree.h +++ b/cpp/BayesTree.h @@ -30,7 +30,6 @@ namespace gtsam { public: typedef boost::shared_ptr conditional_ptr; - typedef std::pair NamedConditional; private: @@ -66,13 +65,16 @@ namespace gtsam { typedef std::map NodeMap; NodeMap nodeMap_; + /** add a clique */ + void addClique(const conditional_ptr& conditional, node_ptr parent_clique=node_ptr()); + public: /** Create an empty Bayes Tree */ BayesTree(); /** Create a Bayes Tree from a Bayes Net */ - BayesTree(const BayesNet& bayesNet, bool verbose=false); + BayesTree(const BayesNet& bayesNet); /** Destructor */ virtual ~BayesTree() {} @@ -84,7 +86,7 @@ namespace gtsam { bool equals(const BayesTree& other, double tol = 1e-9) const; /** insert a new conditional */ - void insert(const boost::shared_ptr& conditional, bool verbose=false); + void insert(const boost::shared_ptr& conditional); /** number of cliques */ inline size_t size() const { return nodes_.size();} @@ -92,6 +94,9 @@ namespace gtsam { /** return root clique */ const BayesNet& root() const {return *(nodes_[0]);} + /** return marginal on any variable */ + boost::shared_ptr marginal(const std::string& key) const; + }; // BayesTree } /// namespace gtsam diff --git a/cpp/testBayesTree.cpp b/cpp/testBayesTree.cpp index bde5c044c..8b76306b7 100644 --- a/cpp/testBayesTree.cpp +++ b/cpp/testBayesTree.cpp @@ -27,10 +27,10 @@ SymbolicConditional::shared_ptr B(new SymbolicConditional("B")), L( /* ************************************************************************* */ TEST( BayesTree, Front ) { - BayesNet f1; + SymbolicBayesNet f1; f1.push_back(B); f1.push_back(L); - BayesNet f2; + SymbolicBayesNet f2; f2.push_back(L); f2.push_back(B); CHECK(f1.equals(f1)); @@ -68,9 +68,8 @@ TEST( BayesTree, constructor ) ASIA.push_back(E); ASIA.push_back(L); ASIA.push_back(B); - bool verbose = false; - BayesTree bayesTree2(ASIA,verbose); - if (verbose) bayesTree2.print("bayesTree2"); + BayesTree bayesTree2(ASIA); + //bayesTree2.print("bayesTree2"); // Check whether the same CHECK(assert_equal(bayesTree,bayesTree2)); @@ -97,7 +96,7 @@ TEST( BayesTree, smoother ) GaussianBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering); // Create the Bayes tree - BayesTree bayesTree(*chordalBayesNet,false); + BayesTree bayesTree(*chordalBayesNet); LONGS_EQUAL(6,bayesTree.size()); } @@ -108,7 +107,7 @@ TEST( BayesTree, smoother ) x1 : x2 x7 : x6 /* ************************************************************************* */ -TEST( BayesTree, balanced_smoother ) +TEST( BayesTree, balanced_smoother_marginals ) { // Create smoother with 7 nodes LinearFactorGraph smoother = createSmoother(7); @@ -119,8 +118,21 @@ TEST( BayesTree, balanced_smoother ) GaussianBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering); // Create the Bayes tree - BayesTree bayesTree(*chordalBayesNet,false); + BayesTree bayesTree(*chordalBayesNet); LONGS_EQUAL(4,bayesTree.size()); + + // Check root clique + //BayesNet expected_root; + //BayesNet actual_root = bayesTree.root(); + //CHECK(assert_equal(expected_root,actual_root)); + + // Check marginal on x1 + ConditionalGaussian expected; + ConditionalGaussian::shared_ptr actual = bayesTree.marginal("x1"); + CHECK(assert_equal(expected,*actual)); + + // JunctionTree is an undirected tree of cliques + // JunctionTree marginals = bayesTree.marginals(); } /* ************************************************************************* */