diff --git a/gtsam/nonlinear/ISAM2.cpp b/gtsam/nonlinear/ISAM2.cpp index 1c15469cc..c4f6ddb69 100644 --- a/gtsam/nonlinear/ISAM2.cpp +++ b/gtsam/nonlinear/ISAM2.cpp @@ -552,9 +552,12 @@ void ISAM2::marginalizeLeaves( // We do not need the marginal factors associated with this clique // because their information is already incorporated in the new // marginal factor. So, now associate this marginal factor with the - // parent of this clique. - marginalFactors[clique->parent()->conditional()->front()].push_back( - marginalFactor); + // parent of this clique. If the clique is a root and has no parent, then + // we can discard it without keeping track of the marginal factor. + if (clique->parent()) { + marginalFactors[clique->parent()->conditional()->front()].push_back( + marginalFactor); + } // Now remove this clique and its subtree - all of its marginal // information has been stored in marginalFactors. trackingRemoveSubtree(clique); @@ -632,7 +635,7 @@ void ISAM2::marginalizeLeaves( // Make the clique's matrix appear as a subset const DenseIndex dimToRemove = cg->matrixObject().offset(nToRemove); - cg->matrixObject().firstBlock() = nToRemove; + cg->matrixObject().firstBlock() += nToRemove; cg->matrixObject().rowStart() = dimToRemove; // Change the keys in the clique @@ -658,42 +661,55 @@ void ISAM2::marginalizeLeaves( // At this point we have updated the BayesTree, now update the remaining iSAM2 // data structures + // Remove the factors to remove that will be summarized in marginal factors + NonlinearFactorGraph removedFactors; + for (const auto index : factorIndicesToRemove) { + removedFactors.push_back(nonlinearFactors_[index]); + nonlinearFactors_.remove(index); + if (params_.cacheLinearizedFactors) { + linearFactors_.remove(index); + } + } + variableIndex_.remove(factorIndicesToRemove.begin(), + factorIndicesToRemove.end(), removedFactors); + // Gather factors to add - the new marginal factors - GaussianFactorGraph factorsToAdd; + GaussianFactorGraph factorsToAdd{}; + NonlinearFactorGraph nonlinearFactorsToAdd{}; for (const auto& key_factors : marginalFactors) { for (const auto& factor : key_factors.second) { if (factor) { factorsToAdd.push_back(factor); - if (marginalFactorsIndices) - marginalFactorsIndices->push_back(nonlinearFactors_.size()); - nonlinearFactors_.push_back( - boost::make_shared(factor)); - if (params_.cacheLinearizedFactors) linearFactors_.push_back(factor); + nonlinearFactorsToAdd.emplace_shared(factor); for (Key factorKey : *factor) { fixedVariables_.insert(factorKey); } } } } - variableIndex_.augment(factorsToAdd); // Augment the variable index - - // Remove the factors to remove that have been summarized in the newly-added - // marginal factors - NonlinearFactorGraph removedFactors; - for (const auto index : factorIndicesToRemove) { - removedFactors.push_back(nonlinearFactors_[index]); - nonlinearFactors_.remove(index); - if (params_.cacheLinearizedFactors) linearFactors_.remove(index); + // Add the nonlinear factors and keep track of the new factor indices + auto newFactorIndices = nonlinearFactors_.add_factors(nonlinearFactorsToAdd, + params_.findUnusedFactorSlots); + // Add cached linear factors. + if (params_.cacheLinearizedFactors){ + linearFactors_.resize(nonlinearFactors_.size()); + for (std::size_t i = 0; i < nonlinearFactorsToAdd.size(); ++i){ + linearFactors_[newFactorIndices[i]] = factorsToAdd[i]; + } } - variableIndex_.remove(factorIndicesToRemove.begin(), - factorIndicesToRemove.end(), removedFactors); - - if (deletedFactorsIndices) - deletedFactorsIndices->assign(factorIndicesToRemove.begin(), - factorIndicesToRemove.end()); + // Augment the variable index + variableIndex_.augment(factorsToAdd, newFactorIndices); // Remove the marginalized variables removeVariables(KeySet(leafKeys.begin(), leafKeys.end())); + + if (deletedFactorsIndices) { + deletedFactorsIndices->assign(factorIndicesToRemove.begin(), + factorIndicesToRemove.end()); + } + if (marginalFactorsIndices){ + *marginalFactorsIndices = std::move(newFactorIndices); + } } /* ************************************************************************* */ diff --git a/tests/testGaussianISAM2.cpp b/tests/testGaussianISAM2.cpp index 8dbf3fff6..d6e36a4b8 100644 --- a/tests/testGaussianISAM2.cpp +++ b/tests/testGaussianISAM2.cpp @@ -660,6 +660,77 @@ namespace { bool ok = treeEqual && /*linEqual &&*/ nonlinEqual && /*linCorrect &&*/ /*afterLinCorrect &&*/ afterNonlinCorrect; return ok; } + + boost::optional> createOrderingConstraints(const ISAM2& isam, const KeyVector& newKeys, const KeySet& marginalizableKeys) + { + if (marginalizableKeys.empty()) { + return boost::none; + } else { + FastMap constrainedKeys = FastMap(); + // Generate ordering constraints so that the marginalizable variables will be eliminated first + // Set all existing and new variables to Group1 + for (const auto& key_val : isam.getDelta()) { + constrainedKeys.emplace(key_val.first, 1); + } + for (const auto& key : newKeys) { + constrainedKeys.emplace(key, 1); + } + // And then re-assign the marginalizable variables to Group0 so that they'll all be leaf nodes + for (const auto& key : marginalizableKeys) { + constrainedKeys.at(key) = 0; + } + return constrainedKeys; + } + } + + void markAffectedKeys(const Key& key, const ISAM2Clique::shared_ptr& rootClique, KeyList& additionalKeys) + { + std::stack frontier; + frontier.push(rootClique); + // Basic DFS to find additional keys + while (!frontier.empty()) { + // Get the top of the stack + const ISAM2Clique::shared_ptr clique = frontier.top(); + frontier.pop(); + // Check if we have more keys and children to add + if (std::find(clique->conditional()->beginParents(), clique->conditional()->endParents(), key) != + clique->conditional()->endParents()) { + for (Key i : clique->conditional()->frontals()) { + additionalKeys.push_back(i); + } + for (const ISAM2Clique::shared_ptr& child : clique->children) { + frontier.push(child); + } + } + } + } + + bool updateAndMarginalize(const NonlinearFactorGraph& newFactors, const Values& newValues, const KeySet& marginalizableKeys, ISAM2& isam) + { + // Force ISAM2 to put marginalizable variables at the beginning + const boost::optional> orderingConstraints = createOrderingConstraints(isam, newValues.keys(), marginalizableKeys); + + // Mark additional keys between the marginalized keys and the leaves + KeyList markedKeys; + for (Key key : marginalizableKeys) { + markedKeys.push_back(key); + ISAM2Clique::shared_ptr clique = isam[key]; + for (const ISAM2Clique::shared_ptr& child : clique->children) { + markAffectedKeys(key, child, markedKeys); + } + } + + // Update + isam.update(newFactors, newValues, FactorIndices{}, orderingConstraints, boost::none, markedKeys); + + if (!marginalizableKeys.empty()) { + FastList leafKeys(marginalizableKeys.begin(), marginalizableKeys.end()); + return checkMarginalizeLeaves(isam, leafKeys); + } + else { + return true; + } + } } /* ************************************************************************* */