sampling tests
							parent
							
								
									0d70a47571
								
							
						
					
					
						commit
						f853b1584b
					
				|  | @ -33,6 +33,39 @@ XRay = (2, 2) | |||
| Dyspnea = (1, 2) | ||||
| 
 | ||||
| 
 | ||||
| class TestDiscreteConditional(GtsamTestCase): | ||||
|     """Tests for Discrete Conditional""" | ||||
| 
 | ||||
|     def setUp(self): | ||||
|         self.key = (0, 2) | ||||
|         self.parent = (1, 2) | ||||
|         self.parents = DiscreteKeys() | ||||
|         self.parents.push_back(self.parent) | ||||
| 
 | ||||
|     def test_sample(self): | ||||
|         """Tests to check sampling in DiscreteConditionals""" | ||||
|         rng = gtsam.MT19937(11) | ||||
|         niters = 1000 | ||||
| 
 | ||||
|         # Sample with only 1 variable | ||||
|         conditional = DiscreteConditional(self.key, "7/3") | ||||
|         p = 0 | ||||
|         for _ in range(niters): | ||||
|             p += conditional.sample(rng) | ||||
| 
 | ||||
|         self.assertAlmostEqual(p / niters, 0.3, 1) | ||||
| 
 | ||||
|         # Sample with variable and parent | ||||
|         conditional = DiscreteConditional(self.key, self.parents, "7/3 2/8") | ||||
|         p = 0 | ||||
|         parentValues = gtsam.DiscreteValues() | ||||
|         parentValues[self.parent[0]] = 1 | ||||
|         for _ in range(niters): | ||||
|             p += conditional.sample(parentValues, rng) | ||||
| 
 | ||||
|         self.assertAlmostEqual(p / niters, 0.8, 1) | ||||
| 
 | ||||
| 
 | ||||
| class TestDiscreteBayesNet(GtsamTestCase): | ||||
|     """Tests for Discrete Bayes Nets.""" | ||||
| 
 | ||||
|  | @ -85,10 +118,12 @@ class TestDiscreteBayesNet(GtsamTestCase): | |||
|         # solve | ||||
|         actualMPE = fg.optimize() | ||||
|         expectedMPE = DiscreteValues() | ||||
|         for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]: | ||||
|         for key in [ | ||||
|                 Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, | ||||
|                 Bronchitis | ||||
|         ]: | ||||
|             expectedMPE[key[0]] = 0 | ||||
|         self.assertEqual(list(actualMPE.items()), | ||||
|                          list(expectedMPE.items())) | ||||
|         self.assertEqual(list(actualMPE.items()), list(expectedMPE.items())) | ||||
| 
 | ||||
|         # Check value for MPE is the same | ||||
|         self.assertAlmostEqual(asia(actualMPE), fg(actualMPE)) | ||||
|  | @ -104,8 +139,7 @@ class TestDiscreteBayesNet(GtsamTestCase): | |||
|             expectedMPE2[key[0]] = 0 | ||||
|         for key in [Asia, Dyspnea, Smoking, Bronchitis]: | ||||
|             expectedMPE2[key[0]] = 1 | ||||
|         self.assertEqual(list(actualMPE2.items()), | ||||
|                          list(expectedMPE2.items())) | ||||
|         self.assertEqual(list(actualMPE2.items()), list(expectedMPE2.items())) | ||||
| 
 | ||||
|         # now sample from it | ||||
|         chordal2 = fg.eliminateSequential(ordering) | ||||
|  | @ -135,8 +169,9 @@ class TestDiscreteBayesNet(GtsamTestCase): | |||
|         # self.assertEqual(len(values), 5) | ||||
| 
 | ||||
|         for i in [0, 1, 2]: | ||||
|             self.assertAlmostEqual(fragment.at(i).logProbability(values), | ||||
|                                    math.log(fragment.at(i).evaluate(values))) | ||||
|             self.assertAlmostEqual( | ||||
|                 fragment.at(i).logProbability(values), | ||||
|                 math.log(fragment.at(i).evaluate(values))) | ||||
|         self.assertAlmostEqual(fragment.logProbability(values), | ||||
|                                math.log(fragment.evaluate(values))) | ||||
|         actual = fragment.sample(given) | ||||
|  |  | |||
|  | @ -76,7 +76,7 @@ class TestGaussianBayesNet(GtsamTestCase): | |||
|         niters = 10000 | ||||
|         for _ in range(niters): | ||||
|             val += conditional.sample(rng).at(_x_).item() | ||||
|         self.assertAlmostEqual(val / niters, 9.0, 2) | ||||
|         self.assertAlmostEqual(val / niters, 9.0, 1) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue