Avoid some copy/paste
parent
19773153bc
commit
615196e415
|
@ -0,0 +1,35 @@
|
||||||
|
import numpy as np
|
||||||
|
from gtsam import Symbol
|
||||||
|
|
||||||
|
|
||||||
|
def make_key(character, index, cardinality):
|
||||||
|
"""
|
||||||
|
Helper function to mimic the behavior of gtbook.Variables discrete_series function.
|
||||||
|
"""
|
||||||
|
symbol = Symbol(character, index)
|
||||||
|
key = symbol.key()
|
||||||
|
return (key, cardinality)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_transition_cpt(num_states, transitions=None):
|
||||||
|
"""
|
||||||
|
Generate a row-wise CPT for a transition matrix.
|
||||||
|
"""
|
||||||
|
if transitions is None:
|
||||||
|
# Default to identity matrix with slight regularization
|
||||||
|
transitions = np.eye(num_states) + 0.1 / num_states
|
||||||
|
|
||||||
|
# Ensure transitions sum to 1 if not already normalized
|
||||||
|
transitions /= np.sum(transitions, axis=1, keepdims=True)
|
||||||
|
return " ".join(["/".join(map(str, row)) for row in transitions])
|
||||||
|
|
||||||
|
|
||||||
|
def generate_observation_cpt(num_states, num_obs, desired_state):
|
||||||
|
"""
|
||||||
|
Generate a row-wise CPT for observations with contrived probabilities.
|
||||||
|
"""
|
||||||
|
obs = np.zeros((num_states, num_obs + 1))
|
||||||
|
obs[:, -1] = 1 # All states default to measurement num_obs
|
||||||
|
obs[desired_state, 0:-1] = 1
|
||||||
|
obs[desired_state, -1] = 0
|
||||||
|
return " ".join(["/".join(map(str, row)) for row in obs])
|
|
@ -15,10 +15,16 @@ import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
from dfg_utils import make_key, generate_transition_cpt, generate_observation_cpt
|
||||||
|
|
||||||
from gtsam import (DecisionTreeFactor, DiscreteConditional,
|
from gtsam import (
|
||||||
DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering,
|
DecisionTreeFactor,
|
||||||
Symbol)
|
DiscreteConditional,
|
||||||
|
DiscreteFactorGraph,
|
||||||
|
DiscreteKeys,
|
||||||
|
DiscreteValues,
|
||||||
|
Ordering,
|
||||||
|
)
|
||||||
|
|
||||||
OrderingType = Ordering.OrderingType
|
OrderingType = Ordering.OrderingType
|
||||||
|
|
||||||
|
@ -50,7 +56,7 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
||||||
assignment[1] = 1
|
assignment[1] = 1
|
||||||
|
|
||||||
# Check if graph evaluation works ( 0.3*0.6*4 )
|
# Check if graph evaluation works ( 0.3*0.6*4 )
|
||||||
self.assertAlmostEqual(.72, graph(assignment))
|
self.assertAlmostEqual(0.72, graph(assignment))
|
||||||
|
|
||||||
# Create a new test with third node and adding unary and ternary factor
|
# Create a new test with third node and adding unary and ternary factor
|
||||||
graph.add(P3, "0.9 0.2 0.5")
|
graph.add(P3, "0.9 0.2 0.5")
|
||||||
|
@ -100,8 +106,7 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
||||||
expectedValues[1] = 0
|
expectedValues[1] = 0
|
||||||
expectedValues[2] = 0
|
expectedValues[2] = 0
|
||||||
actualValues = graph.optimize()
|
actualValues = graph.optimize()
|
||||||
self.assertEqual(list(actualValues.items()),
|
self.assertEqual(list(actualValues.items()), list(expectedValues.items()))
|
||||||
list(expectedValues.items()))
|
|
||||||
|
|
||||||
def test_MPE(self):
|
def test_MPE(self):
|
||||||
"""Test maximum probable explanation (MPE): same as optimize."""
|
"""Test maximum probable explanation (MPE): same as optimize."""
|
||||||
|
@ -123,13 +128,11 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
||||||
# Use maxProduct
|
# Use maxProduct
|
||||||
dag = graph.maxProduct(OrderingType.COLAMD)
|
dag = graph.maxProduct(OrderingType.COLAMD)
|
||||||
actualMPE = dag.argmax()
|
actualMPE = dag.argmax()
|
||||||
self.assertEqual(list(actualMPE.items()),
|
self.assertEqual(list(actualMPE.items()), list(mpe.items()))
|
||||||
list(mpe.items()))
|
|
||||||
|
|
||||||
# All in one
|
# All in one
|
||||||
actualMPE2 = graph.optimize()
|
actualMPE2 = graph.optimize()
|
||||||
self.assertEqual(list(actualMPE2.items()),
|
self.assertEqual(list(actualMPE2.items()), list(mpe.items()))
|
||||||
list(mpe.items()))
|
|
||||||
|
|
||||||
def test_sumProduct(self):
|
def test_sumProduct(self):
|
||||||
"""Test sumProduct."""
|
"""Test sumProduct."""
|
||||||
|
@ -154,11 +157,17 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
||||||
self.assertAlmostEqual(mpeProbability, 0.36) # regression
|
self.assertAlmostEqual(mpeProbability, 0.36) # regression
|
||||||
|
|
||||||
# Use sumProduct
|
# Use sumProduct
|
||||||
for ordering_type in [OrderingType.COLAMD, OrderingType.METIS, OrderingType.NATURAL,
|
for ordering_type in [
|
||||||
OrderingType.CUSTOM]:
|
OrderingType.COLAMD,
|
||||||
|
OrderingType.METIS,
|
||||||
|
OrderingType.NATURAL,
|
||||||
|
OrderingType.CUSTOM,
|
||||||
|
]:
|
||||||
bayesNet = graph.sumProduct(ordering_type)
|
bayesNet = graph.sumProduct(ordering_type)
|
||||||
self.assertEqual(bayesNet(mpe), mpeProbability)
|
self.assertEqual(bayesNet(mpe), mpeProbability)
|
||||||
|
|
||||||
|
|
||||||
|
class TestChains(GtsamTestCase):
|
||||||
def test_MPE_chain(self):
|
def test_MPE_chain(self):
|
||||||
"""
|
"""
|
||||||
Test for numerical underflow in EliminateMPE on long chains.
|
Test for numerical underflow in EliminateMPE on long chains.
|
||||||
|
@ -170,46 +179,22 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
||||||
desired_state = 1
|
desired_state = 1
|
||||||
states = list(range(num_states))
|
states = list(range(num_states))
|
||||||
|
|
||||||
# Helper function to mimic the behavior of gtbook.Variables discrete_series function
|
|
||||||
def make_key(character, index, cardinality):
|
|
||||||
symbol = Symbol(character, index)
|
|
||||||
key = symbol.key()
|
|
||||||
return (key, cardinality)
|
|
||||||
|
|
||||||
X = {index: make_key("X", index, len(states)) for index in range(num_obs)}
|
X = {index: make_key("X", index, len(states)) for index in range(num_obs)}
|
||||||
Z = {index: make_key("Z", index, num_obs + 1) for index in range(num_obs)}
|
Z = {index: make_key("Z", index, num_obs + 1) for index in range(num_obs)}
|
||||||
graph = DiscreteFactorGraph()
|
graph = DiscreteFactorGraph()
|
||||||
|
|
||||||
# Mostly identity transition matrix
|
transition_cpt = generate_transition_cpt(num_states)
|
||||||
transitions = np.eye(num_states)
|
|
||||||
|
|
||||||
# Needed otherwise mpe is always state 0?
|
|
||||||
transitions += 0.1/(num_states)
|
|
||||||
|
|
||||||
transition_cpt = []
|
|
||||||
for i in range(0, num_states):
|
|
||||||
transition_row = "/".join([str(x) for x in transitions[i]])
|
|
||||||
transition_cpt.append(transition_row)
|
|
||||||
transition_cpt = " ".join(transition_cpt)
|
|
||||||
|
|
||||||
for i in reversed(range(1, num_obs)):
|
for i in reversed(range(1, num_obs)):
|
||||||
transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt)
|
transition_conditional = DiscreteConditional(
|
||||||
|
X[i], [X[i - 1]], transition_cpt
|
||||||
|
)
|
||||||
graph.push_back(transition_conditional)
|
graph.push_back(transition_conditional)
|
||||||
|
|
||||||
# Contrived example such that the desired state gives measurements [0, num_obs) with equal probability
|
# Contrived example such that the desired state gives measurements [0, num_obs) with equal probability
|
||||||
# but all other states always give measurement num_obs
|
# but all other states always give measurement num_obs
|
||||||
obs = np.zeros((num_states, num_obs+1))
|
obs_cpt = generate_observation_cpt(num_states, num_obs, desired_state)
|
||||||
obs[:,-1] = 1
|
|
||||||
obs[desired_state,0: -1] = 1
|
|
||||||
obs[desired_state,-1] = 0
|
|
||||||
obs_cpt_list = []
|
|
||||||
for i in range(0, num_states):
|
|
||||||
obs_row = "/".join([str(z) for z in obs[i]])
|
|
||||||
obs_cpt_list.append(obs_row)
|
|
||||||
obs_cpt = " ".join(obs_cpt_list)
|
|
||||||
|
|
||||||
# Contrived example where each measurement is its own index
|
# Contrived example where each measurement is its own index
|
||||||
for i in range(0, num_obs):
|
for i in range(num_obs):
|
||||||
obs_conditional = DiscreteConditional(Z[i], [X[i]], obs_cpt)
|
obs_conditional = DiscreteConditional(Z[i], [X[i]], obs_cpt)
|
||||||
factor = obs_conditional.likelihood(i)
|
factor = obs_conditional.likelihood(i)
|
||||||
graph.push_back(factor)
|
graph.push_back(factor)
|
||||||
|
@ -217,7 +202,7 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
||||||
mpe = graph.optimize()
|
mpe = graph.optimize()
|
||||||
vals = [mpe[X[i][0]] for i in range(num_obs)]
|
vals = [mpe[X[i][0]] for i in range(num_obs)]
|
||||||
|
|
||||||
self.assertEqual(vals, [desired_state]*num_obs)
|
self.assertEqual(vals, [desired_state] * num_obs)
|
||||||
|
|
||||||
def test_sumProduct_chain(self):
|
def test_sumProduct_chain(self):
|
||||||
"""
|
"""
|
||||||
|
@ -227,15 +212,8 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
||||||
"""
|
"""
|
||||||
num_states = 3
|
num_states = 3
|
||||||
chain_length = 400
|
chain_length = 400
|
||||||
desired_state = 1
|
|
||||||
states = list(range(num_states))
|
states = list(range(num_states))
|
||||||
|
|
||||||
# Helper function to mimic the behavior of gtbook.Variables discrete_series function
|
|
||||||
def make_key(character, index, cardinality):
|
|
||||||
symbol = Symbol(character, index)
|
|
||||||
key = symbol.key()
|
|
||||||
return (key, cardinality)
|
|
||||||
|
|
||||||
X = {index: make_key("X", index, len(states)) for index in range(chain_length)}
|
X = {index: make_key("X", index, len(states)) for index in range(chain_length)}
|
||||||
graph = DiscreteFactorGraph()
|
graph = DiscreteFactorGraph()
|
||||||
|
|
||||||
|
@ -253,18 +231,15 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
||||||
|
|
||||||
# Ensure that the stationary distribution is positive and normalized
|
# Ensure that the stationary distribution is positive and normalized
|
||||||
stationary_dist /= np.sum(stationary_dist)
|
stationary_dist /= np.sum(stationary_dist)
|
||||||
expected = DecisionTreeFactor(X[chain_length-1], stationary_dist.flatten())
|
expected = DecisionTreeFactor(X[chain_length - 1], stationary_dist.ravel())
|
||||||
|
|
||||||
# The transition matrix parsed by DiscreteConditional is a row-wise CPT
|
# The transition matrix parsed by DiscreteConditional is a row-wise CPT
|
||||||
transitions = transitions.T
|
transition_cpt = generate_transition_cpt(num_states, transitions.T)
|
||||||
transition_cpt = []
|
|
||||||
for i in range(0, num_states):
|
|
||||||
transition_row = "/".join([str(x) for x in transitions[i]])
|
|
||||||
transition_cpt.append(transition_row)
|
|
||||||
transition_cpt = " ".join(transition_cpt)
|
|
||||||
|
|
||||||
for i in reversed(range(1, chain_length)):
|
for i in reversed(range(1, chain_length)):
|
||||||
transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt)
|
transition_conditional = DiscreteConditional(
|
||||||
|
X[i], [X[i - 1]], transition_cpt
|
||||||
|
)
|
||||||
graph.push_back(transition_conditional)
|
graph.push_back(transition_conditional)
|
||||||
|
|
||||||
# Run sum product using natural ordering so the resulting Bayes net has the form:
|
# Run sum product using natural ordering so the resulting Bayes net has the form:
|
||||||
|
@ -277,5 +252,6 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
||||||
# Ensure marginal probabilities are close to the stationary distribution
|
# Ensure marginal probabilities are close to the stationary distribution
|
||||||
self.gtsamAssertEquals(expected, last_marginal)
|
self.gtsamAssertEquals(expected, last_marginal)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue