Added importance sampling
							parent
							
								
									23eec0bc6a
								
							
						
					
					
						commit
						f22ada6c0a
					
				| 
						 | 
				
			
			@ -18,7 +18,7 @@ from gtsam.utils.test_case import GtsamTestCase
 | 
			
		|||
 | 
			
		||||
import gtsam
 | 
			
		||||
from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional,
 | 
			
		||||
                   GaussianMixture, GaussianMixtureFactor,
 | 
			
		||||
                   GaussianMixture, GaussianMixtureFactor, HybridBayesNet, HybridValues,
 | 
			
		||||
                   HybridGaussianFactorGraph, JacobianFactor, Ordering,
 | 
			
		||||
                   noiseModel)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -82,13 +82,13 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
 | 
			
		|||
        self.assertEqual(hv.atDiscrete(C(0)), 1)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def tiny(num_measurements: int = 1) -> gtsam.HybridBayesNet:
 | 
			
		||||
    def tiny(num_measurements: int = 1) -> HybridBayesNet:
 | 
			
		||||
        """
 | 
			
		||||
        Create a tiny two variable hybrid model which represents
 | 
			
		||||
        the generative probability P(z, x, n) = P(z | x, n)P(x)P(n).
 | 
			
		||||
        """
 | 
			
		||||
        # Create hybrid Bayes net.
 | 
			
		||||
        bayesNet = gtsam.HybridBayesNet()
 | 
			
		||||
        bayesNet = HybridBayesNet()
 | 
			
		||||
 | 
			
		||||
        # Create mode key: 0 is low-noise, 1 is high-noise.
 | 
			
		||||
        mode = (M(0), 2)
 | 
			
		||||
