Some test refactoring

release/4.3a0
Frank Dellaert 2023-01-14 15:41:39 -08:00
parent 070cdb7018
commit 96e3eb7d8b
2 changed files with 18 additions and 12 deletions

View File

@ -220,25 +220,24 @@ TEST(HybridBayesNet, logProbability) {
EXPECT_LONGS_EQUAL(5, hybridBayesNet->size()); EXPECT_LONGS_EQUAL(5, hybridBayesNet->size());
HybridValues delta = hybridBayesNet->optimize(); HybridValues delta = hybridBayesNet->optimize();
auto error_tree = hybridBayesNet->logProbability(delta.continuous()); auto actual = hybridBayesNet->logProbability(delta.continuous());
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}}; std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
std::vector<double> leaves = {4.1609374, 4.1706942, 4.141568, 4.1609374}; std::vector<double> leaves = {4.1609374, 4.1706942, 4.141568, 4.1609374};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves); AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);
// regression // regression
EXPECT(assert_equal(expected_error, error_tree, 1e-6)); EXPECT(assert_equal(expected, actual, 1e-6));
// logProbability on pruned Bayes net // logProbability on pruned Bayes net
auto prunedBayesNet = hybridBayesNet->prune(2); auto prunedBayesNet = hybridBayesNet->prune(2);
auto pruned_error_tree = prunedBayesNet.logProbability(delta.continuous()); auto pruned = prunedBayesNet.logProbability(delta.continuous());
std::vector<double> pruned_leaves = {2e50, 4.1706942, 2e50, 4.1609374}; std::vector<double> pruned_leaves = {2e50, 4.1706942, 2e50, 4.1609374};
AlgebraicDecisionTree<Key> expected_pruned_error(discrete_keys, AlgebraicDecisionTree<Key> expected_pruned(discrete_keys, pruned_leaves);
pruned_leaves);
// regression // regression
EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-6)); EXPECT(assert_equal(expected_pruned, pruned, 1e-6));
// Verify logProbability computation and check for specific logProbability // Verify logProbability computation and check for specific logProbability
// value // value
@ -253,9 +252,8 @@ TEST(HybridBayesNet, logProbability) {
hybridBayesNet->at(2)->asMixture()->logProbability(hybridValues); hybridBayesNet->at(2)->asMixture()->logProbability(hybridValues);
// TODO(dellaert): the discrete errors are not added in logProbability tree! // TODO(dellaert): the discrete errors are not added in logProbability tree!
EXPECT_DOUBLES_EQUAL(logProbability, error_tree(discrete_values), 1e-9); EXPECT_DOUBLES_EQUAL(logProbability, actual(discrete_values), 1e-9);
EXPECT_DOUBLES_EQUAL(logProbability, pruned_error_tree(discrete_values), EXPECT_DOUBLES_EQUAL(logProbability, pruned(discrete_values), 1e-9);
1e-9);
logProbability += logProbability +=
hybridBayesNet->at(3)->asDiscrete()->logProbability(discrete_values); hybridBayesNet->at(3)->asDiscrete()->logProbability(discrete_values);

View File

@ -172,10 +172,9 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
marginals /= marginals.sum() marginals /= marginals.sum()
return marginals return marginals
@unittest.skip
def test_tiny(self): def test_tiny(self):
"""Test a tiny two variable hybrid model.""" """Test a tiny two variable hybrid model."""
# P(x0)P(mode)P(z0|x0,mode) # Create P(x0)P(mode)P(z0|x0,mode)
prior_sigma = 0.5 prior_sigma = 0.5
bayesNet = self.tiny(prior_sigma=prior_sigma) bayesNet = self.tiny(prior_sigma=prior_sigma)
@ -210,9 +209,17 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
self.assertAlmostEqual(marginals[0], 0.74, delta=0.01) self.assertAlmostEqual(marginals[0], 0.74, delta=0.01)
self.assertAlmostEqual(marginals[1], 0.26, delta=0.01) self.assertAlmostEqual(marginals[1], 0.26, delta=0.01)
# Convert to factor graph with given measurements.
fg = bayesNet.toFactorGraph(measurements) fg = bayesNet.toFactorGraph(measurements)
self.assertEqual(fg.size(), 4) self.assertEqual(fg.size(), 4)
# Check ratio between unnormalized posterior and factor graph is the same for all modes:
for mode in [1, 0]:
values.insert_or_assign(M(0), mode)
self.assertAlmostEqual(bayesNet.evaluate(values) /
fg.error(values),
0.025178994744461187)
# Test elimination. # Test elimination.
posterior = fg.eliminateSequential() posterior = fg.eliminateSequential()
@ -239,6 +246,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
return bayesNet.evaluate(sample) / fg.probPrime(sample) if \ return bayesNet.evaluate(sample) / fg.probPrime(sample) if \
fg.probPrime(sample) > 0 else 0 fg.probPrime(sample) > 0 else 0
@unittest.skip
def test_ratio(self): def test_ratio(self):
""" """
Given a tiny two variable hybrid model, with 2 measurements, test the Given a tiny two variable hybrid model, with 2 measurements, test the