diff --git a/python/gtsam/tests/dfg_utils.py b/python/gtsam/tests/dfg_utils.py new file mode 100644 index 000000000..9ad521fd4 --- /dev/null +++ b/python/gtsam/tests/dfg_utils.py @@ -0,0 +1,35 @@ +import numpy as np +from gtsam import Symbol + + +def make_key(character, index, cardinality): + """ + Helper function to mimic the behavior of gtbook.Variables discrete_series function. + """ + symbol = Symbol(character, index) + key = symbol.key() + return (key, cardinality) + + +def generate_transition_cpt(num_states, transitions=None): + """ + Generate a row-wise CPT for a transition matrix. + """ + if transitions is None: + # Default to identity matrix with slight regularization + transitions = np.eye(num_states) + 0.1 / num_states + + # Ensure transitions sum to 1 if not already normalized + transitions /= np.sum(transitions, axis=1, keepdims=True) + return " ".join(["/".join(map(str, row)) for row in transitions]) + + +def generate_observation_cpt(num_states, num_obs, desired_state): + """ + Generate a row-wise CPT for observations with contrived probabilities. + """ + obs = np.zeros((num_states, num_obs + 1)) + obs[:, -1] = 1 # All states default to measurement num_obs + obs[desired_state, 0:-1] = 1 + obs[desired_state, -1] = 0 + return " ".join(["/".join(map(str, row)) for row in obs]) diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py index 3053087b4..521eeefa6 100644 --- a/python/gtsam/tests/test_DiscreteFactorGraph.py +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -15,10 +15,16 @@ import unittest import numpy as np from gtsam.utils.test_case import GtsamTestCase +from dfg_utils import make_key, generate_transition_cpt, generate_observation_cpt -from gtsam import (DecisionTreeFactor, DiscreteConditional, - DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, - Symbol) +from gtsam import ( + DecisionTreeFactor, + DiscreteConditional, + DiscreteFactorGraph, + DiscreteKeys, + DiscreteValues, + Ordering, +) OrderingType = Ordering.OrderingType @@ -50,7 +56,7 @@ class TestDiscreteFactorGraph(GtsamTestCase): assignment[1] = 1 # Check if graph evaluation works ( 0.3*0.6*4 ) - self.assertAlmostEqual(.72, graph(assignment)) + self.assertAlmostEqual(0.72, graph(assignment)) # Create a new test with third node and adding unary and ternary factor graph.add(P3, "0.9 0.2 0.5") @@ -100,8 +106,7 @@ class TestDiscreteFactorGraph(GtsamTestCase): expectedValues[1] = 0 expectedValues[2] = 0 actualValues = graph.optimize() - self.assertEqual(list(actualValues.items()), - list(expectedValues.items())) + self.assertEqual(list(actualValues.items()), list(expectedValues.items())) def test_MPE(self): """Test maximum probable explanation (MPE): same as optimize.""" @@ -123,13 +128,11 @@ class TestDiscreteFactorGraph(GtsamTestCase): # Use maxProduct dag = graph.maxProduct(OrderingType.COLAMD) actualMPE = dag.argmax() - self.assertEqual(list(actualMPE.items()), - list(mpe.items())) + self.assertEqual(list(actualMPE.items()), list(mpe.items())) # All in one actualMPE2 = graph.optimize() - self.assertEqual(list(actualMPE2.items()), - list(mpe.items())) + self.assertEqual(list(actualMPE2.items()), list(mpe.items())) def test_sumProduct(self): """Test sumProduct.""" @@ -154,11 +157,17 @@ class TestDiscreteFactorGraph(GtsamTestCase): self.assertAlmostEqual(mpeProbability, 0.36) # regression # Use sumProduct - for ordering_type in [OrderingType.COLAMD, OrderingType.METIS, OrderingType.NATURAL, - OrderingType.CUSTOM]: + for ordering_type in [ + OrderingType.COLAMD, + OrderingType.METIS, + OrderingType.NATURAL, + OrderingType.CUSTOM, + ]: bayesNet = graph.sumProduct(ordering_type) self.assertEqual(bayesNet(mpe), mpeProbability) + +class TestChains(GtsamTestCase): def test_MPE_chain(self): """ Test for numerical underflow in EliminateMPE on long chains. @@ -170,46 +179,22 @@ class TestDiscreteFactorGraph(GtsamTestCase): desired_state = 1 states = list(range(num_states)) - # Helper function to mimic the behavior of gtbook.Variables discrete_series function - def make_key(character, index, cardinality): - symbol = Symbol(character, index) - key = symbol.key() - return (key, cardinality) - X = {index: make_key("X", index, len(states)) for index in range(num_obs)} Z = {index: make_key("Z", index, num_obs + 1) for index in range(num_obs)} graph = DiscreteFactorGraph() - # Mostly identity transition matrix - transitions = np.eye(num_states) - - # Needed otherwise mpe is always state 0? - transitions += 0.1/(num_states) - - transition_cpt = [] - for i in range(0, num_states): - transition_row = "/".join([str(x) for x in transitions[i]]) - transition_cpt.append(transition_row) - transition_cpt = " ".join(transition_cpt) - + transition_cpt = generate_transition_cpt(num_states) for i in reversed(range(1, num_obs)): - transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt) + transition_conditional = DiscreteConditional( + X[i], [X[i - 1]], transition_cpt + ) graph.push_back(transition_conditional) # Contrived example such that the desired state gives measurements [0, num_obs) with equal probability # but all other states always give measurement num_obs - obs = np.zeros((num_states, num_obs+1)) - obs[:,-1] = 1 - obs[desired_state,0: -1] = 1 - obs[desired_state,-1] = 0 - obs_cpt_list = [] - for i in range(0, num_states): - obs_row = "/".join([str(z) for z in obs[i]]) - obs_cpt_list.append(obs_row) - obs_cpt = " ".join(obs_cpt_list) - + obs_cpt = generate_observation_cpt(num_states, num_obs, desired_state) # Contrived example where each measurement is its own index - for i in range(0, num_obs): + for i in range(num_obs): obs_conditional = DiscreteConditional(Z[i], [X[i]], obs_cpt) factor = obs_conditional.likelihood(i) graph.push_back(factor) @@ -217,7 +202,7 @@ class TestDiscreteFactorGraph(GtsamTestCase): mpe = graph.optimize() vals = [mpe[X[i][0]] for i in range(num_obs)] - self.assertEqual(vals, [desired_state]*num_obs) + self.assertEqual(vals, [desired_state] * num_obs) def test_sumProduct_chain(self): """ @@ -227,15 +212,8 @@ class TestDiscreteFactorGraph(GtsamTestCase): """ num_states = 3 chain_length = 400 - desired_state = 1 states = list(range(num_states)) - # Helper function to mimic the behavior of gtbook.Variables discrete_series function - def make_key(character, index, cardinality): - symbol = Symbol(character, index) - key = symbol.key() - return (key, cardinality) - X = {index: make_key("X", index, len(states)) for index in range(chain_length)} graph = DiscreteFactorGraph() @@ -253,18 +231,15 @@ class TestDiscreteFactorGraph(GtsamTestCase): # Ensure that the stationary distribution is positive and normalized stationary_dist /= np.sum(stationary_dist) - expected = DecisionTreeFactor(X[chain_length-1], stationary_dist.flatten()) + expected = DecisionTreeFactor(X[chain_length - 1], stationary_dist.ravel()) # The transition matrix parsed by DiscreteConditional is a row-wise CPT - transitions = transitions.T - transition_cpt = [] - for i in range(0, num_states): - transition_row = "/".join([str(x) for x in transitions[i]]) - transition_cpt.append(transition_row) - transition_cpt = " ".join(transition_cpt) + transition_cpt = generate_transition_cpt(num_states, transitions.T) for i in reversed(range(1, chain_length)): - transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt) + transition_conditional = DiscreteConditional( + X[i], [X[i - 1]], transition_cpt + ) graph.push_back(transition_conditional) # Run sum product using natural ordering so the resulting Bayes net has the form: @@ -277,5 +252,6 @@ class TestDiscreteFactorGraph(GtsamTestCase): # Ensure marginal probabilities are close to the stationary distribution self.gtsamAssertEquals(expected, last_marginal) + if __name__ == "__main__": unittest.main()