Some test refactoring
parent
070cdb7018
commit
96e3eb7d8b
|
@ -220,25 +220,24 @@ TEST(HybridBayesNet, logProbability) {
|
||||||
EXPECT_LONGS_EQUAL(5, hybridBayesNet->size());
|
EXPECT_LONGS_EQUAL(5, hybridBayesNet->size());
|
||||||
|
|
||||||
HybridValues delta = hybridBayesNet->optimize();
|
HybridValues delta = hybridBayesNet->optimize();
|
||||||
auto error_tree = hybridBayesNet->logProbability(delta.continuous());
|
auto actual = hybridBayesNet->logProbability(delta.continuous());
|
||||||
|
|
||||||
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
|
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
|
||||||
std::vector<double> leaves = {4.1609374, 4.1706942, 4.141568, 4.1609374};
|
std::vector<double> leaves = {4.1609374, 4.1706942, 4.141568, 4.1609374};
|
||||||
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
|
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);
|
||||||
|
|
||||||
// regression
|
// regression
|
||||||
EXPECT(assert_equal(expected_error, error_tree, 1e-6));
|
EXPECT(assert_equal(expected, actual, 1e-6));
|
||||||
|
|
||||||
// logProbability on pruned Bayes net
|
// logProbability on pruned Bayes net
|
||||||
auto prunedBayesNet = hybridBayesNet->prune(2);
|
auto prunedBayesNet = hybridBayesNet->prune(2);
|
||||||
auto pruned_error_tree = prunedBayesNet.logProbability(delta.continuous());
|
auto pruned = prunedBayesNet.logProbability(delta.continuous());
|
||||||
|
|
||||||
std::vector<double> pruned_leaves = {2e50, 4.1706942, 2e50, 4.1609374};
|
std::vector<double> pruned_leaves = {2e50, 4.1706942, 2e50, 4.1609374};
|
||||||
AlgebraicDecisionTree<Key> expected_pruned_error(discrete_keys,
|
AlgebraicDecisionTree<Key> expected_pruned(discrete_keys, pruned_leaves);
|
||||||
pruned_leaves);
|
|
||||||
|
|
||||||
// regression
|
// regression
|
||||||
EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-6));
|
EXPECT(assert_equal(expected_pruned, pruned, 1e-6));
|
||||||
|
|
||||||
// Verify logProbability computation and check for specific logProbability
|
// Verify logProbability computation and check for specific logProbability
|
||||||
// value
|
// value
|
||||||
|
@ -253,9 +252,8 @@ TEST(HybridBayesNet, logProbability) {
|
||||||
hybridBayesNet->at(2)->asMixture()->logProbability(hybridValues);
|
hybridBayesNet->at(2)->asMixture()->logProbability(hybridValues);
|
||||||
|
|
||||||
// TODO(dellaert): the discrete errors are not added in logProbability tree!
|
// TODO(dellaert): the discrete errors are not added in logProbability tree!
|
||||||
EXPECT_DOUBLES_EQUAL(logProbability, error_tree(discrete_values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(logProbability, actual(discrete_values), 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(logProbability, pruned_error_tree(discrete_values),
|
EXPECT_DOUBLES_EQUAL(logProbability, pruned(discrete_values), 1e-9);
|
||||||
1e-9);
|
|
||||||
|
|
||||||
logProbability +=
|
logProbability +=
|
||||||
hybridBayesNet->at(3)->asDiscrete()->logProbability(discrete_values);
|
hybridBayesNet->at(3)->asDiscrete()->logProbability(discrete_values);
|
||||||
|
|
|
@ -172,10 +172,9 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
marginals /= marginals.sum()
|
marginals /= marginals.sum()
|
||||||
return marginals
|
return marginals
|
||||||
|
|
||||||
@unittest.skip
|
|
||||||
def test_tiny(self):
|
def test_tiny(self):
|
||||||
"""Test a tiny two variable hybrid model."""
|
"""Test a tiny two variable hybrid model."""
|
||||||
# P(x0)P(mode)P(z0|x0,mode)
|
# Create P(x0)P(mode)P(z0|x0,mode)
|
||||||
prior_sigma = 0.5
|
prior_sigma = 0.5
|
||||||
bayesNet = self.tiny(prior_sigma=prior_sigma)
|
bayesNet = self.tiny(prior_sigma=prior_sigma)
|
||||||
|
|
||||||
|
@ -210,9 +209,17 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
self.assertAlmostEqual(marginals[0], 0.74, delta=0.01)
|
self.assertAlmostEqual(marginals[0], 0.74, delta=0.01)
|
||||||
self.assertAlmostEqual(marginals[1], 0.26, delta=0.01)
|
self.assertAlmostEqual(marginals[1], 0.26, delta=0.01)
|
||||||
|
|
||||||
|
# Convert to factor graph with given measurements.
|
||||||
fg = bayesNet.toFactorGraph(measurements)
|
fg = bayesNet.toFactorGraph(measurements)
|
||||||
self.assertEqual(fg.size(), 4)
|
self.assertEqual(fg.size(), 4)
|
||||||
|
|
||||||
|
# Check ratio between unnormalized posterior and factor graph is the same for all modes:
|
||||||
|
for mode in [1, 0]:
|
||||||
|
values.insert_or_assign(M(0), mode)
|
||||||
|
self.assertAlmostEqual(bayesNet.evaluate(values) /
|
||||||
|
fg.error(values),
|
||||||
|
0.025178994744461187)
|
||||||
|
|
||||||
# Test elimination.
|
# Test elimination.
|
||||||
posterior = fg.eliminateSequential()
|
posterior = fg.eliminateSequential()
|
||||||
|
|
||||||
|
@ -239,6 +246,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
return bayesNet.evaluate(sample) / fg.probPrime(sample) if \
|
return bayesNet.evaluate(sample) / fg.probPrime(sample) if \
|
||||||
fg.probPrime(sample) > 0 else 0
|
fg.probPrime(sample) > 0 else 0
|
||||||
|
|
||||||
|
@unittest.skip
|
||||||
def test_ratio(self):
|
def test_ratio(self):
|
||||||
"""
|
"""
|
||||||
Given a tiny two variable hybrid model, with 2 measurements, test the
|
Given a tiny two variable hybrid model, with 2 measurements, test the
|
||||||
|
|
Loading…
Reference in New Issue