diff --git a/python/gtsam/tests/test_DiscreteBayesTree.py b/python/gtsam/tests/test_DiscreteBayesTree.py index b1ed4fe69..e1754ca64 100644 --- a/python/gtsam/tests/test_DiscreteBayesTree.py +++ b/python/gtsam/tests/test_DiscreteBayesTree.py @@ -13,10 +13,17 @@ Author: Frank Dellaert import unittest -from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique, - DiscreteConditional, DiscreteFactorGraph, Ordering) from gtsam.utils.test_case import GtsamTestCase +from gtsam import ( + DiscreteBayesNet, + DiscreteBayesTreeClique, + DiscreteConditional, + DiscreteFactorGraph, + DiscreteValues, + Ordering, +) + class TestDiscreteBayesNet(GtsamTestCase): """Tests for Discrete Bayes Nets.""" @@ -65,15 +72,34 @@ class TestDiscreteBayesNet(GtsamTestCase): # bayesTree[key].printSignature() # bayesTree.saveGraph("test_DiscreteBayesTree.dot") - self.assertFalse(bayesTree.empty()) - self.assertEqual(12, bayesTree.size()) - # The root is P( 8 12 14), we can retrieve it by key: root = bayesTree[8] self.assertIsInstance(root, DiscreteBayesTreeClique) self.assertTrue(root.isRoot()) self.assertIsInstance(root.conditional(), DiscreteConditional) + # Test all methods in DiscreteBayesTree + self.gtsamAssertEquals(bayesTree, bayesTree) + + # Check value at 0 + zero_values = DiscreteValues() + for j in range(15): + zero_values[j] = 0 + value_at_zeros = bayesTree.evaluate(zero_values) + self.assertAlmostEqual(value_at_zeros, 0.0) + + # Check value at max + values_star = factorGraph.optimize() + max_value = bayesTree.evaluate(values_star) + self.assertAlmostEqual(max_value, 0.002548) + + # Check operator sugar + max_value = bayesTree(values_star) + self.assertAlmostEqual(max_value, 0.002548) + + self.assertFalse(bayesTree.empty()) + self.assertEqual(12, bayesTree.size()) + if __name__ == "__main__": unittest.main()