diff --git a/cpp/BayesTree-inl.h b/cpp/BayesTree-inl.h index 796020a52..2c9fe4448 100644 --- a/cpp/BayesTree-inl.h +++ b/cpp/BayesTree-inl.h @@ -144,11 +144,12 @@ namespace gtsam { // For now, assume neither is the root // Combine P(F1|S1), P(S1|R), P(F2|S2), P(S2|R), and P(R) - sharedBayesNet p_FSR = this->shortcut(R); - p_FSR->push_front(*this); - p_FSR->push_front(*C2->shortcut(R)); - p_FSR->push_front(*C2); - p_FSR->push_back(*R); + sharedBayesNet bn(new BayesNet); + if (!isRoot()) bn->push_back(*this); // P(F1|S1) + if (!isRoot()) bn->push_back(*(shortcut(R))); // P(S1|R) + if (!C2->isRoot()) bn->push_back(*C2); // P(F2|S2) + if (!C2->isRoot()) bn->push_back(*C2->shortcut(R)); // P(S2|R) + bn->push_back(*R); // P(R) // Find the keys of both C1 and C2 Ordering keys12 = keys(); @@ -156,7 +157,7 @@ namespace gtsam { keys12.unique(); // Calculate the marginal - return marginals(*p_FSR,keys12); + return marginals(*bn,keys12); } /* ************************************************************************* */ diff --git a/cpp/testBayesTree.cpp b/cpp/testBayesTree.cpp index e6da23ea9..c89835a17 100644 --- a/cpp/testBayesTree.cpp +++ b/cpp/testBayesTree.cpp @@ -77,6 +77,12 @@ TEST( BayesTree, constructor ) CHECK(assert_equal(bayesTree,bayesTree2)); } +/* ************************************************************************* */ +// Some numbers that should be consistent among all smoother tests + +double sigmax1 = 0.786153, sigmax2 = 0.687131, sigmax3 = 0.671512, sigmax4 = + 0.669534, sigmax5 = sigmax3, sigmax6 = sigmax2, sigmax7 = sigmax1; + /* ************************************************************************* * Bayes tree for smoother with "natural" ordering: C1 x6 x7 @@ -86,7 +92,7 @@ C4 x3 : x4 C5 x2 : x3 C6 x1 : x2 /* ************************************************************************* */ -TEST( BayesTree, smoother ) +TEST( BayesTree, linear_smoother_shortcuts ) { // Create smoother with 7 nodes LinearFactorGraph smoother = createSmoother(7); @@ -161,9 +167,8 @@ TEST( BayesTree, balanced_smoother_marginals ) GaussianBayesNet::shared_ptr chordalBayesNet = smoother.eliminate(ordering); VectorConfig expectedSolution; - Vector delta = zero(2); BOOST_FOREACH(string key, ordering) - expectedSolution.insert(key,delta); + expectedSolution.insert(key,zero(2)); boost::shared_ptr actualSolution = chordalBayesNet->optimize(); CHECK(assert_equal(expectedSolution,*actualSolution,1e-4)); @@ -172,19 +177,29 @@ TEST( BayesTree, balanced_smoother_marginals ) LONGS_EQUAL(7,bayesTree.size()); // Check marginal on x1 - GaussianBayesNet expected1("x1", delta, 0.786153); - BayesNet actual1 = bayesTree.marginal("x1"); + GaussianBayesNet expected1("x1", zero(2), sigmax1); + BayesNet actual1 = bayesTree.marginal("x1"); CHECK(assert_equal((BayesNet)expected1,actual1,1e-4)); // Check marginal on x2 - GaussianBayesNet expected2("x2", delta, 0.687131); - BayesNet actual2 = bayesTree.marginal("x2"); + GaussianBayesNet expected2("x2", zero(2), sigmax2); + BayesNet actual2 = bayesTree.marginal("x2"); CHECK(assert_equal((BayesNet)expected2,actual2,1e-4)); // Check marginal on x3 - GaussianBayesNet expected3("x3", delta, 0.671512); - BayesNet actual3 = bayesTree.marginal("x3"); + GaussianBayesNet expected3("x3", zero(2), sigmax3); + BayesNet actual3 = bayesTree.marginal("x3"); CHECK(assert_equal((BayesNet)expected3,actual3,1e-4)); + + // Check marginal on x4 + GaussianBayesNet expected4("x4", zero(2), sigmax4); + BayesNet actual4 = bayesTree.marginal("x4"); + CHECK(assert_equal((BayesNet)expected4,actual4,1e-4)); + + // Check marginal on x7 (should be equal to x1) + GaussianBayesNet expected7("x7", zero(2), sigmax7); + BayesNet actual7 = bayesTree.marginal("x7"); + CHECK(assert_equal((BayesNet)expected7,actual7,1e-4)); } /* ************************************************************************* */ @@ -231,7 +246,7 @@ TEST( BayesTree, balanced_smoother_clique_marginals ) Gaussian bayesTree(*chordalBayesNet); // Check the clique marginal P(C3) - GaussianBayesNet expected("x2",zero(2),0.687131); + GaussianBayesNet expected("x2",zero(2),sigmax2); Vector sigma = repeat(2, 0.707107); Matrix A12 = (-0.5)*eye(2); ConditionalGaussian::shared_ptr cg(new ConditionalGaussian("x1", zero(2), eye(2), "x2", A12, sigma)); @@ -258,18 +273,36 @@ TEST( BayesTree, balanced_smoother_joint ) Matrix A = (-0.00429185)*eye(2); // Check the joint density P(x1,x7) factored as P(x1|x7)P(x7) - GaussianBayesNet expected1("x7", zero(2), 0.786153); + GaussianBayesNet expected1("x7", zero(2), sigmax7); ConditionalGaussian::shared_ptr cg1(new ConditionalGaussian("x1", zero(2), eye(2), "x7", A, sigma)); expected1.push_front(cg1); BayesNet actual1 = bayesTree.joint("x1","x7"); CHECK(assert_equal((BayesNet)expected1,actual1,1e-4)); // Check the joint density P(x7,x1) factored as P(x7|x1)P(x1) - GaussianBayesNet expected2("x1", zero(2), 0.786153); + GaussianBayesNet expected2("x1", zero(2), sigmax1); ConditionalGaussian::shared_ptr cg2(new ConditionalGaussian("x7", zero(2), eye(2), "x1", A, sigma)); expected2.push_front(cg2); BayesNet actual2 = bayesTree.joint("x7","x1"); CHECK(assert_equal((BayesNet)expected2,actual2,1e-4)); + + // Check the joint density P(x1,x4), i.e. with a root variable + GaussianBayesNet expected3("x4", zero(2), sigmax4); + Vector sigma14 = repeat(2, 0.784465); + Matrix A14 = (-0.0769231)*eye(2); + ConditionalGaussian::shared_ptr cg3(new ConditionalGaussian("x1", zero(2), eye(2), "x4", A14, sigma14)); + expected3.push_front(cg3); + BayesNet actual3 = bayesTree.joint("x1","x4"); + CHECK(assert_equal((BayesNet)expected3,actual3,1e-4)); + + // Check the joint density P(x4,x1), i.e. with a root variable, factored the other way + GaussianBayesNet expected4("x1", zero(2), sigmax1); + Vector sigma41 = repeat(2, 0.668096); + Matrix A41 = (-0.055794)*eye(2); + ConditionalGaussian::shared_ptr cg4(new ConditionalGaussian("x4", zero(2), eye(2), "x1", A41, sigma41)); + expected4.push_front(cg4); + BayesNet actual4 = bayesTree.joint("x4","x1"); + CHECK(assert_equal((BayesNet)expected4,actual4,1e-4)); } /* ************************************************************************* */