diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 5acd7c0f6..e8aa4511d 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -110,7 +110,26 @@ DiscreteConditional DiscreteConditional::operator*( return DiscreteConditional(newFrontals.size(), discreteKeys, product); } -/* ******************************************************************************** */ +/* ************************************************************************** */ +DiscreteConditional DiscreteConditional::marginal(Key key) const { + if (nrParents() > 0) + throw std::invalid_argument( + "DiscreteConditional::marginal: single argument version only valid for " + "fully specified joint distributions (i.e., no parents)."); + + // Calculate the keys as the frontal keys without the given key. + DiscreteKeys discreteKeys{{key, cardinality(key)}}; + + // Calculate sum + ADT adt(*this); + for (auto&& k : frontals()) + if (k != key) adt = adt.sum(k, cardinality(k)); + + // Return new factor + return DiscreteConditional(1, discreteKeys, adt); +} + +/* ************************************************************************** */ void DiscreteConditional::print(const string& s, const KeyFormatter& formatter) const { cout << s << " P( "; diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 450af57ab..836aa3920 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -126,6 +126,9 @@ class GTSAM_EXPORT DiscreteConditional */ DiscreteConditional operator*(const DiscreteConditional& other) const; + /** Calculate marginal on given key, no parent case. */ + DiscreteConditional marginal(Key key) const; + /// @} /// @name Testable /// @{ diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 8bcb8b4aa..cd3e85598 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -97,6 +97,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { const gtsam::Ordering& orderedKeys); gtsam::DiscreteConditional operator*( const gtsam::DiscreteConditional& other) const; + DiscreteConditional marginal(gtsam::Key key) const; void print(string s = "Discrete Conditional\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 03766136c..125659517 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -97,10 +97,14 @@ TEST(DiscreteConditional, constructors3) { /* ************************************************************************* */ // Check calculation of joint P(A,B) TEST(DiscreteConditional, Multiply) { - DiscreteKey A(0, 2), B(1, 2); + DiscreteKey A(1, 2), B(0, 2); DiscreteConditional conditional(A | B = "1/2 2/1"); DiscreteConditional prior(B % "1/2"); + // The expected factor + DecisionTreeFactor f(A & B, "1 4 2 2"); + DiscreteConditional expected(2, f); + // P(A,B) = P(A|B) * P(B) = P(B) * P(A|B) for (auto&& actual : {prior * conditional, conditional * prior}) { EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); @@ -110,8 +114,11 @@ TEST(DiscreteConditional, Multiply) { const DiscreteValues& v = it.first; EXPECT_DOUBLES_EQUAL(actual(v), conditional(v) * prior(v), 1e-9); } + // And for good measure: + EXPECT(assert_equal(expected, actual)); } } + /* ************************************************************************* */ // Check calculation of conditional joint P(A,B|C) TEST(DiscreteConditional, Multiply2) { @@ -131,6 +138,7 @@ TEST(DiscreteConditional, Multiply2) { } } } + /* ************************************************************************* */ // Check calculation of conditional joint P(A,B|C), double check keys TEST(DiscreteConditional, Multiply3) { @@ -150,6 +158,7 @@ TEST(DiscreteConditional, Multiply3) { } } } + /* ************************************************************************* */ // Check calculation of conditional joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E) TEST(DiscreteConditional, Multiply4) { @@ -173,6 +182,31 @@ TEST(DiscreteConditional, Multiply4) { } } } + +/* ************************************************************************* */ +// Check calculation of marginals for joint P(A,B) +TEST(DiscreteConditional, marginals) { + DiscreteKey A(1, 2), B(0, 2); + DiscreteConditional conditional(A | B = "1/2 2/1"); + DiscreteConditional prior(B % "1/2"); + DiscreteConditional pAB = prior * conditional; + + DiscreteConditional actualA = pAB.marginal(A.first); + DiscreteConditional pA(A % "5/4"); + EXPECT(assert_equal(pA, actualA)); + EXPECT_LONGS_EQUAL(1, actualA.nrFrontals()); + EXPECT_LONGS_EQUAL(0, actualA.nrParents()); + KeyVector frontalsA(actualA.beginFrontals(), actualA.endFrontals()); + EXPECT((frontalsA == KeyVector{1})); + + DiscreteConditional actualB = pAB.marginal(B.first); + EXPECT(assert_equal(prior, actualB)); + EXPECT_LONGS_EQUAL(1, actualB.nrFrontals()); + EXPECT_LONGS_EQUAL(0, actualB.nrParents()); + KeyVector frontalsB(actualB.beginFrontals(), actualB.endFrontals()); + EXPECT((frontalsB == KeyVector{0})); +} + /* ************************************************************************* */ TEST(DiscreteConditional, likelihood) { DiscreteKey X(0, 2), Y(1, 3); diff --git a/python/gtsam/tests/test_DiscreteConditional.py b/python/gtsam/tests/test_DiscreteConditional.py index 190c22181..f46a0e877 100644 --- a/python/gtsam/tests/test_DiscreteConditional.py +++ b/python/gtsam/tests/test_DiscreteConditional.py @@ -81,6 +81,15 @@ class TestDiscreteConditional(GtsamTestCase): self.assertAlmostEqual( actual(v), AB_given_D(v) * C_given_DE(v)) + def test_marginals(self): + conditional = DiscreteConditional(A, [B], "1/2 2/1") + prior = DiscreteConditional(B, "1/2") + pAB = prior * conditional + self.gtsamAssertEquals(prior, pAB.marginal(B[0])) + + pA = DiscreteConditional(A % "5/4") + self.gtsamAssertEquals(pA, pAB.marginal(A[0])) + def test_markdown(self): """Test whether the _repr_markdown_ method."""