Bayes tree constructor implemented and tested with ASIA, as well as smoother example from frankcvs meeting
parent
1e5a2d692a
commit
c046fed37c
|
@ -41,15 +41,18 @@ namespace gtsam {
|
||||||
/** check equality */
|
/** check equality */
|
||||||
bool equals(const BayesChain& other, double tol = 1e-9) const;
|
bool equals(const BayesChain& other, double tol = 1e-9) const;
|
||||||
|
|
||||||
/** size is the number of nodes */
|
|
||||||
inline size_t size() const {return nodes_.size();}
|
|
||||||
|
|
||||||
/** insert: use reverse topological sort (i.e. parents last) */
|
/** insert: use reverse topological sort (i.e. parents last) */
|
||||||
void insert(const std::string& key, boost::shared_ptr<Conditional> node);
|
void insert(const std::string& key, boost::shared_ptr<Conditional> node);
|
||||||
|
|
||||||
/** delete */
|
/** delete */
|
||||||
void erase(const std::string& key);
|
void erase(const std::string& key);
|
||||||
|
|
||||||
|
/** size is the number of nodes */
|
||||||
|
inline size_t size() const {return nodes_.size();}
|
||||||
|
|
||||||
|
/** return keys in topological sort order (parents first), i.e., reverse elimination order */
|
||||||
|
inline std::list<std::string> keys() const { return keys_;}
|
||||||
|
|
||||||
inline boost::shared_ptr<Conditional> operator[](const std::string& key) const {
|
inline boost::shared_ptr<Conditional> operator[](const std::string& key) const {
|
||||||
const_iterator cg = nodes_.find(key); // get node
|
const_iterator cg = nodes_.find(key); // get node
|
||||||
assert( cg != nodes_.end() );
|
assert( cg != nodes_.end() );
|
||||||
|
|
|
@ -4,26 +4,64 @@
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <boost/foreach.hpp>
|
||||||
#include "BayesTree.h"
|
#include "BayesTree.h"
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
template<class Conditional>
|
||||||
|
Front<Conditional>::Front(string key, cond_ptr conditional) {
|
||||||
|
add(key, conditional);
|
||||||
|
separator_ = conditional->parents();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
template<class Conditional>
|
||||||
|
void Front<Conditional>::print(const string& s) const {
|
||||||
|
cout << s;
|
||||||
|
BOOST_FOREACH(string key, keys_) cout << " " << key;
|
||||||
|
if (!separator_.empty()) {
|
||||||
|
cout << " :";
|
||||||
|
BOOST_FOREACH(string key, separator_)
|
||||||
|
cout << " " << key;
|
||||||
|
}
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
template<class Conditional>
|
||||||
|
bool Front<Conditional>::equals(const Front<Conditional>& other, double tol) const {
|
||||||
|
return (keys_ == other.keys_) &&
|
||||||
|
equal(conditionals_.begin(),conditionals_.end(),other.conditionals_.begin(),equals_star<Conditional>);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
template<class Conditional>
|
||||||
|
void Front<Conditional>::add(string key, cond_ptr conditional) {
|
||||||
|
keys_.push_front(key);
|
||||||
|
conditionals_.push_front(conditional);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template<class Conditional>
|
template<class Conditional>
|
||||||
BayesTree<Conditional>::BayesTree() {
|
BayesTree<Conditional>::BayesTree() {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
// TODO: traversal is O(n*log(n)) but could be O(n) with better bayesChain
|
||||||
template<class Conditional>
|
template<class Conditional>
|
||||||
BayesTree<Conditional>::BayesTree(BayesChain<Conditional>& bayesChain) {
|
BayesTree<Conditional>::BayesTree(BayesChain<Conditional>& bayesChain, bool verbose) {
|
||||||
list<string> ordering;// = bayesChain.ordering();
|
list<string> reverseOrdering = bayesChain.keys();
|
||||||
|
BOOST_FOREACH(string key, reverseOrdering)
|
||||||
|
insert(key,bayesChain[key],verbose);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template<class Conditional>
|
template<class Conditional>
|
||||||
void BayesTree<Conditional>::print(const std::string& s) const {
|
void BayesTree<Conditional>::print(const string& s) const {
|
||||||
cout << s << ": size == " << nodes_.size() << endl;
|
cout << s << ": size == " << nodes_.size() << endl;
|
||||||
if (nodes_.empty()) return;
|
if (nodes_.empty()) return;
|
||||||
nodes_[0]->printTree("");
|
nodes_[0]->printTree("");
|
||||||
|
@ -34,19 +72,24 @@ namespace gtsam {
|
||||||
bool BayesTree<Conditional>::equals(const BayesTree<Conditional>& other,
|
bool BayesTree<Conditional>::equals(const BayesTree<Conditional>& other,
|
||||||
double tol) const {
|
double tol) const {
|
||||||
return size()==other.size() &&
|
return size()==other.size() &&
|
||||||
equal(nodeMap_.begin(),nodeMap_.end(),other.nodeMap_.begin()) &&
|
equal(nodeMap_.begin(),nodeMap_.end(),other.nodeMap_.begin()) &&
|
||||||
equal(nodes_.begin(),nodes_.end(),other.nodes_.begin(),equals_star<Node>);
|
equal(nodes_.begin(),nodes_.end(),other.nodes_.begin(),equals_star<Node>);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template<class Conditional>
|
template<class Conditional>
|
||||||
void BayesTree<Conditional>::insert(string key, conditional_ptr conditional) {
|
void BayesTree<Conditional>::insert(string key, conditional_ptr conditional, bool verbose) {
|
||||||
|
|
||||||
// get any parent
|
if (verbose) cout << "Inserting " << key << "| ";
|
||||||
|
|
||||||
|
// get parents
|
||||||
list<string> parents = conditional->parents();
|
list<string> parents = conditional->parents();
|
||||||
|
if (verbose) BOOST_FOREACH(string p, parents) cout << p << " ";
|
||||||
|
if (verbose) cout << endl;
|
||||||
|
|
||||||
// if no parents, start a new root clique
|
// if no parents, start a new root clique
|
||||||
if (parents.empty()) {
|
if (parents.empty()) {
|
||||||
|
if (verbose) cout << "Creating root clique" << endl;
|
||||||
node_ptr root(new Node(key, conditional));
|
node_ptr root(new Node(key, conditional));
|
||||||
nodes_.push_back(root);
|
nodes_.push_back(root);
|
||||||
nodeMap_.insert(make_pair(key, 0));
|
nodeMap_.insert(make_pair(key, 0));
|
||||||
|
@ -57,18 +100,20 @@ namespace gtsam {
|
||||||
string parent = parents.front();
|
string parent = parents.front();
|
||||||
NodeMap::const_iterator it = nodeMap_.find(parent);
|
NodeMap::const_iterator it = nodeMap_.find(parent);
|
||||||
if (it == nodeMap_.end()) throw(invalid_argument(
|
if (it == nodeMap_.end()) throw(invalid_argument(
|
||||||
"BayesTree::insert: parent with key " + key + "was not yet inserted"));
|
"BayesTree::insert('"+key+"'): parent '" + parent + "' was not yet inserted"));
|
||||||
int index = it->second;
|
int index = it->second;
|
||||||
node_ptr parent_clique = nodes_[index];
|
node_ptr parent_clique = nodes_[index];
|
||||||
|
|
||||||
// if the parents and parent clique have the same size, add to parent clique
|
// if the parents and parent clique have the same size, add to parent clique
|
||||||
if (parent_clique->size() == parents.size()) {
|
if (parent_clique->size() == parents.size()) {
|
||||||
|
if (verbose) cout << "Adding to clique " << index << endl;
|
||||||
nodeMap_.insert(make_pair(key, index));
|
nodeMap_.insert(make_pair(key, index));
|
||||||
parent_clique->add(key, conditional);
|
parent_clique->add(key, conditional);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// otherwise, start a new clique and add it to the tree
|
// otherwise, start a new clique and add it to the tree
|
||||||
|
if (verbose) cout << "Starting new clique" << endl;
|
||||||
node_ptr new_clique(new Node(key, conditional));
|
node_ptr new_clique(new Node(key, conditional));
|
||||||
new_clique->parent_ = parent_clique;
|
new_clique->parent_ = parent_clique;
|
||||||
parent_clique->children_.push_back(new_clique);
|
parent_clique->children_.push_back(new_clique);
|
||||||
|
@ -76,6 +121,6 @@ namespace gtsam {
|
||||||
nodes_.push_back(new_clique);
|
nodes_.push_back(new_clique);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
||||||
} /// namespace gtsam
|
} /// namespace gtsam
|
||||||
|
|
|
@ -13,7 +13,6 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <boost/serialization/map.hpp>
|
#include <boost/serialization/map.hpp>
|
||||||
#include <boost/serialization/list.hpp>
|
#include <boost/serialization/list.hpp>
|
||||||
#include <boost/foreach.hpp> // TODO: make cpp file
|
|
||||||
#include "Testable.h"
|
#include "Testable.h"
|
||||||
#include "BayesChain.h"
|
#include "BayesChain.h"
|
||||||
|
|
||||||
|
@ -30,35 +29,16 @@ namespace gtsam {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/** constructor */
|
/** constructor */
|
||||||
Front(std::string key, cond_ptr conditional) {
|
Front(std::string key, cond_ptr conditional);
|
||||||
add(key, conditional);
|
|
||||||
separator_ = conditional->parents();
|
|
||||||
}
|
|
||||||
|
|
||||||
/** print */
|
/** print */
|
||||||
void print(const std::string& s = "") const {
|
void print(const std::string& s = "") const;
|
||||||
std::cout << s;
|
|
||||||
BOOST_FOREACH(std::string key, keys_)
|
|
||||||
std::cout << " " << key;
|
|
||||||
if (!separator_.empty()) {
|
|
||||||
std::cout << " :";
|
|
||||||
BOOST_FOREACH(std::string key, separator_)
|
|
||||||
std::cout << " " << key;
|
|
||||||
}
|
|
||||||
std::cout << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** check equality. TODO: only keys */
|
/** check equality */
|
||||||
bool equals(const Front<Conditional>& other, double tol = 1e-9) const {
|
bool equals(const Front<Conditional>& other, double tol = 1e-9) const;
|
||||||
return (keys_ == other.keys_) &&
|
|
||||||
equal(conditionals_.begin(),conditionals_.end(),other.conditionals_.begin(),equals_star<Conditional>);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** add a frontal node */
|
/** add a frontal node */
|
||||||
void add(std::string key, cond_ptr conditional) {
|
void add(std::string key, cond_ptr conditional);
|
||||||
keys_.push_front(key);
|
|
||||||
conditionals_.push_front(conditional);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** return size of the clique */
|
/** return size of the clique */
|
||||||
inline size_t size() const {return keys_.size() + separator_.size();}
|
inline size_t size() const {return keys_.size() + separator_.size();}
|
||||||
|
@ -109,7 +89,7 @@ namespace gtsam {
|
||||||
BayesTree();
|
BayesTree();
|
||||||
|
|
||||||
/** Create a Bayes Tree from a SymbolicBayesChain */
|
/** Create a Bayes Tree from a SymbolicBayesChain */
|
||||||
BayesTree(BayesChain<Conditional>& bayesChain);
|
BayesTree(BayesChain<Conditional>& bayesChain, bool verbose=false);
|
||||||
|
|
||||||
/** Destructor */
|
/** Destructor */
|
||||||
virtual ~BayesTree() {}
|
virtual ~BayesTree() {}
|
||||||
|
@ -121,7 +101,7 @@ namespace gtsam {
|
||||||
bool equals(const BayesTree<Conditional>& other, double tol = 1e-9) const;
|
bool equals(const BayesTree<Conditional>& other, double tol = 1e-9) const;
|
||||||
|
|
||||||
/** insert a new conditional */
|
/** insert a new conditional */
|
||||||
void insert(std::string key, conditional_ptr conditional);
|
void insert(std::string key, conditional_ptr conditional, bool verbose=false);
|
||||||
|
|
||||||
/** number of cliques */
|
/** number of cliques */
|
||||||
inline size_t size() const { return nodes_.size();}
|
inline size_t size() const { return nodes_.size();}
|
||||||
|
|
|
@ -12,23 +12,23 @@ using namespace boost::assign;
|
||||||
|
|
||||||
#include "SymbolicBayesChain.h"
|
#include "SymbolicBayesChain.h"
|
||||||
#include "BayesTree-inl.h"
|
#include "BayesTree-inl.h"
|
||||||
|
#include "SmallExample.h"
|
||||||
|
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
// Conditionals for ASIA example from the tutorial with A and D evidence
|
// Conditionals for ASIA example from the tutorial with A and D evidence
|
||||||
SymbolicConditional::shared_ptr
|
SymbolicConditional::shared_ptr B(new SymbolicConditional()), L(
|
||||||
B(new SymbolicConditional()),
|
new SymbolicConditional("B")), E(new SymbolicConditional("L", "B")), S(
|
||||||
L(new SymbolicConditional("B")),
|
new SymbolicConditional("L", "B")), T(new SymbolicConditional("E", "L")),
|
||||||
E(new SymbolicConditional("L","B")),
|
|
||||||
S(new SymbolicConditional("L","B")),
|
|
||||||
T(new SymbolicConditional("L","E")),
|
|
||||||
X(new SymbolicConditional("E"));
|
X(new SymbolicConditional("E"));
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( BayesTree, Front )
|
TEST( BayesTree, Front )
|
||||||
{
|
{
|
||||||
Front<SymbolicConditional> f1("B",B); f1.add("L",L);
|
Front<SymbolicConditional> f1("B", B);
|
||||||
Front<SymbolicConditional> f2("L",L); f2.add("B",B);
|
f1.add("L", L);
|
||||||
|
Front<SymbolicConditional> f2("L", L);
|
||||||
|
f2.add("B", B);
|
||||||
CHECK(f1.equals(f1));
|
CHECK(f1.equals(f1));
|
||||||
CHECK(!f1.equals(f2));
|
CHECK(!f1.equals(f2));
|
||||||
}
|
}
|
||||||
|
@ -38,31 +38,82 @@ TEST( BayesTree, constructor )
|
||||||
{
|
{
|
||||||
// Create using insert
|
// Create using insert
|
||||||
BayesTree<SymbolicConditional> bayesTree;
|
BayesTree<SymbolicConditional> bayesTree;
|
||||||
bayesTree.insert("B",B);
|
bayesTree.insert("B", B);
|
||||||
bayesTree.insert("L",L);
|
bayesTree.insert("L", L);
|
||||||
bayesTree.insert("E",E);
|
bayesTree.insert("E", E);
|
||||||
bayesTree.insert("S",S);
|
bayesTree.insert("S", S);
|
||||||
bayesTree.insert("T",T);
|
bayesTree.insert("T", T);
|
||||||
bayesTree.insert("X",X);
|
bayesTree.insert("X", X);
|
||||||
|
|
||||||
// Check Size
|
// Check Size
|
||||||
LONGS_EQUAL(4,bayesTree.size());
|
LONGS_EQUAL(4,bayesTree.size());
|
||||||
|
|
||||||
// Check root
|
// Check root
|
||||||
Front<SymbolicConditional> expected_root("B",B);
|
Front<SymbolicConditional> expected_root("B", B);
|
||||||
expected_root.add("L",L);
|
expected_root.add("L", L);
|
||||||
expected_root.add("E",E);
|
expected_root.add("E", E);
|
||||||
Front<SymbolicConditional> actual_root = bayesTree.root();
|
Front<SymbolicConditional> actual_root = bayesTree.root();
|
||||||
CHECK(assert_equal(expected_root,actual_root));
|
CHECK(assert_equal(expected_root,actual_root));
|
||||||
|
|
||||||
// Create from symbolic Bayes chain in which we want to discover cliques
|
// Create from symbolic Bayes chain in which we want to discover cliques
|
||||||
map<string, SymbolicConditional::shared_ptr> nodes;
|
SymbolicBayesChain ASIA;
|
||||||
insert(nodes)("B",B)("L",L)("E",E)("S",S)("T",T)("X",X);
|
ASIA.insert("X", X);
|
||||||
SymbolicBayesChain ASIA(nodes);
|
ASIA.insert("T", T);
|
||||||
|
ASIA.insert("S", S);
|
||||||
|
ASIA.insert("E", E);
|
||||||
|
ASIA.insert("L", L);
|
||||||
|
ASIA.insert("B", B);
|
||||||
BayesTree<SymbolicConditional> bayesTree2(ASIA);
|
BayesTree<SymbolicConditional> bayesTree2(ASIA);
|
||||||
|
|
||||||
// Check whether the same
|
// Check whether the same
|
||||||
//CHECK(assert_equal(bayesTree,bayesTree2));
|
CHECK(assert_equal(bayesTree,bayesTree2));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* *
|
||||||
|
Bayes tree for smoother with "natural" ordering:
|
||||||
|
x6 x7
|
||||||
|
x5 : x6
|
||||||
|
x4 : x5
|
||||||
|
x3 : x4
|
||||||
|
x2 : x3
|
||||||
|
x1 : x2
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST( BayesTree, smoother )
|
||||||
|
{
|
||||||
|
// Create smoother with 7 nodes
|
||||||
|
LinearFactorGraph smoother = createSmoother(7);
|
||||||
|
Ordering ordering;
|
||||||
|
for (int t = 1; t <= 7; t++)
|
||||||
|
ordering.push_back(symbol('x', t));
|
||||||
|
|
||||||
|
// eliminate using the "natural" ordering
|
||||||
|
ChordalBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering);
|
||||||
|
|
||||||
|
// Create the Bayes tree
|
||||||
|
BayesTree<ConditionalGaussian> bayesTree(*chordalBayesNet,false);
|
||||||
|
LONGS_EQUAL(6,bayesTree.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* *
|
||||||
|
Bayes tree for smoother with "nested dissection" ordering:
|
||||||
|
x5 x6 x4
|
||||||
|
x3 x2 : x4
|
||||||
|
x1 : x2
|
||||||
|
x7 : x6
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST( BayesTree, balanced_smoother )
|
||||||
|
{
|
||||||
|
// Create smoother with 7 nodes
|
||||||
|
LinearFactorGraph smoother = createSmoother(7);
|
||||||
|
Ordering ordering;
|
||||||
|
ordering += "x1","x3","x5","x7","x2","x6","x4";
|
||||||
|
|
||||||
|
// eliminate using a "nested dissection" ordering
|
||||||
|
ChordalBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering);
|
||||||
|
|
||||||
|
// Create the Bayes tree
|
||||||
|
BayesTree<ConditionalGaussian> bayesTree(*chordalBayesNet,false);
|
||||||
|
LONGS_EQUAL(4,bayesTree.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
Loading…
Reference in New Issue