factor_graph_from_bayes_net
parent
078e6b0b62
commit
23eec0bc6a
|
@ -6,7 +6,7 @@ All Rights Reserved
|
||||||
See LICENSE for the license information
|
See LICENSE for the license information
|
||||||
|
|
||||||
Unit tests for Hybrid Factor Graphs.
|
Unit tests for Hybrid Factor Graphs.
|
||||||
Author: Fan Jiang
|
Author: Fan Jiang, Varun Agrawal, Frank Dellaert
|
||||||
"""
|
"""
|
||||||
# pylint: disable=invalid-name, no-name-in-module, no-member
|
# pylint: disable=invalid-name, no-name-in-module, no-member
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional,
|
||||||
|
|
||||||
class TestHybridGaussianFactorGraph(GtsamTestCase):
|
class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
"""Unit tests for HybridGaussianFactorGraph."""
|
"""Unit tests for HybridGaussianFactorGraph."""
|
||||||
|
|
||||||
def test_create(self):
|
def test_create(self):
|
||||||
"""Test construction of hybrid factor graph."""
|
"""Test construction of hybrid factor graph."""
|
||||||
model = noiseModel.Unit.Create(3)
|
model = noiseModel.Unit.Create(3)
|
||||||
|
@ -117,23 +118,23 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
|
|
||||||
return bayesNet
|
return bayesNet
|
||||||
|
|
||||||
def test_tiny(self):
|
@staticmethod
|
||||||
"""Test a tiny two variable hybrid model."""
|
def factor_graph_from_bayes_net(bayesNet: gtsam.HybridBayesNet, sample: gtsam.HybridValues):
|
||||||
bayesNet = self.tiny()
|
"""Create a factor graph from the Bayes net with sampled measurements.
|
||||||
sample = bayesNet.sample()
|
The factor graph is `P(x)P(n) ϕ(x, n; z0) ϕ(x, n; z1) ...`
|
||||||
# print(sample)
|
and thus represents the same joint probability as the Bayes net.
|
||||||
|
"""
|
||||||
# Create a factor graph from the Bayes net with sampled measurements.
|
|
||||||
fg = HybridGaussianFactorGraph()
|
fg = HybridGaussianFactorGraph()
|
||||||
conditional = bayesNet.atMixture(0)
|
num_measurements = bayesNet.size() - 2
|
||||||
measurement = gtsam.VectorValues()
|
for i in range(num_measurements):
|
||||||
measurement.insert(Z(0), sample.at(Z(0)))
|
conditional = bayesNet.atMixture(i)
|
||||||
factor = conditional.likelihood(measurement)
|
measurement = gtsam.VectorValues()
|
||||||
fg.push_back(factor)
|
measurement.insert(Z(i), sample.at(Z(i)))
|
||||||
fg.push_back(bayesNet.atGaussian(1))
|
factor = conditional.likelihood(measurement)
|
||||||
fg.push_back(bayesNet.atDiscrete(2))
|
fg.push_back(factor)
|
||||||
|
fg.push_back(bayesNet.atGaussian(num_measurements))
|
||||||
self.assertEqual(fg.size(), 3)
|
fg.push_back(bayesNet.atDiscrete(num_measurements+1))
|
||||||
|
return fg
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def calculate_ratio(bayesNet, fg, sample):
|
def calculate_ratio(bayesNet, fg, sample):
|
||||||
|
@ -143,6 +144,26 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
return bayesNet.evaluate(sample) / fg.probPrime(
|
return bayesNet.evaluate(sample) / fg.probPrime(
|
||||||
continuous, sample.discrete())
|
continuous, sample.discrete())
|
||||||
|
|
||||||
|
def test_tiny(self):
|
||||||
|
"""Test a tiny two variable hybrid model."""
|
||||||
|
bayesNet = self.tiny()
|
||||||
|
sample = bayesNet.sample()
|
||||||
|
# print(sample)
|
||||||
|
|
||||||
|
# TODO(dellaert): do importance sampling to get an estimate P(mode)
|
||||||
|
prior = self.tiny(num_measurements=0) # just P(x0)P(mode)
|
||||||
|
for s in range(100):
|
||||||
|
proposed = prior.sample()
|
||||||
|
print(proposed)
|
||||||
|
for i in range(2):
|
||||||
|
proposed.insert(Z(i), sample.at(Z(i)))
|
||||||
|
print(proposed)
|
||||||
|
weight = bayesNet.evaluate(proposed) / prior.evaluate(proposed)
|
||||||
|
print(weight)
|
||||||
|
|
||||||
|
fg = self.factor_graph_from_bayes_net(bayesNet, sample)
|
||||||
|
self.assertEqual(fg.size(), 3)
|
||||||
|
|
||||||
def test_ratio(self):
|
def test_ratio(self):
|
||||||
"""
|
"""
|
||||||
Given a tiny two variable hybrid model, with 2 measurements,
|
Given a tiny two variable hybrid model, with 2 measurements,
|
||||||
|
@ -156,20 +177,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
sample: gtsam.HybridValues = bayesNet.sample()
|
sample: gtsam.HybridValues = bayesNet.sample()
|
||||||
# print(sample)
|
# print(sample)
|
||||||
|
|
||||||
# Create a factor graph from the Bayes net with sampled measurements.
|
fg = self.factor_graph_from_bayes_net(bayesNet, sample)
|
||||||
# The factor graph is `P(x)P(n) ϕ(x, n; z1) ϕ(x, n; z2)`
|
|
||||||
# and thus represents the same joint probability as the Bayes net.
|
|
||||||
fg = HybridGaussianFactorGraph()
|
|
||||||
for i in range(2):
|
|
||||||
conditional = bayesNet.atMixture(i)
|
|
||||||
measurement = gtsam.VectorValues()
|
|
||||||
measurement.insert(Z(i), sample.at(Z(i)))
|
|
||||||
factor = conditional.likelihood(measurement)
|
|
||||||
fg.push_back(factor)
|
|
||||||
fg.push_back(bayesNet.atGaussian(2))
|
|
||||||
fg.push_back(bayesNet.atDiscrete(3))
|
|
||||||
|
|
||||||
# print(fg)
|
|
||||||
self.assertEqual(fg.size(), 4)
|
self.assertEqual(fg.size(), 4)
|
||||||
|
|
||||||
# Calculate ratio between Bayes net probability and the factor graph:
|
# Calculate ratio between Bayes net probability and the factor graph:
|
||||||
|
@ -186,9 +194,9 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
other = bayesNet.sample()
|
other = bayesNet.sample()
|
||||||
other.update(measurements)
|
other.update(measurements)
|
||||||
# print(other)
|
# print(other)
|
||||||
# ratio = self.calculate_ratio(bayesNet, fg, other)
|
ratio = self.calculate_ratio(bayesNet, fg, other)
|
||||||
# print(f"Ratio: {ratio}\n")
|
# print(f"Ratio: {ratio}\n")
|
||||||
# self.assertAlmostEqual(ratio, expected_ratio)
|
self.assertAlmostEqual(ratio, expected_ratio)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue