Test eliminate

release/4.3a0
Frank Dellaert 2022-12-31 02:07:05 -05:00
parent efd8eb1984
commit dcb07fea8c
1 changed files with 34 additions and 12 deletions

View File

@ -119,7 +119,15 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
return bayesNet
@staticmethod
def factor_graph_from_bayes_net(bayesNet: HybridBayesNet, sample: HybridValues):
def measurements(sample: HybridValues, indices) -> gtsam.VectorValues:
"""Create measurements from a sample, grabbing Z(i) where i in indices."""
measurements = gtsam.VectorValues()
for i in indices:
measurements.insert(Z(i), sample.at(Z(i)))
return measurements
@classmethod
def factor_graph_from_bayes_net(cls, bayesNet: HybridBayesNet, sample: HybridValues):
"""Create a factor graph from the Bayes net with sampled measurements.
The factor graph is `P(x)P(n) ϕ(x, n; z0) ϕ(x, n; z1) ...`
and thus represents the same joint probability as the Bayes net.
@ -128,9 +136,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
num_measurements = bayesNet.size() - 2
for i in range(num_measurements):
conditional = bayesNet.atMixture(i)
measurement = gtsam.VectorValues()
measurement.insert(Z(i), sample.at(Z(i)))
factor = conditional.likelihood(measurement)
factor = conditional.likelihood(cls.measurements(sample, [i]))
fg.push_back(factor)
fg.push_back(bayesNet.atGaussian(num_measurements))
fg.push_back(bayesNet.atDiscrete(num_measurements+1))
@ -147,11 +153,10 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
# Do importance sampling.
num_measurements = bayesNet.size() - 2
measurements = cls.measurements(sample, range(num_measurements))
for s in range(N):
proposed = prior.sample()
for i in range(num_measurements):
z_i = sample.at(Z(i))
proposed.insert(Z(i), z_i)
proposed.insert(measurements)
weight = bayesNet.evaluate(proposed) / prior.evaluate(proposed)
marginals[proposed.atDiscrete(M(0))] += weight
@ -213,15 +218,13 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
fg = self.factor_graph_from_bayes_net(bayesNet, sample)
self.assertEqual(fg.size(), 4)
# Create measurements from the sample.
measurements = self.measurements(sample, [0, 1])
# Calculate ratio between Bayes net probability and the factor graph:
expected_ratio = self.calculate_ratio(bayesNet, fg, sample)
# print(f"expected_ratio: {expected_ratio}\n")
# Create measurements from the sample.
measurements = gtsam.VectorValues()
for i in range(2):
measurements.insert(Z(i), sample.at(Z(i)))
# Check with a number of other samples.
for i in range(10):
other = bayesNet.sample()
@ -231,6 +234,25 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
if (ratio > 0):
self.assertAlmostEqual(ratio, expected_ratio)
# Test elimination.
ordering = gtsam.Ordering()
ordering.push_back(X(0))
ordering.push_back(M(0))
posterior = fg.eliminateSequential(ordering)
print(posterior)
# Calculate ratio between Bayes net probability and the factor graph:
expected_ratio = self.calculate_ratio(posterior, fg, sample)
print(f"expected_ratio: {expected_ratio}\n")
# Check with a number of other samples.
for i in range(10):
other = posterior.sample()
other.insert(measurements)
ratio = self.calculate_ratio(posterior, fg, other)
print(f"Ratio: {ratio}\n")
# if (ratio > 0):
# self.assertAlmostEqual(ratio, expected_ratio)
if __name__ == "__main__":
unittest.main()