From 03baf8f75ed0aba3c971b2da5ad58460666d000d Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 29 Dec 2022 08:33:14 +0530 Subject: [PATCH] formatting and fixes to test --- python/gtsam/tests/test_HybridBayesNet.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/python/gtsam/tests/test_HybridBayesNet.py b/python/gtsam/tests/test_HybridBayesNet.py index cfe080dcb..13ac3a3e2 100644 --- a/python/gtsam/tests/test_HybridBayesNet.py +++ b/python/gtsam/tests/test_HybridBayesNet.py @@ -10,8 +10,6 @@ Author: Frank Dellaert """ # pylint: disable=invalid-name, no-name-in-module, no-member -from __future__ import print_function - import unittest import numpy as np @@ -19,12 +17,12 @@ from gtsam.symbol_shorthand import A, X from gtsam.utils.test_case import GtsamTestCase import gtsam -from gtsam import GaussianConditional, GaussianMixture, HybridBayesNet, HybridValues, noiseModel +from gtsam import (DiscreteKeys, GaussianConditional, GaussianMixture, + HybridBayesNet, HybridValues, noiseModel) class TestHybridBayesNet(GtsamTestCase): """Unit tests for HybridValues.""" - def test_evaluate(self): """Test evaluate for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia).""" asiaKey = A(0) @@ -32,7 +30,8 @@ class TestHybridBayesNet(GtsamTestCase): # Create the continuous conditional I_1x1 = np.eye(1) - gc = GaussianConditional.FromMeanAndStddev(X(0), 2 * I_1x1, X(1), [-4], 5.0) + gc = GaussianConditional.FromMeanAndStddev(X(0), 2 * I_1x1, X(1), [-4], + 5.0) # Create the noise models model0 = noiseModel.Diagonal.Sigmas([2.0]) @@ -41,7 +40,10 @@ class TestHybridBayesNet(GtsamTestCase): # Create the conditionals conditional0 = GaussianConditional(X(1), [5], I_1x1, model0) conditional1 = GaussianConditional(X(1), [2], I_1x1, model1) - # gm = GaussianMixture.FromConditionals([X(1)], [], [Asia], [conditional0, conditional1]) # + dkeys = DiscreteKeys() + dkeys.push_back(Asia) + gm = GaussianMixture.FromConditionals([X(1)], [], dkeys, + [conditional0, conditional1]) # # Create hybrid Bayes net. bayesNet = HybridBayesNet() @@ -57,9 +59,10 @@ class TestHybridBayesNet(GtsamTestCase): conditionalProbability = gc.evaluate(values.continuous()) mixtureProbability = conditional0.evaluate(values.continuous()) - assert self.assertAlmostEqual( - conditionalProbability * mixtureProbability * 0.99, bayesNet.evaluate(values), places=5 - ) + assert self.assertAlmostEqual(conditionalProbability * + mixtureProbability * 0.99, + bayesNet.evaluate(values), + places=5) if __name__ == "__main__":