Calculate gradient of factor graph objective function.

release/4.3a0
Frank Dellaert 2009-12-10 20:19:15 +00:00
parent 67e1897e47
commit e7a912bd3b
5 changed files with 67 additions and 6 deletions

View File

@ -358,6 +358,24 @@ GaussianFactor::eliminate(const string& key) const
return make_pair(conditional, factor);
}
/* ************************************************************************* */
void GaussianFactor::addGradientContribution(const VectorConfig& x, VectorConfig& g) const {
// calculate the value of the factor
Vector e = -b_;
string j; Matrix Aj;
FOREACH_PAIR(j, Aj, As_) e += Vector(Aj * x[j]);
// transpose
Vector et = trans(e);
// contribute to gradient for each connected variable
FOREACH_PAIR(j, Aj, As_) {
Vector dj = trans(et*Aj); // this factor's contribution to gradient on j
Vector wdj = ediv(dj,emul(sigmas_,sigmas_)); // properly weight by sigmas
g.add(j,wdj);
}
}
/* ************************************************************************* */
namespace gtsam {

View File

@ -241,7 +241,14 @@ public:
* @param m final number of rows of f, needs to be known in advance
* @param pos where to insert in the m-sized matrices
*/
inline void append_factor(GaussianFactor::shared_ptr f, size_t m, size_t pos);
void append_factor(GaussianFactor::shared_ptr f, size_t m, size_t pos);
/**
* Add gradient contribution to gradient config g
* @param x: confif at which to evaluate gradient
* @param g: I/O parameter, evolving gradient
*/
void addGradientContribution(const VectorConfig& x, VectorConfig& g) const;
}; // GaussianFactor

View File

@ -200,3 +200,12 @@ Matrix GaussianFactorGraph::sparse(const Ordering& ordering) const {
}
/* ************************************************************************* */
VectorConfig GaussianFactorGraph::gradient(const VectorConfig& x) const {
VectorConfig g;
// For each factor add the gradient contribution
BOOST_FOREACH(sharedFactor factor,factors_)
factor->addGradientContribution(x,g);
return g;
}
/* ************************************************************************* */

View File

@ -159,6 +159,11 @@ namespace gtsam {
* @param ordering of variables needed for matrix column order
*/
Matrix sparse(const Ordering& ordering) const;
/**
* Calculate Gradient of 0.5*|Ax-b| for a given config
*/
VectorConfig gradient(const VectorConfig& x) const;
};
}

View File

@ -537,6 +537,28 @@ TEST( GaussianFactorGraph, involves )
CHECK(!fg.involves("x3"));
}
/* ************************************************************************* */
TEST( GaussianFactorGraph, gradient )
{
GaussianFactorGraph fg = createGaussianFactorGraph();
// Construct expected gradient
VectorConfig expected;
// 2*f(x) = 100*(x1+c["x1"])^2 + 100*(x2-x1-[0.2;-0.1])^2 + 25*(l1-x1-[0.0;0.2])^2 + 25*(l1-x2-[-0.2;0.3])^2
// worked out: df/dx1 = 100*[0.1;0.1] + 100*[0.2;-0.1]) + 25*[0.0;0.2] = [10+20;10-10+5] = [30;5]
expected.insert("x1",Vector_(2,30.0,5.0));
// from working implementation:
expected.insert("x2",Vector_(2,-25.0, 17.5));
expected.insert("l1",Vector_(2, 5.0,-12.5));
// calculate the gradient at delta=0
VectorConfig delta = createZeroDelta();
VectorConfig actual = fg.gradient(delta);
CHECK(assert_equal(expected,actual));
}
/* ************************************************************************* */
// Tests ported from ConstrainedGaussianFactorGraph
/* ************************************************************************* */
@ -554,7 +576,7 @@ TEST( GaussianFactorGraph, constrained_simple )
// verify
VectorConfig expected = createSimpleConstraintConfig();
CHECK(assert_equal(actual, expected));
CHECK(assert_equal(expected, actual));
}
/* ************************************************************************* */
@ -570,7 +592,7 @@ TEST( GaussianFactorGraph, constrained_single )
// verify
VectorConfig expected = createSingleConstraintConfig();
CHECK(assert_equal(actual, expected));
CHECK(assert_equal(expected, actual));
}
/* ************************************************************************* */
@ -586,7 +608,7 @@ TEST( GaussianFactorGraph, constrained_single2 )
// verify
VectorConfig expected = createSingleConstraintConfig();
CHECK(assert_equal(actual, expected));
CHECK(assert_equal(expected, actual));
}
/* ************************************************************************* */
@ -602,7 +624,7 @@ TEST( GaussianFactorGraph, constrained_multi1 )
// verify
VectorConfig expected = createMultiConstraintConfig();
CHECK(assert_equal(actual, expected));
CHECK(assert_equal(expected, actual));
}
/* ************************************************************************* */
@ -618,7 +640,7 @@ TEST( GaussianFactorGraph, constrained_multi2 )
// verify
VectorConfig expected = createMultiConstraintConfig();
CHECK(assert_equal(actual, expected));
CHECK(assert_equal(expected, actual));
}
/* ************************************************************************* */