Simply sum-product test in Python.

release/4.3a0
Kevin 2023-02-11 20:04:43 -05:00
parent 9fa2d30362
commit 92443f5378
1 changed files with 3 additions and 11 deletions

View File

@ -14,7 +14,7 @@ Author: Frank Dellaert
import unittest
import numpy as np
from gtsam import DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol
from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol
from gtsam.utils.test_case import GtsamTestCase
OrderingType = Ordering.OrderingType
@ -250,7 +250,7 @@ class TestDiscreteFactorGraph(GtsamTestCase):
# Ensure that the stationary distribution is positive and normalized
stationary_dist /= np.sum(stationary_dist)
expected = stationary_dist.flatten()
expected = DecisionTreeFactor(X[chain_length-1], stationary_dist.flatten())
# The transition matrix parsed by DiscreteConditional is a row-wise CPT
transitions = transitions.T
@ -271,16 +271,8 @@ class TestDiscreteFactorGraph(GtsamTestCase):
# Get the DiscreteConditional representing the marginal on the last factor
last_marginal = sum_product.at(chain_length - 1)
# Extract the actual marginal probabilities
assignment = DiscreteValues()
marginal_probs = []
for i in range(num_states):
assignment[X[chain_length - 1][0]] = i
marginal_probs.append(last_marginal(assignment))
marginal_probs = np.array(marginal_probs)
# Ensure marginal probabilities are close to the stationary distribution
self.gtsamAssertEquals(expected, marginal_probs)
self.gtsamAssertEquals(expected, last_marginal)
if __name__ == "__main__":
unittest.main()