diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 19d61c3b1..ff4c0694e 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -135,33 +135,31 @@ namespace gtsam { } /* ************************************************************************* */ - // Check markdown representation looks as expected. - std::string DecisionTreeFactor::_repr_markdown_() const { + std::string DecisionTreeFactor::_repr_markdown_( + const KeyFormatter& keyFormatter) const { std::stringstream ss; - // Print out header and calculate number of rows. + // Print out header and construct argument for `cartesianProduct`. + std::vector> pairs; ss << "|"; - for (auto& key : cardinalities_) { - size_t k = key.second; - ss << key.first << "(" << k << ")|"; + for (auto& key : keys()) { + ss << keyFormatter(key) << "|"; + pairs.emplace_back(key, cardinalities_.at(key)); } ss << "value|\n"; // Print out separator with alignment hints. - size_t n = cardinalities_.size(); ss << "|"; - for (size_t j = 0; j < n; j++) ss << ":-:|"; + for (size_t j = 0; j < size(); j++) ss << ":-:|"; ss << ":-:|\n"; // Print out all rows. - std::vector> keys(cardinalities_.begin(), - cardinalities_.end()); - const auto assignments = cartesianProduct(keys); - for (auto &&assignment : assignments) { + std::vector> rpairs(pairs.rbegin(), pairs.rend()); + const auto assignments = cartesianProduct(rpairs); + for (const auto& assignment : assignments) { ss << "|"; - for (auto& kv : assignment) ss << kv.second << "|"; - const double value = operator()(assignment); - ss << value << "|\n"; + for (auto& key : keys()) ss << assignment.at(key) << "|"; + ss << operator()(assignment) << "|\n"; } return ss.str(); } diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 68d629e9c..308b2f9ca 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -167,7 +167,8 @@ namespace gtsam { /// @{ /// Render as markdown table. - std::string _repr_markdown_() const; + std::string _repr_markdown_( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; /// @} diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 3026d5f82..75d5efa9f 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -83,19 +83,21 @@ TEST( DecisionTreeFactor, sum_max) } /* ************************************************************************* */ +// Check markdown representation looks as expected. TEST(DecisionTreeFactor, markdown) { - DiscreteKey v0(0, 3), v1(1, 2); - DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6"); + DiscreteKey A(12, 3), B(5, 2); + DecisionTreeFactor f1(A & B, "1 2 3 4 5 6"); string expected = - "|0(3)|1(2)|value|\n" + "|A|B|value|\n" "|:-:|:-:|:-:|\n" "|0|0|1|\n" - "|1|0|3|\n" - "|2|0|5|\n" "|0|1|2|\n" + "|1|0|3|\n" "|1|1|4|\n" + "|2|0|5|\n" "|2|1|6|\n"; - string actual = f1._repr_markdown_(); + auto formatter = [](Key key) { return key == 12 ? "A" : "B"; }; + string actual = f1._repr_markdown_(formatter); EXPECT(actual == expected); }