Check marginals in addition to ratios for non-uniform mode prior
parent
b798f3ebb5
commit
cec26d16ea
|
@ -114,7 +114,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
bayesNet.addGaussian(prior_on_x0)
|
bayesNet.addGaussian(prior_on_x0)
|
||||||
|
|
||||||
# Add prior on mode.
|
# Add prior on mode.
|
||||||
bayesNet.emplaceDiscrete(mode, "6/4")
|
bayesNet.emplaceDiscrete(mode, "4/6")
|
||||||
|
|
||||||
return bayesNet
|
return bayesNet
|
||||||
|
|
||||||
|
@ -136,15 +136,8 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
fg.push_back(bayesNet.atDiscrete(num_measurements+1))
|
fg.push_back(bayesNet.atDiscrete(num_measurements+1))
|
||||||
return fg
|
return fg
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def calculate_ratio(bayesNet: HybridBayesNet,
|
|
||||||
fg: HybridGaussianFactorGraph,
|
|
||||||
sample: HybridValues):
|
|
||||||
"""Calculate ratio between Bayes net probability and the factor graph."""
|
|
||||||
return bayesNet.evaluate(sample) / fg.probPrime(sample) if fg.probPrime(sample) > 0 else 0
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def estimate_marginals(cls, bayesNet: HybridBayesNet, sample: HybridValues, N=1000):
|
def estimate_marginals(cls, bayesNet: HybridBayesNet, sample: HybridValues, N=10000):
|
||||||
"""Do importance sampling to get an estimate of the discrete marginal P(mode)."""
|
"""Do importance sampling to get an estimate of the discrete marginal P(mode)."""
|
||||||
# Use prior on x0, mode as proposal density.
|
# Use prior on x0, mode as proposal density.
|
||||||
prior = cls.tiny(num_measurements=0) # just P(x0)P(mode)
|
prior = cls.tiny(num_measurements=0) # just P(x0)P(mode)
|
||||||
|
@ -174,13 +167,24 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
|
|
||||||
# Estimate marginals using importance sampling.
|
# Estimate marginals using importance sampling.
|
||||||
marginals = self.estimate_marginals(bayesNet, sample)
|
marginals = self.estimate_marginals(bayesNet, sample)
|
||||||
print(f"True mode: {sample.atDiscrete(M(0))}")
|
# print(f"True mode: {sample.atDiscrete(M(0))}")
|
||||||
print(f"P(mode=0; z0) = {marginals[0]}")
|
# print(f"P(mode=0; z0) = {marginals[0]}")
|
||||||
print(f"P(mode=1; z0) = {marginals[1]}")
|
# print(f"P(mode=1; z0) = {marginals[1]}")
|
||||||
|
|
||||||
|
# Check that the estimate is close to the true value.
|
||||||
|
self.assertAlmostEqual(marginals[0], 0.4, delta=0.1)
|
||||||
|
self.assertAlmostEqual(marginals[1], 0.6, delta=0.1)
|
||||||
|
|
||||||
fg = self.factor_graph_from_bayes_net(bayesNet, sample)
|
fg = self.factor_graph_from_bayes_net(bayesNet, sample)
|
||||||
self.assertEqual(fg.size(), 3)
|
self.assertEqual(fg.size(), 3)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_ratio(bayesNet: HybridBayesNet,
|
||||||
|
fg: HybridGaussianFactorGraph,
|
||||||
|
sample: HybridValues):
|
||||||
|
"""Calculate ratio between Bayes net probability and the factor graph."""
|
||||||
|
return bayesNet.evaluate(sample) / fg.probPrime(sample) if fg.probPrime(sample) > 0 else 0
|
||||||
|
|
||||||
def test_ratio(self):
|
def test_ratio(self):
|
||||||
"""
|
"""
|
||||||
Given a tiny two variable hybrid model, with 2 measurements,
|
Given a tiny two variable hybrid model, with 2 measurements,
|
||||||
|
@ -196,9 +200,15 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
|
|
||||||
# Estimate marginals using importance sampling.
|
# Estimate marginals using importance sampling.
|
||||||
marginals = self.estimate_marginals(bayesNet, sample)
|
marginals = self.estimate_marginals(bayesNet, sample)
|
||||||
print(f"True mode: {sample.atDiscrete(M(0))}")
|
# print(f"True mode: {sample.atDiscrete(M(0))}")
|
||||||
print(f"P(mode=0; z0, z1) = {marginals[0]}")
|
# print(f"P(mode=0; z0, z1) = {marginals[0]}")
|
||||||
print(f"P(mode=1; z0, z1) = {marginals[1]}")
|
# print(f"P(mode=1; z0, z1) = {marginals[1]}")
|
||||||
|
|
||||||
|
# Check marginals based on sampled mode.
|
||||||
|
if sample.atDiscrete(M(0)) == 0:
|
||||||
|
self.assertGreater(marginals[0], marginals[1])
|
||||||
|
else:
|
||||||
|
self.assertGreater(marginals[1], marginals[0])
|
||||||
|
|
||||||
fg = self.factor_graph_from_bayes_net(bayesNet, sample)
|
fg = self.factor_graph_from_bayes_net(bayesNet, sample)
|
||||||
self.assertEqual(fg.size(), 4)
|
self.assertEqual(fg.size(), 4)
|
||||||
|
|
Loading…
Reference in New Issue