279 lines
9.3 KiB
Python
279 lines
9.3 KiB
Python
"""
|
|
GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
|
|
Atlanta, Georgia 30332-0415
|
|
All Rights Reserved
|
|
|
|
See LICENSE for the license information
|
|
|
|
Unit tests for Discrete Factor Graphs.
|
|
Author: Frank Dellaert
|
|
"""
|
|
|
|
# pylint: disable=no-name-in-module, invalid-name
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol
|
|
from gtsam.utils.test_case import GtsamTestCase
|
|
|
|
OrderingType = Ordering.OrderingType
|
|
|
|
|
|
class TestDiscreteFactorGraph(GtsamTestCase):
|
|
"""Tests for Discrete Factor Graphs."""
|
|
|
|
def test_evaluation(self):
|
|
"""Test constructing and evaluating a discrete factor graph."""
|
|
|
|
# Three keys
|
|
P1 = (0, 2)
|
|
P2 = (1, 2)
|
|
P3 = (2, 3)
|
|
|
|
# Create the DiscreteFactorGraph
|
|
graph = DiscreteFactorGraph()
|
|
|
|
# Add two unary factors (priors)
|
|
graph.add(P1, [0.9, 0.3])
|
|
graph.add(P2, "0.9 0.6")
|
|
|
|
# Add a binary factor
|
|
graph.add([P1, P2], "4 1 10 4")
|
|
|
|
# Instantiate Values
|
|
assignment = DiscreteValues()
|
|
assignment[0] = 1
|
|
assignment[1] = 1
|
|
|
|
# Check if graph evaluation works ( 0.3*0.6*4 )
|
|
self.assertAlmostEqual(.72, graph(assignment))
|
|
|
|
# Create a new test with third node and adding unary and ternary factor
|
|
graph.add(P3, "0.9 0.2 0.5")
|
|
keys = DiscreteKeys()
|
|
keys.push_back(P1)
|
|
keys.push_back(P2)
|
|
keys.push_back(P3)
|
|
graph.add(keys, "1 2 3 4 5 6 7 8 9 10 11 12")
|
|
|
|
# Below assignment selects the 8th index in the ternary factor table
|
|
assignment[0] = 1
|
|
assignment[1] = 0
|
|
assignment[2] = 1
|
|
|
|
# Check if graph evaluation works (0.3*0.9*1*0.2*8)
|
|
self.assertAlmostEqual(4.32, graph(assignment))
|
|
|
|
# Below assignment selects the 3rd index in the ternary factor table
|
|
assignment[0] = 0
|
|
assignment[1] = 1
|
|
assignment[2] = 0
|
|
|
|
# Check if graph evaluation works (0.9*0.6*1*0.9*4)
|
|
self.assertAlmostEqual(1.944, graph(assignment))
|
|
|
|
# Check if graph product works
|
|
product = graph.product()
|
|
self.assertAlmostEqual(1.944, product(assignment))
|
|
|
|
def test_optimize(self):
|
|
"""Test constructing and optizing a discrete factor graph."""
|
|
|
|
# Three keys
|
|
C = (0, 2)
|
|
B = (1, 2)
|
|
A = (2, 2)
|
|
|
|
# A simple factor graph (A)-fAC-(C)-fBC-(B)
|
|
# with smoothness priors
|
|
graph = DiscreteFactorGraph()
|
|
graph.add([A, C], "3 1 1 3")
|
|
graph.add([C, B], "3 1 1 3")
|
|
|
|
# Test optimization
|
|
expectedValues = DiscreteValues()
|
|
expectedValues[0] = 0
|
|
expectedValues[1] = 0
|
|
expectedValues[2] = 0
|
|
actualValues = graph.optimize()
|
|
self.assertEqual(list(actualValues.items()),
|
|
list(expectedValues.items()))
|
|
|
|
def test_MPE(self):
|
|
"""Test maximum probable explanation (MPE): same as optimize."""
|
|
|
|
# Declare a bunch of keys
|
|
C, A, B = (0, 2), (1, 2), (2, 2)
|
|
|
|
# Create Factor graph
|
|
graph = DiscreteFactorGraph()
|
|
graph.add([C, A], "0.2 0.8 0.3 0.7")
|
|
graph.add([C, B], "0.1 0.9 0.4 0.6")
|
|
|
|
# We know MPE
|
|
mpe = DiscreteValues()
|
|
mpe[0] = 0
|
|
mpe[1] = 1
|
|
mpe[2] = 1
|
|
|
|
# Use maxProduct
|
|
dag = graph.maxProduct(OrderingType.COLAMD)
|
|
actualMPE = dag.argmax()
|
|
self.assertEqual(list(actualMPE.items()),
|
|
list(mpe.items()))
|
|
|
|
# All in one
|
|
actualMPE2 = graph.optimize()
|
|
self.assertEqual(list(actualMPE2.items()),
|
|
list(mpe.items()))
|
|
|
|
def test_sumProduct(self):
|
|
"""Test sumProduct."""
|
|
|
|
# Declare a bunch of keys
|
|
C, A, B = (0, 2), (1, 2), (2, 2)
|
|
|
|
# Create Factor graph
|
|
graph = DiscreteFactorGraph()
|
|
graph.add([C, A], "0.2 0.8 0.3 0.7")
|
|
graph.add([C, B], "0.1 0.9 0.4 0.6")
|
|
|
|
# We know MPE
|
|
mpe = DiscreteValues()
|
|
mpe[0] = 0
|
|
mpe[1] = 1
|
|
mpe[2] = 1
|
|
|
|
# Use default sumProduct
|
|
bayesNet = graph.sumProduct()
|
|
mpeProbability = bayesNet(mpe)
|
|
self.assertAlmostEqual(mpeProbability, 0.36) # regression
|
|
|
|
# Use sumProduct
|
|
for ordering_type in [OrderingType.COLAMD, OrderingType.METIS, OrderingType.NATURAL,
|
|
OrderingType.CUSTOM]:
|
|
bayesNet = graph.sumProduct(ordering_type)
|
|
self.assertEqual(bayesNet(mpe), mpeProbability)
|
|
|
|
def test_MPE_chain(self):
|
|
"""
|
|
Test for numerical underflow in EliminateMPE on long chains.
|
|
Adapted from the toy problem of @pcl15423
|
|
Ref: https://github.com/borglab/gtsam/issues/1448
|
|
"""
|
|
num_states = 3
|
|
num_obs = 200
|
|
desired_state = 1
|
|
states = list(range(num_states))
|
|
|
|
# Helper function to mimic the behavior of gtbook.Variables discrete_series function
|
|
def make_key(character, index, cardinality):
|
|
symbol = Symbol(character, index)
|
|
key = symbol.key()
|
|
return (key, cardinality)
|
|
|
|
X = {index: make_key("X", index, len(states)) for index in range(num_obs)}
|
|
Z = {index: make_key("Z", index, num_obs + 1) for index in range(num_obs)}
|
|
graph = DiscreteFactorGraph()
|
|
|
|
# Mostly identity transition matrix
|
|
transitions = np.eye(num_states)
|
|
|
|
# Needed otherwise mpe is always state 0?
|
|
transitions += 0.1/(num_states)
|
|
|
|
transition_cpt = []
|
|
for i in range(0, num_states):
|
|
transition_row = "/".join([str(x) for x in transitions[i]])
|
|
transition_cpt.append(transition_row)
|
|
transition_cpt = " ".join(transition_cpt)
|
|
|
|
for i in reversed(range(1, num_obs)):
|
|
transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt)
|
|
graph.push_back(transition_conditional)
|
|
|
|
# Contrived example such that the desired state gives measurements [0, num_obs) with equal probability
|
|
# but all other states always give measurement num_obs
|
|
obs = np.zeros((num_states, num_obs+1))
|
|
obs[:,-1] = 1
|
|
obs[desired_state,0: -1] = 1
|
|
obs[desired_state,-1] = 0
|
|
obs_cpt_list = []
|
|
for i in range(0, num_states):
|
|
obs_row = "/".join([str(z) for z in obs[i]])
|
|
obs_cpt_list.append(obs_row)
|
|
obs_cpt = " ".join(obs_cpt_list)
|
|
|
|
# Contrived example where each measurement is its own index
|
|
for i in range(0, num_obs):
|
|
obs_conditional = DiscreteConditional(Z[i], [X[i]], obs_cpt)
|
|
factor = obs_conditional.likelihood(i)
|
|
graph.push_back(factor)
|
|
|
|
mpe = graph.optimize()
|
|
vals = [mpe[X[i][0]] for i in range(num_obs)]
|
|
|
|
self.assertEqual(vals, [desired_state]*num_obs)
|
|
|
|
def test_sumProduct_chain(self):
|
|
"""
|
|
Test for numerical underflow in EliminateDiscrete on long chains.
|
|
Adapted from the toy problem of @pcl15423
|
|
Ref: https://github.com/borglab/gtsam/issues/1448
|
|
"""
|
|
num_states = 3
|
|
chain_length = 400
|
|
desired_state = 1
|
|
states = list(range(num_states))
|
|
|
|
# Helper function to mimic the behavior of gtbook.Variables discrete_series function
|
|
def make_key(character, index, cardinality):
|
|
symbol = Symbol(character, index)
|
|
key = symbol.key()
|
|
return (key, cardinality)
|
|
|
|
X = {index: make_key("X", index, len(states)) for index in range(chain_length)}
|
|
graph = DiscreteFactorGraph()
|
|
|
|
# Construct test transition matrix
|
|
transitions = np.diag([1.0, 0.5, 0.1])
|
|
transitions += 0.1/(num_states)
|
|
|
|
# Ensure that the transition matrix is Markov (columns sum to 1)
|
|
transitions /= np.sum(transitions, axis=0)
|
|
|
|
# The stationary distribution is the eigenvector corresponding to eigenvalue 1
|
|
eigvals, eigvecs = np.linalg.eig(transitions)
|
|
stationary_idx = np.where(np.isclose(eigvals, 1.0))
|
|
stationary_dist = eigvecs[:, stationary_idx]
|
|
|
|
# Ensure that the stationary distribution is positive and normalized
|
|
stationary_dist /= np.sum(stationary_dist)
|
|
expected = DecisionTreeFactor(X[chain_length-1], stationary_dist.flatten())
|
|
|
|
# The transition matrix parsed by DiscreteConditional is a row-wise CPT
|
|
transitions = transitions.T
|
|
transition_cpt = []
|
|
for i in range(0, num_states):
|
|
transition_row = "/".join([str(x) for x in transitions[i]])
|
|
transition_cpt.append(transition_row)
|
|
transition_cpt = " ".join(transition_cpt)
|
|
|
|
for i in reversed(range(1, chain_length)):
|
|
transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt)
|
|
graph.push_back(transition_conditional)
|
|
|
|
# Run sum product using natural ordering so the resulting Bayes net has the form:
|
|
# X_0 <- X_1 <- ... <- X_n
|
|
sum_product = graph.sumProduct(OrderingType.NATURAL)
|
|
|
|
# Get the DiscreteConditional representing the marginal on the last factor
|
|
last_marginal = sum_product.at(chain_length - 1)
|
|
|
|
# Ensure marginal probabilities are close to the stationary distribution
|
|
self.gtsamAssertEquals(expected, last_marginal)
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|