diff --git a/gtsam/inference/BayesTreeUnordered-inst.h b/gtsam/inference/BayesTreeUnordered-inst.h new file mode 100644 index 000000000..3ebb502f2 --- /dev/null +++ b/gtsam/inference/BayesTreeUnordered-inst.h @@ -0,0 +1,815 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file BayesTree-inl.h + * @brief Bayes Tree is a tree of cliques of a Bayes Chain + * @author Frank Dellaert + * @author Michael Kaess + * @author Viorela Ila + * @author Richard Roberts + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include // for operator += +using boost::assign::operator+=; +#include + +namespace gtsam { + + /* ************************************************************************* */ + template + typename BayesTree::CliqueData + BayesTree::getCliqueData() const { + CliqueData data; + getCliqueData(data, root_); + return data; + } + + /* ************************************************************************* */ + template + void BayesTree::getCliqueData(CliqueData& data, sharedClique clique) const { + data.conditionalSizes.push_back((*clique)->nrFrontals()); + data.separatorSizes.push_back((*clique)->nrParents()); + BOOST_FOREACH(sharedClique c, clique->children_) { + getCliqueData(data, c); + } + } + + /* ************************************************************************* */ + template + size_t BayesTree::numCachedSeparatorMarginals() const { + return (root_) ? root_->numCachedSeparatorMarginals() : 0; + } + + /* ************************************************************************* */ + template + void BayesTree::saveGraph(const std::string &s, const IndexFormatter& indexFormatter) const { + if (!root_.get()) throw std::invalid_argument("the root of Bayes tree has not been initialized!"); + std::ofstream of(s.c_str()); + of<< "digraph G{\n"; + saveGraph(of, root_, indexFormatter); + of<<"}"; + of.close(); + } + + /* ************************************************************************* */ + template + void BayesTree::saveGraph(std::ostream &s, sharedClique clique, const IndexFormatter& indexFormatter, int parentnum) const { + static int num = 0; + bool first = true; + std::stringstream out; + out << num; + std::string parent = out.str(); + parent += "[label=\""; + + BOOST_FOREACH(Index index, clique->conditional_->frontals()) { + if(!first) parent += ","; first = false; + parent += indexFormatter(index); + } + + if( clique != root_){ + parent += " : "; + s << parentnum << "->" << num << "\n"; + } + + first = true; + BOOST_FOREACH(Index sep, clique->conditional_->parents()) { + if(!first) parent += ","; first = false; + parent += indexFormatter(sep); + } + parent += "\"];\n"; + s << parent; + parentnum = num; + + BOOST_FOREACH(sharedClique c, clique->children_) { + num++; + saveGraph(s, c, indexFormatter, parentnum); + } + } + + + /* ************************************************************************* */ + template + void BayesTree::CliqueStats::print(const std::string& s) const { + std::cout << s + <<"avg Conditional Size: " << avgConditionalSize << std::endl + << "max Conditional Size: " << maxConditionalSize << std::endl + << "avg Separator Size: " << avgSeparatorSize << std::endl + << "max Separator Size: " << maxSeparatorSize << std::endl; + } + + /* ************************************************************************* */ + template + typename BayesTree::CliqueStats + BayesTree::CliqueData::getStats() const { + CliqueStats stats; + + double sum = 0.0; + size_t max = 0; + BOOST_FOREACH(size_t s, conditionalSizes) { + sum += (double)s; + if(s > max) max = s; + } + stats.avgConditionalSize = sum / (double)conditionalSizes.size(); + stats.maxConditionalSize = max; + + sum = 0.0; + max = 1; + BOOST_FOREACH(size_t s, separatorSizes) { + sum += (double)s; + if(s > max) max = s; + } + stats.avgSeparatorSize = sum / (double)separatorSizes.size(); + stats.maxSeparatorSize = max; + + return stats; + } + + /* ************************************************************************* */ + template + void BayesTree::Cliques::print(const std::string& s, const IndexFormatter& indexFormatter) const { + std::cout << s << ":\n"; + BOOST_FOREACH(sharedClique clique, *this) + clique->printTree("", indexFormatter); + } + + /* ************************************************************************* */ + template + bool BayesTree::Cliques::equals(const Cliques& other, double tol) const { + return other == *this; + } + + /* ************************************************************************* */ + template + typename BayesTree::sharedClique + BayesTree::addClique(const sharedConditional& conditional, const sharedClique& parent_clique) { + sharedClique new_clique(new Clique(conditional)); + addClique(new_clique, parent_clique); + return new_clique; + } + + /* ************************************************************************* */ + template + void BayesTree::addClique(const sharedClique& clique, const sharedClique& parent_clique) { + nodes_.resize(std::max((*clique)->lastFrontalKey()+1, nodes_.size())); + BOOST_FOREACH(Index j, (*clique)->frontals()) + nodes_[j] = clique; + if (parent_clique != NULL) { + clique->parent_ = parent_clique; + parent_clique->children_.push_back(clique); + } else { + assert(!root_); + root_ = clique; + } + clique->assertInvariants(); + } + + /* ************************************************************************* */ + template + typename BayesTree::sharedClique BayesTree::addClique( + const sharedConditional& conditional, std::list& child_cliques) { + sharedClique new_clique(new Clique(conditional)); + nodes_.resize(std::max(conditional->lastFrontalKey()+1, nodes_.size())); + BOOST_FOREACH(Index j, conditional->frontals()) + nodes_[j] = new_clique; + new_clique->children_ = child_cliques; + BOOST_FOREACH(sharedClique& child, child_cliques) + child->parent_ = new_clique; + new_clique->assertInvariants(); + return new_clique; + } + + /* ************************************************************************* */ + template + void BayesTree::permuteWithInverse(const Permutation& inversePermutation) { + // recursively permute the cliques and internal conditionals + if (root_) + root_->permuteWithInverse(inversePermutation); + + // need to know what the largest key is to get the right number of cliques + Index maxIndex = *std::max_element(inversePermutation.begin(), inversePermutation.end()); + + // Update the nodes structure + typename BayesTree::Nodes newNodes(maxIndex+1); +// inversePermutation.applyToCollection(newNodes, nodes_); // Uses the forward, rather than inverse permutation + for(size_t i = 0; i < nodes_.size(); ++i) + newNodes[inversePermutation[i]] = nodes_[i]; + + nodes_ = newNodes; + } + + /* ************************************************************************* */ + template + inline void BayesTree::addToCliqueFront(BayesTree& bayesTree, const sharedConditional& conditional, const sharedClique& clique) { + static const bool debug = false; +#ifndef NDEBUG + // Debug check to make sure the conditional variable is ordered lower than + // its parents and that all of its parents are present either in this + // clique or its separator. + BOOST_FOREACH(Index parent, conditional->parents()) { + assert(parent > conditional->lastFrontalKey()); + const Clique& cliquer(*clique); + assert(find(cliquer->begin(), cliquer->end(), parent) != cliquer->end()); + } +#endif + if(debug) conditional->print("Adding conditional "); + if(debug) clique->print("To clique "); + Index j = conditional->lastFrontalKey(); + bayesTree.nodes_.resize(std::max(j+1, bayesTree.nodes_.size())); + bayesTree.nodes_[j] = clique; + FastVector newIndices((*clique)->size() + 1); + newIndices[0] = j; + std::copy((*clique)->begin(), (*clique)->end(), newIndices.begin()+1); + clique->conditional_ = CONDITIONAL::FromKeys(newIndices, (*clique)->nrFrontals() + 1); + if(debug) clique->print("Expanded clique is "); + clique->assertInvariants(); + } + + /* ************************************************************************* */ + template + void BayesTree::removeClique(sharedClique clique) { + + if (clique->isRoot()) + root_.reset(); + else { // detach clique from parent + sharedClique parent = clique->parent_.lock(); + typename FastList::iterator child = std::find(parent->children().begin(), parent->children().end(), clique); + assert(child != parent->children().end()); + parent->children().erase(child); + } + + // orphan my children + BOOST_FOREACH(sharedClique child, clique->children_) + child->parent_ = typename Clique::weak_ptr(); + + BOOST_FOREACH(Index j, clique->conditional()->frontals()) { + nodes_[j].reset(); + } + } + + /* ************************************************************************* */ + template + void BayesTree::recursiveTreeBuild(const boost::shared_ptr >& symbolic, + const std::vector >& conditionals, + const typename BayesTree::sharedClique& parent) { + + // Helper function to build a non-symbolic tree (e.g. Gaussian) using a + // symbolic tree, used in the BT(BN) constructor. + + // Build the current clique + FastList cliqueConditionals; + BOOST_FOREACH(Index j, symbolic->conditional()->frontals()) { + cliqueConditionals.push_back(conditionals[j]); } + typename BayesTree::sharedClique thisClique(new CLIQUE(CONDITIONAL::Combine(cliqueConditionals.begin(), cliqueConditionals.end()))); + + // Add the new clique with the current parent + this->addClique(thisClique, parent); + + // Build the children, whose parent is the new clique + BOOST_FOREACH(const BayesTree::sharedClique& child, symbolic->children()) { + this->recursiveTreeBuild(child, conditionals, thisClique); } + } + + /* ************************************************************************* */ + template + BayesTree::BayesTree(const BayesNet& bayesNet) { + // First generate symbolic BT to determine clique structure + BayesTree sbt(bayesNet); + + // Build index of variables to conditionals + std::vector > conditionals(sbt.root()->conditional()->frontals().back() + 1); + BOOST_FOREACH(const boost::shared_ptr& c, bayesNet) { + if(c->nrFrontals() != 1) + throw std::invalid_argument("BayesTree constructor from BayesNet only supports single frontal variable conditionals"); + if(c->firstFrontalKey() >= conditionals.size()) + throw std::invalid_argument("An inconsistent BayesNet was passed into the BayesTree constructor!"); + if(conditionals[c->firstFrontalKey()]) + throw std::invalid_argument("An inconsistent BayesNet with duplicate frontal variables was passed into the BayesTree constructor!"); + + conditionals[c->firstFrontalKey()] = c; + } + + // Build the new tree + this->recursiveTreeBuild(sbt.root(), conditionals, sharedClique()); + } + + /* ************************************************************************* */ + template<> + inline BayesTree::BayesTree(const BayesNet& bayesNet) { + BayesNet::const_reverse_iterator rit; + for ( rit=bayesNet.rbegin(); rit != bayesNet.rend(); ++rit ) + insert(*this, *rit); + } + + /* ************************************************************************* */ + template + BayesTree::BayesTree(const BayesNet& bayesNet, std::list > subtrees) { + if (bayesNet.size() == 0) + throw std::invalid_argument("BayesTree::insert: empty bayes net!"); + + // get the roots of child subtrees and merge their nodes_ + std::list childRoots; + typedef BayesTree Tree; + BOOST_FOREACH(const Tree& subtree, subtrees) { + nodes_.assign(subtree.nodes_.begin(), subtree.nodes_.end()); + childRoots.push_back(subtree.root()); + } + + // create a new clique and add all the conditionals to the clique + sharedClique new_clique; + typename BayesNet::sharedConditional conditional; + BOOST_REVERSE_FOREACH(conditional, bayesNet) { + if (!new_clique.get()) + new_clique = addClique(conditional,childRoots); + else + addToCliqueFront(*this, conditional, new_clique); + } + + root_ = new_clique; + } + + /* ************************************************************************* */ + template + BayesTree::BayesTree(const This& other) { + *this = other; + } + + /* ************************************************************************* */ + template + BayesTree& BayesTree::operator=(const This& other) { + this->clear(); + other.cloneTo(*this); + return *this; + } + + /* ************************************************************************* */ + template + void BayesTree::print(const std::string& s, const IndexFormatter& indexFormatter) const { + if (root_.use_count() == 0) { + std::cout << "WARNING: BayesTree.print encountered a forest..." << std::endl; + return; + } + std::cout << s << ": clique size == " << size() << ", node size == " << nodes_.size() << std::endl; + if (nodes_.empty()) return; + root_->printTree("", indexFormatter); + } + + /* ************************************************************************* */ + // binary predicate to test equality of a pair for use in equals + template + bool check_sharedCliques( + const typename BayesTree::sharedClique& v1, + const typename BayesTree::sharedClique& v2 + ) { + return (!v1 && !v2) || (v1 && v2 && v1->equals(*v2)); + } + + /* ************************************************************************* */ + template + bool BayesTree::equals(const BayesTree& other, + double tol) const { + return size()==other.size() && + std::equal(nodes_.begin(), nodes_.end(), other.nodes_.begin(), &check_sharedCliques); + } + + /* ************************************************************************* */ + template + template + inline Index BayesTree::findParentClique(const CONTAINER& parents) const { + typename CONTAINER::const_iterator lowestOrderedParent = min_element(parents.begin(), parents.end()); + assert(lowestOrderedParent != parents.end()); + return *lowestOrderedParent; + } + + /* ************************************************************************* */ + template + void BayesTree::insert(BayesTree& bayesTree, const sharedConditional& conditional) + { + static const bool debug = false; + + // get indices and parents + const typename CONDITIONAL::Parents& parents = conditional->parents(); + + if(debug) conditional->print("Adding conditional "); + + // if no parents, start a new root clique + if (parents.empty()) { + if(debug) std::cout << "No parents so making root" << std::endl; + bayesTree.root_ = bayesTree.addClique(conditional); + return; + } + + // otherwise, find the parent clique by using the index data structure + // to find the lowest-ordered parent + Index parentRepresentative = bayesTree.findParentClique(parents); + if(debug) std::cout << "First-eliminated parent is " << parentRepresentative << ", have " << bayesTree.nodes_.size() << " nodes." << std::endl; + sharedClique parent_clique = bayesTree[parentRepresentative]; + if(debug) parent_clique->print("Parent clique is "); + + // if the parents and parent clique have the same size, add to parent clique + if ((*parent_clique)->size() == size_t(parents.size())) { + if(debug) std::cout << "Adding to parent clique" << std::endl; +#ifndef NDEBUG + // Debug check that the parent indices of the new conditional match the indices + // currently in the clique. +// list::const_iterator parent = parents.begin(); +// typename Clique::const_iterator cond = parent_clique->begin(); +// while(parent != parents.end()) { +// assert(cond != parent_clique->end()); +// assert(*parent == (*cond)->key()); +// ++ parent; +// ++ cond; +// } +#endif + addToCliqueFront(bayesTree, conditional, parent_clique); + } else { + if(debug) std::cout << "Starting new clique" << std::endl; + // otherwise, start a new clique and add it to the tree + bayesTree.addClique(conditional,parent_clique); + } + } + + /* ************************************************************************* */ + //TODO: remove this function after removing TSAM.cpp + template + typename BayesTree::sharedClique BayesTree::insert( + const sharedConditional& clique, std::list& children, bool isRootClique) { + + if (clique->nrFrontals() == 0) + throw std::invalid_argument("BayesTree::insert: empty clique!"); + + // create a new clique and add all the conditionals to the clique + sharedClique new_clique = addClique(clique, children); + if (isRootClique) root_ = new_clique; + + return new_clique; + } + + /* ************************************************************************* */ + template + void BayesTree::fillNodesIndex(const sharedClique& subtree) { + // Add each frontal variable of this root node + BOOST_FOREACH(const Index& j, subtree->conditional()->frontals()) { nodes_[j] = subtree; } + // Fill index for each child + typedef typename BayesTree::sharedClique sharedClique; + BOOST_FOREACH(const sharedClique& child, subtree->children_) { + fillNodesIndex(child); } + } + + /* ************************************************************************* */ + template + void BayesTree::insert(const sharedClique& subtree) { + if(subtree) { + // Find the parent clique of the new subtree. By the running intersection + // property, those separator variables in the subtree that are ordered + // lower than the highest frontal variable of the subtree root will all + // appear in the separator of the subtree root. + if(subtree->conditional()->parents().empty()) { + assert(!root_); + root_ = subtree; + } else { + Index parentRepresentative = findParentClique(subtree->conditional()->parents()); + sharedClique parent = (*this)[parentRepresentative]; + parent->children_ += subtree; + subtree->parent_ = parent; // set new parent! + } + + // Now fill in the nodes index + if(nodes_.size() == 0 || + *std::max_element(subtree->conditional()->beginFrontals(), subtree->conditional()->endFrontals()) > (nodes_.size() - 1)) { + nodes_.resize(subtree->conditional()->lastFrontalKey() + 1); + } + fillNodesIndex(subtree); + } + } + + /* ************************************************************************* */ + // First finds clique marginal then marginalizes that + /* ************************************************************************* */ + template + typename CONDITIONAL::FactorType::shared_ptr BayesTree::marginalFactor( + Index j, Eliminate function) const + { + gttic(BayesTree_marginalFactor); + + // get clique containing Index j + sharedClique clique = (*this)[j]; + + // calculate or retrieve its marginal P(C) = P(F,S) +#ifdef OLD_SHORTCUT_MARGINALS + FactorGraph cliqueMarginal = clique->marginal(root_,function); +#else + FactorGraph cliqueMarginal = clique->marginal2(root_,function); +#endif + + // Reduce the variable indices to start at zero + gttic(Reduce); + const Permutation reduction = internal::createReducingPermutation(cliqueMarginal.keys()); + internal::Reduction inverseReduction = internal::Reduction::CreateAsInverse(reduction); + BOOST_FOREACH(const boost::shared_ptr& factor, cliqueMarginal) { + if(factor) factor->reduceWithInverse(inverseReduction); } + gttoc(Reduce); + + // now, marginalize out everything that is not variable j + GenericSequentialSolver solver(cliqueMarginal); + boost::shared_ptr result = solver.marginalFactor(inverseReduction[j], function); + + // Undo the reduction + gttic(Undo_Reduce); + result->permuteWithInverse(reduction); + BOOST_FOREACH(const boost::shared_ptr& factor, cliqueMarginal) { + if(factor) factor->permuteWithInverse(reduction); } + gttoc(Undo_Reduce); + return result; + } + + /* ************************************************************************* */ + template + typename BayesNet::shared_ptr BayesTree::marginalBayesNet( + Index j, Eliminate function) const + { + gttic(BayesTree_marginalBayesNet); + + // calculate marginal as a factor graph + FactorGraph fg; + fg.push_back(this->marginalFactor(j,function)); + + // Reduce the variable indices to start at zero + gttic(Reduce); + const Permutation reduction = internal::createReducingPermutation(fg.keys()); + internal::Reduction inverseReduction = internal::Reduction::CreateAsInverse(reduction); + BOOST_FOREACH(const boost::shared_ptr& factor, fg) { + if(factor) factor->reduceWithInverse(inverseReduction); } + gttoc(Reduce); + + // eliminate factor graph marginal to a Bayes net + boost::shared_ptr > bn = GenericSequentialSolver(fg).eliminate(function); + + // Undo the reduction + gttic(Undo_Reduce); + bn->permuteWithInverse(reduction); + BOOST_FOREACH(const boost::shared_ptr& factor, fg) { + if(factor) factor->permuteWithInverse(reduction); } + gttoc(Undo_Reduce); + return bn; + } + + /* ************************************************************************* */ + // Find two cliques, their joint, then marginalizes + /* ************************************************************************* */ + template + typename FactorGraph::shared_ptr + BayesTree::joint(Index j1, Index j2, Eliminate function) const { + gttic(BayesTree_joint); + + // get clique C1 and C2 + sharedClique C1 = (*this)[j1], C2 = (*this)[j2]; + + gttic(Lowest_common_ancestor); + // Find lowest common ancestor clique + sharedClique B; { + // Build two paths to the root + FastList path1, path2; { + sharedClique p = C1; + while(p) { + path1.push_front(p); + p = p->parent(); + } + } { + sharedClique p = C2; + while(p) { + path2.push_front(p); + p = p->parent(); + } + } + // Find the path intersection + B = this->root(); + typename FastList::const_iterator p1 = path1.begin(), p2 = path2.begin(); + while(p1 != path1.end() && p2 != path2.end() && *p1 == *p2) { + B = *p1; + ++p1; + ++p2; + } + } + gttoc(Lowest_common_ancestor); + + // Compute marginal on lowest common ancestor clique + gttic(LCA_marginal); + FactorGraph p_B = B->marginal2(this->root(), function); + gttoc(LCA_marginal); + + // Compute shortcuts of the requested cliques given the lowest common ancestor + gttic(Clique_shortcuts); + BayesNet p_C1_Bred = C1->shortcut(B, function); + BayesNet p_C2_Bred = C2->shortcut(B, function); + gttoc(Clique_shortcuts); + + // Factor the shortcuts to be conditioned on the full root + // Get the set of variables to eliminate, which is C1\B. + gttic(Full_root_factoring); + sharedConditional p_C1_B; { + std::vector C1_minus_B; { + FastSet C1_minus_B_set(C1->conditional()->beginParents(), C1->conditional()->endParents()); + BOOST_FOREACH(const Index j, *B->conditional()) { + C1_minus_B_set.erase(j); } + C1_minus_B.assign(C1_minus_B_set.begin(), C1_minus_B_set.end()); + } + // Factor into C1\B | B. + FactorGraph temp_remaining; + boost::tie(p_C1_B, temp_remaining) = FactorGraph(p_C1_Bred).eliminate(C1_minus_B, function); + } + sharedConditional p_C2_B; { + std::vector C2_minus_B; { + FastSet C2_minus_B_set(C2->conditional()->beginParents(), C2->conditional()->endParents()); + BOOST_FOREACH(const Index j, *B->conditional()) { + C2_minus_B_set.erase(j); } + C2_minus_B.assign(C2_minus_B_set.begin(), C2_minus_B_set.end()); + } + // Factor into C2\B | B. + FactorGraph temp_remaining; + boost::tie(p_C2_B, temp_remaining) = FactorGraph(p_C2_Bred).eliminate(C2_minus_B, function); + } + gttoc(Full_root_factoring); + + gttic(Variable_joint); + // Build joint on all involved variables + FactorGraph p_BC1C2; + p_BC1C2.push_back(p_B); + p_BC1C2.push_back(p_C1_B->toFactor()); + p_BC1C2.push_back(p_C2_B->toFactor()); + if(C1 != B) + p_BC1C2.push_back(C1->conditional()->toFactor()); + if(C2 != B) + p_BC1C2.push_back(C2->conditional()->toFactor()); + + // Reduce the variable indices to start at zero + gttic(Reduce); + const Permutation reduction = internal::createReducingPermutation(p_BC1C2.keys()); + internal::Reduction inverseReduction = internal::Reduction::CreateAsInverse(reduction); + BOOST_FOREACH(const boost::shared_ptr& factor, p_BC1C2) { + if(factor) factor->reduceWithInverse(inverseReduction); } + std::vector js; js.push_back(inverseReduction[j1]); js.push_back(inverseReduction[j2]); + gttoc(Reduce); + + // now, marginalize out everything that is not variable j + GenericSequentialSolver solver(p_BC1C2); + boost::shared_ptr > result = solver.jointFactorGraph(js, function); + + // Undo the reduction + gttic(Undo_Reduce); + BOOST_FOREACH(const boost::shared_ptr& factor, *result) { + if(factor) factor->permuteWithInverse(reduction); } + BOOST_FOREACH(const boost::shared_ptr& factor, p_BC1C2) { + if(factor) factor->permuteWithInverse(reduction); } + gttoc(Undo_Reduce); + return result; + + } + + /* ************************************************************************* */ + template + typename BayesNet::shared_ptr BayesTree::jointBayesNet( + Index j1, Index j2, Eliminate function) const { + + // eliminate factor graph marginal to a Bayes net + return GenericSequentialSolver ( + *this->joint(j1, j2, function)).eliminate(function); + } + + /* ************************************************************************* */ + template + void BayesTree::clear() { + // Remove all nodes and clear the root pointer + nodes_.clear(); + root_.reset(); + } + + /* ************************************************************************* */ + template + void BayesTree::removePath(sharedClique clique, + BayesNet& bn, typename BayesTree::Cliques& orphans) { + + // base case is NULL, if so we do nothing and return empties above + if (clique!=NULL) { + + // remove the clique from orphans in case it has been added earlier + orphans.remove(clique); + + // remove me + this->removeClique(clique); + + // remove path above me + this->removePath(typename Clique::shared_ptr(clique->parent_.lock()), bn, orphans); + + // add children to list of orphans (splice also removed them from clique->children_) + orphans.splice(orphans.begin(), clique->children_); + + bn.push_back(clique->conditional()); + + } + } + + /* ************************************************************************* */ + template + template + void BayesTree::removeTop(const CONTAINER& keys, + BayesNet& bn, typename BayesTree::Cliques& orphans) { + + // process each key of the new factor + BOOST_FOREACH(const Index& j, keys) { + + // get the clique + if(j < nodes_.size()) { + const sharedClique& clique(nodes_[j]); + if(clique) { + // remove path from clique to root + this->removePath(clique, bn, orphans); + } + } + } + + // Delete cachedShortcuts for each orphan subtree + //TODO: Consider Improving + BOOST_FOREACH(sharedClique& orphan, orphans) + orphan->deleteCachedShortcuts(); + } + + /* ************************************************************************* */ + template + typename BayesTree::Cliques BayesTree::removeSubtree( + const sharedClique& subtree) + { + // Result clique list + Cliques cliques; + cliques.push_back(subtree); + + // Remove the first clique from its parents + if(!subtree->isRoot()) + subtree->parent()->children().remove(subtree); + else + root_.reset(); + + // Add all subtree cliques and erase the children and parent of each + for(typename Cliques::iterator clique = cliques.begin(); clique != cliques.end(); ++clique) + { + // Add children + BOOST_FOREACH(const sharedClique& child, (*clique)->children()) { + cliques.push_back(child); } + + // Delete cached shortcuts + (*clique)->deleteCachedShortcutsNonRecursive(); + + // Remove this node from the nodes index + BOOST_FOREACH(Index j, (*clique)->conditional()->frontals()) { + nodes_[j].reset(); } + + // Erase the parent and children pointers + (*clique)->parent_.reset(); + (*clique)->children_.clear(); + } + + return cliques; + } + + /* ************************************************************************* */ + template + void BayesTree::cloneTo(This& newTree) const { + if(root()) + cloneTo(newTree, root(), sharedClique()); + } + + /* ************************************************************************* */ + template + void BayesTree::cloneTo( + This& newTree, const sharedClique& subtree, const sharedClique& parent) const { + sharedClique newClique(subtree->clone()); + newTree.addClique(newClique, parent); + BOOST_FOREACH(const sharedClique& childClique, subtree->children()) { + cloneTo(newTree, childClique, newClique); + } + } + +} +/// namespace gtsam diff --git a/gtsam/inference/BayesTreeUnordered.h b/gtsam/inference/BayesTreeUnordered.h new file mode 100644 index 000000000..5e35b3f09 --- /dev/null +++ b/gtsam/inference/BayesTreeUnordered.h @@ -0,0 +1,415 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file BayesTree.h + * @brief Bayes Tree is a tree of cliques of a Bayes Chain + * @author Frank Dellaert + */ + +// \callgraph + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace gtsam { + + // Forward declaration of BayesTreeClique which is defined below BayesTree in this file + template struct BayesTreeClique; + + /** + * Bayes tree + * @tparam CONDITIONAL The type of the conditional densities, i.e. the type of node in the underlying Bayes chain, + * which could be a ConditionalProbabilityTable, a GaussianConditional, or a SymbolicConditional. + * @tparam CLIQUE The type of the clique data structure, defaults to BayesTreeClique, normally do not change this + * as it is only used when developing special versions of BayesTree, e.g. for ISAM2. + * + * \addtogroup Multifrontal + * \nosubgrouping + */ + template > + class BayesTree { + + public: + + typedef BayesTree This; + typedef boost::shared_ptr > shared_ptr; + typedef boost::shared_ptr sharedConditional; + typedef boost::shared_ptr > sharedBayesNet; + typedef CONDITIONAL ConditionalType; + typedef typename CONDITIONAL::FactorType FactorType; + typedef typename FactorGraph::Eliminate Eliminate; + + typedef CLIQUE Clique; ///< The clique type, normally BayesTreeClique + + // typedef for shared pointers to cliques + typedef boost::shared_ptr sharedClique; + + // A convenience class for a list of shared cliques + struct Cliques : public FastList { + void print(const std::string& s = "Cliques", + const IndexFormatter& indexFormatter = DefaultIndexFormatter) const; + bool equals(const Cliques& other, double tol = 1e-9) const; + }; + + /** clique statistics */ + struct CliqueStats { + double avgConditionalSize; + std::size_t maxConditionalSize; + double avgSeparatorSize; + std::size_t maxSeparatorSize; + void print(const std::string& s = "") const ; + }; + + /** store all the sizes */ + struct CliqueData { + std::vector conditionalSizes; + std::vector separatorSizes; + CliqueStats getStats() const; + }; + + /** Map from indices to Clique */ + typedef std::vector Nodes; + + protected: + + /** Map from indices to Clique */ + Nodes nodes_; + + /** Root clique */ + sharedClique root_; + + public: + + /// @name Standard Constructors + /// @{ + + /** Create an empty Bayes Tree */ + BayesTree() {} + + /** Create a Bayes Tree from a Bayes Net (requires CONDITIONAL is IndexConditional *or* CONDITIONAL::Combine) */ + explicit BayesTree(const BayesNet& bayesNet); + + /** Copy constructor */ + BayesTree(const This& other); + + /** Assignment operator */ + This& operator=(const This& other); + + /// @} + /// @name Advanced Constructors + /// @{ + + /** + * Create a Bayes Tree from a Bayes Net and some subtrees. The Bayes net corresponds to the + * new root clique and the subtrees are connected to the root clique. + */ + BayesTree(const BayesNet& bayesNet, std::list > subtrees); + + /** Destructor */ + virtual ~BayesTree() {} + + /// @} + /// @name Testable + /// @{ + + /** check equality */ + bool equals(const BayesTree& other, double tol = 1e-9) const; + + /** print */ + void print(const std::string& s = "", + const IndexFormatter& indexFormatter = DefaultIndexFormatter ) const; + + /// @} + /// @name Standard Interface + /// @{ + + /** + * Find parent clique of a conditional. It will look at all parents and + * return the one with the lowest index in the ordering. + */ + template + Index findParentClique(const CONTAINER& parents) const; + + /** number of cliques */ + inline size_t size() const { + if(root_) + return root_->treeSize(); + else + return 0; + } + + /** Check if there are any cliques in the tree */ + inline bool empty() const { + return nodes_.empty(); + } + + /** return nodes */ + const Nodes& nodes() const { return nodes_; } + + /** return root clique */ + const sharedClique& root() const { return root_; } + + /** find the clique that contains the variable with Index j */ + inline sharedClique operator[](Index j) const { + return nodes_.at(j); + } + + /** alternate syntax for matlab: find the clique that contains the variable with Index j */ + inline sharedClique clique(Index j) const { + return nodes_.at(j); + } + + /** Gather data on all cliques */ + CliqueData getCliqueData() const; + + /** Collect number of cliques with cached separator marginals */ + size_t numCachedSeparatorMarginals() const; + + /** return marginal on any variable */ + typename FactorType::shared_ptr marginalFactor(Index j, Eliminate function) const; + + /** + * return marginal on any variable, as a Bayes Net + * NOTE: this function calls marginal, and then eliminates it into a Bayes Net + * This is more expensive than the above function + */ + typename BayesNet::shared_ptr marginalBayesNet(Index j, Eliminate function) const; + + /** + * return joint on two variables + * Limitation: can only calculate joint if cliques are disjoint or one of them is root + */ + typename FactorGraph::shared_ptr joint(Index j1, Index j2, Eliminate function) const; + + /** + * return joint on two variables as a BayesNet + * Limitation: can only calculate joint if cliques are disjoint or one of them is root + */ + typename BayesNet::shared_ptr jointBayesNet(Index j1, Index j2, Eliminate function) const; + + /** + * Read only with side effects + */ + + /** saves the Tree to a text file in GraphViz format */ + void saveGraph(const std::string& s, const IndexFormatter& indexFormatter = DefaultIndexFormatter ) const; + + /// @} + /// @name Advanced Interface + /// @{ + + /** Access the root clique (non-const version) */ + sharedClique& root() { return root_; } + + /** Access the nodes (non-cost version) */ + Nodes& nodes() { return nodes_; } + + /** Remove all nodes */ + void clear(); + + /** Clear all shortcut caches - use before timing on marginal calculation to avoid residual cache data */ + void deleteCachedShortcuts() { + root_->deleteCachedShortcuts(); + } + + /** Apply a permutation to the full tree - also updates the nodes structure */ + void permuteWithInverse(const Permutation& inversePermutation); + + /** + * Remove path from clique to root and return that path as factors + * plus a list of orphaned subtree roots. Used in removeTop below. + */ + void removePath(sharedClique clique, BayesNet& bn, Cliques& orphans); + + /** + * Given a list of indices, turn "contaminated" part of the tree back into a factor graph. + * Factors and orphans are added to the in/out arguments. + */ + template + void removeTop(const CONTAINER& indices, BayesNet& bn, Cliques& orphans); + + /** + * Remove the requested subtree. */ + Cliques removeSubtree(const sharedClique& subtree); + + /** + * Hang a new subtree off of the existing tree. This finds the appropriate + * parent clique for the subtree (which may be the root), and updates the + * nodes index with the new cliques in the subtree. None of the frontal + * variables in the subtree may appear in the separators of the existing + * BayesTree. + */ + void insert(const sharedClique& subtree); + + /** Insert a new conditional + * This function only applies for Symbolic case with IndexCondtional, + * We make it static so that it won't be compiled in GaussianConditional case. + * */ + static void insert(BayesTree& bayesTree, const sharedConditional& conditional); + + /** + * Insert a new clique corresponding to the given Bayes net. + * It is the caller's responsibility to decide whether the given Bayes net is a valid clique, + * i.e. all the variables (frontal and separator) are connected + */ + sharedClique insert(const sharedConditional& clique, + std::list& children, bool isRootClique = false); + + /** + * Create a clone of this object as a shared pointer + * Necessary for inheritance in matlab interface + */ + virtual shared_ptr clone() const { + return shared_ptr(new This(*this)); + } + + protected: + + /** private helper method for saving the Tree to a text file in GraphViz format */ + void saveGraph(std::ostream &s, sharedClique clique, const IndexFormatter& indexFormatter, + int parentnum = 0) const; + + /** Gather data on a single clique */ + void getCliqueData(CliqueData& stats, sharedClique clique) const; + + /** remove a clique: warning, can result in a forest */ + void removeClique(sharedClique clique); + + /** add a clique (top down) */ + sharedClique addClique(const sharedConditional& conditional, const sharedClique& parent_clique = sharedClique()); + + /** add a clique (top down) */ + void addClique(const sharedClique& clique, const sharedClique& parent_clique = sharedClique()); + + /** add a clique (bottom up) */ + sharedClique addClique(const sharedConditional& conditional, std::list& child_cliques); + + /** + * Add a conditional to the front of a clique, i.e. a conditional whose + * parents are already in the clique or its separators. This function does + * not check for this condition, it just updates the data structures. + */ + static void addToCliqueFront(BayesTree& bayesTree, + const sharedConditional& conditional, const sharedClique& clique); + + /** Fill the nodes index for a subtree */ + void fillNodesIndex(const sharedClique& subtree); + + /** Helper function to build a non-symbolic tree (e.g. Gaussian) using a + * symbolic tree, used in the BT(BN) constructor. + */ + void recursiveTreeBuild(const boost::shared_ptr >& symbolic, + const std::vector >& conditionals, + const typename BayesTree::sharedClique& parent); + + private: + + /** deep copy to another tree */ + void cloneTo(This& newTree) const; + + /** deep copy to another tree */ + void cloneTo(This& newTree, const sharedClique& subtree, const sharedClique& parent) const; + + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE & ar, const unsigned int version) { + ar & BOOST_SERIALIZATION_NVP(nodes_); + ar & BOOST_SERIALIZATION_NVP(root_); + } + + /// @} + + }; // BayesTree + + + /* ************************************************************************* */ + template + void _BayesTree_dim_adder( + std::vector& dims, + const typename BayesTree::sharedClique& clique) { + + if(clique) { + // Add dims from this clique + for(typename CONDITIONAL::const_iterator it = (*clique)->beginFrontals(); it != (*clique)->endFrontals(); ++it) + dims[*it] = (*clique)->dim(it); + + // Traverse children + typedef typename BayesTree::sharedClique sharedClique; + BOOST_FOREACH(const sharedClique& child, clique->children()) { + _BayesTree_dim_adder(dims, child); + } + } + } + + /* ************************************************************************* */ + template + boost::shared_ptr allocateVectorValues(const BayesTree& bt) { + std::vector dimensions(bt.nodes().size(), 0); + _BayesTree_dim_adder(dimensions, bt.root()); + return boost::shared_ptr(new VectorValues(dimensions)); + } + + + /* ************************************************************************* */ + /** + * A Clique in the tree is an incomplete Bayes net: the variables + * in the Bayes net are the frontal nodes, and the variables conditioned + * on are the separator. We also have pointers up and down the tree. + * + * Since our Conditional class already handles multiple frontal variables, + * this Clique contains exactly 1 conditional. + * + * This is the default clique type in a BayesTree, but some algorithms, like + * iSAM2 (see ISAM2Clique), use a different clique type in order to store + * extra data along with the clique. + */ + template + struct BayesTreeClique : public BayesTreeCliqueBase, CONDITIONAL> { + public: + typedef CONDITIONAL ConditionalType; + typedef BayesTreeClique This; + typedef BayesTreeCliqueBase Base; + typedef boost::shared_ptr shared_ptr; + typedef boost::weak_ptr weak_ptr; + BayesTreeClique() {} + BayesTreeClique(const typename ConditionalType::shared_ptr& conditional) : Base(conditional) {} + BayesTreeClique(const std::pair& result) : Base(result) {} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE & ar, const unsigned int version) { + ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + } + }; + +} /// namespace gtsam + +#include +#include