diff --git a/examples/UGM_small.cpp b/examples/UGM_small.cpp index 3829a5c91..24bd0c0ba 100644 --- a/examples/UGM_small.cpp +++ b/examples/UGM_small.cpp @@ -50,7 +50,8 @@ int main(int argc, char** argv) { // Print the UGM distribution cout << "\nUGM distribution:" << endl; - auto allPosbValues = cartesianProduct(Cathy & Heather & Mark & Allison); + auto allPosbValues = + DiscreteValues::CartesianProduct(Cathy & Heather & Mark & Allison); for (size_t i = 0; i < allPosbValues.size(); ++i) { DiscreteFactor::Values values = allPosbValues[i]; double prodPot = graph(values); diff --git a/gtsam/discrete/Assignment.h b/gtsam/discrete/Assignment.h index 3665d6dfa..cdbf0a2e9 100644 --- a/gtsam/discrete/Assignment.h +++ b/gtsam/discrete/Assignment.h @@ -19,32 +19,30 @@ #pragma once #include -#include #include - +#include +#include namespace gtsam { - /** - * An assignment from labels to value index (size_t). - * Assigns to each label a value. Implemented as a simple map. - * A discrete factor takes an Assignment and returns a value. - */ - template - class Assignment: public std::map { - public: - void print(const std::string& s = "Assignment: ") const { - std::cout << s << ": "; - for(const typename Assignment::value_type& keyValue: *this) - std::cout << "(" << keyValue.first << ", " << keyValue.second << ")"; - std::cout << std::endl; - } - - bool equals(const Assignment& other, double tol = 1e-9) const { - return (*this == other); - } - }; //Assignment +/** + * An assignment from labels to value index (size_t). + * Assigns to each label a value. Implemented as a simple map. + * A discrete factor takes an Assignment and returns a value. + */ +template +class Assignment : public std::map { + public: + void print(const std::string& s = "Assignment: ") const { + std::cout << s << ": "; + for (const typename Assignment::value_type& keyValue : *this) + std::cout << "(" << keyValue.first << ", " << keyValue.second << ")"; + std::cout << std::endl; + } + bool equals(const Assignment& other, double tol = 1e-9) const { + return (*this == other); + } /** * @brief Get Cartesian product consisting all possible configurations @@ -58,29 +56,28 @@ namespace gtsam { * variables with each having cardinalities 4, we get 4096 possible * configurations!! */ - template - std::vector > cartesianProduct( - const std::vector >& keys) { - std::vector > allPossValues; - Assignment values; + template > + static std::vector CartesianProduct( + const std::vector>& keys) { + std::vector allPossValues; + Derived values; typedef std::pair DiscreteKey; - for(const DiscreteKey& key: keys) - values[key.first] = 0; //Initialize from 0 + for (const DiscreteKey& key : keys) + values[key.first] = 0; // Initialize from 0 while (1) { allPossValues.push_back(values); size_t j = 0; for (j = 0; j < keys.size(); j++) { L idx = keys[j].first; values[idx]++; - if (values[idx] < keys[j].second) - break; - //Wrap condition + if (values[idx] < keys[j].second) break; + // Wrap condition values[idx] = 0; } - if (j == keys.size()) - break; + if (j == keys.size()) break; } return allPossValues; } +}; // Assignment -} // namespace gtsam +} // namespace gtsam diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 3f40ddf3b..c50811a50 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -22,6 +22,7 @@ #include #include +#include using namespace std; @@ -150,9 +151,9 @@ namespace gtsam { for (auto& key : keys()) { pairs.emplace_back(key, cardinalities_.at(key)); } - // Reverse to make cartesianProduct output a more natural ordering. + // Reverse to make cartesian product output a more natural ordering. std::vector> rpairs(pairs.rbegin(), pairs.rend()); - const auto assignments = cartesianProduct(rpairs); + const auto assignments = DiscreteValues::CartesianProduct(rpairs); // Construct unordered_map with values std::vector> result; @@ -212,7 +213,7 @@ namespace gtsam { auto assignment = kv.first; for (auto& key : keys()) { size_t index = assignment.at(key); - ss << Translate(names, key, index) << "|"; + ss << DiscreteValues::Translate(names, key, index) << "|"; } ss << kv.second << "|\n"; } @@ -244,7 +245,7 @@ namespace gtsam { auto assignment = kv.first; for (auto& key : keys()) { size_t index = assignment.at(key); - ss << "" << Translate(names, key, index) << ""; + ss << "" << DiscreteValues::Translate(names, key, index) << ""; } ss << "" << kv.second << ""; // value ss << "\n"; diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 512ff88b4..82018abea 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -180,26 +180,21 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( return likelihood(values); } -/* ******************************************************************************** */ +/* ************************************************************************** */ void DiscreteConditional::solveInPlace(DiscreteValues* values) const { - // TODO: Abhijit asks: is this really the fastest way? He thinks it is. - ADT pFS = Choose(*this, *values); // P(F|S=parentsValues) + // TODO(Abhijit): is this really the fastest way? He thinks it is. + ADT pFS = Choose(*this, *values); // P(F|S=parentsValues) // Initialize DiscreteValues mpe; double maxP = 0; - DiscreteKeys keys; - for(Key idx: frontals()) { - DiscreteKey dk(idx, cardinality(idx)); - keys & dk; - } // Get all Possible Configurations - const auto allPosbValues = cartesianProduct(keys); + const auto allPosbValues = frontalAssignments(); // Find the MPE - for(const auto& frontalVals: allPosbValues) { - double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) + for (const auto& frontalVals : allPosbValues) { + double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) // Update MPE solution if better if (pValueS > maxP) { maxP = pValueS; @@ -207,8 +202,8 @@ void DiscreteConditional::solveInPlace(DiscreteValues* values) const { } } - //set values (inPlace) to mpe - for(Key j: frontals()) { + // set values (inPlace) to mpe + for (Key j : frontals()) { (*values)[j] = mpe[j]; } } @@ -295,20 +290,20 @@ size_t DiscreteConditional::sample() const { } /* ************************************************************************* */ -vector> DiscreteConditional::frontalAssignments() const { +vector DiscreteConditional::frontalAssignments() const { vector> pairs; for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key)); vector> rpairs(pairs.rbegin(), pairs.rend()); - return cartesianProduct(rpairs); + return DiscreteValues::CartesianProduct(rpairs); } /* ************************************************************************* */ -vector> DiscreteConditional::allAssignments() const { +vector DiscreteConditional::allAssignments() const { vector> pairs; for (Key key : parents()) pairs.emplace_back(key, cardinalities_.at(key)); for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key)); vector> rpairs(pairs.rbegin(), pairs.rend()); - return cartesianProduct(rpairs); + return DiscreteValues::CartesianProduct(rpairs); } /* ************************************************************************* */ @@ -358,7 +353,7 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter, for (const auto& a : frontalAssignments) { for (auto&& it = beginFrontals(); it != endFrontals(); ++it) { size_t index = a.at(*it); - ss << Translate(names, *it, index); + ss << DiscreteValues::Translate(names, *it, index); } ss << "|"; } @@ -377,7 +372,7 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter, ss << "|"; for (auto&& it = beginParents(); it != endParents(); ++it) { size_t index = a.at(*it); - ss << Translate(names, *it, index) << "|"; + ss << DiscreteValues::Translate(names, *it, index) << "|"; } } ss << operator()(a) << "|"; @@ -413,7 +408,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter, ss << ""; for (auto&& it = beginFrontals(); it != endFrontals(); ++it) { size_t index = a.at(*it); - ss << Translate(names, *it, index); + ss << DiscreteValues::Translate(names, *it, index); } ss << ""; } @@ -429,7 +424,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter, ss << " "; for (auto&& it = beginParents(); it != endParents(); ++it) { size_t index = a.at(*it); - ss << "" << Translate(names, *it, index) << ""; + ss << "" << DiscreteValues::Translate(names, *it, index) << ""; } } ss << "" << operator()(a) << ""; // value diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 3d258dbe7..4a83ff83a 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -25,6 +25,7 @@ #include #include #include +#include namespace gtsam { @@ -182,10 +183,10 @@ class GTSAM_EXPORT DiscreteConditional void sampleInPlace(DiscreteValues* parentsValues) const; /// Return all assignments for frontal variables. - std::vector> frontalAssignments() const; + std::vector frontalAssignments() const; /// Return all assignments for frontal *and* parent variables. - std::vector> allAssignments() const; + std::vector allAssignments() const; /// @} /// @name Wrapper support diff --git a/gtsam/discrete/DiscreteFactor.cpp b/gtsam/discrete/DiscreteFactor.cpp index 1a12ef405..0cf7f2a5e 100644 --- a/gtsam/discrete/DiscreteFactor.cpp +++ b/gtsam/discrete/DiscreteFactor.cpp @@ -25,14 +25,4 @@ using namespace std; namespace gtsam { -string DiscreteFactor::Translate(const Names& names, Key key, size_t index) { - if (names.empty()) { - stringstream ss; - ss << index; - return ss.str(); - } else { - return names.at(key)[index]; - } -} - } // namespace gtsam diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 43486c5ae..8f39fbc23 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -22,6 +22,7 @@ #include #include +#include namespace gtsam { class DecisionTreeFactor; @@ -90,14 +91,11 @@ public: /// @{ /// Translation table from values to strings. - using Names = std::map>; - - /// Translate an integer index value for given key to a string. - static std::string Translate(const Names& names, Key key, size_t index); + using Names = DiscreteValues::Names; /** * @brief Render as markdown table - * + * * @param keyFormatter GTSAM-style Key formatter. * @param names optional, category names corresponding to choices. * @return std::string a markdown string. diff --git a/gtsam/discrete/DiscreteValues.cpp b/gtsam/discrete/DiscreteValues.cpp new file mode 100644 index 000000000..9bd0572bd --- /dev/null +++ b/gtsam/discrete/DiscreteValues.cpp @@ -0,0 +1,86 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteValues.cpp + * @date January, 2022 + * @author Frank Dellaert + */ + +#include + +#include + +using std::cout; +using std::endl; +using std::string; +using std::stringstream; + +namespace gtsam { + +void DiscreteValues::print(const string& s, + const KeyFormatter& keyFormatter) const { + cout << s << ": "; + for (auto&& kv : *this) + cout << "(" << keyFormatter(kv.first) << ", " << kv.second << ")"; + cout << endl; +} + +string DiscreteValues::Translate(const Names& names, Key key, size_t index) { + if (names.empty()) { + stringstream ss; + ss << index; + return ss.str(); + } else { + return names.at(key)[index]; + } +} + +string DiscreteValues::markdown(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; + + // Print out header and separator with alignment hints. + ss << "|Variable|value|\n|:-:|:-:|\n"; + + // Print out all rows. + for (const auto& kv : *this) { + ss << "|" << keyFormatter(kv.first) << "|" + << Translate(names, kv.first, kv.second) << "|\n"; + } + + return ss.str(); +} + +std::string DiscreteValues::html(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; + + // Print out preamble. + ss << "
\n\n \n"; + + // Print out header row. + ss << " \n"; + + // Finish header and start body. + ss << " \n \n"; + + // Print out all rows. + for (const auto& kv : *this) { + ss << " "; + ss << ""; + ss << "\n"; + } + ss << " \n
Variablevalue
" << keyFormatter(kv.first) << "\'" + << Translate(names, kv.first, kv.second) << "
\n
"; + return ss.str(); +} +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteValues.h b/gtsam/discrete/DiscreteValues.h index 2d9c8d3cf..3d0c54ab6 100644 --- a/gtsam/discrete/DiscreteValues.h +++ b/gtsam/discrete/DiscreteValues.h @@ -18,8 +18,11 @@ #pragma once #include +#include #include +#include + namespace gtsam { /** A map from keys to values @@ -34,25 +37,57 @@ namespace gtsam { */ class DiscreteValues : public Assignment { public: - using Assignment::Assignment; // all constructors + using Base = Assignment; // base class + + using Assignment::Assignment; // all constructors // Define the implicit default constructor. DiscreteValues() = default; // Construct from assignment. - DiscreteValues(const Assignment& a) : Assignment(a) {} + explicit DiscreteValues(const Base& a) : Base(a) {} void print(const std::string& s = "", - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { - std::cout << s << ": "; - for (const typename Assignment::value_type& keyValue : *this) - std::cout << "(" << keyFormatter(keyValue.first) << ", " - << keyValue.second << ")"; - std::cout << std::endl; + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + static std::vector CartesianProduct( + const DiscreteKeys& keys) { + return Base::CartesianProduct(keys); } + + /// @name Wrapper support + /// @{ + + /// Translation table from values to strings. + using Names = std::map>; + + /// Translate an integer index value for given key to a string. + static std::string Translate(const Names& names, Key key, size_t index); + + /** + * @brief Output as a markdown table. + * + * @param keyFormatter function that formats keys. + * @param names translation table for values. + * @return string markdown output. + */ + std::string markdown(const KeyFormatter& keyFormatter, + const Names& names) const; + + /** + * @brief Output as a html table. + * + * @param keyFormatter function that formats keys. + * @param names translation table for values. + * @return string html output. + */ + std::string html(const KeyFormatter& keyFormatter, const Names& names) const; + + /// @} }; // traits -template<> struct traits : public Testable {}; +template <> +struct traits : public Testable {}; } // namespace gtsam diff --git a/gtsam/discrete/tests/testDiscreteBayesTree.cpp b/gtsam/discrete/tests/testDiscreteBayesTree.cpp index edb5ea46c..26356be3d 100644 --- a/gtsam/discrete/tests/testDiscreteBayesTree.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesTree.cpp @@ -101,10 +101,10 @@ TEST(DiscreteBayesTree, ThinTree) { auto R = self.bayesTree->roots().front(); // Check whether BN and BT give the same answer on all configurations - auto allPosbValues = - cartesianProduct(keys[0] & keys[1] & keys[2] & keys[3] & keys[4] & - keys[5] & keys[6] & keys[7] & keys[8] & keys[9] & - keys[10] & keys[11] & keys[12] & keys[13] & keys[14]); + auto allPosbValues = DiscreteValues::CartesianProduct( + keys[0] & keys[1] & keys[2] & keys[3] & keys[4] & keys[5] & keys[6] & + keys[7] & keys[8] & keys[9] & keys[10] & keys[11] & keys[12] & keys[13] & + keys[14]); for (size_t i = 0; i < allPosbValues.size(); ++i) { DiscreteValues x = allPosbValues[i]; double expected = self.bayesNet.evaluate(x); diff --git a/gtsam/discrete/tests/testDiscreteMarginals.cpp b/gtsam/discrete/tests/testDiscreteMarginals.cpp index e75016b68..3208f81c5 100644 --- a/gtsam/discrete/tests/testDiscreteMarginals.cpp +++ b/gtsam/discrete/tests/testDiscreteMarginals.cpp @@ -164,8 +164,8 @@ TEST_UNSAFE(DiscreteMarginals, truss2) { graph.add(key[2] & key[3] & key[4], "1 2 3 4 5 6 7 8"); // Calculate the marginals by brute force - auto allPosbValues = - cartesianProduct(key[0] & key[1] & key[2] & key[3] & key[4]); + auto allPosbValues = DiscreteValues::CartesianProduct( + key[0] & key[1] & key[2] & key[3] & key[4]); Vector T = Z_5x1, F = Z_5x1; for (size_t i = 0; i < allPosbValues.size(); ++i) { DiscreteValues x = allPosbValues[i]; diff --git a/gtsam/discrete/tests/testDiscreteValues.cpp b/gtsam/discrete/tests/testDiscreteValues.cpp new file mode 100644 index 000000000..5e7c0ac6f --- /dev/null +++ b/gtsam/discrete/tests/testDiscreteValues.cpp @@ -0,0 +1,76 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * testDiscreteValues.cpp + * + * @date Jan, 2022 + * @author Frank Dellaert + */ + +#include +#include +#include +#include + +#include +using namespace boost::assign; + +using namespace std; +using namespace gtsam; + +/* ************************************************************************* */ +// Check markdown representation with a value formatter. +TEST(DiscreteValues, markdownWithValueFormatter) { + DiscreteValues values; + values[12] = 1; // A + values[5] = 0; // B + string expected = + "|Variable|value|\n" + "|:-:|:-:|\n" + "|B|-|\n" + "|A|One|\n"; + auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; + DiscreteValues::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}}; + string actual = values.markdown(keyFormatter, names); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check html representation with a value formatter. +TEST(DiscreteValues, htmlWithValueFormatter) { + DiscreteValues values; + values[12] = 1; // A + values[5] = 0; // B + string expected = + "
\n" + "\n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + "
Variablevalue
B'-
A'One
\n" + "
"; + auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; + DiscreteValues::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}}; + string actual = values.html(keyFormatter, names); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */