From 8eb623b58f821c7dbaf7ddda3ac6cc7af54a3f5c Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 2 Jan 2022 21:34:22 -0500 Subject: [PATCH] Added an optional names argument for discrete markdown renderers --- gtsam/discrete/DecisionTreeFactor.cpp | 11 ++++++---- gtsam/discrete/DecisionTreeFactor.h | 12 +++++++--- gtsam/discrete/DiscreteBayesNet.cpp | 5 +++-- gtsam/discrete/DiscreteBayesNet.h | 4 ++-- gtsam/discrete/DiscreteBayesTree.cpp | 5 +++-- gtsam/discrete/DiscreteBayesTree.h | 4 ++-- gtsam/discrete/DiscreteConditional.cpp | 15 ++++++++----- gtsam/discrete/DiscreteConditional.h | 4 ++-- gtsam/discrete/DiscreteFactor.cpp | 15 +++++++++++-- gtsam/discrete/DiscreteFactor.h | 15 +++++++++++-- gtsam/discrete/DiscreteFactorGraph.cpp | 21 ++++++++++-------- gtsam/discrete/DiscreteFactorGraph.h | 15 ++++++++++--- gtsam/discrete/discrete.i | 10 +++++++++ .../discrete/tests/testDecisionTreeFactor.cpp | 21 ++++++++++++++++++ .../tests/testDiscreteConditional.cpp | 22 ++++++++++--------- gtsam_unstable/discrete/Constraint.h | 4 ++-- 16 files changed, 133 insertions(+), 50 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 75018cf92..4f3e3f7f1 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -179,9 +179,9 @@ namespace gtsam { } /* ************************************************************************* */ - std::string DecisionTreeFactor::markdown( - const KeyFormatter& keyFormatter) const { - std::stringstream ss; + string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; // Print out header and construct argument for `cartesianProduct`. ss << "|"; @@ -200,7 +200,10 @@ namespace gtsam { for (const auto& kv : rows) { ss << "|"; auto assignment = kv.first; - for (auto& key : keys()) ss << assignment.at(key) << "|"; + for (auto& key : keys()) { + size_t index = assignment.at(key); + ss << Translate(names, key, index) << "|"; + } ss << kv.second << "|\n"; } return ss.str(); diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 46509db82..f8832c223 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -192,9 +192,15 @@ namespace gtsam { std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter, bool showZero = true) const; - /// Render as markdown table. - std::string markdown( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override; + /** + * @brief Render as markdown table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a markdown string. + */ + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; /// @} diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index d9fba630e..510fb5638 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -63,12 +63,13 @@ namespace gtsam { /* ************************************************************************* */ std::string DiscreteBayesNet::markdown( - const KeyFormatter& keyFormatter) const { + const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { using std::endl; std::stringstream ss; ss << "`DiscreteBayesNet` of size " << size() << endl << endl; for(const DiscreteConditional::shared_ptr& conditional: *this) - ss << conditional->markdown(keyFormatter) << endl; + ss << conditional->markdown(keyFormatter, names) << endl; return ss.str(); } /* ************************************************************************* */ diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index aed4cec0a..5332b51dd 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -108,8 +108,8 @@ namespace gtsam { /// @{ /// Render as markdown table. - std::string markdown( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteFactor::Names& names = {}) const; /// @} diff --git a/gtsam/discrete/DiscreteBayesTree.cpp b/gtsam/discrete/DiscreteBayesTree.cpp index 8a9186d05..07d6e0f0e 100644 --- a/gtsam/discrete/DiscreteBayesTree.cpp +++ b/gtsam/discrete/DiscreteBayesTree.cpp @@ -57,13 +57,14 @@ namespace gtsam { /* **************************************************************************/ std::string DiscreteBayesTree::markdown( - const KeyFormatter& keyFormatter) const { + const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { using std::endl; std::stringstream ss; ss << "`DiscreteBayesTree` of size " << nodes_.size() << endl << endl; auto visitor = [&](const DiscreteBayesTreeClique::shared_ptr& clique, size_t& indent) { - ss << "\n" << clique->conditional()->markdown(keyFormatter); + ss << "\n" << clique->conditional()->markdown(keyFormatter, names); return indent + 1; }; size_t indent; diff --git a/gtsam/discrete/DiscreteBayesTree.h b/gtsam/discrete/DiscreteBayesTree.h index 12d6017cc..6189f25d5 100644 --- a/gtsam/discrete/DiscreteBayesTree.h +++ b/gtsam/discrete/DiscreteBayesTree.h @@ -93,8 +93,8 @@ class GTSAM_EXPORT DiscreteBayesTree /// @{ /// Render as markdown table. - std::string markdown( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteFactor::Names& names = {}) const; /// @} }; diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 46d5509e0..203f00f89 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -283,8 +283,8 @@ size_t DiscreteConditional::sample(size_t parent_value) const { } /* ************************************************************************* */ -std::string DiscreteConditional::markdown( - const KeyFormatter& keyFormatter) const { +std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter, + const Names& names) const { std::stringstream ss; // Print out signature. @@ -331,7 +331,10 @@ std::string DiscreteConditional::markdown( pairs.rend() - nrParents()); const auto frontal_assignments = cartesianProduct(slatnorf); for (const auto& a : frontal_assignments) { - for (it = beginFrontals(); it != endFrontals(); ++it) ss << a.at(*it); + for (it = beginFrontals(); it != endFrontals(); ++it) { + size_t index = a.at(*it); + ss << Translate(names, *it, index); + } ss << "|"; } ss << "\n"; @@ -348,8 +351,10 @@ std::string DiscreteConditional::markdown( for (const auto& a : assignments) { if (count == 0) { ss << "|"; - for (it = beginParents(); it != endParents(); ++it) - ss << a.at(*it) << "|"; + for (it = beginParents(); it != endParents(); ++it) { + size_t index = a.at(*it); + ss << Translate(names, *it, index) << "|"; + } } ss << operator()(a) << "|"; count = (count + 1) % n; diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index d21e3ae26..1cad927e9 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -180,8 +180,8 @@ public: /// @{ /// Render as markdown table. - std::string markdown( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override; + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; /// @} }; diff --git a/gtsam/discrete/DiscreteFactor.cpp b/gtsam/discrete/DiscreteFactor.cpp index c101653d2..1a12ef405 100644 --- a/gtsam/discrete/DiscreteFactor.cpp +++ b/gtsam/discrete/DiscreteFactor.cpp @@ -19,9 +19,20 @@ #include +#include + using namespace std; namespace gtsam { -/* ************************************************************************* */ -} // namespace gtsam +string DiscreteFactor::Translate(const Names& names, Key key, size_t index) { + if (names.empty()) { + stringstream ss; + ss << index; + return ss.str(); + } else { + return names.at(key)[index]; + } +} + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 76ed703bb..70545a5ca 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -89,9 +89,20 @@ public: /// @name Wrapper support /// @{ - /// Render as markdown table. + using Names = std::map>; + + static std::string Translate(const Names& names, Key key, size_t index); + + /** + * @brief Render as markdown table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a markdown string. + */ virtual std::string markdown( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const = 0; + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const = 0; /// @} }; diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index bd84e1364..be046d290 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -16,15 +16,17 @@ * @author Frank Dellaert */ -//#define ENABLE_TIMING -#include -#include #include +#include #include +#include #include -#include #include -#include +#include + +using std::vector; +using std::string; +using std::map; namespace gtsam { @@ -64,7 +66,7 @@ namespace gtsam { } /* ************************************************************************* */ - void DiscreteFactorGraph::print(const std::string& s, + void DiscreteFactorGraph::print(const string& s, const KeyFormatter& formatter) const { std::cout << s << std::endl; std::cout << "size: " << size() << std::endl; @@ -130,14 +132,15 @@ namespace gtsam { } /* ************************************************************************* */ - std::string DiscreteFactorGraph::markdown( - const KeyFormatter& keyFormatter) const { + string DiscreteFactorGraph::markdown( + const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { using std::endl; std::stringstream ss; ss << "`DiscreteFactorGraph` of size " << size() << endl << endl; for (size_t i = 0; i < factors_.size(); i++) { ss << "factor " << i << ":\n"; - ss << factors_[i]->markdown(keyFormatter) << endl; + ss << factors_[i]->markdown(keyFormatter, names) << endl; } return ss.str(); } diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index 6856493f7..9aa04d649 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -24,7 +24,10 @@ #include #include #include + #include +#include +#include namespace gtsam { @@ -140,9 +143,15 @@ public: /// @name Wrapper support /// @{ - /// Render as markdown table. - std::string markdown( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /** + * @brief Render as markdown table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, a map from Key to category names. + * @return std::string a (potentially long) markdown string. + */ + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteFactor::Names& names = {}) const; /// @} }; // \ DiscreteFactorGraph diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 5bd4a2913..e298deaf1 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -52,6 +52,8 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor { std::vector> enumerate() const; string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; }; #include @@ -88,6 +90,8 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; }; #include @@ -130,6 +134,8 @@ class DiscreteBayesNet { gtsam::DiscreteValues sample() const; string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; }; #include @@ -164,6 +170,8 @@ class DiscreteBayesTree { string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; }; #include @@ -211,6 +219,8 @@ class DiscreteFactorGraph { string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; }; } // namespace gtsam diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 6af7ca731..c4e5f06bb 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -119,6 +119,27 @@ TEST(DecisionTreeFactor, markdown) { EXPECT(actual == expected); } +/* ************************************************************************* */ +// Check markdown representation with a value formatter. +TEST(DecisionTreeFactor, markdownWithValueFormatter) { + DiscreteKey A(12, 3), B(5, 2); + DecisionTreeFactor f(A & B, "1 2 3 4 5 6"); + string expected = + "|A|B|value|\n" + "|:-:|:-:|:-:|\n" + "|Zero|-|1|\n" + "|Zero|+|2|\n" + "|One|-|3|\n" + "|One|+|4|\n" + "|Two|-|5|\n" + "|Two|+|6|\n"; + auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; + DecisionTreeFactor::Names names{{12, {"Zero", "One", "Two"}}, + {5, {"-", "+"}}}; + string actual = f.markdown(keyFormatter, names); + EXPECT(actual == expected); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 00ae1acd0..90da07cdc 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -161,17 +161,19 @@ TEST(DiscreteConditional, markdown) { DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0"); string expected = " *P(A|B,C)*:\n\n" - "|B|C|0|1|\n" + "|B|C|T|F|\n" "|:-:|:-:|:-:|:-:|\n" - "|0|0|0|1|\n" - "|0|1|0.25|0.75|\n" - "|0|2|0.5|0.5|\n" - "|1|0|0.75|0.25|\n" - "|1|1|0|1|\n" - "|1|2|1|0|\n"; - vector names{"C", "B", "A"}; - auto formatter = [names](Key key) { return names[key]; }; - string actual = conditional.markdown(formatter); + "|-|Zero|0|1|\n" + "|-|One|0.25|0.75|\n" + "|-|Two|0.5|0.5|\n" + "|+|Zero|0.75|0.25|\n" + "|+|One|0|1|\n" + "|+|Two|1|0|\n"; + vector keyNames{"C", "B", "A"}; + auto formatter = [keyNames](Key key) { return keyNames[key]; }; + DecisionTreeFactor::Names names{ + {0, {"Zero", "One", "Two"}}, {1, {"-", "+"}}, {2, {"T", "F"}}}; + string actual = conditional.markdown(formatter, names); EXPECT(actual == expected); } diff --git a/gtsam_unstable/discrete/Constraint.h b/gtsam_unstable/discrete/Constraint.h index 5c21028a0..85748f054 100644 --- a/gtsam_unstable/discrete/Constraint.h +++ b/gtsam_unstable/discrete/Constraint.h @@ -86,8 +86,8 @@ class GTSAM_EXPORT Constraint : public DiscreteFactor { /// @{ /// Render as markdown table. - std::string markdown( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override { + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override { return (boost::format("`Constraint` on %1% variables\n") % (size())).str(); }