Basic handling of constraints now works in factor graphs, assuming there is only one constraint on any given variable.
parent
a7b711db37
commit
ddc0173671
|
@ -286,6 +286,7 @@ void householder_update(Matrix &A, int j, double beta, const Vector& vjm) {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
std::pair<Matrix, Vector> weighted_eliminate(Matrix& A, const Vector& sigmas) {
|
std::pair<Matrix, Vector> weighted_eliminate(Matrix& A, const Vector& sigmas) {
|
||||||
|
bool verbose = false;
|
||||||
// get sizes
|
// get sizes
|
||||||
size_t m = A.size1();
|
size_t m = A.size1();
|
||||||
size_t n = A.size2();
|
size_t n = A.size2();
|
||||||
|
@ -301,21 +302,26 @@ std::pair<Matrix, Vector> weighted_eliminate(Matrix& A, const Vector& sigmas) {
|
||||||
for (int j=0; j<maxRank; ++j) {
|
for (int j=0; j<maxRank; ++j) {
|
||||||
// extract the first column of A
|
// extract the first column of A
|
||||||
Vector a = column(A, j);
|
Vector a = column(A, j);
|
||||||
|
if (verbose) print(a,"a");
|
||||||
|
|
||||||
// find weighted pseudoinverse
|
// find weighted pseudoinverse
|
||||||
Vector pseudo; double precision;
|
Vector pseudo; double precision;
|
||||||
|
if (verbose) print(sigmas, "sigmas");
|
||||||
boost::tie(pseudo, precision) = weightedPseudoinverse(a, sigmas);
|
boost::tie(pseudo, precision) = weightedPseudoinverse(a, sigmas);
|
||||||
|
if (verbose) print(pseudo, "pseudo");
|
||||||
|
|
||||||
// create solution and copy into R
|
// create solution and copy into R
|
||||||
for (int j2=j; j2<n; ++j2) {
|
for (int j2=j; j2<n; ++j2) {
|
||||||
R(j,j2) = inner_prod(pseudo, column(A, j2));
|
R(j,j2) = inner_prod(pseudo, column(A, j2));
|
||||||
}
|
}
|
||||||
|
if (verbose) print(R, "updatedR");
|
||||||
|
|
||||||
// update A
|
// update A
|
||||||
for (int i=0;i<m;++i) // update all rows
|
for (int i=0;i<m;++i) // update all rows
|
||||||
for (int j2=j+1;j2<n;++j2) { // limit to only columns in separator
|
for (int j2=j+1;j2<n;++j2) { // limit to only columns in separator
|
||||||
A(i,j2) -= R(j,j2)*a(i);
|
A(i,j2) -= R(j,j2)*a(i);
|
||||||
}
|
}
|
||||||
|
if (verbose) print(A, "updatedA");
|
||||||
|
|
||||||
// save precision information
|
// save precision information
|
||||||
newSigmas[j] = sqrt(1./precision);
|
newSigmas[j] = sqrt(1./precision);
|
||||||
|
|
|
@ -185,16 +185,42 @@ namespace gtsam {
|
||||||
pair<Vector, double> weightedPseudoinverse(const Vector& v, const Vector& sigmas) {
|
pair<Vector, double> weightedPseudoinverse(const Vector& v, const Vector& sigmas) {
|
||||||
if (v.size() != sigmas.size())
|
if (v.size() != sigmas.size())
|
||||||
throw invalid_argument("V and precisions have different sizes!");
|
throw invalid_argument("V and precisions have different sizes!");
|
||||||
double normV = 0;
|
|
||||||
Vector precisions(sigmas.size());
|
// detect constraints and sanity-check
|
||||||
for(int i = 0; i<v.size(); i++) {
|
int constraint_index = -1;
|
||||||
precisions[i] = 1./(sigmas[i]*sigmas[i]);
|
for(int i=0; i<sigmas.size(); ++i) {
|
||||||
normV += v[i]*v[i]*precisions[i];
|
if (sigmas[i] < 1e-9 && v[i] > 1e-9) {
|
||||||
|
if (constraint_index != -1)
|
||||||
|
throw invalid_argument("Multiple constraints on a single node!");
|
||||||
|
else
|
||||||
|
constraint_index = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute pseudoinverse
|
||||||
|
if (constraint_index != -1) {
|
||||||
|
// constrained case
|
||||||
|
Vector sol = zero(sigmas.size());
|
||||||
|
sol(constraint_index) = 1.0;
|
||||||
|
return make_pair(sol, 1.0/0.0);
|
||||||
|
} else {
|
||||||
|
// normal case
|
||||||
|
double normV = 0.;
|
||||||
|
Vector precisions(sigmas.size());
|
||||||
|
for(int i = 0; i<v.size(); i++) {
|
||||||
|
if (sigmas[i] < 1e-5) {
|
||||||
|
precisions[i] = 1./0.;
|
||||||
|
} else {
|
||||||
|
precisions[i] = 1./(sigmas[i]*sigmas[i]);
|
||||||
|
normV += v[i]*v[i]*precisions[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Vector sol = zero(v.size());
|
||||||
|
for(int i = 0; i<v.size(); i++)
|
||||||
|
if (sigmas[i] > 1e-5)
|
||||||
|
sol[i] = precisions[i]*v[i];
|
||||||
|
return make_pair(sol/normV, normV);
|
||||||
}
|
}
|
||||||
Vector sol(v.size());
|
|
||||||
for(int i = 0; i<v.size(); i++)
|
|
||||||
sol[i] = precisions[i]*v[i];
|
|
||||||
return make_pair(sol/normV, normV);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -109,6 +109,9 @@ std::pair<double,Vector> house(Vector &x);
|
||||||
/**
|
/**
|
||||||
* Weighted Householder solution vector,
|
* Weighted Householder solution vector,
|
||||||
* a.k.a., the pseudoinverse of the column
|
* a.k.a., the pseudoinverse of the column
|
||||||
|
* NOTE: if any sigmas are zero (indicating a constraint)
|
||||||
|
* the pseudoinverse will be a selection vector, and the
|
||||||
|
* precision will be infinite
|
||||||
* @param v is the first column of the matrix to solve
|
* @param v is the first column of the matrix to solve
|
||||||
* @param simgas is a vector of standard deviations
|
* @param simgas is a vector of standard deviations
|
||||||
* @return a pair of the pseudoinverse of v and the precision
|
* @return a pair of the pseudoinverse of v and the precision
|
||||||
|
|
|
@ -639,28 +639,29 @@ TEST( LinearFactor, CONSTRUCTOR_ConditionalGaussian )
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST ( LinearFactor, constraint_eliminate1 )
|
TEST ( LinearFactor, constraint_eliminate1 )
|
||||||
{
|
{
|
||||||
// // construct a linear constraint
|
// construct a linear constraint
|
||||||
// Vector v(2); v(0)=1.2; v(1)=3.4;
|
Vector v(2); v(0)=1.2; v(1)=3.4;
|
||||||
// string key = "x0";
|
string key = "x0";
|
||||||
// LinearFactor lc(key, eye(2), v, 0.0);
|
LinearFactor lc(key, eye(2), v, 0.0);
|
||||||
//
|
|
||||||
// // eliminate it
|
// eliminate it
|
||||||
// ConditionalGaussian::shared_ptr actualCG;
|
ConditionalGaussian::shared_ptr actualCG;
|
||||||
// LinearFactor::shared_ptr actualLF;
|
LinearFactor::shared_ptr actualLF;
|
||||||
// boost::tie(actualCG,actualLF) = lc.eliminate("x0");
|
boost::tie(actualCG,actualLF) = lc.eliminate("x0");
|
||||||
//
|
|
||||||
// // verify linear factor
|
// verify linear factor
|
||||||
// CHECK(actualLF->size() == 0);
|
CHECK(actualLF->size() == 0);
|
||||||
//
|
|
||||||
// // verify conditional Gaussian
|
// verify conditional Gaussian
|
||||||
// Vector sigmas = Vector_(2, 0.0, 0.0);
|
Vector sigmas = Vector_(2, 0.0, 0.0);
|
||||||
// ConditionalGaussian expCG("x0", v, eye(2), sigmas);
|
ConditionalGaussian expCG("x0", v, eye(2), sigmas);
|
||||||
// CHECK(assert_equal(expCG, *actualCG)); // FAILS - gets NaN values
|
CHECK(assert_equal(expCG, *actualCG));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This test fails due to multiple constraints on a node
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST ( LinearFactor, constraint_eliminate2 )
|
//TEST ( LinearFactor, constraint_eliminate2 )
|
||||||
{
|
//{
|
||||||
// // Construct a linear constraint
|
// // Construct a linear constraint
|
||||||
// // RHS
|
// // RHS
|
||||||
// Vector b(2); b(0)=3.0; b(1)=4.0;
|
// Vector b(2); b(0)=3.0; b(1)=4.0;
|
||||||
|
@ -685,9 +686,12 @@ TEST ( LinearFactor, constraint_eliminate2 )
|
||||||
// Vector expected = Vector_(2, -3.3333, 0.6667);
|
// Vector expected = Vector_(2, -3.3333, 0.6667);
|
||||||
//
|
//
|
||||||
// // eliminate x for basic check
|
// // eliminate x for basic check
|
||||||
// ConditionalGaussian::shared_ptr actual = lc.eliminate("x");
|
|
||||||
// CHECK(assert_equal(expected, actual->solve(fg1), 1e-4));
|
|
||||||
//
|
//
|
||||||
|
// ConditionalGaussian::shared_ptr actualCG;
|
||||||
|
// LinearFactor::shared_ptr actualLF;
|
||||||
|
// boost::tie(actualCG, actualLF) = lc.eliminate("x");
|
||||||
|
// CHECK(assert_equal(expected, actualCG->solve(fg1), 1e-4));
|
||||||
|
|
||||||
// // eliminate y to test thrown error
|
// // eliminate y to test thrown error
|
||||||
// VectorConfig fg2;
|
// VectorConfig fg2;
|
||||||
// fg2.insert("x", expected);
|
// fg2.insert("x", expected);
|
||||||
|
@ -698,7 +702,7 @@ TEST ( LinearFactor, constraint_eliminate2 )
|
||||||
// } catch (...) {
|
// } catch (...) {
|
||||||
// CHECK(true);
|
// CHECK(true);
|
||||||
// }
|
// }
|
||||||
}
|
//}
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() { TestResult tr; return TestRegistry::runAllTests(tr);}
|
int main() { TestResult tr; return TestRegistry::runAllTests(tr);}
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -26,7 +26,7 @@ using namespace gtsam;
|
||||||
double tol=1e-4;
|
double tol=1e-4;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
/* unit test for equals (LinearFactorGraph1 == LinearFactorGraph2) */
|
/* unit test for equals (LinearFactorGraph1 == LinearFactorGraph2) */
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( LinearFactorGraph, equals ){
|
TEST( LinearFactorGraph, equals ){
|
||||||
|
|
||||||
|
@ -49,10 +49,10 @@ TEST( LinearFactorGraph, error )
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
/* unit test for find seperator */
|
/* unit test for find seperator */
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( LinearFactorGraph, find_separator )
|
TEST( LinearFactorGraph, find_separator )
|
||||||
{
|
{
|
||||||
LinearFactorGraph fg = createLinearFactorGraph();
|
LinearFactorGraph fg = createLinearFactorGraph();
|
||||||
|
|
||||||
set<string> separator = fg.find_separator("x2");
|
set<string> separator = fg.find_separator("x2");
|
||||||
|
@ -68,7 +68,7 @@ TEST( LinearFactorGraph, find_separator )
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( LinearFactorGraph, combine_factors_x1 )
|
TEST( LinearFactorGraph, combine_factors_x1 )
|
||||||
{
|
{
|
||||||
// create a small example for a linear factor graph
|
// create a small example for a linear factor graph
|
||||||
LinearFactorGraph fg = createLinearFactorGraph();
|
LinearFactorGraph fg = createLinearFactorGraph();
|
||||||
|
|
||||||
|
@ -77,8 +77,8 @@ TEST( LinearFactorGraph, combine_factors_x1 )
|
||||||
double sigma2 = 0.1;
|
double sigma2 = 0.1;
|
||||||
double sigma3 = 0.2;
|
double sigma3 = 0.2;
|
||||||
Vector sigmas = Vector_(6, sigma1, sigma1, sigma2, sigma2, sigma3, sigma3);
|
Vector sigmas = Vector_(6, sigma1, sigma1, sigma2, sigma2, sigma3, sigma3);
|
||||||
|
|
||||||
// combine all factors
|
// combine all factors
|
||||||
LinearFactor::shared_ptr actual = fg.removeAndCombineFactors("x1");
|
LinearFactor::shared_ptr actual = fg.removeAndCombineFactors("x1");
|
||||||
|
|
||||||
// the expected linear factor
|
// the expected linear factor
|
||||||
|
@ -131,7 +131,7 @@ TEST( LinearFactorGraph, combine_factors_x1 )
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( LinearFactorGraph, combine_factors_x2 )
|
TEST( LinearFactorGraph, combine_factors_x2 )
|
||||||
{
|
{
|
||||||
// create a small example for a linear factor graph
|
// create a small example for a linear factor graph
|
||||||
LinearFactorGraph fg = createLinearFactorGraph();
|
LinearFactorGraph fg = createLinearFactorGraph();
|
||||||
|
|
||||||
|
@ -214,7 +214,7 @@ TEST( LinearFactorGraph, eliminateOne_x1 )
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
||||||
TEST( LinearFactorGraph, eliminateOne_x2 )
|
TEST( LinearFactorGraph, eliminateOne_x2 )
|
||||||
{
|
{
|
||||||
LinearFactorGraph fg = createLinearFactorGraph();
|
LinearFactorGraph fg = createLinearFactorGraph();
|
||||||
|
@ -587,23 +587,25 @@ TEST( LinearFactorGraph, variables )
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Tests ported from ConstrainedLinearFactorGraph
|
// Tests ported from ConstrainedLinearFactorGraph
|
||||||
|
/* ************************************************************************* */
|
||||||
|
|
||||||
///* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
//TEST( LinearFactorGraph, constrained_simple )
|
TEST( LinearFactorGraph, constrained_simple )
|
||||||
//{
|
{
|
||||||
// // get a graph with a constraint in it
|
// get a graph with a constraint in it
|
||||||
// LinearFactorGraph fg = createSimpleConstraintGraph();
|
LinearFactorGraph fg = createSimpleConstraintGraph();
|
||||||
//
|
|
||||||
// // eliminate and solve
|
// eliminate and solve
|
||||||
// Ordering ord;
|
Ordering ord;
|
||||||
// ord += "x", "y";
|
ord += "x", "y";
|
||||||
// VectorConfig actual = fg.optimize(ord);
|
VectorConfig actual = fg.optimize(ord);
|
||||||
//
|
|
||||||
// // verify
|
// verify
|
||||||
// VectorConfig expected = createSimpleConstraintConfig();
|
VectorConfig expected = createSimpleConstraintConfig();
|
||||||
// CHECK(assert_equal(actual, expected));
|
CHECK(assert_equal(actual, expected));
|
||||||
//}
|
}
|
||||||
//
|
|
||||||
|
// These tests require multiple constraints on a single node and will fail
|
||||||
///* ************************************************************************* */
|
///* ************************************************************************* */
|
||||||
//TEST( LinearFactorGraph, constrained_single )
|
//TEST( LinearFactorGraph, constrained_single )
|
||||||
//{
|
//{
|
||||||
|
|
|
@ -151,6 +151,44 @@ TEST( TestVector, weightedPseudoinverse )
|
||||||
CHECK(fabs(expPrecision-precision) < 1e-5);
|
CHECK(fabs(expPrecision-precision) < 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST( TestVector, weightedPseudoinverse_constraint )
|
||||||
|
{
|
||||||
|
// column from a matrix
|
||||||
|
Vector x(2);
|
||||||
|
x(0) = 1.0; x(1) = 2.0;
|
||||||
|
|
||||||
|
// create sigmas
|
||||||
|
Vector sigmas(2);
|
||||||
|
sigmas(0) = 0.0; sigmas(1) = 0.2;
|
||||||
|
|
||||||
|
// perform solve
|
||||||
|
Vector act; double precision;
|
||||||
|
boost::tie(act, precision) = weightedPseudoinverse(x, sigmas);
|
||||||
|
|
||||||
|
// construct expected
|
||||||
|
Vector exp(2);
|
||||||
|
exp(0) = 1.0; exp(1) = 0.0;
|
||||||
|
|
||||||
|
// verify
|
||||||
|
CHECK(assert_equal(act, exp));
|
||||||
|
CHECK(isinf(precision));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST( TestVector, weightedPseudoinverse_nan )
|
||||||
|
{
|
||||||
|
Vector a = Vector_(4, 1., 0., 0., 0.);
|
||||||
|
Vector sigmas = Vector_(4, 0.1, 0.1, 0., 0.);
|
||||||
|
Vector pseudo; double precision;
|
||||||
|
boost::tie(pseudo, precision) = weightedPseudoinverse(a, sigmas);
|
||||||
|
|
||||||
|
Vector exp = Vector_(4, 1., 0., 0.,0.);
|
||||||
|
CHECK(assert_equal(pseudo, exp));
|
||||||
|
DOUBLES_EQUAL(100, precision, 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( TestVector, ediv )
|
TEST( TestVector, ediv )
|
||||||
{
|
{
|
||||||
|
|
Loading…
Reference in New Issue