sampling tests

release/4.3a0
Varun Agrawal 2025-05-15 18:11:40 -04:00
parent 0d70a47571
commit f853b1584b
2 changed files with 43 additions and 8 deletions

View File

@ -33,6 +33,39 @@ XRay = (2, 2)
Dyspnea = (1, 2) Dyspnea = (1, 2)
class TestDiscreteConditional(GtsamTestCase):
"""Tests for Discrete Conditional"""
def setUp(self):
self.key = (0, 2)
self.parent = (1, 2)
self.parents = DiscreteKeys()
self.parents.push_back(self.parent)
def test_sample(self):
"""Tests to check sampling in DiscreteConditionals"""
rng = gtsam.MT19937(11)
niters = 1000
# Sample with only 1 variable
conditional = DiscreteConditional(self.key, "7/3")
p = 0
for _ in range(niters):
p += conditional.sample(rng)
self.assertAlmostEqual(p / niters, 0.3, 1)
# Sample with variable and parent
conditional = DiscreteConditional(self.key, self.parents, "7/3 2/8")
p = 0
parentValues = gtsam.DiscreteValues()
parentValues[self.parent[0]] = 1
for _ in range(niters):
p += conditional.sample(parentValues, rng)
self.assertAlmostEqual(p / niters, 0.8, 1)
class TestDiscreteBayesNet(GtsamTestCase): class TestDiscreteBayesNet(GtsamTestCase):
"""Tests for Discrete Bayes Nets.""" """Tests for Discrete Bayes Nets."""
@ -85,10 +118,12 @@ class TestDiscreteBayesNet(GtsamTestCase):
# solve # solve
actualMPE = fg.optimize() actualMPE = fg.optimize()
expectedMPE = DiscreteValues() expectedMPE = DiscreteValues()
for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]: for key in [
Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer,
Bronchitis
]:
expectedMPE[key[0]] = 0 expectedMPE[key[0]] = 0
self.assertEqual(list(actualMPE.items()), self.assertEqual(list(actualMPE.items()), list(expectedMPE.items()))
list(expectedMPE.items()))
# Check value for MPE is the same # Check value for MPE is the same
self.assertAlmostEqual(asia(actualMPE), fg(actualMPE)) self.assertAlmostEqual(asia(actualMPE), fg(actualMPE))
@ -104,8 +139,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
expectedMPE2[key[0]] = 0 expectedMPE2[key[0]] = 0
for key in [Asia, Dyspnea, Smoking, Bronchitis]: for key in [Asia, Dyspnea, Smoking, Bronchitis]:
expectedMPE2[key[0]] = 1 expectedMPE2[key[0]] = 1
self.assertEqual(list(actualMPE2.items()), self.assertEqual(list(actualMPE2.items()), list(expectedMPE2.items()))
list(expectedMPE2.items()))
# now sample from it # now sample from it
chordal2 = fg.eliminateSequential(ordering) chordal2 = fg.eliminateSequential(ordering)
@ -135,8 +169,9 @@ class TestDiscreteBayesNet(GtsamTestCase):
# self.assertEqual(len(values), 5) # self.assertEqual(len(values), 5)
for i in [0, 1, 2]: for i in [0, 1, 2]:
self.assertAlmostEqual(fragment.at(i).logProbability(values), self.assertAlmostEqual(
math.log(fragment.at(i).evaluate(values))) fragment.at(i).logProbability(values),
math.log(fragment.at(i).evaluate(values)))
self.assertAlmostEqual(fragment.logProbability(values), self.assertAlmostEqual(fragment.logProbability(values),
math.log(fragment.evaluate(values))) math.log(fragment.evaluate(values)))
actual = fragment.sample(given) actual = fragment.sample(given)

View File

@ -76,7 +76,7 @@ class TestGaussianBayesNet(GtsamTestCase):
niters = 10000 niters = 10000
for _ in range(niters): for _ in range(niters):
val += conditional.sample(rng).at(_x_).item() val += conditional.sample(rng).at(_x_).item()
self.assertAlmostEqual(val / niters, 9.0, 2) self.assertAlmostEqual(val / niters, 9.0, 1)
if __name__ == "__main__": if __name__ == "__main__":