Merge pull request #967 from borglab/feature/discrete_wrapper

release/4.3a0
Frank Dellaert 2021-12-16 21:20:40 -05:00 committed by GitHub
commit 0bab7b00c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 633 additions and 130 deletions

View File

@ -34,16 +34,6 @@ namespace gtsam {
return Base::equals(bn, tol);
}
/* ************************************************************************* */
// void DiscreteBayesNet::add_front(const Signature& s) {
// push_front(boost::make_shared<DiscreteConditional>(s));
// }
/* ************************************************************************* */
void DiscreteBayesNet::add(const Signature& s) {
push_back(boost::make_shared<DiscreteConditional>(s));
}
/* ************************************************************************* */
double DiscreteBayesNet::evaluate(const DiscreteValues & values) const {
// evaluate all conditionals and multiply

View File

@ -71,15 +71,23 @@ namespace gtsam {
/// @name Standard Interface
/// @{
// Add inherited versions of add.
using Base::add;
/** Add a DiscreteCondtional */
void add(const Signature& s);
// /** Add a DiscreteCondtional in front, when listing parents first*/
// GTSAM_EXPORT void add_front(const Signature& s);
template <typename... Args>
void add(Args&&... args) {
emplace_shared<DiscreteConditional>(std::forward<Args>(args)...);
}
//** evaluate for given DiscreteValues */
double evaluate(const DiscreteValues & values) const;
//** (Preferred) sugar for the above for given DiscreteValues */
double operator()(const DiscreteValues & values) const {
return evaluate(values);
}
/**
* Solve the DiscreteBayesNet by back-substitution
*/

View File

@ -80,6 +80,12 @@ class GTSAM_EXPORT DiscreteBayesTree
//** evaluate probability for given DiscreteValues */
double evaluate(const DiscreteValues& values) const;
//** (Preferred) sugar for the above for given DiscreteValues */
double operator()(const DiscreteValues & values) const {
return evaluate(values);
}
};
} // namespace gtsam

View File

