insert bayes net as a clique

release/4.3a0
Kai Ni 2010-02-13 07:09:27 +00:00
parent 4408eaf6f4
commit d0a93ad9dd
3 changed files with 116 additions and 6 deletions

View File

@ -23,6 +23,10 @@ namespace gtsam {
using namespace std; using namespace std;
/* ************************************************************************* */
template<class Conditional>
BayesTree<Conditional>::Clique::Clique() {}
/* ************************************************************************* */ /* ************************************************************************* */
template<class Conditional> template<class Conditional>
BayesTree<Conditional>::Clique::Clique(const sharedConditional& conditional) { BayesTree<Conditional>::Clique::Clique(const sharedConditional& conditional) {
@ -298,6 +302,18 @@ namespace gtsam {
return new_clique; return new_clique;
} }
/* ************************************************************************* */
template<class Conditional>
typename BayesTree<Conditional>::sharedClique BayesTree<Conditional>::addClique
(const sharedConditional& conditional, list<sharedClique>& child_cliques) {
sharedClique new_clique(new Clique(conditional));
nodes_.insert(make_pair(conditional->key(), new_clique));
new_clique->children_ = child_cliques;
BOOST_FOREACH(sharedClique& child, child_cliques)
child->parent_ = new_clique;
return new_clique;
}
/* ************************************************************************* */ /* ************************************************************************* */
template<class Conditional> template<class Conditional>
void BayesTree<Conditional>::removeClique(sharedClique clique) { void BayesTree<Conditional>::removeClique(sharedClique clique) {
@ -337,8 +353,9 @@ namespace gtsam {
printf("WARNING: BayesTree.print encountered a forest...\n"); printf("WARNING: BayesTree.print encountered a forest...\n");
return; return;
} }
cout << s << ": size == " << size() << endl; cout << s << ": clique size == " << size() << ", node size == " << nodes_.size() << endl;
if (nodes_.empty()) return; if (nodes_.empty()) return;
printf("printing tree!\n");
root_->printTree(""); root_->printTree("");
} }
@ -409,6 +426,31 @@ namespace gtsam {
addClique(conditional,parent_clique); addClique(conditional,parent_clique);
} }
/* ************************************************************************* */
template<class Conditional>
typename BayesTree<Conditional>::sharedClique BayesTree<Conditional>::insert(
const BayesNet<Conditional>& bayesNet, list<sharedClique>& children, bool isRootClique)
{
if (bayesNet.size() == 0)
throw invalid_argument("BayesTree::insert: empty bayes net!");
// create a new clique and add all the conditionals to the clique
sharedClique new_clique;
typename BayesNet<Conditional>::sharedConditional conditional;
BOOST_REVERSE_FOREACH(conditional, bayesNet) {
if (!new_clique.get()) {
new_clique = addClique(conditional,children);
} else {
nodes_.insert(make_pair(conditional->key(), new_clique));
new_clique->push_front(conditional);
}
}
if (isRootClique) root_ = new_clique;
return new_clique;
}
/* ************************************************************************* */ /* ************************************************************************* */
// First finds clique marginal then marginalizes that // First finds clique marginal then marginalizes that
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -51,6 +51,8 @@ namespace gtsam {
//* Constructor */ //* Constructor */
Clique(const sharedConditional& conditional); Clique(const sharedConditional& conditional);
Clique();
/** return keys in frontal:separator order */ /** return keys in frontal:separator order */
Ordering keys() const; Ordering keys() const;
@ -100,17 +102,21 @@ namespace gtsam {
typedef SymbolMap<sharedClique> Nodes; typedef SymbolMap<sharedClique> Nodes;
Nodes nodes_; Nodes nodes_;
protected:
/** Root clique */ /** Root clique */
sharedClique root_; sharedClique root_;
/** add a clique */ /** remove a clique: warning, can result in a forest */
void removeClique(sharedClique clique);
/** add a clique (top down) */
sharedClique addClique(const sharedConditional& conditional, sharedClique addClique(const sharedConditional& conditional,
sharedClique parent_clique = sharedClique()); sharedClique parent_clique = sharedClique());
protected: /** add a clique (bottom up) */
sharedClique addClique(const sharedConditional& conditional,
/** remove a clique: warning, can result in a forest */ std::list<sharedClique>& child_cliques);
void removeClique(sharedClique clique);
public: public:
@ -146,6 +152,12 @@ namespace gtsam {
/** insert a new conditional */ /** insert a new conditional */
void insert(const sharedConditional& conditional, const IndexTable<Symbol>& index); void insert(const sharedConditional& conditional, const IndexTable<Symbol>& index);
/** insert a new clique corresponding to the given bayes net.
* it is the caller's responsibility to decide whether the given bayes net is a valid clique,
* i.e. all the variables (frontal and separator) are connected */
sharedClique insert(const BayesNet<Conditional>& bayesNet,
std::list<sharedClique>& children, bool isRootClique = false);
/** number of cliques */ /** number of cliques */
inline size_t size() const { inline size_t size() const {
if(root_) if(root_)

View File

@ -333,6 +333,62 @@ TEST( BayesTree, removeTop3 )
CHECK(orphans.size() == 0); CHECK(orphans.size() == 0);
} }
/* ************************************************************************* */ /* ************************************************************************* */
/**
* x2 - x3 - x4 - x5
* | / \ |
* x1 / \ x6
*/
TEST( BayesTree, insert )
{
// construct bayes tree by split the graph along the separator x3 - x4
Symbol _x4_('x', 4), _x5_('x', 5), _x6_('x', 6);
SymbolicFactorGraph fg1, fg2, fg3;
fg1.push_factor(_x3_, _x4_);
fg2.push_factor(_x1_, _x2_);
fg2.push_factor(_x2_, _x3_);
fg2.push_factor(_x1_, _x3_);
fg3.push_factor(_x4_, _x5_);
fg3.push_factor(_x5_, _x6_);
fg3.push_factor(_x4_, _x6_);
Ordering ordering1; ordering1 += _x3_, _x4_;
Ordering ordering2; ordering2 += _x1_, _x2_;
Ordering ordering3; ordering3 += _x6_, _x5_;
BayesNet<SymbolicConditional> bn1, bn2, bn3;
bn1 = fg1.eliminate(ordering1);
bn2 = fg2.eliminate(ordering2);
bn3 = fg3.eliminate(ordering3);
// insert child cliques
SymbolicBayesTree actual;
list<SymbolicBayesTree::sharedClique> children;
SymbolicBayesTree::sharedClique r1 = actual.insert(bn2, children);
SymbolicBayesTree::sharedClique r2 = actual.insert(bn3, children);
// insert root clique
children.push_back(r1);
children.push_back(r2);
actual.insert(bn1, children, true);
// traditional way
SymbolicFactorGraph fg;
fg.push_factor(_x3_, _x4_);
fg.push_factor(_x1_, _x2_);
fg.push_factor(_x2_, _x3_);
fg.push_factor(_x1_, _x3_);
fg.push_factor(_x4_, _x5_);
fg.push_factor(_x5_, _x6_);
fg.push_factor(_x4_, _x6_);
Ordering ordering; ordering += _x1_, _x2_, _x6_, _x5_, _x3_, _x4_;
BayesNet<SymbolicConditional> bn;
bn = fg.eliminate(ordering);
SymbolicBayesTree expected(bn);
CHECK(assert_equal(expected, actual));
}
/* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;