Some test refactoring
							parent
							
								
									070cdb7018
								
							
						
					
					
						commit
						96e3eb7d8b
					
				|  | @ -220,25 +220,24 @@ TEST(HybridBayesNet, logProbability) { | |||
|   EXPECT_LONGS_EQUAL(5, hybridBayesNet->size()); | ||||
| 
 | ||||
|   HybridValues delta = hybridBayesNet->optimize(); | ||||
|   auto error_tree = hybridBayesNet->logProbability(delta.continuous()); | ||||
|   auto actual = hybridBayesNet->logProbability(delta.continuous()); | ||||
| 
 | ||||
|   std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}}; | ||||
|   std::vector<double> leaves = {4.1609374, 4.1706942, 4.141568, 4.1609374}; | ||||
|   AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves); | ||||
|   AlgebraicDecisionTree<Key> expected(discrete_keys, leaves); | ||||
| 
 | ||||
|   // regression
 | ||||
|   EXPECT(assert_equal(expected_error, error_tree, 1e-6)); | ||||
|   EXPECT(assert_equal(expected, actual, 1e-6)); | ||||
| 
 | ||||
|   // logProbability on pruned Bayes net
 | ||||
|   auto prunedBayesNet = hybridBayesNet->prune(2); | ||||
|   auto pruned_error_tree = prunedBayesNet.logProbability(delta.continuous()); | ||||
|   auto pruned = prunedBayesNet.logProbability(delta.continuous()); | ||||
| 
 | ||||
|   std::vector<double> pruned_leaves = {2e50, 4.1706942, 2e50, 4.1609374}; | ||||
|   AlgebraicDecisionTree<Key> expected_pruned_error(discrete_keys, | ||||
|                                                    pruned_leaves); | ||||
|   AlgebraicDecisionTree<Key> expected_pruned(discrete_keys, pruned_leaves); | ||||
| 
 | ||||
|   // regression
 | ||||
|   EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-6)); | ||||
|   EXPECT(assert_equal(expected_pruned, pruned, 1e-6)); | ||||
| 
 | ||||
|   // Verify logProbability computation and check for specific logProbability
 | ||||
|   // value
 | ||||
|  | @ -253,9 +252,8 @@ TEST(HybridBayesNet, logProbability) { | |||
|       hybridBayesNet->at(2)->asMixture()->logProbability(hybridValues); | ||||
| 
 | ||||
|   // TODO(dellaert): the discrete errors are not added in logProbability tree!
 | ||||
|   EXPECT_DOUBLES_EQUAL(logProbability, error_tree(discrete_values), 1e-9); | ||||
|   EXPECT_DOUBLES_EQUAL(logProbability, pruned_error_tree(discrete_values), | ||||
|                        1e-9); | ||||
|   EXPECT_DOUBLES_EQUAL(logProbability, actual(discrete_values), 1e-9); | ||||
|   EXPECT_DOUBLES_EQUAL(logProbability, pruned(discrete_values), 1e-9); | ||||
| 
 | ||||
|   logProbability += | ||||
|       hybridBayesNet->at(3)->asDiscrete()->logProbability(discrete_values); | ||||
|  |  | |||
|  | @ -172,10 +172,9 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): | |||
|         marginals /= marginals.sum() | ||||
|         return marginals | ||||
| 
 | ||||
|     @unittest.skip | ||||
|     def test_tiny(self): | ||||
|         """Test a tiny two variable hybrid model.""" | ||||
|         # P(x0)P(mode)P(z0|x0,mode) | ||||
|         # Create P(x0)P(mode)P(z0|x0,mode) | ||||
|         prior_sigma = 0.5 | ||||
|         bayesNet = self.tiny(prior_sigma=prior_sigma) | ||||
| 
 | ||||
|  | @ -210,9 +209,17 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): | |||
|         self.assertAlmostEqual(marginals[0], 0.74, delta=0.01) | ||||
|         self.assertAlmostEqual(marginals[1], 0.26, delta=0.01) | ||||
| 
 | ||||
|         # Convert to factor graph with given measurements. | ||||
|         fg = bayesNet.toFactorGraph(measurements) | ||||
|         self.assertEqual(fg.size(), 4) | ||||
| 
 | ||||
|         # Check ratio between unnormalized posterior and factor graph is the same for all modes: | ||||
|         for mode in [1, 0]: | ||||
|             values.insert_or_assign(M(0), mode) | ||||
|             self.assertAlmostEqual(bayesNet.evaluate(values) / | ||||
|                                    fg.error(values), | ||||
|                                    0.025178994744461187) | ||||
| 
 | ||||
|         # Test elimination. | ||||
|         posterior = fg.eliminateSequential() | ||||
| 
 | ||||
|  | @ -239,6 +246,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): | |||
|         return bayesNet.evaluate(sample) / fg.probPrime(sample) if \ | ||||
|             fg.probPrime(sample) > 0 else 0 | ||||
| 
 | ||||
|     @unittest.skip | ||||
|     def test_ratio(self): | ||||
|         """ | ||||
|         Given a tiny two variable hybrid model, with 2 measurements, test the | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue