Simply sum-product test in Python.
parent
9fa2d30362
commit
92443f5378
|
@ -14,7 +14,7 @@ Author: Frank Dellaert
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gtsam import DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol
|
from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
OrderingType = Ordering.OrderingType
|
OrderingType = Ordering.OrderingType
|
||||||
|
@ -250,7 +250,7 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
||||||
|
|
||||||
# Ensure that the stationary distribution is positive and normalized
|
# Ensure that the stationary distribution is positive and normalized
|
||||||
stationary_dist /= np.sum(stationary_dist)
|
stationary_dist /= np.sum(stationary_dist)
|
||||||
expected = stationary_dist.flatten()
|
expected = DecisionTreeFactor(X[chain_length-1], stationary_dist.flatten())
|
||||||
|
|
||||||
# The transition matrix parsed by DiscreteConditional is a row-wise CPT
|
# The transition matrix parsed by DiscreteConditional is a row-wise CPT
|
||||||
transitions = transitions.T
|
transitions = transitions.T
|
||||||
|
@ -271,16 +271,8 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
||||||
# Get the DiscreteConditional representing the marginal on the last factor
|
# Get the DiscreteConditional representing the marginal on the last factor
|
||||||
last_marginal = sum_product.at(chain_length - 1)
|
last_marginal = sum_product.at(chain_length - 1)
|
||||||
|
|
||||||
# Extract the actual marginal probabilities
|
|
||||||
assignment = DiscreteValues()
|
|
||||||
marginal_probs = []
|
|
||||||
for i in range(num_states):
|
|
||||||
assignment[X[chain_length - 1][0]] = i
|
|
||||||
marginal_probs.append(last_marginal(assignment))
|
|
||||||
marginal_probs = np.array(marginal_probs)
|
|
||||||
|
|
||||||
# Ensure marginal probabilities are close to the stationary distribution
|
# Ensure marginal probabilities are close to the stationary distribution
|
||||||
self.gtsamAssertEquals(expected, marginal_probs)
|
self.gtsamAssertEquals(expected, last_marginal)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue