From 0710a8a8939fd5fc21445a9df01824095ec8424b Mon Sep 17 00:00:00 2001 From: Kevin Date: Wed, 8 Feb 2023 14:15:32 -0500 Subject: [PATCH 1/6] Add normalization trick to sum-product. --- gtsam/discrete/DiscreteFactorGraph.cpp | 6 ++ .../gtsam/tests/test_DiscreteFactorGraph.py | 65 +++++++++++++++++++ 2 files changed, 71 insertions(+) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 2073164c3..7d07043b2 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -210,6 +210,12 @@ namespace gtsam { for (auto&& factor : factors) product = (*factor) * product; gttoc(product); + // Sum all the potentials by pretending all keys are frontal: + auto normalization = product.sum(product.size()); + + // Normalize the product factor to prevent underflow. + product = product / (*normalization); + // sum out frontals, this is the factor on the separator gttic(sum); DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys); diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py index 42dc807b0..6990a995e 100644 --- a/python/gtsam/tests/test_DiscreteFactorGraph.py +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -216,5 +216,70 @@ class TestDiscreteFactorGraph(GtsamTestCase): self.assertEqual(vals, [desired_state]*num_obs) + def test_sumProduct_chain(self): + """ + Test for numerical underflow in EliminateDiscrete on long chains. + Adapted from the toy problem of @pcl15423 + Ref: https://github.com/borglab/gtsam/issues/1448 + """ + num_states = 3 + num_obs = 200 + desired_state = 1 + states = list(range(num_states)) + + # Helper function to mimic the behavior of gtbook.Variables discrete_series function + def make_key(character, index, cardinality): + symbol = Symbol(character, index) + key = symbol.key() + return (key, cardinality) + + X = {index: make_key("X", index, len(states)) for index in range(num_obs)} + Z = {index: make_key("Z", index, num_obs + 1) for index in range(num_obs)} + graph = DiscreteFactorGraph() + + # Mostly identity transition matrix + transitions = np.eye(num_states) + + # Needed otherwise mpe is always state 0? + transitions += 0.1/(num_states) + + transition_cpt = [] + for i in range(0, num_states): + transition_row = "/".join([str(x) for x in transitions[i]]) + transition_cpt.append(transition_row) + transition_cpt = " ".join(transition_cpt) + + for i in reversed(range(1, num_obs)): + transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt) + graph.push_back(transition_conditional) + + # Contrived example such that the desired state gives measurements [0, num_obs) with equal probability + # but all other states always give measurement num_obs + obs = np.zeros((num_states, num_obs+1)) + obs[:,-1] = 1 + obs[desired_state,0: -1] = 1 + obs[desired_state,-1] = 0 + obs_cpt_list = [] + for i in range(0, num_states): + obs_row = "/".join([str(z) for z in obs[i]]) + obs_cpt_list.append(obs_row) + obs_cpt = " ".join(obs_cpt_list) + + # Contrived example where each measurement is its own index + for i in range(0, num_obs): + obs_conditional = DiscreteConditional(Z[i], [X[i]], obs_cpt) + factor = obs_conditional.likelihood(i) + graph.push_back(factor) + + mpe = graph.optimize() + vals = [mpe[X[i][0]] for i in range(num_obs)] + sum_product = graph.sumProduct() + + print("This should have 9 potential assignments", sum_product.at(0)) + + print("This should have 9 potential assignments", sum_product.at(138)) + + self.assertEqual(vals, [desired_state]*num_obs) + if __name__ == "__main__": unittest.main() From 29358f826b0d7f8d74a5d4c916b56c6364771648 Mon Sep 17 00:00:00 2001 From: Kevin Date: Thu, 9 Feb 2023 14:23:52 -0500 Subject: [PATCH 2/6] Patch discrete factor graph test to normalize expected result. --- gtsam/discrete/tests/testDiscreteFactorGraph.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index bb23b7a83..8226a81dd 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -119,6 +119,9 @@ TEST(DiscreteFactorGraph, test) { // Check Factor CHECK(newFactor); DecisionTreeFactor expectedFactor(B & A, "10 6 6 10"); + auto normalization = expectedFactor.sum(expectedFactor.size()); + // Ensure normalization is correct. + expectedFactor = expectedFactor / *normalization; EXPECT(assert_equal(expectedFactor, *newFactor)); // Test using elimination tree From 548509f28b40ca1112e79609cb974c44e5a2420e Mon Sep 17 00:00:00 2001 From: Kevin Date: Fri, 10 Feb 2023 17:02:27 -0500 Subject: [PATCH 3/6] First pass at underflow test for sum-product. --- .../gtsam/tests/test_DiscreteFactorGraph.py | 65 ++++++++++--------- 1 file changed, 33 insertions(+), 32 deletions(-) diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py index 6990a995e..13e84c5b8 100644 --- a/python/gtsam/tests/test_DiscreteFactorGraph.py +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -223,7 +223,7 @@ class TestDiscreteFactorGraph(GtsamTestCase): Ref: https://github.com/borglab/gtsam/issues/1448 """ num_states = 3 - num_obs = 200 + chain_length = 400 desired_state = 1 states = list(range(num_states)) @@ -233,53 +233,54 @@ class TestDiscreteFactorGraph(GtsamTestCase): key = symbol.key() return (key, cardinality) - X = {index: make_key("X", index, len(states)) for index in range(num_obs)} - Z = {index: make_key("Z", index, num_obs + 1) for index in range(num_obs)} + X = {index: make_key("X", index, len(states)) for index in range(chain_length)} graph = DiscreteFactorGraph() - # Mostly identity transition matrix - transitions = np.eye(num_states) - - # Needed otherwise mpe is always state 0? + # Construct test transition matrix + transitions = np.diag([1.0, 0.5, 0.1]) transitions += 0.1/(num_states) + # Ensure that the transition matrix is Markov (columns sum to 1) + transitions /= np.sum(transitions, axis=0) + + # The stationary distribution is the eigenvector corresponding to eigenvalue 1 + eigvals, eigvecs = np.linalg.eig(transitions) + stationary_idx = np.where(np.isclose(eigvals, 1.0)) + stationary_dist = eigvecs[:, stationary_idx] + + # Ensure that the stationary distribution is positive and normalized + stationary_dist /= np.sum(stationary_dist) + expected = stationary_dist.flatten() + + # The transition matrix parsed by DiscreteConditional is a row-wise CPT + transitions = transitions.T transition_cpt = [] for i in range(0, num_states): transition_row = "/".join([str(x) for x in transitions[i]]) transition_cpt.append(transition_row) transition_cpt = " ".join(transition_cpt) - for i in reversed(range(1, num_obs)): + for i in reversed(range(1, chain_length)): transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt) graph.push_back(transition_conditional) - # Contrived example such that the desired state gives measurements [0, num_obs) with equal probability - # but all other states always give measurement num_obs - obs = np.zeros((num_states, num_obs+1)) - obs[:,-1] = 1 - obs[desired_state,0: -1] = 1 - obs[desired_state,-1] = 0 - obs_cpt_list = [] - for i in range(0, num_states): - obs_row = "/".join([str(z) for z in obs[i]]) - obs_cpt_list.append(obs_row) - obs_cpt = " ".join(obs_cpt_list) + # Run sum product using natural ordering so the resulting Bayes net has the form: + # X_0 <- X_1 <- ... <- X_n + sum_product = graph.sumProduct(OrderingType.NATURAL) - # Contrived example where each measurement is its own index - for i in range(0, num_obs): - obs_conditional = DiscreteConditional(Z[i], [X[i]], obs_cpt) - factor = obs_conditional.likelihood(i) - graph.push_back(factor) + # Get the DiscreteConditional representing the marginal on the last factor + last_marginal = sum_product.at(chain_length - 1) - mpe = graph.optimize() - vals = [mpe[X[i][0]] for i in range(num_obs)] - sum_product = graph.sumProduct() + # Extract the actual marginal probabilities + assignment = DiscreteValues() + marginal_probs = [] + for i in range(num_states): + assignment[X[chain_length - 1][0]] = i + marginal_probs.append(last_marginal(assignment)) + marginal_probs = np.array(marginal_probs) - print("This should have 9 potential assignments", sum_product.at(0)) - - print("This should have 9 potential assignments", sum_product.at(138)) - - self.assertEqual(vals, [desired_state]*num_obs) + # Ensure marginal probabilities are close to the stationary distribution + self.gtsamAssertEquals(expected, marginal_probs) if __name__ == "__main__": unittest.main() From 70fa5681319330f33a7b0779c7a7d976a33e6d61 Mon Sep 17 00:00:00 2001 From: Kevin Date: Sat, 11 Feb 2023 19:59:16 -0500 Subject: [PATCH 4/6] Normalize products by max in discrete elimination. --- gtsam/discrete/DiscreteFactorGraph.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 7d07043b2..4ededbb8b 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -121,8 +121,8 @@ namespace gtsam { for (auto&& factor : factors) product = (*factor) * product; gttoc(product); - // Sum all the potentials by pretending all keys are frontal: - auto normalization = product.sum(product.size()); + // Max over all the potentials by pretending all keys are frontal: + auto normalization = product.max(product.size()); // Normalize the product factor to prevent underflow. product = product / (*normalization); @@ -210,8 +210,8 @@ namespace gtsam { for (auto&& factor : factors) product = (*factor) * product; gttoc(product); - // Sum all the potentials by pretending all keys are frontal: - auto normalization = product.sum(product.size()); + // Max over all the potentials by pretending all keys are frontal: + auto normalization = product.max(product.size()); // Normalize the product factor to prevent underflow. product = product / (*normalization); From 9fa2d30362bb6cf5a54b5de66268e4ce239251d2 Mon Sep 17 00:00:00 2001 From: Kevin Date: Sat, 11 Feb 2023 20:00:26 -0500 Subject: [PATCH 5/6] Test sum-product in discrete factor graph up to scale. --- gtsam/discrete/tests/testDiscreteFactorGraph.cpp | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index 8226a81dd..bbce5e8ce 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -108,7 +108,14 @@ TEST(DiscreteFactorGraph, test) { // Test EliminateDiscrete const Ordering frontalKeys{0}; - const auto [conditional, newFactor] = EliminateDiscrete(graph, frontalKeys); + const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys); + + DecisionTreeFactor newFactor = *newFactorPtr; + + // Normalize newFactor by max for comparison with expected + auto normalization = newFactor.max(newFactor.size()); + + newFactor = newFactor / *normalization; // Check Conditional CHECK(conditional); @@ -117,12 +124,13 @@ TEST(DiscreteFactorGraph, test) { EXPECT(assert_equal(expectedConditional, *conditional)); // Check Factor - CHECK(newFactor); + CHECK(&newFactor); DecisionTreeFactor expectedFactor(B & A, "10 6 6 10"); - auto normalization = expectedFactor.sum(expectedFactor.size()); + // Normalize by max. + normalization = expectedFactor.max(expectedFactor.size()); // Ensure normalization is correct. expectedFactor = expectedFactor / *normalization; - EXPECT(assert_equal(expectedFactor, *newFactor)); + EXPECT(assert_equal(expectedFactor, newFactor)); // Test using elimination tree const Ordering ordering{0, 1, 2}; From 92443f537861ec26075b25a08a7b2d2f30117b8c Mon Sep 17 00:00:00 2001 From: Kevin Date: Sat, 11 Feb 2023 20:04:43 -0500 Subject: [PATCH 6/6] Simply sum-product test in Python. --- python/gtsam/tests/test_DiscreteFactorGraph.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py index 13e84c5b8..d725ceac8 100644 --- a/python/gtsam/tests/test_DiscreteFactorGraph.py +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -14,7 +14,7 @@ Author: Frank Dellaert import unittest import numpy as np -from gtsam import DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol +from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol from gtsam.utils.test_case import GtsamTestCase OrderingType = Ordering.OrderingType @@ -250,7 +250,7 @@ class TestDiscreteFactorGraph(GtsamTestCase): # Ensure that the stationary distribution is positive and normalized stationary_dist /= np.sum(stationary_dist) - expected = stationary_dist.flatten() + expected = DecisionTreeFactor(X[chain_length-1], stationary_dist.flatten()) # The transition matrix parsed by DiscreteConditional is a row-wise CPT transitions = transitions.T @@ -271,16 +271,8 @@ class TestDiscreteFactorGraph(GtsamTestCase): # Get the DiscreteConditional representing the marginal on the last factor last_marginal = sum_product.at(chain_length - 1) - # Extract the actual marginal probabilities - assignment = DiscreteValues() - marginal_probs = [] - for i in range(num_states): - assignment[X[chain_length - 1][0]] = i - marginal_probs.append(last_marginal(assignment)) - marginal_probs = np.array(marginal_probs) - # Ensure marginal probabilities are close to the stationary distribution - self.gtsamAssertEquals(expected, marginal_probs) + self.gtsamAssertEquals(expected, last_marginal) if __name__ == "__main__": unittest.main()