Simply sum-product test in Python.
parent
9fa2d30362
commit
92443f5378
|
@ -14,7 +14,7 @@ Author: Frank Dellaert
|
|||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from gtsam import DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol
|
||||
from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
OrderingType = Ordering.OrderingType
|
||||
|
@ -250,7 +250,7 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
|||
|
||||
# Ensure that the stationary distribution is positive and normalized
|
||||
stationary_dist /= np.sum(stationary_dist)
|
||||
expected = stationary_dist.flatten()
|
||||
expected = DecisionTreeFactor(X[chain_length-1], stationary_dist.flatten())
|
||||
|
||||
# The transition matrix parsed by DiscreteConditional is a row-wise CPT
|
||||
transitions = transitions.T
|
||||
|
@ -271,16 +271,8 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
|||
# Get the DiscreteConditional representing the marginal on the last factor
|
||||
last_marginal = sum_product.at(chain_length - 1)
|
||||
|
||||
# Extract the actual marginal probabilities
|
||||
assignment = DiscreteValues()
|
||||
marginal_probs = []
|
||||
for i in range(num_states):
|
||||
assignment[X[chain_length - 1][0]] = i
|
||||
marginal_probs.append(last_marginal(assignment))
|
||||
marginal_probs = np.array(marginal_probs)
|
||||
|
||||
# Ensure marginal probabilities are close to the stationary distribution
|
||||
self.gtsamAssertEquals(expected, marginal_probs)
|
||||
self.gtsamAssertEquals(expected, last_marginal)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue