diff --git a/cpp/BayesTree-inl.h b/cpp/BayesTree-inl.h index b0581fe73..6a507e9de 100644 --- a/cpp/BayesTree-inl.h +++ b/cpp/BayesTree-inl.h @@ -14,16 +14,16 @@ namespace gtsam { /* ************************************************************************* */ template - BayesTree::Node::Node(const boost::shared_ptr& conditional) { + BayesTree::Clique::Clique(const boost::shared_ptr& conditional) { separator_ = conditional->parents(); this->push_back(conditional); } /* ************************************************************************* */ template - void BayesTree::Node::print(const string& s) const { + void BayesTree::Clique::print(const string& s) const { cout << s; - BOOST_REVERSE_FOREACH(const conditional_ptr& conditional, this->conditionals_) + BOOST_REVERSE_FOREACH(const sharedConditional& conditional, this->conditionals_) cout << " " << conditional->key(); if (!separator_.empty()) { cout << " :"; @@ -35,19 +35,37 @@ namespace gtsam { /* ************************************************************************* */ template - void BayesTree::Node::printTree(const string& indent) const { + void BayesTree::Clique::printTree(const string& indent) const { print(indent); BOOST_FOREACH(shared_ptr child, children_) child->printTree(indent+" "); } + /* ************************************************************************* */ + template + typename BayesTree::sharedBayesNet + BayesTree::Clique::shortcut() { + // The shortcut density is a conditional P(S|R) of the separator of this + // clique on the root. We can compute it recursively from the parent shortcut + // P(S_p|R) as \int P(F_p|S_p) P(S_p|R), where F_p are the frontal nodes in p + + // The base case is when we are the root or the parent is the root, + // in which case we return an empty Bayes net + sharedBayesNet p_S_R(new BayesNet); + if (parent_==NULL || parent_->parent_==NULL) return p_S_R; + + // If not, calculate the parent shortcut P(S_p|R) + sharedBayesNet p_Sp_R = parent_->shortcut(); + + return p_S_R; + } + /* ************************************************************************* */ template BayesTree::BayesTree() { } /* ************************************************************************* */ - // TODO: traversal is O(n*log(n)) but could be O(n) with better bayesNet template BayesTree::BayesTree(const BayesNet& bayesNet) { typename BayesNet::const_reverse_iterator rit; @@ -68,21 +86,7 @@ namespace gtsam { bool BayesTree::equals(const BayesTree& other, double tol) const { return size()==other.size(); - //&& equal(nodes_.begin(),nodes_.end(),other.nodes_.begin(),equals_star(tol)); - } - - /* ************************************************************************* */ - template - boost::shared_ptr::Node> BayesTree::addClique - (const boost::shared_ptr& conditional, node_ptr parent_clique) - { - node_ptr new_clique(new Node(conditional)); - nodes_.insert(make_pair(conditional->key(), new_clique)); - if (parent_clique!=NULL) { - new_clique->parent_ = parent_clique; - parent_clique->children_.push_back(new_clique); - } - return new_clique; + //&& equal(nodes_.begin(),nodes_.end(),other.nodes_.begin(),equals_star(tol)); } /* ************************************************************************* */ @@ -102,7 +106,7 @@ namespace gtsam { // otherwise, find the parent clique string parent = parents.front(); - node_ptr parent_clique = (*this)[parent]; + sharedClique parent_clique = (*this)[parent]; // if the parents and parent clique have the same size, add to parent clique if (parent_clique->size() == parents.size()) { @@ -128,7 +132,7 @@ namespace gtsam { boost::shared_ptr BayesTree::marginal(const string& key) const { // get clique containing key, and remove all factors below key - node_ptr clique = (*this)[key]; + sharedClique clique = (*this)[key]; Ordering ordering = clique->ordering(); FactorGraph graph(*clique); while(ordering.front()!=key) { diff --git a/cpp/BayesTree.h b/cpp/BayesTree.h index 6d343b2d4..f898bf4a9 100644 --- a/cpp/BayesTree.h +++ b/cpp/BayesTree.h @@ -29,21 +29,22 @@ namespace gtsam { public: - typedef boost::shared_ptr conditional_ptr; + typedef boost::shared_ptr sharedConditional; + typedef boost::shared_ptr > sharedBayesNet; - /** A Node in the tree is an incomplete Bayes net: the variables + /** A Clique in the tree is an incomplete Bayes net: the variables * in the Bayes net are the frontal nodes, and the variables conditioned * on is the separator. We also have pointers up and down the tree. */ - struct Node: public BayesNet { + struct Clique: public BayesNet { - typedef boost::shared_ptr shared_ptr; + typedef boost::shared_ptr shared_ptr; shared_ptr parent_; - std::list separator_; /** separator keys */ std::list children_; + std::list separator_; /** separator keys */ //* Constructor */ - Node(const conditional_ptr& conditional); + Clique(const sharedConditional& conditional); /** The size *includes* the separator */ size_t size() const { @@ -53,24 +54,35 @@ namespace gtsam { /** print this node */ void print(const std::string& s = "Bayes tree node") const; - /** print this node and entire subtree below it*/ + /** print this node and entire subtree below it */ void printTree(const std::string& indent) const; + + /** return the conditional P(S|Root) on the separator given the root */ + sharedBayesNet shortcut(); }; - typedef boost::shared_ptr node_ptr; + typedef boost::shared_ptr sharedClique; private: - /** Map from keys to Node */ - typedef std::map Nodes; + /** Map from keys to Clique */ + typedef std::map Nodes; Nodes nodes_; /** Roor clique */ - node_ptr root_; + sharedClique root_; /** add a clique */ - node_ptr addClique(const conditional_ptr& conditional, - node_ptr parent_clique = node_ptr()); + sharedClique addClique(const sharedConditional& conditional, + sharedClique parent_clique = sharedClique()) { + sharedClique new_clique(new Clique(conditional)); + nodes_.insert(make_pair(conditional->key(), new_clique)); + if (parent_clique != NULL) { + new_clique->parent_ = parent_clique; + parent_clique->children_.push_back(new_clique); + } + return new_clique; + } public: @@ -91,7 +103,7 @@ namespace gtsam { bool equals(const BayesTree& other, double tol = 1e-9) const; /** insert a new conditional */ - void insert(const conditional_ptr& conditional); + void insert(const sharedConditional& conditional); /** number of cliques */ inline size_t size() const { @@ -99,23 +111,22 @@ namespace gtsam { } /** return root clique */ - node_ptr root() const { + sharedClique root() const { return root_; } /** find the clique to which key belongs */ - node_ptr operator[](const std::string& key) const { + sharedClique operator[](const std::string& key) const { typename Nodes::const_iterator it = nodes_.find(key); if (it == nodes_.end()) throw(std::invalid_argument( "BayesTree::operator['" + key + "'): key not found")); - node_ptr clique = it->second; + sharedClique clique = it->second; return clique; } /** return marginal on any variable */ template - conditional_ptr marginal(const std::string& key) const; - + sharedConditional marginal(const std::string& key) const; }; // BayesTree } /// namespace gtsam diff --git a/cpp/testBayesTree.cpp b/cpp/testBayesTree.cpp index 837be47c7..ca4bc399e 100644 --- a/cpp/testBayesTree.cpp +++ b/cpp/testBayesTree.cpp @@ -17,6 +17,8 @@ using namespace boost::assign; using namespace gtsam; +typedef BayesTree Gaussian; + // Conditionals for ASIA example from the tutorial with A and D evidence SymbolicConditional::shared_ptr B(new SymbolicConditional("B")), L( new SymbolicConditional("L", "B")), E( @@ -77,12 +79,12 @@ TEST( BayesTree, constructor ) /* ************************************************************************* * Bayes tree for smoother with "natural" ordering: - x6 x7 - x5 : x6 - x4 : x5 - x3 : x4 - x2 : x3 - x1 : x2 +C1 x6 x7 +C2 x5 : x6 +C3 x4 : x5 +C4 x3 : x4 +C5 x2 : x3 +C6 x1 : x2 /* ************************************************************************* */ TEST( BayesTree, smoother ) { @@ -96,8 +98,16 @@ TEST( BayesTree, smoother ) GaussianBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering); // Create the Bayes tree - BayesTree bayesTree(*chordalBayesNet); + Gaussian bayesTree(*chordalBayesNet); LONGS_EQUAL(7,bayesTree.size()); + + // Get the conditional P(S6|Root) + Vector sigma1 = repeat(2, 0.786153); + ConditionalGaussian::shared_ptr cg(new ConditionalGaussian("x2", zero(2), eye(2), sigma1)); + BayesNet expected; expected.push_back(cg); + Gaussian::sharedClique C6 = bayesTree["x1"]; + Gaussian::sharedBayesNet actual = C6->shortcut(); + //CHECK(assert_equal(expected,*actual,1e-4)); } /* ************************************************************************* * @@ -130,7 +140,7 @@ TEST( BayesTree, balanced_smoother_marginals ) CHECK(assert_equal(expectedSolution,*actualSolution,1e-4)); // Create the Bayes tree - BayesTree bayesTree(*chordalBayesNet); + Gaussian bayesTree(*chordalBayesNet); LONGS_EQUAL(7,bayesTree.size()); // Check root clique @@ -158,9 +168,6 @@ TEST( BayesTree, balanced_smoother_marginals ) ConditionalGaussian expected3("x3", delta, R, sigma3); ConditionalGaussian::shared_ptr actual3 = bayesTree.marginal("x3"); CHECK(assert_equal(expected3,*actual3,1e-4)); - - // JunctionTree is an undirected tree of cliques - // JunctionTree marginals = bayesTree.marginals(); } /* ************************************************************************* */