Symbolic Bayes Tree successfully constructed
parent
cefeca149b
commit
53890c4ba6
|
@ -6,22 +6,74 @@
|
|||
|
||||
#include "BayesTree.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
using namespace std;
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class Conditional>
|
||||
BayesTree<Conditional>::BayesTree() {
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class Conditional>
|
||||
BayesTree<Conditional>::BayesTree(BayesChain<Conditional>& bayesChain) {
|
||||
list<string> ordering;// = bayesChain.ordering();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class Conditional>
|
||||
void BayesTree<Conditional>::print(const string& s) const {
|
||||
void BayesTree<Conditional>::print(const std::string& s) const {
|
||||
cout << s << ": size == " << nodes_.size() << endl;
|
||||
if (nodes_.empty()) return;
|
||||
nodes_[0]->printTree("");
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class Conditional>
|
||||
bool BayesTree<Conditional>::equals(const BayesTree<Conditional>& other,
|
||||
double tol) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class Conditional>
|
||||
void BayesTree<Conditional>::insert(string key, conditional_ptr conditional) {
|
||||
|
||||
// get any parent
|
||||
list<string> parents = conditional->parents();
|
||||
|
||||
// if no parents, start a new root clique
|
||||
if (parents.empty()) {
|
||||
node_ptr root(new Node(key, conditional));
|
||||
nodes_.push_back(root);
|
||||
nodeMap_.insert(make_pair(key, 0));
|
||||
return;
|
||||
}
|
||||
|
||||
// otherwise, find the parent clique
|
||||
string parent = parents.front();
|
||||
NodeMap::const_iterator it = nodeMap_.find(parent);
|
||||
if (it == nodeMap_.end()) throw(invalid_argument(
|
||||
"BayesTree::insert: parent with key " + key + "was not yet inserted"));
|
||||
int index = it->second;
|
||||
node_ptr parent_clique = nodes_[index];
|
||||
|
||||
// if the parents and parent clique have the same size, add to parent clique
|
||||
if (parent_clique->size() == parents.size()) {
|
||||
nodeMap_.insert(make_pair(key, index));
|
||||
parent_clique->add(key, conditional);
|
||||
return;
|
||||
}
|
||||
|
||||
// otherwise, start a new clique and add it to the tree
|
||||
node_ptr new_clique(new Node(key, conditional));
|
||||
new_clique->parent_ = parent_clique;
|
||||
parent_clique->children_.push_back(new_clique);
|
||||
nodeMap_.insert(make_pair(key, nodes_.size()));
|
||||
nodes_.push_back(new_clique);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
||||
} /// namespace gtsam
|
||||
|
|
124
cpp/BayesTree.h
124
cpp/BayesTree.h
|
@ -8,37 +8,123 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <list>
|
||||
#include <vector>
|
||||
#include <boost/serialization/map.hpp>
|
||||
#include <boost/serialization/list.hpp>
|
||||
|
||||
#include <boost/foreach.hpp> // TODO: make cpp file
|
||||
#include "Testable.h"
|
||||
#include "BayesChain.h"
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* Bayes tree
|
||||
* Templated on the Conditional class, the type of node in the underlying Bayes chain.
|
||||
* This could be a ConditionalProbabilityTable, a ConditionalGaussian, or a SymbolicConditional
|
||||
*/
|
||||
template <class Conditional>
|
||||
class BayesTree : public Testable<BayesTree<Conditional> >
|
||||
{
|
||||
public:
|
||||
/** A clique in a Bayes tree consisting of frontal nodes and conditionals */
|
||||
template<class Conditional>
|
||||
class Front: Testable<Front<Conditional> > {
|
||||
private:
|
||||
typedef boost::shared_ptr<Conditional> cond_ptr;
|
||||
std::list<std::string> keys_; /** frontal keys */
|
||||
std::list<cond_ptr> nodes_; /** conditionals */
|
||||
std::list<std::string> separator_; /** separator keys */
|
||||
public:
|
||||
|
||||
/** Create a Bayes Tree from a SymbolicBayesChain */
|
||||
BayesTree(BayesChain<Conditional>& bayesChain);
|
||||
/** constructor */
|
||||
Front(std::string key, cond_ptr conditional) {
|
||||
add(key, conditional);
|
||||
separator_ = conditional->parents();
|
||||
}
|
||||
|
||||
/** Destructor */
|
||||
virtual ~BayesTree() {}
|
||||
/** print */
|
||||
void print(const std::string& s = "") const {
|
||||
std::cout << s;
|
||||
BOOST_FOREACH(std::string key, keys_)
|
||||
std::cout << " " << key;
|
||||
if (!separator_.empty()) {
|
||||
std::cout << " :";
|
||||
BOOST_FOREACH(std::string key, separator_)
|
||||
std::cout << " " << key;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
/** print */
|
||||
void print(const std::string& s = "") const;
|
||||
/** check equality. TODO: only keys */
|
||||
bool equals(const Front<Conditional>& other, double tol = 1e-9) const {
|
||||
return (keys_ == other.keys_);
|
||||
}
|
||||
|
||||
/** check equality */
|
||||
bool equals(const BayesTree<Conditional>& other, double tol = 1e-9) const;
|
||||
/** add a frontal node */
|
||||
void add(std::string key, cond_ptr conditional) {
|
||||
keys_.push_front(key);
|
||||
nodes_.push_front(conditional);
|
||||
}
|
||||
|
||||
}; // BayesTree
|
||||
/** return size of the clique */
|
||||
inline size_t size() const {return keys_.size() + separator_.size();}
|
||||
};
|
||||
|
||||
/**
|
||||
* Bayes tree
|
||||
* Templated on the Conditional class, the type of node in the underlying Bayes chain.
|
||||
* This could be a ConditionalProbabilityTable, a ConditionalGaussian, or a SymbolicConditional
|
||||
*/
|
||||
template<class Conditional>
|
||||
class BayesTree: public Testable<BayesTree<Conditional> > {
|
||||
|
||||
public:
|
||||
|
||||
typedef boost::shared_ptr<Conditional> conditional_ptr;
|
||||
|
||||
private:
|
||||
|
||||
/** A Node in the tree is a Front with tree connectivity */
|
||||
struct Node : public Front<Conditional> {
|
||||
typedef boost::shared_ptr<Node> shared_ptr;
|
||||
shared_ptr parent_;
|
||||
std::list<shared_ptr> children_;
|
||||
|
||||
Node(std::string key, conditional_ptr conditional):Front<Conditional>(key,conditional) {}
|
||||
|
||||
/** print this node and entire subtree below it*/
|
||||
void printTree(const std::string& indent) const {
|
||||
print(indent);
|
||||
BOOST_FOREACH(shared_ptr child, children_)
|
||||
child->printTree(indent+" ");
|
||||
}
|
||||
};
|
||||
|
||||
/** vector of Nodes */
|
||||
typedef boost::shared_ptr<Node> node_ptr;
|
||||
typedef std::vector<node_ptr> Nodes;
|
||||
Nodes nodes_;
|
||||
|
||||
/** Map from keys to Node index */
|
||||
typedef std::map<std::string, int> NodeMap;
|
||||
NodeMap nodeMap_;
|
||||
|
||||
public:
|
||||
|
||||
/** Create an empty Bayes Tree */
|
||||
BayesTree();
|
||||
|
||||
/** Create a Bayes Tree from a SymbolicBayesChain */
|
||||
BayesTree(BayesChain<Conditional>& bayesChain);
|
||||
|
||||
/** Destructor */
|
||||
virtual ~BayesTree() {}
|
||||
|
||||
/** print */
|
||||
void print(const std::string& s = "") const;
|
||||
|
||||
/** check equality */
|
||||
bool equals(const BayesTree<Conditional>& other, double tol = 1e-9) const;
|
||||
|
||||
/** insert a new conditional */
|
||||
void insert(std::string key, conditional_ptr conditional);
|
||||
|
||||
/** return root clique */
|
||||
const Front<Conditional>& root() const {return *(nodes_[0]);}
|
||||
|
||||
}; // BayesTree
|
||||
|
||||
} /// namespace gtsam
|
||||
|
|
|
@ -43,7 +43,7 @@ ConditionalGaussian::ConditionalGaussian(Vector d,
|
|||
/* ************************************************************************* */
|
||||
ConditionalGaussian::ConditionalGaussian(const Vector& d,
|
||||
const Matrix& R,
|
||||
const map<std::string, Matrix>& parents)
|
||||
const map<string, Matrix>& parents)
|
||||
: R_(R), d_(d), parents_(parents)
|
||||
{
|
||||
}
|
||||
|
@ -53,7 +53,7 @@ void ConditionalGaussian::print(const string &s) const
|
|||
{
|
||||
cout << s << ":" << endl;
|
||||
gtsam::print(R_,"R");
|
||||
for(map<string, Matrix>::const_iterator it = parents_.begin() ; it != parents_.end() ; it++ ) {
|
||||
for(Parents::const_iterator it = parents_.begin() ; it != parents_.end() ; it++ ) {
|
||||
const string& j = it->first;
|
||||
const Matrix& Aj = it->second;
|
||||
gtsam::print(Aj, "A["+j+"]");
|
||||
|
@ -63,7 +63,7 @@ void ConditionalGaussian::print(const string &s) const
|
|||
|
||||
/* ************************************************************************* */
|
||||
bool ConditionalGaussian::equals(const ConditionalGaussian &cg, double tol) const {
|
||||
map<string, Matrix>::const_iterator it = parents_.begin();
|
||||
Parents::const_iterator it = parents_.begin();
|
||||
|
||||
// check if the size of the parents_ map is the same
|
||||
if (parents_.size() != cg.parents_.size()) return false;
|
||||
|
@ -77,21 +77,26 @@ bool ConditionalGaussian::equals(const ConditionalGaussian &cg, double tol) cons
|
|||
// check if the matrices are the same
|
||||
// iterate over the parents_ map
|
||||
for (it = parents_.begin(); it != parents_.end(); it++) {
|
||||
map<string, Matrix>::const_iterator it2 = cg.parents_.find(
|
||||
it->first.c_str());
|
||||
Parents::const_iterator it2 = cg.parents_.find(it->first.c_str());
|
||||
if (it2 != cg.parents_.end()) {
|
||||
if (!(equal_with_abs_tol(it->second, it2->second, tol))) return false;
|
||||
} else {
|
||||
} else
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
list<string> ConditionalGaussian::parents() {
|
||||
list<string> result;
|
||||
for (Parents::const_iterator it = parents_.begin(); it != parents_.end(); it++)
|
||||
result.push_back(it->first);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
Vector ConditionalGaussian::solve(const VectorConfig& x) const {
|
||||
Vector rhs = d_;
|
||||
for (map<string, Matrix>::const_iterator it = parents_.begin(); it
|
||||
for (Parents::const_iterator it = parents_.begin(); it
|
||||
!= parents_.end(); it++) {
|
||||
const string& j = it->first;
|
||||
const Matrix& Aj = it->second;
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <list>
|
||||
#include <boost/utility.hpp>
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <boost/serialization/map.hpp>
|
||||
|
@ -29,7 +30,9 @@ namespace gtsam {
|
|||
class ConditionalGaussian : boost::noncopyable, public Testable<ConditionalGaussian>
|
||||
{
|
||||
public:
|
||||
typedef std::map<std::string, Matrix>::const_iterator const_iterator;
|
||||
typedef std::map<std::string, Matrix> Parents;
|
||||
typedef Parents::const_iterator const_iterator;
|
||||
typedef boost::shared_ptr<ConditionalGaussian> shared_ptr;
|
||||
|
||||
protected:
|
||||
|
||||
|
@ -37,13 +40,12 @@ namespace gtsam {
|
|||
Matrix R_;
|
||||
|
||||
/** the names and the matrices connecting to parent nodes */
|
||||
std::map<std::string, Matrix> parents_;
|
||||
Parents parents_;
|
||||
|
||||
/** the RHS vector */
|
||||
Vector d_;
|
||||
|
||||
public:
|
||||
typedef boost::shared_ptr<ConditionalGaussian> shared_ptr;
|
||||
|
||||
/** constructor */
|
||||
ConditionalGaussian() {};
|
||||
|
@ -84,7 +86,7 @@ namespace gtsam {
|
|||
*/
|
||||
ConditionalGaussian(const Vector& d,
|
||||
const Matrix& R,
|
||||
const std::map<std::string, Matrix>& parents);
|
||||
const Parents& parents);
|
||||
|
||||
/** deconstructor */
|
||||
virtual ~ConditionalGaussian() {};
|
||||
|
@ -98,6 +100,9 @@ namespace gtsam {
|
|||
/** dimension of multivariate variable */
|
||||
size_t dim() const {return R_.size2();}
|
||||
|
||||
/** return all parents */
|
||||
std::list<std::string> parents();
|
||||
|
||||
/** return stuff contained in ConditionalGaussian */
|
||||
const Vector& get_d() const {return d_;}
|
||||
const Matrix& get_R() const {return R_;}
|
||||
|
|
|
@ -20,6 +20,7 @@ namespace gtsam {
|
|||
* Conditional node for use in a Bayes net
|
||||
*/
|
||||
class SymbolicConditional: Testable<SymbolicConditional> {
|
||||
|
||||
private:
|
||||
|
||||
std::list<std::string> parents_;
|
||||
|
@ -68,6 +69,9 @@ namespace gtsam {
|
|||
return parents_ == other.parents_;
|
||||
}
|
||||
|
||||
/** return any parent */
|
||||
std::list<std::string> parents() { return parents_;}
|
||||
|
||||
};
|
||||
|
||||
} /// namespace gtsam
|
||||
|
|
|
@ -4,24 +4,73 @@
|
|||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <boost/assign/list_inserter.hpp> // for 'insert()'
|
||||
#include <boost/assign/std/vector.hpp> // for operator +=
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
#include "SymbolicBayesChain.h"
|
||||
#include "smallExample.h"
|
||||
#include "BayesTree-inl.h"
|
||||
|
||||
//using namespace std;
|
||||
using namespace gtsam;
|
||||
using namespace boost::assign;
|
||||
|
||||
// Conditionals for ASIA example from the tutorial with A and D evidence
|
||||
SymbolicConditional::shared_ptr
|
||||
B(new SymbolicConditional()),
|
||||
L(new SymbolicConditional("B")),
|
||||
E(new SymbolicConditional("L","B")),
|
||||
S(new SymbolicConditional("L","B")),
|
||||
T(new SymbolicConditional("L","E")),
|
||||
X(new SymbolicConditional("E"));
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( BayesTree, Front )
|
||||
{
|
||||
Front<SymbolicConditional> f1("B",B); f1.add("L",L);
|
||||
Front<SymbolicConditional> f2("L",L); f2.add("B",B);
|
||||
CHECK(f1.equals(f1));
|
||||
CHECK(!f1.equals(f2));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( BayesTree, insert )
|
||||
{
|
||||
// Insert
|
||||
BayesTree<SymbolicConditional> bayesTree;
|
||||
bayesTree.insert("B",B);
|
||||
bayesTree.insert("L",L);
|
||||
bayesTree.insert("E",E);
|
||||
bayesTree.insert("S",S);
|
||||
bayesTree.insert("T",T);
|
||||
bayesTree.insert("X",X);
|
||||
//bayesTree.print("bayesTree");
|
||||
|
||||
//LONGS_EQUAL(1,bayesTree.size());
|
||||
|
||||
// Check root
|
||||
Front<SymbolicConditional> expected_root("B",B);
|
||||
//CHECK(assert_equal(expected_root,bayesTree.root()));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( BayesTree, constructor )
|
||||
{
|
||||
LinearFactorGraph factorGraph = createLinearFactorGraph();
|
||||
Ordering ordering;
|
||||
ordering.push_back("x2");
|
||||
ordering.push_back("l1");
|
||||
ordering.push_back("x1");
|
||||
SymbolicBayesChain symbolicBayesChain(factorGraph,ordering);
|
||||
BayesTree<SymbolicConditional> bayesTree(symbolicBayesChain);
|
||||
// Create Symbolic Bayes Chain in which we want to discover cliques
|
||||
map<string, SymbolicConditional::shared_ptr> nodes;
|
||||
insert(nodes)("B",B)("L",L)("E",E)("S",S)("T",T)("X",X);
|
||||
SymbolicBayesChain ASIA(nodes);
|
||||
|
||||
// Create Bayes Tree from Symbolic Bayes Chain
|
||||
BayesTree<SymbolicConditional> bayesTree(ASIA);
|
||||
bayesTree.insert("B",B);
|
||||
//bayesTree.print("bayesTree");
|
||||
|
||||
//LONGS_EQUAL(1,bayesTree.size());
|
||||
|
||||
// Check root
|
||||
Front<SymbolicConditional> expected_root("B",B);
|
||||
//CHECK(assert_equal(expected_root,bayesTree.root()));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -4,6 +4,10 @@
|
|||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <boost/assign/list_inserter.hpp> // for 'insert()'
|
||||
#include <boost/assign/std/vector.hpp> // for operator +=
|
||||
using namespace boost::assign;
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
#include "smallExample.h"
|
||||
|
@ -16,23 +20,22 @@ using namespace gtsam;
|
|||
TEST( SymbolicBayesChain, constructor )
|
||||
{
|
||||
// Create manually
|
||||
SymbolicConditional::shared_ptr x2(new SymbolicConditional("x1", "l1"));
|
||||
SymbolicConditional::shared_ptr l1(new SymbolicConditional("x1"));
|
||||
SymbolicConditional::shared_ptr x1(new SymbolicConditional());
|
||||
SymbolicConditional::shared_ptr
|
||||
x2(new SymbolicConditional("x1", "l1")),
|
||||
l1(new SymbolicConditional("x1")),
|
||||
x1(new SymbolicConditional());
|
||||
map<string, SymbolicConditional::shared_ptr> nodes;
|
||||
nodes.insert(make_pair("x2", x2));
|
||||
nodes.insert(make_pair("l1", l1));
|
||||
nodes.insert(make_pair("x1", x1));
|
||||
insert(nodes)("x2", x2)("l1", l1)("x1", x1);
|
||||
SymbolicBayesChain expected(nodes);
|
||||
|
||||
// Create from a factor graph
|
||||
Ordering ordering;
|
||||
ordering.push_back("x2");
|
||||
ordering.push_back("l1");
|
||||
ordering.push_back("x1");
|
||||
ordering += "x2","l1","x1";
|
||||
LinearFactorGraph factorGraph = createLinearFactorGraph();
|
||||
SymbolicBayesChain actual(factorGraph, ordering);
|
||||
//CHECK(assert_equal(expected, actual));
|
||||
CHECK(assert_equal(expected, actual));
|
||||
|
||||
//bayesChain.ordering();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
Loading…
Reference in New Issue