sampling tests
parent
0d70a47571
commit
f853b1584b
|
@ -33,6 +33,39 @@ XRay = (2, 2)
|
|||
Dyspnea = (1, 2)
|
||||
|
||||
|
||||
class TestDiscreteConditional(GtsamTestCase):
|
||||
"""Tests for Discrete Conditional"""
|
||||
|
||||
def setUp(self):
|
||||
self.key = (0, 2)
|
||||
self.parent = (1, 2)
|
||||
self.parents = DiscreteKeys()
|
||||
self.parents.push_back(self.parent)
|
||||
|
||||
def test_sample(self):
|
||||
"""Tests to check sampling in DiscreteConditionals"""
|
||||
rng = gtsam.MT19937(11)
|
||||
niters = 1000
|
||||
|
||||
# Sample with only 1 variable
|
||||
conditional = DiscreteConditional(self.key, "7/3")
|
||||
p = 0
|
||||
for _ in range(niters):
|
||||
p += conditional.sample(rng)
|
||||
|
||||
self.assertAlmostEqual(p / niters, 0.3, 1)
|
||||
|
||||
# Sample with variable and parent
|
||||
conditional = DiscreteConditional(self.key, self.parents, "7/3 2/8")
|
||||
p = 0
|
||||
parentValues = gtsam.DiscreteValues()
|
||||
parentValues[self.parent[0]] = 1
|
||||
for _ in range(niters):
|
||||
p += conditional.sample(parentValues, rng)
|
||||
|
||||
self.assertAlmostEqual(p / niters, 0.8, 1)
|
||||
|
||||
|
||||
class TestDiscreteBayesNet(GtsamTestCase):
|
||||
"""Tests for Discrete Bayes Nets."""
|
||||
|
||||
|
@ -85,10 +118,12 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
# solve
|
||||
actualMPE = fg.optimize()
|
||||
expectedMPE = DiscreteValues()
|
||||
for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]:
|
||||
for key in [
|
||||
Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer,
|
||||
Bronchitis
|
||||
]:
|
||||
expectedMPE[key[0]] = 0
|
||||
self.assertEqual(list(actualMPE.items()),
|
||||
list(expectedMPE.items()))
|
||||
self.assertEqual(list(actualMPE.items()), list(expectedMPE.items()))
|
||||
|
||||
# Check value for MPE is the same
|
||||
self.assertAlmostEqual(asia(actualMPE), fg(actualMPE))
|
||||
|
@ -104,8 +139,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
expectedMPE2[key[0]] = 0
|
||||
for key in [Asia, Dyspnea, Smoking, Bronchitis]:
|
||||
expectedMPE2[key[0]] = 1
|
||||
self.assertEqual(list(actualMPE2.items()),
|
||||
list(expectedMPE2.items()))
|
||||
self.assertEqual(list(actualMPE2.items()), list(expectedMPE2.items()))
|
||||
|
||||
# now sample from it
|
||||
chordal2 = fg.eliminateSequential(ordering)
|
||||
|
@ -135,8 +169,9 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
# self.assertEqual(len(values), 5)
|
||||
|
||||
for i in [0, 1, 2]:
|
||||
self.assertAlmostEqual(fragment.at(i).logProbability(values),
|
||||
math.log(fragment.at(i).evaluate(values)))
|
||||
self.assertAlmostEqual(
|
||||
fragment.at(i).logProbability(values),
|
||||
math.log(fragment.at(i).evaluate(values)))
|
||||
self.assertAlmostEqual(fragment.logProbability(values),
|
||||
math.log(fragment.evaluate(values)))
|
||||
actual = fragment.sample(given)
|
||||
|
|
|
@ -76,7 +76,7 @@ class TestGaussianBayesNet(GtsamTestCase):
|
|||
niters = 10000
|
||||
for _ in range(niters):
|
||||
val += conditional.sample(rng).at(_x_).item()
|
||||
self.assertAlmostEqual(val / niters, 9.0, 2)
|
||||
self.assertAlmostEqual(val / niters, 9.0, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue