formatting and fixes to test
parent
873f5baf56
commit
03baf8f75e
|
@ -10,8 +10,6 @@ Author: Frank Dellaert
|
||||||
"""
|
"""
|
||||||
# pylint: disable=invalid-name, no-name-in-module, no-member
|
# pylint: disable=invalid-name, no-name-in-module, no-member
|
||||||
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -19,12 +17,12 @@ from gtsam.symbol_shorthand import A, X
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
import gtsam
|
import gtsam
|
||||||
from gtsam import GaussianConditional, GaussianMixture, HybridBayesNet, HybridValues, noiseModel
|
from gtsam import (DiscreteKeys, GaussianConditional, GaussianMixture,
|
||||||
|
HybridBayesNet, HybridValues, noiseModel)
|
||||||
|
|
||||||
|
|
||||||
class TestHybridBayesNet(GtsamTestCase):
|
class TestHybridBayesNet(GtsamTestCase):
|
||||||
"""Unit tests for HybridValues."""
|
"""Unit tests for HybridValues."""
|
||||||
|
|
||||||
def test_evaluate(self):
|
def test_evaluate(self):
|
||||||
"""Test evaluate for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia)."""
|
"""Test evaluate for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia)."""
|
||||||
asiaKey = A(0)
|
asiaKey = A(0)
|
||||||
|
@ -32,7 +30,8 @@ class TestHybridBayesNet(GtsamTestCase):
|
||||||
|
|
||||||
# Create the continuous conditional
|
# Create the continuous conditional
|
||||||
I_1x1 = np.eye(1)
|
I_1x1 = np.eye(1)
|
||||||
gc = GaussianConditional.FromMeanAndStddev(X(0), 2 * I_1x1, X(1), [-4], 5.0)
|
gc = GaussianConditional.FromMeanAndStddev(X(0), 2 * I_1x1, X(1), [-4],
|
||||||
|
5.0)
|
||||||
|
|
||||||
# Create the noise models
|
# Create the noise models
|
||||||
model0 = noiseModel.Diagonal.Sigmas([2.0])
|
model0 = noiseModel.Diagonal.Sigmas([2.0])
|
||||||
|
@ -41,7 +40,10 @@ class TestHybridBayesNet(GtsamTestCase):
|
||||||
# Create the conditionals
|
# Create the conditionals
|
||||||
conditional0 = GaussianConditional(X(1), [5], I_1x1, model0)
|
conditional0 = GaussianConditional(X(1), [5], I_1x1, model0)
|
||||||
conditional1 = GaussianConditional(X(1), [2], I_1x1, model1)
|
conditional1 = GaussianConditional(X(1), [2], I_1x1, model1)
|
||||||
# gm = GaussianMixture.FromConditionals([X(1)], [], [Asia], [conditional0, conditional1]) #
|
dkeys = DiscreteKeys()
|
||||||
|
dkeys.push_back(Asia)
|
||||||
|
gm = GaussianMixture.FromConditionals([X(1)], [], dkeys,
|
||||||
|
[conditional0, conditional1]) #
|
||||||
|
|
||||||
# Create hybrid Bayes net.
|
# Create hybrid Bayes net.
|
||||||
bayesNet = HybridBayesNet()
|
bayesNet = HybridBayesNet()
|
||||||
|
@ -57,9 +59,10 @@ class TestHybridBayesNet(GtsamTestCase):
|
||||||
|
|
||||||
conditionalProbability = gc.evaluate(values.continuous())
|
conditionalProbability = gc.evaluate(values.continuous())
|
||||||
mixtureProbability = conditional0.evaluate(values.continuous())
|
mixtureProbability = conditional0.evaluate(values.continuous())
|
||||||
assert self.assertAlmostEqual(
|
assert self.assertAlmostEqual(conditionalProbability *
|
||||||
conditionalProbability * mixtureProbability * 0.99, bayesNet.evaluate(values), places=5
|
mixtureProbability * 0.99,
|
||||||
)
|
bayesNet.evaluate(values),
|
||||||
|
places=5)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue