Test more things in BayesTree
parent
40e5a1a6ab
commit
c8fe2fcff2
|
@ -13,10 +13,17 @@ Author: Frank Dellaert
|
|||
|
||||
import unittest
|
||||
|
||||
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
|
||||
DiscreteConditional, DiscreteFactorGraph, Ordering)
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
from gtsam import (
|
||||
DiscreteBayesNet,
|
||||
DiscreteBayesTreeClique,
|
||||
DiscreteConditional,
|
||||
DiscreteFactorGraph,
|
||||
DiscreteValues,
|
||||
Ordering,
|
||||
)
|
||||
|
||||
|
||||
class TestDiscreteBayesNet(GtsamTestCase):
|
||||
"""Tests for Discrete Bayes Nets."""
|
||||
|
@ -65,15 +72,34 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
# bayesTree[key].printSignature()
|
||||
# bayesTree.saveGraph("test_DiscreteBayesTree.dot")
|
||||
|
||||
self.assertFalse(bayesTree.empty())
|
||||
self.assertEqual(12, bayesTree.size())
|
||||
|
||||
# The root is P( 8 12 14), we can retrieve it by key:
|
||||
root = bayesTree[8]
|
||||
self.assertIsInstance(root, DiscreteBayesTreeClique)
|
||||
self.assertTrue(root.isRoot())
|
||||
self.assertIsInstance(root.conditional(), DiscreteConditional)
|
||||
|
||||
# Test all methods in DiscreteBayesTree
|
||||
self.gtsamAssertEquals(bayesTree, bayesTree)
|
||||
|
||||
# Check value at 0
|
||||
zero_values = DiscreteValues()
|
||||
for j in range(15):
|
||||
zero_values[j] = 0
|
||||
value_at_zeros = bayesTree.evaluate(zero_values)
|
||||
self.assertAlmostEqual(value_at_zeros, 0.0)
|
||||
|
||||
# Check value at max
|
||||
values_star = factorGraph.optimize()
|
||||
max_value = bayesTree.evaluate(values_star)
|
||||
self.assertAlmostEqual(max_value, 0.002548)
|
||||
|
||||
# Check operator sugar
|
||||
max_value = bayesTree(values_star)
|
||||
self.assertAlmostEqual(max_value, 0.002548)
|
||||
|
||||
self.assertFalse(bayesTree.empty())
|
||||
self.assertEqual(12, bayesTree.size())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue