/* ---------------------------------------------------------------------------- * GTSAM Copyright 2010, Georgia Tech Research Corporation, * Atlanta, Georgia 30332-0415 * All Rights Reserved * Authors: Frank Dellaert, et al. (see THANKS for the full author list) * See LICENSE for the license information * -------------------------------------------------------------------------- */ /** * @file BayesTree-inl.h * @brief Bayes Tree is a tree of cliques of a Bayes Chain * @author Frank Dellaert * @author Michael Kaess * @author Viorela Ila * @author Richard Roberts */ #pragma once #include #include #include #include #include #include #include using boost::assign::cref_list_of; namespace gtsam { /* ************************************************************************* */ template BayesTreeCliqueData BayesTree::getCliqueData() const { BayesTreeCliqueData data; for(const sharedClique& root: roots_) getCliqueData(data, root); return data; } /* ************************************************************************* */ template void BayesTree::getCliqueData(BayesTreeCliqueData& data, sharedClique clique) const { data.conditionalSizes.push_back(clique->conditional()->nrFrontals()); data.separatorSizes.push_back(clique->conditional()->nrParents()); for(sharedClique c: clique->children) { getCliqueData(data, c); } } /* ************************************************************************* */ template size_t BayesTree::numCachedSeparatorMarginals() const { size_t count = 0; for(const sharedClique& root: roots_) count += root->numCachedSeparatorMarginals(); return count; } /* ************************************************************************* */ template void BayesTree::saveGraph(const std::string &s, const KeyFormatter& keyFormatter) const { if (roots_.empty()) throw std::invalid_argument("the root of Bayes tree has not been initialized!"); std::ofstream of(s.c_str()); of<< "digraph G{\n"; for(const sharedClique& root: roots_) saveGraph(of, root, keyFormatter); of<<"}"; of.close(); } /* ************************************************************************* */ template void BayesTree::saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& indexFormatter, int parentnum) const { static int num = 0; bool first = true; std::stringstream out; out << num; std::string parent = out.str(); parent += "[label=\""; for(Key index: clique->conditional_->frontals()) { if(!first) parent += ","; first = false; parent += indexFormatter(index); } if(clique->parent()){ parent += " : "; s << parentnum << "->" << num << "\n"; } first = true; for(Key sep: clique->conditional_->parents()) { if(!first) parent += ","; first = false; parent += indexFormatter(sep); } parent += "\"];\n"; s << parent; parentnum = num; for(sharedClique c: clique->children) { num++; saveGraph(s, c, indexFormatter, parentnum); } } /* ************************************************************************* */ template size_t BayesTree::size() const { size_t size = 0; for(const sharedClique& clique: roots_) size += clique->treeSize(); return size; } /* ************************************************************************* */ template void BayesTree::addClique(const sharedClique& clique, const sharedClique& parent_clique) { for(Key j: clique->conditional()->frontals()) nodes_[j] = clique; if (parent_clique != NULL) { clique->parent_ = parent_clique; parent_clique->children.push_back(clique); } else { roots_.push_back(clique); } } /* ************************************************************************* */ // TODO: Clean up namespace { template int _pushClique(FactorGraph& fg, const boost::shared_ptr& clique) { fg.push_back(clique->conditional_); return 0; } template struct _pushCliqueFunctor { _pushCliqueFunctor(FactorGraph& graph_) : graph(graph_) {} FactorGraph& graph; int operator()(const boost::shared_ptr& clique, int dummy) { graph.push_back(clique->conditional_); return 0; } }; } /* ************************************************************************* */ template void BayesTree::addFactorsToGraph(FactorGraph& graph) const { // Traverse the BayesTree and add all conditionals to this graph int data = 0; // Unused _pushCliqueFunctor functor(graph); treeTraversal::DepthFirstForest(*this, data, functor); // FIXME: sort of works? // treeTraversal::DepthFirstForest(*this, data, boost::bind(&_pushClique, boost::ref(graph), _1)); } /* ************************************************************************* */ template BayesTree::BayesTree(const This& other) { *this = other; } /* ************************************************************************* */ namespace { template boost::shared_ptr BayesTreeCloneForestVisitorPre(const boost::shared_ptr& node, const boost::shared_ptr& parentPointer) { // Clone the current node and add it to its cloned parent boost::shared_ptr clone = boost::make_shared(*node); clone->children.clear(); clone->parent_ = parentPointer; parentPointer->children.push_back(clone); return clone; } } /* ************************************************************************* */ template BayesTree& BayesTree::operator=(const This& other) { this->clear(); boost::shared_ptr rootContainer = boost::make_shared(); treeTraversal::DepthFirstForest(other, rootContainer, BayesTreeCloneForestVisitorPre); for(const sharedClique& root: rootContainer->children) { root->parent_ = typename Clique::weak_ptr(); // Reset the parent since it's set to the dummy clique insertRoot(root); } return *this; } /* ************************************************************************* */ template void BayesTree::print(const std::string& s, const KeyFormatter& keyFormatter) const { std::cout << s << ": cliques: " << size() << ", variables: " << nodes_.size() << std::endl; treeTraversal::PrintForest(*this, s, keyFormatter); } /* ************************************************************************* */ // binary predicate to test equality of a pair for use in equals template bool check_sharedCliques( const std::pair::sharedClique>& v1, const std::pair::sharedClique>& v2 ) { return v1.first == v2.first && ((!v1.second && !v2.second) || (v1.second && v2.second && v1.second->equals(*v2.second))); } /* ************************************************************************* */ template bool BayesTree::equals(const BayesTree& other, double tol) const { return size()==other.size() && std::equal(nodes_.begin(), nodes_.end(), other.nodes_.begin(), &check_sharedCliques); } /* ************************************************************************* */ template template Key BayesTree::findParentClique(const CONTAINER& parents) const { typename CONTAINER::const_iterator lowestOrderedParent = min_element(parents.begin(), parents.end()); assert(lowestOrderedParent != parents.end()); return *lowestOrderedParent; } /* ************************************************************************* */ template void BayesTree::fillNodesIndex(const sharedClique& subtree) { // Add each frontal variable of this root node for(const Key& j: subtree->conditional()->frontals()) { bool inserted = nodes_.insert(std::make_pair(j, subtree)).second; assert(inserted); (void)inserted; } // Fill index for each child typedef typename BayesTree::sharedClique sharedClique; for(const sharedClique& child: subtree->children) { fillNodesIndex(child); } } /* ************************************************************************* */ template void BayesTree::insertRoot(const sharedClique& subtree) { roots_.push_back(subtree); // Add to roots fillNodesIndex(subtree); // Populate nodes index } /* ************************************************************************* */ // First finds clique marginal then marginalizes that /* ************************************************************************* */ template typename BayesTree::sharedConditional BayesTree::marginalFactor(Key j, const Eliminate& function) const { gttic(BayesTree_marginalFactor); // get clique containing Key j sharedClique clique = this->clique(j); // calculate or retrieve its marginal P(C) = P(F,S) FactorGraphType cliqueMarginal = clique->marginal2(function); // Now, marginalize out everything that is not variable j BayesNetType marginalBN = *cliqueMarginal.marginalMultifrontalBayesNet( Ordering(cref_list_of<1,Key>(j)), boost::none, function); // The Bayes net should contain only one conditional for variable j, so return it return marginalBN.front(); } /* ************************************************************************* */ // Find two cliques, their joint, then marginalizes /* ************************************************************************* */ template typename BayesTree::sharedFactorGraph BayesTree::joint(Key j1, Key j2, const Eliminate& function) const { gttic(BayesTree_joint); return boost::make_shared(*jointBayesNet(j1, j2, function)); } /* ************************************************************************* */ template typename BayesTree::sharedBayesNet BayesTree::jointBayesNet(Key j1, Key j2, const Eliminate& function) const { gttic(BayesTree_jointBayesNet); // get clique C1 and C2 sharedClique C1 = (*this)[j1], C2 = (*this)[j2]; gttic(Lowest_common_ancestor); // Find lowest common ancestor clique sharedClique B; { // Build two paths to the root FastList path1, path2; { sharedClique p = C1; while(p) { path1.push_front(p); p = p->parent(); } } { sharedClique p = C2; while(p) { path2.push_front(p); p = p->parent(); } } // Find the path intersection typename FastList::const_iterator p1 = path1.begin(), p2 = path2.begin(); if(*p1 == *p2) B = *p1; while(p1 != path1.end() && p2 != path2.end() && *p1 == *p2) { B = *p1; ++p1; ++p2; } } gttoc(Lowest_common_ancestor); // Build joint on all involved variables FactorGraphType p_BC1C2; if(B) { // Compute marginal on lowest common ancestor clique gttic(LCA_marginal); FactorGraphType p_B = B->marginal2(function); gttoc(LCA_marginal); // Compute shortcuts of the requested cliques given the lowest common ancestor gttic(Clique_shortcuts); BayesNetType p_C1_Bred = C1->shortcut(B, function); BayesNetType p_C2_Bred = C2->shortcut(B, function); gttoc(Clique_shortcuts); // Factor the shortcuts to be conditioned on the full root // Get the set of variables to eliminate, which is C1\B. gttic(Full_root_factoring); boost::shared_ptr p_C1_B; { FastVector C1_minus_B; { KeySet C1_minus_B_set(C1->conditional()->beginParents(), C1->conditional()->endParents()); for(const Key j: *B->conditional()) { C1_minus_B_set.erase(j); } C1_minus_B.assign(C1_minus_B_set.begin(), C1_minus_B_set.end()); } // Factor into C1\B | B. sharedFactorGraph temp_remaining; boost::tie(p_C1_B, temp_remaining) = FactorGraphType(p_C1_Bred).eliminatePartialMultifrontal(Ordering(C1_minus_B), function); } boost::shared_ptr p_C2_B; { FastVector C2_minus_B; { KeySet C2_minus_B_set(C2->conditional()->beginParents(), C2->conditional()->endParents()); for(const Key j: *B->conditional()) { C2_minus_B_set.erase(j); } C2_minus_B.assign(C2_minus_B_set.begin(), C2_minus_B_set.end()); } // Factor into C2\B | B. sharedFactorGraph temp_remaining; boost::tie(p_C2_B, temp_remaining) = FactorGraphType(p_C2_Bred).eliminatePartialMultifrontal(Ordering(C2_minus_B), function); } gttoc(Full_root_factoring); gttic(Variable_joint); p_BC1C2 += p_B; p_BC1C2 += *p_C1_B; p_BC1C2 += *p_C2_B; if(C1 != B) p_BC1C2 += C1->conditional(); if(C2 != B) p_BC1C2 += C2->conditional(); gttoc(Variable_joint); } else { // The nodes have no common ancestor, they're in different trees, so they're joint is just the // product of their marginals. gttic(Disjoint_marginals); p_BC1C2 += C1->marginal2(function); p_BC1C2 += C2->marginal2(function); gttoc(Disjoint_marginals); } // now, marginalize out everything that is not variable j1 or j2 return p_BC1C2.marginalMultifrontalBayesNet(Ordering(cref_list_of<2,Key>(j1)(j2)), boost::none, function); } /* ************************************************************************* */ template void BayesTree::clear() { // Remove all nodes and clear the root pointer nodes_.clear(); roots_.clear(); } /* ************************************************************************* */ template void BayesTree::deleteCachedShortcuts() { for(const sharedClique& root: roots_) { root->deleteCachedShortcuts(); } } /* ************************************************************************* */ template void BayesTree::removeClique(sharedClique clique) { if (clique->isRoot()) { typename Roots::iterator root = std::find(roots_.begin(), roots_.end(), clique); if(root != roots_.end()) roots_.erase(root); } else { // detach clique from parent sharedClique parent = clique->parent_.lock(); typename Roots::iterator child = std::find(parent->children.begin(), parent->children.end(), clique); assert(child != parent->children.end()); parent->children.erase(child); } // orphan my children for(sharedClique child: clique->children) child->parent_ = typename Clique::weak_ptr(); for(Key j: clique->conditional()->frontals()) { nodes_.unsafe_erase(j); } } /* ************************************************************************* */ template void BayesTree::removePath(sharedClique clique, BayesNetType& bn, Cliques& orphans) { // base case is NULL, if so we do nothing and return empties above if (clique) { // remove the clique from orphans in case it has been added earlier orphans.remove(clique); // remove me this->removeClique(clique); // remove path above me this->removePath(typename Clique::shared_ptr(clique->parent_.lock()), bn, orphans); // add children to list of orphans (splice also removed them from clique->children_) orphans.insert(orphans.begin(), clique->children.begin(), clique->children.end()); clique->children.clear(); bn.push_back(clique->conditional_); } } /* ************************************************************************* */ template void BayesTree::removeTop(const FastVector& keys, BayesNetType& bn, Cliques& orphans) { // process each key of the new factor for(const Key& j: keys) { // get the clique // TODO: Nodes will be searched again in removeClique typename Nodes::const_iterator node = nodes_.find(j); if(node != nodes_.end()) { // remove path from clique to root this->removePath(node->second, bn, orphans); } } // Delete cachedShortcuts for each orphan subtree //TODO: Consider Improving for(sharedClique& orphan: orphans) orphan->deleteCachedShortcuts(); } /* ************************************************************************* */ template typename BayesTree::Cliques BayesTree::removeSubtree( const sharedClique& subtree) { // Result clique list Cliques cliques; cliques.push_back(subtree); // Remove the first clique from its parents if(!subtree->isRoot()) subtree->parent()->children.erase(std::find( subtree->parent()->children.begin(), subtree->parent()->children.end(), subtree)); else roots_.erase(std::find(roots_.begin(), roots_.end(), subtree)); // Add all subtree cliques and erase the children and parent of each for(typename Cliques::iterator clique = cliques.begin(); clique != cliques.end(); ++clique) { // Add children for(const sharedClique& child: (*clique)->children) { cliques.push_back(child); } // Delete cached shortcuts (*clique)->deleteCachedShortcutsNonRecursive(); // Remove this node from the nodes index for(Key j: (*clique)->conditional()->frontals()) { nodes_.unsafe_erase(j); } // Erase the parent and children pointers (*clique)->parent_.reset(); (*clique)->children.clear(); } return cliques; } } /// namespace gtsam