New, non-fancy constructors

release/4.3a0
Frank Dellaert 2021-12-15 08:51:01 -05:00
parent fd7640b1b7
commit 4e5530b6d5
5 changed files with 141 additions and 50 deletions

View File

@ -57,6 +57,29 @@ public:
/** Construct from signature */ /** Construct from signature */
DiscreteConditional(const Signature& signature); DiscreteConditional(const Signature& signature);
/**
* Construct from key, parents, and a Table specifying the CPT.
*
* The first string is parsed to add a key and parents.
*
* Example: DiscreteConditional P(D, {B,E}, table);
*/
DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents,
const Signature::Table& table)
: DiscreteConditional(Signature(key, parents, table)) {}
/**
* Construct from key, parents, and a string specifying the CPT.
*
* The first string is parsed to add a key and parents. The second string
* parses into a table.
*
* Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9");
*/
DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents,
const std::string& spec)
: DiscreteConditional(Signature(key, parents, spec)) {}
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
DiscreteConditional(const DecisionTreeFactor& joint, DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal); const DecisionTreeFactor& marginal);

View File

@ -79,6 +79,18 @@ namespace gtsam {
return os; return os;
} }
Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents,
const Table& table)
: key_(key), parents_(parents) {
operator=(table);
}
Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents,
const std::string& spec)
: key_(key), parents_(parents) {
operator=(spec);
}
Signature::Signature(const DiscreteKey& key) : Signature::Signature(const DiscreteKey& key) :
key_(key) { key_(key) {
} }

View File

