enumerate
parent
c622dde7a7
commit
911819c7f2
|
|
@ -134,17 +134,33 @@ namespace gtsam {
|
|||
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate() const {
|
||||
// Get all possible assignments
|
||||
std::vector<std::pair<Key, size_t>> pairs;
|
||||
for (auto& key : keys()) {
|
||||
pairs.emplace_back(key, cardinalities_.at(key));
|
||||
}
|
||||
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
||||
const auto assignments = cartesianProduct(rpairs);
|
||||
|
||||
// Construct unordered_map with values
|
||||
std::vector<std::pair<DiscreteValues, double>> result;
|
||||
for (const auto& assignment : assignments) {
|
||||
result.emplace_back(assignment, operator()(assignment));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
std::string DecisionTreeFactor::markdown(
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
std::stringstream ss;
|
||||
|
||||
// Print out header and construct argument for `cartesianProduct`.
|
||||
std::vector<std::pair<Key, size_t>> pairs;
|
||||
ss << "|";
|
||||
for (auto& key : keys()) {
|
||||
ss << keyFormatter(key) << "|";
|
||||
pairs.emplace_back(key, cardinalities_.at(key));
|
||||
}
|
||||
ss << "value|\n";
|
||||
|
||||
|
|
@ -154,12 +170,12 @@ namespace gtsam {
|
|||
ss << ":-:|\n";
|
||||
|
||||
// Print out all rows.
|
||||
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
||||
const auto assignments = cartesianProduct(rpairs);
|
||||
for (const auto& assignment : assignments) {
|
||||
auto rows = enumerate();
|
||||
for (const auto& kv : rows) {
|
||||
ss << "|";
|
||||
auto assignment = kv.first;
|
||||
for (auto& key : keys()) ss << assignment.at(key) << "|";
|
||||
ss << operator()(assignment) << "|\n";
|
||||
ss << kv.second << "|\n";
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -183,6 +183,9 @@ namespace gtsam {
|
|||
// Potentials::reduceWithInverse(inverseReduction);
|
||||
// }
|
||||
|
||||
/// Enumerate all values into a map from values to double.
|
||||
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
||||
|
||||
/// @}
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
|
|||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
|
||||
string dot(bool showZero = false) const;
|
||||
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -82,11 +82,29 @@ TEST( DecisionTreeFactor, sum_max)
|
|||
DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check enumerate yields the correct list of assignment/value pairs.
|
||||
TEST(DecisionTreeFactor, enumerate) {
|
||||
DiscreteKey A(12, 3), B(5, 2);
|
||||
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
|
||||
auto actual = f.enumerate();
|
||||
std::vector<std::pair<DiscreteValues, double>> expected;
|
||||
DiscreteValues values;
|
||||
for (size_t a : {0, 1, 2}) {
|
||||
for (size_t b : {0, 1}) {
|
||||
values[12] = a;
|
||||
values[5] = b;
|
||||
expected.emplace_back(values, f(values));
|
||||
}
|
||||
}
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check markdown representation looks as expected.
|
||||
TEST(DecisionTreeFactor, markdown) {
|
||||
DiscreteKey A(12, 3), B(5, 2);
|
||||
DecisionTreeFactor f1(A & B, "1 2 3 4 5 6");
|
||||
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
|
||||
string expected =
|
||||
"|A|B|value|\n"
|
||||
"|:-:|:-:|:-:|\n"
|
||||
|
|
@ -97,7 +115,7 @@ TEST(DecisionTreeFactor, markdown) {
|
|||
"|2|0|5|\n"
|
||||
"|2|1|6|\n";
|
||||
auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||
string actual = f1.markdown(formatter);
|
||||
string actual = f.markdown(formatter);
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,54 @@
|
|||
"""
|
||||
GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
|
||||
Atlanta, Georgia 30332-0415
|
||||
All Rights Reserved
|
||||
|
||||
See LICENSE for the license information
|
||||
|
||||
Unit tests for DecisionTreeFactors.
|
||||
Author: Frank Dellaert
|
||||
"""
|
||||
|
||||
# pylint: disable=no-name-in-module, invalid-name
|
||||
|
||||
import unittest
|
||||
|
||||
from gtsam import DecisionTreeFactor, DecisionTreeFactor, DiscreteKeys
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
|
||||
class TestDecisionTreeFactor(GtsamTestCase):
|
||||
"""Tests for DecisionTreeFactors."""
|
||||
|
||||
def setUp(self):
|
||||
A = (12, 3)
|
||||
B = (5, 2)
|
||||
self.factor = DecisionTreeFactor(A, B, "1 2 3 4 5 6")
|
||||
|
||||
def test_enumerate(self):
|
||||
actual = self.factor.enumerate()
|
||||
_, values = zip(*actual)
|
||||
self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
|
||||
|
||||
def test_markdown(self):
|
||||
"""Test whether the _repr_markdown_ method."""
|
||||
|
||||
expected = \
|
||||
"|A|B|value|\n" \
|
||||
"|:-:|:-:|:-:|\n" \
|
||||
"|0|0|1|\n" \
|
||||
"|0|1|2|\n" \
|
||||
"|1|0|3|\n" \
|
||||
"|1|1|4|\n" \
|
||||
"|2|0|5|\n" \
|
||||
"|2|1|6|\n"
|
||||
|
||||
def formatter(x: int):
|
||||
return "A" if x == 12 else "B"
|
||||
|
||||
actual = self.factor._repr_markdown_(formatter)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue