diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index c2d941eaa..13a34dd19 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -191,20 +191,36 @@ TEST(DiscreteConditional, marginals) { DiscreteConditional prior(B % "1/2"); DiscreteConditional pAB = prior * conditional; + // P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 1*1 + 2*2 = 5 + // P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4 DiscreteConditional actualA = pAB.marginal(A.first); DiscreteConditional pA(A % "5/4"); EXPECT(assert_equal(pA, actualA)); - EXPECT_LONGS_EQUAL(1, actualA.nrFrontals()); + EXPECT(actualA.frontals() == KeyVector{1}); EXPECT_LONGS_EQUAL(0, actualA.nrParents()); - KeyVector frontalsA(actualA.beginFrontals(), actualA.endFrontals()); - EXPECT((frontalsA == KeyVector{1})); DiscreteConditional actualB = pAB.marginal(B.first); EXPECT(assert_equal(prior, actualB)); - EXPECT_LONGS_EQUAL(1, actualB.nrFrontals()); + EXPECT(actualB.frontals() == KeyVector{0}); EXPECT_LONGS_EQUAL(0, actualB.nrParents()); - KeyVector frontalsB(actualB.beginFrontals(), actualB.endFrontals()); - EXPECT((frontalsB == KeyVector{0})); +} + +/* ************************************************************************* */ +// Check calculation of marginals in case branches are pruned +TEST(DiscreteConditional, marginals2) { + DiscreteKey A(0, 2), B(1, 2); // changing keys need to make pruning happen! + DiscreteConditional conditional(A | B = "2/2 3/1"); + DiscreteConditional prior(B % "1/2"); + DiscreteConditional pAB = prior * conditional; + GTSAM_PRINT(pAB); + // P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8 + // P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4 + DiscreteConditional actualA = pAB.marginal(A.first); + DiscreteConditional pA(A % "8/4"); + EXPECT(assert_equal(pA, actualA)); + + DiscreteConditional actualB = pAB.marginal(B.first); + EXPECT(assert_equal(prior, actualB)); } /* ************************************************************************* */