Simply sum-product test in Python.

release/4.3a0
Kevin 2023-02-11 20:04:43 -05:00
parent 9fa2d30362
commit 92443f5378
1 changed files with 3 additions and 11 deletions

View File

@ -14,7 +14,7 @@ Author: Frank Dellaert
import unittest import unittest
import numpy as np import numpy as np
from gtsam import DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
OrderingType = Ordering.OrderingType OrderingType = Ordering.OrderingType
@ -250,7 +250,7 @@ class TestDiscreteFactorGraph(GtsamTestCase):
# Ensure that the stationary distribution is positive and normalized # Ensure that the stationary distribution is positive and normalized
stationary_dist /= np.sum(stationary_dist) stationary_dist /= np.sum(stationary_dist)
expected = stationary_dist.flatten() expected = DecisionTreeFactor(X[chain_length-1], stationary_dist.flatten())
# The transition matrix parsed by DiscreteConditional is a row-wise CPT # The transition matrix parsed by DiscreteConditional is a row-wise CPT
transitions = transitions.T transitions = transitions.T
@ -271,16 +271,8 @@ class TestDiscreteFactorGraph(GtsamTestCase):
# Get the DiscreteConditional representing the marginal on the last factor # Get the DiscreteConditional representing the marginal on the last factor
last_marginal = sum_product.at(chain_length - 1) last_marginal = sum_product.at(chain_length - 1)
# Extract the actual marginal probabilities
assignment = DiscreteValues()
marginal_probs = []
for i in range(num_states):
assignment[X[chain_length - 1][0]] = i
marginal_probs.append(last_marginal(assignment))
marginal_probs = np.array(marginal_probs)
# Ensure marginal probabilities are close to the stationary distribution # Ensure marginal probabilities are close to the stationary distribution
self.gtsamAssertEquals(expected, marginal_probs) self.gtsamAssertEquals(expected, last_marginal)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()