Bayes tree constructor implemented and tested with ASIA, as well as smoother example from frankcvs meeting
parent
1e5a2d692a
commit
c046fed37c
|
@ -41,15 +41,18 @@ namespace gtsam {
|
|||
/** check equality */
|
||||
bool equals(const BayesChain& other, double tol = 1e-9) const;
|
||||
|
||||
/** size is the number of nodes */
|
||||
inline size_t size() const {return nodes_.size();}
|
||||
|
||||
/** insert: use reverse topological sort (i.e. parents last) */
|
||||
void insert(const std::string& key, boost::shared_ptr<Conditional> node);
|
||||
|
||||
/** delete */
|
||||
void erase(const std::string& key);
|
||||
|
||||
/** size is the number of nodes */
|
||||
inline size_t size() const {return nodes_.size();}
|
||||
|
||||
/** return keys in topological sort order (parents first), i.e., reverse elimination order */
|
||||
inline std::list<std::string> keys() const { return keys_;}
|
||||
|
||||
inline boost::shared_ptr<Conditional> operator[](const std::string& key) const {
|
||||
const_iterator cg = nodes_.find(key); // get node
|
||||
assert( cg != nodes_.end() );
|
||||
|
|
|
@ -4,26 +4,64 @@
|
|||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <boost/foreach.hpp>
|
||||
#include "BayesTree.h"
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
using namespace std;
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class Conditional>
|
||||
Front<Conditional>::Front(string key, cond_ptr conditional) {
|
||||
add(key, conditional);
|
||||
separator_ = conditional->parents();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class Conditional>
|
||||
void Front<Conditional>::print(const string& s) const {
|
||||
cout << s;
|
||||
BOOST_FOREACH(string key, keys_) cout << " " << key;
|
||||
if (!separator_.empty()) {
|
||||
cout << " :";
|
||||
BOOST_FOREACH(string key, separator_)
|
||||
cout << " " << key;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class Conditional>
|
||||
bool Front<Conditional>::equals(const Front<Conditional>& other, double tol) const {
|
||||
return (keys_ == other.keys_) &&
|
||||
equal(conditionals_.begin(),conditionals_.end(),other.conditionals_.begin(),equals_star<Conditional>);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class Conditional>
|
||||
void Front<Conditional>::add(string key, cond_ptr conditional) {
|
||||
keys_.push_front(key);
|
||||
conditionals_.push_front(conditional);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class Conditional>
|
||||
BayesTree<Conditional>::BayesTree() {
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// TODO: traversal is O(n*log(n)) but could be O(n) with better bayesChain
|
||||
template<class Conditional>
|
||||
BayesTree<Conditional>::BayesTree(BayesChain<Conditional>& bayesChain) {
|
||||
list<string> ordering;// = bayesChain.ordering();
|
||||
BayesTree<Conditional>::BayesTree(BayesChain<Conditional>& bayesChain, bool verbose) {
|
||||
list<string> reverseOrdering = bayesChain.keys();
|
||||
BOOST_FOREACH(string key, reverseOrdering)
|
||||
insert(key,bayesChain[key],verbose);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class Conditional>
|
||||
void BayesTree<Conditional>::print(const std::string& s) const {
|
||||
void BayesTree<Conditional>::print(const string& s) const {
|
||||
cout << s << ": size == " << nodes_.size() << endl;
|
||||
if (nodes_.empty()) return;
|
||||
nodes_[0]->printTree("");
|
||||
|
@ -34,19 +72,24 @@ namespace gtsam {
|
|||
bool BayesTree<Conditional>::equals(const BayesTree<Conditional>& other,
|
||||
double tol) const {
|
||||
return size()==other.size() &&
|
||||
equal(nodeMap_.begin(),nodeMap_.end(),other.nodeMap_.begin()) &&
|
||||
equal(nodes_.begin(),nodes_.end(),other.nodes_.begin(),equals_star<Node>);
|
||||
equal(nodeMap_.begin(),nodeMap_.end(),other.nodeMap_.begin()) &&
|
||||
equal(nodes_.begin(),nodes_.end(),other.nodes_.begin(),equals_star<Node>);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class Conditional>
|
||||
void BayesTree<Conditional>::insert(string key, conditional_ptr conditional) {
|
||||
void BayesTree<Conditional>::insert(string key, conditional_ptr conditional, bool verbose) {
|
||||
|
||||
// get any parent
|
||||
if (verbose) cout << "Inserting " << key << "| ";
|
||||
|
||||
// get parents
|
||||
list<string> parents = conditional->parents();
|
||||
if (verbose) BOOST_FOREACH(string p, parents) cout << p << " ";
|
||||
if (verbose) cout << endl;
|
||||
|
||||
// if no parents, start a new root clique
|
||||
if (parents.empty()) {
|
||||
if (verbose) cout << "Creating root clique" << endl;
|
||||
node_ptr root(new Node(key, conditional));
|
||||
nodes_.push_back(root);
|
||||
nodeMap_.insert(make_pair(key, 0));
|
||||
|
@ -57,18 +100,20 @@ namespace gtsam {
|
|||
string parent = parents.front();
|
||||
NodeMap::const_iterator it = nodeMap_.find(parent);
|
||||
if (it == nodeMap_.end()) throw(invalid_argument(
|
||||
"BayesTree::insert: parent with key " + key + "was not yet inserted"));
|
||||
"BayesTree::insert('"+key+"'): parent '" + parent + "' was not yet inserted"));
|
||||
int index = it->second;
|
||||
node_ptr parent_clique = nodes_[index];
|
||||
|
||||
// if the parents and parent clique have the same size, add to parent clique
|
||||
if (parent_clique->size() == parents.size()) {
|
||||
if (verbose) cout << "Adding to clique " << index << endl;
|
||||
nodeMap_.insert(make_pair(key, index));
|
||||
parent_clique->add(key, conditional);
|
||||
return;
|
||||
}
|
||||
|
||||
// otherwise, start a new clique and add it to the tree
|
||||
if (verbose) cout << "Starting new clique" << endl;
|
||||
node_ptr new_clique(new Node(key, conditional));
|
||||
new_clique->parent_ = parent_clique;
|
||||
parent_clique->children_.push_back(new_clique);
|
||||
|
@ -76,6 +121,6 @@ namespace gtsam {
|
|||
nodes_.push_back(new_clique);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************* */
|
||||
|
||||
} /// namespace gtsam
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
#include <vector>
|
||||
#include <boost/serialization/map.hpp>
|
||||
#include <boost/serialization/list.hpp>
|
||||
#include <boost/foreach.hpp> // TODO: make cpp file
|
||||
#include "Testable.h"
|
||||
#include "BayesChain.h"
|
||||
|
||||
|
@ -30,35 +29,16 @@ namespace gtsam {
|
|||
public:
|
||||
|
||||
/** constructor */
|
||||
Front(std::string key, cond_ptr conditional) {
|
||||
add(key, conditional);
|
||||
separator_ = conditional->parents();
|
||||
}
|
||||
Front(std::string key, cond_ptr conditional);
|
||||
|
||||
/** print */
|
||||
void print(const std::string& s = "") const {
|
||||
std::cout << s;
|
||||
BOOST_FOREACH(std::string key, keys_)
|
||||
std::cout << " " << key;
|
||||
if (!separator_.empty()) {
|
||||
std::cout << " :";
|
||||
BOOST_FOREACH(std::string key, separator_)
|
||||
std::cout << " " << key;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
void print(const std::string& s = "") const;
|
||||
|
||||
/** check equality. TODO: only keys */
|
||||
bool equals(const Front<Conditional>& other, double tol = 1e-9) const {
|
||||
return (keys_ == other.keys_) &&
|
||||
equal(conditionals_.begin(),conditionals_.end(),other.conditionals_.begin(),equals_star<Conditional>);
|
||||
}
|
||||
/** check equality */
|
||||
bool equals(const Front<Conditional>& other, double tol = 1e-9) const;
|
||||
|
||||
/** add a frontal node */
|
||||
void add(std::string key, cond_ptr conditional) {
|
||||
keys_.push_front(key);
|
||||
conditionals_.push_front(conditional);
|
||||
}
|
||||
void add(std::string key, cond_ptr conditional);
|
||||
|
||||
/** return size of the clique */
|
||||
inline size_t size() const {return keys_.size() + separator_.size();}
|
||||
|
@ -109,7 +89,7 @@ namespace gtsam {
|
|||
BayesTree();
|
||||
|
||||
/** Create a Bayes Tree from a SymbolicBayesChain */
|
||||
BayesTree(BayesChain<Conditional>& bayesChain);
|
||||
BayesTree(BayesChain<Conditional>& bayesChain, bool verbose=false);
|
||||
|
||||
/** Destructor */
|
||||
virtual ~BayesTree() {}
|
||||
|
@ -121,7 +101,7 @@ namespace gtsam {
|
|||
bool equals(const BayesTree<Conditional>& other, double tol = 1e-9) const;
|
||||
|
||||
/** insert a new conditional */
|
||||
void insert(std::string key, conditional_ptr conditional);
|
||||
void insert(std::string key, conditional_ptr conditional, bool verbose=false);
|
||||
|
||||
/** number of cliques */
|
||||
inline size_t size() const { return nodes_.size();}
|
||||
|
|
|
@ -12,23 +12,23 @@ using namespace boost::assign;
|
|||
|
||||
#include "SymbolicBayesChain.h"
|
||||
#include "BayesTree-inl.h"
|
||||
#include "SmallExample.h"
|
||||
|
||||
using namespace gtsam;
|
||||
|
||||
// Conditionals for ASIA example from the tutorial with A and D evidence
|
||||
SymbolicConditional::shared_ptr
|
||||
B(new SymbolicConditional()),
|
||||
L(new SymbolicConditional("B")),
|
||||
E(new SymbolicConditional("L","B")),
|
||||
S(new SymbolicConditional("L","B")),
|
||||
T(new SymbolicConditional("L","E")),
|
||||
SymbolicConditional::shared_ptr B(new SymbolicConditional()), L(
|
||||
new SymbolicConditional("B")), E(new SymbolicConditional("L", "B")), S(
|
||||
new SymbolicConditional("L", "B")), T(new SymbolicConditional("E", "L")),
|
||||
X(new SymbolicConditional("E"));
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( BayesTree, Front )
|
||||
{
|
||||
Front<SymbolicConditional> f1("B",B); f1.add("L",L);
|
||||
Front<SymbolicConditional> f2("L",L); f2.add("B",B);
|
||||
Front<SymbolicConditional> f1("B", B);
|
||||
f1.add("L", L);
|
||||
Front<SymbolicConditional> f2("L", L);
|
||||
f2.add("B", B);
|
||||
CHECK(f1.equals(f1));
|
||||
CHECK(!f1.equals(f2));
|
||||
}
|
||||
|
@ -38,31 +38,82 @@ TEST( BayesTree, constructor )
|
|||
{
|
||||
// Create using insert
|
||||
BayesTree<SymbolicConditional> bayesTree;
|
||||
bayesTree.insert("B",B);
|
||||
bayesTree.insert("L",L);
|
||||
bayesTree.insert("E",E);
|
||||
bayesTree.insert("S",S);
|
||||
bayesTree.insert("T",T);
|
||||
bayesTree.insert("X",X);
|
||||
bayesTree.insert("B", B);
|
||||
bayesTree.insert("L", L);
|
||||
bayesTree.insert("E", E);
|
||||
bayesTree.insert("S", S);
|
||||
bayesTree.insert("T", T);
|
||||
bayesTree.insert("X", X);
|
||||
|
||||
// Check Size
|
||||
LONGS_EQUAL(4,bayesTree.size());
|
||||
|
||||
// Check root
|
||||
Front<SymbolicConditional> expected_root("B",B);
|
||||
expected_root.add("L",L);
|
||||
expected_root.add("E",E);
|
||||
Front<SymbolicConditional> expected_root("B", B);
|
||||
expected_root.add("L", L);
|
||||
expected_root.add("E", E);
|
||||
Front<SymbolicConditional> actual_root = bayesTree.root();
|
||||
CHECK(assert_equal(expected_root,actual_root));
|
||||
|
||||
// Create from symbolic Bayes chain in which we want to discover cliques
|
||||
map<string, SymbolicConditional::shared_ptr> nodes;
|
||||
insert(nodes)("B",B)("L",L)("E",E)("S",S)("T",T)("X",X);
|
||||
SymbolicBayesChain ASIA(nodes);
|
||||
SymbolicBayesChain ASIA;
|
||||
ASIA.insert("X", X);
|
||||
ASIA.insert("T", T);
|
||||
ASIA.insert("S", S);
|
||||
ASIA.insert("E", E);
|
||||
ASIA.insert("L", L);
|
||||
ASIA.insert("B", B);
|
||||
BayesTree<SymbolicConditional> bayesTree2(ASIA);
|
||||
|
||||
// Check whether the same
|
||||
//CHECK(assert_equal(bayesTree,bayesTree2));
|
||||
CHECK(assert_equal(bayesTree,bayesTree2));
|
||||
}
|
||||
|
||||
/* ************************************************************************* *
|
||||
Bayes tree for smoother with "natural" ordering:
|
||||
x6 x7
|
||||
x5 : x6
|
||||
x4 : x5
|
||||
x3 : x4
|
||||
x2 : x3
|
||||
x1 : x2
|
||||
/* ************************************************************************* */
|
||||
TEST( BayesTree, smoother )
|
||||
{
|
||||
// Create smoother with 7 nodes
|
||||
LinearFactorGraph smoother = createSmoother(7);
|
||||
Ordering ordering;
|
||||
for (int t = 1; t <= 7; t++)
|
||||
ordering.push_back(symbol('x', t));
|
||||
|
||||
// eliminate using the "natural" ordering
|
||||
ChordalBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering);
|
||||
|
||||
// Create the Bayes tree
|
||||
BayesTree<ConditionalGaussian> bayesTree(*chordalBayesNet,false);
|
||||
LONGS_EQUAL(6,bayesTree.size());
|
||||
}
|
||||
|
||||
/* ************************************************************************* *
|
||||
Bayes tree for smoother with "nested dissection" ordering:
|
||||
x5 x6 x4
|
||||
x3 x2 : x4
|
||||
x1 : x2
|
||||
x7 : x6
|
||||
/* ************************************************************************* */
|
||||
TEST( BayesTree, balanced_smoother )
|
||||
{
|
||||
// Create smoother with 7 nodes
|
||||
LinearFactorGraph smoother = createSmoother(7);
|
||||
Ordering ordering;
|
||||
ordering += "x1","x3","x5","x7","x2","x6","x4";
|
||||
|
||||
// eliminate using a "nested dissection" ordering
|
||||
ChordalBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering);
|
||||
|
||||
// Create the Bayes tree
|
||||
BayesTree<ConditionalGaussian> bayesTree(*chordalBayesNet,false);
|
||||
LONGS_EQUAL(4,bayesTree.size());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
Loading…
Reference in New Issue