diff --git a/.cproject b/.cproject index ef2dd0bc9..248ce11a7 100644 --- a/.cproject +++ b/.cproject @@ -396,6 +396,14 @@ false true + + make + + check + true + true + true + make check @@ -578,6 +586,30 @@ true true + + make + + check + true + true + true + + + make + + testClusterTree.run + true + true + true + + + make + + testJunctionTree.run + true + true + true + make check diff --git a/inference/ClusterTree-inl.h b/inference/ClusterTree-inl.h new file mode 100644 index 000000000..488dac352 --- /dev/null +++ b/inference/ClusterTree-inl.h @@ -0,0 +1,159 @@ +/* + * ClusterTree-inl.h + * Created on: July 13, 2010 + * @author Kai Ni + * @author Frank Dellaert + * @brief: Collects factorgraph fragments defined on variable clusters, arranged in a tree + */ + +#pragma once + +#include + +#include "SymbolicFactorGraph.h" +#include "BayesTree-inl.h" +#include "ClusterTree.h" + +namespace gtsam { + + using namespace std; + + /* ************************************************************************* */ + template + bool ClusterTree::Clique::equals(const ClusterTree::Clique& other) const { + if (!frontal_.equals(other.frontal_)) + return false; + + if (!separator_.equals(other.separator_)) + return false; + + if (children_.size() != other.children_.size()) + return false; + + typename vector::const_iterator it1 = children_.begin(); + typename vector::const_iterator it2 = other.children_.begin(); + for(; it1!=children_.end(); it1++, it2++) + if (!(*it1)->equals(**it2)) return false; + + return true; + } + + /* ************************************************************************* */ + /** + * ClusterTree + */ + template + void ClusterTree::Clique::print(const string& indent) const { + // FG::print(indent); + cout << indent; + BOOST_FOREACH(const Symbol& key, frontal_) + cout << (string)key << " "; + cout << ":"; + BOOST_FOREACH(const Symbol& key, separator_) + cout << (string)key << " "; + cout << endl; + } + + /* ************************************************************************* */ + template + void ClusterTree::Clique::printTree(const string& indent) const { + print(indent); + BOOST_FOREACH(const shared_ptr& child, children_) + child->printTree(indent+" "); + } + + /* ************************************************************************* */ + template + ClusterTree::ClusterTree(FG& fg, const Ordering& ordering) { + // Symbolic factorization: GaussianFactorGraph -> SymbolicFactorGraph -> SymbolicBayesNet -> SymbolicBayesTree + SymbolicFactorGraph sfg(fg); + SymbolicBayesNet sbn = sfg.eliminate(ordering); + BayesTree sbt(sbn); + + // distribtue factors + root_ = distributeFactors(fg, sbt.root()); + } + + /* ************************************************************************* */ + template + typename ClusterTree::sharedClique ClusterTree::distributeFactors(FG& fg, + const BayesTree::sharedClique bayesClique) { + // create a new clique in the junction tree + sharedClique clique(new Clique()); + clique->frontal_ = bayesClique->ordering(); + clique->separator_.insert(bayesClique->separator_.begin(), bayesClique->separator_.end()); + + // recursively call the children + BOOST_FOREACH(const BayesTree::sharedClique bayesChild, bayesClique->children()) { + sharedClique child = distributeFactors(fg, bayesChild); + clique->children_.push_back(child); + child->parent_ = clique; + } + + // collect the factors + typedef vector Factors; + BOOST_FOREACH(const Symbol& frontal, clique->frontal_) { + Factors factors = fg.template findAndRemoveFactors(frontal); + BOOST_FOREACH(const typename FG::sharedFactor& factor_, factors) { + clique->push_back(factor_); + } + } + + return clique; + } + + /* ************************************************************************* */ + template template + pair > + ClusterTree::eliminateOneClique(sharedClique current) { + +// current->frontal_.print("current clique:"); + + typedef typename BayesTree::sharedClique sharedBtreeClique; + FG fg; // factor graph will be assembled from local factors and marginalized children + list > children; + fg.push_back(*current); // add the local factor graph + +// BOOST_FOREACH(const typename FG::sharedFactor& factor_, fg) +// Ordering(factor_->keys()).print("local factor:"); + + BOOST_FOREACH(sharedClique& child, current->children_) { + // receive the factors from the child and its clique point + FG fgChild; BayesTree childTree; + boost::tie(fgChild, childTree) = eliminateOneClique(child); + +// BOOST_FOREACH(const typename FG::sharedFactor& factor_, fgChild) +// Ordering(factor_->keys()).print("factor from child:"); + + fg.push_back(fgChild); + children.push_back(childTree); + } + + // eliminate the combined factors + // warning: fg is being eliminated in-place and will contain marginal afterwards + BayesNet bn = fg.eliminateFrontals(current->frontal_); + + // create a new clique corresponding the combined factors + BayesTree bayesTree(bn, children); + + return make_pair(fg, bayesTree); + } + + /* ************************************************************************* */ + template template + BayesTree ClusterTree::eliminate() { + pair > ret = this->eliminateOneClique(root_); +// ret.first.print("ret.first"); + if (ret.first.nrFactors() != 0) + throw runtime_error("JuntionTree::eliminate: elimination failed because of factors left over!"); + return ret.second; + } + + /* ************************************************************************* */ + template + bool ClusterTree::equals(const ClusterTree& other, double tol) const { + if (!root_ || !other.root_) return false; + return root_->equals(*other.root_); + } + +} //namespace gtsam diff --git a/inference/ClusterTree.h b/inference/ClusterTree.h new file mode 100644 index 000000000..926a944e1 --- /dev/null +++ b/inference/ClusterTree.h @@ -0,0 +1,103 @@ +/* + * ClusterTree.h + * Created on: July 13, 2010 + * @author Kai Ni + * @author Frank Dellaert + * @brief: Collects factorgraph fragments defined on variable clusters, arranged in a tree + */ + +#pragma once + +#include +#include +#include "BayesTree.h" +#include "SymbolicConditional.h" + +namespace gtsam { + + /* ************************************************************************* */ + template + class ClusterTree : public Testable > { + public: + // the class for subgraphs that also include the pointers to the parents and two children + class Clique : public FG { + private: + typedef typename boost::shared_ptr shared_ptr; + shared_ptr parent_; // the parent subgraph node + std::vector children_; // the child cliques + Ordering frontal_; // the frontal varaibles + Unordered separator_; // the separator variables + + friend class ClusterTree; + + public: + + // empty constructor + Clique() {} + + // constructor with all the information + Clique(const FG& fgLocal, const Ordering& frontal, const Unordered& separator, + const shared_ptr& parent) + : frontal_(frontal), separator_(separator), FG(fgLocal), parent_(parent) {} + + // constructor for an empty graph + Clique(const Ordering& frontal, const Unordered& separator, const shared_ptr& parent) + : frontal_(frontal), separator_(separator), parent_(parent) {} + + // return the members + const Ordering& frontal() const { return frontal_;} + const Unordered& separator() const { return separator_;} + const std::vector& children() { return children_; } + + // add a child node + void addChild(const shared_ptr& child) { children_.push_back(child); } + + // print the object + void print(const std::string& indent) const; + void printTree(const std::string& indent) const; + + // check equality + bool equals(const Clique& other) const; + }; + + // typedef for shared pointers to cliques + typedef typename Clique::shared_ptr sharedClique; + + protected: + // Root clique + sharedClique root_; + + private: + // distribute the factors along the Bayes tree + sharedClique distributeFactors(FG& fg, const BayesTree::sharedClique clique); + + // utility function called by eliminate + template + std::pair > eliminateOneClique(sharedClique fg_); + + public: + // constructor + ClusterTree() {} + + // constructor given a factor graph and the elimination ordering + ClusterTree(FG& fg, const Ordering& ordering); + + // return the root clique + sharedClique root() const { return root_; } + + // eliminate the factors in the subgraphs + template + BayesTree eliminate(); + + // print the object + void print(const std::string& str) const { + cout << str << endl; + if (root_.get()) root_->printTree(""); + } + + /** check equality */ + bool equals(const ClusterTree& other, double tol = 1e-9) const; + + }; // ClusterTree + +} // namespace gtsam diff --git a/inference/JunctionTree.h b/inference/JunctionTree.h index 69cb977c3..21720c25f 100644 --- a/inference/JunctionTree.h +++ b/inference/JunctionTree.h @@ -91,7 +91,7 @@ namespace gtsam { // print the object void print(const std::string& str) const { - cout << str << endl; + std::cout << str << std::endl; if (root_.get()) root_->printTree(""); } diff --git a/inference/Makefile.am b/inference/Makefile.am index 6e928f206..46632d97b 100644 --- a/inference/Makefile.am +++ b/inference/Makefile.am @@ -24,13 +24,13 @@ check_PROGRAMS += testSymbolicFactor testSymbolicFactorGraph testSymbolicBayesNe headers += inference.h inference-inl.h headers += graph.h graph-inl.h headers += FactorGraph.h FactorGraph-inl.h +headers += ClusterTree.h ClusterTree-inl.h headers += JunctionTree.h JunctionTree-inl.h headers += BayesNet.h BayesNet-inl.h headers += BayesTree.h BayesTree-inl.h headers += ISAM.h ISAM-inl.h headers += ISAM2.h ISAM2-inl.h -check_PROGRAMS += testFactorGraph testOrdering -check_PROGRAMS += testBayesTree testISAM +check_PROGRAMS += testFactorGraph testClusterTree testJunctionTree testBayesTree testISAM #---------------------------------------------------------------------------------------------------- # discrete diff --git a/inference/testClusterTree.cpp b/inference/testClusterTree.cpp new file mode 100644 index 000000000..914a68de9 --- /dev/null +++ b/inference/testClusterTree.cpp @@ -0,0 +1,25 @@ +/** + * @file testClusterTree.cpp + * @brief Unit tests for Bayes Tree + * @author Kai Ni + * @author Frank Dellaert + */ + +#include // for operator += +using namespace boost::assign; + +#include + +#include "SymbolicFactorGraph.h" +#include "ClusterTree-inl.h" + +using namespace gtsam; + +typedef ClusterTree SymbolicClusterTree; + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/inference/testJunctionTree.cpp b/inference/testJunctionTree.cpp new file mode 100644 index 000000000..618ee36b1 --- /dev/null +++ b/inference/testJunctionTree.cpp @@ -0,0 +1,57 @@ +/** + * @file testJunctionTree.cpp + * @brief Unit tests for Bayes Tree + * @author Kai Ni + * @author Frank Dellaert + */ + +#include // for operator += +#include // for operator += +using namespace boost::assign; + +#include + +#define GTSAM_MAGIC_KEY + +#include "Ordering.h" +#include "SymbolicFactorGraph.h" +#include "JunctionTree-inl.h" + +using namespace gtsam; + +typedef JunctionTree SymbolicJunctionTree; + +/* ************************************************************************* */ +/** + * x1 - x2 - x3 - x4 + * x3 x4 + * x2 x1 : x3 + */ +TEST( JunctionTree, constructor ) +{ + SymbolicFactorGraph fg; + fg.push_factor("x1","x2"); + fg.push_factor("x2","x3"); + fg.push_factor("x3","x4"); + + Ordering ordering; ordering += "x2","x1","x3","x4"; + SymbolicJunctionTree junctionTree(fg, ordering); + + Ordering frontal1; frontal1 += "x3", "x4"; + Ordering frontal2; frontal2 += "x2", "x1"; + Unordered sep1; + Unordered sep2; sep2 += "x3"; + CHECK(assert_equal(frontal1, junctionTree.root()->frontal())); + CHECK(assert_equal(sep1, junctionTree.root()->separator())); + LONGS_EQUAL(1, junctionTree.root()->size()); + CHECK(assert_equal(frontal2, junctionTree.root()->children()[0]->frontal())); + CHECK(assert_equal(sep2, junctionTree.root()->children()[0]->separator())); + LONGS_EQUAL(2, junctionTree.root()->children()[0]->size()); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/linear/GaussianFactorGraph.cpp b/linear/GaussianFactorGraph.cpp index 00dad4a11..48e8a2f6f 100644 --- a/linear/GaussianFactorGraph.cpp +++ b/linear/GaussianFactorGraph.cpp @@ -16,7 +16,7 @@ #include "FactorGraph-inl.h" #include "inference-inl.h" #include "iterative.h" -#include "GaussianJunctionTree-inl.h" +#include "GaussianJunctionTree.h" using namespace std; using namespace gtsam; @@ -121,8 +121,8 @@ set GaussianFactorGraph::find_separator(const Symbol& key) const /* ************************************************************************* */ GaussianConditional::shared_ptr -GaussianFactorGraph::eliminateOne(const Symbol& key, bool old) { - if (old) +GaussianFactorGraph::eliminateOne(const Symbol& key, bool enableJoinFactor) { + if (enableJoinFactor) return gtsam::eliminateOne(*this, key); else return eliminateOneMatrixJoin(key); @@ -241,11 +241,11 @@ GaussianFactorGraph::eliminateOneMatrixJoin(const Symbol& key) { /* ************************************************************************* */ GaussianBayesNet -GaussianFactorGraph::eliminate(const Ordering& ordering, bool old) +GaussianFactorGraph::eliminate(const Ordering& ordering, bool enableJoinFactor) { GaussianBayesNet chordalBayesNet; // empty BOOST_FOREACH(const Symbol& key, ordering) { - GaussianConditional::shared_ptr cg = eliminateOne(key, old); + GaussianConditional::shared_ptr cg = eliminateOne(key, enableJoinFactor); chordalBayesNet.push_back(cg); } return chordalBayesNet; @@ -299,7 +299,7 @@ VectorConfig GaussianFactorGraph::optimize(const Ordering& ordering, bool old) /* ************************************************************************* */ VectorConfig GaussianFactorGraph::optimizeMultiFrontals(const Ordering& ordering) { - GaussianJunctionTree junctionTree(*this, ordering); + GaussianJunctionTree junctionTree(*this, ordering); return junctionTree.optimize(); } diff --git a/linear/GaussianJunctionTree-inl.h b/linear/GaussianJunctionTree.cpp similarity index 59% rename from linear/GaussianJunctionTree-inl.h rename to linear/GaussianJunctionTree.cpp index 86c2a3562..16c1225d5 100644 --- a/linear/GaussianJunctionTree-inl.h +++ b/linear/GaussianJunctionTree.cpp @@ -1,13 +1,11 @@ /* - * GaussianJunctionTree-inl.h - * - * Created on: Jul 12, 2010 - * Author: nikai - * Description: the Gaussian junction tree + * GaussianJunctionTree.cpp + * Created on: Jul 12, 2010 + * @author Kai Ni + * @author Frank Dellaert + * @brief: the Gaussian junction tree */ -#pragma once - #include #include "JunctionTree-inl.h" @@ -15,35 +13,38 @@ namespace gtsam { + // explicit template instantiation + template class JunctionTree; + using namespace std; /* ************************************************************************* */ /** * GaussianJunctionTree */ - template - void GaussianJunctionTree::btreeBackSubstitue(typename BayesTree::sharedClique current, VectorConfig& config) { + void GaussianJunctionTree::btreeBackSubstitue( + BayesTree::sharedClique current, + VectorConfig& config) { // solve the bayes net in the current node - typename BayesNet::const_reverse_iterator it = current->rbegin(); + BayesNet::const_reverse_iterator it = current->rbegin(); for (; it!=current->rend(); it++) { Vector x = (*it)->solve(config); // Solve for that variable config.insert((*it)->key(),x); // store result in partial solution } // solve the bayes nets in the child nodes - typedef typename BayesTree::sharedClique sharedBayesClique; + typedef BayesTree::sharedClique sharedBayesClique; BOOST_FOREACH(sharedBayesClique child, current->children_) { btreeBackSubstitue(child, config); } } /* ************************************************************************* */ - template - VectorConfig GaussianJunctionTree::optimize() { + VectorConfig GaussianJunctionTree::optimize() { // eliminate from leaves to the root - typedef JunctionTree Base; + typedef JunctionTree Base; BayesTree bayesTree; - this->eliminate(); + this->eliminate(); // back-substitution VectorConfig result; diff --git a/linear/GaussianJunctionTree.h b/linear/GaussianJunctionTree.h index 742cf56bb..a83801ac2 100644 --- a/linear/GaussianJunctionTree.h +++ b/linear/GaussianJunctionTree.h @@ -1,9 +1,9 @@ /* * GaussianJunctionTree.h - * - * Created on: Jul 12, 2010 - * Author: nikai - * Description: the Gaussian junction tree + * Created on: Jul 12, 2010 + * @author Kai Ni + * @author Frank Dellaert + * @brief: the Gaussian junction tree */ #pragma once @@ -18,22 +18,21 @@ namespace gtsam { /** * GaussianJunctionTree that does the optimization */ - template - class GaussianJunctionTree: public JunctionTree { + class GaussianJunctionTree: public JunctionTree { public: - typedef JunctionTree Base; - typedef typename JunctionTree::sharedClique sharedClique; + typedef JunctionTree Base; + typedef Base::sharedClique sharedClique; protected: // back-substitute in topological sort order (parents first) - void btreeBackSubstitue(typename BayesTree::sharedClique current, VectorConfig& config); + void btreeBackSubstitue(BayesTree::sharedClique current, VectorConfig& config); public : GaussianJunctionTree() : Base() {} // constructor - GaussianJunctionTree(FG& fg, const Ordering& ordering) : Base(fg, ordering) {} + GaussianJunctionTree(GaussianFactorGraph& fg, const Ordering& ordering) : Base(fg, ordering) {} // optimize the linear graph VectorConfig optimize(); diff --git a/linear/Makefile.am b/linear/Makefile.am index 56ef4ea2c..43dbd22e5 100644 --- a/linear/Makefile.am +++ b/linear/Makefile.am @@ -19,7 +19,7 @@ check_PROGRAMS += testVectorMap testVectorBTree # Gaussian Factor Graphs headers += GaussianFactorSet.h Factorization.h sources += GaussianFactor.cpp GaussianFactorGraph.cpp -headers += GaussianJunctionTree.h GaussianJunctionTree-inl.h +sources += GaussianJunctionTree.cpp sources += GaussianConditional.cpp GaussianBayesNet.cpp sources += GaussianISAM.cpp check_PROGRAMS += testGaussianFactor testGaussianJunctionTree testGaussianConditional diff --git a/linear/testGaussianJunctionTree.cpp b/linear/testGaussianJunctionTree.cpp index 4955feafd..144f3dad7 100644 --- a/linear/testGaussianJunctionTree.cpp +++ b/linear/testGaussianJunctionTree.cpp @@ -1,9 +1,8 @@ /* - * testJunctionTree.cpp + * testGaussianJunctionTree.cpp * - * Created on: Jul 8, 2010 - * Author: nikai - * Description: + * Created on: Jul 8, 2010 + * @author Kai Ni */ #include @@ -16,45 +15,12 @@ using namespace boost::assign; #define GTSAM_MAGIC_KEY -#include "GaussianJunctionTree-inl.h" +#include "Ordering.h" +#include "GaussianJunctionTree.h" using namespace std; using namespace gtsam; -/* ************************************************************************* */ -/** - * x1 - x2 - x3 - x4 - * x3 x4 - * x2 x1 : x3 - */ -TEST( GaussianFactorGraph, constructor ) -{ - typedef GaussianFactorGraph::sharedFactor Factor; - SharedDiagonal model(Vector_(1, 0.2)); - Factor factor1(new GaussianFactor("x1", Matrix_(1,1,1.), "x2", Matrix_(1,1,1.), Vector_(1,1.), model)); - Factor factor2(new GaussianFactor("x2", Matrix_(1,1,1.), "x3", Matrix_(1,1,1.), Vector_(1,1.), model)); - Factor factor3(new GaussianFactor("x3", Matrix_(1,1,1.), "x4", Matrix_(1,1,1.), Vector_(1,1.), model)); - - GaussianFactorGraph fg; - fg.push_back(factor1); - fg.push_back(factor2); - fg.push_back(factor3); - - Ordering ordering; ordering += "x2","x1","x3","x4"; - GaussianJunctionTree junctionTree(fg, ordering); - - Ordering frontal1; frontal1 += "x3", "x4"; - Ordering frontal2; frontal2 += "x2", "x1"; - Unordered sep1; - Unordered sep2; sep2 += "x3"; - CHECK(assert_equal(frontal1, junctionTree.root()->frontal())); - CHECK(assert_equal(sep1, junctionTree.root()->separator())); - LONGS_EQUAL(1, junctionTree.root()->size()); - CHECK(assert_equal(frontal2, junctionTree.root()->children()[0]->frontal())); - CHECK(assert_equal(sep2, junctionTree.root()->children()[0]->separator())); - LONGS_EQUAL(2, junctionTree.root()->children()[0]->size()); -} - /* ************************************************************************* */ /** * x1 - x2 - x3 - x4 @@ -69,7 +35,7 @@ TEST( GaussianFactorGraph, constructor ) * * 1 0 0 1 */ -TEST( GaussianFactorGraph, eliminate ) +TEST( GaussianJunctionTree, eliminate ) { typedef GaussianFactorGraph::sharedFactor Factor; SharedDiagonal model(Vector_(1, 0.5)); @@ -85,7 +51,7 @@ TEST( GaussianFactorGraph, eliminate ) fg.push_back(factor4); Ordering ordering; ordering += "x2","x1","x3","x4"; - GaussianJunctionTree junctionTree(fg, ordering); + GaussianJunctionTree junctionTree(fg, ordering); BayesTree bayesTree = junctionTree.eliminate(); typedef BayesTree::sharedConditional sharedConditional; diff --git a/tests/testGaussianJunctionTree.cpp b/tests/testGaussianJunctionTree.cpp index 40df0fed1..88155f547 100644 --- a/tests/testGaussianJunctionTree.cpp +++ b/tests/testGaussianJunctionTree.cpp @@ -1,5 +1,5 @@ /* - * testJunctionTree.cpp + * testGaussianJunctionTree.cpp * * Created on: Jul 8, 2010 * Author: nikai @@ -16,8 +16,9 @@ using namespace boost::assign; #define GTSAM_MAGIC_KEY +#include "Ordering.h" +#include "GaussianJunctionTree.h" #include "smallExample.h" -#include "GaussianJunctionTree-inl.h" using namespace std; using namespace gtsam; @@ -30,7 +31,7 @@ using namespace example; C3 x1 : x2 C4 x7 : x6 /* ************************************************************************* */ -TEST( GaussianFactorGraph, constructor2 ) +TEST( GaussianJunctionTree, constructor2 ) { // create a graph GaussianFactorGraph fg = createSmoother(7); @@ -38,7 +39,7 @@ TEST( GaussianFactorGraph, constructor2 ) // create an ordering Ordering ordering; ordering += "x1","x3","x5","x7","x2","x6","x4"; - GaussianJunctionTree junctionTree(fg, ordering); + GaussianJunctionTree junctionTree(fg, ordering); Ordering frontal1; frontal1 += "x5", "x6", "x4"; Ordering frontal2; frontal2 += "x3", "x2"; Ordering frontal3; frontal3 += "x1"; @@ -62,7 +63,7 @@ TEST( GaussianFactorGraph, constructor2 ) } /* ************************************************************************* * -TEST( GaussianFactorGraph, optimizeMultiFrontal ) +TEST( GaussianJunctionTree, optimizeMultiFrontal ) { // create a graph GaussianFactorGraph fg = createSmoother(7);