diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp
index 1a93d89de..48f8fd322 100644
--- a/gtsam/discrete/DiscreteConditional.cpp
+++ b/gtsam/discrete/DiscreteConditional.cpp
@@ -348,18 +348,14 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
return ss.str();
}
- // Print out header and construct argument for `cartesianProduct`.
+ // Print out header.
ss << "|";
for (Key parent : parents()) {
ss << "*" << keyFormatter(parent) << "*|";
}
- size_t n = 1;
- for (Key key : frontals()) {
- size_t k = cardinalities_.at(key);
- n *= k;
- }
- for (const auto& a : frontalAssignments()) {
+ auto frontalAssignments = this->frontalAssignments();
+ for (const auto& a : frontalAssignments) {
for (auto&& it = beginFrontals(); it != endFrontals(); ++it) {
size_t index = a.at(*it);
ss << Translate(names, *it, index);
@@ -370,6 +366,7 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
// Print out separator with alignment hints.
ss << "|";
+ size_t n = frontalAssignments.size();
for (size_t j = 0; j < nrParents() + n; j++) ss << ":-:|";
ss << "\n";
@@ -408,28 +405,37 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
// Print out header row.
ss << "
";
- for (auto& key : keys()) {
- ss << "" << keyFormatter(key) << " | ";
+ for (Key parent : parents()) {
+ ss << "" << keyFormatter(parent) << " | ";
}
- ss << "value |
\n";
+ auto frontalAssignments = this->frontalAssignments();
+ for (const auto& a : frontalAssignments) {
+ for (auto&& it = beginFrontals(); it != endFrontals(); ++it) {
+ size_t index = a.at(*it);
+ ss << "" << Translate(names, *it, index) << " | ";
+ }
+ }
+ ss << "\n";
// Finish header and start body.
ss << " \n \n";
- // Print out all rows.
- auto rows = enumerate();
- for (const auto& kv : rows) {
- ss << " ";
- auto assignment = kv.first;
- for (auto& key : keys()) {
- size_t index = assignment.at(key);
- ss << "" << Translate(names, key, index) << " | ";
+ size_t count = 0, n = frontalAssignments.size();
+ for (const auto& a : allAssignments()) {
+ if (count == 0) {
+ ss << "
";
+ for (auto&& it = beginParents(); it != endParents(); ++it) {
+ size_t index = a.at(*it);
+ ss << "" << Translate(names, *it, index) << " | ";
+ }
}
- ss << "" << kv.second << " | "; // value
- ss << "
\n";
+ ss << "" << operator()(a) << " | "; // value
+ count = (count + 1) % n;
+ if (count == 0) ss << "\n";
}
- ss << " \n\n";
+ // Finish up
+ ss << " \n\n";
return ss.str();
}
diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp
index 0686b3920..0ba53c69a 100644
--- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp
+++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp
@@ -182,13 +182,13 @@ TEST(DiscreteBayesNet, markdown) {
string expected =
"`DiscreteBayesNet` of size 2\n"
"\n"
- " *P(Asia)*:\n\n"
+ " *P(Asia):*\n\n"
"|Asia|value|\n"
"|:-:|:-:|\n"
"|0|0.99|\n"
"|1|0.01|\n"
"\n"
- " *P(Smoking|Asia)*:\n\n"
+ " *P(Smoking|Asia):*\n\n"
"|*Asia*|0|1|\n"
"|:-:|:-:|:-:|\n"
"|0|0.8|0.2|\n"
diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp
index 6d2af3cff..3fb67a615 100644
--- a/gtsam/discrete/tests/testDiscreteConditional.cpp
+++ b/gtsam/discrete/tests/testDiscreteConditional.cpp
@@ -125,7 +125,7 @@ TEST(DiscreteConditional, markdown_prior) {
DiscreteKey A(Symbol('x', 1), 3);
DiscreteConditional conditional(A % "1/2/2");
string expected =
- " *P(x1)*:\n\n"
+ " *P(x1):*\n\n"
"|x1|value|\n"
"|:-:|:-:|\n"
"|0|0.2|\n"
@@ -142,7 +142,7 @@ TEST(DiscreteConditional, markdown_prior_names) {
DiscreteKey A(x1, 3);
DiscreteConditional conditional(A % "1/2/2");
string expected =
- " *P(x1)*:\n\n"
+ " *P(x1):*\n\n"
"|x1|value|\n"
"|:-:|:-:|\n"
"|A0|0.2|\n"
@@ -160,7 +160,7 @@ TEST(DiscreteConditional, markdown_multivalued) {
DiscreteConditional conditional(
A | B = "2/88/10 2/20/78 33/33/34 33/33/34 95/2/3");
string expected =
- " *P(a1|b1)*:\n\n"
+ " *P(a1|b1):*\n\n"
"|*b1*|0|1|2|\n"
"|:-:|:-:|:-:|:-:|\n"
"|0|0.02|0.88|0.1|\n"
@@ -178,7 +178,7 @@ TEST(DiscreteConditional, markdown) {
DiscreteKey A(2, 2), B(1, 2), C(0, 3);
DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0");
string expected =
- " *P(A|B,C)*:\n\n"
+ " *P(A|B,C):*\n\n"
"|*B*|*C*|T|F|\n"
"|:-:|:-:|:-:|:-:|\n"
"|-|Zero|0|1|\n"
@@ -195,6 +195,36 @@ TEST(DiscreteConditional, markdown) {
EXPECT(actual == expected);
}
+/* ************************************************************************* */
+// Check html representation looks as expected, two parents + names.
+TEST(DiscreteConditional, html) {
+ DiscreteKey A(2, 2), B(1, 2), C(0, 3);
+ DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0");
+ string expected =
+ "\n"
+ "
P(A|B,C):
\n"
+ "
\n"
+ " \n"
+ " B | C | T | F |
\n"
+ " \n"
+ " \n"
+ " - | Zero | 0 | 1 |
\n"
+ " - | One | 0.25 | 0.75 |
\n"
+ " - | Two | 0.5 | 0.5 |
\n"
+ " + | Zero | 0.75 | 0.25 |
\n"
+ " + | One | 0 | 1 |
\n"
+ " + | Two | 1 | 0 |
\n"
+ " \n"
+ "
\n"
+ "
";
+ vector keyNames{"C", "B", "A"};
+ auto formatter = [keyNames](Key key) { return keyNames[key]; };
+ DecisionTreeFactor::Names names{
+ {0, {"Zero", "One", "Two"}}, {1, {"-", "+"}}, {2, {"T", "F"}}};
+ string actual = conditional.html(formatter, names);
+ EXPECT(actual == expected);
+}
+
/* ************************************************************************* */
int main() {
TestResult tr;
diff --git a/python/gtsam/tests/test_DiscreteConditional.py b/python/gtsam/tests/test_DiscreteConditional.py
index 86bc303a9..0ae66c2d4 100644
--- a/python/gtsam/tests/test_DiscreteConditional.py
+++ b/python/gtsam/tests/test_DiscreteConditional.py
@@ -49,7 +49,7 @@ class TestDiscreteConditional(GtsamTestCase):
conditional = DiscreteConditional(A, parents,
"0/1 1/3 1/1 3/1 0/1 1/0")
expected = \
- " *P(A|B,C)*:\n\n" \
+ " *P(A|B,C):*\n\n" \
"|*B*|*C*|0|1|\n" \
"|:-:|:-:|:-:|:-:|\n" \
"|0|0|0|1|\n" \
diff --git a/python/gtsam/tests/test_DiscretePrior.py b/python/gtsam/tests/test_DiscretePrior.py
index 5bf6a8d19..2c923589c 100644
--- a/python/gtsam/tests/test_DiscretePrior.py
+++ b/python/gtsam/tests/test_DiscretePrior.py
@@ -51,7 +51,7 @@ class TestDiscretePrior(GtsamTestCase):
"""Test the _repr_markdown_ method."""
prior = DiscretePrior(X, "2/3")
- expected = " *P(0)*:\n\n" \
+ expected = " *P(0):*\n\n" \
"|0|value|\n" \
"|:-:|:-:|\n" \
"|0|0.4|\n" \