Wrapped multiplication
parent
f9dd225ca5
commit
23a8dba716
|
@ -95,10 +95,14 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
||||||
const gtsam::DecisionTreeFactor& marginal,
|
const gtsam::DecisionTreeFactor& marginal,
|
||||||
const gtsam::Ordering& orderedKeys);
|
const gtsam::Ordering& orderedKeys);
|
||||||
|
gtsam::DiscreteConditional operator*(
|
||||||
|
const gtsam::DiscreteConditional& other) const;
|
||||||
void print(string s = "Discrete Conditional\n",
|
void print(string s = "Discrete Conditional\n",
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
|
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
|
||||||
|
size_t nrFrontals() const;
|
||||||
|
size_t nrParents() const;
|
||||||
void printSignature(
|
void printSignature(
|
||||||
string s = "Discrete Conditional: ",
|
string s = "Discrete Conditional: ",
|
||||||
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||||
|
|
|
@ -40,7 +40,7 @@ class TestDecisionTreeFactor(GtsamTestCase):
|
||||||
prior = DiscretePrior(v1, [1, 3])
|
prior = DiscretePrior(v1, [1, 3])
|
||||||
f1 = DecisionTreeFactor([v0, v1], "1 2 3 4")
|
f1 = DecisionTreeFactor([v0, v1], "1 2 3 4")
|
||||||
expected = DecisionTreeFactor([v0, v1], "0.25 1.5 0.75 3")
|
expected = DecisionTreeFactor([v0, v1], "0.25 1.5 0.75 3")
|
||||||
self.gtsamAssertEquals(prior * f1, expected)
|
self.gtsamAssertEquals(DecisionTreeFactor(prior) * f1, expected)
|
||||||
self.gtsamAssertEquals(f1 * prior, expected)
|
self.gtsamAssertEquals(f1 * prior, expected)
|
||||||
|
|
||||||
# Multiply two factors
|
# Multiply two factors
|
||||||
|
|
|
@ -16,6 +16,13 @@ import unittest
|
||||||
from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys
|
from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
# Some DiscreteKeys for binary variables:
|
||||||
|
A = 0, 2
|
||||||
|
B = 1, 2
|
||||||
|
C = 2, 2
|
||||||
|
D = 4, 2
|
||||||
|
E = 3, 2
|
||||||
|
|
||||||
|
|
||||||
class TestDiscreteConditional(GtsamTestCase):
|
class TestDiscreteConditional(GtsamTestCase):
|
||||||
"""Tests for Discrete Conditionals."""
|
"""Tests for Discrete Conditionals."""
|
||||||
|
@ -36,6 +43,44 @@ class TestDiscreteConditional(GtsamTestCase):
|
||||||
actual = conditional.sample(2)
|
actual = conditional.sample(2)
|
||||||
self.assertIsInstance(actual, int)
|
self.assertIsInstance(actual, int)
|
||||||
|
|
||||||
|
def test_multiply(self):
|
||||||
|
"""Check calculation of joint P(A,B)"""
|
||||||
|
conditional = DiscreteConditional(A, [B], "1/2 2/1")
|
||||||
|
prior = DiscreteConditional(B, "1/2")
|
||||||
|
|
||||||
|
# P(A,B) = P(A|B) * P(B) = P(B) * P(A|B)
|
||||||
|
for actual in [prior * conditional, conditional * prior]:
|
||||||
|
self.assertEqual(2, actual.nrFrontals())
|
||||||
|
for v, value in actual.enumerate():
|
||||||
|
self.assertAlmostEqual(actual(v), conditional(v) * prior(v))
|
||||||
|
|
||||||
|
def test_multiply2(self):
|
||||||
|
"""Check calculation of conditional joint P(A,B|C)"""
|
||||||
|
A_given_B = DiscreteConditional(A, [B], "1/3 3/1")
|
||||||
|
B_given_C = DiscreteConditional(B, [C], "1/3 3/1")
|
||||||
|
|
||||||
|
# P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B)
|
||||||
|
for actual in [A_given_B * B_given_C, B_given_C * A_given_B]:
|
||||||
|
self.assertEqual(2, actual.nrFrontals())
|
||||||
|
self.assertEqual(1, actual.nrParents())
|
||||||
|
for v, value in actual.enumerate():
|
||||||
|
self.assertAlmostEqual(actual(v), A_given_B(v) * B_given_C(v))
|
||||||
|
|
||||||
|
def test_multiply4(self):
|
||||||
|
"""Check calculation of joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E)"""
|
||||||
|
A_given_B = DiscreteConditional(A, [B], "1/3 3/1")
|
||||||
|
B_given_D = DiscreteConditional(B, [D], "1/3 3/1")
|
||||||
|
AB_given_D = A_given_B * B_given_D
|
||||||
|
C_given_DE = DiscreteConditional(C, [D, E], "4/1 1/1 1/1 1/4")
|
||||||
|
|
||||||
|
# P(A,B,C|D,E) = P(A,B|D) P(C|D,E) = P(C|D,E) P(A,B|D)
|
||||||
|
for actual in [AB_given_D * C_given_DE, C_given_DE * AB_given_D]:
|
||||||
|
self.assertEqual(3, actual.nrFrontals())
|
||||||
|
self.assertEqual(2, actual.nrParents())
|
||||||
|
for v, value in actual.enumerate():
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
actual(v), AB_given_D(v) * C_given_DE(v))
|
||||||
|
|
||||||
def test_markdown(self):
|
def test_markdown(self):
|
||||||
"""Test whether the _repr_markdown_ method."""
|
"""Test whether the _repr_markdown_ method."""
|
||||||
|
|
||||||
|
@ -48,8 +93,7 @@ class TestDiscreteConditional(GtsamTestCase):
|
||||||
|
|
||||||
conditional = DiscreteConditional(A, parents,
|
conditional = DiscreteConditional(A, parents,
|
||||||
"0/1 1/3 1/1 3/1 0/1 1/0")
|
"0/1 1/3 1/1 3/1 0/1 1/0")
|
||||||
expected = \
|
expected = " *P(A|B,C):*\n\n" \
|
||||||
" *P(A|B,C):*\n\n" \
|
|
||||||
"|*B*|*C*|0|1|\n" \
|
"|*B*|*C*|0|1|\n" \
|
||||||
"|:-:|:-:|:-:|:-:|\n" \
|
"|:-:|:-:|:-:|:-:|\n" \
|
||||||
"|0|0|0|1|\n" \
|
"|0|0|0|1|\n" \
|
||||||
|
|
Loading…
Reference in New Issue