diff --git a/cpp/BayesChain.h b/cpp/BayesChain.h index bdedf6698..143d791db 100644 --- a/cpp/BayesChain.h +++ b/cpp/BayesChain.h @@ -41,15 +41,18 @@ namespace gtsam { /** check equality */ bool equals(const BayesChain& other, double tol = 1e-9) const; - /** size is the number of nodes */ - inline size_t size() const {return nodes_.size();} - /** insert: use reverse topological sort (i.e. parents last) */ void insert(const std::string& key, boost::shared_ptr node); /** delete */ void erase(const std::string& key); + /** size is the number of nodes */ + inline size_t size() const {return nodes_.size();} + + /** return keys in topological sort order (parents first), i.e., reverse elimination order */ + inline std::list keys() const { return keys_;} + inline boost::shared_ptr operator[](const std::string& key) const { const_iterator cg = nodes_.find(key); // get node assert( cg != nodes_.end() ); diff --git a/cpp/BayesTree-inl.h b/cpp/BayesTree-inl.h index 7eed3d8c7..1497956c1 100644 --- a/cpp/BayesTree-inl.h +++ b/cpp/BayesTree-inl.h @@ -4,26 +4,64 @@ * @author Frank Dellaert */ +#include #include "BayesTree.h" namespace gtsam { using namespace std; + /* ************************************************************************* */ + template + Front::Front(string key, cond_ptr conditional) { + add(key, conditional); + separator_ = conditional->parents(); + } + + /* ************************************************************************* */ + template + void Front::print(const string& s) const { + cout << s; + BOOST_FOREACH(string key, keys_) cout << " " << key; + if (!separator_.empty()) { + cout << " :"; + BOOST_FOREACH(string key, separator_) + cout << " " << key; + } + cout << endl; + } + + /* ************************************************************************* */ + template + bool Front::equals(const Front& other, double tol) const { + return (keys_ == other.keys_) && + equal(conditionals_.begin(),conditionals_.end(),other.conditionals_.begin(),equals_star); + } + + /* ************************************************************************* */ + template + void Front::add(string key, cond_ptr conditional) { + keys_.push_front(key); + conditionals_.push_front(conditional); + } + /* ************************************************************************* */ template BayesTree::BayesTree() { } /* ************************************************************************* */ + // TODO: traversal is O(n*log(n)) but could be O(n) with better bayesChain template - BayesTree::BayesTree(BayesChain& bayesChain) { - list ordering;// = bayesChain.ordering(); + BayesTree::BayesTree(BayesChain& bayesChain, bool verbose) { + list reverseOrdering = bayesChain.keys(); + BOOST_FOREACH(string key, reverseOrdering) + insert(key,bayesChain[key],verbose); } /* ************************************************************************* */ template - void BayesTree::print(const std::string& s) const { + void BayesTree::print(const string& s) const { cout << s << ": size == " << nodes_.size() << endl; if (nodes_.empty()) return; nodes_[0]->printTree(""); @@ -34,19 +72,24 @@ namespace gtsam { bool BayesTree::equals(const BayesTree& other, double tol) const { return size()==other.size() && - equal(nodeMap_.begin(),nodeMap_.end(),other.nodeMap_.begin()) && - equal(nodes_.begin(),nodes_.end(),other.nodes_.begin(),equals_star); + equal(nodeMap_.begin(),nodeMap_.end(),other.nodeMap_.begin()) && + equal(nodes_.begin(),nodes_.end(),other.nodes_.begin(),equals_star); } /* ************************************************************************* */ template - void BayesTree::insert(string key, conditional_ptr conditional) { + void BayesTree::insert(string key, conditional_ptr conditional, bool verbose) { - // get any parent + if (verbose) cout << "Inserting " << key << "| "; + + // get parents list parents = conditional->parents(); + if (verbose) BOOST_FOREACH(string p, parents) cout << p << " "; + if (verbose) cout << endl; // if no parents, start a new root clique if (parents.empty()) { + if (verbose) cout << "Creating root clique" << endl; node_ptr root(new Node(key, conditional)); nodes_.push_back(root); nodeMap_.insert(make_pair(key, 0)); @@ -57,18 +100,20 @@ namespace gtsam { string parent = parents.front(); NodeMap::const_iterator it = nodeMap_.find(parent); if (it == nodeMap_.end()) throw(invalid_argument( - "BayesTree::insert: parent with key " + key + "was not yet inserted")); + "BayesTree::insert('"+key+"'): parent '" + parent + "' was not yet inserted")); int index = it->second; node_ptr parent_clique = nodes_[index]; // if the parents and parent clique have the same size, add to parent clique if (parent_clique->size() == parents.size()) { + if (verbose) cout << "Adding to clique " << index << endl; nodeMap_.insert(make_pair(key, index)); parent_clique->add(key, conditional); return; } // otherwise, start a new clique and add it to the tree + if (verbose) cout << "Starting new clique" << endl; node_ptr new_clique(new Node(key, conditional)); new_clique->parent_ = parent_clique; parent_clique->children_.push_back(new_clique); @@ -76,6 +121,6 @@ namespace gtsam { nodes_.push_back(new_clique); } -/* ************************************************************************* */ + /* ************************************************************************* */ } /// namespace gtsam diff --git a/cpp/BayesTree.h b/cpp/BayesTree.h index 8ee46e47c..ffd6ace1a 100644 --- a/cpp/BayesTree.h +++ b/cpp/BayesTree.h @@ -13,7 +13,6 @@ #include #include #include -#include // TODO: make cpp file #include "Testable.h" #include "BayesChain.h" @@ -30,35 +29,16 @@ namespace gtsam { public: /** constructor */ - Front(std::string key, cond_ptr conditional) { - add(key, conditional); - separator_ = conditional->parents(); - } + Front(std::string key, cond_ptr conditional); /** print */ - void print(const std::string& s = "") const { - std::cout << s; - BOOST_FOREACH(std::string key, keys_) - std::cout << " " << key; - if (!separator_.empty()) { - std::cout << " :"; - BOOST_FOREACH(std::string key, separator_) - std::cout << " " << key; - } - std::cout << std::endl; - } + void print(const std::string& s = "") const; - /** check equality. TODO: only keys */ - bool equals(const Front& other, double tol = 1e-9) const { - return (keys_ == other.keys_) && - equal(conditionals_.begin(),conditionals_.end(),other.conditionals_.begin(),equals_star); - } + /** check equality */ + bool equals(const Front& other, double tol = 1e-9) const; /** add a frontal node */ - void add(std::string key, cond_ptr conditional) { - keys_.push_front(key); - conditionals_.push_front(conditional); - } + void add(std::string key, cond_ptr conditional); /** return size of the clique */ inline size_t size() const {return keys_.size() + separator_.size();} @@ -109,7 +89,7 @@ namespace gtsam { BayesTree(); /** Create a Bayes Tree from a SymbolicBayesChain */ - BayesTree(BayesChain& bayesChain); + BayesTree(BayesChain& bayesChain, bool verbose=false); /** Destructor */ virtual ~BayesTree() {} @@ -121,7 +101,7 @@ namespace gtsam { bool equals(const BayesTree& other, double tol = 1e-9) const; /** insert a new conditional */ - void insert(std::string key, conditional_ptr conditional); + void insert(std::string key, conditional_ptr conditional, bool verbose=false); /** number of cliques */ inline size_t size() const { return nodes_.size();} diff --git a/cpp/testBayesTree.cpp b/cpp/testBayesTree.cpp index f6f9fd118..d20fb4f2e 100644 --- a/cpp/testBayesTree.cpp +++ b/cpp/testBayesTree.cpp @@ -12,23 +12,23 @@ using namespace boost::assign; #include "SymbolicBayesChain.h" #include "BayesTree-inl.h" +#include "SmallExample.h" using namespace gtsam; // Conditionals for ASIA example from the tutorial with A and D evidence -SymbolicConditional::shared_ptr - B(new SymbolicConditional()), - L(new SymbolicConditional("B")), - E(new SymbolicConditional("L","B")), - S(new SymbolicConditional("L","B")), - T(new SymbolicConditional("L","E")), +SymbolicConditional::shared_ptr B(new SymbolicConditional()), L( + new SymbolicConditional("B")), E(new SymbolicConditional("L", "B")), S( + new SymbolicConditional("L", "B")), T(new SymbolicConditional("E", "L")), X(new SymbolicConditional("E")); /* ************************************************************************* */ TEST( BayesTree, Front ) { - Front f1("B",B); f1.add("L",L); - Front f2("L",L); f2.add("B",B); + Front f1("B", B); + f1.add("L", L); + Front f2("L", L); + f2.add("B", B); CHECK(f1.equals(f1)); CHECK(!f1.equals(f2)); } @@ -38,31 +38,82 @@ TEST( BayesTree, constructor ) { // Create using insert BayesTree bayesTree; - bayesTree.insert("B",B); - bayesTree.insert("L",L); - bayesTree.insert("E",E); - bayesTree.insert("S",S); - bayesTree.insert("T",T); - bayesTree.insert("X",X); + bayesTree.insert("B", B); + bayesTree.insert("L", L); + bayesTree.insert("E", E); + bayesTree.insert("S", S); + bayesTree.insert("T", T); + bayesTree.insert("X", X); // Check Size LONGS_EQUAL(4,bayesTree.size()); // Check root - Front expected_root("B",B); - expected_root.add("L",L); - expected_root.add("E",E); + Front expected_root("B", B); + expected_root.add("L", L); + expected_root.add("E", E); Front actual_root = bayesTree.root(); CHECK(assert_equal(expected_root,actual_root)); // Create from symbolic Bayes chain in which we want to discover cliques - map nodes; - insert(nodes)("B",B)("L",L)("E",E)("S",S)("T",T)("X",X); - SymbolicBayesChain ASIA(nodes); + SymbolicBayesChain ASIA; + ASIA.insert("X", X); + ASIA.insert("T", T); + ASIA.insert("S", S); + ASIA.insert("E", E); + ASIA.insert("L", L); + ASIA.insert("B", B); BayesTree bayesTree2(ASIA); // Check whether the same - //CHECK(assert_equal(bayesTree,bayesTree2)); + CHECK(assert_equal(bayesTree,bayesTree2)); +} + +/* ************************************************************************* * + Bayes tree for smoother with "natural" ordering: + x6 x7 + x5 : x6 + x4 : x5 + x3 : x4 + x2 : x3 + x1 : x2 +/* ************************************************************************* */ +TEST( BayesTree, smoother ) +{ + // Create smoother with 7 nodes + LinearFactorGraph smoother = createSmoother(7); + Ordering ordering; + for (int t = 1; t <= 7; t++) + ordering.push_back(symbol('x', t)); + + // eliminate using the "natural" ordering + ChordalBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering); + + // Create the Bayes tree + BayesTree bayesTree(*chordalBayesNet,false); + LONGS_EQUAL(6,bayesTree.size()); +} + +/* ************************************************************************* * + Bayes tree for smoother with "nested dissection" ordering: + x5 x6 x4 + x3 x2 : x4 + x1 : x2 + x7 : x6 +/* ************************************************************************* */ +TEST( BayesTree, balanced_smoother ) +{ + // Create smoother with 7 nodes + LinearFactorGraph smoother = createSmoother(7); + Ordering ordering; + ordering += "x1","x3","x5","x7","x2","x6","x4"; + + // eliminate using a "nested dissection" ordering + ChordalBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering); + + // Create the Bayes tree + BayesTree bayesTree(*chordalBayesNet,false); + LONGS_EQUAL(4,bayesTree.size()); } /* ************************************************************************* */