From b58d509b68eaf214a9837a19ff4a3ac45eb403e4 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 15 May 2025 17:15:30 -0400 Subject: [PATCH] add python test for sampling from GaussianConditional with a PRNG --- python/gtsam/tests/test_GaussianBayesNet.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/python/gtsam/tests/test_GaussianBayesNet.py b/python/gtsam/tests/test_GaussianBayesNet.py index 05522441b..f06512e93 100644 --- a/python/gtsam/tests/test_GaussianBayesNet.py +++ b/python/gtsam/tests/test_GaussianBayesNet.py @@ -23,11 +23,12 @@ _x_ = 11 _y_ = 22 _z_ = 33 +I_1x1 = np.eye(1, dtype=float) + def smallBayesNet(): """Create a small Bayes Net for testing""" bayesNet = GaussianBayesNet() - I_1x1 = np.eye(1, dtype=float) bayesNet.push_back(GaussianConditional(_x_, [9.0], I_1x1, _y_, I_1x1)) bayesNet.push_back(GaussianConditional(_y_, [5.0], I_1x1)) return bayesNet @@ -51,8 +52,9 @@ class TestGaussianBayesNet(GtsamTestCase): values.insert(_x_, np.array([9.0])) values.insert(_y_, np.array([5.0])) for i in [0, 1]: - self.assertAlmostEqual(bayesNet.at(i).logProbability(values), - np.log(bayesNet.at(i).evaluate(values))) + self.assertAlmostEqual( + bayesNet.at(i).logProbability(values), + np.log(bayesNet.at(i).evaluate(values))) self.assertAlmostEqual(bayesNet.logProbability(values), np.log(bayesNet.evaluate(values))) @@ -66,6 +68,16 @@ class TestGaussianBayesNet(GtsamTestCase): mean = bayesNet.optimize() self.gtsamAssertEquals(sample, mean, tol=3.0) + # Sample with rng + rng = gtsam.MT19937(42) + conditional = GaussianConditional(_x_, [9.0], I_1x1) + # Sample multiple times and average to gey mean + val = 0 + niters = 10000 + for _ in range(niters): + val += conditional.sample(rng).at(_x_).item() + self.assertAlmostEqual(val / niters, 9.0, 2) + if __name__ == "__main__": unittest.main()