From 09fa002bd76ab98abfc32b0b1579e6642710124e Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 25 Jan 2022 17:31:49 -0500 Subject: [PATCH] Python --- gtsam/discrete/discrete.i | 5 ++ gtsam/nonlinear/nonlinear.i | 5 ++ .../gtsam/tests/test_DiscreteFactorGraph.py | 52 ++++++++++++++++--- 3 files changed, 55 insertions(+), 7 deletions(-) diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 3f2c3e060..0dcbcc1cf 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -277,7 +277,12 @@ class DiscreteFactorGraph { double operator()(const gtsam::DiscreteValues& values) const; gtsam::DiscreteValues optimize() const; + gtsam::DiscreteLookupDAG sumProduct(); + gtsam::DiscreteLookupDAG sumProduct(gtsam::Ordering::OrderingType type); + gtsam::DiscreteLookupDAG sumProduct(const gtsam::Ordering& ordering); + gtsam::DiscreteLookupDAG maxProduct(); + gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type); gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering); gtsam::DiscreteBayesNet eliminateSequential(); diff --git a/gtsam/nonlinear/nonlinear.i b/gtsam/nonlinear/nonlinear.i index b6ab086c4..a6883d38b 100644 --- a/gtsam/nonlinear/nonlinear.i +++ b/gtsam/nonlinear/nonlinear.i @@ -111,6 +111,11 @@ size_t mrsymbolIndex(size_t key); #include class Ordering { + /// Type of ordering to use + enum OrderingType { + COLAMD, METIS, NATURAL, CUSTOM + }; + // Standard Constructors and Named Constructors Ordering(); Ordering(const gtsam::Ordering& other); diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py index 1ba145e09..ef85fc753 100644 --- a/python/gtsam/tests/test_DiscreteFactorGraph.py +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -13,9 +13,11 @@ Author: Frank Dellaert import unittest -from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues +from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering from gtsam.utils.test_case import GtsamTestCase +OrderingType = Ordering.OrderingType + class TestDiscreteFactorGraph(GtsamTestCase): """Tests for Discrete Factor Graphs.""" @@ -108,14 +110,50 @@ class TestDiscreteFactorGraph(GtsamTestCase): graph.add([C, A], "0.2 0.8 0.3 0.7") graph.add([C, B], "0.1 0.9 0.4 0.6") - actualMPE = graph.optimize() + # We know MPE + mpe = DiscreteValues() + mpe[0] = 0 + mpe[1] = 1 + mpe[2] = 1 - expectedMPE = DiscreteValues() - expectedMPE[0] = 0 - expectedMPE[1] = 1 - expectedMPE[2] = 1 + # Use maxProduct + dag = graph.maxProduct(OrderingType.COLAMD) + actualMPE = dag.argmax() self.assertEqual(list(actualMPE.items()), - list(expectedMPE.items())) + list(mpe.items())) + + # All in one + actualMPE2 = graph.optimize() + self.assertEqual(list(actualMPE2.items()), + list(mpe.items())) + + def test_sumProduct(self): + """Test sumProduct.""" + + # Declare a bunch of keys + C, A, B = (0, 2), (1, 2), (2, 2) + + # Create Factor graph + graph = DiscreteFactorGraph() + graph.add([C, A], "0.2 0.8 0.3 0.7") + graph.add([C, B], "0.1 0.9 0.4 0.6") + + # We know MPE + mpe = DiscreteValues() + mpe[0] = 0 + mpe[1] = 1 + mpe[2] = 1 + + # Use default sumProduct + bayesNet = graph.sumProduct() + mpeProbability = bayesNet(mpe) + self.assertAlmostEqual(mpeProbability, 0.36) # regression + + # Use sumProduct + for ordering_type in [OrderingType.COLAMD, OrderingType.METIS, OrderingType.NATURAL, + OrderingType.CUSTOM]: + bayesNet = graph.sumProduct(ordering_type) + self.assertEqual(bayesNet(mpe), mpeProbability) if __name__ == "__main__":