/** * @file BayesTree.cpp * @brief Bayes Tree is a tree of cliques of a Bayes Chain * @author Frank Dellaert * @author Michael Kaess * @author Viorela Ila */ #pragma once #include #include #include // for operator += #include using namespace boost::assign; #include "Conditional.h" #include "BayesTree.h" #include "Ordering.h" #include "inference-inl.h" #include "Key.h" namespace gtsam { using namespace std; /* ************************************************************************* */ template BayesTree::Clique::Clique() {} /* ************************************************************************* */ template BayesTree::Clique::Clique(const sharedConditional& conditional) { separator_ = conditional->parents(); this->push_back(conditional); } /* ************************************************************************* */ template Ordering BayesTree::Clique::keys() const { Ordering frontal_keys = this->ordering(), keys = separator_; keys.splice(keys.begin(),frontal_keys); return keys; } /* ************************************************************************* */ template void BayesTree::Clique::print(const string& s) const { cout << s; BOOST_FOREACH(const sharedConditional& conditional, this->conditionals_) { conditional->print("conditioanl"); cout << " " << (string)(conditional->key()); } if (!separator_.empty()) { cout << " :"; BOOST_FOREACH(const Symbol& key, separator_) cout << " " << (std::string)key; } cout << endl; } /* ************************************************************************* */ template size_t BayesTree::Clique::treeSize() const { size_t size = 1; BOOST_FOREACH(const shared_ptr& child, children_) size += child->treeSize(); return size; } /* ************************************************************************* */ template void BayesTree::Clique::printTree(const string& indent) const { print(indent); BOOST_FOREACH(const shared_ptr& child, children_) child->printTree(indent+" "); } /* ************************************************************************* */ template typename BayesTree::CliqueData BayesTree::getCliqueData() const { CliqueData data; getCliqueData(data, root_); return data; } template void BayesTree::getCliqueData(CliqueData& data, BayesTree::sharedClique clique) const { data.conditionalSizes.push_back(clique->conditionals_.size()); data.separatorSizes.push_back(clique->separator_.size()); BOOST_FOREACH(sharedClique c, clique->children_) { getCliqueData(data, c); } } /* ************************************************************************* */ template void BayesTree::saveGraph(const std::string &s) const { if (!root_.get()) throw invalid_argument("the root of bayes tree has not been initialized!"); ofstream of(s.c_str()); of<< "digraph G{\n"; saveGraph(of, root_); of<<"}"; of.close(); } template void BayesTree::saveGraph(ostream &s, BayesTree::sharedClique clique, int parentnum) const { static int num = 0; bool first = true; std::stringstream out; out << num; string parent = out.str(); parent += "[label=\""; BOOST_FOREACH(boost::shared_ptr c, clique->conditionals_) { if(!first) parent += ","; first = false; parent += (string(c->key())).c_str(); } if( clique != root_){ parent += " : "; s << parentnum << "->" << num << "\n"; } first = true; BOOST_FOREACH(const Symbol& sep, clique->separator_) { if(!first) parent += ","; first = false; parent += ((string)sep).c_str(); } parent += "\"];\n"; s << parent; parentnum = num; BOOST_FOREACH(sharedClique c, clique->children_) { num++; saveGraph(s, c, parentnum); } } template typename BayesTree::CliqueStats BayesTree::CliqueData::getStats() const { CliqueStats stats; double sum = 0.0; size_t max = 0; BOOST_FOREACH(size_t s, conditionalSizes) { sum += (double)s; if(s > max) max = s; } stats.avgConditionalSize = sum / (double)conditionalSizes.size(); stats.maxConditionalSize = max; sum = 0.0; max = 1; BOOST_FOREACH(size_t s, separatorSizes) { sum += (double)s; if(s > max) max = s; } stats.avgSeparatorSize = sum / (double)separatorSizes.size(); stats.maxSeparatorSize = max; return stats; } /* ************************************************************************* */ // The shortcut density is a conditional P(S|R) of the separator of this // clique on the root. We can compute it recursively from the parent shortcut // P(Sp|R) as \int P(Fp|Sp) P(Sp|R), where Fp are the frontal nodes in p // TODO, why do we actually return a shared pointer, why does eliminate? /* ************************************************************************* */ template template BayesNet BayesTree::Clique::shortcut(shared_ptr R) { // A first base case is when this clique or its parent is the root, // in which case we return an empty Bayes net. if (R.get()==this || parent_==R) { BayesNet empty; return empty; } // The parent clique has a Conditional for each frontal node in Fp // so we can obtain P(Fp|Sp) in factor graph form FactorGraph p_Fp_Sp(*parent_); // If not the base case, obtain the parent shortcut P(Sp|R) as factors FactorGraph p_Sp_R(parent_->shortcut(R)); // now combine P(Cp|R) = P(Fp|Sp) * P(Sp|R) FactorGraph p_Cp_R = combine(p_Fp_Sp, p_Sp_R); // Eliminate into a Bayes net with ordering designed to integrate out // any variables not in *our* separator. Variables to integrate out must be // eliminated first hence the desired ordering is [Cp\S S]. // However, an added wrinkle is that Cp might overlap with the root. // Keys corresponding to the root should not be added to the ordering at all. // Get the key list Cp=Fp+Sp, which will form the basis for the integrands Ordering integrands = parent_->keys(); // Start ordering with the separator Ordering ordering = separator_; // remove any variables in the root, after this integrands = Cp\R, ordering = S\R BOOST_FOREACH(const Symbol& key, R->ordering()) { integrands.remove(key); ordering.remove(key); } // remove any variables in the separator, after this integrands = Cp\R\S BOOST_FOREACH(const Symbol& key, separator_) integrands.remove(key); // form the ordering as [Cp\R\S S\R] BOOST_REVERSE_FOREACH(const Symbol& key, integrands) ordering.push_front(key); // eliminate to get marginal BayesNet p_S_R = eliminate(p_Cp_R,ordering); // remove all integrands BOOST_FOREACH(const Symbol& key, integrands) p_S_R.pop_front(); // return the parent shortcut P(Sp|R) return p_S_R; } /* ************************************************************************* */ // P(C) = \int_R P(F|S) P(S|R) P(R) // TODO: Maybe we should integrate given parent marginal P(Cp), // \int(Cp\S) P(F|S)P(S|Cp)P(Cp) // Because the root clique could be very big. /* ************************************************************************* */ template template FactorGraph BayesTree::Clique::marginal(shared_ptr R) { // If we are the root, just return this root if (R.get()==this) return *R; // Combine P(F|S), P(S|R), and P(R) BayesNet p_FSR = this->shortcut(R); p_FSR.push_front(*this); p_FSR.push_back(*R); // Find marginal on the keys we are interested in return marginalize(p_FSR,keys()); } /* ************************************************************************* */ // P(C1,C2) = \int_R P(F1|S1) P(S1|R) P(F2|S1) P(S2|R) P(R) /* ************************************************************************* */ template template pair, Ordering> BayesTree::Clique::joint(shared_ptr C2, shared_ptr R) { // For now, assume neither is the root // Combine P(F1|S1), P(S1|R), P(F2|S2), P(S2|R), and P(R) sharedBayesNet bn(new BayesNet); if (!isRoot()) bn->push_back(*this); // P(F1|S1) if (!isRoot()) bn->push_back(shortcut(R)); // P(S1|R) if (!C2->isRoot()) bn->push_back(*C2); // P(F2|S2) if (!C2->isRoot()) bn->push_back(C2->shortcut(R)); // P(S2|R) bn->push_back(*R); // P(R) // Find the keys of both C1 and C2 Ordering keys12 = keys(); BOOST_FOREACH(const Symbol& key,C2->keys()) keys12.push_back(key); keys12.unique(); // Calculate the marginal return make_pair(marginalize(*bn,keys12), keys12); } /* ************************************************************************* */ template void BayesTree::Cliques::print(const std::string& s) const { cout << s << ":\n"; BOOST_FOREACH(sharedClique clique, *this) clique->printTree(); } /* ************************************************************************* */ template bool BayesTree::Cliques::equals(const Cliques& other, double tol) const { return other == *this; } /* ************************************************************************* */ template typename BayesTree::sharedClique BayesTree::addClique (const sharedConditional& conditional, sharedClique parent_clique) { sharedClique new_clique(new Clique(conditional)); nodes_.insert(make_pair(conditional->key(), new_clique)); if (parent_clique != NULL) { new_clique->parent_ = parent_clique; parent_clique->children_.push_back(new_clique); } return new_clique; } /* ************************************************************************* */ template typename BayesTree::sharedClique BayesTree::addClique (const sharedConditional& conditional, list& child_cliques) { sharedClique new_clique(new Clique(conditional)); nodes_.insert(make_pair(conditional->key(), new_clique)); new_clique->children_ = child_cliques; BOOST_FOREACH(sharedClique& child, child_cliques) child->parent_ = new_clique; return new_clique; } /* ************************************************************************* */ template void BayesTree::removeClique(sharedClique clique) { if (clique->isRoot()) root_.reset(); else // detach clique from parent clique->parent_->children_.remove(clique); // orphan my children BOOST_FOREACH(sharedClique child, clique->children_) child->parent_.reset(); BOOST_FOREACH(const Symbol& key, clique->ordering()) { nodes_.erase(key); } } /* ************************************************************************* */ template BayesTree::BayesTree() { } /* ************************************************************************* */ template BayesTree::BayesTree(const BayesNet& bayesNet) { IndexTable index(bayesNet.ordering()); typename BayesNet::const_reverse_iterator rit; for ( rit=bayesNet.rbegin(); rit != bayesNet.rend(); ++rit ) insert(*rit, index); } /* ************************************************************************* */ template BayesTree::BayesTree(const BayesNet& bayesNet, std::list > subtrees) { if (bayesNet.size() == 0) throw invalid_argument("BayesTree::insert: empty bayes net!"); // get the roots of child subtrees and merge their nodes_ list childRoots; BOOST_FOREACH(const BayesTree& subtree, subtrees) { nodes_.insert(subtree.nodes_.begin(), subtree.nodes_.end()); childRoots.push_back(subtree.root()); } // create a new clique and add all the conditionals to the clique sharedClique new_clique; typename BayesNet::sharedConditional conditional; BOOST_REVERSE_FOREACH(conditional, bayesNet) { if (!new_clique.get()) { new_clique = addClique(conditional,childRoots); } else { nodes_.insert(make_pair(conditional->key(), new_clique)); new_clique->push_front(conditional); } } root_ = new_clique; } /* ************************************************************************* */ template void BayesTree::print(const string& s) const { if (root_.use_count() == 0) { printf("WARNING: BayesTree.print encountered a forest...\n"); return; } cout << s << ": clique size == " << size() << ", node size == " << nodes_.size() << endl; if (nodes_.empty()) return; root_->printTree(""); } /* ************************************************************************* */ // binary predicate to test equality of a pair for use in equals template bool check_pair( const pair::sharedClique >& v1, const pair::sharedClique >& v2 ) { return v1.first == v2.first && v1.second->equals(*(v2.second)); } /* ************************************************************************* */ template bool BayesTree::equals(const BayesTree& other, double tol) const { return size()==other.size() && equal(nodes_.begin(),nodes_.end(),other.nodes_.begin(),check_pair); } /* ************************************************************************* */ template Symbol BayesTree::findParentClique(const list& parents, const IndexTable& index) const { boost::optional parentCliqueRepresentative; boost::optional lowest; BOOST_FOREACH(const Symbol& p, parents) { size_t i = index(p); if (!lowest || i<*lowest) { lowest.reset(i); parentCliqueRepresentative.reset(p); } } if (!lowest) throw invalid_argument("BayesTree::findParentClique: no parents given or key not present in index"); return *parentCliqueRepresentative; } /* ************************************************************************* */ template void BayesTree::insert(const sharedConditional& conditional, const IndexTable& index) { // get key and parents const Symbol& key = conditional->key(); list parents = conditional->parents(); // todo: const reference? // if no parents, start a new root clique if (parents.empty()) { root_ = addClique(conditional); return; } // otherwise, find the parent clique by using the index data structure // to find the lowest-ordered parent Symbol parentRepresentative = findParentClique(parents, index); sharedClique parent_clique = (*this)[parentRepresentative]; // if the parents and parent clique have the same size, add to parent clique if (parent_clique->size() == parents.size()) { nodes_.insert(make_pair(key, parent_clique)); parent_clique->push_front(conditional); return; } // otherwise, start a new clique and add it to the tree addClique(conditional,parent_clique); } /* ************************************************************************* */ //TODO: remove this function after removing TSAM.cpp template typename BayesTree::sharedClique BayesTree::insert( const BayesNet& bayesNet, list& children, bool isRootClique) { if (bayesNet.size() == 0) throw invalid_argument("BayesTree::insert: empty bayes net!"); // create a new clique and add all the conditionals to the clique sharedClique new_clique; typename BayesNet::sharedConditional conditional; BOOST_REVERSE_FOREACH(conditional, bayesNet) { if (!new_clique.get()) { new_clique = addClique(conditional,children); } else { nodes_.insert(make_pair(conditional->key(), new_clique)); new_clique->push_front(conditional); } } if (isRootClique) root_ = new_clique; return new_clique; } /* ************************************************************************* */ // First finds clique marginal then marginalizes that /* ************************************************************************* */ template template FactorGraph BayesTree::marginal(const Symbol& key) const { // get clique containing key sharedClique clique = (*this)[key]; // calculate or retrieve its marginal FactorGraph cliqueMarginal = clique->marginal(root_); // create an ordering where only the requested key is not eliminated Ordering ord = clique->keys(); ord.remove(key); // partially eliminate, remaining factor graph is requested marginal eliminate(cliqueMarginal,ord); return cliqueMarginal; } /* ************************************************************************* */ template template BayesNet BayesTree::marginalBayesNet(const Symbol& key) const { // calculate marginal as a factor graph FactorGraph fg = this->marginal(key); // eliminate further to Bayes net return eliminate(fg,Ordering(key)); } /* ************************************************************************* */ // Find two cliques, their joint, then marginalizes /* ************************************************************************* */ template template FactorGraph BayesTree::joint(const Symbol& key1, const Symbol& key2) const { // get clique C1 and C2 sharedClique C1 = (*this)[key1], C2 = (*this)[key2]; // calculate joint Ordering ord; FactorGraph p_C1C2; boost::tie(p_C1C2,ord) = C1->joint(C2,root_); // create an ordering where both requested keys are not eliminated ord.remove(key1); ord.remove(key2); // partially eliminate, remaining factor graph is requested joint // TODO, make eliminate functional eliminate(p_C1C2,ord); return p_C1C2; } /* ************************************************************************* */ template template BayesNet BayesTree::jointBayesNet(const Symbol& key1, const Symbol& key2) const { // calculate marginal as a factor graph FactorGraph fg = this->joint(key1,key2); // eliminate further to Bayes net Ordering ordering; ordering += key1, key2; return eliminate(fg,ordering); } /* ************************************************************************* */ template void BayesTree::clear() { // Remove all nodes and clear the root pointer nodes_.clear(); root_.reset(); } /* ************************************************************************* */ template void BayesTree::removePath(sharedClique clique, BayesNet& bn, typename BayesTree::Cliques& orphans) { // base case is NULL, if so we do nothing and return empties above if (clique!=NULL) { // remove the clique from orphans in case it has been added earlier orphans.remove(clique); // remove me this->removeClique(clique); // remove path above me this->removePath(clique->parent_, bn, orphans); // add children to list of orphans (splice also removed them from clique->children_) orphans.splice (orphans.begin(), clique->children_); bn.push_back(*clique); } } /* ************************************************************************* */ template void BayesTree::removeTop(const list& keys, BayesNet& bn, typename BayesTree::Cliques& orphans) { // process each key of the new factor BOOST_FOREACH(const Symbol& key, keys) try { // get the clique sharedClique clique = (*this)[key]; // remove path from clique to root this->removePath(clique, bn, orphans); } catch (std::invalid_argument e) { } } /* ************************************************************************* */ } /// namespace gtsam