gtsam/python/gtsam/tests/test_DiscreteBayesNet.py

132 lines
4.0 KiB
Python

"""
GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
Atlanta, Georgia 30332-0415
All Rights Reserved
See LICENSE for the license information
Unit tests for Discrete Bayes Nets.
Author: Frank Dellaert
"""
# pylint: disable=no-name-in-module, invalid-name
import unittest
from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering)
from gtsam.utils.test_case import GtsamTestCase
# Some keys:
Asia = (0, 2)
Smoking = (4, 2)
Tuberculosis = (3, 2)
LungCancer = (6, 2)
Bronchitis = (7, 2)
Either = (5, 2)
XRay = (2, 2)
Dyspnea = (1, 2)
class TestDiscreteBayesNet(GtsamTestCase):
"""Tests for Discrete Bayes Nets."""
def test_constructor(self):
"""Test constructing a Bayes net."""
bayesNet = DiscreteBayesNet()
Parent, Child = (0, 2), (1, 2)
empty = DiscreteKeys()
prior = DiscreteConditional(Parent, empty, "6/4")
bayesNet.add(prior)
parents = DiscreteKeys()
parents.push_back(Parent)
conditional = DiscreteConditional(Child, parents, "7/3 8/2")
bayesNet.add(conditional)
# Check conversion to factor graph:
fg = DiscreteFactorGraph(bayesNet)
self.assertEqual(fg.size(), 2)
self.assertEqual(fg.at(1).size(), 2)
def test_Asia(self):
"""Test full Asia example."""
asia = DiscreteBayesNet()
asia.add(Asia, "99/1")
asia.add(Smoking, "50/50")
asia.add(Tuberculosis, [Asia], "99/1 95/5")
asia.add(LungCancer, [Smoking], "99/1 90/10")
asia.add(Bronchitis, [Smoking], "70/30 40/60")
asia.add(Either, [Tuberculosis, LungCancer], "F T T T")
asia.add(XRay, [Either], "95/5 2/98")
asia.add(Dyspnea, [Either, Bronchitis], "9/1 2/8 3/7 1/9")
# Convert to factor graph
fg = DiscreteFactorGraph(asia)
# Create solver and eliminate
ordering = Ordering()
for j in range(8):
ordering.push_back(j)
chordal = fg.eliminateSequential(ordering)
expected2 = DiscreteDistribution(Bronchitis, "11/9")
self.gtsamAssertEquals(chordal.at(7), expected2)
# solve
actualMPE = fg.optimize()
expectedMPE = DiscreteValues()
for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]:
expectedMPE[key[0]] = 0
self.assertEqual(list(actualMPE.items()),
list(expectedMPE.items()))
# Check value for MPE is the same
self.assertAlmostEqual(asia(actualMPE), fg(actualMPE))
# add evidence, we were in Asia and we have dyspnea
fg.add(Asia, "0 1")
fg.add(Dyspnea, "0 1")
# solve again, now with evidence
actualMPE2 = fg.optimize()
expectedMPE2 = DiscreteValues()
for key in [XRay, Tuberculosis, Either, LungCancer]:
expectedMPE2[key[0]] = 0
for key in [Asia, Dyspnea, Smoking, Bronchitis]:
expectedMPE2[key[0]] = 1
self.assertEqual(list(actualMPE2.items()),
list(expectedMPE2.items()))
# now sample from it
chordal2 = fg.eliminateSequential(ordering)
actualSample = chordal2.sample()
self.assertEqual(len(actualSample), 8)
def test_fragment(self):
"""Test sampling and optimizing for Asia fragment."""
# Create a reverse-topologically sorted fragment:
fragment = DiscreteBayesNet()
fragment.add(Either, [Tuberculosis, LungCancer], "F T T T")
fragment.add(Tuberculosis, [Asia], "99/1 95/5")
fragment.add(LungCancer, [Smoking], "99/1 90/10")
# Create assignment with missing values:
given = DiscreteValues()
for key in [Asia, Smoking]:
given[key[0]] = 0
# Now sample from fragment:
actual = fragment.sample(given)
self.assertEqual(len(actual), 5)
if __name__ == "__main__":
unittest.main()