Added DEBUG_MARGINALS flag
parent
e31884c9a1
commit
9af7236980
|
@ -22,6 +22,8 @@ from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional,
|
|||
HybridGaussianFactorGraph, HybridValues, JacobianFactor,
|
||||
Ordering, noiseModel)
|
||||
|
||||
DEBUG_MARGINALS = False
|
||||
|
||||
|
||||
class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||
"""Unit tests for HybridGaussianFactorGraph."""
|
||||
|
@ -201,9 +203,10 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
# Estimate marginals using importance sampling.
|
||||
marginals = self.estimate_marginals(target=unnormalized_posterior,
|
||||
proposal_density=proposal_density)
|
||||
print(f"True mode: {values.atDiscrete(M(0))}")
|
||||
print(f"P(mode=0; Z) = {marginals[0]}")
|
||||
print(f"P(mode=1; Z) = {marginals[1]}")
|
||||
if DEBUG_MARGINALS:
|
||||
print(f"True mode: {values.atDiscrete(M(0))}")
|
||||
print(f"P(mode=0; Z) = {marginals[0]}")
|
||||
print(f"P(mode=1; Z) = {marginals[1]}")
|
||||
|
||||
# Check that the estimate is close to the true value.
|
||||
self.assertAlmostEqual(marginals[0], 0.74, delta=0.01)
|
||||
|
@ -232,9 +235,10 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
# Estimate marginals using importance sampling.
|
||||
marginals = self.estimate_marginals(target=true_posterior,
|
||||
proposal_density=proposal_density)
|
||||
print(f"True mode: {values.atDiscrete(M(0))}")
|
||||
print(f"P(mode=0; z0) = {marginals[0]}")
|
||||
print(f"P(mode=1; z0) = {marginals[1]}")
|
||||
if DEBUG_MARGINALS:
|
||||
print(f"True mode: {values.atDiscrete(M(0))}")
|
||||
print(f"P(mode=0; z0) = {marginals[0]}")
|
||||
print(f"P(mode=1; z0) = {marginals[1]}")
|
||||
|
||||
# Check that the estimate is close to the true value.
|
||||
self.assertAlmostEqual(marginals[0], 0.74, delta=0.01)
|
||||
|
@ -247,7 +251,6 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
return bayesNet.evaluate(sample) / fg.probPrime(sample) if \
|
||||
fg.probPrime(sample) > 0 else 0
|
||||
|
||||
@unittest.skip
|
||||
def test_ratio(self):
|
||||
"""
|
||||
Given a tiny two variable hybrid model, with 2 measurements, test the
|
||||
|
@ -283,9 +286,10 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
# Estimate marginals using importance sampling.
|
||||
marginals = self.estimate_marginals(target=unnormalized_posterior,
|
||||
proposal_density=proposal_density)
|
||||
print(f"True mode: {values.atDiscrete(M(0))}")
|
||||
print(f"P(mode=0; Z) = {marginals[0]}")
|
||||
print(f"P(mode=1; Z) = {marginals[1]}")
|
||||
if DEBUG_MARGINALS:
|
||||
print(f"True mode: {values.atDiscrete(M(0))}")
|
||||
print(f"P(mode=0; Z) = {marginals[0]}")
|
||||
print(f"P(mode=1; Z) = {marginals[1]}")
|
||||
|
||||
# Check that the estimate is close to the true value.
|
||||
self.assertAlmostEqual(marginals[0], 0.23, delta=0.01)
|
||||
|
|
Loading…
Reference in New Issue