diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 5fce25cf5..8bcb8b4aa 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -95,10 +95,14 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { DiscreteConditional(const gtsam::DecisionTreeFactor& joint, const gtsam::DecisionTreeFactor& marginal, const gtsam::Ordering& orderedKeys); + gtsam::DiscreteConditional operator*( + const gtsam::DiscreteConditional& other) const; void print(string s = "Discrete Conditional\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const; + size_t nrFrontals() const; + size_t nrParents() const; void printSignature( string s = "Discrete Conditional: ", const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; diff --git a/python/gtsam/tests/test_DecisionTreeFactor.py b/python/gtsam/tests/test_DecisionTreeFactor.py index 03d9f82d7..a13a43e26 100644 --- a/python/gtsam/tests/test_DecisionTreeFactor.py +++ b/python/gtsam/tests/test_DecisionTreeFactor.py @@ -40,7 +40,7 @@ class TestDecisionTreeFactor(GtsamTestCase): prior = DiscretePrior(v1, [1, 3]) f1 = DecisionTreeFactor([v0, v1], "1 2 3 4") expected = DecisionTreeFactor([v0, v1], "0.25 1.5 0.75 3") - self.gtsamAssertEquals(prior * f1, expected) + self.gtsamAssertEquals(DecisionTreeFactor(prior) * f1, expected) self.gtsamAssertEquals(f1 * prior, expected) # Multiply two factors diff --git a/python/gtsam/tests/test_DiscreteConditional.py b/python/gtsam/tests/test_DiscreteConditional.py index 0ae66c2d4..190c22181 100644 --- a/python/gtsam/tests/test_DiscreteConditional.py +++ b/python/gtsam/tests/test_DiscreteConditional.py @@ -16,6 +16,13 @@ import unittest from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys from gtsam.utils.test_case import GtsamTestCase +# Some DiscreteKeys for binary variables: +A = 0, 2 +B = 1, 2 +C = 2, 2 +D = 4, 2 +E = 3, 2 + class TestDiscreteConditional(GtsamTestCase): """Tests for Discrete Conditionals.""" @@ -36,6 +43,44 @@ class TestDiscreteConditional(GtsamTestCase): actual = conditional.sample(2) self.assertIsInstance(actual, int) + def test_multiply(self): + """Check calculation of joint P(A,B)""" + conditional = DiscreteConditional(A, [B], "1/2 2/1") + prior = DiscreteConditional(B, "1/2") + + # P(A,B) = P(A|B) * P(B) = P(B) * P(A|B) + for actual in [prior * conditional, conditional * prior]: + self.assertEqual(2, actual.nrFrontals()) + for v, value in actual.enumerate(): + self.assertAlmostEqual(actual(v), conditional(v) * prior(v)) + + def test_multiply2(self): + """Check calculation of conditional joint P(A,B|C)""" + A_given_B = DiscreteConditional(A, [B], "1/3 3/1") + B_given_C = DiscreteConditional(B, [C], "1/3 3/1") + + # P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B) + for actual in [A_given_B * B_given_C, B_given_C * A_given_B]: + self.assertEqual(2, actual.nrFrontals()) + self.assertEqual(1, actual.nrParents()) + for v, value in actual.enumerate(): + self.assertAlmostEqual(actual(v), A_given_B(v) * B_given_C(v)) + + def test_multiply4(self): + """Check calculation of joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E)""" + A_given_B = DiscreteConditional(A, [B], "1/3 3/1") + B_given_D = DiscreteConditional(B, [D], "1/3 3/1") + AB_given_D = A_given_B * B_given_D + C_given_DE = DiscreteConditional(C, [D, E], "4/1 1/1 1/1 1/4") + + # P(A,B,C|D,E) = P(A,B|D) P(C|D,E) = P(C|D,E) P(A,B|D) + for actual in [AB_given_D * C_given_DE, C_given_DE * AB_given_D]: + self.assertEqual(3, actual.nrFrontals()) + self.assertEqual(2, actual.nrParents()) + for v, value in actual.enumerate(): + self.assertAlmostEqual( + actual(v), AB_given_D(v) * C_given_DE(v)) + def test_markdown(self): """Test whether the _repr_markdown_ method.""" @@ -48,8 +93,7 @@ class TestDiscreteConditional(GtsamTestCase): conditional = DiscreteConditional(A, parents, "0/1 1/3 1/1 3/1 0/1 1/0") - expected = \ - " *P(A|B,C):*\n\n" \ + expected = " *P(A|B,C):*\n\n" \ "|*B*|*C*|0|1|\n" \ "|:-:|:-:|:-:|:-:|\n" \ "|0|0|0|1|\n" \