diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index e59552007..95054bcdb 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -104,7 +104,7 @@ namespace gtsam { return ADT::operator()(values); } - /// Evaluate probability density, sugar. + /// Evaluate probability distribution, sugar. double operator()(const DiscreteValues& values) const override { return ADT::operator()(values); } diff --git a/gtsam/discrete/DiscreteBayesTree.h b/gtsam/discrete/DiscreteBayesTree.h index f65508a1a..b0c3dfb7d 100644 --- a/gtsam/discrete/DiscreteBayesTree.h +++ b/gtsam/discrete/DiscreteBayesTree.h @@ -63,7 +63,7 @@ class GTSAM_EXPORT DiscreteBayesTreeClique /* ************************************************************************* */ /** - * @brief A Bayes tree representing a Discrete density. + * @brief A Bayes tree representing a Discrete distribution. * @ingroup discrete */ class GTSAM_EXPORT DiscreteBayesTree diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index c40573ee8..5abc094fb 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -266,7 +266,7 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { static mt19937 rng(2); // random number generator - // Get the correct conditional density + // Get the correct conditional distribution ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) // TODO(Duy): only works for one key now, seems horribly slow this way diff --git a/gtsam/discrete/DiscreteMarginals.h b/gtsam/discrete/DiscreteMarginals.h index 7bb2cc6f8..b97e60805 100644 --- a/gtsam/discrete/DiscreteMarginals.h +++ b/gtsam/discrete/DiscreteMarginals.h @@ -41,7 +41,7 @@ class DiscreteMarginals { DiscreteMarginals() {} /** Construct a marginals class. - * @param graph The factor graph defining the full joint density on all variables. + * @param graph The factor graph defining the full joint distribution on all variables. */ DiscreteMarginals(const DiscreteFactorGraph& graph) { bayesTree_ = graph.eliminateMultifrontal(); diff --git a/gtsam/discrete/Signature.cpp b/gtsam/discrete/Signature.cpp index bc045e8c2..b666a2f4b 100644 --- a/gtsam/discrete/Signature.cpp +++ b/gtsam/discrete/Signature.cpp @@ -16,178 +16,129 @@ * @date Feb 27, 2011 */ -#include - #include "Signature.h" -#include // for parsing -#include // for qi::_val +#include +#include + +#include "gtsam/discrete/SignatureParser.h" namespace gtsam { - using namespace std; +using namespace std; - namespace qi = boost::spirit::qi; - namespace ph = boost::phoenix; +ostream& operator<<(ostream& os, const Signature::Row& row) { + os << row[0]; + for (size_t i = 1; i < row.size(); i++) os << " " << row[i]; + return os; +} - // parser for strings of form "99/1 80/20" etc... - namespace parser { - typedef string::const_iterator It; - using boost::phoenix::val; - using boost::phoenix::ref; - using boost::phoenix::push_back; +ostream& operator<<(ostream& os, const Signature::Table& table) { + for (size_t i = 0; i < table.size(); i++) os << table[i] << endl; + return os; +} - // Special rows, true and false - Signature::Row F{1, 0}, T{0, 1}; +Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const Table& table) + : key_(key), parents_(parents) { + operator=(table); +} - // Special tables (inefficient, but do we care for user input?) - Signature::Table logic(bool ff, bool ft, bool tf, bool tt) { - Signature::Table t(4); - t[0] = ff ? T : F; - t[1] = ft ? T : F; - t[2] = tf ? T : F; - t[3] = tt ? T : F; - return t; - } +Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const std::string& spec) + : key_(key), parents_(parents) { + operator=(spec); +} - struct Grammar { - qi::rule table, or_, and_, rows; - qi::rule true_, false_, row; - Grammar() { - table = or_ | and_ | rows; - or_ = qi::lit("OR")[qi::_val = logic(false, true, true, true)]; - and_ = qi::lit("AND")[qi::_val = logic(false, false, false, true)]; - rows = +(row | true_ | false_); - row = qi::double_ >> +("/" >> qi::double_); - true_ = qi::lit("T")[qi::_val = T]; - false_ = qi::lit("F")[qi::_val = F]; - } - } grammar; +Signature::Signature(const DiscreteKey& key) : key_(key) {} - } // \namespace parser +DiscreteKeys Signature::discreteKeys() const { + DiscreteKeys keys; + keys.push_back(key_); + for (const DiscreteKey& key : parents_) keys.push_back(key); + return keys; +} - ostream& operator <<(ostream &os, const Signature::Row &row) { - os << row[0]; - for (size_t i = 1; i < row.size(); i++) - os << " " << row[i]; - return os; - } +KeyVector Signature::indices() const { + KeyVector js; + js.push_back(key_.first); + for (const DiscreteKey& key : parents_) js.push_back(key.first); + return js; +} - ostream& operator <<(ostream &os, const Signature::Table &table) { - for (size_t i = 0; i < table.size(); i++) - os << table[i] << endl; - return os; - } - - Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents, - const Table& table) - : key_(key), parents_(parents) { - operator=(table); - } - - Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents, - const std::string& spec) - : key_(key), parents_(parents) { - operator=(spec); - } - - Signature::Signature(const DiscreteKey& key) : - key_(key) { - } - - DiscreteKeys Signature::discreteKeys() const { - DiscreteKeys keys; - keys.push_back(key_); - for (const DiscreteKey& key : parents_) keys.push_back(key); - return keys; - } - - KeyVector Signature::indices() const { - KeyVector js; - js.push_back(key_.first); - for (const DiscreteKey& key : parents_) js.push_back(key.first); - return js; - } - - vector Signature::cpt() const { - vector cpt; - if (table_) { - const size_t nrStates = table_->at(0).size(); - for (size_t j = 0; j < nrStates; j++) { - for (const Row& row : *table_) { - assert(row.size() == nrStates); - cpt.push_back(row[j]); - } +vector Signature::cpt() const { + vector cpt; + if (table_) { + const size_t nrStates = table_->at(0).size(); + for (size_t j = 0; j < nrStates; j++) { + for (const Row& row : *table_) { + assert(row.size() == nrStates); + cpt.push_back(row[j]); } } - return cpt; } + return cpt; +} - Signature& Signature::operator,(const DiscreteKey& parent) { - parents_.push_back(parent); - return *this; +Signature &Signature::operator,(const DiscreteKey& parent) { + parents_.push_back(parent); + return *this; +} + +static void normalize(Signature::Row& row) { + double sum = 0; + for (size_t i = 0; i < row.size(); i++) sum += row[i]; + for (size_t i = 0; i < row.size(); i++) row[i] /= sum; +} + +Signature& Signature::operator=(const string& spec) { + spec_ = spec; + auto table = SignatureParser::Parse(spec); + if (table) { + for (Row& row : *table) normalize(row); + table_ = *table; } + return *this; +} - static void normalize(Signature::Row& row) { - double sum = 0; - for (size_t i = 0; i < row.size(); i++) - sum += row[i]; - for (size_t i = 0; i < row.size(); i++) - row[i] /= sum; +Signature& Signature::operator=(const Table& t) { + Table table = t; + for (Row& row : table) normalize(row); + table_ = table; + return *this; +} + +ostream& operator<<(ostream& os, const Signature& s) { + os << s.key_.first; + if (s.parents_.empty()) { + os << " % "; + } else { + os << " | " << s.parents_[0].first; + for (size_t i = 1; i < s.parents_.size(); i++) + os << " && " << s.parents_[i].first; + os << " = "; } + os << (s.spec_ ? *s.spec_ : "no spec") << endl; + if (s.table_) + os << (*s.table_); + else + os << "spec could not be parsed" << endl; + return os; +} - Signature& Signature::operator=(const string& spec) { - spec_ = spec; - Table table; - parser::It f = spec.begin(), l = spec.end(); - bool success = - qi::phrase_parse(f, l, parser::grammar.table, qi::space, table); - if (success) { - for (Row& row : table) normalize(row); - table_ = table; - } - return *this; - } +Signature operator|(const DiscreteKey& key, const DiscreteKey& parent) { + Signature s(key); + return s, parent; +} - Signature& Signature::operator=(const Table& t) { - Table table = t; - for(Row& row: table) - normalize(row); - table_ = table; - return *this; - } +Signature operator%(const DiscreteKey& key, const string& parent) { + Signature s(key); + return s = parent; +} - ostream& operator <<(ostream &os, const Signature &s) { - os << s.key_.first; - if (s.parents_.empty()) { - os << " % "; - } else { - os << " | " << s.parents_[0].first; - for (size_t i = 1; i < s.parents_.size(); i++) - os << " && " << s.parents_[i].first; - os << " = "; - } - os << (s.spec_ ? *s.spec_ : "no spec") << endl; - if (s.table_) - os << (*s.table_); - else - os << "spec could not be parsed" << endl; - return os; - } +Signature operator%(const DiscreteKey& key, const Signature::Table& parent) { + Signature s(key); + return s = parent; +} - Signature operator|(const DiscreteKey& key, const DiscreteKey& parent) { - Signature s(key); - return s, parent; - } - - Signature operator%(const DiscreteKey& key, const string& parent) { - Signature s(key); - return s = parent; - } - - Signature operator%(const DiscreteKey& key, const Signature::Table& parent) { - Signature s(key); - return s = parent; - } - -} // namespace gtsam +} // namespace gtsam diff --git a/gtsam/discrete/Signature.h b/gtsam/discrete/Signature.h index 963843ab2..df175bc10 100644 --- a/gtsam/discrete/Signature.h +++ b/gtsam/discrete/Signature.h @@ -25,7 +25,7 @@ namespace gtsam { /** - * Signature for a discrete conditional density, used to construct conditionals. + * Signature for a discrete conditional distribution, used to construct conditionals. * * The format is (Key % string) for nodes with no parents, * and (Key | Key, Key = string) for nodes with parents. diff --git a/gtsam/discrete/SignatureParser.cpp b/gtsam/discrete/SignatureParser.cpp new file mode 100644 index 000000000..ab259d0ea --- /dev/null +++ b/gtsam/discrete/SignatureParser.cpp @@ -0,0 +1,117 @@ +#include + +#include +#include +#include +#include +namespace gtsam { + +using Row = std::vector; +using Table = std::vector; + +inline static Row ParseTrueRow() { return {0, 1}; } + +inline static Row ParseFalseRow() { return {1, 0}; } + +inline static Table ParseOr() { + return {ParseFalseRow(), ParseTrueRow(), ParseTrueRow(), ParseTrueRow()}; +} + +inline static Table ParseAnd() { + return {ParseFalseRow(), ParseFalseRow(), ParseFalseRow(), ParseTrueRow()}; +} + +std::optional static ParseConditional(const std::string& token) { + // Expect something like a/b/c + std::istringstream iss2(token); + Row row; + try { + // if the string has no / then return std::nullopt + if (std::count(token.begin(), token.end(), '/') == 0) return std::nullopt; + // split the word on the '/' character + for (std::string s; std::getline(iss2, s, '/');) { + // can throw exception + row.push_back(std::stod(s)); + } + } catch (...) { + return std::nullopt; + } + return row; +} + +std::optional static ParseConditionalTable( + const std::vector& tokens) { + Table table; + // loop over the words + // for each word, split it into doubles using a stringstream + for (const auto& word : tokens) { + // If the string word is F or T then the row is {0,1} or {1,0} respectively + if (word == "F") { + table.push_back(ParseFalseRow()); + } else if (word == "T") { + table.push_back(ParseTrueRow()); + } else { + // Expect something like a/b/c + if (auto row = ParseConditional(word)) { + table.push_back(*row); + } else { + // stop parsing if we encounter an error + return std::nullopt; + } + } + } + return table; +} + +std::vector static Tokenize(const std::string& str) { + std::istringstream iss(str); + std::vector tokens; + for (std::string s; iss >> s;) { + tokens.push_back(s); + } + return tokens; +} + +std::optional
SignatureParser::Parse(const std::string& str) { + // check if string is just whitespace + if (std::all_of(str.begin(), str.end(), isspace)) { + return std::nullopt; + } + + // return std::nullopt if the string is empty + if (str.empty()) { + return std::nullopt; + } + + // tokenize the str on whitespace + std::vector tokens = Tokenize(str); + + // if the first token is "OR", return the OR table + if (tokens[0] == "OR") { + // if there are more tokens, return std::nullopt + if (tokens.size() > 1) { + return std::nullopt; + } + return ParseOr(); + } + + // if the first token is "AND", return the AND table + if (tokens[0] == "AND") { + // if there are more tokens, return std::nullopt + if (tokens.size() > 1) { + return std::nullopt; + } + return ParseAnd(); + } + + // otherwise then parse the conditional table + auto table = ParseConditionalTable(tokens); + // return std::nullopt if the table is empty + if (!table || table->empty()) { + return std::nullopt; + } + // the boost::phoenix parser did not return an error if we could not fully + // parse a string it just returned whatever it could parse + return table; +} +} // namespace gtsam diff --git a/gtsam/discrete/SignatureParser.h b/gtsam/discrete/SignatureParser.h new file mode 100644 index 000000000..e6b402e44 --- /dev/null +++ b/gtsam/discrete/SignatureParser.h @@ -0,0 +1,56 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file SignatureParser.h + * @brief Parser for conditional distribution signatures. + * @author Kartik Arcot + * @date January 2023 + */ + +#pragma once + +#include +#include +#include + +namespace gtsam { +/** + * @brief A simple parser that replaces the boost spirit parser. + * + * It is meant to parse strings like "1/1 2/3 1/4". Every word of the form + * "a/b/c/..." is parsed as a row, and the rows are stored in a table. + * + * A `Row` is a vector of doubles, and a `Table` is a vector of rows. + * + * Examples: the parser is able to parse the following strings: + * "1/1 2/3 1/4": + * {{1,1},{2,3},{1,4}} + * "1/1 2/3 1/4 1/1 2/3 1/4" : + * {{1,1},{2,3},{1,4},{1,1},{2,3},{1,4}} + * "1/2/3 2/3/4 1/2/3 2/3/4 1/2/3 2/3/4 1/2/3 2/3/4 1/2/3" : + * {{1,2,3},{2,3,4},{1,2,3},{2,3,4},{1,2,3},{2,3,4},{1,2,3},{2,3,4},{1,2,3}} + * + * If the string has un-parsable elements the parser will fail with nullopt: + * "1/2 sdf" : nullopt ! + * + * It also fails if the string is empty: + * "": nullopt ! + * + * Also fails if the rows are not of the same size. + */ +struct SignatureParser { + using Row = std::vector; + using Table = std::vector; + + static std::optional
Parse(const std::string& str); +}; +} // namespace gtsam diff --git a/gtsam/discrete/tests/testSignature.cpp b/gtsam/discrete/tests/testSignature.cpp index 02e7a1d10..bf39e8417 100644 --- a/gtsam/discrete/tests/testSignature.cpp +++ b/gtsam/discrete/tests/testSignature.cpp @@ -29,7 +29,7 @@ using namespace gtsam; DiscreteKey X(0, 2), Y(1, 3), Z(2, 2); /* ************************************************************************* */ -TEST(testSignature, simple_conditional) { +TEST(testSignature, SimpleConditional) { Signature sig(X, {Y}, "1/1 2/3 1/4"); CHECK(sig.table()); Signature::Table table = *sig.table(); @@ -54,7 +54,7 @@ TEST(testSignature, simple_conditional) { } /* ************************************************************************* */ -TEST(testSignature, simple_conditional_nonparser) { +TEST(testSignature, SimpleConditionalNonparser) { Signature::Row row1{1, 1}, row2{2, 3}, row3{1, 4}; Signature::Table table{row1, row2, row3}; @@ -77,7 +77,7 @@ TEST(testSignature, simple_conditional_nonparser) { DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), D(7, 2); // Make sure we can create all signatures for Asia network with constructor. -TEST(testSignature, all_examples) { +TEST(testSignature, AllExamples) { DiscreteKey X(6, 2); Signature a(A, {}, "99/1"); Signature s(S, {}, "50/50"); @@ -89,7 +89,7 @@ TEST(testSignature, all_examples) { } // Make sure we can create all signatures for Asia network with operator magic. -TEST(testSignature, all_examples_magic) { +TEST(testSignature, AllExamplesMagic) { DiscreteKey X(6, 2); Signature a(A % "99/1"); Signature s(S % "50/50"); @@ -101,7 +101,7 @@ TEST(testSignature, all_examples_magic) { } // Check example from docs. -TEST(testSignature, doxygen_example) { +TEST(testSignature, DoxygenExample) { Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}}; Signature d1(D, {E, B}, table); Signature d2((D | E, B) = "9/1 2/8 3/7 1/9"); diff --git a/gtsam/discrete/tests/testSignatureParser.cpp b/gtsam/discrete/tests/testSignatureParser.cpp new file mode 100644 index 000000000..5ae71442e --- /dev/null +++ b/gtsam/discrete/tests/testSignatureParser.cpp @@ -0,0 +1,98 @@ +/** + * Unit tests for the SimpleParser class. + * @file testSimpleParser.cpp + */ + +#include +#include +#include + +using namespace gtsam; + +/* ************************************************************************* */ +// Simple test case +bool compareTables(const SignatureParser::Table& table1, + const SignatureParser::Table& table2) { + if (table1.size() != table2.size()) { + return false; + } + for (size_t i = 0; i < table1.size(); ++i) { + if (table1[i].size() != table2[i].size()) { + return false; + } + for (size_t j = 0; j < table1[i].size(); ++j) { + if (table1[i][j] != table2[i][j]) { + return false; + } + } + } + return true; +} + +/* ************************************************************************* */ +// Simple test case +TEST(SimpleParser, Simple) { + SignatureParser::Table expectedTable{{1, 1}, {2, 3}, {1, 4}}; + const auto table = SignatureParser::Parse("1/1 2/3 1/4"); + CHECK(table); + // compare the tables + EXPECT(compareTables(*table, expectedTable)); +} + +/* ************************************************************************* */ +// Test case with each row having 3 elements +TEST(SimpleParser, ThreeElements) { + SignatureParser::Table expectedTable{{1, 1, 1}, {2, 3, 2}, {1, 4, 3}}; + const auto table = SignatureParser::Parse("1/1/1 2/3/2 1/4/3"); + CHECK(table); + // compare the tables + EXPECT(compareTables(*table, expectedTable)); +} + +/* ************************************************************************* */ +// A test case to check if we can parse a signature with 'T' and 'F' +TEST(SimpleParser, TAndF) { + SignatureParser::Table expectedTable{{1, 0}, {1, 0}, {1, 0}, {0, 1}}; + const auto table = SignatureParser::Parse("F F F T"); + CHECK(table); + // compare the tables + EXPECT(compareTables(*table, expectedTable)); +} + +/* ************************************************************************* */ +// A test to parse {F F F 1} +TEST(SimpleParser, FFF1) { + SignatureParser::Table expectedTable{{1, 0}, {1, 0}, {1, 0}}; + const auto table = SignatureParser::Parse("F F F"); + CHECK(table); + // compare the tables + EXPECT(compareTables(*table, expectedTable)); +} + +/* ************************************************************************* */ +// Expect false if the string is empty +TEST(SimpleParser, emptyString) { + const auto table = SignatureParser::Parse(""); + EXPECT(!table); +} + +/* ************************************************************************* */ +// Expect false if gibberish +TEST(SimpleParser, Gibberish) { + const auto table = SignatureParser::Parse("sdf 22/3"); + EXPECT(!table); +} + +// If Gibberish is in the middle, it should not parse. +TEST(SimpleParser, GibberishInMiddle) { + SignatureParser::Table expectedTable{{1, 1}, {2, 3}}; + const auto table = SignatureParser::Parse("1/1 2/3 sdf 1/4"); + EXPECT(!table); +} + +/* ************************************************************************* */ + +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +}