Errors:dot, VectorConfig::operator*/-, as a result Conjugate Gradient Descent template now works for factor graphs
parent
5dfd1921e1
commit
1fac98b4cb
|
@ -31,6 +31,17 @@ bool Errors::equals(const Errors& expected, double tol) const {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
double dot(const Errors& a, const Errors& b) {
|
||||||
|
size_t m = a.size();
|
||||||
|
if (b.size()!=m)
|
||||||
|
throw(std::invalid_argument("Errors::dot: incompatible sizes"));
|
||||||
|
double result = 0.0;
|
||||||
|
for (size_t i = 0; i < m; i++)
|
||||||
|
result += gtsam::dot(a[i], b[i]);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
||||||
} // gtsam
|
} // gtsam
|
||||||
|
|
|
@ -28,4 +28,9 @@ namespace gtsam {
|
||||||
|
|
||||||
}; // Errors
|
}; // Errors
|
||||||
|
|
||||||
|
/**
|
||||||
|
* dot product
|
||||||
|
*/
|
||||||
|
double dot(const Errors& a, const Errors& b);
|
||||||
|
|
||||||
} // gtsam
|
} // gtsam
|
||||||
|
|
|
@ -69,6 +69,15 @@ VectorConfig VectorConfig::operator*(double s) const {
|
||||||
return scale(s);
|
return scale(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
VectorConfig VectorConfig::operator-() const {
|
||||||
|
VectorConfig result;
|
||||||
|
string j; Vector v;
|
||||||
|
FOREACH_PAIR(j, v, values)
|
||||||
|
result.insert(j, -v);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void VectorConfig::operator+=(const VectorConfig& b) {
|
void VectorConfig::operator+=(const VectorConfig& b) {
|
||||||
string j; Vector b_j;
|
string j; Vector b_j;
|
||||||
|
|
|
@ -62,7 +62,9 @@ namespace gtsam {
|
||||||
const Vector& get(const std::string& name) const;
|
const Vector& get(const std::string& name) const;
|
||||||
|
|
||||||
/** operator[] syntax for get */
|
/** operator[] syntax for get */
|
||||||
inline const Vector& operator[](const std::string& name) const { return get(name); }
|
inline const Vector& operator[](const std::string& name) const {
|
||||||
|
return get(name);
|
||||||
|
}
|
||||||
|
|
||||||
bool contains(const std::string& name) const {
|
bool contains(const std::string& name) const {
|
||||||
const_iterator it = values.find(name);
|
const_iterator it = values.find(name);
|
||||||
|
@ -79,6 +81,9 @@ namespace gtsam {
|
||||||
VectorConfig scale(double s) const;
|
VectorConfig scale(double s) const;
|
||||||
VectorConfig operator*(double s) const;
|
VectorConfig operator*(double s) const;
|
||||||
|
|
||||||
|
/** Negation */
|
||||||
|
VectorConfig operator-() const;
|
||||||
|
|
||||||
/** Add in place */
|
/** Add in place */
|
||||||
void operator+=(const VectorConfig &b);
|
void operator+=(const VectorConfig &b);
|
||||||
|
|
||||||
|
@ -109,6 +114,9 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
}; // VectorConfig
|
}; // VectorConfig
|
||||||
|
|
||||||
|
/** scalar product */
|
||||||
|
inline VectorConfig operator*(double s, const VectorConfig& x) {return x*s;}
|
||||||
|
|
||||||
/** Dot product */
|
/** Dot product */
|
||||||
double dot(const VectorConfig&, const VectorConfig&);
|
double dot(const VectorConfig&, const VectorConfig&);
|
||||||
|
|
||||||
|
|
|
@ -606,6 +606,11 @@ TEST( GaussianFactorGraph, transposeMultiplication )
|
||||||
CHECK(assert_equal(expected,actual));
|
CHECK(assert_equal(expected,actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
VectorConfig gradient(const GaussianFactorGraph& Ab, const VectorConfig& x) {
|
||||||
|
return Ab.gradient(x);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
typedef pair<Matrix,Vector> System;
|
typedef pair<Matrix,Vector> System;
|
||||||
|
|
||||||
|
@ -640,7 +645,7 @@ Vector operator^(const System& Ab, const Vector& x) {
|
||||||
// "Vector" class V needs dot(v,v), -v, v+v, s*v
|
// "Vector" class V needs dot(v,v), -v, v+v, s*v
|
||||||
// "Vector" class E needs dot(v,v)
|
// "Vector" class E needs dot(v,v)
|
||||||
template <class S, class V, class E>
|
template <class S, class V, class E>
|
||||||
Vector conjugateGradientDescent(const S& Ab, V x, double threshold = 1e-9) {
|
V CGD(const S& Ab, V x, double threshold = 1e-9) {
|
||||||
|
|
||||||
// Start with g0 = A'*(A*x0-b), d0 = - g0
|
// Start with g0 = A'*(A*x0-b), d0 = - g0
|
||||||
// i.e., first step is in direction of negative gradient
|
// i.e., first step is in direction of negative gradient
|
||||||
|
@ -676,11 +681,18 @@ Vector conjugateGradientDescent(const S& Ab, V x, double threshold = 1e-9) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Method of conjugate gradients (CG)
|
// Method of conjugate gradients (CG), Matrix version
|
||||||
Vector conjugateGradientDescent(const Matrix& A, const Vector& b,
|
Vector conjugateGradientDescent(const Matrix& A, const Vector& b,
|
||||||
const Vector& x, double threshold = 1e-9) {
|
const Vector& x, double threshold = 1e-9) {
|
||||||
System Ab = make_pair(A, b);
|
System Ab = make_pair(A, b);
|
||||||
return conjugateGradientDescent<System, Vector, Vector> (Ab, x);
|
return CGD<System, Vector, Vector> (Ab, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Method of conjugate gradients (CG), Gaussian Factor Graph version
|
||||||
|
VectorConfig conjugateGradientDescent(const GaussianFactorGraph& Ab,
|
||||||
|
const VectorConfig& x, double threshold = 1e-9) {
|
||||||
|
return CGD<GaussianFactorGraph, VectorConfig, Errors> (Ab, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -699,8 +711,8 @@ TEST( GaussianFactorGraph, gradientDescent )
|
||||||
CHECK(assert_equal(expected,actual,1e-2));
|
CHECK(assert_equal(expected,actual,1e-2));
|
||||||
|
|
||||||
// Do conjugate gradient descent
|
// Do conjugate gradient descent
|
||||||
VectorConfig actual2 = fg2.conjugateGradientDescent(zero);
|
//VectorConfig actual2 = fg2.conjugateGradientDescent(zero);
|
||||||
//VectorConfig actual2 = conjugateGradientDescent(fg2,zero,zero);
|
VectorConfig actual2 = conjugateGradientDescent(fg2,zero);
|
||||||
CHECK(assert_equal(expected,actual2,1e-2));
|
CHECK(assert_equal(expected,actual2,1e-2));
|
||||||
|
|
||||||
// Do conjugate gradient descent, Matrix version
|
// Do conjugate gradient descent, Matrix version
|
||||||
|
@ -715,7 +727,7 @@ TEST( GaussianFactorGraph, gradientDescent )
|
||||||
|
|
||||||
// Do conjugate gradient descent, System version
|
// Do conjugate gradient descent, System version
|
||||||
System Ab = make_pair(A,b);
|
System Ab = make_pair(A,b);
|
||||||
Vector actualX2 = conjugateGradientDescent<System,Vector,Vector>(Ab,x0);
|
Vector actualX2 = CGD<System,Vector,Vector>(Ab,x0);
|
||||||
CHECK(assert_equal(expectedX,actualX2,1e-9));
|
CHECK(assert_equal(expectedX,actualX2,1e-9));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue