132 lines
4.0 KiB
Python
132 lines
4.0 KiB
Python
"""
|
|
GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
|
|
Atlanta, Georgia 30332-0415
|
|
All Rights Reserved
|
|
|
|
See LICENSE for the license information
|
|
|
|
Unit tests for Discrete Bayes Nets.
|
|
Author: Frank Dellaert
|
|
"""
|
|
|
|
# pylint: disable=no-name-in-module, invalid-name
|
|
|
|
import unittest
|
|
|
|
from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
|
|
DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering)
|
|
from gtsam.utils.test_case import GtsamTestCase
|
|
|
|
# Some keys:
|
|
Asia = (0, 2)
|
|
Smoking = (4, 2)
|
|
Tuberculosis = (3, 2)
|
|
LungCancer = (6, 2)
|
|
|
|
Bronchitis = (7, 2)
|
|
Either = (5, 2)
|
|
XRay = (2, 2)
|
|
Dyspnea = (1, 2)
|
|
|
|
|
|
class TestDiscreteBayesNet(GtsamTestCase):
|
|
"""Tests for Discrete Bayes Nets."""
|
|
|
|
def test_constructor(self):
|
|
"""Test constructing a Bayes net."""
|
|
|
|
bayesNet = DiscreteBayesNet()
|
|
Parent, Child = (0, 2), (1, 2)
|
|
empty = DiscreteKeys()
|
|
prior = DiscreteConditional(Parent, empty, "6/4")
|
|
bayesNet.add(prior)
|
|
|
|
parents = DiscreteKeys()
|
|
parents.push_back(Parent)
|
|
conditional = DiscreteConditional(Child, parents, "7/3 8/2")
|
|
bayesNet.add(conditional)
|
|
|
|
# Check conversion to factor graph:
|
|
fg = DiscreteFactorGraph(bayesNet)
|
|
self.assertEqual(fg.size(), 2)
|
|
self.assertEqual(fg.at(1).size(), 2)
|
|
|
|
def test_Asia(self):
|
|
"""Test full Asia example."""
|
|
|
|
asia = DiscreteBayesNet()
|
|
asia.add(Asia, "99/1")
|
|
asia.add(Smoking, "50/50")
|
|
|
|
asia.add(Tuberculosis, [Asia], "99/1 95/5")
|
|
asia.add(LungCancer, [Smoking], "99/1 90/10")
|
|
asia.add(Bronchitis, [Smoking], "70/30 40/60")
|
|
|
|
asia.add(Either, [Tuberculosis, LungCancer], "F T T T")
|
|
|
|
asia.add(XRay, [Either], "95/5 2/98")
|
|
asia.add(Dyspnea, [Either, Bronchitis], "9/1 2/8 3/7 1/9")
|
|
|
|
# Convert to factor graph
|
|
fg = DiscreteFactorGraph(asia)
|
|
|
|
# Create solver and eliminate
|
|
ordering = Ordering()
|
|
for j in range(8):
|
|
ordering.push_back(j)
|
|
chordal = fg.eliminateSequential(ordering)
|
|
expected2 = DiscreteDistribution(Bronchitis, "11/9")
|
|
self.gtsamAssertEquals(chordal.at(7), expected2)
|
|
|
|
# solve
|
|
actualMPE = fg.optimize()
|
|
expectedMPE = DiscreteValues()
|
|
for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]:
|
|
expectedMPE[key[0]] = 0
|
|
self.assertEqual(list(actualMPE.items()),
|
|
list(expectedMPE.items()))
|
|
|
|
# Check value for MPE is the same
|
|
self.assertAlmostEqual(asia(actualMPE), fg(actualMPE))
|
|
|
|
# add evidence, we were in Asia and we have dyspnea
|
|
fg.add(Asia, "0 1")
|
|
fg.add(Dyspnea, "0 1")
|
|
|
|
# solve again, now with evidence
|
|
actualMPE2 = fg.optimize()
|
|
expectedMPE2 = DiscreteValues()
|
|
for key in [XRay, Tuberculosis, Either, LungCancer]:
|
|
expectedMPE2[key[0]] = 0
|
|
for key in [Asia, Dyspnea, Smoking, Bronchitis]:
|
|
expectedMPE2[key[0]] = 1
|
|
self.assertEqual(list(actualMPE2.items()),
|
|
list(expectedMPE2.items()))
|
|
|
|
# now sample from it
|
|
chordal2 = fg.eliminateSequential(ordering)
|
|
actualSample = chordal2.sample()
|
|
self.assertEqual(len(actualSample), 8)
|
|
|
|
def test_fragment(self):
|
|
"""Test sampling and optimizing for Asia fragment."""
|
|
|
|
# Create a reverse-topologically sorted fragment:
|
|
fragment = DiscreteBayesNet()
|
|
fragment.add(Either, [Tuberculosis, LungCancer], "F T T T")
|
|
fragment.add(Tuberculosis, [Asia], "99/1 95/5")
|
|
fragment.add(LungCancer, [Smoking], "99/1 90/10")
|
|
|
|
# Create assignment with missing values:
|
|
given = DiscreteValues()
|
|
for key in [Asia, Smoking]:
|
|
given[key[0]] = 0
|
|
|
|
# Now sample from fragment:
|
|
actual = fragment.sample(given)
|
|
self.assertEqual(len(actual), 5)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|