diff --git a/python/gtsam/tests/test_HybridFactorGraph.py b/python/gtsam/tests/test_HybridFactorGraph.py index 481617db1..2d3513f12 100644 --- a/python/gtsam/tests/test_HybridFactorGraph.py +++ b/python/gtsam/tests/test_HybridFactorGraph.py @@ -6,7 +6,7 @@ All Rights Reserved See LICENSE for the license information Unit tests for Hybrid Factor Graphs. -Author: Fan Jiang +Author: Fan Jiang, Varun Agrawal, Frank Dellaert """ # pylint: disable=invalid-name, no-name-in-module, no-member @@ -25,6 +25,7 @@ from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional, class TestHybridGaussianFactorGraph(GtsamTestCase): """Unit tests for HybridGaussianFactorGraph.""" + def test_create(self): """Test construction of hybrid factor graph.""" model = noiseModel.Unit.Create(3) @@ -117,23 +118,23 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): return bayesNet - def test_tiny(self): - """Test a tiny two variable hybrid model.""" - bayesNet = self.tiny() - sample = bayesNet.sample() - # print(sample) - - # Create a factor graph from the Bayes net with sampled measurements. + @staticmethod + def factor_graph_from_bayes_net(bayesNet: gtsam.HybridBayesNet, sample: gtsam.HybridValues): + """Create a factor graph from the Bayes net with sampled measurements. + The factor graph is `P(x)P(n) ϕ(x, n; z0) ϕ(x, n; z1) ...` + and thus represents the same joint probability as the Bayes net. + """ fg = HybridGaussianFactorGraph() - conditional = bayesNet.atMixture(0) - measurement = gtsam.VectorValues() - measurement.insert(Z(0), sample.at(Z(0))) - factor = conditional.likelihood(measurement) - fg.push_back(factor) - fg.push_back(bayesNet.atGaussian(1)) - fg.push_back(bayesNet.atDiscrete(2)) - - self.assertEqual(fg.size(), 3) + num_measurements = bayesNet.size() - 2 + for i in range(num_measurements): + conditional = bayesNet.atMixture(i) + measurement = gtsam.VectorValues() + measurement.insert(Z(i), sample.at(Z(i))) + factor = conditional.likelihood(measurement) + fg.push_back(factor) + fg.push_back(bayesNet.atGaussian(num_measurements)) + fg.push_back(bayesNet.atDiscrete(num_measurements+1)) + return fg @staticmethod def calculate_ratio(bayesNet, fg, sample): @@ -143,6 +144,26 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): return bayesNet.evaluate(sample) / fg.probPrime( continuous, sample.discrete()) + def test_tiny(self): + """Test a tiny two variable hybrid model.""" + bayesNet = self.tiny() + sample = bayesNet.sample() + # print(sample) + + # TODO(dellaert): do importance sampling to get an estimate P(mode) + prior = self.tiny(num_measurements=0) # just P(x0)P(mode) + for s in range(100): + proposed = prior.sample() + print(proposed) + for i in range(2): + proposed.insert(Z(i), sample.at(Z(i))) + print(proposed) + weight = bayesNet.evaluate(proposed) / prior.evaluate(proposed) + print(weight) + + fg = self.factor_graph_from_bayes_net(bayesNet, sample) + self.assertEqual(fg.size(), 3) + def test_ratio(self): """ Given a tiny two variable hybrid model, with 2 measurements, @@ -156,20 +177,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): sample: gtsam.HybridValues = bayesNet.sample() # print(sample) - # Create a factor graph from the Bayes net with sampled measurements. - # The factor graph is `P(x)P(n) ϕ(x, n; z1) ϕ(x, n; z2)` - # and thus represents the same joint probability as the Bayes net. - fg = HybridGaussianFactorGraph() - for i in range(2): - conditional = bayesNet.atMixture(i) - measurement = gtsam.VectorValues() - measurement.insert(Z(i), sample.at(Z(i))) - factor = conditional.likelihood(measurement) - fg.push_back(factor) - fg.push_back(bayesNet.atGaussian(2)) - fg.push_back(bayesNet.atDiscrete(3)) - - # print(fg) + fg = self.factor_graph_from_bayes_net(bayesNet, sample) self.assertEqual(fg.size(), 4) # Calculate ratio between Bayes net probability and the factor graph: @@ -186,9 +194,9 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): other = bayesNet.sample() other.update(measurements) # print(other) - # ratio = self.calculate_ratio(bayesNet, fg, other) + ratio = self.calculate_ratio(bayesNet, fg, other) # print(f"Ratio: {ratio}\n") - # self.assertAlmostEqual(ratio, expected_ratio) + self.assertAlmostEqual(ratio, expected_ratio) if __name__ == "__main__":