diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 951c0b6ca..1a93d89de 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -29,9 +29,12 @@ #include #include #include +#include using namespace std; - +using std::stringstream; +using std::vector; +using std::pair; namespace gtsam { // Instantiate base class @@ -292,55 +295,72 @@ size_t DiscreteConditional::sample() const { } /* ************************************************************************* */ -std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter, - const Names& names) const { - std::stringstream ss; +vector> DiscreteConditional::frontalAssignments() const { + vector> pairs; + for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key)); + vector> rpairs(pairs.rbegin(), pairs.rend()); + return cartesianProduct(rpairs); +} - // Print out signature. - ss << " *P("; +/* ************************************************************************* */ +vector> DiscreteConditional::allAssignments() const { + vector> pairs; + for (Key key : parents()) pairs.emplace_back(key, cardinalities_.at(key)); + for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key)); + vector> rpairs(pairs.rbegin(), pairs.rend()); + return cartesianProduct(rpairs); +} + +/* ************************************************************************* */ +// Print out signature. +static void streamSignature(const DiscreteConditional& conditional, + const KeyFormatter& keyFormatter, + stringstream* ss) { + *ss << "P("; bool first = true; - for (Key key : frontals()) { - if (!first) ss << ","; - ss << keyFormatter(key); + for (Key key : conditional.frontals()) { + if (!first) *ss << ","; + *ss << keyFormatter(key); first = false; } + if (conditional.nrParents() > 0) { + *ss << "|"; + bool first = true; + for (Key parent : conditional.parents()) { + if (!first) *ss << ","; + *ss << keyFormatter(parent); + first = false; + } + } + *ss << "):"; +} + +/* ************************************************************************* */ +std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; + ss << " *"; + streamSignature(*this, keyFormatter, &ss); + ss << "*\n" << std::endl; if (nrParents() == 0) { - // We have no parents, call factor method. - ss << ")*:\n" << std::endl; + // We have no parents, call factor method. ss << DecisionTreeFactor::markdown(keyFormatter, names); return ss.str(); } - // We have parents, continue signature and do custom print. - ss << "|"; - first = true; - for (Key parent : parents()) { - if (!first) ss << ","; - ss << keyFormatter(parent); - first = false; - } - ss << ")*:\n" << std::endl; - // Print out header and construct argument for `cartesianProduct`. - std::vector> pairs; ss << "|"; - const_iterator it; - for(Key parent: parents()) { + for (Key parent : parents()) { ss << "*" << keyFormatter(parent) << "*|"; - pairs.emplace_back(parent, cardinalities_.at(parent)); } size_t n = 1; - for(Key key: frontals()) { + for (Key key : frontals()) { size_t k = cardinalities_.at(key); - pairs.emplace_back(key, k); n *= k; } - std::vector> slatnorf(pairs.rbegin(), - pairs.rend() - nrParents()); - const auto frontal_assignments = cartesianProduct(slatnorf); - for (const auto& a : frontal_assignments) { - for (it = beginFrontals(); it != endFrontals(); ++it) { + for (const auto& a : frontalAssignments()) { + for (auto&& it = beginFrontals(); it != endFrontals(); ++it) { size_t index = a.at(*it); ss << Translate(names, *it, index); } @@ -354,13 +374,11 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter, ss << "\n"; // Print out all rows. - std::vector> rpairs(pairs.rbegin(), pairs.rend()); - const auto assignments = cartesianProduct(rpairs); size_t count = 0; - for (const auto& a : assignments) { + for (const auto& a : allAssignments()) { if (count == 0) { ss << "|"; - for (it = beginParents(); it != endParents(); ++it) { + for (auto&& it = beginParents(); it != endParents(); ++it) { size_t index = a.at(*it); ss << Translate(names, *it, index) << "|"; } @@ -371,6 +389,50 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter, } return ss.str(); } + +/* ************************************************************************ */ +string DiscreteConditional::html(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; + ss << "
\n

"; + streamSignature(*this, keyFormatter, &ss); + ss << "

