133 lines
3.8 KiB
C++
133 lines
3.8 KiB
C++
/*
|
|
* iterative.cpp
|
|
* @brief Iterative methods, implementation
|
|
* @author Frank Dellaert
|
|
* Created on: Dec 28, 2009
|
|
*/
|
|
|
|
#include "GaussianFactorGraph.h"
|
|
#include "iterative.h"
|
|
|
|
using namespace std;
|
|
|
|
namespace gtsam {
|
|
|
|
/* ************************************************************************* */
|
|
// Method of conjugate gradients (CG) template
|
|
// "System" class S needs gradient(S,v), e=S*v, v=S^e
|
|
// "Vector" class V needs dot(v,v), -v, v+v, s*v
|
|
// "Vector" class E needs dot(v,v)
|
|
// if (steepest) does steepest descent
|
|
template<class S, class V, class E>
|
|
V conjugateGradients(const S& Ab, V x, size_t maxIterations, double epsilon,
|
|
bool steepest = false) {
|
|
|
|
if (maxIterations == 0) maxIterations = dim(x);
|
|
|
|
// Start with g0 = A'*(A*x0-b), d0 = - g0
|
|
// i.e., first step is in direction of negative gradient
|
|
V g = gradient(Ab, x);
|
|
V d = -g;
|
|
double dotg0 = dot(g, g), prev_dotg = dotg0;
|
|
double threshold = epsilon * epsilon * dotg0;
|
|
|
|
// loop max n times
|
|
size_t n = x.size();
|
|
for (size_t k = 0; k < maxIterations; k++) {
|
|
|
|
// calculate optimal step-size
|
|
E Ad = Ab * d;
|
|
double alpha = -dot(d, g) / dot(Ad, Ad);
|
|
|
|
// do step in new search direction
|
|
x = x + alpha * d;
|
|
if (k == n) break;
|
|
|
|
// update gradient
|
|
g = g + alpha * (Ab ^ Ad);
|
|
|
|
// check for convergence
|
|
double dotg = dot(g, g);
|
|
if (dotg < threshold) break;
|
|
|
|
// calculate new search direction
|
|
if (steepest)
|
|
d = -g;
|
|
else {
|
|
double beta = dotg / prev_dotg;
|
|
prev_dotg = dotg;
|
|
d = -g + beta * d;
|
|
}
|
|
}
|
|
return x;
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
|
|
/** gradient of objective function 0.5*|Ax-b|^2 at x = A'*(Ax-b) */
|
|
Vector gradient(const System& Ab, const Vector& x) {
|
|
const Matrix& A = Ab.first;
|
|
const Vector& b = Ab.second;
|
|
return A ^ (A * x - b);
|
|
}
|
|
|
|
/** Apply operator A */
|
|
Vector operator*(const System& Ab, const Vector& x) {
|
|
const Matrix& A = Ab.first;
|
|
return A * x;
|
|
}
|
|
|
|
/** Apply operator A^T */
|
|
Vector operator^(const System& Ab, const Vector& x) {
|
|
const Matrix& A = Ab.first;
|
|
return A ^ x;
|
|
}
|
|
|
|
Vector steepestDescent(const System& Ab, const Vector& x, double epsilon,
|
|
size_t maxIterations) {
|
|
return conjugateGradients<System, Vector, Vector> (Ab, x, epsilon,
|
|
maxIterations, true);
|
|
}
|
|
|
|
Vector conjugateGradientDescent(const System& Ab, const Vector& x,
|
|
double epsilon, size_t maxIterations) {
|
|
return conjugateGradients<System, Vector, Vector> (Ab, x, epsilon,
|
|
maxIterations);
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
Vector steepestDescent(const Matrix& A, const Vector& b, const Vector& x,
|
|
double epsilon, size_t maxIterations) {
|
|
System Ab = make_pair(A, b);
|
|
return conjugateGradients<System, Vector, Vector> (Ab, x, epsilon,
|
|
maxIterations, true);
|
|
}
|
|
|
|
Vector conjugateGradientDescent(const Matrix& A, const Vector& b,
|
|
const Vector& x, double epsilon, size_t maxIterations) {
|
|
System Ab = make_pair(A, b);
|
|
return conjugateGradients<System, Vector, Vector> (Ab, x, epsilon,
|
|
maxIterations);
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
VectorConfig gradient(const GaussianFactorGraph& fg, const VectorConfig& x) {
|
|
return fg.gradient(x);
|
|
}
|
|
|
|
VectorConfig steepestDescent(const GaussianFactorGraph& fg,
|
|
const VectorConfig& x, double epsilon, size_t maxIterations) {
|
|
return conjugateGradients<GaussianFactorGraph, VectorConfig, Errors> (fg,
|
|
x, epsilon, maxIterations, true);
|
|
}
|
|
|
|
VectorConfig conjugateGradientDescent(const GaussianFactorGraph& fg,
|
|
const VectorConfig& x, double epsilon, size_t maxIterations) {
|
|
return conjugateGradients<GaussianFactorGraph, VectorConfig, Errors> (fg,
|
|
x, epsilon, maxIterations);
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
|
|
} // namespace gtsam
|