Avoid some copy/paste

release/4.3a0
Frank Dellaert 2025-01-28 15:58:15 -05:00
parent 19773153bc
commit 615196e415
2 changed files with 69 additions and 58 deletions

View File

@ -0,0 +1,35 @@
import numpy as np
from gtsam import Symbol
def make_key(character, index, cardinality):
"""
Helper function to mimic the behavior of gtbook.Variables discrete_series function.
"""
symbol = Symbol(character, index)
key = symbol.key()
return (key, cardinality)
def generate_transition_cpt(num_states, transitions=None):
"""
Generate a row-wise CPT for a transition matrix.
"""
if transitions is None:
# Default to identity matrix with slight regularization
transitions = np.eye(num_states) + 0.1 / num_states
# Ensure transitions sum to 1 if not already normalized
transitions /= np.sum(transitions, axis=1, keepdims=True)
return " ".join(["/".join(map(str, row)) for row in transitions])
def generate_observation_cpt(num_states, num_obs, desired_state):
"""
Generate a row-wise CPT for observations with contrived probabilities.
"""
obs = np.zeros((num_states, num_obs + 1))
obs[:, -1] = 1 # All states default to measurement num_obs
obs[desired_state, 0:-1] = 1
obs[desired_state, -1] = 0
return " ".join(["/".join(map(str, row)) for row in obs])

View File

