diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index b31f1d92b..bd6549c91 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -57,6 +57,29 @@ public: /** Construct from signature */ DiscreteConditional(const Signature& signature); + /** + * Construct from key, parents, and a Table specifying the CPT. + * + * The first string is parsed to add a key and parents. + * + * Example: DiscreteConditional P(D, {B,E}, table); + */ + DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents, + const Signature::Table& table) + : DiscreteConditional(Signature(key, parents, table)) {} + + /** + * Construct from key, parents, and a string specifying the CPT. + * + * The first string is parsed to add a key and parents. The second string + * parses into a table. + * + * Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9"); + */ + DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents, + const std::string& spec) + : DiscreteConditional(Signature(key, parents, spec)) {} + /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ DiscreteConditional(const DecisionTreeFactor& joint, const DecisionTreeFactor& marginal); diff --git a/gtsam/discrete/Signature.cpp b/gtsam/discrete/Signature.cpp index 361fc0b0a..146555898 100644 --- a/gtsam/discrete/Signature.cpp +++ b/gtsam/discrete/Signature.cpp @@ -79,6 +79,18 @@ namespace gtsam { return os; } + Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const Table& table) + : key_(key), parents_(parents) { + operator=(table); + } + + Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const std::string& spec) + : key_(key), parents_(parents) { + operator=(spec); + } + Signature::Signature(const DiscreteKey& key) : key_(key) { } diff --git a/gtsam/discrete/Signature.h b/gtsam/discrete/Signature.h index 2a8248171..05f10ed23 100644 --- a/gtsam/discrete/Signature.h +++ b/gtsam/discrete/Signature.h @@ -45,9 +45,9 @@ namespace gtsam { * T|A = "99/1 95/5" * L|S = "99/1 90/10" * B|S = "70/30 40/60" - * E|T,L = "F F F 1" + * (E|T,L) = "F F F 1" * X|E = "95/5 2/98" - * D|E,B = "9/1 2/8 3/7 1/9" + * (D|E,B) = "9/1 2/8 3/7 1/9" */ class GTSAM_EXPORT Signature { @@ -72,45 +72,66 @@ namespace gtsam { boost::optional table_; public: + /** + * Construct from key, parents, and a Table specifying the CPT. + * + * The first string is parsed to add a key and parents. + * + * Example: Signature sig(D, {B,E}, table); + */ + Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const Table& table); - /** Constructor from DiscreteKey */ - Signature(const DiscreteKey& key); + /** + * Construct from key, parents, and a string specifying the CPT. + * + * The first string is parsed to add a key and parents. The second string + * parses into a table. + * + * Example: Signature sig(D, {B,E}, "9/1 2/8 3/7 1/9"); + */ + Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const std::string& spec); - /** the variable key */ - const DiscreteKey& key() const { - return key_; - } + /** + * Construct from a single DiscreteKey. + * + * The resulting signature has no parents or CPT table. Typical use then + * either adds parents with | and , operators below, or assigns a table with + * operator=(). + */ + Signature(const DiscreteKey& key); - /** the parent keys */ - const DiscreteKeys& parents() const { - return parents_; - } + /** the variable key */ + const DiscreteKey& key() const { return key_; } - /** All keys, with variable key first */ - DiscreteKeys discreteKeys() const; + /** the parent keys */ + const DiscreteKeys& parents() const { return parents_; } - /** All key indices, with variable key first */ - KeyVector indices() const; + /** All keys, with variable key first */ + DiscreteKeys discreteKeys() const; - // the CPT as parsed, if successful - const boost::optional
& table() const { - return table_; - } + /** All key indices, with variable key first */ + KeyVector indices() const; - // the CPT as a vector of doubles, with key's values most rapidly changing - std::vector cpt() const; + // the CPT as parsed, if successful + const boost::optional
& table() const { return table_; } - /** Add a parent */ - Signature& operator,(const DiscreteKey& parent); + // the CPT as a vector of doubles, with key's values most rapidly changing + std::vector cpt() const; - /** Add the CPT spec */ - Signature& operator=(const std::string& spec); + /** Add a parent */ + Signature& operator,(const DiscreteKey& parent); - /** Add the CPT spec directly as a table */ - Signature& operator=(const Table& table); + /** Add the CPT spec */ + Signature& operator=(const std::string& spec); - /** provide streaming */ - GTSAM_EXPORT friend std::ostream& operator <<(std::ostream &os, const Signature &s); + /** Add the CPT spec directly as a table */ + Signature& operator=(const Table& table); + + /** provide streaming */ + GTSAM_EXPORT friend std::ostream& operator<<(std::ostream& os, + const Signature& s); }; /** diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 3ac3ffc9e..79714217c 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -61,11 +61,10 @@ TEST(DiscreteConditional, constructors_alt_interface) { r2 += 2.0, 3.0; r3 += 1.0, 4.0; table += r1, r2, r3; - auto actual1 = boost::make_shared(X | Y = table); - EXPECT(actual1); + DiscreteConditional actual1(X, {Y}, table); DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); DiscreteConditional expected1(1, f1); - EXPECT(assert_equal(expected1, *actual1, 1e-9)); + EXPECT(assert_equal(expected1, actual1, 1e-9)); DecisionTreeFactor f2( X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); diff --git a/gtsam/discrete/tests/testSignature.cpp b/gtsam/discrete/tests/testSignature.cpp index 049c455f7..fd15eb36c 100644 --- a/gtsam/discrete/tests/testSignature.cpp +++ b/gtsam/discrete/tests/testSignature.cpp @@ -32,22 +32,27 @@ DiscreteKey X(0, 2), Y(1, 3), Z(2, 2); /* ************************************************************************* */ TEST(testSignature, simple_conditional) { - Signature sig(X | Y = "1/1 2/3 1/4"); + Signature sig(X, {Y}, "1/1 2/3 1/4"); + CHECK(sig.table()); Signature::Table table = *sig.table(); vector row[3]{{0.5, 0.5}, {0.4, 0.6}, {0.2, 0.8}}; + LONGS_EQUAL(3, table.size()); CHECK(row[0] == table[0]); CHECK(row[1] == table[1]); CHECK(row[2] == table[2]); - DiscreteKey actKey = sig.key(); - LONGS_EQUAL(X.first, actKey.first); - DiscreteKeys actKeys = sig.discreteKeys(); - LONGS_EQUAL(2, actKeys.size()); - LONGS_EQUAL(X.first, actKeys.front().first); - LONGS_EQUAL(Y.first, actKeys.back().first); + CHECK(sig.key() == X); - vector actCpt = sig.cpt(); - EXPECT_LONGS_EQUAL(6, actCpt.size()); + DiscreteKeys keys = sig.discreteKeys(); + LONGS_EQUAL(2, keys.size()); + CHECK(keys[0] == X); + CHECK(keys[1] == Y); + + DiscreteKeys parents = sig.parents(); + LONGS_EQUAL(1, parents.size()); + CHECK(parents[0] == Y); + + EXPECT_LONGS_EQUAL(6, sig.cpt().size()); } /* ************************************************************************* */ @@ -60,16 +65,47 @@ TEST(testSignature, simple_conditional_nonparser) { table += row1, row2, row3; Signature sig(X | Y = table); - DiscreteKey actKey = sig.key(); - EXPECT_LONGS_EQUAL(X.first, actKey.first); + CHECK(sig.key() == X); - DiscreteKeys actKeys = sig.discreteKeys(); - LONGS_EQUAL(2, actKeys.size()); - LONGS_EQUAL(X.first, actKeys.front().first); - LONGS_EQUAL(Y.first, actKeys.back().first); + DiscreteKeys keys = sig.discreteKeys(); + LONGS_EQUAL(2, keys.size()); + CHECK(keys[0] == X); + CHECK(keys[1] == Y); - vector actCpt = sig.cpt(); - EXPECT_LONGS_EQUAL(6, actCpt.size()); + DiscreteKeys parents = sig.parents(); + LONGS_EQUAL(1, parents.size()); + CHECK(parents[0] == Y); + + EXPECT_LONGS_EQUAL(6, sig.cpt().size()); +} + +/* ************************************************************************* */ +DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), D(7, 2); + +// Make sure we can create all signatures for Asia network with constructor. +TEST(testSignature, all_examples) { + DiscreteKey X(6, 2); + Signature a(A, {}, "99/1"); + Signature s(S, {}, "50/50"); + Signature t(T, {A}, "99/1 95/5"); + Signature l(L, {S}, "99/1 90/10"); + Signature b(B, {S}, "70/30 40/60"); + Signature e(E, {T, L}, "F F F 1"); + Signature x(X, {E}, "95/5 2/98"); + Signature d(D, {E, B}, "9/1 2/8 3/7 1/9"); +} + +// Make sure we can create all signatures for Asia network with operator magic. +TEST(testSignature, all_examples_magic) { + DiscreteKey X(6, 2); + Signature a(A % "99/1"); + Signature s(S % "50/50"); + Signature t(T | A = "99/1 95/5"); + Signature l(L | S = "99/1 90/10"); + Signature b(B | S = "70/30 40/60"); + Signature e((E | T, L) = "F F F 1"); + Signature x(X | E = "95/5 2/98"); + Signature d((D | E, B) = "9/1 2/8 3/7 1/9"); } /* ************************************************************************* */