Made JunctionTree a subclass of ClusterTree
parent
c3a907127f
commit
d07dfac236
|
@ -20,7 +20,7 @@ namespace gtsam {
|
|||
|
||||
/* ************************************************************************* */
|
||||
template <class FG>
|
||||
bool ClusterTree<FG>::Clique::equals(const ClusterTree<FG>::Clique& other) const {
|
||||
bool ClusterTree<FG>::Cluster::equals(const ClusterTree<FG>::Cluster& other) const {
|
||||
if (!frontal_.equals(other.frontal_))
|
||||
return false;
|
||||
|
||||
|
@ -43,7 +43,7 @@ namespace gtsam {
|
|||
* ClusterTree
|
||||
*/
|
||||
template <class FG>
|
||||
void ClusterTree<FG>::Clique::print(const string& indent) const {
|
||||
void ClusterTree<FG>::Cluster::print(const string& indent) const {
|
||||
// FG::print(indent);
|
||||
cout << indent;
|
||||
BOOST_FOREACH(const Symbol& key, frontal_)
|
||||
|
@ -56,99 +56,12 @@ namespace gtsam {
|
|||
|
||||
/* ************************************************************************* */
|
||||
template <class FG>
|
||||
void ClusterTree<FG>::Clique::printTree(const string& indent) const {
|
||||
void ClusterTree<FG>::Cluster::printTree(const string& indent) const {
|
||||
print(indent);
|
||||
BOOST_FOREACH(const shared_ptr& child, children_)
|
||||
child->printTree(indent+" ");
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class FG>
|
||||
ClusterTree<FG>::ClusterTree(FG& fg, const Ordering& ordering) {
|
||||
// Symbolic factorization: GaussianFactorGraph -> SymbolicFactorGraph -> SymbolicBayesNet -> SymbolicBayesTree
|
||||
SymbolicFactorGraph sfg(fg);
|
||||
SymbolicBayesNet sbn = sfg.eliminate(ordering);
|
||||
BayesTree<SymbolicConditional> sbt(sbn);
|
||||
|
||||
// distribtue factors
|
||||
root_ = distributeFactors(fg, sbt.root());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class FG>
|
||||
typename ClusterTree<FG>::sharedClique ClusterTree<FG>::distributeFactors(FG& fg,
|
||||
const BayesTree<SymbolicConditional>::sharedClique bayesClique) {
|
||||
// create a new clique in the junction tree
|
||||
sharedClique clique(new Clique());
|
||||
clique->frontal_ = bayesClique->ordering();
|
||||
clique->separator_.insert(bayesClique->separator_.begin(), bayesClique->separator_.end());
|
||||
|
||||
// recursively call the children
|
||||
BOOST_FOREACH(const BayesTree<SymbolicConditional>::sharedClique bayesChild, bayesClique->children()) {
|
||||
sharedClique child = distributeFactors(fg, bayesChild);
|
||||
clique->children_.push_back(child);
|
||||
child->parent_ = clique;
|
||||
}
|
||||
|
||||
// collect the factors
|
||||
typedef vector<typename FG::sharedFactor> Factors;
|
||||
BOOST_FOREACH(const Symbol& frontal, clique->frontal_) {
|
||||
Factors factors = fg.template findAndRemoveFactors<Factors>(frontal);
|
||||
BOOST_FOREACH(const typename FG::sharedFactor& factor_, factors) {
|
||||
clique->push_back(factor_);
|
||||
}
|
||||
}
|
||||
|
||||
return clique;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class FG> template <class Conditional>
|
||||
pair<FG, BayesTree<Conditional> >
|
||||
ClusterTree<FG>::eliminateOneClique(sharedClique current) {
|
||||
|
||||
// current->frontal_.print("current clique:");
|
||||
|
||||
typedef typename BayesTree<Conditional>::sharedClique sharedBtreeClique;
|
||||
FG fg; // factor graph will be assembled from local factors and marginalized children
|
||||
list<BayesTree<Conditional> > children;
|
||||
fg.push_back(*current); // add the local factor graph
|
||||
|
||||
// BOOST_FOREACH(const typename FG::sharedFactor& factor_, fg)
|
||||
// Ordering(factor_->keys()).print("local factor:");
|
||||
|
||||
BOOST_FOREACH(sharedClique& child, current->children_) {
|
||||
// receive the factors from the child and its clique point
|
||||
FG fgChild; BayesTree<Conditional> childTree;
|
||||
boost::tie(fgChild, childTree) = eliminateOneClique<Conditional>(child);
|
||||
|
||||
// BOOST_FOREACH(const typename FG::sharedFactor& factor_, fgChild)
|
||||
// Ordering(factor_->keys()).print("factor from child:");
|
||||
|
||||
fg.push_back(fgChild);
|
||||
children.push_back(childTree);
|
||||
}
|
||||
|
||||
// eliminate the combined factors
|
||||
// warning: fg is being eliminated in-place and will contain marginal afterwards
|
||||
BayesNet<Conditional> bn = fg.eliminateFrontals(current->frontal_);
|
||||
|
||||
// create a new clique corresponding the combined factors
|
||||
BayesTree<Conditional> bayesTree(bn, children);
|
||||
|
||||
return make_pair(fg, bayesTree);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class FG> template <class Conditional>
|
||||
BayesTree<Conditional> ClusterTree<FG>::eliminate() {
|
||||
pair<FG, BayesTree<Conditional> > ret = this->eliminateOneClique<Conditional>(root_);
|
||||
// ret.first.print("ret.first");
|
||||
if (ret.first.nrFactors() != 0)
|
||||
throw runtime_error("JuntionTree::eliminate: elimination failed because of factors left over!");
|
||||
return ret.second;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class FG>
|
||||
bool ClusterTree<FG>::equals(const ClusterTree<FG>& other, double tol) const {
|
||||
|
|
|
@ -10,38 +10,45 @@
|
|||
|
||||
#include <set>
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include "BayesTree.h"
|
||||
#include "SymbolicConditional.h"
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************* */
|
||||
/**
|
||||
* A cluster-tree is associated with a factor graph and is defined as in Koller-Friedman:
|
||||
* each node k represents a subset C_k \sub X, and the tree is family preserving, in that
|
||||
* each factor f_i is associated with a single cluster and scope(f_i) \sub C_k.
|
||||
*/
|
||||
template <class FG>
|
||||
class ClusterTree : public Testable<ClusterTree<FG> > {
|
||||
public:
|
||||
// the class for subgraphs that also include the pointers to the parents and two children
|
||||
class Clique : public FG {
|
||||
private:
|
||||
typedef typename boost::shared_ptr<Clique> shared_ptr;
|
||||
shared_ptr parent_; // the parent subgraph node
|
||||
std::vector<shared_ptr> children_; // the child cliques
|
||||
Ordering frontal_; // the frontal varaibles
|
||||
Unordered separator_; // the separator variables
|
||||
|
||||
friend class ClusterTree<FG>;
|
||||
public:
|
||||
|
||||
// the class for subgraphs that also include the pointers to the parents and two children
|
||||
class Cluster : public FG {
|
||||
|
||||
public:
|
||||
|
||||
typedef typename boost::shared_ptr<Cluster> shared_ptr;
|
||||
|
||||
/* commented private out to make compile but needs to be addressed */
|
||||
|
||||
shared_ptr parent_; // the parent subgraph node
|
||||
std::vector<shared_ptr> children_; // the child clusters
|
||||
Ordering frontal_; // the frontal variables
|
||||
Unordered separator_; // the separator variables
|
||||
|
||||
public:
|
||||
|
||||
// empty constructor
|
||||
Clique() {}
|
||||
Cluster() {}
|
||||
|
||||
// constructor with all the information
|
||||
Clique(const FG& fgLocal, const Ordering& frontal, const Unordered& separator,
|
||||
Cluster(const FG& fgLocal, const Ordering& frontal, const Unordered& separator,
|
||||
const shared_ptr& parent)
|
||||
: frontal_(frontal), separator_(separator), FG(fgLocal), parent_(parent) {}
|
||||
|
||||
// constructor for an empty graph
|
||||
Clique(const Ordering& frontal, const Unordered& separator, const shared_ptr& parent)
|
||||
Cluster(const Ordering& frontal, const Unordered& separator, const shared_ptr& parent)
|
||||
: frontal_(frontal), separator_(separator), parent_(parent) {}
|
||||
|
||||
// return the members
|
||||
|
@ -57,23 +64,15 @@ namespace gtsam {
|
|||
void printTree(const std::string& indent) const;
|
||||
|
||||
// check equality
|
||||
bool equals(const Clique& other) const;
|
||||
bool equals(const Cluster& other) const;
|
||||
};
|
||||
|
||||
// typedef for shared pointers to cliques
|
||||
typedef typename Clique::shared_ptr sharedClique;
|
||||
// typedef for shared pointers to clusters
|
||||
typedef typename Cluster::shared_ptr sharedCluster;
|
||||
|
||||
protected:
|
||||
// Root clique
|
||||
sharedClique root_;
|
||||
|
||||
private:
|
||||
// distribute the factors along the Bayes tree
|
||||
sharedClique distributeFactors(FG& fg, const BayesTree<SymbolicConditional>::sharedClique clique);
|
||||
|
||||
// utility function called by eliminate
|
||||
template <class Conditional>
|
||||
std::pair<FG, BayesTree<Conditional> > eliminateOneClique(sharedClique fg_);
|
||||
// Root cluster
|
||||
sharedCluster root_;
|
||||
|
||||
public:
|
||||
// constructor
|
||||
|
@ -82,16 +81,12 @@ namespace gtsam {
|
|||
// constructor given a factor graph and the elimination ordering
|
||||
ClusterTree(FG& fg, const Ordering& ordering);
|
||||
|
||||
// return the root clique
|
||||
sharedClique root() const { return root_; }
|
||||
|
||||
// eliminate the factors in the subgraphs
|
||||
template <class Conditional>
|
||||
BayesTree<Conditional> eliminate();
|
||||
// return the root cluster
|
||||
sharedCluster root() const { return root_; }
|
||||
|
||||
// print the object
|
||||
void print(const std::string& str) const {
|
||||
cout << str << endl;
|
||||
std::cout << str << std::endl;
|
||||
if (root_.get()) root_->printTree("");
|
||||
}
|
||||
|
||||
|
|
|
@ -18,70 +18,28 @@ namespace gtsam {
|
|||
|
||||
using namespace std;
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class FG>
|
||||
bool JunctionTree<FG>::Clique::equals(const JunctionTree<FG>::Clique& other) const {
|
||||
if (!frontal_.equals(other.frontal_))
|
||||
return false;
|
||||
|
||||
if (!separator_.equals(other.separator_))
|
||||
return false;
|
||||
|
||||
if (children_.size() != other.children_.size())
|
||||
return false;
|
||||
|
||||
typename vector<shared_ptr>::const_iterator it1 = children_.begin();
|
||||
typename vector<shared_ptr>::const_iterator it2 = other.children_.begin();
|
||||
for(; it1!=children_.end(); it1++, it2++)
|
||||
if (!(*it1)->equals(**it2)) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/**
|
||||
* JunctionTree
|
||||
*/
|
||||
template <class FG>
|
||||
void JunctionTree<FG>::Clique::print(const string& indent) const {
|
||||
// FG::print(indent);
|
||||
cout << indent;
|
||||
BOOST_FOREACH(const Symbol& key, frontal_)
|
||||
cout << (string)key << " ";
|
||||
cout << ":";
|
||||
BOOST_FOREACH(const Symbol& key, separator_)
|
||||
cout << (string)key << " ";
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class FG>
|
||||
void JunctionTree<FG>::Clique::printTree(const string& indent) const {
|
||||
print(indent);
|
||||
BOOST_FOREACH(const shared_ptr& child, children_)
|
||||
child->printTree(indent+" ");
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class FG>
|
||||
JunctionTree<FG>::JunctionTree(FG& fg, const Ordering& ordering) {
|
||||
// Symbolic factorization: GaussianFactorGraph -> SymbolicFactorGraph -> SymbolicBayesNet -> SymbolicBayesTree
|
||||
// Symbolic factorization: GaussianFactorGraph -> SymbolicFactorGraph
|
||||
// -> SymbolicBayesNet -> SymbolicBayesTree
|
||||
SymbolicFactorGraph sfg(fg);
|
||||
SymbolicBayesNet sbn = sfg.eliminate(ordering);
|
||||
BayesTree<SymbolicConditional> sbt(sbn);
|
||||
|
||||
// distribtue factors
|
||||
root_ = distributeFactors(fg, sbt.root());
|
||||
this->root_ = distributeFactors(fg, sbt.root());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class FG>
|
||||
typename JunctionTree<FG>::sharedClique JunctionTree<FG>::distributeFactors(FG& fg,
|
||||
const BayesTree<SymbolicConditional>::sharedClique bayesClique) {
|
||||
typename JunctionTree<FG>::sharedClique JunctionTree<FG>::distributeFactors(
|
||||
FG& fg, const BayesTree<SymbolicConditional>::sharedClique bayesClique) {
|
||||
// create a new clique in the junction tree
|
||||
sharedClique clique(new Clique());
|
||||
clique->frontal_ = bayesClique->ordering();
|
||||
clique->separator_.insert(bayesClique->separator_.begin(), bayesClique->separator_.end());
|
||||
clique->separator_.insert(bayesClique->separator_.begin(),
|
||||
bayesClique->separator_.end());
|
||||
|
||||
// recursively call the children
|
||||
BOOST_FOREACH(const BayesTree<SymbolicConditional>::sharedClique bayesChild, bayesClique->children()) {
|
||||
|
@ -142,18 +100,10 @@ namespace gtsam {
|
|||
/* ************************************************************************* */
|
||||
template <class FG> template <class Conditional>
|
||||
BayesTree<Conditional> JunctionTree<FG>::eliminate() {
|
||||
pair<FG, BayesTree<Conditional> > ret = this->eliminateOneClique<Conditional>(root_);
|
||||
// ret.first.print("ret.first");
|
||||
pair<FG, BayesTree<Conditional> > ret = this->eliminateOneClique<Conditional>(this->root());
|
||||
if (ret.first.nrFactors() != 0)
|
||||
throw runtime_error("JuntionTree::eliminate: elimination failed because of factors left over!");
|
||||
return ret.second;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class FG>
|
||||
bool JunctionTree<FG>::equals(const JunctionTree<FG>& other, double tol) const {
|
||||
if (!root_ || !other.root_) return false;
|
||||
return root_->equals(*other.root_);
|
||||
}
|
||||
|
||||
} //namespace gtsam
|
||||
|
|
|
@ -11,62 +11,29 @@
|
|||
#include <set>
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include "BayesTree.h"
|
||||
#include "ClusterTree.h"
|
||||
#include "SymbolicConditional.h"
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************* */
|
||||
/**
|
||||
* A junction tree (or clique-tree) is a cluster-tree where each node k represents a
|
||||
* clique (maximal fully connected subset) of an associated chordal graph, such as a
|
||||
* chordal Bayes net resulting from elimination. In GTSAM the BayesTree is used to
|
||||
* represent the clique tree associated with a Bayes net, and the JunctionTree is
|
||||
* used to collect the factors associated with each clique during the elimination process.
|
||||
*/
|
||||
template <class FG>
|
||||
class JunctionTree : public Testable<JunctionTree<FG> > {
|
||||
public:
|
||||
// the class for subgraphs that also include the pointers to the parents and two children
|
||||
class Clique : public FG {
|
||||
private:
|
||||
typedef typename boost::shared_ptr<Clique> shared_ptr;
|
||||
shared_ptr parent_; // the parent subgraph node
|
||||
std::vector<shared_ptr> children_; // the child cliques
|
||||
Ordering frontal_; // the frontal varaibles
|
||||
Unordered separator_; // the separator variables
|
||||
|
||||
friend class JunctionTree<FG>;
|
||||
class JunctionTree : public ClusterTree<FG> {
|
||||
|
||||
public:
|
||||
|
||||
// empty constructor
|
||||
Clique() {}
|
||||
|
||||
// constructor with all the information
|
||||
Clique(const FG& fgLocal, const Ordering& frontal, const Unordered& separator,
|
||||
const shared_ptr& parent)
|
||||
: frontal_(frontal), separator_(separator), FG(fgLocal), parent_(parent) {}
|
||||
|
||||
// constructor for an empty graph
|
||||
Clique(const Ordering& frontal, const Unordered& separator, const shared_ptr& parent)
|
||||
: frontal_(frontal), separator_(separator), parent_(parent) {}
|
||||
|
||||
// return the members
|
||||
const Ordering& frontal() const { return frontal_;}
|
||||
const Unordered& separator() const { return separator_;}
|
||||
const std::vector<shared_ptr>& children() { return children_; }
|
||||
|
||||
// add a child node
|
||||
void addChild(const shared_ptr& child) { children_.push_back(child); }
|
||||
|
||||
// print the object
|
||||
void print(const std::string& indent) const;
|
||||
void printTree(const std::string& indent) const;
|
||||
|
||||
// check equality
|
||||
bool equals(const Clique& other) const;
|
||||
};
|
||||
|
||||
// typedef for shared pointers to cliques
|
||||
/**
|
||||
* In a junction tree each cluster is associated with a clique
|
||||
*/
|
||||
typedef typename ClusterTree<FG>::Cluster Clique;
|
||||
typedef typename Clique::shared_ptr sharedClique;
|
||||
|
||||
protected:
|
||||
// Root clique
|
||||
sharedClique root_;
|
||||
|
||||
private:
|
||||
// distribute the factors along the Bayes tree
|
||||
sharedClique distributeFactors(FG& fg, const BayesTree<SymbolicConditional>::sharedClique clique);
|
||||
|
@ -82,22 +49,10 @@ namespace gtsam {
|
|||
// constructor given a factor graph and the elimination ordering
|
||||
JunctionTree(FG& fg, const Ordering& ordering);
|
||||
|
||||
// return the root clique
|
||||
sharedClique root() const { return root_; }
|
||||
|
||||
// eliminate the factors in the subgraphs
|
||||
template <class Conditional>
|
||||
BayesTree<Conditional> eliminate();
|
||||
|
||||
// print the object
|
||||
void print(const std::string& str) const {
|
||||
std::cout << str << std::endl;
|
||||
if (root_.get()) root_->printTree("");
|
||||
}
|
||||
|
||||
/** check equality */
|
||||
bool equals(const JunctionTree<FG>& other, double tol = 1e-9) const;
|
||||
|
||||
}; // JunctionTree
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -15,6 +15,8 @@ using namespace boost::assign;
|
|||
|
||||
using namespace gtsam;
|
||||
|
||||
// explicit instantiation and typedef
|
||||
template class ClusterTree<SymbolicFactorGraph>;
|
||||
typedef ClusterTree<SymbolicFactorGraph> SymbolicClusterTree;
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -15,18 +15,21 @@ using namespace boost::assign;
|
|||
|
||||
#include "Ordering.h"
|
||||
#include "SymbolicFactorGraph.h"
|
||||
#include "JunctionTree.h"
|
||||
#include "ClusterTree-inl.h"
|
||||
#include "JunctionTree-inl.h"
|
||||
|
||||
using namespace gtsam;
|
||||
|
||||
// explicit instantiation and typedef
|
||||
template class JunctionTree<SymbolicFactorGraph>;
|
||||
typedef JunctionTree<SymbolicFactorGraph> SymbolicJunctionTree;
|
||||
|
||||
/* ************************************************************************* */
|
||||
/**
|
||||
/* ************************************************************************* *
|
||||
* x1 - x2 - x3 - x4
|
||||
* x3 x4
|
||||
* x2 x1 : x3
|
||||
*/
|
||||
****************************************************************************/
|
||||
TEST( JunctionTree, constructor )
|
||||
{
|
||||
SymbolicFactorGraph fg;
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#include <boost/foreach.hpp>
|
||||
|
||||
#include "ClusterTree-inl.h"
|
||||
#include "JunctionTree-inl.h"
|
||||
#include "GaussianJunctionTree.h"
|
||||
|
||||
|
|
Loading…
Reference in New Issue