Test eliminate
parent
efd8eb1984
commit
dcb07fea8c
|
@ -119,7 +119,15 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
return bayesNet
|
return bayesNet
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def factor_graph_from_bayes_net(bayesNet: HybridBayesNet, sample: HybridValues):
|
def measurements(sample: HybridValues, indices) -> gtsam.VectorValues:
|
||||||
|
"""Create measurements from a sample, grabbing Z(i) where i in indices."""
|
||||||
|
measurements = gtsam.VectorValues()
|
||||||
|
for i in indices:
|
||||||
|
measurements.insert(Z(i), sample.at(Z(i)))
|
||||||
|
return measurements
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def factor_graph_from_bayes_net(cls, bayesNet: HybridBayesNet, sample: HybridValues):
|
||||||
"""Create a factor graph from the Bayes net with sampled measurements.
|
"""Create a factor graph from the Bayes net with sampled measurements.
|
||||||
The factor graph is `P(x)P(n) ϕ(x, n; z0) ϕ(x, n; z1) ...`
|
The factor graph is `P(x)P(n) ϕ(x, n; z0) ϕ(x, n; z1) ...`
|
||||||
and thus represents the same joint probability as the Bayes net.
|
and thus represents the same joint probability as the Bayes net.
|
||||||
|
@ -128,9 +136,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
num_measurements = bayesNet.size() - 2
|
num_measurements = bayesNet.size() - 2
|
||||||
for i in range(num_measurements):
|
for i in range(num_measurements):
|
||||||
conditional = bayesNet.atMixture(i)
|
conditional = bayesNet.atMixture(i)
|
||||||
measurement = gtsam.VectorValues()
|
factor = conditional.likelihood(cls.measurements(sample, [i]))
|
||||||
measurement.insert(Z(i), sample.at(Z(i)))
|
|
||||||
factor = conditional.likelihood(measurement)
|
|
||||||
fg.push_back(factor)
|
fg.push_back(factor)
|
||||||
fg.push_back(bayesNet.atGaussian(num_measurements))
|
fg.push_back(bayesNet.atGaussian(num_measurements))
|
||||||
fg.push_back(bayesNet.atDiscrete(num_measurements+1))
|
fg.push_back(bayesNet.atDiscrete(num_measurements+1))
|
||||||
|
@ -147,11 +153,10 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
|
|
||||||
# Do importance sampling.
|
# Do importance sampling.
|
||||||
num_measurements = bayesNet.size() - 2
|
num_measurements = bayesNet.size() - 2
|
||||||
|
measurements = cls.measurements(sample, range(num_measurements))
|
||||||
for s in range(N):
|
for s in range(N):
|
||||||
proposed = prior.sample()
|
proposed = prior.sample()
|
||||||
for i in range(num_measurements):
|
proposed.insert(measurements)
|
||||||
z_i = sample.at(Z(i))
|
|
||||||
proposed.insert(Z(i), z_i)
|
|
||||||
weight = bayesNet.evaluate(proposed) / prior.evaluate(proposed)
|
weight = bayesNet.evaluate(proposed) / prior.evaluate(proposed)
|
||||||
marginals[proposed.atDiscrete(M(0))] += weight
|
marginals[proposed.atDiscrete(M(0))] += weight
|
||||||
|
|
||||||
|
@ -213,15 +218,13 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
fg = self.factor_graph_from_bayes_net(bayesNet, sample)
|
fg = self.factor_graph_from_bayes_net(bayesNet, sample)
|
||||||
self.assertEqual(fg.size(), 4)
|
self.assertEqual(fg.size(), 4)
|
||||||
|
|
||||||
|
# Create measurements from the sample.
|
||||||
|
measurements = self.measurements(sample, [0, 1])
|
||||||
|
|
||||||
# Calculate ratio between Bayes net probability and the factor graph:
|
# Calculate ratio between Bayes net probability and the factor graph:
|
||||||
expected_ratio = self.calculate_ratio(bayesNet, fg, sample)
|
expected_ratio = self.calculate_ratio(bayesNet, fg, sample)
|
||||||
# print(f"expected_ratio: {expected_ratio}\n")
|
# print(f"expected_ratio: {expected_ratio}\n")
|
||||||
|
|
||||||
# Create measurements from the sample.
|
|
||||||
measurements = gtsam.VectorValues()
|
|
||||||
for i in range(2):
|
|
||||||
measurements.insert(Z(i), sample.at(Z(i)))
|
|
||||||
|
|
||||||
# Check with a number of other samples.
|
# Check with a number of other samples.
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
other = bayesNet.sample()
|
other = bayesNet.sample()
|
||||||
|
@ -231,6 +234,25 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
if (ratio > 0):
|
if (ratio > 0):
|
||||||
self.assertAlmostEqual(ratio, expected_ratio)
|
self.assertAlmostEqual(ratio, expected_ratio)
|
||||||
|
|
||||||
|
# Test elimination.
|
||||||
|
ordering = gtsam.Ordering()
|
||||||
|
ordering.push_back(X(0))
|
||||||
|
ordering.push_back(M(0))
|
||||||
|
posterior = fg.eliminateSequential(ordering)
|
||||||
|
print(posterior)
|
||||||
|
|
||||||
|
# Calculate ratio between Bayes net probability and the factor graph:
|
||||||
|
expected_ratio = self.calculate_ratio(posterior, fg, sample)
|
||||||
|
print(f"expected_ratio: {expected_ratio}\n")
|
||||||
|
|
||||||
|
# Check with a number of other samples.
|
||||||
|
for i in range(10):
|
||||||
|
other = posterior.sample()
|
||||||
|
other.insert(measurements)
|
||||||
|
ratio = self.calculate_ratio(posterior, fg, other)
|
||||||
|
print(f"Ratio: {ratio}\n")
|
||||||
|
# if (ratio > 0):
|
||||||
|
# self.assertAlmostEqual(ratio, expected_ratio)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue