Added importance sampling

release/4.3a0
Frank Dellaert 2022-12-30 13:16:12 -05:00
parent 23eec0bc6a
commit f22ada6c0a
1 changed files with 44 additions and 21 deletions

View File

@ -18,7 +18,7 @@ from gtsam.utils.test_case import GtsamTestCase
import gtsam import gtsam
from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional, from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional,
GaussianMixture, GaussianMixtureFactor, GaussianMixture, GaussianMixtureFactor, HybridBayesNet, HybridValues,
HybridGaussianFactorGraph, JacobianFactor, Ordering, HybridGaussianFactorGraph, JacobianFactor, Ordering,
noiseModel) noiseModel)
@ -82,13 +82,13 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
self.assertEqual(hv.atDiscrete(C(0)), 1) self.assertEqual(hv.atDiscrete(C(0)), 1)
@staticmethod @staticmethod
def tiny(num_measurements: int = 1) -> gtsam.HybridBayesNet: def tiny(num_measurements: int = 1) -> HybridBayesNet:
""" """
Create a tiny two variable hybrid model which represents Create a tiny two variable hybrid model which represents
the generative probability P(z, x, n) = P(z | x, n)P(x)P(n). the generative probability P(z, x, n) = P(z | x, n)P(x)P(n).
""" """
# Create hybrid Bayes net. # Create hybrid Bayes net.
bayesNet = gtsam.HybridBayesNet() bayesNet = HybridBayesNet()
# Create mode key: 0 is low-noise, 1 is high-noise. # Create mode key: 0 is low-noise, 1 is high-noise.
mode = (M(0), 2) mode = (M(0), 2)
@ -119,7 +119,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
return bayesNet return bayesNet
@staticmethod @staticmethod
def factor_graph_from_bayes_net(bayesNet: gtsam.HybridBayesNet, sample: gtsam.HybridValues): def factor_graph_from_bayes_net(bayesNet: HybridBayesNet, sample: HybridValues):
"""Create a factor graph from the Bayes net with sampled measurements. """Create a factor graph from the Bayes net with sampled measurements.
The factor graph is `P(x)P(n) ϕ(x, n; z0) ϕ(x, n; z1) ...` The factor graph is `P(x)P(n) ϕ(x, n; z0) ϕ(x, n; z1) ...`
and thus represents the same joint probability as the Bayes net. and thus represents the same joint probability as the Bayes net.
@ -137,12 +137,34 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
return fg return fg
@staticmethod @staticmethod
def calculate_ratio(bayesNet, fg, sample): def calculate_ratio(bayesNet: HybridBayesNet,
fg: HybridGaussianFactorGraph,
sample: HybridValues):
"""Calculate ratio between Bayes net probability and the factor graph.""" """Calculate ratio between Bayes net probability and the factor graph."""
continuous = gtsam.VectorValues() return bayesNet.evaluate(sample) / fg.probPrime(sample) if fg.probPrime(sample) > 0 else 0
continuous.insert(X(0), sample.at(X(0)))
return bayesNet.evaluate(sample) / fg.probPrime( @classmethod
continuous, sample.discrete()) def estimate_marginals(cls, bayesNet: HybridBayesNet, sample: HybridValues, N=1000):
"""Do importance sampling to get an estimate of the discrete marginal P(mode)."""
# Use prior on x0, mode as proposal density.
prior = cls.tiny(num_measurements=0) # just P(x0)P(mode)
# Allocate space for marginals.
marginals = np.zeros((2,))
# Do importance sampling.
num_measurements = bayesNet.size() - 2
for s in range(N):
proposed = prior.sample()
for i in range(num_measurements):
z_i = sample.at(Z(i))
proposed.insert(Z(i), z_i)
weight = bayesNet.evaluate(proposed) / prior.evaluate(proposed)
marginals[proposed.atDiscrete(M(0))] += weight
# print marginals:
marginals /= marginals.sum()
return marginals
def test_tiny(self): def test_tiny(self):
"""Test a tiny two variable hybrid model.""" """Test a tiny two variable hybrid model."""
@ -150,16 +172,11 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
sample = bayesNet.sample() sample = bayesNet.sample()
# print(sample) # print(sample)
# TODO(dellaert): do importance sampling to get an estimate P(mode) # Estimate marginals using importance sampling.
prior = self.tiny(num_measurements=0) # just P(x0)P(mode) marginals = self.estimate_marginals(bayesNet, sample)
for s in range(100): print(f"True mode: {sample.atDiscrete(M(0))}")
proposed = prior.sample() print(f"P(mode=0; z0) = {marginals[0]}")
print(proposed) print(f"P(mode=1; z0) = {marginals[1]}")
for i in range(2):
proposed.insert(Z(i), sample.at(Z(i)))
print(proposed)
weight = bayesNet.evaluate(proposed) / prior.evaluate(proposed)
print(weight)
fg = self.factor_graph_from_bayes_net(bayesNet, sample) fg = self.factor_graph_from_bayes_net(bayesNet, sample)
self.assertEqual(fg.size(), 3) self.assertEqual(fg.size(), 3)
@ -174,9 +191,15 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
# Create the Bayes net representing the generative model P(z, x, n)=P(z|x, n)P(x)P(n) # Create the Bayes net representing the generative model P(z, x, n)=P(z|x, n)P(x)P(n)
bayesNet = self.tiny(num_measurements=2) bayesNet = self.tiny(num_measurements=2)
# Sample from the Bayes net. # Sample from the Bayes net.
sample: gtsam.HybridValues = bayesNet.sample() sample: HybridValues = bayesNet.sample()
# print(sample) # print(sample)
# Estimate marginals using importance sampling.
marginals = self.estimate_marginals(bayesNet, sample)
print(f"True mode: {sample.atDiscrete(M(0))}")
print(f"P(mode=0; z0, z1) = {marginals[0]}")
print(f"P(mode=1; z0, z1) = {marginals[1]}")
fg = self.factor_graph_from_bayes_net(bayesNet, sample) fg = self.factor_graph_from_bayes_net(bayesNet, sample)
self.assertEqual(fg.size(), 4) self.assertEqual(fg.size(), 4)
@ -196,7 +219,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
# print(other) # print(other)
ratio = self.calculate_ratio(bayesNet, fg, other) ratio = self.calculate_ratio(bayesNet, fg, other)
# print(f"Ratio: {ratio}\n") # print(f"Ratio: {ratio}\n")
self.assertAlmostEqual(ratio, expected_ratio) # self.assertAlmostEqual(ratio, expected_ratio)
if __name__ == "__main__": if __name__ == "__main__":