Test more things in BayesTree

release/4.3a0
Frank Dellaert 2023-06-04 22:39:14 +01:00
parent 40e5a1a6ab
commit c8fe2fcff2
1 changed files with 31 additions and 5 deletions

View File

@ -13,10 +13,17 @@ Author: Frank Dellaert
import unittest import unittest
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
DiscreteConditional, DiscreteFactorGraph, Ordering)
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
from gtsam import (
DiscreteBayesNet,
DiscreteBayesTreeClique,
DiscreteConditional,
DiscreteFactorGraph,
DiscreteValues,
Ordering,
)
class TestDiscreteBayesNet(GtsamTestCase): class TestDiscreteBayesNet(GtsamTestCase):
"""Tests for Discrete Bayes Nets.""" """Tests for Discrete Bayes Nets."""
@ -65,15 +72,34 @@ class TestDiscreteBayesNet(GtsamTestCase):
# bayesTree[key].printSignature() # bayesTree[key].printSignature()
# bayesTree.saveGraph("test_DiscreteBayesTree.dot") # bayesTree.saveGraph("test_DiscreteBayesTree.dot")
self.assertFalse(bayesTree.empty())
self.assertEqual(12, bayesTree.size())
# The root is P( 8 12 14), we can retrieve it by key: # The root is P( 8 12 14), we can retrieve it by key:
root = bayesTree[8] root = bayesTree[8]
self.assertIsInstance(root, DiscreteBayesTreeClique) self.assertIsInstance(root, DiscreteBayesTreeClique)
self.assertTrue(root.isRoot()) self.assertTrue(root.isRoot())
self.assertIsInstance(root.conditional(), DiscreteConditional) self.assertIsInstance(root.conditional(), DiscreteConditional)
# Test all methods in DiscreteBayesTree
self.gtsamAssertEquals(bayesTree, bayesTree)
# Check value at 0
zero_values = DiscreteValues()
for j in range(15):
zero_values[j] = 0
value_at_zeros = bayesTree.evaluate(zero_values)
self.assertAlmostEqual(value_at_zeros, 0.0)
# Check value at max
values_star = factorGraph.optimize()
max_value = bayesTree.evaluate(values_star)
self.assertAlmostEqual(max_value, 0.002548)
# Check operator sugar
max_value = bayesTree(values_star)
self.assertAlmostEqual(max_value, 0.002548)
self.assertFalse(bayesTree.empty())
self.assertEqual(12, bayesTree.size())
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()