diff --git a/gtsam/inference/BayesTree-inl.h b/gtsam/inference/BayesTree-inl.h index 177457cfa..03175dec9 100644 --- a/gtsam/inference/BayesTree-inl.h +++ b/gtsam/inference/BayesTree-inl.h @@ -382,6 +382,7 @@ namespace gtsam { template FactorGraph BayesTree::Clique::marginal(shared_ptr R) { // If we are the root, just return this root + // NOTE: immediately cast to a factor graph if (R.get()==this) return *R; // Combine P(F|S), P(S|R), and P(R) @@ -389,9 +390,6 @@ namespace gtsam { p_FSR.push_front(*this); p_FSR.push_back(*R); - // Find marginal on the keys we are interested in - FactorGraph p_FSRfg(p_FSR); - assertInvariants(); return *GenericSequentialSolver(p_FSR).jointFactorGraph(keys()); } @@ -731,7 +729,7 @@ namespace gtsam { // First finds clique marginal then marginalizes that /* ************************************************************************* */ template - typename CONDITIONAL::Factor::shared_ptr BayesTree::marginal(Index key) const { + typename CONDITIONAL::Factor::shared_ptr BayesTree::marginalFactor(Index key) const { // get clique containing key sharedClique clique = (*this)[key]; @@ -748,7 +746,7 @@ namespace gtsam { // calculate marginal as a factor graph FactorGraph fg; - fg.push_back(this->marginal(key)); + fg.push_back(this->marginalFactor(key)); // eliminate factor graph marginal to a Bayes net return GenericSequentialSolver(fg).eliminate(); diff --git a/gtsam/inference/BayesTree.h b/gtsam/inference/BayesTree.h index 6f99d4470..37c7bbca0 100644 --- a/gtsam/inference/BayesTree.h +++ b/gtsam/inference/BayesTree.h @@ -262,9 +262,13 @@ namespace gtsam { CliqueData getCliqueData() const; /** return marginal on any variable */ - typename CONDITIONAL::Factor::shared_ptr marginal(Index key) const; + typename CONDITIONAL::Factor::shared_ptr marginalFactor(Index key) const; - /** return marginal on any variable, as a Bayes Net */ + /** + * return marginal on any variable, as a Bayes Net + * NOTE: this function calls marginal, and then eliminates it into a Bayes Net + * This is more expensive than the above function + */ typename BayesNet::shared_ptr marginalBayesNet(Index key) const; /** return joint on two variables */ diff --git a/gtsam/inference/GenericMultifrontalSolver-inl.h b/gtsam/inference/GenericMultifrontalSolver-inl.h index 6bf8c55a9..918f500b3 100644 --- a/gtsam/inference/GenericMultifrontalSolver-inl.h +++ b/gtsam/inference/GenericMultifrontalSolver-inl.h @@ -75,7 +75,7 @@ GenericMultifrontalSolver::jointFactorGraph(const std::vec /* ************************************************************************* */ template typename FACTOR::shared_ptr GenericMultifrontalSolver::marginalFactor(Index j) const { - return eliminate()->marginal(j); + return eliminate()->marginalFactor(j); } } diff --git a/gtsam/linear/GaussianISAM.cpp b/gtsam/linear/GaussianISAM.cpp index 1a3ff89aa..d54e4cc31 100644 --- a/gtsam/linear/GaussianISAM.cpp +++ b/gtsam/linear/GaussianISAM.cpp @@ -24,8 +24,20 @@ using namespace gtsam; #include template class ISAM; +namespace ublas = boost::numeric::ublas; + namespace gtsam { +/* ************************************************************************* */ +std::pair GaussianISAM::marginal(Index j) const { + GaussianFactor::shared_ptr factor = this->marginalFactor(j); + FactorGraph graph; + graph.push_back(factor); + GaussianConditional::shared_ptr conditional = GaussianFactor::CombineAndEliminate(graph,1).first->front(); + Matrix R = conditional->get_R(); + return make_pair(conditional->get_d(), inverse(ublas::prod(ublas::trans(R), R))); +} + /* ************************************************************************* */ void optimize(const GaussianISAM::sharedClique& clique, VectorValues& result) { // parents are assumed to already be solved and available in result diff --git a/gtsam/linear/GaussianISAM.h b/gtsam/linear/GaussianISAM.h index 5e2ab7488..37653c485 100644 --- a/gtsam/linear/GaussianISAM.h +++ b/gtsam/linear/GaussianISAM.h @@ -69,6 +69,9 @@ public: friend VectorValues optimize(const GaussianISAM&); + /** return marginal on any variable */ + std::pair marginal(Index key) const; + }; // recursively optimize this conditional and all subtrees diff --git a/tests/testGaussianISAM.cpp b/tests/testGaussianISAM.cpp index b355a094d..8c98ecadf 100644 --- a/tests/testGaussianISAM.cpp +++ b/tests/testGaussianISAM.cpp @@ -195,28 +195,58 @@ TEST( BayesTree, balanced_smoother_marginals ) // Check marginal on x1 GaussianBayesNet expected1 = simpleGaussian(ordering["x1"], zero(2), sigmax1); GaussianBayesNet actual1 = *bayesTree.marginalBayesNet(ordering["x1"]); - CHECK(assert_equal(expected1,actual1,tol)); + Matrix expectedCovarianceX1 = eye(2,2) * (sigmax1 * sigmax1); + Vector expectedMeanX1 = zero(2); + Matrix actualCovarianceX1; Vector actualMeanX1; + boost::tie(actualMeanX1, actualCovarianceX1) = bayesTree.marginal(ordering["x1"]); + EXPECT(assert_equal(expectedCovarianceX1, actualCovarianceX1, tol)); + EXPECT(assert_equal(expectedMeanX1, actualMeanX1, tol)); + EXPECT(assert_equal(expected1,actual1,tol)); // Check marginal on x2 double sigx2 = 0.68712938; // FIXME: this should be corrected analytically GaussianBayesNet expected2 = simpleGaussian(ordering["x2"], zero(2), sigx2); GaussianBayesNet actual2 = *bayesTree.marginalBayesNet(ordering["x2"]); - CHECK(assert_equal(expected2,actual2,tol)); // FAILS + Matrix expectedCovarianceX2 = eye(2,2) * (sigx2 * sigx2); + Vector expectedMeanX2 = zero(2); + Matrix actualCovarianceX2; Vector actualMeanX2; + boost::tie(actualMeanX2, actualCovarianceX2) = bayesTree.marginal(ordering["x2"]); + EXPECT(assert_equal(expectedCovarianceX2, actualCovarianceX2, tol)); + EXPECT(assert_equal(expectedMeanX2, actualMeanX2, tol)); + EXPECT(assert_equal(expected2,actual2,tol)); // Check marginal on x3 GaussianBayesNet expected3 = simpleGaussian(ordering["x3"], zero(2), sigmax3); GaussianBayesNet actual3 = *bayesTree.marginalBayesNet(ordering["x3"]); - CHECK(assert_equal(expected3,actual3,tol)); + Matrix expectedCovarianceX3 = eye(2,2) * (sigmax3 * sigmax3); + Vector expectedMeanX3 = zero(2); + Matrix actualCovarianceX3; Vector actualMeanX3; + boost::tie(actualMeanX3, actualCovarianceX3) = bayesTree.marginal(ordering["x3"]); + EXPECT(assert_equal(expectedCovarianceX3, actualCovarianceX3, tol)); + EXPECT(assert_equal(expectedMeanX3, actualMeanX3, tol)); + EXPECT(assert_equal(expected3,actual3,tol)); // Check marginal on x4 GaussianBayesNet expected4 = simpleGaussian(ordering["x4"], zero(2), sigmax4); GaussianBayesNet actual4 = *bayesTree.marginalBayesNet(ordering["x4"]); - CHECK(assert_equal(expected4,actual4,tol)); + Matrix expectedCovarianceX4 = eye(2,2) * (sigmax4 * sigmax4); + Vector expectedMeanX4 = zero(2); + Matrix actualCovarianceX4; Vector actualMeanX4; + boost::tie(actualMeanX4, actualCovarianceX4) = bayesTree.marginal(ordering["x4"]); + EXPECT(assert_equal(expectedCovarianceX4, actualCovarianceX4, tol)); + EXPECT(assert_equal(expectedMeanX4, actualMeanX4, tol)); + EXPECT(assert_equal(expected4,actual4,tol)); // Check marginal on x7 (should be equal to x1) GaussianBayesNet expected7 = simpleGaussian(ordering["x7"], zero(2), sigmax7); GaussianBayesNet actual7 = *bayesTree.marginalBayesNet(ordering["x7"]); - CHECK(assert_equal(expected7,actual7,tol)); + Matrix expectedCovarianceX7 = eye(2,2) * (sigmax7 * sigmax7); + Vector expectedMeanX7 = zero(2); + Matrix actualCovarianceX7; Vector actualMeanX7; + boost::tie(actualMeanX7, actualCovarianceX7) = bayesTree.marginal(ordering["x7"]); + EXPECT(assert_equal(expectedCovarianceX7, actualCovarianceX7, tol)); + EXPECT(assert_equal(expectedMeanX7, actualMeanX7, tol)); + EXPECT(assert_equal(expected7,actual7,tol)); } /* ************************************************************************* */