Merge pull request #987 from borglab/feature/markdown2

release/4.3a0
Varun Agrawal 2021-12-25 09:30:57 -05:00 committed by GitHub
commit a4f1cf328f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 227 additions and 86 deletions

View File

@ -168,7 +168,7 @@ namespace gtsam {
/// Render as markdown table. /// Render as markdown table.
std::string _repr_markdown_( std::string _repr_markdown_(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;
/// @} /// @}

View File

@ -38,7 +38,7 @@ namespace gtsam {
double DiscreteBayesNet::evaluate(const DiscreteValues & values) const { double DiscreteBayesNet::evaluate(const DiscreteValues & values) const {
// evaluate all conditionals and multiply // evaluate all conditionals and multiply
double result = 1.0; double result = 1.0;
for(DiscreteConditional::shared_ptr conditional: *this) for(const DiscreteConditional::shared_ptr& conditional: *this)
result *= (*conditional)(values); result *= (*conditional)(values);
return result; return result;
} }
@ -61,5 +61,15 @@ namespace gtsam {
return result; return result;
} }
/* ************************************************************************* */
std::string DiscreteBayesNet::_repr_markdown_(
const KeyFormatter& keyFormatter) const {
using std::endl;
std::stringstream ss;
ss << "`DiscreteBayesNet` of size " << size() << endl << endl;
for(const DiscreteConditional::shared_ptr& conditional: *this)
ss << conditional->_repr_markdown_(keyFormatter) << endl;
return ss.str();
}
/* ************************************************************************* */ /* ************************************************************************* */
} // namespace } // namespace

View File

@ -13,6 +13,7 @@
* @file DiscreteBayesNet.h * @file DiscreteBayesNet.h
* @date Feb 15, 2011 * @date Feb 15, 2011
* @author Duy-Nguyen Ta * @author Duy-Nguyen Ta
* @author Frank dellaert
*/ */
#pragma once #pragma once
@ -97,6 +98,14 @@ namespace gtsam {
DiscreteValues sample() const; DiscreteValues sample() const;
///@} ///@}
/// @name Wrapper support
/// @{
/// Render as markdown table.
std::string _repr_markdown_(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// @}
private: private:
/** Serialization function */ /** Serialization function */

View File

@ -55,8 +55,21 @@ namespace gtsam {
return result; return result;
} }
} // \namespace gtsam /* **************************************************************************/
std::string DiscreteBayesTree::_repr_markdown_(
const KeyFormatter& keyFormatter) const {
using std::endl;
std::stringstream ss;
ss << "`DiscreteBayesTree` of size " << nodes_.size() << endl << endl;
auto visitor = [&](const DiscreteBayesTreeClique::shared_ptr& clique,
size_t& indent) {
ss << "\n" << clique->conditional()->_repr_markdown_(keyFormatter);
return indent + 1;
};
size_t indent;
treeTraversal::DepthFirstForest(*this, indent, visitor);
return ss.str();
}
/* **************************************************************************/
} // namespace gtsam

View File

@ -72,6 +72,8 @@ class GTSAM_EXPORT DiscreteBayesTree
typedef DiscreteBayesTree This; typedef DiscreteBayesTree This;
typedef boost::shared_ptr<This> shared_ptr; typedef boost::shared_ptr<This> shared_ptr;
/// @name Standard interface
/// @{
/** Default constructor, creates an empty Bayes tree */ /** Default constructor, creates an empty Bayes tree */
DiscreteBayesTree() {} DiscreteBayesTree() {}
@ -82,10 +84,19 @@ class GTSAM_EXPORT DiscreteBayesTree
double evaluate(const DiscreteValues& values) const; double evaluate(const DiscreteValues& values) const;
//** (Preferred) sugar for the above for given DiscreteValues */ //** (Preferred) sugar for the above for given DiscreteValues */
double operator()(const DiscreteValues & values) const { double operator()(const DiscreteValues& values) const {
return evaluate(values); return evaluate(values);
} }
/// @}
/// @name Wrapper support
/// @{
/// Render as markdown table.
std::string _repr_markdown_(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// @}
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -229,11 +229,22 @@ std::string DiscreteConditional::_repr_markdown_(
// Print out signature. // Print out signature.
ss << " $P("; ss << " $P(";
for(Key key: frontals())
ss << keyFormatter(key);
if (nrParents() > 0)
ss << "|";
bool first = true; bool first = true;
for (Key key : frontals()) {
if (!first) ss << ",";
ss << keyFormatter(key);
first = false;
}
if (nrParents() == 0) {
// We have no parents, call factor method.
ss << ")$:" << std::endl;
ss << DecisionTreeFactor::_repr_markdown_();
return ss.str();
}
// We have parents, continue signature and do custom print.
ss << "|";
first = true;
for (Key parent : parents()) { for (Key parent : parents()) {
if (!first) ss << ","; if (!first) ss << ",";
ss << keyFormatter(parent); ss << keyFormatter(parent);
@ -256,9 +267,8 @@ std::string DiscreteConditional::_repr_markdown_(
pairs.emplace_back(key, k); pairs.emplace_back(key, k);
n *= k; n *= k;
} }
size_t nrParents = size() - nrFrontals_;
std::vector<std::pair<Key, size_t>> slatnorf(pairs.rbegin(), std::vector<std::pair<Key, size_t>> slatnorf(pairs.rbegin(),
pairs.rend() - nrParents); pairs.rend() - nrParents());
const auto frontal_assignments = cartesianProduct(slatnorf); const auto frontal_assignments = cartesianProduct(slatnorf);
for (const auto& a : frontal_assignments) { for (const auto& a : frontal_assignments) {
for (it = beginFrontals(); it != endFrontals(); ++it) ss << a.at(*it); for (it = beginFrontals(); it != endFrontals(); ++it) ss << a.at(*it);
@ -268,7 +278,7 @@ std::string DiscreteConditional::_repr_markdown_(
// Print out separator with alignment hints. // Print out separator with alignment hints.
ss << "|"; ss << "|";
for (size_t j = 0; j < nrParents + n; j++) ss << ":-:|"; for (size_t j = 0; j < nrParents() + n; j++) ss << ":-:|";
ss << "\n"; ss << "\n";
// Print out all rows. // Print out all rows.

View File

@ -172,7 +172,7 @@ public:
/// Render as markdown table. /// Render as markdown table.
std::string _repr_markdown_( std::string _repr_markdown_(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;
/// @} /// @}
}; };

View File

@ -88,6 +88,14 @@ public:
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
/// @}
/// @name Wrapper support
/// @{
/// Render as markdown table.
virtual std::string _repr_markdown_(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const = 0;
/// @} /// @}
}; };
// DiscreteFactor // DiscreteFactor

View File

@ -129,6 +129,18 @@ namespace gtsam {
return std::make_pair(cond, sum); return std::make_pair(cond, sum);
} }
/* ************************************************************************* */ /* ************************************************************************* */
} // namespace std::string DiscreteFactorGraph::_repr_markdown_(
const KeyFormatter& keyFormatter) const {
using std::endl;
std::stringstream ss;
ss << "`DiscreteFactorGraph` of size " << size() << endl << endl;
for (size_t i = 0; i < factors_.size(); i++) {
ss << "factor " << i << ":\n";
ss << factors_[i]->_repr_markdown_(keyFormatter) << endl;
}
return ss.str();
}
/* ************************************************************************* */
} // namespace gtsam

View File

@ -154,6 +154,14 @@ public:
// /** Apply a reduction, which is a remapping of variable indices. */ // /** Apply a reduction, which is a remapping of variable indices. */
// GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction); // GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction);
/// @name Wrapper support
/// @{
/// Render as markdown table.
std::string _repr_markdown_(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// @}
}; // \ DiscreteFactorGraph }; // \ DiscreteFactorGraph
/// traits /// traits

View File

@ -93,6 +93,8 @@ class DiscreteBayesNet {
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const; gtsam::DiscreteValues optimize() const;
gtsam::DiscreteValues sample() const; gtsam::DiscreteValues sample() const;
string _repr_markdown_(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
}; };
#include <gtsam/discrete/DiscreteBayesTree.h> #include <gtsam/discrete/DiscreteBayesTree.h>
@ -124,6 +126,9 @@ class DiscreteBayesTree {
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
string _repr_markdown_(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
}; };
#include <gtsam/inference/DotWriter.h> #include <gtsam/inference/DotWriter.h>
@ -164,6 +169,9 @@ class DiscreteFactorGraph {
gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering); gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesTree eliminateMultifrontal(); gtsam::DiscreteBayesTree eliminateMultifrontal();
gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering); gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering);
string _repr_markdown_(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -38,6 +38,9 @@ using namespace boost::assign;
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2),
LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscreteBayesNet, bayesNet) { TEST(DiscreteBayesNet, bayesNet) {
DiscreteBayesNet bayesNet; DiscreteBayesNet bayesNet;
@ -71,8 +74,6 @@ TEST(DiscreteBayesNet, bayesNet) {
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscreteBayesNet, Asia) { TEST(DiscreteBayesNet, Asia) {
DiscreteBayesNet asia; DiscreteBayesNet asia;
DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2),
Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
asia.add(Asia % "99/1"); asia.add(Asia % "99/1");
asia.add(Smoking % "50/50"); asia.add(Smoking % "50/50");
@ -151,9 +152,6 @@ TEST(DiscreteBayesNet, Sugar) {
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscreteBayesNet, Dot) { TEST(DiscreteBayesNet, Dot) {
DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2),
Either(5, 2);
DiscreteBayesNet fragment; DiscreteBayesNet fragment;
fragment.add(Asia % "99/1"); fragment.add(Asia % "99/1");
fragment.add(Smoking % "50/50"); fragment.add(Smoking % "50/50");
@ -172,6 +170,32 @@ TEST(DiscreteBayesNet, Dot) {
"}"); "}");
} }
/* ************************************************************************* */
// Check markdown representation looks as expected.
TEST(DiscreteBayesNet, markdown) {
DiscreteBayesNet fragment;
fragment.add(Asia % "99/1");
fragment.add(Smoking | Asia = "8/2 7/3");
string expected =
"`DiscreteBayesNet` of size 2\n"
"\n"
" $P(Asia)$:\n"
"|0|value|\n"
"|:-:|:-:|\n"
"|0|0.99|\n"
"|1|0.01|\n"
"\n"
" $P(Smoking|Asia)$:\n"
"|Asia|0|1|\n"
"|:-:|:-:|:-:|\n"
"|0|0.8|0.2|\n"
"|1|0.7|0.3|\n\n";
auto formatter = [](Key key) { return key == 0 ? "Asia" : "Smoking"; };
string actual = fragment._repr_markdown_(formatter);
EXPECT(actual == expected);
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -110,13 +110,15 @@ TEST(DiscreteConditional, Combine) {
/* ************************************************************************* */ /* ************************************************************************* */
// Check markdown representation looks as expected, no parents. // Check markdown representation looks as expected, no parents.
TEST(DiscreteConditional, markdown_prior) { TEST(DiscreteConditional, markdown_prior) {
DiscreteKey A(Symbol('x', 1), 2); DiscreteKey A(Symbol('x', 1), 3);
DiscreteConditional conditional(A % "1/3"); DiscreteConditional conditional(A % "1/2/2");
string expected = string expected =
" $P(x1)$:\n" " $P(x1)$:\n"
"|0|1|\n" "|x1|value|\n"
"|:-:|:-:|\n" "|:-:|:-:|\n"
"|0.25|0.75|\n"; "|0|0.2|\n"
"|1|0.4|\n"
"|2|0.4|\n";
string actual = conditional._repr_markdown_(); string actual = conditional._repr_markdown_();
EXPECT(actual == expected); EXPECT(actual == expected);
} }

View File

@ -361,11 +361,9 @@ cout << unicorns;
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscreteFactorGraph, Dot) { TEST(DiscreteFactorGraph, Dot) {
// Declare a bunch of keys
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
// Create Factor graph // Create Factor graph
DiscreteFactorGraph graph; DiscreteFactorGraph graph;
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
graph.add(C & A, "0.2 0.8 0.3 0.7"); graph.add(C & A, "0.2 0.8 0.3 0.7");
graph.add(C & B, "0.1 0.9 0.4 0.6"); graph.add(C & B, "0.1 0.9 0.4 0.6");
@ -384,6 +382,44 @@ TEST(DiscreteFactorGraph, Dot) {
EXPECT(actual == expected); EXPECT(actual == expected);
} }
/* ************************************************************************* */
// Check markdown representation looks as expected.
TEST(DiscreteFactorGraph, markdown) {
// Create Factor graph
DiscreteFactorGraph graph;
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
graph.add(C & A, "0.2 0.8 0.3 0.7");
graph.add(C & B, "0.1 0.9 0.4 0.6");
string expected =
"`DiscreteFactorGraph` of size 2\n"
"\n"
"factor 0:\n"
"|C|A|value|\n"
"|:-:|:-:|:-:|\n"
"|0|0|0.2|\n"
"|0|1|0.8|\n"
"|1|0|0.3|\n"
"|1|1|0.7|\n"
"\n"
"factor 1:\n"
"|C|B|value|\n"
"|:-:|:-:|:-:|\n"
"|0|0|0.1|\n"
"|0|1|0.9|\n"
"|1|0|0.4|\n"
"|1|1|0.6|\n\n";
vector<string> names{"C", "A", "B"};
auto formatter = [names](Key key) { return names[key]; };
string actual = graph._repr_markdown_(formatter);
EXPECT(actual == expected);
// Make sure values are correctly displayed.
DiscreteValues values;
values[0] = 1;
values[1] = 0;
EXPECT_DOUBLES_EQUAL(0.3, graph[0]->operator()(values), 1e-9);
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -22,6 +22,7 @@
#include <gtsam_unstable/dllexport.h> #include <gtsam_unstable/dllexport.h>
#include <boost/assign.hpp> #include <boost/assign.hpp>
#include <boost/format.hpp>
#include <map> #include <map>
namespace gtsam { namespace gtsam {
@ -81,6 +82,16 @@ class GTSAM_EXPORT Constraint : public DiscreteFactor {
/// Partially apply known values, domain version /// Partially apply known values, domain version
virtual shared_ptr partiallyApply(const Domains&) const = 0; virtual shared_ptr partiallyApply(const Domains&) const = 0;
/// @} /// @}
/// @name Wrapper support
/// @{
/// Render as markdown table.
std::string _repr_markdown_(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override {
return (boost::format("`Constraint` on %1% variables\n") % (size())).str();
}
/// @}
}; };
// DiscreteFactor // DiscreteFactor

File diff suppressed because one or more lines are too long