gtsam/gtsam/nonlinear/Marginals.cpp

176 lines
6.4 KiB
C++

/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file Marginals.cpp
* @brief
* @author Richard Roberts
* @date May 14, 2012
*/
#include <gtsam/base/timing.h>
#include <gtsam/linear/JacobianFactor.h>
#include <gtsam/linear/HessianFactor.h>
#include <gtsam/nonlinear/Marginals.h>
using namespace std;
namespace gtsam {
/* ************************************************************************* */
Marginals::Marginals(const NonlinearFactorGraph& graph, const Values& solution, Factorization factorization,
EliminateableFactorGraph<GaussianFactorGraph>::OptionalOrdering ordering)
: values_(solution), factorization_(factorization) {
gttic(MarginalsConstructor);
graph_ = *graph.linearize(solution);
computeBayesTree(ordering);
}
/* ************************************************************************* */
Marginals::Marginals(const GaussianFactorGraph& graph, const VectorValues& solution, Factorization factorization,
EliminateableFactorGraph<GaussianFactorGraph>::OptionalOrdering ordering)
: graph_(graph), factorization_(factorization) {
gttic(MarginalsConstructor);
Values vals;
for (const VectorValues::KeyValuePair& v: solution) {
vals.insert(v.first, v.second);
}
values_ = vals;
computeBayesTree(ordering);
}
/* ************************************************************************* */
Marginals::Marginals(const GaussianFactorGraph& graph, const Values& solution, Factorization factorization,
EliminateableFactorGraph<GaussianFactorGraph>::OptionalOrdering ordering)
: graph_(graph), values_(solution), factorization_(factorization) {
gttic(MarginalsConstructor);
computeBayesTree(ordering);
}
/* ************************************************************************* */
void Marginals::computeBayesTree(EliminateableFactorGraph<GaussianFactorGraph>::OptionalOrdering ordering) {
// Compute BayesTree
factorization_ = factorization;
if(factorization_ == CHOLESKY)
bayesTree_ = *graph_.eliminateMultifrontal(ordering, EliminatePreferCholesky);
else if(factorization_ == QR)
bayesTree_ = *graph_.eliminateMultifrontal(ordering, EliminateQR);
}
/* ************************************************************************* */
void Marginals::print(const std::string& str, const KeyFormatter& keyFormatter) const
{
graph_.print(str+"Graph: ");
values_.print(str+"Solution: ", keyFormatter);
bayesTree_.print(str+"Bayes Tree: ");
}
/* ************************************************************************* */
GaussianFactor::shared_ptr Marginals::marginalFactor(Key variable) const {
gttic(marginalFactor);
// Compute marginal factor
if(factorization_ == CHOLESKY)
return bayesTree_.marginalFactor(variable, EliminatePreferCholesky);
else if(factorization_ == QR)
return bayesTree_.marginalFactor(variable, EliminateQR);
else
throw std::runtime_error("Marginals::marginalFactor: Unknown factorization");
}
/* ************************************************************************* */
Matrix Marginals::marginalInformation(Key variable) const {
// Get information matrix (only store upper-right triangle)
gttic(marginalInformation);
return marginalFactor(variable)->information();
}
/* ************************************************************************* */
Matrix Marginals::marginalCovariance(Key variable) const {
return marginalInformation(variable).inverse();
}
/* ************************************************************************* */
JointMarginal Marginals::jointMarginalCovariance(const KeyVector& variables) const {
JointMarginal info = jointMarginalInformation(variables);
info.blockMatrix_.invertInPlace();
return info;
}
/* ************************************************************************* */
JointMarginal Marginals::jointMarginalInformation(const KeyVector& variables) const {
// If 2 variables, we can use the BayesTree::joint function, otherwise we
// have to use sequential elimination.
if(variables.size() == 1)
{
Matrix info = marginalInformation(variables.front());
std::vector<size_t> dims;
dims.push_back(info.rows());
return JointMarginal(info, dims, variables);
}
else
{
// Compute joint marginal factor graph.
GaussianFactorGraph jointFG;
if(variables.size() == 2) {
if(factorization_ == CHOLESKY)
jointFG = *bayesTree_.joint(variables[0], variables[1], EliminatePreferCholesky);
else if(factorization_ == QR)
jointFG = *bayesTree_.joint(variables[0], variables[1], EliminateQR);
} else {
if(factorization_ == CHOLESKY)
jointFG = GaussianFactorGraph(*graph_.marginalMultifrontalBayesTree(variables, boost::none, EliminatePreferCholesky));
else if(factorization_ == QR)
jointFG = GaussianFactorGraph(*graph_.marginalMultifrontalBayesTree(variables, boost::none, EliminateQR));
}
// Get information matrix
Matrix augmentedInfo = jointFG.augmentedHessian();
Matrix info = augmentedInfo.topLeftCorner(augmentedInfo.rows()-1, augmentedInfo.cols()-1);
// Information matrix will be returned with sorted keys
KeyVector variablesSorted = variables;
std::sort(variablesSorted.begin(), variablesSorted.end());
// Get dimensions from factor graph
std::vector<size_t> dims;
dims.reserve(variablesSorted.size());
for(Key key: variablesSorted) {
dims.push_back(values_.at(key).dim());
}
return JointMarginal(info, dims, variablesSorted);
}
}
/* ************************************************************************* */
VectorValues Marginals::optimize() const {
return bayesTree_.optimize();
}
/* ************************************************************************* */
void JointMarginal::print(const std::string& s, const KeyFormatter& formatter) const {
cout << s << "Joint marginal on keys ";
bool first = true;
for(Key key: keys_) {
if(!first)
cout << ", ";
else
first = false;
cout << formatter(key);
}
cout << ". Use 'at' or 'operator()' to query matrix blocks." << endl;
}
} /* namespace gtsam */