add python test for sampling from GaussianConditional with a PRNG

release/4.3a0
Varun Agrawal 2025-05-15 17:15:30 -04:00
parent 82190cb7eb
commit b58d509b68
1 changed files with 15 additions and 3 deletions

View File

@ -23,11 +23,12 @@ _x_ = 11
_y_ = 22
_z_ = 33
I_1x1 = np.eye(1, dtype=float)
def smallBayesNet():
"""Create a small Bayes Net for testing"""
bayesNet = GaussianBayesNet()
I_1x1 = np.eye(1, dtype=float)
bayesNet.push_back(GaussianConditional(_x_, [9.0], I_1x1, _y_, I_1x1))
bayesNet.push_back(GaussianConditional(_y_, [5.0], I_1x1))
return bayesNet
@ -51,8 +52,9 @@ class TestGaussianBayesNet(GtsamTestCase):
values.insert(_x_, np.array([9.0]))
values.insert(_y_, np.array([5.0]))
for i in [0, 1]:
self.assertAlmostEqual(bayesNet.at(i).logProbability(values),
np.log(bayesNet.at(i).evaluate(values)))
self.assertAlmostEqual(
bayesNet.at(i).logProbability(values),
np.log(bayesNet.at(i).evaluate(values)))
self.assertAlmostEqual(bayesNet.logProbability(values),
np.log(bayesNet.evaluate(values)))
@ -66,6 +68,16 @@ class TestGaussianBayesNet(GtsamTestCase):
mean = bayesNet.optimize()
self.gtsamAssertEquals(sample, mean, tol=3.0)
# Sample with rng
rng = gtsam.MT19937(42)
conditional = GaussianConditional(_x_, [9.0], I_1x1)
# Sample multiple times and average to gey mean
val = 0
niters = 10000
for _ in range(niters):
val += conditional.sample(rng).at(_x_).item()
self.assertAlmostEqual(val / niters, 9.0, 2)
if __name__ == "__main__":
unittest.main()