enumerate

release/4.3a0
Frank Dellaert 2021-12-27 13:55:11 -05:00
parent c622dde7a7
commit 911819c7f2
5 changed files with 100 additions and 8 deletions

View File

@ -134,17 +134,33 @@ namespace gtsam {
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
}
/* ************************************************************************* */
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate() const {
// Get all possible assignments
std::vector<std::pair<Key, size_t>> pairs;
for (auto& key : keys()) {
pairs.emplace_back(key, cardinalities_.at(key));
}
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
const auto assignments = cartesianProduct(rpairs);
// Construct unordered_map with values
std::vector<std::pair<DiscreteValues, double>> result;
for (const auto& assignment : assignments) {
result.emplace_back(assignment, operator()(assignment));
}
return result;
}
/* ************************************************************************* */
std::string DecisionTreeFactor::markdown(
const KeyFormatter& keyFormatter) const {
std::stringstream ss;
// Print out header and construct argument for `cartesianProduct`.
std::vector<std::pair<Key, size_t>> pairs;
ss << "|";
for (auto& key : keys()) {
ss << keyFormatter(key) << "|";
pairs.emplace_back(key, cardinalities_.at(key));
}
ss << "value|\n";
@ -154,12 +170,12 @@ namespace gtsam {
ss << ":-:|\n";
// Print out all rows.
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
const auto assignments = cartesianProduct(rpairs);
for (const auto& assignment : assignments) {
auto rows = enumerate();
for (const auto& kv : rows) {
ss << "|";
auto assignment = kv.first;
for (auto& key : keys()) ss << assignment.at(key) << "|";
ss << operator()(assignment) << "|\n";
ss << kv.second << "|\n";
}
return ss.str();
}

View File

@ -183,6 +183,9 @@ namespace gtsam {
// Potentials::reduceWithInverse(inverseReduction);
// }
/// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
/// @}
/// @name Wrapper support
/// @{

View File

@ -47,6 +47,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
string dot(bool showZero = false) const;
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};

View File

@ -82,11 +82,29 @@ TEST( DecisionTreeFactor, sum_max)
DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
}
/* ************************************************************************* */
// Check enumerate yields the correct list of assignment/value pairs.
TEST(DecisionTreeFactor, enumerate) {
DiscreteKey A(12, 3), B(5, 2);
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
auto actual = f.enumerate();
std::vector<std::pair<DiscreteValues, double>> expected;
DiscreteValues values;
for (size_t a : {0, 1, 2}) {
for (size_t b : {0, 1}) {
values[12] = a;
values[5] = b;
expected.emplace_back(values, f(values));
}
}
EXPECT(actual == expected);
}
/* ************************************************************************* */
// Check markdown representation looks as expected.
TEST(DecisionTreeFactor, markdown) {
DiscreteKey A(12, 3), B(5, 2);
DecisionTreeFactor f1(A & B, "1 2 3 4 5 6");
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
string expected =
"|A|B|value|\n"
"|:-:|:-:|:-:|\n"
@ -97,7 +115,7 @@ TEST(DecisionTreeFactor, markdown) {
"|2|0|5|\n"
"|2|1|6|\n";
auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
string actual = f1.markdown(formatter);
string actual = f.markdown(formatter);
EXPECT(actual == expected);
}

View File

@ -0,0 +1,54 @@
"""
GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
Atlanta, Georgia 30332-0415
All Rights Reserved
See LICENSE for the license information
Unit tests for DecisionTreeFactors.
Author: Frank Dellaert
"""
# pylint: disable=no-name-in-module, invalid-name
import unittest
from gtsam import DecisionTreeFactor, DecisionTreeFactor, DiscreteKeys
from gtsam.utils.test_case import GtsamTestCase
class TestDecisionTreeFactor(GtsamTestCase):
"""Tests for DecisionTreeFactors."""
def setUp(self):
A = (12, 3)
B = (5, 2)
self.factor = DecisionTreeFactor(A, B, "1 2 3 4 5 6")
def test_enumerate(self):
actual = self.factor.enumerate()
_, values = zip(*actual)
self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
def test_markdown(self):
"""Test whether the _repr_markdown_ method."""
expected = \
"|A|B|value|\n" \
"|:-:|:-:|:-:|\n" \
"|0|0|1|\n" \
"|0|1|2|\n" \
"|1|0|3|\n" \
"|1|1|4|\n" \
"|2|0|5|\n" \
"|2|1|6|\n"
def formatter(x: int):
return "A" if x == 12 else "B"
actual = self.factor._repr_markdown_(formatter)
self.assertEqual(actual, expected)
if __name__ == "__main__":
unittest.main()