From 53890c4ba6eca8f51b3c59c761137efa361530ff Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 31 Oct 2009 05:12:39 +0000 Subject: [PATCH] Symbolic Bayes Tree successfully constructed --- cpp/BayesTree-inl.h | 58 ++++++++++++++- cpp/BayesTree.h | 124 ++++++++++++++++++++++++++++----- cpp/ConditionalGaussian.cpp | 21 +++--- cpp/ConditionalGaussian.h | 13 ++-- cpp/SymbolicConditional.h | 4 ++ cpp/testBayesTree.cpp | 65 ++++++++++++++--- cpp/testSymbolicBayesChain.cpp | 23 +++--- 7 files changed, 256 insertions(+), 52 deletions(-) diff --git a/cpp/BayesTree-inl.h b/cpp/BayesTree-inl.h index 4d5dcf742..0c0c7caa1 100644 --- a/cpp/BayesTree-inl.h +++ b/cpp/BayesTree-inl.h @@ -6,22 +6,74 @@ #include "BayesTree.h" -using namespace std; - namespace gtsam { + using namespace std; + + /* ************************************************************************* */ + template + BayesTree::BayesTree() { + } + + /* ************************************************************************* */ template BayesTree::BayesTree(BayesChain& bayesChain) { + list ordering;// = bayesChain.ordering(); } + /* ************************************************************************* */ template - void BayesTree::print(const string& s) const { + void BayesTree::print(const std::string& s) const { + cout << s << ": size == " << nodes_.size() << endl; + if (nodes_.empty()) return; + nodes_[0]->printTree(""); } + /* ************************************************************************* */ template bool BayesTree::equals(const BayesTree& other, double tol) const { return false; } + /* ************************************************************************* */ + template + void BayesTree::insert(string key, conditional_ptr conditional) { + + // get any parent + list parents = conditional->parents(); + + // if no parents, start a new root clique + if (parents.empty()) { + node_ptr root(new Node(key, conditional)); + nodes_.push_back(root); + nodeMap_.insert(make_pair(key, 0)); + return; + } + + // otherwise, find the parent clique + string parent = parents.front(); + NodeMap::const_iterator it = nodeMap_.find(parent); + if (it == nodeMap_.end()) throw(invalid_argument( + "BayesTree::insert: parent with key " + key + "was not yet inserted")); + int index = it->second; + node_ptr parent_clique = nodes_[index]; + + // if the parents and parent clique have the same size, add to parent clique + if (parent_clique->size() == parents.size()) { + nodeMap_.insert(make_pair(key, index)); + parent_clique->add(key, conditional); + return; + } + + // otherwise, start a new clique and add it to the tree + node_ptr new_clique(new Node(key, conditional)); + new_clique->parent_ = parent_clique; + parent_clique->children_.push_back(new_clique); + nodeMap_.insert(make_pair(key, nodes_.size())); + nodes_.push_back(new_clique); + } + +/* ************************************************************************* */ + } /// namespace gtsam diff --git a/cpp/BayesTree.h b/cpp/BayesTree.h index 257fdc3bc..0af1d2175 100644 --- a/cpp/BayesTree.h +++ b/cpp/BayesTree.h @@ -8,37 +8,123 @@ #pragma once +#include #include +#include #include #include - +#include // TODO: make cpp file #include "Testable.h" #include "BayesChain.h" namespace gtsam { -/** - * Bayes tree - * Templated on the Conditional class, the type of node in the underlying Bayes chain. - * This could be a ConditionalProbabilityTable, a ConditionalGaussian, or a SymbolicConditional - */ -template -class BayesTree : public Testable > -{ -public: + /** A clique in a Bayes tree consisting of frontal nodes and conditionals */ + template + class Front: Testable > { + private: + typedef boost::shared_ptr cond_ptr; + std::list keys_; /** frontal keys */ + std::list nodes_; /** conditionals */ + std::list separator_; /** separator keys */ + public: - /** Create a Bayes Tree from a SymbolicBayesChain */ - BayesTree(BayesChain& bayesChain); + /** constructor */ + Front(std::string key, cond_ptr conditional) { + add(key, conditional); + separator_ = conditional->parents(); + } - /** Destructor */ - virtual ~BayesTree() {} + /** print */ + void print(const std::string& s = "") const { + std::cout << s; + BOOST_FOREACH(std::string key, keys_) + std::cout << " " << key; + if (!separator_.empty()) { + std::cout << " :"; + BOOST_FOREACH(std::string key, separator_) + std::cout << " " << key; + } + std::cout << std::endl; + } - /** print */ - void print(const std::string& s = "") const; + /** check equality. TODO: only keys */ + bool equals(const Front& other, double tol = 1e-9) const { + return (keys_ == other.keys_); + } - /** check equality */ - bool equals(const BayesTree& other, double tol = 1e-9) const; + /** add a frontal node */ + void add(std::string key, cond_ptr conditional) { + keys_.push_front(key); + nodes_.push_front(conditional); + } -}; // BayesTree + /** return size of the clique */ + inline size_t size() const {return keys_.size() + separator_.size();} + }; + + /** + * Bayes tree + * Templated on the Conditional class, the type of node in the underlying Bayes chain. + * This could be a ConditionalProbabilityTable, a ConditionalGaussian, or a SymbolicConditional + */ + template + class BayesTree: public Testable > { + + public: + + typedef boost::shared_ptr conditional_ptr; + + private: + + /** A Node in the tree is a Front with tree connectivity */ + struct Node : public Front { + typedef boost::shared_ptr shared_ptr; + shared_ptr parent_; + std::list children_; + + Node(std::string key, conditional_ptr conditional):Front(key,conditional) {} + + /** print this node and entire subtree below it*/ + void printTree(const std::string& indent) const { + print(indent); + BOOST_FOREACH(shared_ptr child, children_) + child->printTree(indent+" "); + } + }; + + /** vector of Nodes */ + typedef boost::shared_ptr node_ptr; + typedef std::vector Nodes; + Nodes nodes_; + + /** Map from keys to Node index */ + typedef std::map NodeMap; + NodeMap nodeMap_; + + public: + + /** Create an empty Bayes Tree */ + BayesTree(); + + /** Create a Bayes Tree from a SymbolicBayesChain */ + BayesTree(BayesChain& bayesChain); + + /** Destructor */ + virtual ~BayesTree() {} + + /** print */ + void print(const std::string& s = "") const; + + /** check equality */ + bool equals(const BayesTree& other, double tol = 1e-9) const; + + /** insert a new conditional */ + void insert(std::string key, conditional_ptr conditional); + + /** return root clique */ + const Front& root() const {return *(nodes_[0]);} + + }; // BayesTree } /// namespace gtsam diff --git a/cpp/ConditionalGaussian.cpp b/cpp/ConditionalGaussian.cpp index efac31236..f456c68ef 100644 --- a/cpp/ConditionalGaussian.cpp +++ b/cpp/ConditionalGaussian.cpp @@ -43,7 +43,7 @@ ConditionalGaussian::ConditionalGaussian(Vector d, /* ************************************************************************* */ ConditionalGaussian::ConditionalGaussian(const Vector& d, const Matrix& R, - const map& parents) + const map& parents) : R_(R), d_(d), parents_(parents) { } @@ -53,7 +53,7 @@ void ConditionalGaussian::print(const string &s) const { cout << s << ":" << endl; gtsam::print(R_,"R"); - for(map::const_iterator it = parents_.begin() ; it != parents_.end() ; it++ ) { + for(Parents::const_iterator it = parents_.begin() ; it != parents_.end() ; it++ ) { const string& j = it->first; const Matrix& Aj = it->second; gtsam::print(Aj, "A["+j+"]"); @@ -63,7 +63,7 @@ void ConditionalGaussian::print(const string &s) const /* ************************************************************************* */ bool ConditionalGaussian::equals(const ConditionalGaussian &cg, double tol) const { - map::const_iterator it = parents_.begin(); + Parents::const_iterator it = parents_.begin(); // check if the size of the parents_ map is the same if (parents_.size() != cg.parents_.size()) return false; @@ -77,21 +77,26 @@ bool ConditionalGaussian::equals(const ConditionalGaussian &cg, double tol) cons // check if the matrices are the same // iterate over the parents_ map for (it = parents_.begin(); it != parents_.end(); it++) { - map::const_iterator it2 = cg.parents_.find( - it->first.c_str()); + Parents::const_iterator it2 = cg.parents_.find(it->first.c_str()); if (it2 != cg.parents_.end()) { if (!(equal_with_abs_tol(it->second, it2->second, tol))) return false; - } else { + } else return false; - } } return true; } +/* ************************************************************************* */ +list ConditionalGaussian::parents() { + list result; + for (Parents::const_iterator it = parents_.begin(); it != parents_.end(); it++) + result.push_back(it->first); +} + /* ************************************************************************* */ Vector ConditionalGaussian::solve(const VectorConfig& x) const { Vector rhs = d_; - for (map::const_iterator it = parents_.begin(); it + for (Parents::const_iterator it = parents_.begin(); it != parents_.end(); it++) { const string& j = it->first; const Matrix& Aj = it->second; diff --git a/cpp/ConditionalGaussian.h b/cpp/ConditionalGaussian.h index 806032af8..9ba1d1ce9 100644 --- a/cpp/ConditionalGaussian.h +++ b/cpp/ConditionalGaussian.h @@ -10,6 +10,7 @@ #pragma once #include +#include #include #include #include @@ -29,7 +30,9 @@ namespace gtsam { class ConditionalGaussian : boost::noncopyable, public Testable { public: - typedef std::map::const_iterator const_iterator; + typedef std::map Parents; + typedef Parents::const_iterator const_iterator; + typedef boost::shared_ptr shared_ptr; protected: @@ -37,13 +40,12 @@ namespace gtsam { Matrix R_; /** the names and the matrices connecting to parent nodes */ - std::map parents_; + Parents parents_; /** the RHS vector */ Vector d_; public: - typedef boost::shared_ptr shared_ptr; /** constructor */ ConditionalGaussian() {}; @@ -84,7 +86,7 @@ namespace gtsam { */ ConditionalGaussian(const Vector& d, const Matrix& R, - const std::map& parents); + const Parents& parents); /** deconstructor */ virtual ~ConditionalGaussian() {}; @@ -98,6 +100,9 @@ namespace gtsam { /** dimension of multivariate variable */ size_t dim() const {return R_.size2();} + /** return all parents */ + std::list parents(); + /** return stuff contained in ConditionalGaussian */ const Vector& get_d() const {return d_;} const Matrix& get_R() const {return R_;} diff --git a/cpp/SymbolicConditional.h b/cpp/SymbolicConditional.h index 5088781d2..e07ff4dad 100644 --- a/cpp/SymbolicConditional.h +++ b/cpp/SymbolicConditional.h @@ -20,6 +20,7 @@ namespace gtsam { * Conditional node for use in a Bayes net */ class SymbolicConditional: Testable { + private: std::list parents_; @@ -68,6 +69,9 @@ namespace gtsam { return parents_ == other.parents_; } + /** return any parent */ + std::list parents() { return parents_;} + }; } /// namespace gtsam diff --git a/cpp/testBayesTree.cpp b/cpp/testBayesTree.cpp index a4bba9cc7..aadbefd88 100644 --- a/cpp/testBayesTree.cpp +++ b/cpp/testBayesTree.cpp @@ -4,24 +4,73 @@ * @author Frank Dellaert */ +#include // for 'insert()' +#include // for operator += #include #include "SymbolicBayesChain.h" -#include "smallExample.h" #include "BayesTree-inl.h" +//using namespace std; using namespace gtsam; +using namespace boost::assign; + +// Conditionals for ASIA example from the tutorial with A and D evidence +SymbolicConditional::shared_ptr + B(new SymbolicConditional()), + L(new SymbolicConditional("B")), + E(new SymbolicConditional("L","B")), + S(new SymbolicConditional("L","B")), + T(new SymbolicConditional("L","E")), + X(new SymbolicConditional("E")); + +/* ************************************************************************* */ +TEST( BayesTree, Front ) +{ + Front f1("B",B); f1.add("L",L); + Front f2("L",L); f2.add("B",B); + CHECK(f1.equals(f1)); + CHECK(!f1.equals(f2)); +} + +/* ************************************************************************* */ +TEST( BayesTree, insert ) +{ + // Insert + BayesTree bayesTree; + bayesTree.insert("B",B); + bayesTree.insert("L",L); + bayesTree.insert("E",E); + bayesTree.insert("S",S); + bayesTree.insert("T",T); + bayesTree.insert("X",X); + //bayesTree.print("bayesTree"); + + //LONGS_EQUAL(1,bayesTree.size()); + + // Check root + Front expected_root("B",B); + //CHECK(assert_equal(expected_root,bayesTree.root())); +} /* ************************************************************************* */ TEST( BayesTree, constructor ) { - LinearFactorGraph factorGraph = createLinearFactorGraph(); - Ordering ordering; - ordering.push_back("x2"); - ordering.push_back("l1"); - ordering.push_back("x1"); - SymbolicBayesChain symbolicBayesChain(factorGraph,ordering); - BayesTree bayesTree(symbolicBayesChain); + // Create Symbolic Bayes Chain in which we want to discover cliques + map nodes; + insert(nodes)("B",B)("L",L)("E",E)("S",S)("T",T)("X",X); + SymbolicBayesChain ASIA(nodes); + + // Create Bayes Tree from Symbolic Bayes Chain + BayesTree bayesTree(ASIA); + bayesTree.insert("B",B); + //bayesTree.print("bayesTree"); + + //LONGS_EQUAL(1,bayesTree.size()); + + // Check root + Front expected_root("B",B); + //CHECK(assert_equal(expected_root,bayesTree.root())); } /* ************************************************************************* */ diff --git a/cpp/testSymbolicBayesChain.cpp b/cpp/testSymbolicBayesChain.cpp index 974031bd8..44fafc59d 100644 --- a/cpp/testSymbolicBayesChain.cpp +++ b/cpp/testSymbolicBayesChain.cpp @@ -4,6 +4,10 @@ * @author Frank Dellaert */ +#include // for 'insert()' +#include // for operator += +using namespace boost::assign; + #include #include "smallExample.h" @@ -16,23 +20,22 @@ using namespace gtsam; TEST( SymbolicBayesChain, constructor ) { // Create manually - SymbolicConditional::shared_ptr x2(new SymbolicConditional("x1", "l1")); - SymbolicConditional::shared_ptr l1(new SymbolicConditional("x1")); - SymbolicConditional::shared_ptr x1(new SymbolicConditional()); + SymbolicConditional::shared_ptr + x2(new SymbolicConditional("x1", "l1")), + l1(new SymbolicConditional("x1")), + x1(new SymbolicConditional()); map nodes; - nodes.insert(make_pair("x2", x2)); - nodes.insert(make_pair("l1", l1)); - nodes.insert(make_pair("x1", x1)); + insert(nodes)("x2", x2)("l1", l1)("x1", x1); SymbolicBayesChain expected(nodes); // Create from a factor graph Ordering ordering; - ordering.push_back("x2"); - ordering.push_back("l1"); - ordering.push_back("x1"); + ordering += "x2","l1","x1"; LinearFactorGraph factorGraph = createLinearFactorGraph(); SymbolicBayesChain actual(factorGraph, ordering); - //CHECK(assert_equal(expected, actual)); + CHECK(assert_equal(expected, actual)); + + //bayesChain.ordering(); } /* ************************************************************************* */