From 911819c7f2d4acd8c6076e63551b094ed82380f5 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Dec 2021 13:55:11 -0500 Subject: [PATCH] enumerate --- gtsam/discrete/DecisionTreeFactor.cpp | 28 +++++++--- gtsam/discrete/DecisionTreeFactor.h | 3 ++ gtsam/discrete/discrete.i | 1 + .../discrete/tests/testDecisionTreeFactor.cpp | 22 +++++++- python/gtsam/tests/test_DecisionTreeFactor.py | 54 +++++++++++++++++++ 5 files changed, 100 insertions(+), 8 deletions(-) create mode 100644 python/gtsam/tests/test_DecisionTreeFactor.py diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 50b21fc76..768a37ab4 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -134,17 +134,33 @@ namespace gtsam { return boost::make_shared(dkeys, result); } + /* ************************************************************************* */ + std::vector> DecisionTreeFactor::enumerate() const { + // Get all possible assignments + std::vector> pairs; + for (auto& key : keys()) { + pairs.emplace_back(key, cardinalities_.at(key)); + } + std::vector> rpairs(pairs.rbegin(), pairs.rend()); + const auto assignments = cartesianProduct(rpairs); + + // Construct unordered_map with values + std::vector> result; + for (const auto& assignment : assignments) { + result.emplace_back(assignment, operator()(assignment)); + } + return result; + } + /* ************************************************************************* */ std::string DecisionTreeFactor::markdown( const KeyFormatter& keyFormatter) const { std::stringstream ss; // Print out header and construct argument for `cartesianProduct`. - std::vector> pairs; ss << "|"; for (auto& key : keys()) { ss << keyFormatter(key) << "|"; - pairs.emplace_back(key, cardinalities_.at(key)); } ss << "value|\n"; @@ -154,12 +170,12 @@ namespace gtsam { ss << ":-:|\n"; // Print out all rows. - std::vector> rpairs(pairs.rbegin(), pairs.rend()); - const auto assignments = cartesianProduct(rpairs); - for (const auto& assignment : assignments) { + auto rows = enumerate(); + for (const auto& kv : rows) { ss << "|"; + auto assignment = kv.first; for (auto& key : keys()) ss << assignment.at(key) << "|"; - ss << operator()(assignment) << "|\n"; + ss << kv.second << "|\n"; } return ss.str(); } diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index b5a037119..ad81d9eb1 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -183,6 +183,9 @@ namespace gtsam { // Potentials::reduceWithInverse(inverseReduction); // } + /// Enumerate all values into a map from values to double. + std::vector> enumerate() const; + /// @} /// @name Wrapper support /// @{ diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index daacee371..3b39374cb 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -47,6 +47,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor { gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; string dot(bool showZero = false) const; + std::vector> enumerate() const; string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; }; diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 542d16b29..6af7ca731 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -82,11 +82,29 @@ TEST( DecisionTreeFactor, sum_max) DecisionTreeFactor::shared_ptr actual22 = f2.sum(1); } +/* ************************************************************************* */ +// Check enumerate yields the correct list of assignment/value pairs. +TEST(DecisionTreeFactor, enumerate) { + DiscreteKey A(12, 3), B(5, 2); + DecisionTreeFactor f(A & B, "1 2 3 4 5 6"); + auto actual = f.enumerate(); + std::vector> expected; + DiscreteValues values; + for (size_t a : {0, 1, 2}) { + for (size_t b : {0, 1}) { + values[12] = a; + values[5] = b; + expected.emplace_back(values, f(values)); + } + } + EXPECT(actual == expected); +} + /* ************************************************************************* */ // Check markdown representation looks as expected. TEST(DecisionTreeFactor, markdown) { DiscreteKey A(12, 3), B(5, 2); - DecisionTreeFactor f1(A & B, "1 2 3 4 5 6"); + DecisionTreeFactor f(A & B, "1 2 3 4 5 6"); string expected = "|A|B|value|\n" "|:-:|:-:|:-:|\n" @@ -97,7 +115,7 @@ TEST(DecisionTreeFactor, markdown) { "|2|0|5|\n" "|2|1|6|\n"; auto formatter = [](Key key) { return key == 12 ? "A" : "B"; }; - string actual = f1.markdown(formatter); + string actual = f.markdown(formatter); EXPECT(actual == expected); } diff --git a/python/gtsam/tests/test_DecisionTreeFactor.py b/python/gtsam/tests/test_DecisionTreeFactor.py new file mode 100644 index 000000000..586d1d142 --- /dev/null +++ b/python/gtsam/tests/test_DecisionTreeFactor.py @@ -0,0 +1,54 @@ +""" +GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for DecisionTreeFactors. +Author: Frank Dellaert +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +from gtsam import DecisionTreeFactor, DecisionTreeFactor, DiscreteKeys +from gtsam.utils.test_case import GtsamTestCase + + +class TestDecisionTreeFactor(GtsamTestCase): + """Tests for DecisionTreeFactors.""" + + def setUp(self): + A = (12, 3) + B = (5, 2) + self.factor = DecisionTreeFactor(A, B, "1 2 3 4 5 6") + + def test_enumerate(self): + actual = self.factor.enumerate() + _, values = zip(*actual) + self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + + def test_markdown(self): + """Test whether the _repr_markdown_ method.""" + + expected = \ + "|A|B|value|\n" \ + "|:-:|:-:|:-:|\n" \ + "|0|0|1|\n" \ + "|0|1|2|\n" \ + "|1|0|3|\n" \ + "|1|1|4|\n" \ + "|2|0|5|\n" \ + "|2|1|6|\n" + + def formatter(x: int): + return "A" if x == 12 else "B" + + actual = self.factor._repr_markdown_(formatter) + self.assertEqual(actual, expected) + + +if __name__ == "__main__": + unittest.main()