diff --git a/gtsam/inference/FactorGraph-inl.h b/gtsam/inference/FactorGraph-inl.h index b0cb4e3ec..9125ff31b 100644 --- a/gtsam/inference/FactorGraph-inl.h +++ b/gtsam/inference/FactorGraph-inl.h @@ -149,6 +149,68 @@ namespace gtsam { return std::make_pair(eliminationResult.first, remainingFactors); } + /* ************************************************************************* */ + template + std::pair::sharedConditional, FactorGraph > + FactorGraph::eliminate(const std::vector& variables, const Eliminate& eliminateFcn, + boost::optional variableIndex_) + { + const VariableIndex& variableIndex = + variableIndex_ ? *variableIndex_ : VariableIndex(*this); + + // First find the involved factors + FactorGraph involvedFactors; + Index highestInvolvedVariable = 0; // Largest index of the variables in the involved factors + + // First get the indices of the involved factors, but uniquely in a set + FastSet involvedFactorIndices; + BOOST_FOREACH(Index variable, variables) { + involvedFactorIndices.insert(variableIndex[variable].begin(), variableIndex[variable].end()); } + + // Add the factors themselves to involvedFactors and update largest index + involvedFactors.reserve(involvedFactorIndices.size()); + BOOST_FOREACH(size_t factorIndex, involvedFactorIndices) { + const sharedFactor factor = this->at(factorIndex); + involvedFactors.push_back(factor); // Add involved factor + highestInvolvedVariable = std::max( // Updated largest index + highestInvolvedVariable, + *std::max_element(factor->begin(), factor->end())); + } + + // Now permute the variables to be eliminated to the front of the ordering + Permutation toFront = Permutation::PullToFront(variables, highestInvolvedVariable+1); + Permutation toFrontInverse = *toFront.inverse(); + BOOST_FOREACH(const sharedFactor& factor, involvedFactors) { + factor->permuteWithInverse(toFrontInverse); + } + + // Eliminate into conditional and remaining factor + EliminationResult eliminated = eliminateFcn(involvedFactors, variables.size()); + sharedConditional conditional = eliminated.first; + sharedFactor remainingFactor = eliminated.second; + + // Undo the permutation + conditional->permuteWithInverse(toFront); + remainingFactor->permuteWithInverse(toFront); + + // Build the remaining graph, without the removed factors + FactorGraph remainingGraph; + remainingGraph.reserve(this->size() - involvedFactors.size() + 1); + FastSet::const_iterator involvedFactorIndexIt = involvedFactorIndices.begin(); + for(size_t i = 0; i < this->size(); ++i) { + if(involvedFactorIndexIt != involvedFactorIndices.end() && *involvedFactorIndexIt == i) + ++ involvedFactorIndexIt; + else + remainingGraph.push_back(this->at(i)); + } + + // Add the remaining factor if it is not empty. + if(remainingFactor->size() != 0) + remainingGraph.push_back(remainingFactor); + + return std::make_pair(conditional, remainingGraph); + + } /* ************************************************************************* */ template void FactorGraph::replace(size_t index, sharedFactor factor) { diff --git a/gtsam/inference/FactorGraph.h b/gtsam/inference/FactorGraph.h index 0c70601ba..5be010e2c 100644 --- a/gtsam/inference/FactorGraph.h +++ b/gtsam/inference/FactorGraph.h @@ -33,6 +33,7 @@ namespace gtsam { // Forward declarations template class BayesTree; +class VariableIndex; /** * A factor graph is a bipartite graph with factor nodes connected to variable nodes. @@ -182,6 +183,34 @@ template class BayesTree; * JunctionTree. */ std::pair > eliminateFrontals(size_t nFrontals, const Eliminate& eliminate) const; + + /** Factor the factor graph into a conditional and a remaining factor graph. + * Given the factor graph \f$ f(X) \f$, and \c variables to factorize out + * \f$ V \f$, this function factorizes into \f$ f(X) = f(V;Y)f(Y) \f$, where + * \f$ Y := X\V \f$ are the remaining variables. If \f$ f(X) = p(X) \f$ is + * a probability density or likelihood, the factorization produces a + * conditional probability density and a marginal \f$ p(X) = p(V|Y)p(Y) \f$. + * + * For efficiency, this function treats the variables to eliminate + * \c variables as fully-connected, so produces a dense (fully-connected) + * conditional on all of the variables in \c variables, instead of a sparse + * BayesNet. If the variables are not fully-connected, it is more efficient + * to sequentially factorize multiple times. + */ + std::pair::sharedConditional, FactorGraph > + eliminate( + const std::vector& variables, const Eliminate& eliminateFcn, + boost::optional variableIndex = boost::none); + + /** Eliminate a single variable, by calling + * eliminate(const Graph&, const std::vector&, const typename Graph::Eliminate&, boost::optional) + */ + std::pair > eliminateOne( + KeyType variable, const Eliminate& eliminateFcn, + boost::optional variableIndex = boost::none) { + std::vector variables(1, variable); + return eliminate(variables, eliminateFcn, variableIndex); + } /// @} /// @name Modifying Factor Graphs (imperative, discouraged) diff --git a/gtsam/inference/inference-inl.h b/gtsam/inference/inference-inl.h index 33e5e1b65..0ed2d1ec7 100644 --- a/gtsam/inference/inference-inl.h +++ b/gtsam/inference/inference-inl.h @@ -79,69 +79,6 @@ inline Permutation::shared_ptr PermutationCOLAMD(const VariableIndex& variableIn return PermutationCOLAMD_(variableIndex, cmember); } -/* ************************************************************************* */ -template -std::pair eliminate( - const Graph& factorGraph, - const std::vector& variables, - const typename Graph::Eliminate& eliminateFcn, - boost::optional variableIndex_) { - - const VariableIndex& variableIndex = - variableIndex_ ? *variableIndex_ : VariableIndex(factorGraph); - - // First find the involved factors - Graph involvedFactors; - Index highestInvolvedVariable = 0; // Largest index of the variables in the involved factors - - // First get the indices of the involved factors, but uniquely in a set - FastSet involvedFactorIndices; - BOOST_FOREACH(Index variable, variables) { - involvedFactorIndices.insert(variableIndex[variable].begin(), variableIndex[variable].end()); } - - // Add the factors themselves to involvedFactors and update largest index - involvedFactors.reserve(involvedFactorIndices.size()); - BOOST_FOREACH(size_t factorIndex, involvedFactorIndices) { - const typename Graph::sharedFactor factor = factorGraph[factorIndex]; - involvedFactors.push_back(factor); // Add involved factor - highestInvolvedVariable = std::max( // Updated largest index - highestInvolvedVariable, - *std::max_element(factor->begin(), factor->end())); - } - - // Now permute the variables to be eliminated to the front of the ordering - Permutation toFront = Permutation::PullToFront(variables, highestInvolvedVariable+1); - Permutation toFrontInverse = *toFront.inverse(); - involvedFactors.permuteWithInverse(toFrontInverse); - - // Eliminate into conditional and remaining factor - typename Graph::EliminationResult eliminated = eliminateFcn(involvedFactors, variables.size()); - boost::shared_ptr conditional = eliminated.first; - typename Graph::sharedFactor remainingFactor = eliminated.second; - - // Undo the permutation - conditional->permuteWithInverse(toFront); - remainingFactor->permuteWithInverse(toFront); - - // Build the remaining graph, without the removed factors - Graph remainingGraph; - remainingGraph.reserve(factorGraph.size() - involvedFactors.size() + 1); - FastSet::const_iterator involvedFactorIndexIt = involvedFactorIndices.begin(); - for(size_t i = 0; i < factorGraph.size(); ++i) { - if(involvedFactorIndexIt != involvedFactorIndices.end() && *involvedFactorIndexIt == i) - ++ involvedFactorIndexIt; - else - remainingGraph.push_back(factorGraph[i]); - } - - // Add the remaining factor if it is not empty. - if(remainingFactor->size() != 0) - remainingGraph.push_back(remainingFactor); - - return std::make_pair(conditional, remainingGraph); - -} // eliminate - } // namespace inference } // namespace gtsam diff --git a/gtsam/inference/inference.h b/gtsam/inference/inference.h index b0082b556..781a31ef4 100644 --- a/gtsam/inference/inference.h +++ b/gtsam/inference/inference.h @@ -75,38 +75,6 @@ namespace gtsam { Permutation::shared_ptr PermutationCOLAMD_( const VariableIndex& variableIndex, std::vector& cmember); - /** Factor the factor graph into a conditional and a remaining factor graph. - * Given the factor graph \f$ f(X) \f$, and \c variables to factorize out - * \f$ V \f$, this function factorizes into \f$ f(X) = f(V;Y)f(Y) \f$, where - * \f$ Y := X\V \f$ are the remaining variables. If \f$ f(X) = p(X) \f$ is - * a probability density or likelihood, the factorization produces a - * conditional probability density and a marginal \f$ p(X) = p(V|Y)p(Y) \f$. - * - * For efficiency, this function treats the variables to eliminate - * \c variables as fully-connected, so produces a dense (fully-connected) - * conditional on all of the variables in \c variables, instead of a sparse - * BayesNet. If the variables are not fully-connected, it is more efficient - * to sequentially factorize multiple times. - */ - template - std::pair eliminate( - const Graph& factorGraph, - const std::vector& variables, - const typename Graph::Eliminate& eliminateFcn, - boost::optional variableIndex = boost::none); - - /** Eliminate a single variable, by calling - * eliminate(const Graph&, const std::vector&, const typename Graph::Eliminate&, boost::optional) - */ - template - std::pair eliminateOne( - const Graph& factorGraph, typename Graph::KeyType variable, - const typename Graph::Eliminate& eliminateFcn, - boost::optional variableIndex = boost::none) { - std::vector variables(1, variable); - return eliminate(factorGraph, variables, eliminateFcn, variableIndex); - } - } // \namespace inference } // \namespace gtsam diff --git a/tests/testGaussianFactorGraphB.cpp b/tests/testGaussianFactorGraphB.cpp index 6d48e45ad..94476d87a 100644 --- a/tests/testGaussianFactorGraphB.cpp +++ b/tests/testGaussianFactorGraphB.cpp @@ -76,7 +76,7 @@ TEST( GaussianFactorGraph, eliminateOne_x1 ) GaussianConditional::shared_ptr conditional; GaussianFactorGraph remaining; - boost::tie(conditional,remaining) = inference::eliminateOne(fg, 0, EliminateQR); + boost::tie(conditional,remaining) = fg.eliminateOne(0, EliminateQR); // create expected Conditional Gaussian Matrix I = 15*eye(2), R11 = I, S12 = -0.111111*I, S13 = -0.444444*I; @@ -91,7 +91,7 @@ TEST( GaussianFactorGraph, eliminateOne_x2 ) { Ordering ordering; ordering += X(2),L(1),X(1); GaussianFactorGraph fg = createGaussianFactorGraph(ordering); - GaussianConditional::shared_ptr actual = inference::eliminateOne(fg, 0, EliminateQR).first; + GaussianConditional::shared_ptr actual = fg.eliminateOne(0, EliminateQR).first; // create expected Conditional Gaussian double sig = 0.0894427; @@ -107,7 +107,7 @@ TEST( GaussianFactorGraph, eliminateOne_l1 ) { Ordering ordering; ordering += L(1),X(1),X(2); GaussianFactorGraph fg = createGaussianFactorGraph(ordering); - GaussianConditional::shared_ptr actual = inference::eliminateOne(fg, 0, EliminateQR).first; + GaussianConditional::shared_ptr actual = fg.eliminateOne(0, EliminateQR).first; // create expected Conditional Gaussian double sig = sqrt(2.0)/10.; @@ -125,7 +125,7 @@ TEST( GaussianFactorGraph, eliminateOne_x1_fast ) GaussianFactorGraph fg = createGaussianFactorGraph(ordering); GaussianConditional::shared_ptr conditional; GaussianFactorGraph remaining; - boost::tie(conditional,remaining) = inference::eliminateOne(fg, ordering[X(1)], EliminateQR); + boost::tie(conditional,remaining) = fg.eliminateOne(ordering[X(1)], EliminateQR); // create expected Conditional Gaussian Matrix I = 15*eye(2), R11 = I, S12 = -0.111111*I, S13 = -0.444444*I; @@ -154,7 +154,7 @@ TEST( GaussianFactorGraph, eliminateOne_x2_fast ) { Ordering ordering; ordering += X(1),L(1),X(2); GaussianFactorGraph fg = createGaussianFactorGraph(ordering); - GaussianConditional::shared_ptr actual = inference::eliminateOne(fg, ordering[X(2)], EliminateQR).first; + GaussianConditional::shared_ptr actual = fg.eliminateOne(ordering[X(2)], EliminateQR).first; // create expected Conditional Gaussian double sig = 0.0894427; @@ -170,7 +170,7 @@ TEST( GaussianFactorGraph, eliminateOne_l1_fast ) { Ordering ordering; ordering += X(1),L(1),X(2); GaussianFactorGraph fg = createGaussianFactorGraph(ordering); - GaussianConditional::shared_ptr actual = inference::eliminateOne(fg, ordering[L(1)], EliminateQR).first; + GaussianConditional::shared_ptr actual = fg.eliminateOne(ordering[L(1)], EliminateQR).first; // create expected Conditional Gaussian double sig = sqrt(2.0)/10.;