add python test for sampling from GaussianConditional with a PRNG
parent
82190cb7eb
commit
b58d509b68
|
@ -23,11 +23,12 @@ _x_ = 11
|
||||||
_y_ = 22
|
_y_ = 22
|
||||||
_z_ = 33
|
_z_ = 33
|
||||||
|
|
||||||
|
I_1x1 = np.eye(1, dtype=float)
|
||||||
|
|
||||||
|
|
||||||
def smallBayesNet():
|
def smallBayesNet():
|
||||||
"""Create a small Bayes Net for testing"""
|
"""Create a small Bayes Net for testing"""
|
||||||
bayesNet = GaussianBayesNet()
|
bayesNet = GaussianBayesNet()
|
||||||
I_1x1 = np.eye(1, dtype=float)
|
|
||||||
bayesNet.push_back(GaussianConditional(_x_, [9.0], I_1x1, _y_, I_1x1))
|
bayesNet.push_back(GaussianConditional(_x_, [9.0], I_1x1, _y_, I_1x1))
|
||||||
bayesNet.push_back(GaussianConditional(_y_, [5.0], I_1x1))
|
bayesNet.push_back(GaussianConditional(_y_, [5.0], I_1x1))
|
||||||
return bayesNet
|
return bayesNet
|
||||||
|
@ -51,8 +52,9 @@ class TestGaussianBayesNet(GtsamTestCase):
|
||||||
values.insert(_x_, np.array([9.0]))
|
values.insert(_x_, np.array([9.0]))
|
||||||
values.insert(_y_, np.array([5.0]))
|
values.insert(_y_, np.array([5.0]))
|
||||||
for i in [0, 1]:
|
for i in [0, 1]:
|
||||||
self.assertAlmostEqual(bayesNet.at(i).logProbability(values),
|
self.assertAlmostEqual(
|
||||||
np.log(bayesNet.at(i).evaluate(values)))
|
bayesNet.at(i).logProbability(values),
|
||||||
|
np.log(bayesNet.at(i).evaluate(values)))
|
||||||
self.assertAlmostEqual(bayesNet.logProbability(values),
|
self.assertAlmostEqual(bayesNet.logProbability(values),
|
||||||
np.log(bayesNet.evaluate(values)))
|
np.log(bayesNet.evaluate(values)))
|
||||||
|
|
||||||
|
@ -66,6 +68,16 @@ class TestGaussianBayesNet(GtsamTestCase):
|
||||||
mean = bayesNet.optimize()
|
mean = bayesNet.optimize()
|
||||||
self.gtsamAssertEquals(sample, mean, tol=3.0)
|
self.gtsamAssertEquals(sample, mean, tol=3.0)
|
||||||
|
|
||||||
|
# Sample with rng
|
||||||
|
rng = gtsam.MT19937(42)
|
||||||
|
conditional = GaussianConditional(_x_, [9.0], I_1x1)
|
||||||
|
# Sample multiple times and average to gey mean
|
||||||
|
val = 0
|
||||||
|
niters = 10000
|
||||||
|
for _ in range(niters):
|
||||||
|
val += conditional.sample(rng).at(_x_).item()
|
||||||
|
self.assertAlmostEqual(val / niters, 9.0, 2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue