sampling tests
parent
0d70a47571
commit
f853b1584b
|
@ -33,6 +33,39 @@ XRay = (2, 2)
|
||||||
Dyspnea = (1, 2)
|
Dyspnea = (1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDiscreteConditional(GtsamTestCase):
|
||||||
|
"""Tests for Discrete Conditional"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.key = (0, 2)
|
||||||
|
self.parent = (1, 2)
|
||||||
|
self.parents = DiscreteKeys()
|
||||||
|
self.parents.push_back(self.parent)
|
||||||
|
|
||||||
|
def test_sample(self):
|
||||||
|
"""Tests to check sampling in DiscreteConditionals"""
|
||||||
|
rng = gtsam.MT19937(11)
|
||||||
|
niters = 1000
|
||||||
|
|
||||||
|
# Sample with only 1 variable
|
||||||
|
conditional = DiscreteConditional(self.key, "7/3")
|
||||||
|
p = 0
|
||||||
|
for _ in range(niters):
|
||||||
|
p += conditional.sample(rng)
|
||||||
|
|
||||||
|
self.assertAlmostEqual(p / niters, 0.3, 1)
|
||||||
|
|
||||||
|
# Sample with variable and parent
|
||||||
|
conditional = DiscreteConditional(self.key, self.parents, "7/3 2/8")
|
||||||
|
p = 0
|
||||||
|
parentValues = gtsam.DiscreteValues()
|
||||||
|
parentValues[self.parent[0]] = 1
|
||||||
|
for _ in range(niters):
|
||||||
|
p += conditional.sample(parentValues, rng)
|
||||||
|
|
||||||
|
self.assertAlmostEqual(p / niters, 0.8, 1)
|
||||||
|
|
||||||
|
|
||||||
class TestDiscreteBayesNet(GtsamTestCase):
|
class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
"""Tests for Discrete Bayes Nets."""
|
"""Tests for Discrete Bayes Nets."""
|
||||||
|
|
||||||
|
@ -85,10 +118,12 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
# solve
|
# solve
|
||||||
actualMPE = fg.optimize()
|
actualMPE = fg.optimize()
|
||||||
expectedMPE = DiscreteValues()
|
expectedMPE = DiscreteValues()
|
||||||
for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]:
|
for key in [
|
||||||
|
Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer,
|
||||||
|
Bronchitis
|
||||||
|
]:
|
||||||
expectedMPE[key[0]] = 0
|
expectedMPE[key[0]] = 0
|
||||||
self.assertEqual(list(actualMPE.items()),
|
self.assertEqual(list(actualMPE.items()), list(expectedMPE.items()))
|
||||||
list(expectedMPE.items()))
|
|
||||||
|
|
||||||
# Check value for MPE is the same
|
# Check value for MPE is the same
|
||||||
self.assertAlmostEqual(asia(actualMPE), fg(actualMPE))
|
self.assertAlmostEqual(asia(actualMPE), fg(actualMPE))
|
||||||
|
@ -104,8 +139,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
expectedMPE2[key[0]] = 0
|
expectedMPE2[key[0]] = 0
|
||||||
for key in [Asia, Dyspnea, Smoking, Bronchitis]:
|
for key in [Asia, Dyspnea, Smoking, Bronchitis]:
|
||||||
expectedMPE2[key[0]] = 1
|
expectedMPE2[key[0]] = 1
|
||||||
self.assertEqual(list(actualMPE2.items()),
|
self.assertEqual(list(actualMPE2.items()), list(expectedMPE2.items()))
|
||||||
list(expectedMPE2.items()))
|
|
||||||
|
|
||||||
# now sample from it
|
# now sample from it
|
||||||
chordal2 = fg.eliminateSequential(ordering)
|
chordal2 = fg.eliminateSequential(ordering)
|
||||||
|
@ -135,8 +169,9 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
# self.assertEqual(len(values), 5)
|
# self.assertEqual(len(values), 5)
|
||||||
|
|
||||||
for i in [0, 1, 2]:
|
for i in [0, 1, 2]:
|
||||||
self.assertAlmostEqual(fragment.at(i).logProbability(values),
|
self.assertAlmostEqual(
|
||||||
math.log(fragment.at(i).evaluate(values)))
|
fragment.at(i).logProbability(values),
|
||||||
|
math.log(fragment.at(i).evaluate(values)))
|
||||||
self.assertAlmostEqual(fragment.logProbability(values),
|
self.assertAlmostEqual(fragment.logProbability(values),
|
||||||
math.log(fragment.evaluate(values)))
|
math.log(fragment.evaluate(values)))
|
||||||
actual = fragment.sample(given)
|
actual = fragment.sample(given)
|
||||||
|
|
|
@ -76,7 +76,7 @@ class TestGaussianBayesNet(GtsamTestCase):
|
||||||
niters = 10000
|
niters = 10000
|
||||||
for _ in range(niters):
|
for _ in range(niters):
|
||||||
val += conditional.sample(rng).at(_x_).item()
|
val += conditional.sample(rng).at(_x_).item()
|
||||||
self.assertAlmostEqual(val / niters, 9.0, 2)
|
self.assertAlmostEqual(val / niters, 9.0, 1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue