diff --git a/gtsam/nonlinear/ISAM2.cpp b/gtsam/nonlinear/ISAM2.cpp index 367b88833..7783d8e94 100644 --- a/gtsam/nonlinear/ISAM2.cpp +++ b/gtsam/nonlinear/ISAM2.cpp @@ -792,6 +792,7 @@ void ISAM2::experimentalMarginalizeLeaves(const FastList& leafKeys) BOOST_FOREACH(Key key, leafKeys) { indices.insert(ordering_[key]); } + FastSet origIndices = indices; // For each clique containing variables to be marginalized, we need to // reeliminate the marginalized variables and add their linear contribution @@ -823,7 +824,7 @@ void ISAM2::experimentalMarginalizeLeaves(const FastList& leafKeys) } #endif - // Now loop over the indices + // Now loop over the indices, the iterator jI is advanced inside the loop. FastSet factorIndicesToRemove; GaussianFactorGraph factorsToAdd; for(FastSet::iterator jI = indices.begin(); jI != indices.end(); ) @@ -834,22 +835,19 @@ void ISAM2::experimentalMarginalizeLeaves(const FastList& leafKeys) FastMap::iterator clique_lastIndex = cliquesToMarginalize.find(nodes_[*jI]); assert(clique_lastIndex != cliquesToMarginalize.end()); // Assert that we indexed the clique + const size_t originalnFrontals = clique_lastIndex->first->conditional()->nrFrontals(); + // Check that the clique has no children if(!clique_lastIndex->first->children().empty()) throw MarginalizeNonleafException(ordering_.key(*jI), params_.keyFormatter); - // Mark factors to be removed - BOOST_FOREACH(size_t i, variableIndex_[*jI]) { - factorIndicesToRemove.insert(i); - } - // Check that all previous variables in the clique are also being eliminated and no later ones. // At the same time, remove the indices marginalized with this clique from the indices set. // This is where the iterator j is advanced. size_t nFrontals = 0; { bool foundLast = false; - BOOST_FOREACH(Index cliqueVar, *clique_lastIndex->first->conditional()) { + BOOST_FOREACH(Index cliqueVar, clique_lastIndex->first->conditional()->frontals()) { if(!foundLast && indices.find(cliqueVar) == indices.end()) throw MarginalizeNonleafException(ordering_.key(j), params_.keyFormatter); if(foundLast && indices.find(cliqueVar) != indices.end()) @@ -863,6 +861,10 @@ void ISAM2::experimentalMarginalizeLeaves(const FastList& leafKeys) } if(cliqueVar == clique_lastIndex->second) foundLast = true; + // Mark factors to be removed + BOOST_FOREACH(size_t i, variableIndex_[cliqueVar]) { + factorIndicesToRemove.insert(i); + } } } @@ -876,11 +878,31 @@ void ISAM2::experimentalMarginalizeLeaves(const FastList& leafKeys) EliminateQR(cliqueGraph, nFrontals) : EliminatePreferCholesky(cliqueGraph, nFrontals); + // Now we discard the conditional part and add the marginal part back into + // the graph. Also we need to rebuild the leaf clique using the marginal. + // Add the marginal into the factor graph factorsToAdd.push_back(eliminationResult.second); + // Get the parent of the clique to be removed + sharedClique parent = clique_lastIndex->first->parent(); + // Remove the clique this->removeClique(clique_lastIndex->first); + + // Add a new leaf clique if necessary + const size_t newnFrontals = originalnFrontals - nFrontals; + if(newnFrontals > 0) { + // Do the elimination for the new leaf clique + GaussianFactorGraph newCliqueGraph; + newCliqueGraph.push_back(eliminationResult.second); + pair newEliminationResult = + params_.factorization==ISAM2Params::QR ? + EliminateQR(newCliqueGraph, newnFrontals) : + EliminatePreferCholesky(newCliqueGraph, newnFrontals); + // Create and add the new clique + this->addClique(ISAM2Clique::Create(newEliminationResult), parent); + } } // Remove any factors touching the marginalized-out variables @@ -893,6 +915,7 @@ void ISAM2::experimentalMarginalizeLeaves(const FastList& leafKeys) if(params_.cacheLinearizedFactors) linearFactors_.remove(i); } + variableIndex_.remove(removedFactorIndices, removedFactors); // Add the new factors and fix linearization points of involved variables BOOST_FOREACH(const GaussianFactor::shared_ptr& factor, factorsToAdd) { @@ -907,7 +930,6 @@ void ISAM2::experimentalMarginalizeLeaves(const FastList& leafKeys) variableIndex_.augment(factorsToAdd); // Augment the variable index // Remove the marginalized variables - variableIndex_.remove(removedFactorIndices, removedFactors); Impl::RemoveVariables(FastSet(leafKeys.begin(), leafKeys.end()), root_, theta_, variableIndex_, delta_, deltaNewton_, RgProd_, deltaReplacedMask_, ordering_, nodes_, linearFactors_, fixedVariables_); } diff --git a/tests/testGaussianISAM2.cpp b/tests/testGaussianISAM2.cpp index 96bbd5949..b397c4685 100644 --- a/tests/testGaussianISAM2.cpp +++ b/tests/testGaussianISAM2.cpp @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -832,41 +833,99 @@ TEST(ISAM2, slamlike_solution_partial_relinearization_check) CHECK(isam_check(fullgraph, fullinit, isam, *this, result_)); } +namespace { + bool checkMarginalizeLeaves(ISAM2& isam, const FastList& leafKeys) { + Matrix expectedAugmentedHessian, expected3AugmentedHessian; + vector toKeep; + const Index lastVar = isam.getOrdering().size() - 1; + for(Index i=0; i<=lastVar; ++i) + if(find(leafKeys.begin(), leafKeys.end(), isam.getOrdering().key(i)) == leafKeys.end()) + toKeep.push_back(i); + + // Calculate expected marginal from iSAM2 tree + GaussianFactorGraph isamAsGraph(isam); + GaussianSequentialSolver solver(isamAsGraph); + GaussianFactorGraph marginalgfg = *solver.jointFactorGraph(toKeep); + expectedAugmentedHessian = marginalgfg.augmentedHessian(); + + //// Calculate expected marginal from cached linear factors + //assert(isam.params().cacheLinearizedFactors); + //GaussianSequentialSolver solver2(isam.linearFactors_, isam.params().factorization == ISAM2Params::QR); + //expected2AugmentedHessian = solver2.jointFactorGraph(toKeep)->augmentedHessian(); + + // Calculate expected marginal from original nonlinear factors + GaussianSequentialSolver solver3( + *isam.getFactorsUnsafe().linearize(isam.getLinearizationPoint(), isam.getOrdering()), + isam.params().factorization == ISAM2Params::QR); + expected3AugmentedHessian = solver3.jointFactorGraph(toKeep)->augmentedHessian(); + + // Do marginalization + isam.experimentalMarginalizeLeaves(leafKeys); + + // Check + GaussianFactorGraph actualMarginalGraph(isam); + Matrix actualAugmentedHessian = actualMarginalGraph.augmentedHessian(); + //Matrix actual2AugmentedHessian = linearFactors_.augmentedHessian(); + Matrix actual3AugmentedHessian = isam.getFactorsUnsafe().linearize( + isam.getLinearizationPoint(), isam.getOrdering())->augmentedHessian(); + assert(actualAugmentedHessian.unaryExpr(std::ptr_fun(&std::isfinite)).all()); + + // Check full marginalization + bool treeEqual = assert_equal(expectedAugmentedHessian, actualAugmentedHessian, 1e-6); + //bool linEqual = assert_equal(expected2AugmentedHessian, actualAugmentedHessian, 1e-6); + bool nonlinEqual = assert_equal(expected3AugmentedHessian, actualAugmentedHessian, 1e-6); + //bool linCorrect = assert_equal(expected3AugmentedHessian, expected2AugmentedHessian, 1e-6); + //actual2AugmentedHessian.bottomRightCorner(1,1) = expected3AugmentedHessian.bottomRightCorner(1,1); bool afterLinCorrect = assert_equal(expected3AugmentedHessian, actual2AugmentedHessian, 1e-6); + actual3AugmentedHessian.bottomRightCorner(1,1) = expected3AugmentedHessian.bottomRightCorner(1,1); bool afterNonlinCorrect = assert_equal(expected3AugmentedHessian, actual3AugmentedHessian, 1e-6); + + bool ok = treeEqual && /*linEqual &&*/ nonlinEqual && /*linCorrect &&*/ /*afterLinCorrect &&*/ afterNonlinCorrect; + return ok; + } +} + /* ************************************************************************* */ TEST_UNSAFE(ISAM2, marginalizeLeaves) +{ + ISAM2 isam; + + NonlinearFactorGraph factors; + factors.add(PriorFactor(0, LieVector(0.0), noiseModel::Unit::Create(1))); + + factors.add(BetweenFactor(0, 1, LieVector(0.0), noiseModel::Unit::Create(1))); + factors.add(BetweenFactor(1, 2, LieVector(0.0), noiseModel::Unit::Create(1))); + factors.add(BetweenFactor(0, 2, LieVector(0.0), noiseModel::Unit::Create(1))); + + factors.add(BetweenFactor(2, 3, LieVector(0.0), noiseModel::Unit::Create(1))); + + factors.add(BetweenFactor(3, 4, LieVector(0.0), noiseModel::Unit::Create(1))); + factors.add(BetweenFactor(4, 5, LieVector(0.0), noiseModel::Unit::Create(1))); + factors.add(BetweenFactor(3, 5, LieVector(0.0), noiseModel::Unit::Create(1))); + + Values values; + values.insert(0, LieVector(0.0)); + values.insert(1, LieVector(0.0)); + values.insert(2, LieVector(0.0)); + values.insert(3, LieVector(0.0)); + values.insert(4, LieVector(0.0)); + values.insert(5, LieVector(0.0)); + + isam.update(factors, values); + + FastList leafKeys; + leafKeys.push_back(0); + EXPECT(checkMarginalizeLeaves(isam, leafKeys)); +} + +/* ************************************************************************* */ +TEST_UNSAFE(ISAM2, marginalizeLeaves2) { // Create isam2 ISAM2 isam = createSlamlikeISAM2(); - - // Get linearization point - Values soln = isam.calculateBestEstimate(); - - // Calculate expected marginal - GaussianFactorGraph isamAsGraph(isam); - GaussianSequentialSolver solver(isamAsGraph); - vector toKeep; - const Index lastVar = isam.getOrdering().size() - 1; - for(Index i=0; i<=lastVar; ++i) - if(i != isam.getOrdering()[0]) - toKeep.push_back(i); - GaussianFactorGraph marginalgfg = *solver.jointFactorGraph(toKeep); - vector toFrontI; - toFrontI.push_back(isam.getOrdering()[0]); - Permutation toFront = Permutation::PullToFront(toFrontI, lastVar+1); - marginalgfg.permuteWithInverse(*toFront.inverse()); - Matrix expectedAugmentedHessian = marginalgfg.augmentedHessian(); // Marginalize FastList marginalizeKeys; marginalizeKeys.push_back(isam.getOrdering().key(0)); - isam.experimentalMarginalizeLeaves(marginalizeKeys); - - // Check - GaussianFactorGraph actualMarginalGraph(isam); - Matrix actualAugmentedHessian = actualMarginalGraph.augmentedHessian(); - - LONGS_EQUAL(lastVar-1, isam.getOrdering().size()-1); - EXPECT(assert_equal(expectedAugmentedHessian, actualAugmentedHessian)); + EXPECT(checkMarginalizeLeaves(isam, marginalizeKeys)); } /* ************************************************************************* */