enumerate
parent
c622dde7a7
commit
911819c7f2
|
|
@ -134,17 +134,33 @@ namespace gtsam {
|
||||||
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
|
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate() const {
|
||||||
|
// Get all possible assignments
|
||||||
|
std::vector<std::pair<Key, size_t>> pairs;
|
||||||
|
for (auto& key : keys()) {
|
||||||
|
pairs.emplace_back(key, cardinalities_.at(key));
|
||||||
|
}
|
||||||
|
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
||||||
|
const auto assignments = cartesianProduct(rpairs);
|
||||||
|
|
||||||
|
// Construct unordered_map with values
|
||||||
|
std::vector<std::pair<DiscreteValues, double>> result;
|
||||||
|
for (const auto& assignment : assignments) {
|
||||||
|
result.emplace_back(assignment, operator()(assignment));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
std::string DecisionTreeFactor::markdown(
|
std::string DecisionTreeFactor::markdown(
|
||||||
const KeyFormatter& keyFormatter) const {
|
const KeyFormatter& keyFormatter) const {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
|
|
||||||
// Print out header and construct argument for `cartesianProduct`.
|
// Print out header and construct argument for `cartesianProduct`.
|
||||||
std::vector<std::pair<Key, size_t>> pairs;
|
|
||||||
ss << "|";
|
ss << "|";
|
||||||
for (auto& key : keys()) {
|
for (auto& key : keys()) {
|
||||||
ss << keyFormatter(key) << "|";
|
ss << keyFormatter(key) << "|";
|
||||||
pairs.emplace_back(key, cardinalities_.at(key));
|
|
||||||
}
|
}
|
||||||
ss << "value|\n";
|
ss << "value|\n";
|
||||||
|
|
||||||
|
|
@ -154,12 +170,12 @@ namespace gtsam {
|
||||||
ss << ":-:|\n";
|
ss << ":-:|\n";
|
||||||
|
|
||||||
// Print out all rows.
|
// Print out all rows.
|
||||||
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
auto rows = enumerate();
|
||||||
const auto assignments = cartesianProduct(rpairs);
|
for (const auto& kv : rows) {
|
||||||
for (const auto& assignment : assignments) {
|
|
||||||
ss << "|";
|
ss << "|";
|
||||||
|
auto assignment = kv.first;
|
||||||
for (auto& key : keys()) ss << assignment.at(key) << "|";
|
for (auto& key : keys()) ss << assignment.at(key) << "|";
|
||||||
ss << operator()(assignment) << "|\n";
|
ss << kv.second << "|\n";
|
||||||
}
|
}
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -183,6 +183,9 @@ namespace gtsam {
|
||||||
// Potentials::reduceWithInverse(inverseReduction);
|
// Potentials::reduceWithInverse(inverseReduction);
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
/// Enumerate all values into a map from values to double.
|
||||||
|
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Wrapper support
|
/// @name Wrapper support
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
|
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
|
||||||
string dot(bool showZero = false) const;
|
string dot(bool showZero = false) const;
|
||||||
|
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -82,11 +82,29 @@ TEST( DecisionTreeFactor, sum_max)
|
||||||
DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
|
DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check enumerate yields the correct list of assignment/value pairs.
|
||||||
|
TEST(DecisionTreeFactor, enumerate) {
|
||||||
|
DiscreteKey A(12, 3), B(5, 2);
|
||||||
|
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
|
||||||
|
auto actual = f.enumerate();
|
||||||
|
std::vector<std::pair<DiscreteValues, double>> expected;
|
||||||
|
DiscreteValues values;
|
||||||
|
for (size_t a : {0, 1, 2}) {
|
||||||
|
for (size_t b : {0, 1}) {
|
||||||
|
values[12] = a;
|
||||||
|
values[5] = b;
|
||||||
|
expected.emplace_back(values, f(values));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT(actual == expected);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Check markdown representation looks as expected.
|
// Check markdown representation looks as expected.
|
||||||
TEST(DecisionTreeFactor, markdown) {
|
TEST(DecisionTreeFactor, markdown) {
|
||||||
DiscreteKey A(12, 3), B(5, 2);
|
DiscreteKey A(12, 3), B(5, 2);
|
||||||
DecisionTreeFactor f1(A & B, "1 2 3 4 5 6");
|
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
|
||||||
string expected =
|
string expected =
|
||||||
"|A|B|value|\n"
|
"|A|B|value|\n"
|
||||||
"|:-:|:-:|:-:|\n"
|
"|:-:|:-:|:-:|\n"
|
||||||
|
|
@ -97,7 +115,7 @@ TEST(DecisionTreeFactor, markdown) {
|
||||||
"|2|0|5|\n"
|
"|2|0|5|\n"
|
||||||
"|2|1|6|\n";
|
"|2|1|6|\n";
|
||||||
auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||||
string actual = f1.markdown(formatter);
|
string actual = f.markdown(formatter);
|
||||||
EXPECT(actual == expected);
|
EXPECT(actual == expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,54 @@
|
||||||
|
"""
|
||||||
|
GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
|
||||||
|
Atlanta, Georgia 30332-0415
|
||||||
|
All Rights Reserved
|
||||||
|
|
||||||
|
See LICENSE for the license information
|
||||||
|
|
||||||
|
Unit tests for DecisionTreeFactors.
|
||||||
|
Author: Frank Dellaert
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=no-name-in-module, invalid-name
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from gtsam import DecisionTreeFactor, DecisionTreeFactor, DiscreteKeys
|
||||||
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestDecisionTreeFactor(GtsamTestCase):
|
||||||
|
"""Tests for DecisionTreeFactors."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
A = (12, 3)
|
||||||
|
B = (5, 2)
|
||||||
|
self.factor = DecisionTreeFactor(A, B, "1 2 3 4 5 6")
|
||||||
|
|
||||||
|
def test_enumerate(self):
|
||||||
|
actual = self.factor.enumerate()
|
||||||
|
_, values = zip(*actual)
|
||||||
|
self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
|
||||||
|
|
||||||
|
def test_markdown(self):
|
||||||
|
"""Test whether the _repr_markdown_ method."""
|
||||||
|
|
||||||
|
expected = \
|
||||||
|
"|A|B|value|\n" \
|
||||||
|
"|:-:|:-:|:-:|\n" \
|
||||||
|
"|0|0|1|\n" \
|
||||||
|
"|0|1|2|\n" \
|
||||||
|
"|1|0|3|\n" \
|
||||||
|
"|1|1|4|\n" \
|
||||||
|
"|2|0|5|\n" \
|
||||||
|
"|2|1|6|\n"
|
||||||
|
|
||||||
|
def formatter(x: int):
|
||||||
|
return "A" if x == 12 else "B"
|
||||||
|
|
||||||
|
actual = self.factor._repr_markdown_(formatter)
|
||||||
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in New Issue