Removed debug code, added marginal function
parent
ec6611ae56
commit
4865edb883
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
#include <boost/foreach.hpp>
|
#include <boost/foreach.hpp>
|
||||||
#include "BayesTree.h"
|
#include "BayesTree.h"
|
||||||
|
#include "FactorGraph.h"
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
@ -27,7 +28,7 @@ namespace gtsam {
|
||||||
if (!separator_.empty()) {
|
if (!separator_.empty()) {
|
||||||
cout << " :";
|
cout << " :";
|
||||||
BOOST_FOREACH(string key, separator_)
|
BOOST_FOREACH(string key, separator_)
|
||||||
cout << " " << key;
|
cout << " " << key;
|
||||||
}
|
}
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
@ -48,10 +49,10 @@ namespace gtsam {
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// TODO: traversal is O(n*log(n)) but could be O(n) with better bayesNet
|
// TODO: traversal is O(n*log(n)) but could be O(n) with better bayesNet
|
||||||
template<class Conditional>
|
template<class Conditional>
|
||||||
BayesTree<Conditional>::BayesTree(const BayesNet<Conditional>& bayesNet, bool verbose) {
|
BayesTree<Conditional>::BayesTree(const BayesNet<Conditional>& bayesNet) {
|
||||||
typename BayesNet<Conditional>::const_reverse_iterator rit;
|
typename BayesNet<Conditional>::const_reverse_iterator rit;
|
||||||
for ( rit=bayesNet.rbegin(); rit != bayesNet.rend(); ++rit )
|
for ( rit=bayesNet.rbegin(); rit != bayesNet.rend(); ++rit )
|
||||||
insert(*rit,verbose);
|
insert(*rit);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -73,22 +74,29 @@ namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template<class Conditional>
|
template<class Conditional>
|
||||||
void BayesTree<Conditional>::insert(const boost::shared_ptr<Conditional>& conditional, bool verbose) {
|
void BayesTree<Conditional>::addClique
|
||||||
|
(const boost::shared_ptr<Conditional>& conditional, node_ptr parent_clique)
|
||||||
|
{
|
||||||
|
node_ptr new_clique(new Node(conditional));
|
||||||
|
nodeMap_.insert(make_pair(conditional->key(), nodes_.size()));
|
||||||
|
nodes_.push_back(new_clique);
|
||||||
|
if (parent_clique==NULL) return;
|
||||||
|
new_clique->parent_ = parent_clique;
|
||||||
|
parent_clique->children_.push_back(new_clique);
|
||||||
|
}
|
||||||
|
|
||||||
string key = conditional->key();
|
/* ************************************************************************* */
|
||||||
if (verbose) cout << "Inserting " << key << "| ";
|
template<class Conditional>
|
||||||
|
void BayesTree<Conditional>::insert
|
||||||
// get parents
|
(const boost::shared_ptr<Conditional>& conditional)
|
||||||
|
{
|
||||||
|
// get key and parents
|
||||||
|
string key = conditional->key();
|
||||||
list<string> parents = conditional->parents();
|
list<string> parents = conditional->parents();
|
||||||
if (verbose) BOOST_FOREACH(string p, parents) cout << p << " ";
|
|
||||||
if (verbose) cout << endl;
|
|
||||||
|
|
||||||
// if no parents, start a new root clique
|
// if no parents, start a new root clique
|
||||||
if (parents.empty()) {
|
if (parents.empty()) {
|
||||||
if (verbose) cout << "Creating root clique" << endl;
|
addClique(conditional);
|
||||||
node_ptr root(new Node(conditional));
|
|
||||||
nodes_.push_back(root);
|
|
||||||
nodeMap_.insert(make_pair(key, 0));
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -96,26 +104,35 @@ namespace gtsam {
|
||||||
string parent = parents.front();
|
string parent = parents.front();
|
||||||
NodeMap::const_iterator it = nodeMap_.find(parent);
|
NodeMap::const_iterator it = nodeMap_.find(parent);
|
||||||
if (it == nodeMap_.end()) throw(invalid_argument(
|
if (it == nodeMap_.end()) throw(invalid_argument(
|
||||||
"BayesTree::insert('"+key+"'): parent '" + parent + "' was not yet inserted"));
|
"BayesTree::insert('"+key+"'): parent '" + parent + "' not yet inserted"));
|
||||||
int index = it->second;
|
int parent_index = it->second;
|
||||||
node_ptr parent_clique = nodes_[index];
|
node_ptr parent_clique = nodes_[parent_index];
|
||||||
if (verbose) cout << "Parent clique " << index << " of size " << parent_clique->size() << endl;
|
|
||||||
|
|
||||||
// if the parents and parent clique have the same size, add to parent clique
|
// if the parents and parent clique have the same size, add to parent clique
|
||||||
if (parent_clique->size() == parents.size()) {
|
if (parent_clique->size() == parents.size()) {
|
||||||
if (verbose) cout << "Adding to clique " << index << endl;
|
nodeMap_.insert(make_pair(key, parent_index));
|
||||||
nodeMap_.insert(make_pair(key, index));
|
|
||||||
parent_clique->push_front(conditional);
|
parent_clique->push_front(conditional);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// otherwise, start a new clique and add it to the tree
|
// otherwise, start a new clique and add it to the tree
|
||||||
if (verbose) cout << "Starting new clique" << endl;
|
addClique(conditional,parent_clique);
|
||||||
node_ptr new_clique(new Node(conditional));
|
}
|
||||||
new_clique->parent_ = parent_clique;
|
|
||||||
parent_clique->children_.push_back(new_clique);
|
/* ************************************************************************* */
|
||||||
nodeMap_.insert(make_pair(key, nodes_.size()));
|
template<class Conditional>
|
||||||
nodes_.push_back(new_clique);
|
boost::shared_ptr<Conditional> BayesTree<Conditional>::marginal(const string& key) const {
|
||||||
|
|
||||||
|
// find the clique to which key belongs
|
||||||
|
NodeMap::const_iterator it = nodeMap_.find(key);
|
||||||
|
if (it == nodeMap_.end()) throw(invalid_argument(
|
||||||
|
"BayesTree::marginal('"+key+"'): key not found"));
|
||||||
|
|
||||||
|
// find all cliques on the path to the root
|
||||||
|
// FactorGraph
|
||||||
|
|
||||||
|
boost::shared_ptr<Conditional> result(new Conditional);
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -30,7 +30,6 @@ namespace gtsam {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
typedef boost::shared_ptr<Conditional> conditional_ptr;
|
typedef boost::shared_ptr<Conditional> conditional_ptr;
|
||||||
typedef std::pair<std::string,conditional_ptr> NamedConditional;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
|
@ -66,13 +65,16 @@ namespace gtsam {
|
||||||
typedef std::map<std::string, int> NodeMap;
|
typedef std::map<std::string, int> NodeMap;
|
||||||
NodeMap nodeMap_;
|
NodeMap nodeMap_;
|
||||||
|
|
||||||
|
/** add a clique */
|
||||||
|
void addClique(const conditional_ptr& conditional, node_ptr parent_clique=node_ptr());
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/** Create an empty Bayes Tree */
|
/** Create an empty Bayes Tree */
|
||||||
BayesTree();
|
BayesTree();
|
||||||
|
|
||||||
/** Create a Bayes Tree from a Bayes Net */
|
/** Create a Bayes Tree from a Bayes Net */
|
||||||
BayesTree(const BayesNet<Conditional>& bayesNet, bool verbose=false);
|
BayesTree(const BayesNet<Conditional>& bayesNet);
|
||||||
|
|
||||||
/** Destructor */
|
/** Destructor */
|
||||||
virtual ~BayesTree() {}
|
virtual ~BayesTree() {}
|
||||||
|
@ -84,7 +86,7 @@ namespace gtsam {
|
||||||
bool equals(const BayesTree<Conditional>& other, double tol = 1e-9) const;
|
bool equals(const BayesTree<Conditional>& other, double tol = 1e-9) const;
|
||||||
|
|
||||||
/** insert a new conditional */
|
/** insert a new conditional */
|
||||||
void insert(const boost::shared_ptr<Conditional>& conditional, bool verbose=false);
|
void insert(const boost::shared_ptr<Conditional>& conditional);
|
||||||
|
|
||||||
/** number of cliques */
|
/** number of cliques */
|
||||||
inline size_t size() const { return nodes_.size();}
|
inline size_t size() const { return nodes_.size();}
|
||||||
|
@ -92,6 +94,9 @@ namespace gtsam {
|
||||||
/** return root clique */
|
/** return root clique */
|
||||||
const BayesNet<Conditional>& root() const {return *(nodes_[0]);}
|
const BayesNet<Conditional>& root() const {return *(nodes_[0]);}
|
||||||
|
|
||||||
|
/** return marginal on any variable */
|
||||||
|
boost::shared_ptr<Conditional> marginal(const std::string& key) const;
|
||||||
|
|
||||||
}; // BayesTree
|
}; // BayesTree
|
||||||
|
|
||||||
} /// namespace gtsam
|
} /// namespace gtsam
|
||||||
|
|
|
@ -27,10 +27,10 @@ SymbolicConditional::shared_ptr B(new SymbolicConditional("B")), L(
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( BayesTree, Front )
|
TEST( BayesTree, Front )
|
||||||
{
|
{
|
||||||
BayesNet<SymbolicConditional> f1;
|
SymbolicBayesNet f1;
|
||||||
f1.push_back(B);
|
f1.push_back(B);
|
||||||
f1.push_back(L);
|
f1.push_back(L);
|
||||||
BayesNet<SymbolicConditional> f2;
|
SymbolicBayesNet f2;
|
||||||
f2.push_back(L);
|
f2.push_back(L);
|
||||||
f2.push_back(B);
|
f2.push_back(B);
|
||||||
CHECK(f1.equals(f1));
|
CHECK(f1.equals(f1));
|
||||||
|
@ -68,9 +68,8 @@ TEST( BayesTree, constructor )
|
||||||
ASIA.push_back(E);
|
ASIA.push_back(E);
|
||||||
ASIA.push_back(L);
|
ASIA.push_back(L);
|
||||||
ASIA.push_back(B);
|
ASIA.push_back(B);
|
||||||
bool verbose = false;
|
BayesTree<SymbolicConditional> bayesTree2(ASIA);
|
||||||
BayesTree<SymbolicConditional> bayesTree2(ASIA,verbose);
|
//bayesTree2.print("bayesTree2");
|
||||||
if (verbose) bayesTree2.print("bayesTree2");
|
|
||||||
|
|
||||||
// Check whether the same
|
// Check whether the same
|
||||||
CHECK(assert_equal(bayesTree,bayesTree2));
|
CHECK(assert_equal(bayesTree,bayesTree2));
|
||||||
|
@ -97,7 +96,7 @@ TEST( BayesTree, smoother )
|
||||||
GaussianBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering);
|
GaussianBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering);
|
||||||
|
|
||||||
// Create the Bayes tree
|
// Create the Bayes tree
|
||||||
BayesTree<ConditionalGaussian> bayesTree(*chordalBayesNet,false);
|
BayesTree<ConditionalGaussian> bayesTree(*chordalBayesNet);
|
||||||
LONGS_EQUAL(6,bayesTree.size());
|
LONGS_EQUAL(6,bayesTree.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,7 +107,7 @@ TEST( BayesTree, smoother )
|
||||||
x1 : x2
|
x1 : x2
|
||||||
x7 : x6
|
x7 : x6
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( BayesTree, balanced_smoother )
|
TEST( BayesTree, balanced_smoother_marginals )
|
||||||
{
|
{
|
||||||
// Create smoother with 7 nodes
|
// Create smoother with 7 nodes
|
||||||
LinearFactorGraph smoother = createSmoother(7);
|
LinearFactorGraph smoother = createSmoother(7);
|
||||||
|
@ -119,8 +118,21 @@ TEST( BayesTree, balanced_smoother )
|
||||||
GaussianBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering);
|
GaussianBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering);
|
||||||
|
|
||||||
// Create the Bayes tree
|
// Create the Bayes tree
|
||||||
BayesTree<ConditionalGaussian> bayesTree(*chordalBayesNet,false);
|
BayesTree<ConditionalGaussian> bayesTree(*chordalBayesNet);
|
||||||
LONGS_EQUAL(4,bayesTree.size());
|
LONGS_EQUAL(4,bayesTree.size());
|
||||||
|
|
||||||
|
// Check root clique
|
||||||
|
//BayesNet<ConditionalGaussian> expected_root;
|
||||||
|
//BayesNet<ConditionalGaussian> actual_root = bayesTree.root();
|
||||||
|
//CHECK(assert_equal(expected_root,actual_root));
|
||||||
|
|
||||||
|
// Check marginal on x1
|
||||||
|
ConditionalGaussian expected;
|
||||||
|
ConditionalGaussian::shared_ptr actual = bayesTree.marginal("x1");
|
||||||
|
CHECK(assert_equal(expected,*actual));
|
||||||
|
|
||||||
|
// JunctionTree is an undirected tree of cliques
|
||||||
|
// JunctionTree<ConditionalGaussian> marginals = bayesTree.marginals();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
Loading…
Reference in New Issue