| 
						 | 
				
			
			@ -119,7 +119,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
 | 
			
		|||
        return bayesNet
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def factor_graph_from_bayes_net(bayesNet: gtsam.HybridBayesNet, sample: gtsam.HybridValues):
 | 
			
		||||
    def factor_graph_from_bayes_net(bayesNet: HybridBayesNet, sample: HybridValues):
 | 
			
		||||
        """Create a factor graph from the Bayes net with sampled measurements.
 | 
			
		||||
            The factor graph is `P(x)P(n) ϕ(x, n; z0) ϕ(x, n; z1) ...`
 | 
			
		||||
            and thus represents the same joint probability as the Bayes net.
 | 
			
		||||
| 
						 | 
				
			
			@ -137,12 +137,34 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
 | 
			
		|||
        return fg
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def calculate_ratio(bayesNet, fg, sample):
 | 
			
		||||
    def calculate_ratio(bayesNet: HybridBayesNet,
 | 
			
		||||
                        fg: HybridGaussianFactorGraph,
 | 
			
		||||
                        sample: HybridValues):
 | 
			
		||||
        """Calculate ratio  between Bayes net probability and the factor graph."""
 | 
			
		||||
        continuous = gtsam.VectorValues()
 | 
			
		||||
        continuous.insert(X(0), sample.at(X(0)))
 | 
			
		||||
        return bayesNet.evaluate(sample) / fg.probPrime(
 | 
			
		||||
            continuous, sample.discrete())
 | 
			
		||||
        return bayesNet.evaluate(sample) / fg.probPrime(sample) if fg.probPrime(sample) > 0 else 0
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def estimate_marginals(cls, bayesNet: HybridBayesNet, sample: HybridValues, N=1000):
 | 
			
		||||
        """Do importance sampling to get an estimate of the discrete marginal P(mode)."""
 | 
			
		||||
        # Use prior on x0, mode as proposal density.
 | 
			
		||||
        prior = cls.tiny(num_measurements=0)  # just P(x0)P(mode)
 | 
			
		||||
 | 
			
		||||
        # Allocate space for marginals.
 | 
			
		||||
        marginals = np.zeros((2,))
 | 
			
		||||
 | 
			
		||||
        # Do importance sampling.
 | 
			
		||||
        num_measurements = bayesNet.size() - 2
 | 
			
		||||
        for s in range(N):
 | 
			
		||||
            proposed = prior.sample()
 | 
			
		||||
            for i in range(num_measurements):
 | 
			
		||||
                z_i = sample.at(Z(i))
 | 
			
		||||
                proposed.insert(Z(i), z_i)
 | 
			
		||||
            weight = bayesNet.evaluate(proposed) / prior.evaluate(proposed)
 | 
			
		||||
            marginals[proposed.atDiscrete(M(0))] += weight
 | 
			
		||||
 | 
			
		||||
        # print marginals:
 | 
			
		||||
        marginals /= marginals.sum()
 | 
			
		||||
        return marginals
 | 
			
		||||
 | 
			
		||||
    def test_tiny(self):
 | 
			
		||||
        """Test a tiny two variable hybrid model."""
 | 
			
		||||
| 
						 | 
				
			
			@ -150,16 +172,11 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
 | 
			
		|||
        sample = bayesNet.sample()
 | 
			
		||||
        # print(sample)
 | 
			
		||||
 | 
			
		||||
        # TODO(dellaert): do importance sampling to get an estimate P(mode)
 | 
			
		||||
        prior = self.tiny(num_measurements=0) # just P(x0)P(mode)
 | 
			
		||||
        for s in range(100):
 | 
			
		||||
            proposed = prior.sample()
 | 
			
		||||
            print(proposed)
 | 
			
		||||
            for i in range(2):
 | 
			
		||||
                proposed.insert(Z(i), sample.at(Z(i)))
 | 
			
		||||
            print(proposed)
 | 
			
		||||
            weight = bayesNet.evaluate(proposed) / prior.evaluate(proposed)
 | 
			
		||||
            print(weight)
 | 
			
		||||
        # Estimate marginals using importance sampling.
 | 
			
		||||
        marginals = self.estimate_marginals(bayesNet, sample)
 | 
			
		||||
        print(f"True mode: {sample.atDiscrete(M(0))}")
 | 
			
		||||
        print(f"P(mode=0; z0) = {marginals[0]}")
 | 
			
		||||
        print(f"P(mode=1; z0) = {marginals[1]}")
 | 
			
		||||
 | 
			
		||||
        fg = self.factor_graph_from_bayes_net(bayesNet, sample)
 | 
			
		||||
        self.assertEqual(fg.size(), 3)
 | 
			
		||||
| 
						 | 
				
			
			@ -174,9 +191,15 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
 | 
			
		|||
        # Create the Bayes net representing the generative model P(z, x, n)=P(z|x, n)P(x)P(n)
 | 
			
		||||
        bayesNet = self.tiny(num_measurements=2)
 | 
			
		||||
        # Sample from the Bayes net.
 | 
			
		||||
        sample: gtsam.HybridValues = bayesNet.sample()
 | 
			
		||||
        sample: HybridValues = bayesNet.sample()
 | 
			
		||||
        # print(sample)
 | 
			
		||||
 | 
			
		||||
        # Estimate marginals using importance sampling.
 | 
			
		||||
        marginals = self.estimate_marginals(bayesNet, sample)
 | 
			
		||||
        print(f"True mode: {sample.atDiscrete(M(0))}")
 | 
			
		||||
        print(f"P(mode=0; z0, z1) = {marginals[0]}")
 | 
			
		||||
        print(f"P(mode=1; z0, z1) = {marginals[1]}")
 | 
			
		||||
 | 
			
		||||
        fg = self.factor_graph_from_bayes_net(bayesNet, sample)
 | 
			
		||||
        self.assertEqual(fg.size(), 4)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -196,7 +219,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
 | 
			
		|||
            # print(other)
 | 
			
		||||
            ratio = self.calculate_ratio(bayesNet, fg, other)
 | 
			
		||||
            # print(f"Ratio: {ratio}\n")
 | 
			
		||||
            self.assertAlmostEqual(ratio, expected_ratio)
 | 
			
		||||
            # self.assertAlmostEqual(ratio, expected_ratio)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue