Test more things in BayesTree
							parent
							
								
									40e5a1a6ab
								
							
						
					
					
						commit
						c8fe2fcff2
					
				| 
						 | 
				
			
			@ -13,10 +13,17 @@ Author: Frank Dellaert
 | 
			
		|||
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
 | 
			
		||||
                   DiscreteConditional, DiscreteFactorGraph, Ordering)
 | 
			
		||||
from gtsam.utils.test_case import GtsamTestCase
 | 
			
		||||
 | 
			
		||||
from gtsam import (
 | 
			
		||||
    DiscreteBayesNet,
 | 
			
		||||
    DiscreteBayesTreeClique,
 | 
			
		||||
    DiscreteConditional,
 | 
			
		||||
    DiscreteFactorGraph,
 | 
			
		||||
    DiscreteValues,
 | 
			
		||||
    Ordering,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestDiscreteBayesNet(GtsamTestCase):
 | 
			
		||||
    """Tests for Discrete Bayes Nets."""
 | 
			
		||||
| 
						 | 
				
			
			@ -65,15 +72,34 @@ class TestDiscreteBayesNet(GtsamTestCase):
 | 
			
		|||
        #     bayesTree[key].printSignature()
 | 
			
		||||
        # bayesTree.saveGraph("test_DiscreteBayesTree.dot")
 | 
			
		||||
 | 
			
		||||
        self.assertFalse(bayesTree.empty())
 | 
			
		||||
        self.assertEqual(12, bayesTree.size())
 | 
			
		||||
 | 
			
		||||
        # The root is P( 8 12 14), we can retrieve it by key:
 | 
			
		||||
        root = bayesTree[8]
 | 
			
		||||
        self.assertIsInstance(root, DiscreteBayesTreeClique)
 | 
			
		||||
        self.assertTrue(root.isRoot())
 | 
			
		||||
        self.assertIsInstance(root.conditional(), DiscreteConditional)
 | 
			
		||||
 | 
			
		||||
        # Test all methods in DiscreteBayesTree
 | 
			
		||||
        self.gtsamAssertEquals(bayesTree, bayesTree)
 | 
			
		||||
 | 
			
		||||
        # Check value at 0
 | 
			
		||||
        zero_values = DiscreteValues()
 | 
			
		||||
        for j in range(15):
 | 
			
		||||
            zero_values[j] = 0
 | 
			
		||||
        value_at_zeros = bayesTree.evaluate(zero_values)
 | 
			
		||||
        self.assertAlmostEqual(value_at_zeros, 0.0)
 | 
			
		||||
 | 
			
		||||
        # Check value at max
 | 
			
		||||
        values_star = factorGraph.optimize()
 | 
			
		||||
        max_value = bayesTree.evaluate(values_star)
 | 
			
		||||
        self.assertAlmostEqual(max_value, 0.002548)
 | 
			
		||||
 | 
			
		||||
        # Check operator sugar
 | 
			
		||||
        max_value = bayesTree(values_star)
 | 
			
		||||
        self.assertAlmostEqual(max_value, 0.002548)
 | 
			
		||||
 | 
			
		||||
        self.assertFalse(bayesTree.empty())
 | 
			
		||||
        self.assertEqual(12, bayesTree.size())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue