From d0a93ad9ddbe0fe5eb892cef1ff3a38f46aeced8 Mon Sep 17 00:00:00 2001 From: Kai Ni Date: Sat, 13 Feb 2010 07:09:27 +0000 Subject: [PATCH] insert bayes net as a clique --- cpp/BayesTree-inl.h | 44 +++++++++++++++++++++++++++++++++- cpp/BayesTree.h | 22 +++++++++++++---- cpp/testBayesTree.cpp | 56 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 116 insertions(+), 6 deletions(-) diff --git a/cpp/BayesTree-inl.h b/cpp/BayesTree-inl.h index 5d6f4d3e6..10e194b3a 100644 --- a/cpp/BayesTree-inl.h +++ b/cpp/BayesTree-inl.h @@ -23,6 +23,10 @@ namespace gtsam { using namespace std; + /* ************************************************************************* */ + template + BayesTree::Clique::Clique() {} + /* ************************************************************************* */ template BayesTree::Clique::Clique(const sharedConditional& conditional) { @@ -298,6 +302,18 @@ namespace gtsam { return new_clique; } + /* ************************************************************************* */ + template + typename BayesTree::sharedClique BayesTree::addClique + (const sharedConditional& conditional, list& child_cliques) { + sharedClique new_clique(new Clique(conditional)); + nodes_.insert(make_pair(conditional->key(), new_clique)); + new_clique->children_ = child_cliques; + BOOST_FOREACH(sharedClique& child, child_cliques) + child->parent_ = new_clique; + return new_clique; + } + /* ************************************************************************* */ template void BayesTree::removeClique(sharedClique clique) { @@ -337,8 +353,9 @@ namespace gtsam { printf("WARNING: BayesTree.print encountered a forest...\n"); return; } - cout << s << ": size == " << size() << endl; + cout << s << ": clique size == " << size() << ", node size == " << nodes_.size() << endl; if (nodes_.empty()) return; + printf("printing tree!\n"); root_->printTree(""); } @@ -409,6 +426,31 @@ namespace gtsam { addClique(conditional,parent_clique); } + /* ************************************************************************* */ + template + typename BayesTree::sharedClique BayesTree::insert( + const BayesNet& bayesNet, list& children, bool isRootClique) + { + if (bayesNet.size() == 0) + throw invalid_argument("BayesTree::insert: empty bayes net!"); + + // create a new clique and add all the conditionals to the clique + sharedClique new_clique; + typename BayesNet::sharedConditional conditional; + BOOST_REVERSE_FOREACH(conditional, bayesNet) { + if (!new_clique.get()) { + new_clique = addClique(conditional,children); + } else { + nodes_.insert(make_pair(conditional->key(), new_clique)); + new_clique->push_front(conditional); + } + } + + if (isRootClique) root_ = new_clique; + + return new_clique; + } + /* ************************************************************************* */ // First finds clique marginal then marginalizes that /* ************************************************************************* */ diff --git a/cpp/BayesTree.h b/cpp/BayesTree.h index 64ea69c8b..6b6b462b3 100644 --- a/cpp/BayesTree.h +++ b/cpp/BayesTree.h @@ -51,6 +51,8 @@ namespace gtsam { //* Constructor */ Clique(const sharedConditional& conditional); + Clique(); + /** return keys in frontal:separator order */ Ordering keys() const; @@ -100,17 +102,21 @@ namespace gtsam { typedef SymbolMap Nodes; Nodes nodes_; + protected: + /** Root clique */ sharedClique root_; - /** add a clique */ + /** remove a clique: warning, can result in a forest */ + void removeClique(sharedClique clique); + + /** add a clique (top down) */ sharedClique addClique(const sharedConditional& conditional, sharedClique parent_clique = sharedClique()); - protected: - - /** remove a clique: warning, can result in a forest */ - void removeClique(sharedClique clique); + /** add a clique (bottom up) */ + sharedClique addClique(const sharedConditional& conditional, + std::list& child_cliques); public: @@ -146,6 +152,12 @@ namespace gtsam { /** insert a new conditional */ void insert(const sharedConditional& conditional, const IndexTable& index); + /** insert a new clique corresponding to the given bayes net. + * it is the caller's responsibility to decide whether the given bayes net is a valid clique, + * i.e. all the variables (frontal and separator) are connected */ + sharedClique insert(const BayesNet& bayesNet, + std::list& children, bool isRootClique = false); + /** number of cliques */ inline size_t size() const { if(root_) diff --git a/cpp/testBayesTree.cpp b/cpp/testBayesTree.cpp index 600a53173..d4cf87656 100644 --- a/cpp/testBayesTree.cpp +++ b/cpp/testBayesTree.cpp @@ -333,6 +333,62 @@ TEST( BayesTree, removeTop3 ) CHECK(orphans.size() == 0); } /* ************************************************************************* */ +/** + * x2 - x3 - x4 - x5 + * | / \ | + * x1 / \ x6 + */ +TEST( BayesTree, insert ) +{ + // construct bayes tree by split the graph along the separator x3 - x4 + Symbol _x4_('x', 4), _x5_('x', 5), _x6_('x', 6); + SymbolicFactorGraph fg1, fg2, fg3; + fg1.push_factor(_x3_, _x4_); + fg2.push_factor(_x1_, _x2_); + fg2.push_factor(_x2_, _x3_); + fg2.push_factor(_x1_, _x3_); + fg3.push_factor(_x4_, _x5_); + fg3.push_factor(_x5_, _x6_); + fg3.push_factor(_x4_, _x6_); + + Ordering ordering1; ordering1 += _x3_, _x4_; + Ordering ordering2; ordering2 += _x1_, _x2_; + Ordering ordering3; ordering3 += _x6_, _x5_; + + BayesNet bn1, bn2, bn3; + bn1 = fg1.eliminate(ordering1); + bn2 = fg2.eliminate(ordering2); + bn3 = fg3.eliminate(ordering3); + + // insert child cliques + SymbolicBayesTree actual; + list children; + SymbolicBayesTree::sharedClique r1 = actual.insert(bn2, children); + SymbolicBayesTree::sharedClique r2 = actual.insert(bn3, children); + + // insert root clique + children.push_back(r1); + children.push_back(r2); + actual.insert(bn1, children, true); + + // traditional way + SymbolicFactorGraph fg; + fg.push_factor(_x3_, _x4_); + fg.push_factor(_x1_, _x2_); + fg.push_factor(_x2_, _x3_); + fg.push_factor(_x1_, _x3_); + fg.push_factor(_x4_, _x5_); + fg.push_factor(_x5_, _x6_); + fg.push_factor(_x4_, _x6_); + + Ordering ordering; ordering += _x1_, _x2_, _x6_, _x5_, _x3_, _x4_; + BayesNet bn; + bn = fg.eliminate(ordering); + SymbolicBayesTree expected(bn); + CHECK(assert_equal(expected, actual)); + +} +/* ************************************************************************* */ int main() { TestResult tr;