Merge pull request #1449 from keevindoherty/hotfix/maxproduct
Add normalization to max-product, avoiding underflow.release/4.3a0
commit
4f4c6eba7e
|
@ -121,6 +121,12 @@ namespace gtsam {
|
||||||
for (auto&& factor : factors) product = (*factor) * product;
|
for (auto&& factor : factors) product = (*factor) * product;
|
||||||
gttoc(product);
|
gttoc(product);
|
||||||
|
|
||||||
|
// Sum all the potentials by pretending all keys are frontal:
|
||||||
|
auto normalization = product.sum(product.size());
|
||||||
|
|
||||||
|
// Normalize the product factor to prevent underflow.
|
||||||
|
product = product / (*normalization);
|
||||||
|
|
||||||
// max out frontals, this is the factor on the separator
|
// max out frontals, this is the factor on the separator
|
||||||
gttic(max);
|
gttic(max);
|
||||||
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
|
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
|
||||||
|
|
|
@ -13,7 +13,8 @@ Author: Frank Dellaert
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering
|
import numpy as np
|
||||||
|
from gtsam import DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
OrderingType = Ordering.OrderingType
|
OrderingType = Ordering.OrderingType
|
||||||
|
@ -155,6 +156,65 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
||||||
bayesNet = graph.sumProduct(ordering_type)
|
bayesNet = graph.sumProduct(ordering_type)
|
||||||
self.assertEqual(bayesNet(mpe), mpeProbability)
|
self.assertEqual(bayesNet(mpe), mpeProbability)
|
||||||
|
|
||||||
|
def test_MPE_chain(self):
|
||||||
|
"""
|
||||||
|
Test for numerical underflow in EliminateMPE on long chains.
|
||||||
|
Adapted from the toy problem of @pcl15423
|
||||||
|
Ref: https://github.com/borglab/gtsam/issues/1448
|
||||||
|
"""
|
||||||
|
num_states = 3
|
||||||
|
num_obs = 200
|
||||||
|
desired_state = 1
|
||||||
|
states = list(range(num_states))
|
||||||
|
|
||||||
|
# Helper function to mimic the behavior of gtbook.Variables discrete_series function
|
||||||
|
def make_key(character, index, cardinality):
|
||||||
|
symbol = Symbol(character, index)
|
||||||
|
key = symbol.key()
|
||||||
|
return (key, cardinality)
|
||||||
|
|
||||||
|
X = {index: make_key("X", index, len(states)) for index in range(num_obs)}
|
||||||
|
Z = {index: make_key("Z", index, num_obs + 1) for index in range(num_obs)}
|
||||||
|
graph = DiscreteFactorGraph()
|
||||||
|
|
||||||
|
# Mostly identity transition matrix
|
||||||
|
transitions = np.eye(num_states)
|
||||||
|
|
||||||
|
# Needed otherwise mpe is always state 0?
|
||||||
|
transitions += 0.1/(num_states)
|
||||||
|
|
||||||
|
transition_cpt = []
|
||||||
|
for i in range(0, num_states):
|
||||||
|
transition_row = "/".join([str(x) for x in transitions[i]])
|
||||||
|
transition_cpt.append(transition_row)
|
||||||
|
transition_cpt = " ".join(transition_cpt)
|
||||||
|
|
||||||
|
for i in reversed(range(1, num_obs)):
|
||||||
|
transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt)
|
||||||
|
graph.push_back(transition_conditional)
|
||||||
|
|
||||||
|
# Contrived example such that the desired state gives measurements [0, num_obs) with equal probability
|
||||||
|
# but all other states always give measurement num_obs
|
||||||
|
obs = np.zeros((num_states, num_obs+1))
|
||||||
|
obs[:,-1] = 1
|
||||||
|
obs[desired_state,0: -1] = 1
|
||||||
|
obs[desired_state,-1] = 0
|
||||||
|
obs_cpt_list = []
|
||||||
|
for i in range(0, num_states):
|
||||||
|
obs_row = "/".join([str(z) for z in obs[i]])
|
||||||
|
obs_cpt_list.append(obs_row)
|
||||||
|
obs_cpt = " ".join(obs_cpt_list)
|
||||||
|
|
||||||
|
# Contrived example where each measurement is its own index
|
||||||
|
for i in range(0, num_obs):
|
||||||
|
obs_conditional = DiscreteConditional(Z[i], [X[i]], obs_cpt)
|
||||||
|
factor = obs_conditional.likelihood(i)
|
||||||
|
graph.push_back(factor)
|
||||||
|
|
||||||
|
mpe = graph.optimize()
|
||||||
|
vals = [mpe[X[i][0]] for i in range(num_obs)]
|
||||||
|
|
||||||
|
self.assertEqual(vals, [desired_state]*num_obs)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue