From 19773153bc24107d744611fb71c953f2abf3ef45 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 28 Jan 2025 15:57:56 -0500 Subject: [PATCH 1/3] Wrapper --- gtsam/discrete/DiscreteSearch.h | 2 ++ gtsam/discrete/discrete.i | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index b610955b2..db3dd5f03 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -161,4 +161,6 @@ class GTSAM_EXPORT DiscreteSearch { double lowerBound_; ///< Lower bound on the cost-to-go for the entire search. std::vector slots_; ///< The slots to fill in the search. }; + +using DiscreteSearchSolution = DiscreteSearch::Solution; // for wrapping } // namespace gtsam diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index b84ac69a0..5e4d8d22d 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -464,4 +464,29 @@ class DiscreteJunctionTree { const gtsam::DiscreteCluster& operator[](size_t i) const; }; +#include +class DiscreteSearchSolution { + double error; + gtsam::DiscreteValues assignment; + DiscreteSearchSolution(double error, const gtsam::DiscreteValues& assignment); +}; + +class DiscreteSearch { + static DiscreteSearch FromFactorGraph(const gtsam::DiscreteFactorGraph& factorGraph, + const gtsam::Ordering& ordering, + bool buildJunctionTree = false); + + DiscreteSearch(const gtsam::DiscreteEliminationTree& etree); + DiscreteSearch(const gtsam::DiscreteJunctionTree& junctionTree); + DiscreteSearch(const gtsam::DiscreteBayesNet& bayesNet); + DiscreteSearch(const gtsam::DiscreteBayesTree& bayesTree); + + void print(string name = "DiscreteSearch: ", + const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; + + double lowerBound() const; + + std::vector run(size_t K = 1) const; +}; + } // namespace gtsam From 615196e41504961b792c1630986c1fdf8f229ade Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 28 Jan 2025 15:58:15 -0500 Subject: [PATCH 2/3] Avoid some copy/paste --- python/gtsam/tests/dfg_utils.py | 35 +++++++ .../gtsam/tests/test_DiscreteFactorGraph.py | 92 +++++++------------ 2 files changed, 69 insertions(+), 58 deletions(-) create mode 100644 python/gtsam/tests/dfg_utils.py diff --git a/python/gtsam/tests/dfg_utils.py b/python/gtsam/tests/dfg_utils.py new file mode 100644 index 000000000..9ad521fd4 --- /dev/null +++ b/python/gtsam/tests/dfg_utils.py @@ -0,0 +1,35 @@ +import numpy as np +from gtsam import Symbol + + +def make_key(character, index, cardinality): + """ + Helper function to mimic the behavior of gtbook.Variables discrete_series function. + """ + symbol = Symbol(character, index) + key = symbol.key() + return (key, cardinality) + + +def generate_transition_cpt(num_states, transitions=None): + """ + Generate a row-wise CPT for a transition matrix. + """ + if transitions is None: + # Default to identity matrix with slight regularization + transitions = np.eye(num_states) + 0.1 / num_states + + # Ensure transitions sum to 1 if not already normalized + transitions /= np.sum(transitions, axis=1, keepdims=True) + return " ".join(["/".join(map(str, row)) for row in transitions]) + + +def generate_observation_cpt(num_states, num_obs, desired_state): + """ + Generate a row-wise CPT for observations with contrived probabilities. + """ + obs = np.zeros((num_states, num_obs + 1)) + obs[:, -1] = 1 # All states default to measurement num_obs + obs[desired_state, 0:-1] = 1 + obs[desired_state, -1] = 0 + return " ".join(["/".join(map(str, row)) for row in obs]) diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py index 3053087b4..521eeefa6 100644 --- a/python/gtsam/tests/test_DiscreteFactorGraph.py +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -15,10 +15,16 @@ import unittest import numpy as np from gtsam.utils.test_case import GtsamTestCase +from dfg_utils import make_key, generate_transition_cpt, generate_observation_cpt -from gtsam import (DecisionTreeFactor, DiscreteConditional, - DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, - Symbol) +from gtsam import ( + DecisionTreeFactor, + DiscreteConditional, + DiscreteFactorGraph, + DiscreteKeys, + DiscreteValues, + Ordering, +) OrderingType = Ordering.OrderingType @@ -50,7 +56,7 @@ class TestDiscreteFactorGraph(GtsamTestCase): assignment[1] = 1 # Check if graph evaluation works ( 0.3*0.6*4 ) - self.assertAlmostEqual(.72, graph(assignment)) + self.assertAlmostEqual(0.72, graph(assignment)) # Create a new test with third node and adding unary and ternary factor graph.add(P3, "0.9 0.2 0.5") @@ -100,8 +106,7 @@ class TestDiscreteFactorGraph(GtsamTestCase): expectedValues[1] = 0 expectedValues[2] = 0 actualValues = graph.optimize() - self.assertEqual(list(actualValues.items()), - list(expectedValues.items())) + self.assertEqual(list(actualValues.items()), list(expectedValues.items())) def test_MPE(self): """Test maximum probable explanation (MPE): same as optimize.""" @@ -123,13 +128,11 @@ class TestDiscreteFactorGraph(GtsamTestCase): # Use maxProduct dag = graph.maxProduct(OrderingType.COLAMD) actualMPE = dag.argmax() - self.assertEqual(list(actualMPE.items()), - list(mpe.items())) + self.assertEqual(list(actualMPE.items()), list(mpe.items())) # All in one actualMPE2 = graph.optimize() - self.assertEqual(list(actualMPE2.items()), - list(mpe.items())) + self.assertEqual(list(actualMPE2.items()), list(mpe.items())) def test_sumProduct(self): """Test sumProduct.""" @@ -154,11 +157,17 @@ class TestDiscreteFactorGraph(GtsamTestCase): self.assertAlmostEqual(mpeProbability, 0.36) # regression # Use sumProduct - for ordering_type in [OrderingType.COLAMD, OrderingType.METIS, OrderingType.NATURAL, - OrderingType.CUSTOM]: + for ordering_type in [ + OrderingType.COLAMD, + OrderingType.METIS, + OrderingType.NATURAL, + OrderingType.CUSTOM, + ]: bayesNet = graph.sumProduct(ordering_type) self.assertEqual(bayesNet(mpe), mpeProbability) + +class TestChains(GtsamTestCase): def test_MPE_chain(self): """ Test for numerical underflow in EliminateMPE on long chains. @@ -170,46 +179,22 @@ class TestDiscreteFactorGraph(GtsamTestCase): desired_state = 1 states = list(range(num_states)) - # Helper function to mimic the behavior of gtbook.Variables discrete_series function - def make_key(character, index, cardinality): - symbol = Symbol(character, index) - key = symbol.key() - return (key, cardinality) - X = {index: make_key("X", index, len(states)) for index in range(num_obs)} Z = {index: make_key("Z", index, num_obs + 1) for index in range(num_obs)} graph = DiscreteFactorGraph() - # Mostly identity transition matrix - transitions = np.eye(num_states) - - # Needed otherwise mpe is always state 0? - transitions += 0.1/(num_states) - - transition_cpt = [] - for i in range(0, num_states): - transition_row = "/".join([str(x) for x in transitions[i]]) - transition_cpt.append(transition_row) - transition_cpt = " ".join(transition_cpt) - + transition_cpt = generate_transition_cpt(num_states) for i in reversed(range(1, num_obs)): - transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt) + transition_conditional = DiscreteConditional( + X[i], [X[i - 1]], transition_cpt + ) graph.push_back(transition_conditional) # Contrived example such that the desired state gives measurements [0, num_obs) with equal probability # but all other states always give measurement num_obs - obs = np.zeros((num_states, num_obs+1)) - obs[:,-1] = 1 - obs[desired_state,0: -1] = 1 - obs[desired_state,-1] = 0 - obs_cpt_list = [] - for i in range(0, num_states): - obs_row = "/".join([str(z) for z in obs[i]]) - obs_cpt_list.append(obs_row) - obs_cpt = " ".join(obs_cpt_list) - + obs_cpt = generate_observation_cpt(num_states, num_obs, desired_state) # Contrived example where each measurement is its own index - for i in range(0, num_obs): + for i in range(num_obs): obs_conditional = DiscreteConditional(Z[i], [X[i]], obs_cpt) factor = obs_conditional.likelihood(i) graph.push_back(factor) @@ -217,7 +202,7 @@ class TestDiscreteFactorGraph(GtsamTestCase): mpe = graph.optimize() vals = [mpe[X[i][0]] for i in range(num_obs)] - self.assertEqual(vals, [desired_state]*num_obs) + self.assertEqual(vals, [desired_state] * num_obs) def test_sumProduct_chain(self): """ @@ -227,15 +212,8 @@ class TestDiscreteFactorGraph(GtsamTestCase): """ num_states = 3 chain_length = 400 - desired_state = 1 states = list(range(num_states)) - # Helper function to mimic the behavior of gtbook.Variables discrete_series function - def make_key(character, index, cardinality): - symbol = Symbol(character, index) - key = symbol.key() - return (key, cardinality) - X = {index: make_key("X", index, len(states)) for index in range(chain_length)} graph = DiscreteFactorGraph() @@ -253,18 +231,15 @@ class TestDiscreteFactorGraph(GtsamTestCase): # Ensure that the stationary distribution is positive and normalized stationary_dist /= np.sum(stationary_dist) - expected = DecisionTreeFactor(X[chain_length-1], stationary_dist.flatten()) + expected = DecisionTreeFactor(X[chain_length - 1], stationary_dist.ravel()) # The transition matrix parsed by DiscreteConditional is a row-wise CPT - transitions = transitions.T - transition_cpt = [] - for i in range(0, num_states): - transition_row = "/".join([str(x) for x in transitions[i]]) - transition_cpt.append(transition_row) - transition_cpt = " ".join(transition_cpt) + transition_cpt = generate_transition_cpt(num_states, transitions.T) for i in reversed(range(1, chain_length)): - transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt) + transition_conditional = DiscreteConditional( + X[i], [X[i - 1]], transition_cpt + ) graph.push_back(transition_conditional) # Run sum product using natural ordering so the resulting Bayes net has the form: @@ -277,5 +252,6 @@ class TestDiscreteFactorGraph(GtsamTestCase): # Ensure marginal probabilities are close to the stationary distribution self.gtsamAssertEquals(expected, last_marginal) + if __name__ == "__main__": unittest.main() From 5e5a67d85316a6f41df542d8940f802c663bd4d9 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 28 Jan 2025 15:58:33 -0500 Subject: [PATCH 3/3] Test search with long chain --- python/gtsam/tests/test_DiscreteSearch.py | 84 +++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 python/gtsam/tests/test_DiscreteSearch.py diff --git a/python/gtsam/tests/test_DiscreteSearch.py b/python/gtsam/tests/test_DiscreteSearch.py new file mode 100644 index 000000000..d0077f6db --- /dev/null +++ b/python/gtsam/tests/test_DiscreteSearch.py @@ -0,0 +1,84 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Search. +Author: Frank Dellaert +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +from dfg_utils import generate_observation_cpt, generate_transition_cpt, make_key +from gtsam.utils.test_case import GtsamTestCase + +from gtsam import ( + DiscreteConditional, + DiscreteFactorGraph, + DiscreteSearch, + Ordering, + DefaultKeyFormatter, +) + +OrderingType = Ordering.OrderingType + + +class TestDiscreteSearch(GtsamTestCase): + """Tests for Discrete Factor Graphs.""" + + def test_MPE_chain(self): + """ + Test for numerical underflow in EliminateMPE on long chains. + Adapted from the toy problem of @pcl15423 + Ref: https://github.com/borglab/gtsam/issues/1448 + """ + num_states = 3 + num_obs = 200 + desired_state = 1 + states = list(range(num_states)) + + X = {index: make_key("X", index, len(states)) for index in range(num_obs)} + Z = {index: make_key("Z", index, num_obs + 1) for index in range(num_obs)} + graph = DiscreteFactorGraph() + + transition_cpt = generate_transition_cpt(num_states) + for i in reversed(range(1, num_obs)): + transition_conditional = DiscreteConditional( + X[i], [X[i - 1]], transition_cpt + ) + graph.push_back(transition_conditional) + + # Contrived example such that the desired state gives measurements [0, num_obs) with equal + # probability but all other states always give measurement num_obs + obs_cpt = generate_observation_cpt(num_states, num_obs, desired_state) + # Contrived example where each measurement is its own index + for i in range(num_obs): + obs_conditional = DiscreteConditional(Z[i], [X[i]], obs_cpt) + factor = obs_conditional.likelihood(i) + graph.push_back(factor) + + # Check MPE + mpe = graph.optimize() + vals = [mpe[X[i][0]] for i in range(num_obs)] + self.assertEqual(vals, [desired_state] * num_obs) + + # Create an ordering: + ordering = Ordering() + for i in reversed(range(num_obs)): + ordering.push_back(X[i][0]) + + # Now do Search + search = DiscreteSearch.FromFactorGraph(graph, ordering) + solutions = search.run(K=1) + mpe2 = solutions[0].assignment + # print({DefaultKeyFormatter(key): value for key, value in mpe2.items()}) + vals = [mpe2[X[i][0]] for i in range(num_obs)] + self.assertEqual(vals, [desired_state] * num_obs) + + +if __name__ == "__main__": + unittest.main()