Merge pull request #1453 from keevindoherty/hotfix/sumproduct
Add normalization to sum-product, avoiding underflow.release/4.3a0
commit
72cdcf8ce3
|
@ -121,8 +121,8 @@ namespace gtsam {
|
|||
for (auto&& factor : factors) product = (*factor) * product;
|
||||
gttoc(product);
|
||||
|
||||
// Sum all the potentials by pretending all keys are frontal:
|
||||
auto normalization = product.sum(product.size());
|
||||
// Max over all the potentials by pretending all keys are frontal:
|
||||
auto normalization = product.max(product.size());
|
||||
|
||||
// Normalize the product factor to prevent underflow.
|
||||
product = product / (*normalization);
|
||||
|
@ -210,6 +210,12 @@ namespace gtsam {
|
|||
for (auto&& factor : factors) product = (*factor) * product;
|
||||
gttoc(product);
|
||||
|
||||
// Max over all the potentials by pretending all keys are frontal:
|
||||
auto normalization = product.max(product.size());
|
||||
|
||||
// Normalize the product factor to prevent underflow.
|
||||
product = product / (*normalization);
|
||||
|
||||
// sum out frontals, this is the factor on the separator
|
||||
gttic(sum);
|
||||
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
|
||||
|
|
|
@ -108,7 +108,14 @@ TEST(DiscreteFactorGraph, test) {
|
|||
|
||||
// Test EliminateDiscrete
|
||||
const Ordering frontalKeys{0};
|
||||
const auto [conditional, newFactor] = EliminateDiscrete(graph, frontalKeys);
|
||||
const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys);
|
||||
|
||||
DecisionTreeFactor newFactor = *newFactorPtr;
|
||||
|
||||
// Normalize newFactor by max for comparison with expected
|
||||
auto normalization = newFactor.max(newFactor.size());
|
||||
|
||||
newFactor = newFactor / *normalization;
|
||||
|
||||
// Check Conditional
|
||||
CHECK(conditional);
|
||||
|
@ -117,9 +124,13 @@ TEST(DiscreteFactorGraph, test) {
|
|||
EXPECT(assert_equal(expectedConditional, *conditional));
|
||||
|
||||
// Check Factor
|
||||
CHECK(newFactor);
|
||||
CHECK(&newFactor);
|
||||
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
||||
EXPECT(assert_equal(expectedFactor, *newFactor));
|
||||
// Normalize by max.
|
||||
normalization = expectedFactor.max(expectedFactor.size());
|
||||
// Ensure normalization is correct.
|
||||
expectedFactor = expectedFactor / *normalization;
|
||||
EXPECT(assert_equal(expectedFactor, newFactor));
|
||||
|
||||
// Test using elimination tree
|
||||
const Ordering ordering{0, 1, 2};
|
||||
|
|
|
@ -14,7 +14,7 @@ Author: Frank Dellaert
|
|||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from gtsam import DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol
|
||||
from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
OrderingType = Ordering.OrderingType
|
||||
|
@ -216,5 +216,63 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
|||
|
||||
self.assertEqual(vals, [desired_state]*num_obs)
|
||||
|
||||
def test_sumProduct_chain(self):
|
||||
"""
|
||||
Test for numerical underflow in EliminateDiscrete on long chains.
|
||||
Adapted from the toy problem of @pcl15423
|
||||
Ref: https://github.com/borglab/gtsam/issues/1448
|
||||
"""
|
||||
num_states = 3
|
||||
chain_length = 400
|
||||
desired_state = 1
|
||||
states = list(range(num_states))
|
||||
|
||||
# Helper function to mimic the behavior of gtbook.Variables discrete_series function
|
||||
def make_key(character, index, cardinality):
|
||||
symbol = Symbol(character, index)
|
||||
key = symbol.key()
|
||||
return (key, cardinality)
|
||||
|
||||
X = {index: make_key("X", index, len(states)) for index in range(chain_length)}
|
||||
graph = DiscreteFactorGraph()
|
||||
|
||||
# Construct test transition matrix
|
||||
transitions = np.diag([1.0, 0.5, 0.1])
|
||||
transitions += 0.1/(num_states)
|
||||
|
||||
# Ensure that the transition matrix is Markov (columns sum to 1)
|
||||
transitions /= np.sum(transitions, axis=0)
|
||||
|
||||
# The stationary distribution is the eigenvector corresponding to eigenvalue 1
|
||||
eigvals, eigvecs = np.linalg.eig(transitions)
|
||||
stationary_idx = np.where(np.isclose(eigvals, 1.0))
|
||||
stationary_dist = eigvecs[:, stationary_idx]
|
||||
|
||||
# Ensure that the stationary distribution is positive and normalized
|
||||
stationary_dist /= np.sum(stationary_dist)
|
||||
expected = DecisionTreeFactor(X[chain_length-1], stationary_dist.flatten())
|
||||
|
||||
# The transition matrix parsed by DiscreteConditional is a row-wise CPT
|
||||
transitions = transitions.T
|
||||
transition_cpt = []
|
||||
for i in range(0, num_states):
|
||||
transition_row = "/".join([str(x) for x in transitions[i]])
|
||||
transition_cpt.append(transition_row)
|
||||
transition_cpt = " ".join(transition_cpt)
|
||||
|
||||
for i in reversed(range(1, chain_length)):
|
||||
transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt)
|
||||
graph.push_back(transition_conditional)
|
||||
|
||||
# Run sum product using natural ordering so the resulting Bayes net has the form:
|
||||
# X_0 <- X_1 <- ... <- X_n
|
||||
sum_product = graph.sumProduct(OrderingType.NATURAL)
|
||||
|
||||
# Get the DiscreteConditional representing the marginal on the last factor
|
||||
last_marginal = sum_product.at(chain_length - 1)
|
||||
|
||||
# Ensure marginal probabilities are close to the stationary distribution
|
||||
self.gtsamAssertEquals(expected, last_marginal)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue