211 lines
8.5 KiB
C++
211 lines
8.5 KiB
C++
/* ----------------------------------------------------------------------------
|
|
|
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
* Atlanta, Georgia 30332-0415
|
|
* All Rights Reserved
|
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
|
|
* See LICENSE for the license information
|
|
* -------------------------------------------------------------------------- */
|
|
|
|
/**
|
|
* @file BayesTreeCliqueBase-inst.h
|
|
* @brief Base class for cliques of a BayesTree
|
|
* @author Richard Roberts and Frank Dellaert
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include <gtsam/inference/BayesTreeCliqueBase.h>
|
|
#include <gtsam/base/timing.h>
|
|
#include <boost/foreach.hpp>
|
|
|
|
namespace gtsam {
|
|
|
|
/* ************************************************************************* */
|
|
template<class DERIVED, class FACTORGRAPH>
|
|
bool BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::equals(
|
|
const DERIVED& other, double tol) const
|
|
{
|
|
return (!conditional_ && !other.conditional())
|
|
|| conditional_->equals(*other.conditional(), tol);
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
template<class DERIVED, class FACTORGRAPH>
|
|
std::vector<Key>
|
|
BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::separator_setminus_B(const derived_ptr& B) const
|
|
{
|
|
FastSet<Key> p_F_S_parents(this->conditional()->beginParents(), this->conditional()->endParents());
|
|
FastSet<Key> indicesB(B->conditional()->begin(), B->conditional()->end());
|
|
std::vector<Key> S_setminus_B;
|
|
std::set_difference(p_F_S_parents.begin(), p_F_S_parents.end(),
|
|
indicesB.begin(), indicesB.end(), back_inserter(S_setminus_B));
|
|
return S_setminus_B;
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
template<class DERIVED, class FACTORGRAPH>
|
|
std::vector<Key> BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::shortcut_indices(
|
|
const derived_ptr& B, const FactorGraphType& p_Cp_B) const
|
|
{
|
|
gttic(shortcut_indices);
|
|
FastSet<Key> allKeys = p_Cp_B.keys();
|
|
FastSet<Key> indicesB(B->conditional()->begin(), B->conditional()->end());
|
|
std::vector<Key> S_setminus_B = separator_setminus_B(B);
|
|
std::vector<Key> keep;
|
|
// keep = S\B intersect allKeys (S_setminus_B is already sorted)
|
|
std::set_intersection(S_setminus_B.begin(), S_setminus_B.end(), //
|
|
allKeys.begin(), allKeys.end(), back_inserter(keep));
|
|
// keep += B intersect allKeys
|
|
std::set_intersection(indicesB.begin(), indicesB.end(), //
|
|
allKeys.begin(), allKeys.end(), back_inserter(keep));
|
|
return keep;
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
template<class DERIVED, class FACTORGRAPH>
|
|
void BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::print(
|
|
const std::string& s, const KeyFormatter& keyFormatter) const
|
|
{
|
|
conditional_->print(s, keyFormatter);
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
template<class DERIVED, class FACTORGRAPH>
|
|
size_t BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::treeSize() const {
|
|
size_t size = 1;
|
|
BOOST_FOREACH(const derived_ptr& child, children)
|
|
size += child->treeSize();
|
|
return size;
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
template<class DERIVED, class FACTORGRAPH>
|
|
size_t BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::numCachedSeparatorMarginals() const
|
|
{
|
|
if (!cachedSeparatorMarginal_)
|
|
return 0;
|
|
|
|
size_t subtree_count = 1;
|
|
BOOST_FOREACH(const derived_ptr& child, children)
|
|
subtree_count += child->numCachedSeparatorMarginals();
|
|
|
|
return subtree_count;
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
// The shortcut density is a conditional P(S|R) of the separator of this
|
|
// clique on the root. We can compute it recursively from the parent shortcut
|
|
// P(Sp|R) as \int P(Fp|Sp) P(Sp|R), where Fp are the frontal nodes in p
|
|
/* ************************************************************************* */
|
|
template<class DERIVED, class FACTORGRAPH>
|
|
typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::BayesNetType
|
|
BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::shortcut(const derived_ptr& B, Eliminate function) const
|
|
{
|
|
gttic(BayesTreeCliqueBase_shortcut);
|
|
// We only calculate the shortcut when this clique is not B
|
|
// and when the S\B is not empty
|
|
std::vector<Key> S_setminus_B = separator_setminus_B(B);
|
|
if (!parent_.expired() /*(if we're not the root)*/ && !S_setminus_B.empty())
|
|
{
|
|
// Obtain P(Cp||B) = P(Fp|Sp) * P(Sp||B) as a factor graph
|
|
derived_ptr parent(parent_.lock());
|
|
gttoc(BayesTreeCliqueBase_shortcut);
|
|
FactorGraphType p_Cp_B(parent->shortcut(B, function)); // P(Sp||B)
|
|
gttic(BayesTreeCliqueBase_shortcut);
|
|
p_Cp_B += parent->conditional_; // P(Fp|Sp)
|
|
|
|
// Determine the variables we want to keepSet, S union B
|
|
std::vector<Key> keep = shortcut_indices(B, p_Cp_B);
|
|
|
|
// Marginalize out everything except S union B
|
|
boost::shared_ptr<FactorGraphType> p_S_B = p_Cp_B.marginal(keep, function);
|
|
return *p_S_B->eliminatePartialSequential(S_setminus_B, function).first;
|
|
}
|
|
else
|
|
{
|
|
return BayesNetType();
|
|
}
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
// separator marginal, uses separator marginal of parent recursively
|
|
// P(C) = P(F|S) P(S)
|
|
/* ************************************************************************* */
|
|
template<class DERIVED, class FACTORGRAPH>
|
|
typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType
|
|
BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::separatorMarginal(Eliminate function) const
|
|
{
|
|
gttic(BayesTreeCliqueBase_separatorMarginal);
|
|
// Check if the Separator marginal was already calculated
|
|
if (!cachedSeparatorMarginal_)
|
|
{
|
|
gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss);
|
|
// If this is the root, there is no separator
|
|
if (parent_.expired() /*(if we're the root)*/)
|
|
{
|
|
// we are root, return empty
|
|
FactorGraphType empty;
|
|
cachedSeparatorMarginal_ = empty;
|
|
}
|
|
else
|
|
{
|
|
// Obtain P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp)
|
|
// initialize P(Cp) with the parent separator marginal
|
|
derived_ptr parent(parent_.lock());
|
|
gttoc(BayesTreeCliqueBase_separatorMarginal_cachemiss); // Flatten recursion in timing outline
|
|
gttoc(BayesTreeCliqueBase_separatorMarginal);
|
|
FactorGraphType p_Cp(parent->separatorMarginal(function)); // P(Sp)
|
|
gttic(BayesTreeCliqueBase_separatorMarginal);
|
|
gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss);
|
|
// now add the parent conditional
|
|
// p_Cp += parent->conditional_; // P(Fp|Sp) // FIXME: assign error?
|
|
p_Cp.push_back(parent->conditional_);
|
|
|
|
// The variables we want to keepSet are exactly the ones in S
|
|
std::vector<Key> indicesS(this->conditional()->beginParents(), this->conditional()->endParents());
|
|
cachedSeparatorMarginal_ = *p_Cp.marginalMultifrontalBayesNet(Ordering(indicesS), boost::none, function);
|
|
}
|
|
}
|
|
|
|
// return the shortcut P(S||B)
|
|
return *cachedSeparatorMarginal_; // return the cached version
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
// marginal2, uses separator marginal of parent recursively
|
|
// P(C) = P(F|S) P(S)
|
|
/* ************************************************************************* */
|
|
template<class DERIVED, class FACTORGRAPH>
|
|
typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType
|
|
BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::marginal2(Eliminate function) const
|
|
{
|
|
gttic(BayesTreeCliqueBase_marginal2);
|
|
// initialize with separator marginal P(S)
|
|
FactorGraphType p_C = this->separatorMarginal(function);
|
|
// add the conditional P(F|S)
|
|
p_C += boost::shared_ptr<FactorType>(this->conditional_);
|
|
return p_C;
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
template<class DERIVED, class FACTORGRAPH>
|
|
void BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::deleteCachedShortcuts() {
|
|
|
|
// When a shortcut is requested, all of the shortcuts between it and the
|
|
// root are also generated. So, if this clique's cached shortcut is set,
|
|
// recursively call over all child cliques. Otherwise, it is unnecessary.
|
|
if (cachedSeparatorMarginal_) {
|
|
BOOST_FOREACH(derived_ptr& child, children) {
|
|
child->deleteCachedShortcuts();
|
|
}
|
|
|
|
//Delete CachedShortcut for this clique
|
|
cachedSeparatorMarginal_ = boost::none;
|
|
}
|
|
|
|
}
|
|
|
|
}
|