enumerate

release/4.3a0
Frank Dellaert 2021-12-27 13:55:11 -05:00
parent c622dde7a7
commit 911819c7f2
5 changed files with 100 additions and 8 deletions

View File

@ -134,17 +134,33 @@ namespace gtsam {
return boost::make_shared<DecisionTreeFactor>(dkeys, result); return boost::make_shared<DecisionTreeFactor>(dkeys, result);
} }
/* ************************************************************************* */
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate() const {
// Get all possible assignments
std::vector<std::pair<Key, size_t>> pairs;
for (auto& key : keys()) {
pairs.emplace_back(key, cardinalities_.at(key));
}
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
const auto assignments = cartesianProduct(rpairs);
// Construct unordered_map with values
std::vector<std::pair<DiscreteValues, double>> result;
for (const auto& assignment : assignments) {
result.emplace_back(assignment, operator()(assignment));
}
return result;
}
/* ************************************************************************* */ /* ************************************************************************* */
std::string DecisionTreeFactor::markdown( std::string DecisionTreeFactor::markdown(
const KeyFormatter& keyFormatter) const { const KeyFormatter& keyFormatter) const {
std::stringstream ss; std::stringstream ss;
// Print out header and construct argument for `cartesianProduct`. // Print out header and construct argument for `cartesianProduct`.
std::vector<std::pair<Key, size_t>> pairs;
ss << "|"; ss << "|";
for (auto& key : keys()) { for (auto& key : keys()) {
ss << keyFormatter(key) << "|"; ss << keyFormatter(key) << "|";
pairs.emplace_back(key, cardinalities_.at(key));
} }
ss << "value|\n"; ss << "value|\n";
@ -154,12 +170,12 @@ namespace gtsam {
ss << ":-:|\n"; ss << ":-:|\n";
// Print out all rows. // Print out all rows.
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend()); auto rows = enumerate();
const auto assignments = cartesianProduct(rpairs); for (const auto& kv : rows) {
for (const auto& assignment : assignments) {
ss << "|"; ss << "|";
auto assignment = kv.first;
for (auto& key : keys()) ss << assignment.at(key) << "|"; for (auto& key : keys()) ss << assignment.at(key) << "|";
ss << operator()(assignment) << "|\n"; ss << kv.second << "|\n";
} }
return ss.str(); return ss.str();
} }

View File

@ -183,6 +183,9 @@ namespace gtsam {
// Potentials::reduceWithInverse(inverseReduction); // Potentials::reduceWithInverse(inverseReduction);
// } // }
/// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
/// @} /// @}
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{

View File

@ -47,6 +47,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
string dot(bool showZero = false) const; string dot(bool showZero = false) const;
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
string markdown(const gtsam::KeyFormatter& keyFormatter = string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
}; };

View File

@ -82,11 +82,29 @@ TEST( DecisionTreeFactor, sum_max)
DecisionTreeFactor::shared_ptr actual22 = f2.sum(1); DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
} }
/* ************************************************************************* */
// Check enumerate yields the correct list of assignment/value pairs.
TEST(DecisionTreeFactor, enumerate) {
DiscreteKey A(12, 3), B(5, 2);
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
auto actual = f.enumerate();
std::vector<std::pair<DiscreteValues, double>> expected;
DiscreteValues values;
for (size_t a : {0, 1, 2}) {
for (size_t b : {0, 1}) {
values[12] = a;
values[5] = b;
expected.emplace_back(values, f(values));
}
}
EXPECT(actual == expected);
}
/* ************************************************************************* */ /* ************************************************************************* */
// Check markdown representation looks as expected. // Check markdown representation looks as expected.
TEST(DecisionTreeFactor, markdown) { TEST(DecisionTreeFactor, markdown) {
DiscreteKey A(12, 3), B(5, 2); DiscreteKey A(12, 3), B(5, 2);
DecisionTreeFactor f1(A & B, "1 2 3 4 5 6"); DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
string expected = string expected =
"|A|B|value|\n" "|A|B|value|\n"
"|:-:|:-:|:-:|\n" "|:-:|:-:|:-:|\n"
@ -97,7 +115,7 @@ TEST(DecisionTreeFactor, markdown) {
"|2|0|5|\n" "|2|0|5|\n"
"|2|1|6|\n"; "|2|1|6|\n";
auto formatter = [](Key key) { return key == 12 ? "A" : "B"; }; auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
string actual = f1.markdown(formatter); string actual = f.markdown(formatter);
EXPECT(actual == expected); EXPECT(actual == expected);
} }

View File

@ -0,0 +1,54 @@
"""
GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
Atlanta, Georgia 30332-0415
All Rights Reserved
See LICENSE for the license information
Unit tests for DecisionTreeFactors.
Author: Frank Dellaert
"""
# pylint: disable=no-name-in-module, invalid-name
import unittest
from gtsam import DecisionTreeFactor, DecisionTreeFactor, DiscreteKeys
from gtsam.utils.test_case import GtsamTestCase
class TestDecisionTreeFactor(GtsamTestCase):
"""Tests for DecisionTreeFactors."""
def setUp(self):
A = (12, 3)
B = (5, 2)
self.factor = DecisionTreeFactor(A, B, "1 2 3 4 5 6")
def test_enumerate(self):
actual = self.factor.enumerate()
_, values = zip(*actual)
self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
def test_markdown(self):
"""Test whether the _repr_markdown_ method."""
expected = \
"|A|B|value|\n" \
"|:-:|:-:|:-:|\n" \
"|0|0|1|\n" \
"|0|1|2|\n" \
"|1|0|3|\n" \
"|1|1|4|\n" \
"|2|0|5|\n" \
"|2|1|6|\n"
def formatter(x: int):
return "A" if x == 12 else "B"
actual = self.factor._repr_markdown_(formatter)
self.assertEqual(actual, expected)
if __name__ == "__main__":
unittest.main()