\n"; + if (nrParents() == 0) { + // We have no parents, call factor method. + ss << DecisionTreeFactor::html(keyFormatter, names); + return ss.str(); + } + + // Print out preamble. + ss << "\n \n"; + + // Print out header row. + ss << " "; + for (auto& key : keys()) { + ss << ""; + } + ss << "\n"; + + // Finish header and start body. + ss << " \n \n"; + + // Print out all rows. + auto rows = enumerate(); + for (const auto& kv : rows) { + ss << " "; + auto assignment = kv.first; + for (auto& key : keys()) { + size_t index = assignment.at(key); + ss << ""; + } + ss << ""; // value + ss << "\n"; + } + ss << " \n
" << keyFormatter(key) << "value
" << Translate(names, key, index) << "" << kv.second << "
\n
"; + + return ss.str(); +} + /* ************************************************************************* */ } // namespace gtsam diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 4c2e964fd..3d258dbe7 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -21,9 +21,9 @@ #include #include #include -#include -#include +#include +#include #include namespace gtsam { @@ -32,24 +32,24 @@ namespace gtsam { * Discrete Conditional Density * Derives from DecisionTreeFactor */ -class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor, - public Conditional { - -public: +class GTSAM_EXPORT DiscreteConditional + : public DecisionTreeFactor, + public Conditional { + public: // typedefs needed to play nice with gtsam - typedef DiscreteConditional This; ///< Typedef to this class - typedef boost::shared_ptr shared_ptr; ///< shared_ptr to this class - typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class - typedef Conditional BaseConditional; ///< Typedef to our conditional base class + typedef DiscreteConditional This; ///< Typedef to this class + typedef boost::shared_ptr shared_ptr; ///< shared_ptr to this class + typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class + typedef Conditional + BaseConditional; ///< Typedef to our conditional base class - using Values = DiscreteValues; ///< backwards compatibility + using Values = DiscreteValues; ///< backwards compatibility /// @name Standard Constructors /// @{ /** default constructor needed for serialization */ - DiscreteConditional() { - } + DiscreteConditional() {} /** constructor from factor */ DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f); @@ -87,29 +87,33 @@ public: /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal); + const DecisionTreeFactor& marginal); /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal, const Ordering& orderedKeys); + const DecisionTreeFactor& marginal, + const Ordering& orderedKeys); /** * Combine several conditional into a single one. - * The conditionals must be given in increasing order, meaning that the parents - * of any conditional may not include a conditional coming before it. - * @param firstConditional Iterator to the first conditional to combine, must dereference to a shared_ptr. - * @param lastConditional Iterator to after the last conditional to combine, must dereference to a shared_ptr. + * The conditionals must be given in increasing order, meaning that the + * parents of any conditional may not include a conditional coming before it. + * @param firstConditional Iterator to the first conditional to combine, must + * dereference to a shared_ptr. + * @param lastConditional Iterator to after the last conditional to combine, + * must dereference to a shared_ptr. * */ - template + template static shared_ptr Combine(ITERATOR firstConditional, - ITERATOR lastConditional); + ITERATOR lastConditional); /// @} /// @name Testable /// @{ /// GTSAM-style print - void print(const std::string& s = "Discrete Conditional: ", + void print( + const std::string& s = "Discrete Conditional: ", const KeyFormatter& formatter = DefaultKeyFormatter) const override; /// GTSAM-style equals @@ -161,7 +165,6 @@ public: */ size_t sample(const DiscreteValues& parentsValues) const; - /// Single parent version. size_t sample(size_t parent_value) const; @@ -178,6 +181,12 @@ public: /// sample in place, stores result in partial solution void sampleInPlace(DiscreteValues* parentsValues) const; + /// Return all assignments for frontal variables. + std::vector> frontalAssignments() const; + + /// Return all assignments for frontal *and* parent variables. + std::vector> allAssignments() const; + /// @} /// @name Wrapper support /// @{ @@ -186,15 +195,20 @@ public: std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, const Names& names = {}) const override; + /// Render as html table. + std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; + /// @} }; // DiscreteConditional // traits -template<> struct traits : public Testable {}; +template <> +struct traits : public Testable {}; /* ************************************************************************* */ -template +template DiscreteConditional::shared_ptr DiscreteConditional::Combine( ITERATOR firstConditional, ITERATOR lastConditional) { // TODO: check for being a clique @@ -203,7 +217,7 @@ DiscreteConditional::shared_ptr DiscreteConditional::Combine( size_t nrFrontals = 0; DecisionTreeFactor product; for (ITERATOR it = firstConditional; it != lastConditional; - ++it, ++nrFrontals) { + ++it, ++nrFrontals) { DiscreteConditional::shared_ptr c = *it; DecisionTreeFactor::shared_ptr factor = c->toFactor(); product = (*factor) * product; @@ -212,5 +226,4 @@ DiscreteConditional::shared_ptr DiscreteConditional::Combine( return boost::make_shared(nrFrontals, product); } -} // gtsam - +} // namespace gtsam