Merge pull request #1453 from keevindoherty/hotfix/sumproduct

Add normalization to sum-product, avoiding underflow.
release/4.3a0
Frank Dellaert 2023-02-11 22:46:50 -08:00 committed by GitHub
commit 72cdcf8ce3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 81 additions and 6 deletions

View File

@ -121,8 +121,8 @@ namespace gtsam {
for (auto&& factor : factors) product = (*factor) * product; for (auto&& factor : factors) product = (*factor) * product;
gttoc(product); gttoc(product);
// Sum all the potentials by pretending all keys are frontal: // Max over all the potentials by pretending all keys are frontal:
auto normalization = product.sum(product.size()); auto normalization = product.max(product.size());
// Normalize the product factor to prevent underflow. // Normalize the product factor to prevent underflow.
product = product / (*normalization); product = product / (*normalization);
@ -210,6 +210,12 @@ namespace gtsam {
for (auto&& factor : factors) product = (*factor) * product; for (auto&& factor : factors) product = (*factor) * product;
gttoc(product); gttoc(product);
// Max over all the potentials by pretending all keys are frontal:
auto normalization = product.max(product.size());
// Normalize the product factor to prevent underflow.
product = product / (*normalization);
// sum out frontals, this is the factor on the separator // sum out frontals, this is the factor on the separator
gttic(sum); gttic(sum);
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys); DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);

View File

@ -108,7 +108,14 @@ TEST(DiscreteFactorGraph, test) {
// Test EliminateDiscrete // Test EliminateDiscrete
const Ordering frontalKeys{0}; const Ordering frontalKeys{0};
const auto [conditional, newFactor] = EliminateDiscrete(graph, frontalKeys); const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys);
DecisionTreeFactor newFactor = *newFactorPtr;
// Normalize newFactor by max for comparison with expected
auto normalization = newFactor.max(newFactor.size());
newFactor = newFactor / *normalization;
// Check Conditional // Check Conditional
CHECK(conditional); CHECK(conditional);
@ -117,9 +124,13 @@ TEST(DiscreteFactorGraph, test) {
EXPECT(assert_equal(expectedConditional, *conditional)); EXPECT(assert_equal(expectedConditional, *conditional));
// Check Factor // Check Factor
CHECK(newFactor); CHECK(&newFactor);
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10"); DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
EXPECT(assert_equal(expectedFactor, *newFactor)); // Normalize by max.
normalization = expectedFactor.max(expectedFactor.size());
// Ensure normalization is correct.
expectedFactor = expectedFactor / *normalization;
EXPECT(assert_equal(expectedFactor, newFactor));
// Test using elimination tree // Test using elimination tree
const Ordering ordering{0, 1, 2}; const Ordering ordering{0, 1, 2};

View File

@ -14,7 +14,7 @@ Author: Frank Dellaert
import unittest import unittest
import numpy as np import numpy as np
from gtsam import DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
OrderingType = Ordering.OrderingType OrderingType = Ordering.OrderingType
@ -216,5 +216,63 @@ class TestDiscreteFactorGraph(GtsamTestCase):
self.assertEqual(vals, [desired_state]*num_obs) self.assertEqual(vals, [desired_state]*num_obs)
def test_sumProduct_chain(self):
"""
Test for numerical underflow in EliminateDiscrete on long chains.
Adapted from the toy problem of @pcl15423
Ref: https://github.com/borglab/gtsam/issues/1448
"""
num_states = 3
chain_length = 400
desired_state = 1
states = list(range(num_states))
# Helper function to mimic the behavior of gtbook.Variables discrete_series function
def make_key(character, index, cardinality):
symbol = Symbol(character, index)
key = symbol.key()
return (key, cardinality)
X = {index: make_key("X", index, len(states)) for index in range(chain_length)}
graph = DiscreteFactorGraph()
# Construct test transition matrix
transitions = np.diag([1.0, 0.5, 0.1])
transitions += 0.1/(num_states)
# Ensure that the transition matrix is Markov (columns sum to 1)
transitions /= np.sum(transitions, axis=0)
# The stationary distribution is the eigenvector corresponding to eigenvalue 1
eigvals, eigvecs = np.linalg.eig(transitions)
stationary_idx = np.where(np.isclose(eigvals, 1.0))
stationary_dist = eigvecs[:, stationary_idx]
# Ensure that the stationary distribution is positive and normalized
stationary_dist /= np.sum(stationary_dist)
expected = DecisionTreeFactor(X[chain_length-1], stationary_dist.flatten())
# The transition matrix parsed by DiscreteConditional is a row-wise CPT
transitions = transitions.T
transition_cpt = []
for i in range(0, num_states):
transition_row = "/".join([str(x) for x in transitions[i]])
transition_cpt.append(transition_row)
transition_cpt = " ".join(transition_cpt)
for i in reversed(range(1, chain_length)):
transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt)
graph.push_back(transition_conditional)
# Run sum product using natural ordering so the resulting Bayes net has the form:
# X_0 <- X_1 <- ... <- X_n
sum_product = graph.sumProduct(OrderingType.NATURAL)
# Get the DiscreteConditional representing the marginal on the last factor
last_marginal = sum_product.at(chain_length - 1)
# Ensure marginal probabilities are close to the stationary distribution
self.gtsamAssertEquals(expected, last_marginal)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()