diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py index 6990a995e..13e84c5b8 100644 --- a/python/gtsam/tests/test_DiscreteFactorGraph.py +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -223,7 +223,7 @@ class TestDiscreteFactorGraph(GtsamTestCase): Ref: https://github.com/borglab/gtsam/issues/1448 """ num_states = 3 - num_obs = 200 + chain_length = 400 desired_state = 1 states = list(range(num_states)) @@ -233,53 +233,54 @@ class TestDiscreteFactorGraph(GtsamTestCase): key = symbol.key() return (key, cardinality) - X = {index: make_key("X", index, len(states)) for index in range(num_obs)} - Z = {index: make_key("Z", index, num_obs + 1) for index in range(num_obs)} + X = {index: make_key("X", index, len(states)) for index in range(chain_length)} graph = DiscreteFactorGraph() - # Mostly identity transition matrix - transitions = np.eye(num_states) - - # Needed otherwise mpe is always state 0? + # Construct test transition matrix + transitions = np.diag([1.0, 0.5, 0.1]) transitions += 0.1/(num_states) + # Ensure that the transition matrix is Markov (columns sum to 1) + transitions /= np.sum(transitions, axis=0) + + # The stationary distribution is the eigenvector corresponding to eigenvalue 1 + eigvals, eigvecs = np.linalg.eig(transitions) + stationary_idx = np.where(np.isclose(eigvals, 1.0)) + stationary_dist = eigvecs[:, stationary_idx] + + # Ensure that the stationary distribution is positive and normalized + stationary_dist /= np.sum(stationary_dist) + expected = stationary_dist.flatten() + + # The transition matrix parsed by DiscreteConditional is a row-wise CPT + transitions = transitions.T transition_cpt = [] for i in range(0, num_states): transition_row = "/".join([str(x) for x in transitions[i]]) transition_cpt.append(transition_row) transition_cpt = " ".join(transition_cpt) - for i in reversed(range(1, num_obs)): + for i in reversed(range(1, chain_length)): transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt) graph.push_back(transition_conditional) - # Contrived example such that the desired state gives measurements [0, num_obs) with equal probability - # but all other states always give measurement num_obs - obs = np.zeros((num_states, num_obs+1)) - obs[:,-1] = 1 - obs[desired_state,0: -1] = 1 - obs[desired_state,-1] = 0 - obs_cpt_list = [] - for i in range(0, num_states): - obs_row = "/".join([str(z) for z in obs[i]]) - obs_cpt_list.append(obs_row) - obs_cpt = " ".join(obs_cpt_list) + # Run sum product using natural ordering so the resulting Bayes net has the form: + # X_0 <- X_1 <- ... <- X_n + sum_product = graph.sumProduct(OrderingType.NATURAL) - # Contrived example where each measurement is its own index - for i in range(0, num_obs): - obs_conditional = DiscreteConditional(Z[i], [X[i]], obs_cpt) - factor = obs_conditional.likelihood(i) - graph.push_back(factor) + # Get the DiscreteConditional representing the marginal on the last factor + last_marginal = sum_product.at(chain_length - 1) - mpe = graph.optimize() - vals = [mpe[X[i][0]] for i in range(num_obs)] - sum_product = graph.sumProduct() + # Extract the actual marginal probabilities + assignment = DiscreteValues() + marginal_probs = [] + for i in range(num_states): + assignment[X[chain_length - 1][0]] = i + marginal_probs.append(last_marginal(assignment)) + marginal_probs = np.array(marginal_probs) - print("This should have 9 potential assignments", sum_product.at(0)) - - print("This should have 9 potential assignments", sum_product.at(138)) - - self.assertEqual(vals, [desired_state]*num_obs) + # Ensure marginal probabilities are close to the stationary distribution + self.gtsamAssertEquals(expected, marginal_probs) if __name__ == "__main__": unittest.main()