First pass at underflow test for sum-product.

release/4.3a0
Kevin 2023-02-10 17:02:27 -05:00
parent 29358f826b
commit 548509f28b
1 changed files with 33 additions and 32 deletions

View File

@ -223,7 +223,7 @@ class TestDiscreteFactorGraph(GtsamTestCase):
Ref: https://github.com/borglab/gtsam/issues/1448
"""
num_states = 3
num_obs = 200
chain_length = 400
desired_state = 1
states = list(range(num_states))
@ -233,53 +233,54 @@ class TestDiscreteFactorGraph(GtsamTestCase):
key = symbol.key()
return (key, cardinality)
X = {index: make_key("X", index, len(states)) for index in range(num_obs)}
Z = {index: make_key("Z", index, num_obs + 1) for index in range(num_obs)}
X = {index: make_key("X", index, len(states)) for index in range(chain_length)}
graph = DiscreteFactorGraph()
# Mostly identity transition matrix
transitions = np.eye(num_states)
# Needed otherwise mpe is always state 0?
# Construct test transition matrix
transitions = np.diag([1.0, 0.5, 0.1])
transitions += 0.1/(num_states)
# Ensure that the transition matrix is Markov (columns sum to 1)
transitions /= np.sum(transitions, axis=0)
# The stationary distribution is the eigenvector corresponding to eigenvalue 1
eigvals, eigvecs = np.linalg.eig(transitions)
stationary_idx = np.where(np.isclose(eigvals, 1.0))
stationary_dist = eigvecs[:, stationary_idx]
# Ensure that the stationary distribution is positive and normalized
stationary_dist /= np.sum(stationary_dist)
expected = stationary_dist.flatten()
# The transition matrix parsed by DiscreteConditional is a row-wise CPT
transitions = transitions.T
transition_cpt = []
for i in range(0, num_states):
transition_row = "/".join([str(x) for x in transitions[i]])
transition_cpt.append(transition_row)
transition_cpt = " ".join(transition_cpt)
for i in reversed(range(1, num_obs)):
for i in reversed(range(1, chain_length)):
transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt)
graph.push_back(transition_conditional)
# Contrived example such that the desired state gives measurements [0, num_obs) with equal probability
# but all other states always give measurement num_obs
obs = np.zeros((num_states, num_obs+1))
obs[:,-1] = 1
obs[desired_state,0: -1] = 1
obs[desired_state,-1] = 0
obs_cpt_list = []
for i in range(0, num_states):
obs_row = "/".join([str(z) for z in obs[i]])
obs_cpt_list.append(obs_row)
obs_cpt = " ".join(obs_cpt_list)
# Run sum product using natural ordering so the resulting Bayes net has the form:
# X_0 <- X_1 <- ... <- X_n
sum_product = graph.sumProduct(OrderingType.NATURAL)
# Contrived example where each measurement is its own index
for i in range(0, num_obs):
obs_conditional = DiscreteConditional(Z[i], [X[i]], obs_cpt)
factor = obs_conditional.likelihood(i)
graph.push_back(factor)
# Get the DiscreteConditional representing the marginal on the last factor
last_marginal = sum_product.at(chain_length - 1)
mpe = graph.optimize()
vals = [mpe[X[i][0]] for i in range(num_obs)]
sum_product = graph.sumProduct()
# Extract the actual marginal probabilities
assignment = DiscreteValues()
marginal_probs = []
for i in range(num_states):
assignment[X[chain_length - 1][0]] = i
marginal_probs.append(last_marginal(assignment))
marginal_probs = np.array(marginal_probs)
print("This should have 9 potential assignments", sum_product.at(0))
print("This should have 9 potential assignments", sum_product.at(138))
self.assertEqual(vals, [desired_state]*num_obs)
# Ensure marginal probabilities are close to the stationary distribution
self.gtsamAssertEquals(expected, marginal_probs)
if __name__ == "__main__":
unittest.main()