diff --git a/python/gtsam/tests/test_HybridFactorGraph.py b/python/gtsam/tests/test_HybridFactorGraph.py index 5398160dc..700137d21 100644 --- a/python/gtsam/tests/test_HybridFactorGraph.py +++ b/python/gtsam/tests/test_HybridFactorGraph.py @@ -119,7 +119,15 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): return bayesNet @staticmethod - def factor_graph_from_bayes_net(bayesNet: HybridBayesNet, sample: HybridValues): + def measurements(sample: HybridValues, indices) -> gtsam.VectorValues: + """Create measurements from a sample, grabbing Z(i) where i in indices.""" + measurements = gtsam.VectorValues() + for i in indices: + measurements.insert(Z(i), sample.at(Z(i))) + return measurements + + @classmethod + def factor_graph_from_bayes_net(cls, bayesNet: HybridBayesNet, sample: HybridValues): """Create a factor graph from the Bayes net with sampled measurements. The factor graph is `P(x)P(n) ϕ(x, n; z0) ϕ(x, n; z1) ...` and thus represents the same joint probability as the Bayes net. @@ -128,9 +136,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): num_measurements = bayesNet.size() - 2 for i in range(num_measurements): conditional = bayesNet.atMixture(i) - measurement = gtsam.VectorValues() - measurement.insert(Z(i), sample.at(Z(i))) - factor = conditional.likelihood(measurement) + factor = conditional.likelihood(cls.measurements(sample, [i])) fg.push_back(factor) fg.push_back(bayesNet.atGaussian(num_measurements)) fg.push_back(bayesNet.atDiscrete(num_measurements+1)) @@ -147,11 +153,10 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): # Do importance sampling. num_measurements = bayesNet.size() - 2 + measurements = cls.measurements(sample, range(num_measurements)) for s in range(N): proposed = prior.sample() - for i in range(num_measurements): - z_i = sample.at(Z(i)) - proposed.insert(Z(i), z_i) + proposed.insert(measurements) weight = bayesNet.evaluate(proposed) / prior.evaluate(proposed) marginals[proposed.atDiscrete(M(0))] += weight @@ -213,15 +218,13 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): fg = self.factor_graph_from_bayes_net(bayesNet, sample) self.assertEqual(fg.size(), 4) + # Create measurements from the sample. + measurements = self.measurements(sample, [0, 1]) + # Calculate ratio between Bayes net probability and the factor graph: expected_ratio = self.calculate_ratio(bayesNet, fg, sample) # print(f"expected_ratio: {expected_ratio}\n") - # Create measurements from the sample. - measurements = gtsam.VectorValues() - for i in range(2): - measurements.insert(Z(i), sample.at(Z(i))) - # Check with a number of other samples. for i in range(10): other = bayesNet.sample() @@ -231,6 +234,25 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): if (ratio > 0): self.assertAlmostEqual(ratio, expected_ratio) + # Test elimination. + ordering = gtsam.Ordering() + ordering.push_back(X(0)) + ordering.push_back(M(0)) + posterior = fg.eliminateSequential(ordering) + print(posterior) + + # Calculate ratio between Bayes net probability and the factor graph: + expected_ratio = self.calculate_ratio(posterior, fg, sample) + print(f"expected_ratio: {expected_ratio}\n") + + # Check with a number of other samples. + for i in range(10): + other = posterior.sample() + other.insert(measurements) + ratio = self.calculate_ratio(posterior, fg, other) + print(f"Ratio: {ratio}\n") + # if (ratio > 0): + # self.assertAlmostEqual(ratio, expected_ratio) if __name__ == "__main__": unittest.main()