@ -45,9 +45,9 @@ namespace gtsam {
* T|A = "99/1 95/5" * T|A = "99/1 95/5"
* L|S = "99/1 90/10" * L|S = "99/1 90/10"
* B|S = "70/30 40/60" * B|S = "70/30 40/60"
* E|T,L = "F F F 1" * (E|T,L) = "F F F 1"
* X|E = "95/5 2/98" * X|E = "95/5 2/98"
* D|E,B = "9/1 2/8 3/7 1/9" * (D|E,B) = "9/1 2/8 3/7 1/9"
*/ */
class GTSAM_EXPORT Signature { class GTSAM_EXPORT Signature {
@ -72,19 +72,41 @@ namespace gtsam {
boost::optional<Table> table_; boost::optional<Table> table_;
public: public:
/**
* Construct from key, parents, and a Table specifying the CPT.
*
* The first string is parsed to add a key and parents.
*
* Example: Signature sig(D, {B,E}, table);
*/
Signature(const DiscreteKey& key, const DiscreteKeys& parents,
const Table& table);
/** Constructor from DiscreteKey */ /**
* Construct from key, parents, and a string specifying the CPT.
*
* The first string is parsed to add a key and parents. The second string
* parses into a table.
*
* Example: Signature sig(D, {B,E}, "9/1 2/8 3/7 1/9");
*/
Signature(const DiscreteKey& key, const DiscreteKeys& parents,
const std::string& spec);
/**
* Construct from a single DiscreteKey.
*
* The resulting signature has no parents or CPT table. Typical use then
* either adds parents with | and , operators below, or assigns a table with
* operator=().
*/
Signature(const DiscreteKey& key); Signature(const DiscreteKey& key);
/** the variable key */ /** the variable key */
const DiscreteKey& key() const { const DiscreteKey& key() const { return key_; }
return key_;
}
/** the parent keys */ /** the parent keys */
const DiscreteKeys& parents() const { const DiscreteKeys& parents() const { return parents_; }
return parents_;
}
/** All keys, with variable key first */ /** All keys, with variable key first */
DiscreteKeys discreteKeys() const; DiscreteKeys discreteKeys() const;
@ -93,9 +115,7 @@ namespace gtsam {
KeyVector indices() const; KeyVector indices() const;
// the CPT as parsed, if successful // the CPT as parsed, if successful
const boost::optional<Table>& table() const { const boost::optional<Table>& table() const { return table_; }
return table_;
}
// the CPT as a vector of doubles, with key's values most rapidly changing // the CPT as a vector of doubles, with key's values most rapidly changing
std::vector<double> cpt() const; std::vector<double> cpt() const;
@ -110,7 +130,8 @@ namespace gtsam {
Signature& operator=(const Table& table); Signature& operator=(const Table& table);
/** provide streaming */ /** provide streaming */
GTSAM_EXPORT friend std::ostream& operator <<(std::ostream &os, const Signature &s); GTSAM_EXPORT friend std::ostream& operator<<(std::ostream& os,
const Signature& s);
}; };
/** /**

View File

@ -61,11 +61,10 @@ TEST(DiscreteConditional, constructors_alt_interface) {
r2 += 2.0, 3.0; r2 += 2.0, 3.0;
r3 += 1.0, 4.0; r3 += 1.0, 4.0;
table += r1, r2, r3; table += r1, r2, r3;
auto actual1 = boost::make_shared<DiscreteConditional>(X | Y = table); DiscreteConditional actual1(X, {Y}, table);
EXPECT(actual1);
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
DiscreteConditional expected1(1, f1); DiscreteConditional expected1(1, f1);
EXPECT(assert_equal(expected1, *actual1, 1e-9)); EXPECT(assert_equal(expected1, actual1, 1e-9));
DecisionTreeFactor f2( DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");

View File

@ -32,22 +32,27 @@ DiscreteKey X(0, 2), Y(1, 3), Z(2, 2);
/* ************************************************************************* */ /* ************************************************************************* */
TEST(testSignature, simple_conditional) { TEST(testSignature, simple_conditional) {
Signature sig(X | Y = "1/1 2/3 1/4"); Signature sig(X, {Y}, "1/1 2/3 1/4");
CHECK(sig.table());
Signature::Table table = *sig.table(); Signature::Table table = *sig.table();
vector<double> row[3]{{0.5, 0.5}, {0.4, 0.6}, {0.2, 0.8}}; vector<double> row[3]{{0.5, 0.5}, {0.4, 0.6}, {0.2, 0.8}};
LONGS_EQUAL(3, table.size());
CHECK(row[0] == table[0]); CHECK(row[0] == table[0]);
CHECK(row[1] == table[1]); CHECK(row[1] == table[1]);
CHECK(row[2] == table[2]); CHECK(row[2] == table[2]);
DiscreteKey actKey = sig.key();
LONGS_EQUAL(X.first, actKey.first);
DiscreteKeys actKeys = sig.discreteKeys(); CHECK(sig.key() == X);
LONGS_EQUAL(2, actKeys.size());
LONGS_EQUAL(X.first, actKeys.front().first);
LONGS_EQUAL(Y.first, actKeys.back().first);
vector<double> actCpt = sig.cpt(); DiscreteKeys keys = sig.discreteKeys();
EXPECT_LONGS_EQUAL(6, actCpt.size()); LONGS_EQUAL(2, keys.size());
CHECK(keys[0] == X);
CHECK(keys[1] == Y);
DiscreteKeys parents = sig.parents();
LONGS_EQUAL(1, parents.size());
CHECK(parents[0] == Y);
EXPECT_LONGS_EQUAL(6, sig.cpt().size());
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -60,16 +65,47 @@ TEST(testSignature, simple_conditional_nonparser) {
table += row1, row2, row3; table += row1, row2, row3;
Signature sig(X | Y = table); Signature sig(X | Y = table);
DiscreteKey actKey = sig.key(); CHECK(sig.key() == X);
EXPECT_LONGS_EQUAL(X.first, actKey.first);
DiscreteKeys actKeys = sig.discreteKeys(); DiscreteKeys keys = sig.discreteKeys();
LONGS_EQUAL(2, actKeys.size()); LONGS_EQUAL(2, keys.size());
LONGS_EQUAL(X.first, actKeys.front().first); CHECK(keys[0] == X);
LONGS_EQUAL(Y.first, actKeys.back().first); CHECK(keys[1] == Y);
vector<double> actCpt = sig.cpt(); DiscreteKeys parents = sig.parents();
EXPECT_LONGS_EQUAL(6, actCpt.size()); LONGS_EQUAL(1, parents.size());
CHECK(parents[0] == Y);
EXPECT_LONGS_EQUAL(6, sig.cpt().size());
}
/* ************************************************************************* */
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), D(7, 2);
// Make sure we can create all signatures for Asia network with constructor.
TEST(testSignature, all_examples) {
DiscreteKey X(6, 2);
Signature a(A, {}, "99/1");
Signature s(S, {}, "50/50");
Signature t(T, {A}, "99/1 95/5");
Signature l(L, {S}, "99/1 90/10");
Signature b(B, {S}, "70/30 40/60");
Signature e(E, {T, L}, "F F F 1");
Signature x(X, {E}, "95/5 2/98");
Signature d(D, {E, B}, "9/1 2/8 3/7 1/9");
}
// Make sure we can create all signatures for Asia network with operator magic.
TEST(testSignature, all_examples_magic) {
DiscreteKey X(6, 2);
Signature a(A % "99/1");
Signature s(S % "50/50");
Signature t(T | A = "99/1 95/5");
Signature l(L | S = "99/1 90/10");
Signature b(B | S = "70/30 40/60");
Signature e((E | T, L) = "F F F 1");
Signature x(X | E = "95/5 2/98");
Signature d((D | E, B) = "9/1 2/8 3/7 1/9");
} }
/* ************************************************************************* */ /* ************************************************************************* */