Added DEBUG_MARGINALS flag
parent
e31884c9a1
commit
9af7236980
|
@ -22,6 +22,8 @@ from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional,
|
||||||
HybridGaussianFactorGraph, HybridValues, JacobianFactor,
|
HybridGaussianFactorGraph, HybridValues, JacobianFactor,
|
||||||
Ordering, noiseModel)
|
Ordering, noiseModel)
|
||||||
|
|
||||||
|
DEBUG_MARGINALS = False
|
||||||
|
|
||||||
|
|
||||||
class TestHybridGaussianFactorGraph(GtsamTestCase):
|
class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
"""Unit tests for HybridGaussianFactorGraph."""
|
"""Unit tests for HybridGaussianFactorGraph."""
|
||||||
|
@ -201,9 +203,10 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
# Estimate marginals using importance sampling.
|
# Estimate marginals using importance sampling.
|
||||||
marginals = self.estimate_marginals(target=unnormalized_posterior,
|
marginals = self.estimate_marginals(target=unnormalized_posterior,
|
||||||
proposal_density=proposal_density)
|
proposal_density=proposal_density)
|
||||||
print(f"True mode: {values.atDiscrete(M(0))}")
|
if DEBUG_MARGINALS:
|
||||||
print(f"P(mode=0; Z) = {marginals[0]}")
|
print(f"True mode: {values.atDiscrete(M(0))}")
|
||||||
print(f"P(mode=1; Z) = {marginals[1]}")
|
print(f"P(mode=0; Z) = {marginals[0]}")
|
||||||
|
print(f"P(mode=1; Z) = {marginals[1]}")
|
||||||
|
|
||||||
# Check that the estimate is close to the true value.
|
# Check that the estimate is close to the true value.
|
||||||
self.assertAlmostEqual(marginals[0], 0.74, delta=0.01)
|
self.assertAlmostEqual(marginals[0], 0.74, delta=0.01)
|
||||||
|
@ -232,9 +235,10 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
# Estimate marginals using importance sampling.
|
# Estimate marginals using importance sampling.
|
||||||
marginals = self.estimate_marginals(target=true_posterior,
|
marginals = self.estimate_marginals(target=true_posterior,
|
||||||
proposal_density=proposal_density)
|
proposal_density=proposal_density)
|
||||||
print(f"True mode: {values.atDiscrete(M(0))}")
|
if DEBUG_MARGINALS:
|
||||||
print(f"P(mode=0; z0) = {marginals[0]}")
|
print(f"True mode: {values.atDiscrete(M(0))}")
|
||||||
print(f"P(mode=1; z0) = {marginals[1]}")
|
print(f"P(mode=0; z0) = {marginals[0]}")
|
||||||
|
print(f"P(mode=1; z0) = {marginals[1]}")
|
||||||
|
|
||||||
# Check that the estimate is close to the true value.
|
# Check that the estimate is close to the true value.
|
||||||
self.assertAlmostEqual(marginals[0], 0.74, delta=0.01)
|
self.assertAlmostEqual(marginals[0], 0.74, delta=0.01)
|
||||||
|
@ -247,7 +251,6 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
return bayesNet.evaluate(sample) / fg.probPrime(sample) if \
|
return bayesNet.evaluate(sample) / fg.probPrime(sample) if \
|
||||||
fg.probPrime(sample) > 0 else 0
|
fg.probPrime(sample) > 0 else 0
|
||||||
|
|
||||||
@unittest.skip
|
|
||||||
def test_ratio(self):
|
def test_ratio(self):
|
||||||
"""
|
"""
|
||||||
Given a tiny two variable hybrid model, with 2 measurements, test the
|
Given a tiny two variable hybrid model, with 2 measurements, test the
|
||||||
|
@ -283,9 +286,10 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
# Estimate marginals using importance sampling.
|
# Estimate marginals using importance sampling.
|
||||||
marginals = self.estimate_marginals(target=unnormalized_posterior,
|
marginals = self.estimate_marginals(target=unnormalized_posterior,
|
||||||
proposal_density=proposal_density)
|
proposal_density=proposal_density)
|
||||||
print(f"True mode: {values.atDiscrete(M(0))}")
|
if DEBUG_MARGINALS:
|
||||||
print(f"P(mode=0; Z) = {marginals[0]}")
|
print(f"True mode: {values.atDiscrete(M(0))}")
|
||||||
print(f"P(mode=1; Z) = {marginals[1]}")
|
print(f"P(mode=0; Z) = {marginals[0]}")
|
||||||
|
print(f"P(mode=1; Z) = {marginals[1]}")
|
||||||
|
|
||||||
# Check that the estimate is close to the true value.
|
# Check that the estimate is close to the true value.
|
||||||
self.assertAlmostEqual(marginals[0], 0.23, delta=0.01)
|
self.assertAlmostEqual(marginals[0], 0.23, delta=0.01)
|
||||||
|
|
Loading…
Reference in New Issue