From f853b1584b4243073752940a9e9c3aa45cc4b0a0 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 15 May 2025 18:11:40 -0400 Subject: [PATCH] sampling tests --- python/gtsam/tests/test_DiscreteBayesNet.py | 49 ++++++++++++++++++--- python/gtsam/tests/test_GaussianBayesNet.py | 2 +- 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/python/gtsam/tests/test_DiscreteBayesNet.py b/python/gtsam/tests/test_DiscreteBayesNet.py index 36bed816d..471d82264 100644 --- a/python/gtsam/tests/test_DiscreteBayesNet.py +++ b/python/gtsam/tests/test_DiscreteBayesNet.py @@ -33,6 +33,39 @@ XRay = (2, 2) Dyspnea = (1, 2) +class TestDiscreteConditional(GtsamTestCase): + """Tests for Discrete Conditional""" + + def setUp(self): + self.key = (0, 2) + self.parent = (1, 2) + self.parents = DiscreteKeys() + self.parents.push_back(self.parent) + + def test_sample(self): + """Tests to check sampling in DiscreteConditionals""" + rng = gtsam.MT19937(11) + niters = 1000 + + # Sample with only 1 variable + conditional = DiscreteConditional(self.key, "7/3") + p = 0 + for _ in range(niters): + p += conditional.sample(rng) + + self.assertAlmostEqual(p / niters, 0.3, 1) + + # Sample with variable and parent + conditional = DiscreteConditional(self.key, self.parents, "7/3 2/8") + p = 0 + parentValues = gtsam.DiscreteValues() + parentValues[self.parent[0]] = 1 + for _ in range(niters): + p += conditional.sample(parentValues, rng) + + self.assertAlmostEqual(p / niters, 0.8, 1) + + class TestDiscreteBayesNet(GtsamTestCase): """Tests for Discrete Bayes Nets.""" @@ -85,10 +118,12 @@ class TestDiscreteBayesNet(GtsamTestCase): # solve actualMPE = fg.optimize() expectedMPE = DiscreteValues() - for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]: + for key in [ + Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, + Bronchitis + ]: expectedMPE[key[0]] = 0 - self.assertEqual(list(actualMPE.items()), - list(expectedMPE.items())) + self.assertEqual(list(actualMPE.items()), list(expectedMPE.items())) # Check value for MPE is the same self.assertAlmostEqual(asia(actualMPE), fg(actualMPE)) @@ -104,8 +139,7 @@ class TestDiscreteBayesNet(GtsamTestCase): expectedMPE2[key[0]] = 0 for key in [Asia, Dyspnea, Smoking, Bronchitis]: expectedMPE2[key[0]] = 1 - self.assertEqual(list(actualMPE2.items()), - list(expectedMPE2.items())) + self.assertEqual(list(actualMPE2.items()), list(expectedMPE2.items())) # now sample from it chordal2 = fg.eliminateSequential(ordering) @@ -135,8 +169,9 @@ class TestDiscreteBayesNet(GtsamTestCase): # self.assertEqual(len(values), 5) for i in [0, 1, 2]: - self.assertAlmostEqual(fragment.at(i).logProbability(values), - math.log(fragment.at(i).evaluate(values))) + self.assertAlmostEqual( + fragment.at(i).logProbability(values), + math.log(fragment.at(i).evaluate(values))) self.assertAlmostEqual(fragment.logProbability(values), math.log(fragment.evaluate(values))) actual = fragment.sample(given) diff --git a/python/gtsam/tests/test_GaussianBayesNet.py b/python/gtsam/tests/test_GaussianBayesNet.py index f06512e93..468ee8de5 100644 --- a/python/gtsam/tests/test_GaussianBayesNet.py +++ b/python/gtsam/tests/test_GaussianBayesNet.py @@ -76,7 +76,7 @@ class TestGaussianBayesNet(GtsamTestCase): niters = 10000 for _ in range(niters): val += conditional.sample(rng).at(_x_).item() - self.assertAlmostEqual(val / niters, 9.0, 2) + self.assertAlmostEqual(val / niters, 9.0, 1) if __name__ == "__main__":