/* ---------------------------------------------------------------------------- * GTSAM Copyright 2010, Georgia Tech Research Corporation, * Atlanta, Georgia 30332-0415 * All Rights Reserved * Authors: Frank Dellaert, et al. (see THANKS for the full author list) * See LICENSE for the license information * -------------------------------------------------------------------------- */ /** * @file BayesTree-inl.h * @brief Bayes Tree is a tree of cliques of a Bayes Chain * @author Frank Dellaert * @author Michael Kaess * @author Viorela Ila * @author Richard Roberts */ #pragma once #include #include #include #include #include #include #include #include using boost::assign::cref_list_of; namespace gtsam { /* ************************************************************************* */ template BayesTreeCliqueData BayesTree::getCliqueData() const { BayesTreeCliqueData data; BOOST_FOREACH(const sharedClique& root, roots_) getCliqueData(data, root); return data; } /* ************************************************************************* */ template void BayesTree::getCliqueData(BayesTreeCliqueData& data, sharedClique clique) const { data.conditionalSizes.push_back(clique->conditional()->nrFrontals()); data.separatorSizes.push_back(clique->conditional()->nrParents()); BOOST_FOREACH(sharedClique c, clique->children) { getCliqueData(data, c); } } /* ************************************************************************* */ template size_t BayesTree::numCachedSeparatorMarginals() const { size_t count = 0; BOOST_FOREACH(const sharedClique& root, roots_) count += root->numCachedSeparatorMarginals(); return count; } /* ************************************************************************* */ template void BayesTree::saveGraph(const std::string &s, const KeyFormatter& keyFormatter) const { if (roots_.empty()) throw std::invalid_argument("the root of Bayes tree has not been initialized!"); std::ofstream of(s.c_str()); of<< "digraph G{\n"; BOOST_FOREACH(const sharedClique& root, roots_) saveGraph(of, root, keyFormatter); of<<"}"; of.close(); } /* ************************************************************************* */ template void BayesTree::saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& indexFormatter, int parentnum) const { static int num = 0; bool first = true; std::stringstream out; out << num; std::string parent = out.str(); parent += "[label=\""; BOOST_FOREACH(Key index, clique->conditional_->frontals()) { if(!first) parent += ","; first = false; parent += indexFormatter(index); } if(clique->parent()){ parent += " : "; s << parentnum << "->" << num << "\n"; } first = true; BOOST_FOREACH(Key sep, clique->conditional_->parents()) { if(!first) parent += ","; first = false; parent += indexFormatter(sep); } parent += "\"];\n"; s << parent; parentnum = num; BOOST_FOREACH(sharedClique c, clique->children) { num++; saveGraph(s, c, indexFormatter, parentnum); } } /* ************************************************************************* */ template size_t BayesTree::size() const { size_t size = 0; BOOST_FOREACH(const sharedClique& clique, roots_) size += clique->treeSize(); return size; } /* ************************************************************************* */ template void BayesTree::addClique(const sharedClique& clique, const sharedClique& parent_clique) { BOOST_FOREACH(Index j, clique->conditional()->frontals()) nodes_[j] = clique; if (parent_clique != NULL) { clique->parent_ = parent_clique; parent_clique->children.push_back(clique); } else { roots_.push_back(clique); } } /* ************************************************************************* */ // TODO: Clean up namespace { template int _pushClique(FactorGraph& fg, const boost::shared_ptr& clique) { fg.push_back(clique->conditional_); return 0; } template struct _pushCliqueFunctor { _pushCliqueFunctor(FactorGraph& graph_) : graph(graph_) {} FactorGraph& graph; int operator()(const boost::shared_ptr& clique, int dummy) { graph.push_back(clique->conditional_); return 0; } }; } /* ************************************************************************* */ template void BayesTree::addFactorsToGraph(FactorGraph& graph) const { // Traverse the BayesTree and add all conditionals to this graph int data = 0; // Unused _pushCliqueFunctor functor(graph); treeTraversal::DepthFirstForest(*this, data, functor); // FIXME: sort of works? // treeTraversal::DepthFirstForest(*this, data, boost::bind(&_pushClique, boost::ref(graph), _1)); } /* ************************************************************************* */ template BayesTree::BayesTree(const This& other) { *this = other; } /* ************************************************************************* */ namespace { template boost::shared_ptr BayesTreeCloneForestVisitorPre(const boost::shared_ptr& node, const boost::shared_ptr& parentPointer) { // Clone the current node and add it to its cloned parent boost::shared_ptr clone = boost::make_shared(*node); clone->children.clear(); clone->parent_ = parentPointer; parentPointer->children.push_back(clone); return clone; } } /* ************************************************************************* */ template BayesTree& BayesTree::operator=(const This& other) { this->clear(); boost::shared_ptr rootContainer = boost::make_shared(); treeTraversal::DepthFirstForest(other, rootContainer, BayesTreeCloneForestVisitorPre); BOOST_FOREACH(const sharedClique& root, rootContainer->children) { root->parent_ = typename Clique::weak_ptr(); // Reset the parent since it's set to the dummy clique insertRoot(root); } return *this; } /* ************************************************************************* */ template void BayesTree::print(const std::string& s, const KeyFormatter& keyFormatter) const { std::cout << s << ": cliques: " << size() << ", variables: " << nodes_.size() << std::endl; treeTraversal::PrintForest(*this, s, keyFormatter); } /* ************************************************************************* */ // binary predicate to test equality of a pair for use in equals template bool check_sharedCliques( const std::pair::sharedClique>& v1, const std::pair::sharedClique>& v2 ) { return v1.first == v2.first && ((!v1.second && !v2.second) || (v1.second && v2.second && v1.second->equals(*v2.second))); } /* ************************************************************************* */ template bool BayesTree::equals(const BayesTree& other, double tol) const { return size()==other.size() && std::equal(nodes_.begin(), nodes_.end(), other.nodes_.begin(), &check_sharedCliques); } /* ************************************************************************* */ template template Key BayesTree::findParentClique(const CONTAINER& parents) const { typename CONTAINER::const_iterator lowestOrderedParent = min_element(parents.begin(), parents.end()); assert(lowestOrderedParent != parents.end()); return *lowestOrderedParent; } /* ************************************************************************* */ template void BayesTree::fillNodesIndex(const sharedClique& subtree) { // Add each frontal variable of this root node BOOST_FOREACH(const Key& j, subtree->conditional()->frontals()) { bool inserted = nodes_.insert(std::make_pair(j, subtree)).second; assert(inserted); (void)inserted; } // Fill index for each child typedef typename BayesTree::sharedClique sharedClique; BOOST_FOREACH(const sharedClique& child, subtree->children) { fillNodesIndex(child); } } /* ************************************************************************* */ template void BayesTree::insertRoot(const sharedClique& subtree) { roots_.push_back(subtree); // Add to roots fillNodesIndex(subtree); // Populate nodes index } /* ************************************************************************* */ // First finds clique marginal then marginalizes that /* ************************************************************************* */ template typename BayesTree::sharedConditional BayesTree::marginalFactor(Key j, const Eliminate& function) const { gttic(BayesTree_marginalFactor); // get clique containing Index j sharedClique clique = this->clique(j); // calculate or retrieve its marginal P(C) = P(F,S) FactorGraphType cliqueMarginal = clique->marginal2(function); // Now, marginalize out everything that is not variable j BayesNetType marginalBN = *cliqueMarginal.marginalMultifrontalBayesNet( Ordering(cref_list_of<1,Key>(j)), boost::none, function); // The Bayes net should contain only one conditional for variable j, so return it return marginalBN.front(); } /* ************************************************************************* */ // Find two cliques, their joint, then marginalizes /* ************************************************************************* */ template typename BayesTree::sharedFactorGraph BayesTree::joint(Key j1, Key j2, const Eliminate& function) const { gttic(BayesTree_joint); return boost::make_shared(*jointBayesNet(j1, j2, function)); } /* ************************************************************************* */ template typename BayesTree::sharedBayesNet BayesTree::jointBayesNet(Key j1, Key j2, const Eliminate& function) const { gttic(BayesTree_jointBayesNet); // get clique C1 and C2 sharedClique C1 = (*this)[j1], C2 = (*this)[j2]; gttic(Lowest_common_ancestor); // Find lowest common ancestor clique sharedClique B; { // Build two paths to the root FastList path1, path2; { sharedClique p = C1; while(p) { path1.push_front(p); p = p->parent(); } } { sharedClique p = C2; while(p) { path2.push_front(p); p = p->parent(); } } // Find the path intersection typename FastList::const_iterator p1 = path1.begin(), p2 = path2.begin(); if(*p1 == *p2) B = *p1; while(p1 != path1.end() && p2 != path2.end() && *p1 == *p2) { B = *p1; ++p1; ++p2; } } if(!B) throw std::invalid_argument("BayesTree::jointBayesNet does not yet work for joints across a forest"); gttoc(Lowest_common_ancestor); // Compute marginal on lowest common ancestor clique gttic(LCA_marginal); FactorGraphType p_B = B->marginal2(function); gttoc(LCA_marginal); // Compute shortcuts of the requested cliques given the lowest common ancestor gttic(Clique_shortcuts); BayesNetType p_C1_Bred = C1->shortcut(B, function); BayesNetType p_C2_Bred = C2->shortcut(B, function); gttoc(Clique_shortcuts); // Factor the shortcuts to be conditioned on the full root // Get the set of variables to eliminate, which is C1\B. gttic(Full_root_factoring); boost::shared_ptr p_C1_B; { std::vector C1_minus_B; { FastSet C1_minus_B_set(C1->conditional()->beginParents(), C1->conditional()->endParents()); BOOST_FOREACH(const Index j, *B->conditional()) { C1_minus_B_set.erase(j); } C1_minus_B.assign(C1_minus_B_set.begin(), C1_minus_B_set.end()); } // Factor into C1\B | B. sharedFactorGraph temp_remaining; boost::tie(p_C1_B, temp_remaining) = FactorGraphType(p_C1_Bred).eliminatePartialMultifrontal(Ordering(C1_minus_B), function); } boost::shared_ptr p_C2_B; { std::vector C2_minus_B; { FastSet C2_minus_B_set(C2->conditional()->beginParents(), C2->conditional()->endParents()); BOOST_FOREACH(const Index j, *B->conditional()) { C2_minus_B_set.erase(j); } C2_minus_B.assign(C2_minus_B_set.begin(), C2_minus_B_set.end()); } // Factor into C2\B | B. sharedFactorGraph temp_remaining; boost::tie(p_C2_B, temp_remaining) = FactorGraphType(p_C2_Bred).eliminatePartialMultifrontal(Ordering(C2_minus_B), function); } gttoc(Full_root_factoring); gttic(Variable_joint); // Build joint on all involved variables FactorGraphType p_BC1C2; p_BC1C2 += p_B; p_BC1C2 += *p_C1_B; p_BC1C2 += *p_C2_B; if(C1 != B) p_BC1C2 += C1->conditional(); if(C2 != B) p_BC1C2 += C2->conditional(); // now, marginalize out everything that is not variable j1 or j2 return p_BC1C2.marginalMultifrontalBayesNet(Ordering(cref_list_of<2,Key>(j1)(j2)), boost::none, function); } /* ************************************************************************* */ template void BayesTree::clear() { // Remove all nodes and clear the root pointer nodes_.clear(); roots_.clear(); } /* ************************************************************************* */ template void BayesTree::deleteCachedShortcuts() { BOOST_FOREACH(const sharedClique& root, roots_) { root->deleteCachedShortcuts(); } } /* ************************************************************************* */ template void BayesTree::removeClique(sharedClique clique) { if (clique->isRoot()) { typename std::vector::iterator root = std::find(roots_.begin(), roots_.end(), clique); if(root != roots_.end()) roots_.erase(root); } else { // detach clique from parent sharedClique parent = clique->parent_.lock(); typename std::vector::iterator child = std::find(parent->children.begin(), parent->children.end(), clique); assert(child != parent->children.end()); parent->children.erase(child); } // orphan my children BOOST_FOREACH(sharedClique child, clique->children) child->parent_ = typename Clique::weak_ptr(); BOOST_FOREACH(Key j, clique->conditional()->frontals()) { nodes_.erase(j); } } /* ************************************************************************* */ template void BayesTree::removePath(sharedClique clique, BayesNetType& bn, Cliques& orphans) { // base case is NULL, if so we do nothing and return empties above if (clique) { // remove the clique from orphans in case it has been added earlier orphans.remove(clique); // remove me this->removeClique(clique); // remove path above me this->removePath(typename Clique::shared_ptr(clique->parent_.lock()), bn, orphans); // add children to list of orphans (splice also removed them from clique->children_) orphans.insert(orphans.begin(), clique->children.begin(), clique->children.end()); clique->children.clear(); bn.push_back(clique->conditional_); } } /* ************************************************************************* */ template void BayesTree::removeTop(const std::vector& keys, BayesNetType& bn, Cliques& orphans) { // process each key of the new factor BOOST_FOREACH(const Key& j, keys) { // get the clique // TODO: Nodes will be searched again in removeClique typename Nodes::const_iterator node = nodes_.find(j); if(node != nodes_.end()) { // remove path from clique to root this->removePath(node->second, bn, orphans); } } // Delete cachedShortcuts for each orphan subtree //TODO: Consider Improving BOOST_FOREACH(sharedClique& orphan, orphans) orphan->deleteCachedShortcuts(); } /* ************************************************************************* */ template typename BayesTree::Cliques BayesTree::removeSubtree( const sharedClique& subtree) { // Result clique list Cliques cliques; cliques.push_back(subtree); // Remove the first clique from its parents if(!subtree->isRoot()) subtree->parent()->children.erase(std::find( subtree->parent()->children.begin(), subtree->parent()->children.end(), subtree)); else roots_.erase(std::find(roots_.begin(), roots_.end(), subtree)); // Add all subtree cliques and erase the children and parent of each for(typename Cliques::iterator clique = cliques.begin(); clique != cliques.end(); ++clique) { // Add children BOOST_FOREACH(const sharedClique& child, (*clique)->children) { cliques.push_back(child); } // Delete cached shortcuts (*clique)->deleteCachedShortcutsNonRecursive(); // Remove this node from the nodes index BOOST_FOREACH(Key j, (*clique)->conditional()->frontals()) { nodes_.erase(j); } // Erase the parent and children pointers (*clique)->parent_.reset(); (*clique)->children.clear(); } return cliques; } } /// namespace gtsam