From 9edeb1102c3c57affd4cde46a232b7c1f9d8bb4a Mon Sep 17 00:00:00 2001 From: Richard Roberts Date: Tue, 21 Dec 2010 18:23:56 +0000 Subject: [PATCH] Fixed bug in BayesTree shortcuts and marginals. Also added an input check to Permutation::PushToBack and PullToFront that catches the bad input of duplicate variables caused by the bug. --- gtsam/inference/BayesTree-inl.h | 40 ++++++++++-------- gtsam/inference/Permutation.cpp | 50 ++++++++++++++++++----- gtsam/inference/Permutation.h | 4 +- gtsam/linear/GaussianMultifrontalSolver.h | 2 +- gtsam/nonlinear/NonlinearOptimizer.h | 10 +++++ tests/testInference.cpp | 34 ++++++++++++++- 6 files changed, 108 insertions(+), 32 deletions(-) diff --git a/gtsam/inference/BayesTree-inl.h b/gtsam/inference/BayesTree-inl.h index 24f57cc55..fa2fdecbc 100644 --- a/gtsam/inference/BayesTree-inl.h +++ b/gtsam/inference/BayesTree-inl.h @@ -306,37 +306,45 @@ namespace gtsam { // root and the marginal on the root, integrating out all other variables. // The integrands include any parents of this clique and the variables of // the parent clique. - vector variablesAtBack; - variablesAtBack.reserve(this->size() + R->size()); + FastSet variablesAtBack; + FastSet separator; + size_t uniqueRootVariables = 0; BOOST_FOREACH(const Index separatorIndex, this->separator_) { - variablesAtBack.push_back(separatorIndex); + variablesAtBack.insert(separatorIndex); + separator.insert(separatorIndex); if(debug) cout << "At back (this): " << separatorIndex << endl; } BOOST_FOREACH(const sharedConditional& conditional, *R) { - variablesAtBack.push_back(conditional->key()); + if(variablesAtBack.insert(conditional->key()).second) + ++ uniqueRootVariables; if(debug) cout << "At back (root): " << conditional->key() << endl; } - Permutation toBack = Permutation::PushToBack(variablesAtBack, R->back()->key() + 1); + Permutation toBack = Permutation::PushToBack( + vector(variablesAtBack.begin(), variablesAtBack.end()), + R->back()->key() + 1); Permutation::shared_ptr toBackInverse(toBack.inverse()); BOOST_FOREACH(const typename CONDITIONAL::Factor::shared_ptr& factor, p_Cp_R) { factor->permuteWithInverse(*toBackInverse); } typename BayesNet::shared_ptr eliminated(EliminationTree::Create(p_Cp_R)->eliminate()); - // take only the conditionals for p(S|R) + // Take only the conditionals for p(S|R). We check for each variable being + // in the separator set because if some separator variables overlap with + // root variables, we cannot rely on the number of root variables, and also + // want to include those variables in the conditional. BayesNet p_S_R; - typename BayesNet::const_reverse_iterator conditional = eliminated->rbegin(); - BOOST_FOREACH(const sharedConditional& c, *R) { - (void)c; ++conditional; } - BOOST_FOREACH(const Index c, this->separator_) { - if(debug) - (*conditional)->print("Taking C|R conditional: "); - (void)c; p_S_R.push_front(*(conditional++)); } - -// for(Index j=0; jkey()]) != separator.end()) { + if(debug) + conditional->print("Taking C|R conditional: "); + p_S_R.push_front(conditional); + } + if(p_S_R.size() == separator.size()) + break; + } // Undo the permutation + if(debug) toBack.print("toBack: "); p_S_R.permuteWithInverse(toBack); // return the parent shortcut P(Sp|R) diff --git a/gtsam/inference/Permutation.cpp b/gtsam/inference/Permutation.cpp index 99381ce4d..781ebb30a 100644 --- a/gtsam/inference/Permutation.cpp +++ b/gtsam/inference/Permutation.cpp @@ -19,6 +19,8 @@ #include #include +#include +#include #include using namespace std; @@ -34,7 +36,7 @@ Permutation Permutation::Identity(Index nVars) { } /* ************************************************************************* */ -Permutation Permutation::PullToFront(const vector& toFront, size_t size) { +Permutation Permutation::PullToFront(const vector& toFront, size_t size, bool filterDuplicates) { Permutation ret(size); @@ -43,13 +45,24 @@ Permutation Permutation::PullToFront(const vector& toFront, size_t size) // Put the pulled variables at the front of the permutation and set up the // pulled flags. + size_t toFrontUniqueSize; for(Index j=0; j #include +#include +#include using namespace std; using namespace gtsam; -using namespace example; /* ************************************************************************* */ // The tests below test the *generic* inference algorithms. Some of these have @@ -34,6 +34,7 @@ using namespace example; /* ************************************************************************* */ TEST(GaussianFactorGraph, createSmoother) { + using namespace example; GaussianFactorGraph fg2; Ordering ordering; boost::tie(fg2,ordering) = createSmoother(3); @@ -50,6 +51,7 @@ TEST(GaussianFactorGraph, createSmoother) /* ************************************************************************* */ TEST( Inference, marginals ) { + using namespace example; // create and marginalize a small Bayes net on "x" GaussianBayesNet cbn = createSmallGaussianBayesNet(); vector xvar; xvar.push_back(0); @@ -60,6 +62,34 @@ TEST( Inference, marginals ) CHECK(assert_equal(expected,actual)); } +/* ************************************************************************* */ +TEST( Inference, marginals2) +{ + using namespace gtsam::planarSLAM; + + Graph fg; + SharedDiagonal poseModel(sharedSigma(3, 0.1)); + SharedDiagonal pointModel(sharedSigma(3, 0.1)); + + fg.addPrior(PoseKey(0), Pose2(), poseModel); + fg.addOdometry(PoseKey(0), PoseKey(1), Pose2(1.0,0.0,0.0), poseModel); + fg.addOdometry(PoseKey(1), PoseKey(2), Pose2(1.0,0.0,0.0), poseModel); + fg.addBearingRange(PoseKey(0), PointKey(0), Rot2(), 1.0, pointModel); + fg.addBearingRange(PoseKey(1), PointKey(0), Rot2(), 1.0, pointModel); + fg.addBearingRange(PoseKey(2), PointKey(0), Rot2(), 1.0, pointModel); + + Values init; + init.insert(PoseKey(0), Pose2(0.0,0.0,0.0)); + init.insert(PoseKey(1), Pose2(1.0,0.0,0.0)); + init.insert(PoseKey(2), Pose2(2.0,0.0,0.0)); + init.insert(PointKey(0), Point2(1.0,1.0)); + + Ordering ordering(*fg.orderingCOLAMD(init)); + GaussianFactorGraph::shared_ptr gfg(fg.linearize(init, ordering)); + GaussianMultifrontalSolver solver(*gfg); + solver.marginalFactor(ordering[PointKey(0)]); +} + /* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr);} /* ************************************************************************* */