diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 1a93d89de..48f8fd322 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -348,18 +348,14 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter, return ss.str(); } - // Print out header and construct argument for `cartesianProduct`. + // Print out header. ss << "|"; for (Key parent : parents()) { ss << "*" << keyFormatter(parent) << "*|"; } - size_t n = 1; - for (Key key : frontals()) { - size_t k = cardinalities_.at(key); - n *= k; - } - for (const auto& a : frontalAssignments()) { + auto frontalAssignments = this->frontalAssignments(); + for (const auto& a : frontalAssignments) { for (auto&& it = beginFrontals(); it != endFrontals(); ++it) { size_t index = a.at(*it); ss << Translate(names, *it, index); @@ -370,6 +366,7 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter, // Print out separator with alignment hints. ss << "|"; + size_t n = frontalAssignments.size(); for (size_t j = 0; j < nrParents() + n; j++) ss << ":-:|"; ss << "\n"; @@ -408,28 +405,37 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter, // Print out header row. ss << " "; - for (auto& key : keys()) { - ss << "" << keyFormatter(key) << ""; + for (Key parent : parents()) { + ss << "" << keyFormatter(parent) << ""; } - ss << "value\n"; + auto frontalAssignments = this->frontalAssignments(); + for (const auto& a : frontalAssignments) { + for (auto&& it = beginFrontals(); it != endFrontals(); ++it) { + size_t index = a.at(*it); + ss << "" << Translate(names, *it, index) << ""; + } + } + ss << "\n"; // Finish header and start body. ss << " \n \n"; - // Print out all rows. - auto rows = enumerate(); - for (const auto& kv : rows) { - ss << " "; - auto assignment = kv.first; - for (auto& key : keys()) { - size_t index = assignment.at(key); - ss << "" << Translate(names, key, index) << ""; + size_t count = 0, n = frontalAssignments.size(); + for (const auto& a : allAssignments()) { + if (count == 0) { + ss << " "; + for (auto&& it = beginParents(); it != endParents(); ++it) { + size_t index = a.at(*it); + ss << "" << Translate(names, *it, index) << ""; + } } - ss << "" << kv.second << ""; // value - ss << "\n"; + ss << "" << operator()(a) << ""; // value + count = (count + 1) % n; + if (count == 0) ss << "\n"; } - ss << " \n\n"; + // Finish up + ss << " \n\n"; return ss.str(); } diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 0686b3920..0ba53c69a 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -182,13 +182,13 @@ TEST(DiscreteBayesNet, markdown) { string expected = "`DiscreteBayesNet` of size 2\n" "\n" - " *P(Asia)*:\n\n" + " *P(Asia):*\n\n" "|Asia|value|\n" "|:-:|:-:|\n" "|0|0.99|\n" "|1|0.01|\n" "\n" - " *P(Smoking|Asia)*:\n\n" + " *P(Smoking|Asia):*\n\n" "|*Asia*|0|1|\n" "|:-:|:-:|:-:|\n" "|0|0.8|0.2|\n" diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 6d2af3cff..3fb67a615 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -125,7 +125,7 @@ TEST(DiscreteConditional, markdown_prior) { DiscreteKey A(Symbol('x', 1), 3); DiscreteConditional conditional(A % "1/2/2"); string expected = - " *P(x1)*:\n\n" + " *P(x1):*\n\n" "|x1|value|\n" "|:-:|:-:|\n" "|0|0.2|\n" @@ -142,7 +142,7 @@ TEST(DiscreteConditional, markdown_prior_names) { DiscreteKey A(x1, 3); DiscreteConditional conditional(A % "1/2/2"); string expected = - " *P(x1)*:\n\n" + " *P(x1):*\n\n" "|x1|value|\n" "|:-:|:-:|\n" "|A0|0.2|\n" @@ -160,7 +160,7 @@ TEST(DiscreteConditional, markdown_multivalued) { DiscreteConditional conditional( A | B = "2/88/10 2/20/78 33/33/34 33/33/34 95/2/3"); string expected = - " *P(a1|b1)*:\n\n" + " *P(a1|b1):*\n\n" "|*b1*|0|1|2|\n" "|:-:|:-:|:-:|:-:|\n" "|0|0.02|0.88|0.1|\n" @@ -178,7 +178,7 @@ TEST(DiscreteConditional, markdown) { DiscreteKey A(2, 2), B(1, 2), C(0, 3); DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0"); string expected = - " *P(A|B,C)*:\n\n" + " *P(A|B,C):*\n\n" "|*B*|*C*|T|F|\n" "|:-:|:-:|:-:|:-:|\n" "|-|Zero|0|1|\n" @@ -195,6 +195,36 @@ TEST(DiscreteConditional, markdown) { EXPECT(actual == expected); } +/* ************************************************************************* */ +// Check html representation looks as expected, two parents + names. +TEST(DiscreteConditional, html) { + DiscreteKey A(2, 2), B(1, 2), C(0, 3); + DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0"); + string expected = + "
\n" + "

P(A|B,C):

\n" + "\n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + "
BCTF
-Zero01
-One0.250.75
-Two0.50.5
+Zero0.750.25
+One01
+Two10
\n" + "
"; + vector keyNames{"C", "B", "A"}; + auto formatter = [keyNames](Key key) { return keyNames[key]; }; + DecisionTreeFactor::Names names{ + {0, {"Zero", "One", "Two"}}, {1, {"-", "+"}}, {2, {"T", "F"}}}; + string actual = conditional.html(formatter, names); + EXPECT(actual == expected); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/python/gtsam/tests/test_DiscreteConditional.py b/python/gtsam/tests/test_DiscreteConditional.py index 86bc303a9..0ae66c2d4 100644 --- a/python/gtsam/tests/test_DiscreteConditional.py +++ b/python/gtsam/tests/test_DiscreteConditional.py @@ -49,7 +49,7 @@ class TestDiscreteConditional(GtsamTestCase): conditional = DiscreteConditional(A, parents, "0/1 1/3 1/1 3/1 0/1 1/0") expected = \ - " *P(A|B,C)*:\n\n" \ + " *P(A|B,C):*\n\n" \ "|*B*|*C*|0|1|\n" \ "|:-:|:-:|:-:|:-:|\n" \ "|0|0|0|1|\n" \ diff --git a/python/gtsam/tests/test_DiscretePrior.py b/python/gtsam/tests/test_DiscretePrior.py index 5bf6a8d19..2c923589c 100644 --- a/python/gtsam/tests/test_DiscretePrior.py +++ b/python/gtsam/tests/test_DiscretePrior.py @@ -51,7 +51,7 @@ class TestDiscretePrior(GtsamTestCase): """Test the _repr_markdown_ method.""" prior = DiscretePrior(X, "2/3") - expected = " *P(0)*:\n\n" \ + expected = " *P(0):*\n\n" \ "|0|value|\n" \ "|:-:|:-:|\n" \ "|0|0.4|\n" \