Some test refactoring

release/4.3a0
Frank Dellaert 2023-01-14 15:41:39 -08:00
parent 070cdb7018
commit 96e3eb7d8b
2 changed files with 18 additions and 12 deletions

View File

@ -220,25 +220,24 @@ TEST(HybridBayesNet, logProbability) {
EXPECT_LONGS_EQUAL(5, hybridBayesNet->size());
HybridValues delta = hybridBayesNet->optimize();
auto error_tree = hybridBayesNet->logProbability(delta.continuous());
auto actual = hybridBayesNet->logProbability(delta.continuous());
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
std::vector<double> leaves = {4.1609374, 4.1706942, 4.141568, 4.1609374};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);
// regression
EXPECT(assert_equal(expected_error, error_tree, 1e-6));
EXPECT(assert_equal(expected, actual, 1e-6));
// logProbability on pruned Bayes net
auto prunedBayesNet = hybridBayesNet->prune(2);
auto pruned_error_tree = prunedBayesNet.logProbability(delta.continuous());
auto pruned = prunedBayesNet.logProbability(delta.continuous());
std::vector<double> pruned_leaves = {2e50, 4.1706942, 2e50, 4.1609374};
AlgebraicDecisionTree<Key> expected_pruned_error(discrete_keys,
pruned_leaves);
AlgebraicDecisionTree<Key> expected_pruned(discrete_keys, pruned_leaves);
// regression
EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-6));
EXPECT(assert_equal(expected_pruned, pruned, 1e-6));
// Verify logProbability computation and check for specific logProbability
// value
@ -253,9 +252,8 @@ TEST(HybridBayesNet, logProbability) {
hybridBayesNet->at(2)->asMixture()->logProbability(hybridValues);
// TODO(dellaert): the discrete errors are not added in logProbability tree!
EXPECT_DOUBLES_EQUAL(logProbability, error_tree(discrete_values), 1e-9);
EXPECT_DOUBLES_EQUAL(logProbability, pruned_error_tree(discrete_values),
1e-9);
EXPECT_DOUBLES_EQUAL(logProbability, actual(discrete_values), 1e-9);
EXPECT_DOUBLES_EQUAL(logProbability, pruned(discrete_values), 1e-9);
logProbability +=
hybridBayesNet->at(3)->asDiscrete()->logProbability(discrete_values);

View File

@ -172,10 +172,9 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
marginals /= marginals.sum()
return marginals
@unittest.skip
def test_tiny(self):
"""Test a tiny two variable hybrid model."""
# P(x0)P(mode)P(z0|x0,mode)
# Create P(x0)P(mode)P(z0|x0,mode)
prior_sigma = 0.5
bayesNet = self.tiny(prior_sigma=prior_sigma)
@ -210,9 +209,17 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
self.assertAlmostEqual(marginals[0], 0.74, delta=0.01)
self.assertAlmostEqual(marginals[1], 0.26, delta=0.01)
# Convert to factor graph with given measurements.
fg = bayesNet.toFactorGraph(measurements)
self.assertEqual(fg.size(), 4)
# Check ratio between unnormalized posterior and factor graph is the same for all modes:
for mode in [1, 0]:
values.insert_or_assign(M(0), mode)
self.assertAlmostEqual(bayesNet.evaluate(values) /
fg.error(values),
0.025178994744461187)
# Test elimination.
posterior = fg.eliminateSequential()
@ -239,6 +246,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
return bayesNet.evaluate(sample) / fg.probPrime(sample) if \
fg.probPrime(sample) > 0 else 0
@unittest.skip
def test_ratio(self):
"""
Given a tiny two variable hybrid model, with 2 measurements, test the