formatting and fixes to test

release/4.3a0
Varun Agrawal 2022-12-29 08:33:14 +05:30
parent 873f5baf56
commit 03baf8f75e
1 changed files with 12 additions and 9 deletions

View File

@ -10,8 +10,6 @@ Author: Frank Dellaert
""" """
# pylint: disable=invalid-name, no-name-in-module, no-member # pylint: disable=invalid-name, no-name-in-module, no-member
from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
@ -19,12 +17,12 @@ from gtsam.symbol_shorthand import A, X
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
import gtsam import gtsam
from gtsam import GaussianConditional, GaussianMixture, HybridBayesNet, HybridValues, noiseModel from gtsam import (DiscreteKeys, GaussianConditional, GaussianMixture,
HybridBayesNet, HybridValues, noiseModel)
class TestHybridBayesNet(GtsamTestCase): class TestHybridBayesNet(GtsamTestCase):
"""Unit tests for HybridValues.""" """Unit tests for HybridValues."""
def test_evaluate(self): def test_evaluate(self):
"""Test evaluate for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia).""" """Test evaluate for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia)."""
asiaKey = A(0) asiaKey = A(0)
@ -32,7 +30,8 @@ class TestHybridBayesNet(GtsamTestCase):
# Create the continuous conditional # Create the continuous conditional
I_1x1 = np.eye(1) I_1x1 = np.eye(1)
gc = GaussianConditional.FromMeanAndStddev(X(0), 2 * I_1x1, X(1), [-4], 5.0) gc = GaussianConditional.FromMeanAndStddev(X(0), 2 * I_1x1, X(1), [-4],
5.0)
# Create the noise models # Create the noise models
model0 = noiseModel.Diagonal.Sigmas([2.0]) model0 = noiseModel.Diagonal.Sigmas([2.0])
@ -41,7 +40,10 @@ class TestHybridBayesNet(GtsamTestCase):
# Create the conditionals # Create the conditionals
conditional0 = GaussianConditional(X(1), [5], I_1x1, model0) conditional0 = GaussianConditional(X(1), [5], I_1x1, model0)
conditional1 = GaussianConditional(X(1), [2], I_1x1, model1) conditional1 = GaussianConditional(X(1), [2], I_1x1, model1)
# gm = GaussianMixture.FromConditionals([X(1)], [], [Asia], [conditional0, conditional1]) # dkeys = DiscreteKeys()
dkeys.push_back(Asia)
gm = GaussianMixture.FromConditionals([X(1)], [], dkeys,
[conditional0, conditional1]) #
# Create hybrid Bayes net. # Create hybrid Bayes net.
bayesNet = HybridBayesNet() bayesNet = HybridBayesNet()
@ -57,9 +59,10 @@ class TestHybridBayesNet(GtsamTestCase):
conditionalProbability = gc.evaluate(values.continuous()) conditionalProbability = gc.evaluate(values.continuous())
mixtureProbability = conditional0.evaluate(values.continuous()) mixtureProbability = conditional0.evaluate(values.continuous())
assert self.assertAlmostEqual( assert self.assertAlmostEqual(conditionalProbability *
conditionalProbability * mixtureProbability * 0.99, bayesNet.evaluate(values), places=5 mixtureProbability * 0.99,
) bayesNet.evaluate(values),
places=5)
if __name__ == "__main__": if __name__ == "__main__":