@ -15,10 +15,16 @@ import unittest
import numpy as np import numpy as np
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
from dfg_utils import make_key, generate_transition_cpt, generate_observation_cpt
from gtsam import (DecisionTreeFactor, DiscreteConditional, from gtsam import (
DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, DecisionTreeFactor,
Symbol) DiscreteConditional,
DiscreteFactorGraph,
DiscreteKeys,
DiscreteValues,
Ordering,
)
OrderingType = Ordering.OrderingType OrderingType = Ordering.OrderingType
@ -50,7 +56,7 @@ class TestDiscreteFactorGraph(GtsamTestCase):
assignment[1] = 1 assignment[1] = 1
# Check if graph evaluation works ( 0.3*0.6*4 ) # Check if graph evaluation works ( 0.3*0.6*4 )
self.assertAlmostEqual(.72, graph(assignment)) self.assertAlmostEqual(0.72, graph(assignment))
# Create a new test with third node and adding unary and ternary factor # Create a new test with third node and adding unary and ternary factor
graph.add(P3, "0.9 0.2 0.5") graph.add(P3, "0.9 0.2 0.5")
@ -100,8 +106,7 @@ class TestDiscreteFactorGraph(GtsamTestCase):
expectedValues[1] = 0 expectedValues[1] = 0
expectedValues[2] = 0 expectedValues[2] = 0
actualValues = graph.optimize() actualValues = graph.optimize()
self.assertEqual(list(actualValues.items()), self.assertEqual(list(actualValues.items()), list(expectedValues.items()))
list(expectedValues.items()))
def test_MPE(self): def test_MPE(self):
"""Test maximum probable explanation (MPE): same as optimize.""" """Test maximum probable explanation (MPE): same as optimize."""
@ -123,13 +128,11 @@ class TestDiscreteFactorGraph(GtsamTestCase):
# Use maxProduct # Use maxProduct
dag = graph.maxProduct(OrderingType.COLAMD) dag = graph.maxProduct(OrderingType.COLAMD)
actualMPE = dag.argmax() actualMPE = dag.argmax()
self.assertEqual(list(actualMPE.items()), self.assertEqual(list(actualMPE.items()), list(mpe.items()))
list(mpe.items()))
# All in one # All in one
actualMPE2 = graph.optimize() actualMPE2 = graph.optimize()
self.assertEqual(list(actualMPE2.items()), self.assertEqual(list(actualMPE2.items()), list(mpe.items()))
list(mpe.items()))
def test_sumProduct(self): def test_sumProduct(self):
"""Test sumProduct.""" """Test sumProduct."""
@ -154,11 +157,17 @@ class TestDiscreteFactorGraph(GtsamTestCase):
self.assertAlmostEqual(mpeProbability, 0.36) # regression self.assertAlmostEqual(mpeProbability, 0.36) # regression
# Use sumProduct # Use sumProduct
for ordering_type in [OrderingType.COLAMD, OrderingType.METIS, OrderingType.NATURAL, for ordering_type in [
OrderingType.CUSTOM]: OrderingType.COLAMD,
OrderingType.METIS,
OrderingType.NATURAL,
OrderingType.CUSTOM,
]:
bayesNet = graph.sumProduct(ordering_type) bayesNet = graph.sumProduct(ordering_type)
self.assertEqual(bayesNet(mpe), mpeProbability) self.assertEqual(bayesNet(mpe), mpeProbability)
class TestChains(GtsamTestCase):
def test_MPE_chain(self): def test_MPE_chain(self):
""" """
Test for numerical underflow in EliminateMPE on long chains. Test for numerical underflow in EliminateMPE on long chains.
@ -170,46 +179,22 @@ class TestDiscreteFactorGraph(GtsamTestCase):
desired_state = 1 desired_state = 1
states = list(range(num_states)) states = list(range(num_states))
# Helper function to mimic the behavior of gtbook.Variables discrete_series function
def make_key(character, index, cardinality):
symbol = Symbol(character, index)
key = symbol.key()
return (key, cardinality)
X = {index: make_key("X", index, len(states)) for index in range(num_obs)} X = {index: make_key("X", index, len(states)) for index in range(num_obs)}
Z = {index: make_key("Z", index, num_obs + 1) for index in range(num_obs)} Z = {index: make_key("Z", index, num_obs + 1) for index in range(num_obs)}
graph = DiscreteFactorGraph() graph = DiscreteFactorGraph()
# Mostly identity transition matrix transition_cpt = generate_transition_cpt(num_states)
transitions = np.eye(num_states)
# Needed otherwise mpe is always state 0?
transitions += 0.1/(num_states)
transition_cpt = []
for i in range(0, num_states):
transition_row = "/".join([str(x) for x in transitions[i]])
transition_cpt.append(transition_row)
transition_cpt = " ".join(transition_cpt)
for i in reversed(range(1, num_obs)): for i in reversed(range(1, num_obs)):
transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt) transition_conditional = DiscreteConditional(
X[i], [X[i - 1]], transition_cpt
)
graph.push_back(transition_conditional) graph.push_back(transition_conditional)
# Contrived example such that the desired state gives measurements [0, num_obs) with equal probability # Contrived example such that the desired state gives measurements [0, num_obs) with equal probability
# but all other states always give measurement num_obs # but all other states always give measurement num_obs
obs = np.zeros((num_states, num_obs+1)) obs_cpt = generate_observation_cpt(num_states, num_obs, desired_state)
obs[:,-1] = 1
obs[desired_state,0: -1] = 1
obs[desired_state,-1] = 0
obs_cpt_list = []
for i in range(0, num_states):
obs_row = "/".join([str(z) for z in obs[i]])
obs_cpt_list.append(obs_row)
obs_cpt = " ".join(obs_cpt_list)
# Contrived example where each measurement is its own index # Contrived example where each measurement is its own index
for i in range(0, num_obs): for i in range(num_obs):
obs_conditional = DiscreteConditional(Z[i], [X[i]], obs_cpt) obs_conditional = DiscreteConditional(Z[i], [X[i]], obs_cpt)
factor = obs_conditional.likelihood(i) factor = obs_conditional.likelihood(i)
graph.push_back(factor) graph.push_back(factor)
@ -217,7 +202,7 @@ class TestDiscreteFactorGraph(GtsamTestCase):
mpe = graph.optimize() mpe = graph.optimize()
vals = [mpe[X[i][0]] for i in range(num_obs)] vals = [mpe[X[i][0]] for i in range(num_obs)]
self.assertEqual(vals, [desired_state]*num_obs) self.assertEqual(vals, [desired_state] * num_obs)
def test_sumProduct_chain(self): def test_sumProduct_chain(self):
""" """
@ -227,15 +212,8 @@ class TestDiscreteFactorGraph(GtsamTestCase):
""" """
num_states = 3 num_states = 3
chain_length = 400 chain_length = 400
desired_state = 1
states = list(range(num_states)) states = list(range(num_states))
# Helper function to mimic the behavior of gtbook.Variables discrete_series function
def make_key(character, index, cardinality):
symbol = Symbol(character, index)
key = symbol.key()
return (key, cardinality)
X = {index: make_key("X", index, len(states)) for index in range(chain_length)} X = {index: make_key("X", index, len(states)) for index in range(chain_length)}
graph = DiscreteFactorGraph() graph = DiscreteFactorGraph()
@ -253,18 +231,15 @@ class TestDiscreteFactorGraph(GtsamTestCase):
# Ensure that the stationary distribution is positive and normalized # Ensure that the stationary distribution is positive and normalized
stationary_dist /= np.sum(stationary_dist) stationary_dist /= np.sum(stationary_dist)
expected = DecisionTreeFactor(X[chain_length-1], stationary_dist.flatten()) expected = DecisionTreeFactor(X[chain_length - 1], stationary_dist.ravel())
# The transition matrix parsed by DiscreteConditional is a row-wise CPT # The transition matrix parsed by DiscreteConditional is a row-wise CPT
transitions = transitions.T transition_cpt = generate_transition_cpt(num_states, transitions.T)
transition_cpt = []
for i in range(0, num_states):
transition_row = "/".join([str(x) for x in transitions[i]])
transition_cpt.append(transition_row)
transition_cpt = " ".join(transition_cpt)
for i in reversed(range(1, chain_length)): for i in reversed(range(1, chain_length)):
transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt) transition_conditional = DiscreteConditional(
X[i], [X[i - 1]], transition_cpt
)
graph.push_back(transition_conditional) graph.push_back(transition_conditional)
# Run sum product using natural ordering so the resulting Bayes net has the form: # Run sum product using natural ordering so the resulting Bayes net has the form:
@ -277,5 +252,6 @@ class TestDiscreteFactorGraph(GtsamTestCase):
# Ensure marginal probabilities are close to the stationary distribution # Ensure marginal probabilities are close to the stationary distribution
self.gtsamAssertEquals(expected, last_marginal) self.gtsamAssertEquals(expected, last_marginal)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()