@ -97,25 +97,41 @@ bool DiscreteConditional::equals(const DiscreteFactor& other,
}
/* ******************************************************************************** */
Potentials::ADT DiscreteConditional::choose(const DiscreteValues& parentsValues) const {
Potentials::ADT DiscreteConditional::choose(
const DiscreteValues& parentsValues) const {
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables.
ADT pFS(*this);
Key j; size_t value;
for(Key key: parents()) {
size_t value;
for (Key j : parents()) {
try {
j = (key);
value = parentsValues.at(j);
pFS = pFS.choose(j, value);
pFS = pFS.choose(j, value); // ADT keeps getting smaller.
} catch (exception&) {
cout << "Key: " << j << " Value: " << value << endl;
parentsValues.print("parentsValues: ");
// pFS.print("pFS: ");
throw runtime_error("DiscreteConditional::choose: parent value missing");
};
}
return pFS;
}
/* ******************************************************************************** */
DecisionTreeFactor::shared_ptr DiscreteConditional::chooseAsFactor(
const DiscreteValues& parentsValues) const {
ADT pFS = choose(parentsValues);
// Convert ADT to factor.
if (nrFrontals() != 1) {
throw std::runtime_error("Expected only one frontal variable in choose.");
}
DiscreteKeys keys;
const Key frontalKey = keys_[0];
size_t frontalCardinality = this->cardinality(frontalKey);
keys.push_back(DiscreteKey(frontalKey, frontalCardinality));
return boost::make_shared<DecisionTreeFactor>(keys, pFS);
}
/* ******************************************************************************** */
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
// TODO: Abhijit asks: is this really the fastest way? He thinks it is.

View File

@ -57,6 +57,33 @@ public:
/** Construct from signature */
DiscreteConditional(const Signature& signature);
/**
* Construct from key, parents, and a Signature::Table specifying the
* conditional probability table (CPT) in 00 01 10 11 order. For
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
*
* The first string is parsed to add a key and parents.
*
* Example: DiscreteConditional P(D, {B,E}, table);
*/
DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents,
const Signature::Table& table)
: DiscreteConditional(Signature(key, parents, table)) {}
/**
* Construct from key, parents, and a string specifying the conditional
* probability table (CPT) in 00 01 10 11 order. For three-valued, it would
* be 00 01 02 10 11 12 20 21 22, etc....
*
* The first string is parsed to add a key and parents. The second string
* parses into a table.
*
* Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9");
*/
DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents,
const std::string& spec)
: DiscreteConditional(Signature(key, parents, spec)) {}
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal);
@ -111,6 +138,10 @@ public:
/** Restrict to given parent values, returns AlgebraicDecisionDiagram */
ADT choose(const DiscreteValues& parentsValues) const;
/** Restrict to given parent values, returns DecisionTreeFactor */
DecisionTreeFactor::shared_ptr chooseAsFactor(
const DiscreteValues& parentsValues) const;
/**
* solve a conditional
* @param parentsValues Known values of the parents

View File

@ -130,8 +130,11 @@ public:
/** return product of all factors as a single factor */
DecisionTreeFactor product() const;
/** Evaluates the factor graph given values, returns the joint probability of the factor graph given specific instantiation of values*/
double operator()(const DiscreteValues & values) const;
/**
* Evaluates the factor graph given values, returns the joint probability of
* the factor graph given specific instantiation of values
*/
double operator()(const DiscreteValues& values) const;
/// print
void print(

View File

@ -31,22 +31,24 @@ namespace gtsam {
* Key type for discrete conditionals
* Includes name and cardinality
*/
typedef std::pair<Key,size_t> DiscreteKey;
using DiscreteKey = std::pair<Key,size_t>;
/// DiscreteKeys is a set of keys that can be assembled using the & operator
struct DiscreteKeys: public std::vector<DiscreteKey> {
/// Default constructor
DiscreteKeys() {
}
// Forward all constructors.
using std::vector<DiscreteKey>::vector;
/// Constructor for serialization
GTSAM_EXPORT DiscreteKeys() : std::vector<DiscreteKey>::vector() {}
/// Construct from a key
DiscreteKeys(const DiscreteKey& key) {
GTSAM_EXPORT DiscreteKeys(const DiscreteKey& key) {
push_back(key);
}
/// Construct from a vector of keys
DiscreteKeys(const std::vector<DiscreteKey>& keys) :
GTSAM_EXPORT DiscreteKeys(const std::vector<DiscreteKey>& keys) :
std::vector<DiscreteKey>(keys) {
}

View File

@ -38,19 +38,7 @@ namespace gtsam {
using boost::phoenix::push_back;
// Special rows, true and false
Signature::Row createF() {
Signature::Row r(2);
r[0] = 1;
r[1] = 0;
return r;
}
Signature::Row createT() {
Signature::Row r(2);
r[0] = 0;
r[1] = 1;
return r;
}
Signature::Row T = createT(), F = createF();
Signature::Row F{1, 0}, T{0, 1};
// Special tables (inefficient, but do we care for user input?)
Signature::Table logic(bool ff, bool ft, bool tf, bool tt) {
@ -69,40 +57,13 @@ namespace gtsam {
table = or_ | and_ | rows;
or_ = qi::lit("OR")[qi::_val = logic(false, true, true, true)];
and_ = qi::lit("AND")[qi::_val = logic(false, false, false, true)];
rows = +(row | true_ | false_); // only loads first of the rows under boost 1.42
rows = +(row | true_ | false_);
row = qi::double_ >> +("/" >> qi::double_);
true_ = qi::lit("T")[qi::_val = T];
false_ = qi::lit("F")[qi::_val = F];
}
} grammar;
// Create simpler parsing function to avoid the issue of only parsing a single row
bool parse_table(const string& spec, Signature::Table& table) {
// check for OR, AND on whole phrase
It f = spec.begin(), l = spec.end();
if (qi::parse(f, l,
qi::lit("OR")[ph::ref(table) = logic(false, true, true, true)]) ||
qi::parse(f, l,
qi::lit("AND")[ph::ref(table) = logic(false, false, false, true)]))
return true;
// tokenize into separate rows
istringstream iss(spec);
string token;
while (iss >> token) {
Signature::Row values;
It tf = token.begin(), tl = token.end();
bool r = qi::parse(tf, tl,
qi::double_[push_back(ph::ref(values), qi::_1)] >> +("/" >> qi::double_[push_back(ph::ref(values), qi::_1)]) |
qi::lit("T")[ph::ref(values) = T] |
qi::lit("F")[ph::ref(values) = F] );
if (!r)
return false;
table.push_back(values);
}
return true;
}
} // \namespace parser
ostream& operator <<(ostream &os, const Signature::Row &row) {
@ -118,6 +79,18 @@ namespace gtsam {
return os;
}
Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents,
const Table& table)
: key_(key), parents_(parents) {
operator=(table);
}
Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents,
const std::string& spec)
: key_(key), parents_(parents) {
operator=(spec);
}
Signature::Signature(const DiscreteKey& key) :
key_(key) {
}
@ -166,14 +139,11 @@ namespace gtsam {
Signature& Signature::operator=(const string& spec) {
spec_.reset(spec);
Table table;
// NOTE: using simpler parse function to ensure boost back compatibility
// parser::It f = spec.begin(), l = spec.end();
bool success = //
// qi::phrase_parse(f, l, parser::grammar.table, qi::space, table); // using full grammar
parser::parse_table(spec, table);
parser::It f = spec.begin(), l = spec.end();
bool success =
qi::phrase_parse(f, l, parser::grammar.table, qi::space, table);
if (success) {
for(Row& row: table)
normalize(row);
for (Row& row : table) normalize(row);
table_.reset(table);
}
return *this;

View File

@ -30,7 +30,7 @@ namespace gtsam {
* The format is (Key % string) for nodes with no parents,
* and (Key | Key, Key = string) for nodes with parents.
*
* The string specifies a conditional probability spec in the 00 01 10 11 order.
* The string specifies a conditional probability table in 00 01 10 11 order.
* For three-valued, it would be 00 01 02 10 11 12 20 21 22, etc...
*
* For example, given the following keys
@ -45,9 +45,9 @@ namespace gtsam {
* T|A = "99/1 95/5"
* L|S = "99/1 90/10"
* B|S = "70/30 40/60"
* E|T,L = "F F F 1"
* (E|T,L) = "F F F 1"
* X|E = "95/5 2/98"
* D|E,B = "9/1 2/8 3/7 1/9"
* (D|E,B) = "9/1 2/8 3/7 1/9"
*/
class GTSAM_EXPORT Signature {
@ -72,45 +72,73 @@ namespace gtsam {
boost::optional<Table> table_;
public:
/**
* Construct from key, parents, and a Signature::Table specifying the
* conditional probability table (CPT) in 00 01 10 11 order. For
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
*
* The first string is parsed to add a key and parents.
*
* Example:
* Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}};
* Signature sig(D, {E, B}, table);
*/
Signature(const DiscreteKey& key, const DiscreteKeys& parents,
const Table& table);
/** Constructor from DiscreteKey */
Signature(const DiscreteKey& key);
/**
* Construct from key, parents, and a string specifying the conditional
* probability table (CPT) in 00 01 10 11 order. For three-valued, it would
* be 00 01 02 10 11 12 20 21 22, etc....
*
* The first string is parsed to add a key and parents. The second string
* parses into a table.
*
* Example (same CPT as above):
* Signature sig(D, {B,E}, "9/1 2/8 3/7 1/9");
*/
Signature(const DiscreteKey& key, const DiscreteKeys& parents,
const std::string& spec);
/** the variable key */
const DiscreteKey& key() const {
return key_;
}
/**
* Construct from a single DiscreteKey.
*
* The resulting signature has no parents or CPT table. Typical use then
* either adds parents with | and , operators below, or assigns a table with
* operator=().
*/
Signature(const DiscreteKey& key);
/** the parent keys */
const DiscreteKeys& parents() const {
return parents_;
}
/** the variable key */
const DiscreteKey& key() const { return key_; }
/** All keys, with variable key first */
DiscreteKeys discreteKeys() const;
/** the parent keys */
const DiscreteKeys& parents() const { return parents_; }
/** All key indices, with variable key first */
KeyVector indices() const;
/** All keys, with variable key first */
DiscreteKeys discreteKeys() const;
// the CPT as parsed, if successful
const boost::optional<Table>& table() const {
return table_;
}
/** All key indices, with variable key first */
KeyVector indices() const;
// the CPT as a vector of doubles, with key's values most rapidly changing
std::vector<double> cpt() const;
// the CPT as parsed, if successful
const boost::optional<Table>& table() const { return table_; }
/** Add a parent */
Signature& operator,(const DiscreteKey& parent);
// the CPT as a vector of doubles, with key's values most rapidly changing
std::vector<double> cpt() const;
/** Add the CPT spec - Fails in boost 1.40 */
Signature& operator=(const std::string& spec);
/** Add a parent */
Signature& operator,(const DiscreteKey& parent);
/** Add the CPT spec directly as a table */
Signature& operator=(const Table& table);
/** Add the CPT spec */
Signature& operator=(const std::string& spec);
/** provide streaming */
GTSAM_EXPORT friend std::ostream& operator <<(std::ostream &os, const Signature &s);
/** Add the CPT spec directly as a table */
Signature& operator=(const Table& table);
/** provide streaming */
GTSAM_EXPORT friend std::ostream& operator<<(std::ostream& os,
const Signature& s);
};
/**
@ -122,7 +150,6 @@ namespace gtsam {
/**
* Helper function to create Signature objects
* example: Signature s(D % "99/1");
* Uses string parser, which requires BOOST 1.42 or higher
*/
GTSAM_EXPORT Signature operator%(const DiscreteKey& key, const std::string& parent);

131
gtsam/discrete/discrete.i Normal file
View File

@ -0,0 +1,131 @@
//*************************************************************************
// discrete
//*************************************************************************
namespace gtsam {
#include<gtsam/discrete/DiscreteKey.h>
class DiscreteKey {};
class DiscreteKeys {
DiscreteKeys();
size_t size() const;
bool empty() const;
gtsam::DiscreteKey at(size_t n) const;
void push_back(const gtsam::DiscreteKey& point_pair);
};
// DiscreteValues is added in specializations/discrete.h as a std::map
#include <gtsam/discrete/DiscreteFactor.h>
class DiscreteFactor {
void print(string s = "DiscreteFactor\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const;
bool empty() const;
size_t size() const;
double operator()(const gtsam::DiscreteValues& values) const;
};
#include <gtsam/discrete/DecisionTreeFactor.h>
virtual class DecisionTreeFactor: gtsam::DiscreteFactor {
DecisionTreeFactor();
DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table);
DecisionTreeFactor(const gtsam::DiscreteConditional& c);
void print(string s = "DecisionTreeFactor\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
};
#include <gtsam/discrete/DiscreteConditional.h>
virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
DiscreteConditional();
DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f);
DiscreteConditional(const gtsam::DiscreteKey& key,
const gtsam::DiscreteKeys& parents, string spec);
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
const gtsam::DecisionTreeFactor& marginal);
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
const gtsam::DecisionTreeFactor& marginal,
const gtsam::Ordering& orderedKeys);
size_t size() const; // TODO(dellaert): why do I have to repeat???
double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
void print(string s = "Discrete Conditional\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
void printSignature(
string s = "Discrete Conditional: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
gtsam::DecisionTreeFactor* toFactor() const;
gtsam::DecisionTreeFactor* chooseAsFactor(const gtsam::DiscreteValues& parentsValues) const;
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
void solveInPlace(gtsam::DiscreteValues@ parentsValues) const;
void sampleInPlace(gtsam::DiscreteValues@ parentsValues) const;
};
#include <gtsam/discrete/DiscreteBayesNet.h>
class DiscreteBayesNet {
DiscreteBayesNet();
void add(const gtsam::DiscreteKey& key,
const gtsam::DiscreteKeys& parents, string spec);
bool empty() const;
size_t size() const;
gtsam::KeySet keys() const;
const gtsam::DiscreteConditional* at(size_t i) const;
void print(string s = "DiscreteBayesNet\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
void saveGraph(string s,
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void add(const gtsam::DiscreteConditional& s);
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const;
gtsam::DiscreteValues sample() const;
};
#include <gtsam/discrete/DiscreteBayesTree.h>
class DiscreteBayesTree {
DiscreteBayesTree();
void print(string s = "DiscreteBayesTree\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const;
double operator()(const gtsam::DiscreteValues& values) const;
};
#include <gtsam/discrete/DiscreteFactorGraph.h>
class DiscreteFactorGraph {
DiscreteFactorGraph();
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
void add(const gtsam::DiscreteKey& j, string table);
void add(const gtsam::DiscreteKey& j1, const gtsam::DiscreteKey& j2, string table);
void add(const gtsam::DiscreteKeys& keys, string table);
bool empty() const;
size_t size() const;
gtsam::KeySet keys() const;
const gtsam::DiscreteFactor* at(size_t i) const;
void print(string s = "") const;
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;
gtsam::DecisionTreeFactor product() const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const;
gtsam::DiscreteBayesNet eliminateSequential();
gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesTree eliminateMultifrontal();
gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering);
};
} // namespace gtsam

View File

@ -61,11 +61,10 @@ TEST(DiscreteConditional, constructors_alt_interface) {
r2 += 2.0, 3.0;
r3 += 1.0, 4.0;
table += r1, r2, r3;
auto actual1 = boost::make_shared<DiscreteConditional>(X | Y = table);
EXPECT(actual1);
DiscreteConditional actual1(X, {Y}, table);
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
DiscreteConditional expected1(1, f1);
EXPECT(assert_equal(expected1, *actual1, 1e-9));
EXPECT(assert_equal(expected1, actual1, 1e-9));
DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");

View File

@ -32,22 +32,27 @@ DiscreteKey X(0, 2), Y(1, 3), Z(2, 2);
/* ************************************************************************* */
TEST(testSignature, simple_conditional) {
Signature sig(X | Y = "1/1 2/3 1/4");
Signature sig(X, {Y}, "1/1 2/3 1/4");
CHECK(sig.table());
Signature::Table table = *sig.table();
vector<double> row[3]{{0.5, 0.5}, {0.4, 0.6}, {0.2, 0.8}};
LONGS_EQUAL(3, table.size());
CHECK(row[0] == table[0]);
CHECK(row[1] == table[1]);
CHECK(row[2] == table[2]);
DiscreteKey actKey = sig.key();
LONGS_EQUAL(X.first, actKey.first);
DiscreteKeys actKeys = sig.discreteKeys();
LONGS_EQUAL(2, actKeys.size());
LONGS_EQUAL(X.first, actKeys.front().first);
LONGS_EQUAL(Y.first, actKeys.back().first);
CHECK(sig.key() == X);
vector<double> actCpt = sig.cpt();
EXPECT_LONGS_EQUAL(6, actCpt.size());
DiscreteKeys keys = sig.discreteKeys();
LONGS_EQUAL(2, keys.size());
CHECK(keys[0] == X);
CHECK(keys[1] == Y);
DiscreteKeys parents = sig.parents();
LONGS_EQUAL(1, parents.size());
CHECK(parents[0] == Y);
EXPECT_LONGS_EQUAL(6, sig.cpt().size());
}
/* ************************************************************************* */
@ -60,16 +65,56 @@ TEST(testSignature, simple_conditional_nonparser) {
table += row1, row2, row3;
Signature sig(X | Y = table);
DiscreteKey actKey = sig.key();
EXPECT_LONGS_EQUAL(X.first, actKey.first);
CHECK(sig.key() == X);
DiscreteKeys actKeys = sig.discreteKeys();
LONGS_EQUAL(2, actKeys.size());
LONGS_EQUAL(X.first, actKeys.front().first);
LONGS_EQUAL(Y.first, actKeys.back().first);
DiscreteKeys keys = sig.discreteKeys();
LONGS_EQUAL(2, keys.size());
CHECK(keys[0] == X);
CHECK(keys[1] == Y);
vector<double> actCpt = sig.cpt();
EXPECT_LONGS_EQUAL(6, actCpt.size());
DiscreteKeys parents = sig.parents();
LONGS_EQUAL(1, parents.size());
CHECK(parents[0] == Y);
EXPECT_LONGS_EQUAL(6, sig.cpt().size());
}
/* ************************************************************************* */
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), D(7, 2);
// Make sure we can create all signatures for Asia network with constructor.
TEST(testSignature, all_examples) {
DiscreteKey X(6, 2);
Signature a(A, {}, "99/1");
Signature s(S, {}, "50/50");
Signature t(T, {A}, "99/1 95/5");
Signature l(L, {S}, "99/1 90/10");
Signature b(B, {S}, "70/30 40/60");
Signature e(E, {T, L}, "F F F 1");
Signature x(X, {E}, "95/5 2/98");
}
// Make sure we can create all signatures for Asia network with operator magic.
TEST(testSignature, all_examples_magic) {
DiscreteKey X(6, 2);
Signature a(A % "99/1");
Signature s(S % "50/50");
Signature t(T | A = "99/1 95/5");
Signature l(L | S = "99/1 90/10");
Signature b(B | S = "70/30 40/60");
Signature e((E | T, L) = "F F F 1");
Signature x(X | E = "95/5 2/98");
}
// Check example from docs.
TEST(testSignature, doxygen_example) {
Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}};
Signature d1(D, {E, B}, table);
Signature d2((D | E, B) = "9/1 2/8 3/7 1/9");
Signature d3(D, {E, B}, "9/1 2/8 3/7 1/9");
EXPECT(*(d1.table()) == table);
EXPECT(*(d2.table()) == table);
EXPECT(*(d3.table()) == table);
}
/* ************************************************************************* */

View File

@ -49,11 +49,13 @@ set(ignore
gtsam::Pose3Vector
gtsam::KeyVector
gtsam::BinaryMeasurementsUnit3
gtsam::DiscreteKey
gtsam::KeyPairDoubleMap)
set(interface_headers
${PROJECT_SOURCE_DIR}/gtsam/gtsam.i
${PROJECT_SOURCE_DIR}/gtsam/base/base.i
${PROJECT_SOURCE_DIR}/gtsam/discrete/discrete.i
${PROJECT_SOURCE_DIR}/gtsam/geometry/geometry.i
${PROJECT_SOURCE_DIR}/gtsam/linear/linear.i
${PROJECT_SOURCE_DIR}/gtsam/nonlinear/nonlinear.i

View File

@ -0,0 +1,16 @@
/* Please refer to:
* https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html
* These are required to save one copy operation on Python calls.
*
* NOTES
* =================
*
* `PYBIND11_MAKE_OPAQUE` will mark the type as "opaque" for the pybind11
* automatic STL binding, such that the raw objects can be accessed in Python.
* Without this they will be automatically converted to a Python object, and all
* mutations on Python side will not be reflected on C++.
*/
#include <pybind11/stl.h>
PYBIND11_MAKE_OPAQUE(gtsam::DiscreteKeys);

View File

@ -0,0 +1,17 @@
/* Please refer to:
* https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html
* These are required to save one copy operation on Python calls.
*
* NOTES
* =================
*
* `py::bind_vector` and similar machinery gives the std container a Python-like
* interface, but without the `<pybind11/stl.h>` copying mechanism. Combined
* with `PYBIND11_MAKE_OPAQUE` this allows the types to be modified with Python,
* and saves one copy operation.
*/
// Seems this is not a good idea with inherited stl
//py::bind_vector<std::vector<gtsam::DiscreteKey>>(m_, "DiscreteKeys");
py::bind_map<gtsam::DiscreteValues>(m_, "DiscreteValues");

View File

@ -0,0 +1,118 @@
"""
GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
Atlanta, Georgia 30332-0415
All Rights Reserved
See LICENSE for the license information
Unit tests for Discrete Bayes Nets.
Author: Frank Dellaert
"""
# pylint: disable=no-name-in-module, invalid-name
import unittest
from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
DiscreteKeys, DiscreteValues, Ordering)
from gtsam.utils.test_case import GtsamTestCase
class TestDiscreteBayesNet(GtsamTestCase):
"""Tests for Discrete Bayes Nets."""
def test_constructor(self):
"""Test constructing a Bayes net."""
bayesNet = DiscreteBayesNet()
Parent, Child = (0, 2), (1, 2)
empty = DiscreteKeys()
prior = DiscreteConditional(Parent, empty, "6/4")
bayesNet.add(prior)
parents = DiscreteKeys()
parents.push_back(Parent)
conditional = DiscreteConditional(Child, parents, "7/3 8/2")
bayesNet.add(conditional)
# Check conversion to factor graph:
fg = DiscreteFactorGraph(bayesNet)
self.assertEqual(fg.size(), 2)
self.assertEqual(fg.at(1).size(), 2)
def test_Asia(self):
"""Test full Asia example."""
Asia = (0, 2)
Smoking = (4, 2)
Tuberculosis = (3, 2)
LungCancer = (6, 2)
Bronchitis = (7, 2)
Either = (5, 2)
XRay = (2, 2)
Dyspnea = (1, 2)
def P(keys):
dks = DiscreteKeys()
for key in keys:
dks.push_back(key)
return dks
asia = DiscreteBayesNet()
asia.add(Asia, P([]), "99/1")
asia.add(Smoking, P([]), "50/50")
asia.add(Tuberculosis, P([Asia]), "99/1 95/5")
asia.add(LungCancer, P([Smoking]), "99/1 90/10")
asia.add(Bronchitis, P([Smoking]), "70/30 40/60")
asia.add(Either, P([Tuberculosis, LungCancer]), "F T T T")
asia.add(XRay, P([Either]), "95/5 2/98")
asia.add(Dyspnea, P([Either, Bronchitis]), "9/1 2/8 3/7 1/9")
# Convert to factor graph
fg = DiscreteFactorGraph(asia)
# Create solver and eliminate
ordering = Ordering()
for j in range(8):
ordering.push_back(j)
chordal = fg.eliminateSequential(ordering)
expected2 = DiscreteConditional(Bronchitis, P([]), "11/9")
self.gtsamAssertEquals(chordal.at(7), expected2)
# solve
actualMPE = chordal.optimize()
expectedMPE = DiscreteValues()
for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]:
expectedMPE[key[0]] = 0
self.assertEqual(list(actualMPE.items()),
list(expectedMPE.items()))
# Check value for MPE is the same
self.assertAlmostEqual(asia(actualMPE), fg(actualMPE))
# add evidence, we were in Asia and we have dyspnea
fg.add(Asia, "0 1")
fg.add(Dyspnea, "0 1")
# solve again, now with evidence
chordal2 = fg.eliminateSequential(ordering)
actualMPE2 = chordal2.optimize()
expectedMPE2 = DiscreteValues()
for key in [XRay, Tuberculosis, Either, LungCancer]:
expectedMPE2[key[0]] = 0
for key in [Asia, Dyspnea, Smoking, Bronchitis]:
expectedMPE2[key[0]] = 1
self.assertEqual(list(actualMPE2.items()),
list(expectedMPE2.items()))
# now sample from it
actualSample = chordal2.sample()
self.assertEqual(len(actualSample), 8)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,122 @@
"""
GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
Atlanta, Georgia 30332-0415
All Rights Reserved
See LICENSE for the license information
Unit tests for Discrete Factor Graphs.
Author: Frank Dellaert
"""
# pylint: disable=no-name-in-module, invalid-name
import unittest
from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues
from gtsam.utils.test_case import GtsamTestCase
class TestDiscreteFactorGraph(GtsamTestCase):
"""Tests for Discrete Factor Graphs."""
def test_evaluation(self):
"""Test constructing and evaluating a discrete factor graph."""
# Three keys
P1 = (0, 2)
P2 = (1, 2)
P3 = (2, 3)
# Create the DiscreteFactorGraph
graph = DiscreteFactorGraph()
# Add two unary factors (priors)
graph.add(P1, "0.9 0.3")
graph.add(P2, "0.9 0.6")
# Add a binary factor
graph.add(P1, P2, "4 1 10 4")
# Instantiate Values
assignment = DiscreteValues()
assignment[0] = 1
assignment[1] = 1
# Check if graph evaluation works ( 0.3*0.6*4 )
self.assertAlmostEqual(.72, graph(assignment))
# Create a new test with third node and adding unary and ternary factor
graph.add(P3, "0.9 0.2 0.5")
keys = DiscreteKeys()
keys.push_back(P1)
keys.push_back(P2)
keys.push_back(P3)
graph.add(keys, "1 2 3 4 5 6 7 8 9 10 11 12")
# Below assignment selects the 8th index in the ternary factor table
assignment[0] = 1
assignment[1] = 0
assignment[2] = 1
# Check if graph evaluation works (0.3*0.9*1*0.2*8)
self.assertAlmostEqual(4.32, graph(assignment))
# Below assignment selects the 3rd index in the ternary factor table
assignment[0] = 0
assignment[1] = 1
assignment[2] = 0
# Check if graph evaluation works (0.9*0.6*1*0.9*4)
self.assertAlmostEqual(1.944, graph(assignment))
# Check if graph product works
product = graph.product()
self.assertAlmostEqual(1.944, product(assignment))
def test_optimize(self):
"""Test constructing and optizing a discrete factor graph."""
# Three keys
C = (0, 2)
B = (1, 2)
A = (2, 2)
# A simple factor graph (A)-fAC-(C)-fBC-(B)
# with smoothness priors
graph = DiscreteFactorGraph()
graph.add(A, C, "3 1 1 3")
graph.add(C, B, "3 1 1 3")
# Test optimization
expectedValues = DiscreteValues()
expectedValues[0] = 0
expectedValues[1] = 0
expectedValues[2] = 0
actualValues = graph.optimize()
self.assertEqual(list(actualValues.items()),
list(expectedValues.items()))
def test_MPE(self):
"""Test maximum probable explanation (MPE): same as optimize."""
# Declare a bunch of keys
C, A, B = (0, 2), (1, 2), (2, 2)
# Create Factor graph
graph = DiscreteFactorGraph()
graph.add(C, A, "0.2 0.8 0.3 0.7")
graph.add(C, B, "0.1 0.9 0.4 0.6")
actualMPE = graph.optimize()
expectedMPE = DiscreteValues()
expectedMPE[0] = 0
expectedMPE[1] = 1
expectedMPE[2] = 1
self.assertEqual(list(actualMPE.items()),
list(expectedMPE.items()))
if __name__ == "__main__":
unittest.main()