diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index 796c0c8c8..219f2d93e 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -34,16 +34,6 @@ namespace gtsam { return Base::equals(bn, tol); } - /* ************************************************************************* */ -// void DiscreteBayesNet::add_front(const Signature& s) { -// push_front(boost::make_shared(s)); -// } - - /* ************************************************************************* */ - void DiscreteBayesNet::add(const Signature& s) { - push_back(boost::make_shared(s)); - } - /* ************************************************************************* */ double DiscreteBayesNet::evaluate(const DiscreteValues & values) const { // evaluate all conditionals and multiply diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index 9f5f10388..2d92b72e8 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -71,15 +71,23 @@ namespace gtsam { /// @name Standard Interface /// @{ + // Add inherited versions of add. + using Base::add; + /** Add a DiscreteCondtional */ - void add(const Signature& s); - -// /** Add a DiscreteCondtional in front, when listing parents first*/ -// GTSAM_EXPORT void add_front(const Signature& s); - + template + void add(Args&&... args) { + emplace_shared(std::forward(args)...); + } + //** evaluate for given DiscreteValues */ double evaluate(const DiscreteValues & values) const; + //** (Preferred) sugar for the above for given DiscreteValues */ + double operator()(const DiscreteValues & values) const { + return evaluate(values); + } + /** * Solve the DiscreteBayesNet by back-substitution */ diff --git a/gtsam/discrete/DiscreteBayesTree.h b/gtsam/discrete/DiscreteBayesTree.h index 655bcb9ee..42ec7d417 100644 --- a/gtsam/discrete/DiscreteBayesTree.h +++ b/gtsam/discrete/DiscreteBayesTree.h @@ -80,6 +80,12 @@ class GTSAM_EXPORT DiscreteBayesTree //** evaluate probability for given DiscreteValues */ double evaluate(const DiscreteValues& values) const; + + //** (Preferred) sugar for the above for given DiscreteValues */ + double operator()(const DiscreteValues & values) const { + return evaluate(values); + } + }; } // namespace gtsam diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index b2d47be3a..371b15ac0 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -97,25 +97,41 @@ bool DiscreteConditional::equals(const DiscreteFactor& other, } /* ******************************************************************************** */ -Potentials::ADT DiscreteConditional::choose(const DiscreteValues& parentsValues) const { +Potentials::ADT DiscreteConditional::choose( + const DiscreteValues& parentsValues) const { + // Get the big decision tree with all the levels, and then go down the + // branches based on the value of the parent variables. ADT pFS(*this); - Key j; size_t value; - for(Key key: parents()) { + size_t value; + for (Key j : parents()) { try { - j = (key); value = parentsValues.at(j); - pFS = pFS.choose(j, value); + pFS = pFS.choose(j, value); // ADT keeps getting smaller. } catch (exception&) { cout << "Key: " << j << " Value: " << value << endl; parentsValues.print("parentsValues: "); - // pFS.print("pFS: "); throw runtime_error("DiscreteConditional::choose: parent value missing"); }; } - return pFS; } +/* ******************************************************************************** */ +DecisionTreeFactor::shared_ptr DiscreteConditional::chooseAsFactor( + const DiscreteValues& parentsValues) const { + ADT pFS = choose(parentsValues); + + // Convert ADT to factor. + if (nrFrontals() != 1) { + throw std::runtime_error("Expected only one frontal variable in choose."); + } + DiscreteKeys keys; + const Key frontalKey = keys_[0]; + size_t frontalCardinality = this->cardinality(frontalKey); + keys.push_back(DiscreteKey(frontalKey, frontalCardinality)); + return boost::make_shared(keys, pFS); +} + /* ******************************************************************************** */ void DiscreteConditional::solveInPlace(DiscreteValues* values) const { // TODO: Abhijit asks: is this really the fastest way? He thinks it is. diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index b31f1d92b..06928e2e7 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -57,6 +57,33 @@ public: /** Construct from signature */ DiscreteConditional(const Signature& signature); + /** + * Construct from key, parents, and a Signature::Table specifying the + * conditional probability table (CPT) in 00 01 10 11 order. For + * three-valued, it would be 00 01 02 10 11 12 20 21 22, etc.... + * + * The first string is parsed to add a key and parents. + * + * Example: DiscreteConditional P(D, {B,E}, table); + */ + DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents, + const Signature::Table& table) + : DiscreteConditional(Signature(key, parents, table)) {} + + /** + * Construct from key, parents, and a string specifying the conditional + * probability table (CPT) in 00 01 10 11 order. For three-valued, it would + * be 00 01 02 10 11 12 20 21 22, etc.... + * + * The first string is parsed to add a key and parents. The second string + * parses into a table. + * + * Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9"); + */ + DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents, + const std::string& spec) + : DiscreteConditional(Signature(key, parents, spec)) {} + /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ DiscreteConditional(const DecisionTreeFactor& joint, const DecisionTreeFactor& marginal); @@ -111,6 +138,10 @@ public: /** Restrict to given parent values, returns AlgebraicDecisionDiagram */ ADT choose(const DiscreteValues& parentsValues) const; + /** Restrict to given parent values, returns DecisionTreeFactor */ + DecisionTreeFactor::shared_ptr chooseAsFactor( + const DiscreteValues& parentsValues) const; + /** * solve a conditional * @param parentsValues Known values of the parents diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index a8ba7ab03..ff0aaef19 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -130,8 +130,11 @@ public: /** return product of all factors as a single factor */ DecisionTreeFactor product() const; - /** Evaluates the factor graph given values, returns the joint probability of the factor graph given specific instantiation of values*/ - double operator()(const DiscreteValues & values) const; + /** + * Evaluates the factor graph given values, returns the joint probability of + * the factor graph given specific instantiation of values + */ + double operator()(const DiscreteValues& values) const; /// print void print( diff --git a/gtsam/discrete/DiscreteKey.h b/gtsam/discrete/DiscreteKey.h index c041c7e8e..3462166f4 100644 --- a/gtsam/discrete/DiscreteKey.h +++ b/gtsam/discrete/DiscreteKey.h @@ -31,22 +31,24 @@ namespace gtsam { * Key type for discrete conditionals * Includes name and cardinality */ - typedef std::pair DiscreteKey; + using DiscreteKey = std::pair; /// DiscreteKeys is a set of keys that can be assembled using the & operator struct DiscreteKeys: public std::vector { - /// Default constructor - DiscreteKeys() { - } + // Forward all constructors. + using std::vector::vector; + + /// Constructor for serialization + GTSAM_EXPORT DiscreteKeys() : std::vector::vector() {} /// Construct from a key - DiscreteKeys(const DiscreteKey& key) { + GTSAM_EXPORT DiscreteKeys(const DiscreteKey& key) { push_back(key); } /// Construct from a vector of keys - DiscreteKeys(const std::vector& keys) : + GTSAM_EXPORT DiscreteKeys(const std::vector& keys) : std::vector(keys) { } diff --git a/gtsam/discrete/Signature.cpp b/gtsam/discrete/Signature.cpp index 94b160a29..146555898 100644 --- a/gtsam/discrete/Signature.cpp +++ b/gtsam/discrete/Signature.cpp @@ -38,19 +38,7 @@ namespace gtsam { using boost::phoenix::push_back; // Special rows, true and false - Signature::Row createF() { - Signature::Row r(2); - r[0] = 1; - r[1] = 0; - return r; - } - Signature::Row createT() { - Signature::Row r(2); - r[0] = 0; - r[1] = 1; - return r; - } - Signature::Row T = createT(), F = createF(); + Signature::Row F{1, 0}, T{0, 1}; // Special tables (inefficient, but do we care for user input?) Signature::Table logic(bool ff, bool ft, bool tf, bool tt) { @@ -69,40 +57,13 @@ namespace gtsam { table = or_ | and_ | rows; or_ = qi::lit("OR")[qi::_val = logic(false, true, true, true)]; and_ = qi::lit("AND")[qi::_val = logic(false, false, false, true)]; - rows = +(row | true_ | false_); // only loads first of the rows under boost 1.42 + rows = +(row | true_ | false_); row = qi::double_ >> +("/" >> qi::double_); true_ = qi::lit("T")[qi::_val = T]; false_ = qi::lit("F")[qi::_val = F]; } } grammar; - // Create simpler parsing function to avoid the issue of only parsing a single row - bool parse_table(const string& spec, Signature::Table& table) { - // check for OR, AND on whole phrase - It f = spec.begin(), l = spec.end(); - if (qi::parse(f, l, - qi::lit("OR")[ph::ref(table) = logic(false, true, true, true)]) || - qi::parse(f, l, - qi::lit("AND")[ph::ref(table) = logic(false, false, false, true)])) - return true; - - // tokenize into separate rows - istringstream iss(spec); - string token; - while (iss >> token) { - Signature::Row values; - It tf = token.begin(), tl = token.end(); - bool r = qi::parse(tf, tl, - qi::double_[push_back(ph::ref(values), qi::_1)] >> +("/" >> qi::double_[push_back(ph::ref(values), qi::_1)]) | - qi::lit("T")[ph::ref(values) = T] | - qi::lit("F")[ph::ref(values) = F] ); - if (!r) - return false; - table.push_back(values); - } - - return true; - } } // \namespace parser ostream& operator <<(ostream &os, const Signature::Row &row) { @@ -118,6 +79,18 @@ namespace gtsam { return os; } + Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const Table& table) + : key_(key), parents_(parents) { + operator=(table); + } + + Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const std::string& spec) + : key_(key), parents_(parents) { + operator=(spec); + } + Signature::Signature(const DiscreteKey& key) : key_(key) { } @@ -166,14 +139,11 @@ namespace gtsam { Signature& Signature::operator=(const string& spec) { spec_.reset(spec); Table table; - // NOTE: using simpler parse function to ensure boost back compatibility -// parser::It f = spec.begin(), l = spec.end(); - bool success = // -// qi::phrase_parse(f, l, parser::grammar.table, qi::space, table); // using full grammar - parser::parse_table(spec, table); + parser::It f = spec.begin(), l = spec.end(); + bool success = + qi::phrase_parse(f, l, parser::grammar.table, qi::space, table); if (success) { - for(Row& row: table) - normalize(row); + for (Row& row : table) normalize(row); table_.reset(table); } return *this; diff --git a/gtsam/discrete/Signature.h b/gtsam/discrete/Signature.h index 6c59b5bff..ff83caa53 100644 --- a/gtsam/discrete/Signature.h +++ b/gtsam/discrete/Signature.h @@ -30,7 +30,7 @@ namespace gtsam { * The format is (Key % string) for nodes with no parents, * and (Key | Key, Key = string) for nodes with parents. * - * The string specifies a conditional probability spec in the 00 01 10 11 order. + * The string specifies a conditional probability table in 00 01 10 11 order. * For three-valued, it would be 00 01 02 10 11 12 20 21 22, etc... * * For example, given the following keys @@ -45,9 +45,9 @@ namespace gtsam { * T|A = "99/1 95/5" * L|S = "99/1 90/10" * B|S = "70/30 40/60" - * E|T,L = "F F F 1" + * (E|T,L) = "F F F 1" * X|E = "95/5 2/98" - * D|E,B = "9/1 2/8 3/7 1/9" + * (D|E,B) = "9/1 2/8 3/7 1/9" */ class GTSAM_EXPORT Signature { @@ -72,45 +72,73 @@ namespace gtsam { boost::optional table_; public: + /** + * Construct from key, parents, and a Signature::Table specifying the + * conditional probability table (CPT) in 00 01 10 11 order. For + * three-valued, it would be 00 01 02 10 11 12 20 21 22, etc.... + * + * The first string is parsed to add a key and parents. + * + * Example: + * Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}}; + * Signature sig(D, {E, B}, table); + */ + Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const Table& table); - /** Constructor from DiscreteKey */ - Signature(const DiscreteKey& key); + /** + * Construct from key, parents, and a string specifying the conditional + * probability table (CPT) in 00 01 10 11 order. For three-valued, it would + * be 00 01 02 10 11 12 20 21 22, etc.... + * + * The first string is parsed to add a key and parents. The second string + * parses into a table. + * + * Example (same CPT as above): + * Signature sig(D, {B,E}, "9/1 2/8 3/7 1/9"); + */ + Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const std::string& spec); - /** the variable key */ - const DiscreteKey& key() const { - return key_; - } + /** + * Construct from a single DiscreteKey. + * + * The resulting signature has no parents or CPT table. Typical use then + * either adds parents with | and , operators below, or assigns a table with + * operator=(). + */ + Signature(const DiscreteKey& key); - /** the parent keys */ - const DiscreteKeys& parents() const { - return parents_; - } + /** the variable key */ + const DiscreteKey& key() const { return key_; } - /** All keys, with variable key first */ - DiscreteKeys discreteKeys() const; + /** the parent keys */ + const DiscreteKeys& parents() const { return parents_; } - /** All key indices, with variable key first */ - KeyVector indices() const; + /** All keys, with variable key first */ + DiscreteKeys discreteKeys() const; - // the CPT as parsed, if successful - const boost::optional
& table() const { - return table_; - } + /** All key indices, with variable key first */ + KeyVector indices() const; - // the CPT as a vector of doubles, with key's values most rapidly changing - std::vector cpt() const; + // the CPT as parsed, if successful + const boost::optional
& table() const { return table_; } - /** Add a parent */ - Signature& operator,(const DiscreteKey& parent); + // the CPT as a vector of doubles, with key's values most rapidly changing + std::vector cpt() const; - /** Add the CPT spec - Fails in boost 1.40 */ - Signature& operator=(const std::string& spec); + /** Add a parent */ + Signature& operator,(const DiscreteKey& parent); - /** Add the CPT spec directly as a table */ - Signature& operator=(const Table& table); + /** Add the CPT spec */ + Signature& operator=(const std::string& spec); - /** provide streaming */ - GTSAM_EXPORT friend std::ostream& operator <<(std::ostream &os, const Signature &s); + /** Add the CPT spec directly as a table */ + Signature& operator=(const Table& table); + + /** provide streaming */ + GTSAM_EXPORT friend std::ostream& operator<<(std::ostream& os, + const Signature& s); }; /** @@ -122,7 +150,6 @@ namespace gtsam { /** * Helper function to create Signature objects * example: Signature s(D % "99/1"); - * Uses string parser, which requires BOOST 1.42 or higher */ GTSAM_EXPORT Signature operator%(const DiscreteKey& key, const std::string& parent); diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i new file mode 100644 index 000000000..daea84e70 --- /dev/null +++ b/gtsam/discrete/discrete.i @@ -0,0 +1,131 @@ +//************************************************************************* +// discrete +//************************************************************************* + +namespace gtsam { + + +#include +class DiscreteKey {}; + +class DiscreteKeys { + DiscreteKeys(); + size_t size() const; + bool empty() const; + gtsam::DiscreteKey at(size_t n) const; + void push_back(const gtsam::DiscreteKey& point_pair); +}; + +// DiscreteValues is added in specializations/discrete.h as a std::map + +#include +class DiscreteFactor { + void print(string s = "DiscreteFactor\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const; + bool empty() const; + size_t size() const; + double operator()(const gtsam::DiscreteValues& values) const; +}; + +#include +virtual class DecisionTreeFactor: gtsam::DiscreteFactor { + DecisionTreeFactor(); + DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table); + DecisionTreeFactor(const gtsam::DiscreteConditional& c); + void print(string s = "DecisionTreeFactor\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; + double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat??? +}; + +#include +virtual class DiscreteConditional : gtsam::DecisionTreeFactor { + DiscreteConditional(); + DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f); + DiscreteConditional(const gtsam::DiscreteKey& key, + const gtsam::DiscreteKeys& parents, string spec); + DiscreteConditional(const gtsam::DecisionTreeFactor& joint, + const gtsam::DecisionTreeFactor& marginal); + DiscreteConditional(const gtsam::DecisionTreeFactor& joint, + const gtsam::DecisionTreeFactor& marginal, + const gtsam::Ordering& orderedKeys); + size_t size() const; // TODO(dellaert): why do I have to repeat??? + double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat??? + void print(string s = "Discrete Conditional\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const; + void printSignature( + string s = "Discrete Conditional: ", + const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; + gtsam::DecisionTreeFactor* toFactor() const; + gtsam::DecisionTreeFactor* chooseAsFactor(const gtsam::DiscreteValues& parentsValues) const; + size_t solve(const gtsam::DiscreteValues& parentsValues) const; + size_t sample(const gtsam::DiscreteValues& parentsValues) const; + void solveInPlace(gtsam::DiscreteValues@ parentsValues) const; + void sampleInPlace(gtsam::DiscreteValues@ parentsValues) const; +}; + +#include +class DiscreteBayesNet { + DiscreteBayesNet(); + void add(const gtsam::DiscreteKey& key, + const gtsam::DiscreteKeys& parents, string spec); + bool empty() const; + size_t size() const; + gtsam::KeySet keys() const; + const gtsam::DiscreteConditional* at(size_t i) const; + void print(string s = "DiscreteBayesNet\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const; + void saveGraph(string s, + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + void add(const gtsam::DiscreteConditional& s); + double operator()(const gtsam::DiscreteValues& values) const; + gtsam::DiscreteValues optimize() const; + gtsam::DiscreteValues sample() const; +}; + +#include +class DiscreteBayesTree { + DiscreteBayesTree(); + void print(string s = "DiscreteBayesTree\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const; + double operator()(const gtsam::DiscreteValues& values) const; +}; + +#include +class DiscreteFactorGraph { + DiscreteFactorGraph(); + DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet); + + void add(const gtsam::DiscreteKey& j, string table); + void add(const gtsam::DiscreteKey& j1, const gtsam::DiscreteKey& j2, string table); + void add(const gtsam::DiscreteKeys& keys, string table); + + bool empty() const; + size_t size() const; + gtsam::KeySet keys() const; + const gtsam::DiscreteFactor* at(size_t i) const; + + void print(string s = "") const; + bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const; + + gtsam::DecisionTreeFactor product() const; + double operator()(const gtsam::DiscreteValues& values) const; + gtsam::DiscreteValues optimize() const; + + gtsam::DiscreteBayesNet eliminateSequential(); + gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering); + gtsam::DiscreteBayesTree eliminateMultifrontal(); + gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering); +}; + +} // namespace gtsam diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 3ac3ffc9e..79714217c 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -61,11 +61,10 @@ TEST(DiscreteConditional, constructors_alt_interface) { r2 += 2.0, 3.0; r3 += 1.0, 4.0; table += r1, r2, r3; - auto actual1 = boost::make_shared(X | Y = table); - EXPECT(actual1); + DiscreteConditional actual1(X, {Y}, table); DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); DiscreteConditional expected1(1, f1); - EXPECT(assert_equal(expected1, *actual1, 1e-9)); + EXPECT(assert_equal(expected1, actual1, 1e-9)); DecisionTreeFactor f2( X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); diff --git a/gtsam/discrete/tests/testSignature.cpp b/gtsam/discrete/tests/testSignature.cpp index 049c455f7..737bd8aef 100644 --- a/gtsam/discrete/tests/testSignature.cpp +++ b/gtsam/discrete/tests/testSignature.cpp @@ -32,22 +32,27 @@ DiscreteKey X(0, 2), Y(1, 3), Z(2, 2); /* ************************************************************************* */ TEST(testSignature, simple_conditional) { - Signature sig(X | Y = "1/1 2/3 1/4"); + Signature sig(X, {Y}, "1/1 2/3 1/4"); + CHECK(sig.table()); Signature::Table table = *sig.table(); vector row[3]{{0.5, 0.5}, {0.4, 0.6}, {0.2, 0.8}}; + LONGS_EQUAL(3, table.size()); CHECK(row[0] == table[0]); CHECK(row[1] == table[1]); CHECK(row[2] == table[2]); - DiscreteKey actKey = sig.key(); - LONGS_EQUAL(X.first, actKey.first); - DiscreteKeys actKeys = sig.discreteKeys(); - LONGS_EQUAL(2, actKeys.size()); - LONGS_EQUAL(X.first, actKeys.front().first); - LONGS_EQUAL(Y.first, actKeys.back().first); + CHECK(sig.key() == X); - vector actCpt = sig.cpt(); - EXPECT_LONGS_EQUAL(6, actCpt.size()); + DiscreteKeys keys = sig.discreteKeys(); + LONGS_EQUAL(2, keys.size()); + CHECK(keys[0] == X); + CHECK(keys[1] == Y); + + DiscreteKeys parents = sig.parents(); + LONGS_EQUAL(1, parents.size()); + CHECK(parents[0] == Y); + + EXPECT_LONGS_EQUAL(6, sig.cpt().size()); } /* ************************************************************************* */ @@ -60,16 +65,56 @@ TEST(testSignature, simple_conditional_nonparser) { table += row1, row2, row3; Signature sig(X | Y = table); - DiscreteKey actKey = sig.key(); - EXPECT_LONGS_EQUAL(X.first, actKey.first); + CHECK(sig.key() == X); - DiscreteKeys actKeys = sig.discreteKeys(); - LONGS_EQUAL(2, actKeys.size()); - LONGS_EQUAL(X.first, actKeys.front().first); - LONGS_EQUAL(Y.first, actKeys.back().first); + DiscreteKeys keys = sig.discreteKeys(); + LONGS_EQUAL(2, keys.size()); + CHECK(keys[0] == X); + CHECK(keys[1] == Y); - vector actCpt = sig.cpt(); - EXPECT_LONGS_EQUAL(6, actCpt.size()); + DiscreteKeys parents = sig.parents(); + LONGS_EQUAL(1, parents.size()); + CHECK(parents[0] == Y); + + EXPECT_LONGS_EQUAL(6, sig.cpt().size()); +} + +/* ************************************************************************* */ +DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), D(7, 2); + +// Make sure we can create all signatures for Asia network with constructor. +TEST(testSignature, all_examples) { + DiscreteKey X(6, 2); + Signature a(A, {}, "99/1"); + Signature s(S, {}, "50/50"); + Signature t(T, {A}, "99/1 95/5"); + Signature l(L, {S}, "99/1 90/10"); + Signature b(B, {S}, "70/30 40/60"); + Signature e(E, {T, L}, "F F F 1"); + Signature x(X, {E}, "95/5 2/98"); +} + +// Make sure we can create all signatures for Asia network with operator magic. +TEST(testSignature, all_examples_magic) { + DiscreteKey X(6, 2); + Signature a(A % "99/1"); + Signature s(S % "50/50"); + Signature t(T | A = "99/1 95/5"); + Signature l(L | S = "99/1 90/10"); + Signature b(B | S = "70/30 40/60"); + Signature e((E | T, L) = "F F F 1"); + Signature x(X | E = "95/5 2/98"); +} + +// Check example from docs. +TEST(testSignature, doxygen_example) { + Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}}; + Signature d1(D, {E, B}, table); + Signature d2((D | E, B) = "9/1 2/8 3/7 1/9"); + Signature d3(D, {E, B}, "9/1 2/8 3/7 1/9"); + EXPECT(*(d1.table()) == table); + EXPECT(*(d2.table()) == table); + EXPECT(*(d3.table()) == table); } /* ************************************************************************* */ diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index e2444a51a..d3b20e312 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -49,11 +49,13 @@ set(ignore gtsam::Pose3Vector gtsam::KeyVector gtsam::BinaryMeasurementsUnit3 + gtsam::DiscreteKey gtsam::KeyPairDoubleMap) set(interface_headers ${PROJECT_SOURCE_DIR}/gtsam/gtsam.i ${PROJECT_SOURCE_DIR}/gtsam/base/base.i + ${PROJECT_SOURCE_DIR}/gtsam/discrete/discrete.i ${PROJECT_SOURCE_DIR}/gtsam/geometry/geometry.i ${PROJECT_SOURCE_DIR}/gtsam/linear/linear.i ${PROJECT_SOURCE_DIR}/gtsam/nonlinear/nonlinear.i diff --git a/python/gtsam/preamble/discrete.h b/python/gtsam/preamble/discrete.h new file mode 100644 index 000000000..608508c32 --- /dev/null +++ b/python/gtsam/preamble/discrete.h @@ -0,0 +1,16 @@ +/* Please refer to: + * https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html + * These are required to save one copy operation on Python calls. + * + * NOTES + * ================= + * + * `PYBIND11_MAKE_OPAQUE` will mark the type as "opaque" for the pybind11 + * automatic STL binding, such that the raw objects can be accessed in Python. + * Without this they will be automatically converted to a Python object, and all + * mutations on Python side will not be reflected on C++. + */ + +#include + +PYBIND11_MAKE_OPAQUE(gtsam::DiscreteKeys); diff --git a/python/gtsam/specializations/discrete.h b/python/gtsam/specializations/discrete.h new file mode 100644 index 000000000..458a2ea4c --- /dev/null +++ b/python/gtsam/specializations/discrete.h @@ -0,0 +1,17 @@ +/* Please refer to: + * https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html + * These are required to save one copy operation on Python calls. + * + * NOTES + * ================= + * + * `py::bind_vector` and similar machinery gives the std container a Python-like + * interface, but without the `` copying mechanism. Combined + * with `PYBIND11_MAKE_OPAQUE` this allows the types to be modified with Python, + * and saves one copy operation. + */ + +// Seems this is not a good idea with inherited stl +//py::bind_vector>(m_, "DiscreteKeys"); + +py::bind_map(m_, "DiscreteValues"); diff --git a/python/gtsam/tests/test_DiscreteBayesNet.py b/python/gtsam/tests/test_DiscreteBayesNet.py new file mode 100644 index 000000000..bf09da193 --- /dev/null +++ b/python/gtsam/tests/test_DiscreteBayesNet.py @@ -0,0 +1,118 @@ +""" +GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Bayes Nets. +Author: Frank Dellaert +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph, + DiscreteKeys, DiscreteValues, Ordering) +from gtsam.utils.test_case import GtsamTestCase + + +class TestDiscreteBayesNet(GtsamTestCase): + """Tests for Discrete Bayes Nets.""" + + def test_constructor(self): + """Test constructing a Bayes net.""" + + bayesNet = DiscreteBayesNet() + Parent, Child = (0, 2), (1, 2) + empty = DiscreteKeys() + prior = DiscreteConditional(Parent, empty, "6/4") + bayesNet.add(prior) + + parents = DiscreteKeys() + parents.push_back(Parent) + conditional = DiscreteConditional(Child, parents, "7/3 8/2") + bayesNet.add(conditional) + + # Check conversion to factor graph: + fg = DiscreteFactorGraph(bayesNet) + self.assertEqual(fg.size(), 2) + self.assertEqual(fg.at(1).size(), 2) + + def test_Asia(self): + """Test full Asia example.""" + + Asia = (0, 2) + Smoking = (4, 2) + Tuberculosis = (3, 2) + LungCancer = (6, 2) + + Bronchitis = (7, 2) + Either = (5, 2) + XRay = (2, 2) + Dyspnea = (1, 2) + + def P(keys): + dks = DiscreteKeys() + for key in keys: + dks.push_back(key) + return dks + + asia = DiscreteBayesNet() + asia.add(Asia, P([]), "99/1") + asia.add(Smoking, P([]), "50/50") + + asia.add(Tuberculosis, P([Asia]), "99/1 95/5") + asia.add(LungCancer, P([Smoking]), "99/1 90/10") + asia.add(Bronchitis, P([Smoking]), "70/30 40/60") + + asia.add(Either, P([Tuberculosis, LungCancer]), "F T T T") + + asia.add(XRay, P([Either]), "95/5 2/98") + asia.add(Dyspnea, P([Either, Bronchitis]), "9/1 2/8 3/7 1/9") + + # Convert to factor graph + fg = DiscreteFactorGraph(asia) + + # Create solver and eliminate + ordering = Ordering() + for j in range(8): + ordering.push_back(j) + chordal = fg.eliminateSequential(ordering) + expected2 = DiscreteConditional(Bronchitis, P([]), "11/9") + self.gtsamAssertEquals(chordal.at(7), expected2) + + # solve + actualMPE = chordal.optimize() + expectedMPE = DiscreteValues() + for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]: + expectedMPE[key[0]] = 0 + self.assertEqual(list(actualMPE.items()), + list(expectedMPE.items())) + + # Check value for MPE is the same + self.assertAlmostEqual(asia(actualMPE), fg(actualMPE)) + + # add evidence, we were in Asia and we have dyspnea + fg.add(Asia, "0 1") + fg.add(Dyspnea, "0 1") + + # solve again, now with evidence + chordal2 = fg.eliminateSequential(ordering) + actualMPE2 = chordal2.optimize() + expectedMPE2 = DiscreteValues() + for key in [XRay, Tuberculosis, Either, LungCancer]: + expectedMPE2[key[0]] = 0 + for key in [Asia, Dyspnea, Smoking, Bronchitis]: + expectedMPE2[key[0]] = 1 + self.assertEqual(list(actualMPE2.items()), + list(expectedMPE2.items())) + + # now sample from it + actualSample = chordal2.sample() + self.assertEqual(len(actualSample), 8) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py new file mode 100644 index 000000000..9dafff33f --- /dev/null +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -0,0 +1,122 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Factor Graphs. +Author: Frank Dellaert +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues +from gtsam.utils.test_case import GtsamTestCase + + +class TestDiscreteFactorGraph(GtsamTestCase): + """Tests for Discrete Factor Graphs.""" + + def test_evaluation(self): + """Test constructing and evaluating a discrete factor graph.""" + + # Three keys + P1 = (0, 2) + P2 = (1, 2) + P3 = (2, 3) + + # Create the DiscreteFactorGraph + graph = DiscreteFactorGraph() + + # Add two unary factors (priors) + graph.add(P1, "0.9 0.3") + graph.add(P2, "0.9 0.6") + + # Add a binary factor + graph.add(P1, P2, "4 1 10 4") + + # Instantiate Values + assignment = DiscreteValues() + assignment[0] = 1 + assignment[1] = 1 + + # Check if graph evaluation works ( 0.3*0.6*4 ) + self.assertAlmostEqual(.72, graph(assignment)) + + # Create a new test with third node and adding unary and ternary factor + graph.add(P3, "0.9 0.2 0.5") + keys = DiscreteKeys() + keys.push_back(P1) + keys.push_back(P2) + keys.push_back(P3) + graph.add(keys, "1 2 3 4 5 6 7 8 9 10 11 12") + + # Below assignment selects the 8th index in the ternary factor table + assignment[0] = 1 + assignment[1] = 0 + assignment[2] = 1 + + # Check if graph evaluation works (0.3*0.9*1*0.2*8) + self.assertAlmostEqual(4.32, graph(assignment)) + + # Below assignment selects the 3rd index in the ternary factor table + assignment[0] = 0 + assignment[1] = 1 + assignment[2] = 0 + + # Check if graph evaluation works (0.9*0.6*1*0.9*4) + self.assertAlmostEqual(1.944, graph(assignment)) + + # Check if graph product works + product = graph.product() + self.assertAlmostEqual(1.944, product(assignment)) + + def test_optimize(self): + """Test constructing and optizing a discrete factor graph.""" + + # Three keys + C = (0, 2) + B = (1, 2) + A = (2, 2) + + # A simple factor graph (A)-fAC-(C)-fBC-(B) + # with smoothness priors + graph = DiscreteFactorGraph() + graph.add(A, C, "3 1 1 3") + graph.add(C, B, "3 1 1 3") + + # Test optimization + expectedValues = DiscreteValues() + expectedValues[0] = 0 + expectedValues[1] = 0 + expectedValues[2] = 0 + actualValues = graph.optimize() + self.assertEqual(list(actualValues.items()), + list(expectedValues.items())) + + def test_MPE(self): + """Test maximum probable explanation (MPE): same as optimize.""" + + # Declare a bunch of keys + C, A, B = (0, 2), (1, 2), (2, 2) + + # Create Factor graph + graph = DiscreteFactorGraph() + graph.add(C, A, "0.2 0.8 0.3 0.7") + graph.add(C, B, "0.1 0.9 0.4 0.6") + + actualMPE = graph.optimize() + + expectedMPE = DiscreteValues() + expectedMPE[0] = 0 + expectedMPE[1] = 1 + expectedMPE[2] = 1 + self.assertEqual(list(actualMPE.items()), + list(expectedMPE.items())) + + +if __name__ == "__main__": + unittest.main()