Fixed dimensions bug in Marginals and added unit test

release/4.3a0
Richard Roberts 2012-07-23 19:29:52 +00:00
parent 656f573c0a
commit c32d1c7e02
3 changed files with 64 additions and 6 deletions

View File

@ -88,7 +88,7 @@ namespace gtsam {
GenericSequentialSolver<FACTOR>::jointFactorGraph( GenericSequentialSolver<FACTOR>::jointFactorGraph(
const std::vector<Index>& js, Eliminate function) const { const std::vector<Index>& js, Eliminate function) const {
// Compute a COLAMD permutation with the marginal variable constrained to the end. // Compute a COLAMD permutation with the marginal variables constrained to the end.
Permutation::shared_ptr permutation(inference::PermutationCOLAMD(*structure_, js)); Permutation::shared_ptr permutation(inference::PermutationCOLAMD(*structure_, js));
Permutation::shared_ptr permutationInverse(permutation->inverse()); Permutation::shared_ptr permutationInverse(permutation->inverse());

View File

@ -110,11 +110,11 @@ JointMarginal Marginals::jointMarginalInformation(const std::vector<Key>& variab
return JointMarginal(info, dims, indices); return JointMarginal(info, dims, indices);
} else { } else {
// Convert keys to linear indices // Obtain requested variables as ordered indices
vector<Index> indices(variables.size()); vector<Index> indices(variables.size());
for(size_t i=0; i<variables.size(); ++i) { indices[i] = ordering_[variables[i]]; } for(size_t i=0; i<variables.size(); ++i) { indices[i] = ordering_[variables[i]]; }
// Compute joint factor graph // Compute joint marginal factor graph.
GaussianFactorGraph jointFG; GaussianFactorGraph jointFG;
if(variables.size() == 2) { if(variables.size() == 2) {
if(factorization_ == CHOLESKY) if(factorization_ == CHOLESKY)
@ -128,13 +128,15 @@ JointMarginal Marginals::jointMarginalInformation(const std::vector<Key>& variab
jointFG = *GaussianSequentialSolver(graph_, true).jointFactorGraph(indices); jointFG = *GaussianSequentialSolver(graph_, true).jointFactorGraph(indices);
} }
// Conversion from variable keys to position in factor graph variables, // Build map from variable keys to position in factor graph variables,
// which are sorted in index order. // which are sorted in index order.
Ordering variableConversion; Ordering variableConversion;
{ {
// First build map from index to key
FastMap<Index,Key> usedIndices; FastMap<Index,Key> usedIndices;
for(size_t i=0; i<variables.size(); ++i) for(size_t i=0; i<variables.size(); ++i)
usedIndices.insert(make_pair(indices[i], variables[i])); usedIndices.insert(make_pair(indices[i], variables[i]));
// Next run over indices in sorted order
size_t slot = 0; size_t slot = 0;
typedef pair<Index,Key> Index_Key; typedef pair<Index,Key> Index_Key;
BOOST_FOREACH(const Index_Key& index_key, usedIndices) { BOOST_FOREACH(const Index_Key& index_key, usedIndices) {
@ -145,8 +147,9 @@ JointMarginal Marginals::jointMarginalInformation(const std::vector<Key>& variab
// Get dimensions from factor graph // Get dimensions from factor graph
std::vector<size_t> dims(indices.size(), 0); std::vector<size_t> dims(indices.size(), 0);
for(size_t i = 0; i < variables.size(); ++i) BOOST_FOREACH(Key key, variables) {
dims[i] = values_.at(variables[i]).dim(); dims[variableConversion[key]] = values_.at(key).dim();
}
// Get information matrix // Get information matrix
Matrix augmentedInfo = jointFG.denseHessian(); Matrix augmentedInfo = jointFG.denseHessian();

View File

@ -180,7 +180,62 @@ TEST(Marginals, planarSLAMmarginals) {
EXPECT(assert_equal(expectedx1, Matrix(joint_l2x1(x1,x1)), 1e-6)); EXPECT(assert_equal(expectedx1, Matrix(joint_l2x1(x1,x1)), 1e-6));
} }
/* ************************************************************************* */
TEST(Marginals, order) {
NonlinearFactorGraph fg;
fg.add(PriorFactor<Pose2>(0, Pose2(), noiseModel::Unit::Create(3)));
fg.add(BetweenFactor<Pose2>(0, 1, Pose2(1,0,0), noiseModel::Unit::Create(3)));
fg.add(BetweenFactor<Pose2>(1, 2, Pose2(1,0,0), noiseModel::Unit::Create(3)));
fg.add(BetweenFactor<Pose2>(2, 3, Pose2(1,0,0), noiseModel::Unit::Create(3)));
Values vals;
vals.insert(0, Pose2());
vals.insert(1, Pose2(1,0,0));
vals.insert(2, Pose2(2,0,0));
vals.insert(3, Pose2(3,0,0));
vals.insert(100, Point2(0,1));
vals.insert(101, Point2(1,1));
fg.add(BearingRangeFactor<Pose2,Point2>(0, 100,
vals.at<Pose2>(0).bearing(vals.at<Point2>(100)),
vals.at<Pose2>(0).range(vals.at<Point2>(100)), noiseModel::Unit::Create(2)));
fg.add(BearingRangeFactor<Pose2,Point2>(0, 101,
vals.at<Pose2>(0).bearing(vals.at<Point2>(101)),
vals.at<Pose2>(0).range(vals.at<Point2>(101)), noiseModel::Unit::Create(2)));
fg.add(BearingRangeFactor<Pose2,Point2>(1, 100,
vals.at<Pose2>(1).bearing(vals.at<Point2>(100)),
vals.at<Pose2>(1).range(vals.at<Point2>(100)), noiseModel::Unit::Create(2)));
fg.add(BearingRangeFactor<Pose2,Point2>(1, 101,
vals.at<Pose2>(1).bearing(vals.at<Point2>(101)),
vals.at<Pose2>(1).range(vals.at<Point2>(101)), noiseModel::Unit::Create(2)));
fg.add(BearingRangeFactor<Pose2,Point2>(2, 100,
vals.at<Pose2>(2).bearing(vals.at<Point2>(100)),
vals.at<Pose2>(2).range(vals.at<Point2>(100)), noiseModel::Unit::Create(2)));
fg.add(BearingRangeFactor<Pose2,Point2>(2, 101,
vals.at<Pose2>(2).bearing(vals.at<Point2>(101)),
vals.at<Pose2>(2).range(vals.at<Point2>(101)), noiseModel::Unit::Create(2)));
fg.add(BearingRangeFactor<Pose2,Point2>(3, 100,
vals.at<Pose2>(3).bearing(vals.at<Point2>(100)),
vals.at<Pose2>(3).range(vals.at<Point2>(100)), noiseModel::Unit::Create(2)));
fg.add(BearingRangeFactor<Pose2,Point2>(3, 101,
vals.at<Pose2>(3).bearing(vals.at<Point2>(101)),
vals.at<Pose2>(3).range(vals.at<Point2>(101)), noiseModel::Unit::Create(2)));
Marginals marginals(fg, vals);
FastVector<Key> keys(fg.keys());
JointMarginal joint = marginals.jointMarginalCovariance(keys);
LONGS_EQUAL(3, joint(0,0).rows());
LONGS_EQUAL(3, joint(1,1).rows());
LONGS_EQUAL(3, joint(2,2).rows());
LONGS_EQUAL(3, joint(3,3).rows());
LONGS_EQUAL(2, joint(100,100).rows());
LONGS_EQUAL(2, joint(101,101).rows());
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { TestResult tr; return TestRegistry::runAllTests(tr);} int main() { TestResult tr; return TestRegistry::runAllTests(tr);}