gtsam/gtsam/inference/BayesTreeCliqueBase-inl.h

261 lines
11 KiB
C++

/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file BayesTreeCliqueBase
* @brief Base class for cliques of a BayesTree
* @author Richard Roberts and Frank Dellaert
*/
#pragma once
#include <gtsam/inference/GenericSequentialSolver.h>
namespace gtsam {
/* ************************************************************************* */
template<class DERIVED, class CONDITIONAL>
void BayesTreeCliqueBase<DERIVED,CONDITIONAL>::assertInvariants() const {
#ifndef NDEBUG
// We rely on the keys being sorted
// FastVector<Index> sortedUniqueKeys(conditional_->begin(), conditional_->end());
// std::sort(sortedUniqueKeys.begin(), sortedUniqueKeys.end());
// std::unique(sortedUniqueKeys.begin(), sortedUniqueKeys.end());
// assert(sortedUniqueKeys.size() == conditional_->size() &&
// std::equal(sortedUniqueKeys.begin(), sortedUniqueKeys.end(), conditional_->begin()));
#endif
}
/* ************************************************************************* */
template<class DERIVED, class CONDITIONAL>
BayesTreeCliqueBase<DERIVED,CONDITIONAL>::BayesTreeCliqueBase(const sharedConditional& conditional) :
conditional_(conditional) {
assertInvariants();
}
/* ************************************************************************* */
template<class DERIVED, class CONDITIONAL>
BayesTreeCliqueBase<DERIVED,CONDITIONAL>::BayesTreeCliqueBase(const std::pair<sharedConditional, boost::shared_ptr<typename ConditionalType::FactorType> >& result) :
conditional_(result.first) {
assertInvariants();
}
/* ************************************************************************* */
template<class DERIVED, class CONDITIONAL>
void BayesTreeCliqueBase<DERIVED,CONDITIONAL>::print(const std::string& s) const {
conditional_->print(s);
}
/* ************************************************************************* */
template<class DERIVED, class CONDITIONAL>
size_t BayesTreeCliqueBase<DERIVED,CONDITIONAL>::treeSize() const {
size_t size = 1;
BOOST_FOREACH(const derived_ptr& child, children_)
size += child->treeSize();
return size;
}
/* ************************************************************************* */
template<class DERIVED, class CONDITIONAL>
void BayesTreeCliqueBase<DERIVED,CONDITIONAL>::printTree(const std::string& indent) const {
asDerived(this)->print(indent);
BOOST_FOREACH(const derived_ptr& child, children_)
child->printTree(indent+" ");
}
/* ************************************************************************* */
template<class DERIVED, class CONDITIONAL>
void BayesTreeCliqueBase<DERIVED,CONDITIONAL>::permuteWithInverse(const Permutation& inversePermutation) {
conditional_->permuteWithInverse(inversePermutation);
BOOST_FOREACH(const derived_ptr& child, children_) {
child->permuteWithInverse(inversePermutation);
}
assertInvariants();
}
/* ************************************************************************* */
template<class DERIVED, class CONDITIONAL>
bool BayesTreeCliqueBase<DERIVED,CONDITIONAL>::permuteSeparatorWithInverse(const Permutation& inversePermutation) {
bool changed = conditional_->permuteSeparatorWithInverse(inversePermutation);
#ifndef NDEBUG
if(!changed) {
BOOST_FOREACH(Index& separatorKey, conditional_->parents()) { assert(separatorKey == inversePermutation[separatorKey]); }
BOOST_FOREACH(const derived_ptr& child, children_) {
assert(child->permuteSeparatorWithInverse(inversePermutation) == false);
}
}
#endif
if(changed) {
BOOST_FOREACH(const derived_ptr& child, children_) {
(void)child->permuteSeparatorWithInverse(inversePermutation);
}
}
assertInvariants();
return changed;
}
/* ************************************************************************* */
// The shortcut density is a conditional P(S|R) of the separator of this
// clique on the root. We can compute it recursively from the parent shortcut
// P(Sp|R) as \int P(Fp|Sp) P(Sp|R), where Fp are the frontal nodes in p
/* ************************************************************************* */
template<class DERIVED, class CONDITIONAL>
BayesNet<CONDITIONAL> BayesTreeCliqueBase<DERIVED,CONDITIONAL>::shortcut(derived_ptr R, Eliminate function) {
static const bool debug = false;
// A first base case is when this clique or its parent is the root,
// in which case we return an empty Bayes net.
derived_ptr parent(parent_.lock());
if (R.get()==this || parent==R) {
BayesNet<ConditionalType> empty;
return empty;
}
// The root conditional
FactorGraph<FactorType> p_R(BayesNet<ConditionalType>(R->conditional()));
// The parent clique has a ConditionalType for each frontal node in Fp
// so we can obtain P(Fp|Sp) in factor graph form
FactorGraph<FactorType> p_Fp_Sp(BayesNet<ConditionalType>(parent->conditional()));
// If not the base case, obtain the parent shortcut P(Sp|R) as factors
FactorGraph<FactorType> p_Sp_R(parent->shortcut(R, function));
// now combine P(Cp|R) = P(Fp|Sp) * P(Sp|R)
FactorGraph<FactorType> p_Cp_R;
p_Cp_R.push_back(p_R);
p_Cp_R.push_back(p_Fp_Sp);
p_Cp_R.push_back(p_Sp_R);
// Eliminate into a Bayes net with ordering designed to integrate out
// any variables not in *our* separator. Variables to integrate out must be
// eliminated first hence the desired ordering is [Cp\S S].
// However, an added wrinkle is that Cp might overlap with the root.
// Keys corresponding to the root should not be added to the ordering at all.
if(debug) {
p_R.print("p_R: ");
p_Fp_Sp.print("p_Fp_Sp: ");
p_Sp_R.print("p_Sp_R: ");
}
// We want to factor into a conditional of the clique variables given the
// root and the marginal on the root, integrating out all other variables.
// The integrands include any parents of this clique and the variables of
// the parent clique.
FastSet<Index> variablesAtBack;
FastSet<Index> separator;
size_t uniqueRootVariables = 0;
BOOST_FOREACH(const Index separatorIndex, this->conditional()->parents()) {
variablesAtBack.insert(separatorIndex);
separator.insert(separatorIndex);
if(debug) std::cout << "At back (this): " << separatorIndex << std::endl;
}
BOOST_FOREACH(const Index key, R->conditional()->keys()) {
if(variablesAtBack.insert(key).second)
++ uniqueRootVariables;
if(debug) std::cout << "At back (root): " << key << std::endl;
}
Permutation toBack = Permutation::PushToBack(
std::vector<Index>(variablesAtBack.begin(), variablesAtBack.end()),
R->conditional()->lastFrontalKey() + 1);
Permutation::shared_ptr toBackInverse(toBack.inverse());
BOOST_FOREACH(const typename FactorType::shared_ptr& factor, p_Cp_R) {
factor->permuteWithInverse(*toBackInverse); }
typename BayesNet<ConditionalType>::shared_ptr eliminated(EliminationTree<
FactorType>::Create(p_Cp_R)->eliminate(function));
// Take only the conditionals for p(S|R). We check for each variable being
// in the separator set because if some separator variables overlap with
// root variables, we cannot rely on the number of root variables, and also
// want to include those variables in the conditional.
BayesNet<ConditionalType> p_S_R;
BOOST_REVERSE_FOREACH(typename ConditionalType::shared_ptr conditional, *eliminated) {
assert(conditional->nrFrontals() == 1);
if(separator.find(toBack[conditional->firstFrontalKey()]) != separator.end()) {
if(debug)
conditional->print("Taking C|R conditional: ");
p_S_R.push_front(conditional);
}
if(p_S_R.size() == separator.size())
break;
}
// Undo the permutation
if(debug) toBack.print("toBack: ");
p_S_R.permuteWithInverse(toBack);
// return the parent shortcut P(Sp|R)
assertInvariants();
return p_S_R;
}
/* ************************************************************************* */
// P(C) = \int_R P(F|S) P(S|R) P(R)
// TODO: Maybe we should integrate given parent marginal P(Cp),
// \int(Cp\S) P(F|S)P(S|Cp)P(Cp)
// Because the root clique could be very big.
/* ************************************************************************* */
template<class DERIVED, class CONDITIONAL>
FactorGraph<typename BayesTreeCliqueBase<DERIVED,CONDITIONAL>::FactorType> BayesTreeCliqueBase<DERIVED,CONDITIONAL>::marginal(
derived_ptr R, Eliminate function) {
// If we are the root, just return this root
// NOTE: immediately cast to a factor graph
BayesNet<ConditionalType> bn(R->conditional());
if (R.get()==this) return bn;
// Combine P(F|S), P(S|R), and P(R)
BayesNet<ConditionalType> p_FSR = this->shortcut(R, function);
p_FSR.push_front(this->conditional());
p_FSR.push_back(R->conditional());
assertInvariants();
GenericSequentialSolver<FactorType> solver(p_FSR);
return *solver.jointFactorGraph(conditional_->keys(), function);
}
/* ************************************************************************* */
// P(C1,C2) = \int_R P(F1|S1) P(S1|R) P(F2|S1) P(S2|R) P(R)
/* ************************************************************************* */
template<class DERIVED, class CONDITIONAL>
FactorGraph<typename BayesTreeCliqueBase<DERIVED,CONDITIONAL>::FactorType> BayesTreeCliqueBase<DERIVED,CONDITIONAL>::joint(
derived_ptr C2, derived_ptr R, Eliminate function) {
// For now, assume neither is the root
// Combine P(F1|S1), P(S1|R), P(F2|S2), P(S2|R), and P(R)
FactorGraph<FactorType> joint;
if (!isRoot()) joint.push_back(this->conditional()->toFactor()); // P(F1|S1)
if (!isRoot()) joint.push_back(shortcut(R, function)); // P(S1|R)
if (!C2->isRoot()) joint.push_back(C2->conditional()->toFactor()); // P(F2|S2)
if (!C2->isRoot()) joint.push_back(C2->shortcut(R, function)); // P(S2|R)
joint.push_back(R->conditional()->toFactor()); // P(R)
// Find the keys of both C1 and C2
std::vector<Index> keys1(conditional_->keys());
std::vector<Index> keys2(C2->conditional_->keys());
FastSet<Index> keys12;
keys12.insert(keys1.begin(), keys1.end());
keys12.insert(keys2.begin(), keys2.end());
// Calculate the marginal
std::vector<Index> keys12vector; keys12vector.reserve(keys12.size());
keys12vector.insert(keys12vector.begin(), keys12.begin(), keys12.end());
assertInvariants();
GenericSequentialSolver<FactorType> solver(joint);
return *solver.jointFactorGraph(keys12vector, function);
}
}