diff --git a/inference/BayesTree-inl.h b/inference/BayesTree-inl.h index 1608e2f70..fca7e01f3 100644 --- a/inference/BayesTree-inl.h +++ b/inference/BayesTree-inl.h @@ -348,7 +348,7 @@ namespace gtsam { p_FSR.push_back(*R); // Find marginal on the keys we are interested in - return FactorGraph(*Inference::Marginal(FactorGraph(p_FSR), keys())); + return Inference::Marginal(FactorGraph(p_FSR), keys()); } // /* ************************************************************************* */ diff --git a/inference/inference-inl.h b/inference/inference-inl.h index 22f543e55..45d2d7c55 100644 --- a/inference/inference-inl.h +++ b/inference/inference-inl.h @@ -278,7 +278,7 @@ Inference::EliminateOne(FactorGraph& factorGraph, typename FactorGraph::variable /* ************************************************************************* */ template -typename FactorGraph::bayesnet_type::shared_ptr Inference::Marginal(const FactorGraph& factorGraph, const VarContainer& variables) { +FactorGraph Inference::Marginal(const FactorGraph& factorGraph, const VarContainer& variables) { // Compute a COLAMD permutation with the marginal variables constrained to the end typename FactorGraph::variableindex_type varIndex(factorGraph); @@ -298,16 +298,18 @@ typename FactorGraph::bayesnet_type::shared_ptr Inference::Marginal(const Factor typename FactorGraph::bayesnet_type::shared_ptr bn(Inference::Eliminate(eliminationGraph, varIndex)); // The last conditionals in the eliminated BayesNet contain the marginal for - // the variables we want. - typename FactorGraph::bayesnet_type::shared_ptr marginal(new typename FactorGraph::bayesnet_type()); + // the variables we want. Undo the permutation as we add the marginal + // factors. + FactorGraph marginal; marginal.reserve(variables.size()); typename FactorGraph::bayesnet_type::const_reverse_iterator conditional = bn->rbegin(); for(Index j=0; jpush_front(*conditional); + typename FactorGraph::sharedFactor factor(new typename FactorGraph::factor_type(**conditional)); + factor->permuteWithInverse(*permutation); + marginal.push_back(factor); assert(std::find(variables.begin(), variables.end(), (*permutation)[(*conditional)->key()]) != variables.end()); } // Undo the permutation - marginal->permuteWithInverse(*permutation); return marginal; } diff --git a/inference/inference.h b/inference/inference.h index 6a08fd189..92d77e3ef 100644 --- a/inference/inference.h +++ b/inference/inference.h @@ -94,7 +94,7 @@ class Conditional; * variables. */ template - static typename FactorGraph::bayesnet_type::shared_ptr Marginal(const FactorGraph& factorGraph, const VarContainer& variables); + static FactorGraph Marginal(const FactorGraph& factorGraph, const VarContainer& variables); /** * Compute a permutation (variable ordering) using colamd diff --git a/tests/testInference.cpp b/tests/testInference.cpp index 9c20e97d6..e33e7eceb 100644 --- a/tests/testInference.cpp +++ b/tests/testInference.cpp @@ -31,8 +31,8 @@ TEST(GaussianFactorGraph, createSmoother) // eliminate list x3var; x3var.push_back(ordering["x3"]); list x1var; x1var.push_back(ordering["x1"]); - GaussianBayesNet p_x3 = *Inference::Marginal(fg2, x3var); - GaussianBayesNet p_x1 = *Inference::Marginal(fg2, x1var); + GaussianBayesNet p_x3 = *Inference::Eliminate(Inference::Marginal(fg2, x3var)); + GaussianBayesNet p_x1 = *Inference::Eliminate(Inference::Marginal(fg2, x1var)); CHECK(assert_equal(*p_x1.back(),*p_x3.front())); // should be the same because of symmetry } @@ -42,7 +42,7 @@ TEST( Inference, marginals ) // create and marginalize a small Bayes net on "x" GaussianBayesNet cbn = createSmallGaussianBayesNet(); list xvar; xvar.push_back(0); - GaussianBayesNet actual = *Inference::Marginal(GaussianFactorGraph(cbn), xvar); + GaussianBayesNet actual = *Inference::Eliminate(Inference::Marginal(GaussianFactorGraph(cbn), xvar)); // expected is just scalar Gaussian on x GaussianBayesNet expected = scalarGaussian(0, 4, sqrt(2));