Test more things in BayesTree
parent
40e5a1a6ab
commit
c8fe2fcff2
|
@ -13,10 +13,17 @@ Author: Frank Dellaert
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
|
|
||||||
DiscreteConditional, DiscreteFactorGraph, Ordering)
|
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
from gtsam import (
|
||||||
|
DiscreteBayesNet,
|
||||||
|
DiscreteBayesTreeClique,
|
||||||
|
DiscreteConditional,
|
||||||
|
DiscreteFactorGraph,
|
||||||
|
DiscreteValues,
|
||||||
|
Ordering,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestDiscreteBayesNet(GtsamTestCase):
|
class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
"""Tests for Discrete Bayes Nets."""
|
"""Tests for Discrete Bayes Nets."""
|
||||||
|
@ -65,15 +72,34 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
# bayesTree[key].printSignature()
|
# bayesTree[key].printSignature()
|
||||||
# bayesTree.saveGraph("test_DiscreteBayesTree.dot")
|
# bayesTree.saveGraph("test_DiscreteBayesTree.dot")
|
||||||
|
|
||||||
self.assertFalse(bayesTree.empty())
|
|
||||||
self.assertEqual(12, bayesTree.size())
|
|
||||||
|
|
||||||
# The root is P( 8 12 14), we can retrieve it by key:
|
# The root is P( 8 12 14), we can retrieve it by key:
|
||||||
root = bayesTree[8]
|
root = bayesTree[8]
|
||||||
self.assertIsInstance(root, DiscreteBayesTreeClique)
|
self.assertIsInstance(root, DiscreteBayesTreeClique)
|
||||||
self.assertTrue(root.isRoot())
|
self.assertTrue(root.isRoot())
|
||||||
self.assertIsInstance(root.conditional(), DiscreteConditional)
|
self.assertIsInstance(root.conditional(), DiscreteConditional)
|
||||||
|
|
||||||
|
# Test all methods in DiscreteBayesTree
|
||||||
|
self.gtsamAssertEquals(bayesTree, bayesTree)
|
||||||
|
|
||||||
|
# Check value at 0
|
||||||
|
zero_values = DiscreteValues()
|
||||||
|
for j in range(15):
|
||||||
|
zero_values[j] = 0
|
||||||
|
value_at_zeros = bayesTree.evaluate(zero_values)
|
||||||
|
self.assertAlmostEqual(value_at_zeros, 0.0)
|
||||||
|
|
||||||
|
# Check value at max
|
||||||
|
values_star = factorGraph.optimize()
|
||||||
|
max_value = bayesTree.evaluate(values_star)
|
||||||
|
self.assertAlmostEqual(max_value, 0.002548)
|
||||||
|
|
||||||
|
# Check operator sugar
|
||||||
|
max_value = bayesTree(values_star)
|
||||||
|
self.assertAlmostEqual(max_value, 0.002548)
|
||||||
|
|
||||||
|
self.assertFalse(bayesTree.empty())
|
||||||
|
self.assertEqual(12, bayesTree.size())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue