gtsam/gtsam/inference/EliminationTree-inl.h

219 lines
7.3 KiB
C++

/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file EliminationTree-inl.h
* @author Frank Dellaert
* @date Oct 13, 2010
*/
#pragma once
#include <gtsam/base/timing.h>
#include <gtsam/base/FastSet.h>
#include <gtsam/inference/EliminationTree.h>
#include <gtsam/inference/VariableSlots.h>
#include <gtsam/inference/IndexFactor.h>
#include <gtsam/inference/IndexConditional.h>
#include <boost/foreach.hpp>
#include <boost/lambda/lambda.hpp>
#include <boost/static_assert.hpp>
#include <iostream>
#include <set>
#include <vector>
using namespace std;
namespace gtsam {
/* ************************************************************************* */
template<class FACTOR>
typename EliminationTree<FACTOR>::sharedFactor EliminationTree<FACTOR>::eliminate_(
Eliminate function, Conditionals& conditionals) const {
static const bool debug = false;
if(debug) cout << "ETree: eliminating " << this->key_ << endl;
// Create the list of factors to be eliminated, initially empty, and reserve space
FactorGraph<FACTOR> factors;
factors.reserve(this->factors_.size() + this->subTrees_.size());
// Add all factors associated with the current node
factors.push_back(this->factors_.begin(), this->factors_.end());
// for all subtrees, eliminate into Bayes net and a separator factor, added to [factors]
BOOST_FOREACH(const shared_ptr& child, subTrees_)
factors.push_back(child->eliminate_(function, conditionals)); // TODO: spawn thread
// TODO: wait for completion of all threads
// Combine all factors (from this node and from subtrees) into a joint factor
typename FactorGraph<FACTOR>::EliminationResult
eliminated(function(factors, 1));
conditionals[this->key_] = eliminated.first;
if(debug) cout << "Eliminated " << this->key_ << " to get:\n";
if(debug) eliminated.first->print("Conditional: ");
if(debug) eliminated.second->print("Factor: ");
return eliminated.second;
}
/* ************************************************************************* */
template<class FACTOR>
vector<Index> EliminationTree<FACTOR>::ComputeParents(const VariableIndex& structure) {
// Number of factors and variables
const size_t m = structure.nFactors();
const size_t n = structure.size();
static const Index none = numeric_limits<Index>::max();
// Allocate result parent vector and vector of last factor columns
vector<Index> parents(n, none);
vector<Index> prevCol(m, none);
// for column j \in 1 to n do
for (Index j = 0; j < n; j++) {
// for row i \in Struct[A*j] do
BOOST_FOREACH(const size_t i, structure[j]) {
if (prevCol[i] != none) {
Index k = prevCol[i];
// find root r of the current tree that contains k
Index r = k;
while (parents[r] != none)
r = parents[r];
if (r != j) parents[r] = j;
}
prevCol[i] = j;
}
}
return parents;
}
/* ************************************************************************* */
template<class FACTOR>
template<class DERIVEDFACTOR>
typename EliminationTree<FACTOR>::shared_ptr EliminationTree<FACTOR>::Create(
const FactorGraph<DERIVEDFACTOR>& factorGraph,
const VariableIndex& structure) {
static const bool debug = false;
tic(1, "ET ComputeParents");
// Compute the tree structure
vector<Index> parents(ComputeParents(structure));
toc(1, "ET ComputeParents");
// Number of variables
const size_t n = structure.size();
static const Index none = numeric_limits<Index>::max();
// Create tree structure
tic(2, "assemble tree");
vector<shared_ptr> trees(n);
for (Index k = 1; k <= n; k++) {
Index j = n - k; // Start at the last variable and loop down to 0
trees[j].reset(new EliminationTree(j)); // Create a new node on this variable
if (parents[j] != none) // If this node has a parent, add it to the parent's children
trees[parents[j]]->add(trees[j]);
else if(!structure[j].empty() && j != n - 1) // If a node other than the last has no parents, this is a forest
throw DisconnectedGraphException();
}
toc(2, "assemble tree");
// Hang factors in right places
tic(3, "hang factors");
BOOST_FOREACH(const typename DERIVEDFACTOR::shared_ptr& derivedFactor, factorGraph) {
// Here we upwards-cast to the factor type of this EliminationTree. This
// allows performing symbolic elimination on, for example, GaussianFactors.
if(derivedFactor) {
sharedFactor factor(derivedFactor);
Index j = *std::min_element(factor->begin(), factor->end());
trees[j]->add(factor);
}
}
toc(3, "hang factors");
if(debug)
trees.back()->print("ETree: ");
return trees.back();
}
/* ************************************************************************* */
template<class FACTOR>
template<class DERIVEDFACTOR>
typename EliminationTree<FACTOR>::shared_ptr
EliminationTree<FACTOR>::Create(const FactorGraph<DERIVEDFACTOR>& factorGraph) {
// Build variable index
tic(0, "ET Create, variable index");
const VariableIndex variableIndex(factorGraph);
toc(0, "ET Create, variable index");
// Build elimination tree
return Create(factorGraph, variableIndex);
}
/* ************************************************************************* */
template<class FACTORGRAPH>
void EliminationTree<FACTORGRAPH>::print(const std::string& name) const {
cout << name << " (" << key_ << ")" << endl;
BOOST_FOREACH(const sharedFactor& factor, factors_) {
factor->print(name + " "); }
BOOST_FOREACH(const shared_ptr& child, subTrees_) {
child->print(name + " "); }
}
/* ************************************************************************* */
template<class FACTORGRAPH>
bool EliminationTree<FACTORGRAPH>::equals(const EliminationTree<FACTORGRAPH>& expected, double tol) const {
if(this->key_ == expected.key_ && this->factors_ == expected.factors_
&& this->subTrees_.size() == expected.subTrees_.size()) {
typename SubTrees::const_iterator this_subtree = this->subTrees_.begin();
typename SubTrees::const_iterator expected_subtree = expected.subTrees_.begin();
while(this_subtree != this->subTrees_.end())
if( ! (*(this_subtree++))->equals(**(expected_subtree++), tol))
return false;
return true;
} else
return false;
}
/* ************************************************************************* */
template<class FACTOR>
typename EliminationTree<FACTOR>::BayesNet::shared_ptr
EliminationTree<FACTOR>::eliminate(Eliminate function) const {
// call recursive routine
tic(1, "ET recursive eliminate");
size_t nrConditionals = this->key_ + 1; // root key has highest index
Conditionals conditionals(nrConditionals); // reserve a vector of conditional shared pointers
(void)eliminate_(function, conditionals); // modify in place
toc(1, "ET recursive eliminate");
// Add conditionals to BayesNet
tic(2, "assemble BayesNet");
typename BayesNet::shared_ptr bayesNet(new BayesNet);
BOOST_FOREACH(const typename BayesNet::sharedConditional& conditional, conditionals) {
if(conditional)
bayesNet->push_back(conditional);
}
toc(2, "assemble BayesNet");
return bayesNet;
}
}