markdown for Bayes nets

release/4.3a0
Frank Dellaert 2021-12-24 12:33:23 -05:00
parent 3bdb585185
commit edadd352af
4 changed files with 50 additions and 6 deletions

View File

@ -38,7 +38,7 @@ namespace gtsam {
double DiscreteBayesNet::evaluate(const DiscreteValues & values) const { double DiscreteBayesNet::evaluate(const DiscreteValues & values) const {
// evaluate all conditionals and multiply // evaluate all conditionals and multiply
double result = 1.0; double result = 1.0;
for(DiscreteConditional::shared_ptr conditional: *this) for(const DiscreteConditional::shared_ptr& conditional: *this)
result *= (*conditional)(values); result *= (*conditional)(values);
return result; return result;
} }
@ -61,5 +61,15 @@ namespace gtsam {
return result; return result;
} }
/* ************************************************************************* */
std::string DiscreteBayesNet::_repr_markdown_(
const KeyFormatter& keyFormatter) const {
using std::endl;
std::stringstream ss;
ss << "`DiscreteBayesNet` of size " << size() << endl << endl;
for(const DiscreteConditional::shared_ptr& conditional: *this)
ss << conditional->_repr_markdown_(keyFormatter) << endl;
return ss.str();
}
/* ************************************************************************* */ /* ************************************************************************* */
} // namespace } // namespace

View File

@ -13,6 +13,7 @@
* @file DiscreteBayesNet.h * @file DiscreteBayesNet.h
* @date Feb 15, 2011 * @date Feb 15, 2011
* @author Duy-Nguyen Ta * @author Duy-Nguyen Ta
* @author Frank dellaert
*/ */
#pragma once #pragma once
@ -97,6 +98,14 @@ namespace gtsam {
DiscreteValues sample() const; DiscreteValues sample() const;
///@} ///@}
/// @name Wrapper support
/// @{
/// Render as markdown table.
std::string _repr_markdown_(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// @}
private: private:
/** Serialization function */ /** Serialization function */

View File

@ -93,6 +93,8 @@ class DiscreteBayesNet {
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const; gtsam::DiscreteValues optimize() const;
gtsam::DiscreteValues sample() const; gtsam::DiscreteValues sample() const;
string _repr_markdown_(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
}; };
#include <gtsam/discrete/DiscreteBayesTree.h> #include <gtsam/discrete/DiscreteBayesTree.h>

View File

@ -38,6 +38,9 @@ using namespace boost::assign;
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2),
LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscreteBayesNet, bayesNet) { TEST(DiscreteBayesNet, bayesNet) {
DiscreteBayesNet bayesNet; DiscreteBayesNet bayesNet;
@ -71,8 +74,6 @@ TEST(DiscreteBayesNet, bayesNet) {
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscreteBayesNet, Asia) { TEST(DiscreteBayesNet, Asia) {
DiscreteBayesNet asia; DiscreteBayesNet asia;
DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2),
Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
asia.add(Asia % "99/1"); asia.add(Asia % "99/1");
asia.add(Smoking % "50/50"); asia.add(Smoking % "50/50");
@ -151,9 +152,6 @@ TEST(DiscreteBayesNet, Sugar) {
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscreteBayesNet, Dot) { TEST(DiscreteBayesNet, Dot) {
DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2),
Either(5, 2);
DiscreteBayesNet fragment; DiscreteBayesNet fragment;
fragment.add(Asia % "99/1"); fragment.add(Asia % "99/1");
fragment.add(Smoking % "50/50"); fragment.add(Smoking % "50/50");
@ -172,6 +170,31 @@ TEST(DiscreteBayesNet, Dot) {
"}"); "}");
} }
/* ************************************************************************* */
// Check markdown representation looks as expected.
TEST(DiscreteBayesNet, markdown) {
DiscreteBayesNet fragment;
fragment.add(Asia % "99/1");
fragment.add(Smoking | Asia = "8/2 7/3");
string expected =
"`DiscreteBayesNet` of size 2\n"
"\n"
" $P(Asia)$:\n"
"|0|1|\n"
"|:-:|:-:|\n"
"|0.99|0.01|\n"
"\n"
" $P(Smoking|Asia)$:\n"
"|Asia|0|1|\n"
"|:-:|:-:|:-:|\n"
"|0|0.8|0.2|\n"
"|1|0.7|0.3|\n\n";
auto formatter = [](Key key) { return key == 0 ? "Asia" : "Smoking"; };
string actual = fragment._repr_markdown_(formatter);
EXPECT(actual == expected);
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;