gtsam/gtsam/inference/BayesTree-inst.h

523 lines
18 KiB
C++

/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file BayesTree-inl.h
* @brief Bayes Tree is a tree of cliques of a Bayes Chain
* @author Frank Dellaert
* @author Michael Kaess
* @author Viorela Ila
* @author Richard Roberts
*/
#pragma once
#include <gtsam/inference/Ordering.h>
#include <gtsam/inference/BayesTree.h>
#include <gtsam/base/treeTraversal-inst.h>
#include <gtsam/base/timing.h>
#include <boost/optional.hpp>
#include <boost/assign/list_of.hpp>
#include <fstream>
using boost::assign::cref_list_of;
namespace gtsam {
/* ************************************************************************* */
template<class CLIQUE>
BayesTreeCliqueData BayesTree<CLIQUE>::getCliqueData() const {
BayesTreeCliqueData data;
for(const sharedClique& root: roots_)
getCliqueData(data, root);
return data;
}
/* ************************************************************************* */
template<class CLIQUE>
void BayesTree<CLIQUE>::getCliqueData(BayesTreeCliqueData& data, sharedClique clique) const {
data.conditionalSizes.push_back(clique->conditional()->nrFrontals());
data.separatorSizes.push_back(clique->conditional()->nrParents());
for(sharedClique c: clique->children) {
getCliqueData(data, c);
}
}
/* ************************************************************************* */
template<class CLIQUE>
size_t BayesTree<CLIQUE>::numCachedSeparatorMarginals() const {
size_t count = 0;
for(const sharedClique& root: roots_)
count += root->numCachedSeparatorMarginals();
return count;
}
/* ************************************************************************* */
template<class CLIQUE>
void BayesTree<CLIQUE>::saveGraph(const std::string &s, const KeyFormatter& keyFormatter) const {
if (roots_.empty()) throw std::invalid_argument("the root of Bayes tree has not been initialized!");
std::ofstream of(s.c_str());
of<< "digraph G{\n";
for(const sharedClique& root: roots_)
saveGraph(of, root, keyFormatter);
of<<"}";
of.close();
}
/* ************************************************************************* */
template<class CLIQUE>
void BayesTree<CLIQUE>::saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& indexFormatter, int parentnum) const {
static int num = 0;
bool first = true;
std::stringstream out;
out << num;
std::string parent = out.str();
parent += "[label=\"";
for (Key index : clique->conditional_->frontals()) {
if (!first) parent += ",";
first = false;
parent += indexFormatter(index);
}
if (clique->parent()) {
parent += " : ";
s << parentnum << "->" << num << "\n";
}
first = true;
for (Key sep : clique->conditional_->parents()) {
if (!first) parent += ",";
first = false;
parent += indexFormatter(sep);
}
parent += "\"];\n";
s << parent;
parentnum = num;
for (sharedClique c : clique->children) {
num++;
saveGraph(s, c, indexFormatter, parentnum);
}
}
/* ************************************************************************* */
template<class CLIQUE>
size_t BayesTree<CLIQUE>::size() const {
size_t size = 0;
for(const sharedClique& clique: roots_)
size += clique->treeSize();
return size;
}
/* ************************************************************************* */
template<class CLIQUE>
void BayesTree<CLIQUE>::addClique(const sharedClique& clique, const sharedClique& parent_clique) {
for(Key j: clique->conditional()->frontals())
nodes_[j] = clique;
if (parent_clique != NULL) {
clique->parent_ = parent_clique;
parent_clique->children.push_back(clique);
} else {
roots_.push_back(clique);
}
}
/* ************************************************************************* */
// TODO: Clean up
namespace {
template<class FACTOR, class CLIQUE>
int _pushClique(FactorGraph<FACTOR>& fg, const boost::shared_ptr<CLIQUE>& clique) {
fg.push_back(clique->conditional_);
return 0;
}
template<class FACTOR, class CLIQUE>
struct _pushCliqueFunctor {
_pushCliqueFunctor(FactorGraph<FACTOR>& graph_) : graph(graph_) {}
FactorGraph<FACTOR>& graph;
int operator()(const boost::shared_ptr<CLIQUE>& clique, int dummy) {
graph.push_back(clique->conditional_);
return 0;
}
};
}
/* ************************************************************************* */
template<class CLIQUE>
void BayesTree<CLIQUE>::addFactorsToGraph(FactorGraph<FactorType>& graph) const
{
// Traverse the BayesTree and add all conditionals to this graph
int data = 0; // Unused
_pushCliqueFunctor<FactorType,CLIQUE> functor(graph);
treeTraversal::DepthFirstForest(*this, data, functor); // FIXME: sort of works?
// treeTraversal::DepthFirstForest(*this, data, boost::bind(&_pushClique<FactorType,CLIQUE>, boost::ref(graph), _1));
}
/* ************************************************************************* */
template<class CLIQUE>
BayesTree<CLIQUE>::BayesTree(const This& other) {
*this = other;
}
/* ************************************************************************* */
namespace {
template<typename NODE>
boost::shared_ptr<NODE>
BayesTreeCloneForestVisitorPre(const boost::shared_ptr<NODE>& node, const boost::shared_ptr<NODE>& parentPointer)
{
// Clone the current node and add it to its cloned parent
boost::shared_ptr<NODE> clone = boost::make_shared<NODE>(*node);
clone->children.clear();
clone->parent_ = parentPointer;
parentPointer->children.push_back(clone);
return clone;
}
}
/* ************************************************************************* */
template<class CLIQUE>
BayesTree<CLIQUE>& BayesTree<CLIQUE>::operator=(const This& other) {
this->clear();
boost::shared_ptr<Clique> rootContainer = boost::make_shared<Clique>();
treeTraversal::DepthFirstForest(other, rootContainer, BayesTreeCloneForestVisitorPre<Clique>);
for(const sharedClique& root: rootContainer->children) {
root->parent_ = typename Clique::weak_ptr(); // Reset the parent since it's set to the dummy clique
insertRoot(root);
}
return *this;
}
/* ************************************************************************* */
template<class CLIQUE>
void BayesTree<CLIQUE>::print(const std::string& s, const KeyFormatter& keyFormatter) const {
std::cout << s << ": cliques: " << size() << ", variables: " << nodes_.size() << std::endl;
treeTraversal::PrintForest(*this, s, keyFormatter);
}
/* ************************************************************************* */
// binary predicate to test equality of a pair for use in equals
template<class CLIQUE>
bool check_sharedCliques(
const std::pair<Key, typename BayesTree<CLIQUE>::sharedClique>& v1,
const std::pair<Key, typename BayesTree<CLIQUE>::sharedClique>& v2
) {
return v1.first == v2.first &&
((!v1.second && !v2.second) || (v1.second && v2.second && v1.second->equals(*v2.second)));
}
/* ************************************************************************* */
template<class CLIQUE>
bool BayesTree<CLIQUE>::equals(const BayesTree<CLIQUE>& other, double tol) const {
return size()==other.size() &&
std::equal(nodes_.begin(), nodes_.end(), other.nodes_.begin(), &check_sharedCliques<CLIQUE>);
}
/* ************************************************************************* */
template<class CLIQUE>
template<class CONTAINER>
Key BayesTree<CLIQUE>::findParentClique(const CONTAINER& parents) const {
typename CONTAINER::const_iterator lowestOrderedParent = min_element(parents.begin(), parents.end());
assert(lowestOrderedParent != parents.end());
return *lowestOrderedParent;
}
/* ************************************************************************* */
template<class CLIQUE>
void BayesTree<CLIQUE>::fillNodesIndex(const sharedClique& subtree) {
// Add each frontal variable of this root node
for(const Key& j: subtree->conditional()->frontals()) {
bool inserted = nodes_.insert(std::make_pair(j, subtree)).second;
assert(inserted); (void)inserted;
}
// Fill index for each child
typedef typename BayesTree<CLIQUE>::sharedClique sharedClique;
for(const sharedClique& child: subtree->children) {
fillNodesIndex(child); }
}
/* ************************************************************************* */
template<class CLIQUE>
void BayesTree<CLIQUE>::insertRoot(const sharedClique& subtree) {
roots_.push_back(subtree); // Add to roots
fillNodesIndex(subtree); // Populate nodes index
}
/* ************************************************************************* */
// First finds clique marginal then marginalizes that
/* ************************************************************************* */
template<class CLIQUE>
typename BayesTree<CLIQUE>::sharedConditional
BayesTree<CLIQUE>::marginalFactor(Key j, const Eliminate& function) const
{
gttic(BayesTree_marginalFactor);
// get clique containing Key j
sharedClique clique = this->clique(j);
// calculate or retrieve its marginal P(C) = P(F,S)
FactorGraphType cliqueMarginal = clique->marginal2(function);
// Now, marginalize out everything that is not variable j
BayesNetType marginalBN = *cliqueMarginal.marginalMultifrontalBayesNet(
Ordering(cref_list_of<1,Key>(j)), boost::none, function);
// The Bayes net should contain only one conditional for variable j, so return it
return marginalBN.front();
}
/* ************************************************************************* */
// Find two cliques, their joint, then marginalizes
/* ************************************************************************* */
template<class CLIQUE>
typename BayesTree<CLIQUE>::sharedFactorGraph
BayesTree<CLIQUE>::joint(Key j1, Key j2, const Eliminate& function) const
{
gttic(BayesTree_joint);
return boost::make_shared<FactorGraphType>(*jointBayesNet(j1, j2, function));
}
/* ************************************************************************* */
template<class CLIQUE>
typename BayesTree<CLIQUE>::sharedBayesNet
BayesTree<CLIQUE>::jointBayesNet(Key j1, Key j2, const Eliminate& function) const
{
gttic(BayesTree_jointBayesNet);
// get clique C1 and C2
sharedClique C1 = (*this)[j1], C2 = (*this)[j2];
gttic(Lowest_common_ancestor);
// Find lowest common ancestor clique
sharedClique B; {
// Build two paths to the root
FastList<sharedClique> path1, path2; {
sharedClique p = C1;
while(p) {
path1.push_front(p);
p = p->parent();
}
} {
sharedClique p = C2;
while(p) {
path2.push_front(p);
p = p->parent();
}
}
// Find the path intersection
typename FastList<sharedClique>::const_iterator p1 = path1.begin(), p2 = path2.begin();
if(*p1 == *p2)
B = *p1;
while(p1 != path1.end() && p2 != path2.end() && *p1 == *p2) {
B = *p1;
++p1;
++p2;
}
}
gttoc(Lowest_common_ancestor);
// Build joint on all involved variables
FactorGraphType p_BC1C2;
if(B)
{
// Compute marginal on lowest common ancestor clique
gttic(LCA_marginal);
FactorGraphType p_B = B->marginal2(function);
gttoc(LCA_marginal);
// Compute shortcuts of the requested cliques given the lowest common ancestor
gttic(Clique_shortcuts);
BayesNetType p_C1_Bred = C1->shortcut(B, function);
BayesNetType p_C2_Bred = C2->shortcut(B, function);
gttoc(Clique_shortcuts);
// Factor the shortcuts to be conditioned on the full root
// Get the set of variables to eliminate, which is C1\B.
gttic(Full_root_factoring);
boost::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C1_B; {
KeyVector C1_minus_B; {
KeySet C1_minus_B_set(C1->conditional()->beginParents(), C1->conditional()->endParents());
for(const Key j: *B->conditional()) {
C1_minus_B_set.erase(j); }
C1_minus_B.assign(C1_minus_B_set.begin(), C1_minus_B_set.end());
}
// Factor into C1\B | B.
sharedFactorGraph temp_remaining;
boost::tie(p_C1_B, temp_remaining) =
FactorGraphType(p_C1_Bred).eliminatePartialMultifrontal(Ordering(C1_minus_B), function);
}
boost::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C2_B; {
KeyVector C2_minus_B; {
KeySet C2_minus_B_set(C2->conditional()->beginParents(), C2->conditional()->endParents());
for(const Key j: *B->conditional()) {
C2_minus_B_set.erase(j); }
C2_minus_B.assign(C2_minus_B_set.begin(), C2_minus_B_set.end());
}
// Factor into C2\B | B.
sharedFactorGraph temp_remaining;
boost::tie(p_C2_B, temp_remaining) =
FactorGraphType(p_C2_Bred).eliminatePartialMultifrontal(Ordering(C2_minus_B), function);
}
gttoc(Full_root_factoring);
gttic(Variable_joint);
p_BC1C2 += p_B;
p_BC1C2 += *p_C1_B;
p_BC1C2 += *p_C2_B;
if(C1 != B)
p_BC1C2 += C1->conditional();
if(C2 != B)
p_BC1C2 += C2->conditional();
gttoc(Variable_joint);
}
else
{
// The nodes have no common ancestor, they're in different trees, so they're joint is just the
// product of their marginals.
gttic(Disjoint_marginals);
p_BC1C2 += C1->marginal2(function);
p_BC1C2 += C2->marginal2(function);
gttoc(Disjoint_marginals);
}
// now, marginalize out everything that is not variable j1 or j2
return p_BC1C2.marginalMultifrontalBayesNet(Ordering(cref_list_of<2,Key>(j1)(j2)), boost::none, function);
}
/* ************************************************************************* */
template<class CLIQUE>
void BayesTree<CLIQUE>::clear() {
// Remove all nodes and clear the root pointer
nodes_.clear();
roots_.clear();
}
/* ************************************************************************* */
template<class CLIQUE>
void BayesTree<CLIQUE>::deleteCachedShortcuts() {
for(const sharedClique& root: roots_) {
root->deleteCachedShortcuts();
}
}
/* ************************************************************************* */
template<class CLIQUE>
void BayesTree<CLIQUE>::removeClique(sharedClique clique)
{
if (clique->isRoot()) {
typename Roots::iterator root = std::find(roots_.begin(), roots_.end(), clique);
if(root != roots_.end())
roots_.erase(root);
} else { // detach clique from parent
sharedClique parent = clique->parent_.lock();
typename Roots::iterator child = std::find(parent->children.begin(), parent->children.end(), clique);
assert(child != parent->children.end());
parent->children.erase(child);
}
// orphan my children
for(sharedClique child: clique->children)
child->parent_ = typename Clique::weak_ptr();
for(Key j: clique->conditional()->frontals()) {
nodes_.unsafe_erase(j);
}
}
/* ************************************************************************* */
template<class CLIQUE>
void BayesTree<CLIQUE>::removePath(sharedClique clique, BayesNetType& bn, Cliques& orphans)
{
// base case is NULL, if so we do nothing and return empties above
if (clique) {
// remove the clique from orphans in case it has been added earlier
orphans.remove(clique);
// remove me
this->removeClique(clique);
// remove path above me
this->removePath(typename Clique::shared_ptr(clique->parent_.lock()), bn, orphans);
// add children to list of orphans (splice also removed them from clique->children_)
orphans.insert(orphans.begin(), clique->children.begin(), clique->children.end());
clique->children.clear();
bn.push_back(clique->conditional_);
}
}
/* ************************************************************************* */
template<class CLIQUE>
void BayesTree<CLIQUE>::removeTop(const KeyVector& keys, BayesNetType& bn, Cliques& orphans)
{
// process each key of the new factor
for(const Key& j: keys)
{
// get the clique
// TODO: Nodes will be searched again in removeClique
typename Nodes::const_iterator node = nodes_.find(j);
if(node != nodes_.end()) {
// remove path from clique to root
this->removePath(node->second, bn, orphans);
}
}
// Delete cachedShortcuts for each orphan subtree
//TODO: Consider Improving
for(sharedClique& orphan: orphans)
orphan->deleteCachedShortcuts();
}
/* ************************************************************************* */
template<class CLIQUE>
typename BayesTree<CLIQUE>::Cliques BayesTree<CLIQUE>::removeSubtree(
const sharedClique& subtree)
{
// Result clique list
Cliques cliques;
cliques.push_back(subtree);
// Remove the first clique from its parents
if(!subtree->isRoot())
subtree->parent()->children.erase(std::find(
subtree->parent()->children.begin(), subtree->parent()->children.end(), subtree));
else
roots_.erase(std::find(roots_.begin(), roots_.end(), subtree));
// Add all subtree cliques and erase the children and parent of each
for(typename Cliques::iterator clique = cliques.begin(); clique != cliques.end(); ++clique)
{
// Add children
for(const sharedClique& child: (*clique)->children) {
cliques.push_back(child); }
// Delete cached shortcuts
(*clique)->deleteCachedShortcutsNonRecursive();
// Remove this node from the nodes index
for(Key j: (*clique)->conditional()->frontals()) {
nodes_.unsafe_erase(j); }
// Erase the parent and children pointers
(*clique)->parent_.reset();
(*clique)->children.clear();
}
return cliques;
}
}
/// namespace gtsam