Merge pull request #967 from borglab/feature/discrete_wrapper
commit
0bab7b00c8
|
@ -34,16 +34,6 @@ namespace gtsam {
|
|||
return Base::equals(bn, tol);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// void DiscreteBayesNet::add_front(const Signature& s) {
|
||||
// push_front(boost::make_shared<DiscreteConditional>(s));
|
||||
// }
|
||||
|
||||
/* ************************************************************************* */
|
||||
void DiscreteBayesNet::add(const Signature& s) {
|
||||
push_back(boost::make_shared<DiscreteConditional>(s));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double DiscreteBayesNet::evaluate(const DiscreteValues & values) const {
|
||||
// evaluate all conditionals and multiply
|
||||
|
|
|
@ -71,15 +71,23 @@ namespace gtsam {
|
|||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
// Add inherited versions of add.
|
||||
using Base::add;
|
||||
|
||||
/** Add a DiscreteCondtional */
|
||||
void add(const Signature& s);
|
||||
|
||||
// /** Add a DiscreteCondtional in front, when listing parents first*/
|
||||
// GTSAM_EXPORT void add_front(const Signature& s);
|
||||
|
||||
template <typename... Args>
|
||||
void add(Args&&... args) {
|
||||
emplace_shared<DiscreteConditional>(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
//** evaluate for given DiscreteValues */
|
||||
double evaluate(const DiscreteValues & values) const;
|
||||
|
||||
//** (Preferred) sugar for the above for given DiscreteValues */
|
||||
double operator()(const DiscreteValues & values) const {
|
||||
return evaluate(values);
|
||||
}
|
||||
|
||||
/**
|
||||
* Solve the DiscreteBayesNet by back-substitution
|
||||
*/
|
||||
|
|
|
@ -80,6 +80,12 @@ class GTSAM_EXPORT DiscreteBayesTree
|
|||
|
||||
//** evaluate probability for given DiscreteValues */
|
||||
double evaluate(const DiscreteValues& values) const;
|
||||
|
||||
//** (Preferred) sugar for the above for given DiscreteValues */
|
||||
double operator()(const DiscreteValues & values) const {
|
||||
return evaluate(values);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -97,25 +97,41 @@ bool DiscreteConditional::equals(const DiscreteFactor& other,
|
|||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
Potentials::ADT DiscreteConditional::choose(const DiscreteValues& parentsValues) const {
|
||||
Potentials::ADT DiscreteConditional::choose(
|
||||
const DiscreteValues& parentsValues) const {
|
||||
// Get the big decision tree with all the levels, and then go down the
|
||||
// branches based on the value of the parent variables.
|
||||
ADT pFS(*this);
|
||||
Key j; size_t value;
|
||||
for(Key key: parents()) {
|
||||
size_t value;
|
||||
for (Key j : parents()) {
|
||||
try {
|
||||
j = (key);
|
||||
value = parentsValues.at(j);
|
||||
pFS = pFS.choose(j, value);
|
||||
pFS = pFS.choose(j, value); // ADT keeps getting smaller.
|
||||
} catch (exception&) {
|
||||
cout << "Key: " << j << " Value: " << value << endl;
|
||||
parentsValues.print("parentsValues: ");
|
||||
// pFS.print("pFS: ");
|
||||
throw runtime_error("DiscreteConditional::choose: parent value missing");
|
||||
};
|
||||
}
|
||||
|
||||
return pFS;
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
DecisionTreeFactor::shared_ptr DiscreteConditional::chooseAsFactor(
|
||||
const DiscreteValues& parentsValues) const {
|
||||
ADT pFS = choose(parentsValues);
|
||||
|
||||
// Convert ADT to factor.
|
||||
if (nrFrontals() != 1) {
|
||||
throw std::runtime_error("Expected only one frontal variable in choose.");
|
||||
}
|
||||
DiscreteKeys keys;
|
||||
const Key frontalKey = keys_[0];
|
||||
size_t frontalCardinality = this->cardinality(frontalKey);
|
||||
keys.push_back(DiscreteKey(frontalKey, frontalCardinality));
|
||||
return boost::make_shared<DecisionTreeFactor>(keys, pFS);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
||||
// TODO: Abhijit asks: is this really the fastest way? He thinks it is.
|
||||
|
|
|
@ -57,6 +57,33 @@ public:
|
|||
/** Construct from signature */
|
||||
DiscreteConditional(const Signature& signature);
|
||||
|
||||
/**
|
||||
* Construct from key, parents, and a Signature::Table specifying the
|
||||
* conditional probability table (CPT) in 00 01 10 11 order. For
|
||||
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
|
||||
*
|
||||
* The first string is parsed to add a key and parents.
|
||||
*
|
||||
* Example: DiscreteConditional P(D, {B,E}, table);
|
||||
*/
|
||||
DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||
const Signature::Table& table)
|
||||
: DiscreteConditional(Signature(key, parents, table)) {}
|
||||
|
||||
/**
|
||||
* Construct from key, parents, and a string specifying the conditional
|
||||
* probability table (CPT) in 00 01 10 11 order. For three-valued, it would
|
||||
* be 00 01 02 10 11 12 20 21 22, etc....
|
||||
*
|
||||
* The first string is parsed to add a key and parents. The second string
|
||||
* parses into a table.
|
||||
*
|
||||
* Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9");
|
||||
*/
|
||||
DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||
const std::string& spec)
|
||||
: DiscreteConditional(Signature(key, parents, spec)) {}
|
||||
|
||||
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
|
||||
DiscreteConditional(const DecisionTreeFactor& joint,
|
||||
const DecisionTreeFactor& marginal);
|
||||
|
@ -111,6 +138,10 @@ public:
|
|||
/** Restrict to given parent values, returns AlgebraicDecisionDiagram */
|
||||
ADT choose(const DiscreteValues& parentsValues) const;
|
||||
|
||||
/** Restrict to given parent values, returns DecisionTreeFactor */
|
||||
DecisionTreeFactor::shared_ptr chooseAsFactor(
|
||||
const DiscreteValues& parentsValues) const;
|
||||
|
||||
/**
|
||||
* solve a conditional
|
||||
* @param parentsValues Known values of the parents
|
||||
|
|
|
@ -130,8 +130,11 @@ public:
|
|||
/** return product of all factors as a single factor */
|
||||
DecisionTreeFactor product() const;
|
||||
|
||||
/** Evaluates the factor graph given values, returns the joint probability of the factor graph given specific instantiation of values*/
|
||||
double operator()(const DiscreteValues & values) const;
|
||||
/**
|
||||
* Evaluates the factor graph given values, returns the joint probability of
|
||||
* the factor graph given specific instantiation of values
|
||||
*/
|
||||
double operator()(const DiscreteValues& values) const;
|
||||
|
||||
/// print
|
||||
void print(
|
||||
|
|
|
@ -31,22 +31,24 @@ namespace gtsam {
|
|||
* Key type for discrete conditionals
|
||||
* Includes name and cardinality
|
||||
*/
|
||||
typedef std::pair<Key,size_t> DiscreteKey;
|
||||
using DiscreteKey = std::pair<Key,size_t>;
|
||||
|
||||
/// DiscreteKeys is a set of keys that can be assembled using the & operator
|
||||
struct DiscreteKeys: public std::vector<DiscreteKey> {
|
||||
|
||||
/// Default constructor
|
||||
DiscreteKeys() {
|
||||
}
|
||||
// Forward all constructors.
|
||||
using std::vector<DiscreteKey>::vector;
|
||||
|
||||
/// Constructor for serialization
|
||||
GTSAM_EXPORT DiscreteKeys() : std::vector<DiscreteKey>::vector() {}
|
||||
|
||||
/// Construct from a key
|
||||
DiscreteKeys(const DiscreteKey& key) {
|
||||
GTSAM_EXPORT DiscreteKeys(const DiscreteKey& key) {
|
||||
push_back(key);
|
||||
}
|
||||
|
||||
/// Construct from a vector of keys
|
||||
DiscreteKeys(const std::vector<DiscreteKey>& keys) :
|
||||
GTSAM_EXPORT DiscreteKeys(const std::vector<DiscreteKey>& keys) :
|
||||
std::vector<DiscreteKey>(keys) {
|
||||
}
|
||||
|
||||
|
|
|
@ -38,19 +38,7 @@ namespace gtsam {
|
|||
using boost::phoenix::push_back;
|
||||
|
||||
// Special rows, true and false
|
||||
Signature::Row createF() {
|
||||
Signature::Row r(2);
|
||||
r[0] = 1;
|
||||
r[1] = 0;
|
||||
return r;
|
||||
}
|
||||
Signature::Row createT() {
|
||||
Signature::Row r(2);
|
||||
r[0] = 0;
|
||||
r[1] = 1;
|
||||
return r;
|
||||
}
|
||||
Signature::Row T = createT(), F = createF();
|
||||
Signature::Row F{1, 0}, T{0, 1};
|
||||
|
||||
// Special tables (inefficient, but do we care for user input?)
|
||||
Signature::Table logic(bool ff, bool ft, bool tf, bool tt) {
|
||||
|
@ -69,40 +57,13 @@ namespace gtsam {
|
|||
table = or_ | and_ | rows;
|
||||
or_ = qi::lit("OR")[qi::_val = logic(false, true, true, true)];
|
||||
and_ = qi::lit("AND")[qi::_val = logic(false, false, false, true)];
|
||||
rows = +(row | true_ | false_); // only loads first of the rows under boost 1.42
|
||||
rows = +(row | true_ | false_);
|
||||
row = qi::double_ >> +("/" >> qi::double_);
|
||||
true_ = qi::lit("T")[qi::_val = T];
|
||||
false_ = qi::lit("F")[qi::_val = F];
|
||||
}
|
||||
} grammar;
|
||||
|
||||
// Create simpler parsing function to avoid the issue of only parsing a single row
|
||||
bool parse_table(const string& spec, Signature::Table& table) {
|
||||
// check for OR, AND on whole phrase
|
||||
It f = spec.begin(), l = spec.end();
|
||||
if (qi::parse(f, l,
|
||||
qi::lit("OR")[ph::ref(table) = logic(false, true, true, true)]) ||
|
||||
qi::parse(f, l,
|
||||
qi::lit("AND")[ph::ref(table) = logic(false, false, false, true)]))
|
||||
return true;
|
||||
|
||||
// tokenize into separate rows
|
||||
istringstream iss(spec);
|
||||
string token;
|
||||
while (iss >> token) {
|
||||
Signature::Row values;
|
||||
It tf = token.begin(), tl = token.end();
|
||||
bool r = qi::parse(tf, tl,
|
||||
qi::double_[push_back(ph::ref(values), qi::_1)] >> +("/" >> qi::double_[push_back(ph::ref(values), qi::_1)]) |
|
||||
qi::lit("T")[ph::ref(values) = T] |
|
||||
qi::lit("F")[ph::ref(values) = F] );
|
||||
if (!r)
|
||||
return false;
|
||||
table.push_back(values);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
} // \namespace parser
|
||||
|
||||
ostream& operator <<(ostream &os, const Signature::Row &row) {
|
||||
|
@ -118,6 +79,18 @@ namespace gtsam {
|
|||
return os;
|
||||
}
|
||||
|
||||
Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||
const Table& table)
|
||||
: key_(key), parents_(parents) {
|
||||
operator=(table);
|
||||
}
|
||||
|
||||
Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||
const std::string& spec)
|
||||
: key_(key), parents_(parents) {
|
||||
operator=(spec);
|
||||
}
|
||||
|
||||
Signature::Signature(const DiscreteKey& key) :
|
||||
key_(key) {
|
||||
}
|
||||
|
@ -166,14 +139,11 @@ namespace gtsam {
|
|||
Signature& Signature::operator=(const string& spec) {
|
||||
spec_.reset(spec);
|
||||
Table table;
|
||||
// NOTE: using simpler parse function to ensure boost back compatibility
|
||||
// parser::It f = spec.begin(), l = spec.end();
|
||||
bool success = //
|
||||
// qi::phrase_parse(f, l, parser::grammar.table, qi::space, table); // using full grammar
|
||||
parser::parse_table(spec, table);
|
||||
parser::It f = spec.begin(), l = spec.end();
|
||||
bool success =
|
||||
qi::phrase_parse(f, l, parser::grammar.table, qi::space, table);
|
||||
if (success) {
|
||||
for(Row& row: table)
|
||||
normalize(row);
|
||||
for (Row& row : table) normalize(row);
|
||||
table_.reset(table);
|
||||
}
|
||||
return *this;
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace gtsam {
|
|||
* The format is (Key % string) for nodes with no parents,
|
||||
* and (Key | Key, Key = string) for nodes with parents.
|
||||
*
|
||||
* The string specifies a conditional probability spec in the 00 01 10 11 order.
|
||||
* The string specifies a conditional probability table in 00 01 10 11 order.
|
||||
* For three-valued, it would be 00 01 02 10 11 12 20 21 22, etc...
|
||||
*
|
||||
* For example, given the following keys
|
||||
|
@ -45,9 +45,9 @@ namespace gtsam {
|
|||
* T|A = "99/1 95/5"
|
||||
* L|S = "99/1 90/10"
|
||||
* B|S = "70/30 40/60"
|
||||
* E|T,L = "F F F 1"
|
||||
* (E|T,L) = "F F F 1"
|
||||
* X|E = "95/5 2/98"
|
||||
* D|E,B = "9/1 2/8 3/7 1/9"
|
||||
* (D|E,B) = "9/1 2/8 3/7 1/9"
|
||||
*/
|
||||
class GTSAM_EXPORT Signature {
|
||||
|
||||
|
@ -72,45 +72,73 @@ namespace gtsam {
|
|||
boost::optional<Table> table_;
|
||||
|
||||
public:
|
||||
/**
|
||||
* Construct from key, parents, and a Signature::Table specifying the
|
||||
* conditional probability table (CPT) in 00 01 10 11 order. For
|
||||
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
|
||||
*
|
||||
* The first string is parsed to add a key and parents.
|
||||
*
|
||||
* Example:
|
||||
* Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}};
|
||||
* Signature sig(D, {E, B}, table);
|
||||
*/
|
||||
Signature(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||
const Table& table);
|
||||
|
||||
/** Constructor from DiscreteKey */
|
||||
Signature(const DiscreteKey& key);
|
||||
/**
|
||||
* Construct from key, parents, and a string specifying the conditional
|
||||
* probability table (CPT) in 00 01 10 11 order. For three-valued, it would
|
||||
* be 00 01 02 10 11 12 20 21 22, etc....
|
||||
*
|
||||
* The first string is parsed to add a key and parents. The second string
|
||||
* parses into a table.
|
||||
*
|
||||
* Example (same CPT as above):
|
||||
* Signature sig(D, {B,E}, "9/1 2/8 3/7 1/9");
|
||||
*/
|
||||
Signature(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||
const std::string& spec);
|
||||
|
||||
/** the variable key */
|
||||
const DiscreteKey& key() const {
|
||||
return key_;
|
||||
}
|
||||
/**
|
||||
* Construct from a single DiscreteKey.
|
||||
*
|
||||
* The resulting signature has no parents or CPT table. Typical use then
|
||||
* either adds parents with | and , operators below, or assigns a table with
|
||||
* operator=().
|
||||
*/
|
||||
Signature(const DiscreteKey& key);
|
||||
|
||||
/** the parent keys */
|
||||
const DiscreteKeys& parents() const {
|
||||
return parents_;
|
||||
}
|
||||
/** the variable key */
|
||||
const DiscreteKey& key() const { return key_; }
|
||||
|
||||
/** All keys, with variable key first */
|
||||
DiscreteKeys discreteKeys() const;
|
||||
/** the parent keys */
|
||||
const DiscreteKeys& parents() const { return parents_; }
|
||||
|
||||
/** All key indices, with variable key first */
|
||||
KeyVector indices() const;
|
||||
/** All keys, with variable key first */
|
||||
DiscreteKeys discreteKeys() const;
|
||||
|
||||
// the CPT as parsed, if successful
|
||||
const boost::optional<Table>& table() const {
|
||||
return table_;
|
||||
}
|
||||
/** All key indices, with variable key first */
|
||||
KeyVector indices() const;
|
||||
|
||||
// the CPT as a vector of doubles, with key's values most rapidly changing
|
||||
std::vector<double> cpt() const;
|
||||
// the CPT as parsed, if successful
|
||||
const boost::optional<Table>& table() const { return table_; }
|
||||
|
||||
/** Add a parent */
|
||||
Signature& operator,(const DiscreteKey& parent);
|
||||
// the CPT as a vector of doubles, with key's values most rapidly changing
|
||||
std::vector<double> cpt() const;
|
||||
|
||||
/** Add the CPT spec - Fails in boost 1.40 */
|
||||
Signature& operator=(const std::string& spec);
|
||||
/** Add a parent */
|
||||
Signature& operator,(const DiscreteKey& parent);
|
||||
|
||||
/** Add the CPT spec directly as a table */
|
||||
Signature& operator=(const Table& table);
|
||||
/** Add the CPT spec */
|
||||
Signature& operator=(const std::string& spec);
|
||||
|
||||
/** provide streaming */
|
||||
GTSAM_EXPORT friend std::ostream& operator <<(std::ostream &os, const Signature &s);
|
||||
/** Add the CPT spec directly as a table */
|
||||
Signature& operator=(const Table& table);
|
||||
|
||||
/** provide streaming */
|
||||
GTSAM_EXPORT friend std::ostream& operator<<(std::ostream& os,
|
||||
const Signature& s);
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -122,7 +150,6 @@ namespace gtsam {
|
|||
/**
|
||||
* Helper function to create Signature objects
|
||||
* example: Signature s(D % "99/1");
|
||||
* Uses string parser, which requires BOOST 1.42 or higher
|
||||
*/
|
||||
GTSAM_EXPORT Signature operator%(const DiscreteKey& key, const std::string& parent);
|
||||
|
||||
|
|
|
@ -0,0 +1,131 @@
|
|||
//*************************************************************************
|
||||
// discrete
|
||||
//*************************************************************************
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
|
||||
#include<gtsam/discrete/DiscreteKey.h>
|
||||
class DiscreteKey {};
|
||||
|
||||
class DiscreteKeys {
|
||||
DiscreteKeys();
|
||||
size_t size() const;
|
||||
bool empty() const;
|
||||
gtsam::DiscreteKey at(size_t n) const;
|
||||
void push_back(const gtsam::DiscreteKey& point_pair);
|
||||
};
|
||||
|
||||
// DiscreteValues is added in specializations/discrete.h as a std::map
|
||||
|
||||
#include <gtsam/discrete/DiscreteFactor.h>
|
||||
class DiscreteFactor {
|
||||
void print(string s = "DiscreteFactor\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const;
|
||||
bool empty() const;
|
||||
size_t size() const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
virtual class DecisionTreeFactor: gtsam::DiscreteFactor {
|
||||
DecisionTreeFactor();
|
||||
DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table);
|
||||
DecisionTreeFactor(const gtsam::DiscreteConditional& c);
|
||||
void print(string s = "DecisionTreeFactor\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||
DiscreteConditional();
|
||||
DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f);
|
||||
DiscreteConditional(const gtsam::DiscreteKey& key,
|
||||
const gtsam::DiscreteKeys& parents, string spec);
|
||||
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
||||
const gtsam::DecisionTreeFactor& marginal);
|
||||
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
||||
const gtsam::DecisionTreeFactor& marginal,
|
||||
const gtsam::Ordering& orderedKeys);
|
||||
size_t size() const; // TODO(dellaert): why do I have to repeat???
|
||||
double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
|
||||
void print(string s = "Discrete Conditional\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
|
||||
void printSignature(
|
||||
string s = "Discrete Conditional: ",
|
||||
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||
gtsam::DecisionTreeFactor* toFactor() const;
|
||||
gtsam::DecisionTreeFactor* chooseAsFactor(const gtsam::DiscreteValues& parentsValues) const;
|
||||
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
|
||||
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
||||
void solveInPlace(gtsam::DiscreteValues@ parentsValues) const;
|
||||
void sampleInPlace(gtsam::DiscreteValues@ parentsValues) const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
class DiscreteBayesNet {
|
||||
DiscreteBayesNet();
|
||||
void add(const gtsam::DiscreteKey& key,
|
||||
const gtsam::DiscreteKeys& parents, string spec);
|
||||
bool empty() const;
|
||||
size_t size() const;
|
||||
gtsam::KeySet keys() const;
|
||||
const gtsam::DiscreteConditional* at(size_t i) const;
|
||||
void print(string s = "DiscreteBayesNet\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
|
||||
void saveGraph(string s,
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
void add(const gtsam::DiscreteConditional& s);
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
gtsam::DiscreteValues optimize() const;
|
||||
gtsam::DiscreteValues sample() const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||
class DiscreteBayesTree {
|
||||
DiscreteBayesTree();
|
||||
void print(string s = "DiscreteBayesTree\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
class DiscreteFactorGraph {
|
||||
DiscreteFactorGraph();
|
||||
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
|
||||
|
||||
void add(const gtsam::DiscreteKey& j, string table);
|
||||
void add(const gtsam::DiscreteKey& j1, const gtsam::DiscreteKey& j2, string table);
|
||||
void add(const gtsam::DiscreteKeys& keys, string table);
|
||||
|
||||
bool empty() const;
|
||||
size_t size() const;
|
||||
gtsam::KeySet keys() const;
|
||||
const gtsam::DiscreteFactor* at(size_t i) const;
|
||||
|
||||
void print(string s = "") const;
|
||||
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;
|
||||
|
||||
gtsam::DecisionTreeFactor product() const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
gtsam::DiscreteValues optimize() const;
|
||||
|
||||
gtsam::DiscreteBayesNet eliminateSequential();
|
||||
gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering);
|
||||
gtsam::DiscreteBayesTree eliminateMultifrontal();
|
||||
gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering);
|
||||
};
|
||||
|
||||
} // namespace gtsam
|
|
@ -61,11 +61,10 @@ TEST(DiscreteConditional, constructors_alt_interface) {
|
|||
r2 += 2.0, 3.0;
|
||||
r3 += 1.0, 4.0;
|
||||
table += r1, r2, r3;
|
||||
auto actual1 = boost::make_shared<DiscreteConditional>(X | Y = table);
|
||||
EXPECT(actual1);
|
||||
DiscreteConditional actual1(X, {Y}, table);
|
||||
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
||||
DiscreteConditional expected1(1, f1);
|
||||
EXPECT(assert_equal(expected1, *actual1, 1e-9));
|
||||
EXPECT(assert_equal(expected1, actual1, 1e-9));
|
||||
|
||||
DecisionTreeFactor f2(
|
||||
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||
|
|
|
@ -32,22 +32,27 @@ DiscreteKey X(0, 2), Y(1, 3), Z(2, 2);
|
|||
|
||||
/* ************************************************************************* */
|
||||
TEST(testSignature, simple_conditional) {
|
||||
Signature sig(X | Y = "1/1 2/3 1/4");
|
||||
Signature sig(X, {Y}, "1/1 2/3 1/4");
|
||||
CHECK(sig.table());
|
||||
Signature::Table table = *sig.table();
|
||||
vector<double> row[3]{{0.5, 0.5}, {0.4, 0.6}, {0.2, 0.8}};
|
||||
LONGS_EQUAL(3, table.size());
|
||||
CHECK(row[0] == table[0]);
|
||||
CHECK(row[1] == table[1]);
|
||||
CHECK(row[2] == table[2]);
|
||||
DiscreteKey actKey = sig.key();
|
||||
LONGS_EQUAL(X.first, actKey.first);
|
||||
|
||||
DiscreteKeys actKeys = sig.discreteKeys();
|
||||
LONGS_EQUAL(2, actKeys.size());
|
||||
LONGS_EQUAL(X.first, actKeys.front().first);
|
||||
LONGS_EQUAL(Y.first, actKeys.back().first);
|
||||
CHECK(sig.key() == X);
|
||||
|
||||
vector<double> actCpt = sig.cpt();
|
||||
EXPECT_LONGS_EQUAL(6, actCpt.size());
|
||||
DiscreteKeys keys = sig.discreteKeys();
|
||||
LONGS_EQUAL(2, keys.size());
|
||||
CHECK(keys[0] == X);
|
||||
CHECK(keys[1] == Y);
|
||||
|
||||
DiscreteKeys parents = sig.parents();
|
||||
LONGS_EQUAL(1, parents.size());
|
||||
CHECK(parents[0] == Y);
|
||||
|
||||
EXPECT_LONGS_EQUAL(6, sig.cpt().size());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -60,16 +65,56 @@ TEST(testSignature, simple_conditional_nonparser) {
|
|||
table += row1, row2, row3;
|
||||
|
||||
Signature sig(X | Y = table);
|
||||
DiscreteKey actKey = sig.key();
|
||||
EXPECT_LONGS_EQUAL(X.first, actKey.first);
|
||||
CHECK(sig.key() == X);
|
||||
|
||||
DiscreteKeys actKeys = sig.discreteKeys();
|
||||
LONGS_EQUAL(2, actKeys.size());
|
||||
LONGS_EQUAL(X.first, actKeys.front().first);
|
||||
LONGS_EQUAL(Y.first, actKeys.back().first);
|
||||
DiscreteKeys keys = sig.discreteKeys();
|
||||
LONGS_EQUAL(2, keys.size());
|
||||
CHECK(keys[0] == X);
|
||||
CHECK(keys[1] == Y);
|
||||
|
||||
vector<double> actCpt = sig.cpt();
|
||||
EXPECT_LONGS_EQUAL(6, actCpt.size());
|
||||
DiscreteKeys parents = sig.parents();
|
||||
LONGS_EQUAL(1, parents.size());
|
||||
CHECK(parents[0] == Y);
|
||||
|
||||
EXPECT_LONGS_EQUAL(6, sig.cpt().size());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), D(7, 2);
|
||||
|
||||
// Make sure we can create all signatures for Asia network with constructor.
|
||||
TEST(testSignature, all_examples) {
|
||||
DiscreteKey X(6, 2);
|
||||
Signature a(A, {}, "99/1");
|
||||
Signature s(S, {}, "50/50");
|
||||
Signature t(T, {A}, "99/1 95/5");
|
||||
Signature l(L, {S}, "99/1 90/10");
|
||||
Signature b(B, {S}, "70/30 40/60");
|
||||
Signature e(E, {T, L}, "F F F 1");
|
||||
Signature x(X, {E}, "95/5 2/98");
|
||||
}
|
||||
|
||||
// Make sure we can create all signatures for Asia network with operator magic.
|
||||
TEST(testSignature, all_examples_magic) {
|
||||
DiscreteKey X(6, 2);
|
||||
Signature a(A % "99/1");
|
||||
Signature s(S % "50/50");
|
||||
Signature t(T | A = "99/1 95/5");
|
||||
Signature l(L | S = "99/1 90/10");
|
||||
Signature b(B | S = "70/30 40/60");
|
||||
Signature e((E | T, L) = "F F F 1");
|
||||
Signature x(X | E = "95/5 2/98");
|
||||
}
|
||||
|
||||
// Check example from docs.
|
||||
TEST(testSignature, doxygen_example) {
|
||||
Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}};
|
||||
Signature d1(D, {E, B}, table);
|
||||
Signature d2((D | E, B) = "9/1 2/8 3/7 1/9");
|
||||
Signature d3(D, {E, B}, "9/1 2/8 3/7 1/9");
|
||||
EXPECT(*(d1.table()) == table);
|
||||
EXPECT(*(d2.table()) == table);
|
||||
EXPECT(*(d3.table()) == table);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -49,11 +49,13 @@ set(ignore
|
|||
gtsam::Pose3Vector
|
||||
gtsam::KeyVector
|
||||
gtsam::BinaryMeasurementsUnit3
|
||||
gtsam::DiscreteKey
|
||||
gtsam::KeyPairDoubleMap)
|
||||
|
||||
set(interface_headers
|
||||
${PROJECT_SOURCE_DIR}/gtsam/gtsam.i
|
||||
${PROJECT_SOURCE_DIR}/gtsam/base/base.i
|
||||
${PROJECT_SOURCE_DIR}/gtsam/discrete/discrete.i
|
||||
${PROJECT_SOURCE_DIR}/gtsam/geometry/geometry.i
|
||||
${PROJECT_SOURCE_DIR}/gtsam/linear/linear.i
|
||||
${PROJECT_SOURCE_DIR}/gtsam/nonlinear/nonlinear.i
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
/* Please refer to:
|
||||
* https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html
|
||||
* These are required to save one copy operation on Python calls.
|
||||
*
|
||||
* NOTES
|
||||
* =================
|
||||
*
|
||||
* `PYBIND11_MAKE_OPAQUE` will mark the type as "opaque" for the pybind11
|
||||
* automatic STL binding, such that the raw objects can be accessed in Python.
|
||||
* Without this they will be automatically converted to a Python object, and all
|
||||
* mutations on Python side will not be reflected on C++.
|
||||
*/
|
||||
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
PYBIND11_MAKE_OPAQUE(gtsam::DiscreteKeys);
|
|
@ -0,0 +1,17 @@
|
|||
/* Please refer to:
|
||||
* https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html
|
||||
* These are required to save one copy operation on Python calls.
|
||||
*
|
||||
* NOTES
|
||||
* =================
|
||||
*
|
||||
* `py::bind_vector` and similar machinery gives the std container a Python-like
|
||||
* interface, but without the `<pybind11/stl.h>` copying mechanism. Combined
|
||||
* with `PYBIND11_MAKE_OPAQUE` this allows the types to be modified with Python,
|
||||
* and saves one copy operation.
|
||||
*/
|
||||
|
||||
// Seems this is not a good idea with inherited stl
|
||||
//py::bind_vector<std::vector<gtsam::DiscreteKey>>(m_, "DiscreteKeys");
|
||||
|
||||
py::bind_map<gtsam::DiscreteValues>(m_, "DiscreteValues");
|
|
@ -0,0 +1,118 @@
|
|||
"""
|
||||
GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
|
||||
Atlanta, Georgia 30332-0415
|
||||
All Rights Reserved
|
||||
|
||||
See LICENSE for the license information
|
||||
|
||||
Unit tests for Discrete Bayes Nets.
|
||||
Author: Frank Dellaert
|
||||
"""
|
||||
|
||||
# pylint: disable=no-name-in-module, invalid-name
|
||||
|
||||
import unittest
|
||||
|
||||
from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
|
||||
DiscreteKeys, DiscreteValues, Ordering)
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
|
||||
class TestDiscreteBayesNet(GtsamTestCase):
|
||||
"""Tests for Discrete Bayes Nets."""
|
||||
|
||||
def test_constructor(self):
|
||||
"""Test constructing a Bayes net."""
|
||||
|
||||
bayesNet = DiscreteBayesNet()
|
||||
Parent, Child = (0, 2), (1, 2)
|
||||
empty = DiscreteKeys()
|
||||
prior = DiscreteConditional(Parent, empty, "6/4")
|
||||
bayesNet.add(prior)
|
||||
|
||||
parents = DiscreteKeys()
|
||||
parents.push_back(Parent)
|
||||
conditional = DiscreteConditional(Child, parents, "7/3 8/2")
|
||||
bayesNet.add(conditional)
|
||||
|
||||
# Check conversion to factor graph:
|
||||
fg = DiscreteFactorGraph(bayesNet)
|
||||
self.assertEqual(fg.size(), 2)
|
||||
self.assertEqual(fg.at(1).size(), 2)
|
||||
|
||||
def test_Asia(self):
|
||||
"""Test full Asia example."""
|
||||
|
||||
Asia = (0, 2)
|
||||
Smoking = (4, 2)
|
||||
Tuberculosis = (3, 2)
|
||||
LungCancer = (6, 2)
|
||||
|
||||
Bronchitis = (7, 2)
|
||||
Either = (5, 2)
|
||||
XRay = (2, 2)
|
||||
Dyspnea = (1, 2)
|
||||
|
||||
def P(keys):
|
||||
dks = DiscreteKeys()
|
||||
for key in keys:
|
||||
dks.push_back(key)
|
||||
return dks
|
||||
|
||||
asia = DiscreteBayesNet()
|
||||
asia.add(Asia, P([]), "99/1")
|
||||
asia.add(Smoking, P([]), "50/50")
|
||||
|
||||
asia.add(Tuberculosis, P([Asia]), "99/1 95/5")
|
||||
asia.add(LungCancer, P([Smoking]), "99/1 90/10")
|
||||
asia.add(Bronchitis, P([Smoking]), "70/30 40/60")
|
||||
|
||||
asia.add(Either, P([Tuberculosis, LungCancer]), "F T T T")
|
||||
|
||||
asia.add(XRay, P([Either]), "95/5 2/98")
|
||||
asia.add(Dyspnea, P([Either, Bronchitis]), "9/1 2/8 3/7 1/9")
|
||||
|
||||
# Convert to factor graph
|
||||
fg = DiscreteFactorGraph(asia)
|
||||
|
||||
# Create solver and eliminate
|
||||
ordering = Ordering()
|
||||
for j in range(8):
|
||||
ordering.push_back(j)
|
||||
chordal = fg.eliminateSequential(ordering)
|
||||
expected2 = DiscreteConditional(Bronchitis, P([]), "11/9")
|
||||
self.gtsamAssertEquals(chordal.at(7), expected2)
|
||||
|
||||
# solve
|
||||
actualMPE = chordal.optimize()
|
||||
expectedMPE = DiscreteValues()
|
||||
for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]:
|
||||
expectedMPE[key[0]] = 0
|
||||
self.assertEqual(list(actualMPE.items()),
|
||||
list(expectedMPE.items()))
|
||||
|
||||
# Check value for MPE is the same
|
||||
self.assertAlmostEqual(asia(actualMPE), fg(actualMPE))
|
||||
|
||||
# add evidence, we were in Asia and we have dyspnea
|
||||
fg.add(Asia, "0 1")
|
||||
fg.add(Dyspnea, "0 1")
|
||||
|
||||
# solve again, now with evidence
|
||||
chordal2 = fg.eliminateSequential(ordering)
|
||||
actualMPE2 = chordal2.optimize()
|
||||
expectedMPE2 = DiscreteValues()
|
||||
for key in [XRay, Tuberculosis, Either, LungCancer]:
|
||||
expectedMPE2[key[0]] = 0
|
||||
for key in [Asia, Dyspnea, Smoking, Bronchitis]:
|
||||
expectedMPE2[key[0]] = 1
|
||||
self.assertEqual(list(actualMPE2.items()),
|
||||
list(expectedMPE2.items()))
|
||||
|
||||
# now sample from it
|
||||
actualSample = chordal2.sample()
|
||||
self.assertEqual(len(actualSample), 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,122 @@
|
|||
"""
|
||||
GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
|
||||
Atlanta, Georgia 30332-0415
|
||||
All Rights Reserved
|
||||
|
||||
See LICENSE for the license information
|
||||
|
||||
Unit tests for Discrete Factor Graphs.
|
||||
Author: Frank Dellaert
|
||||
"""
|
||||
|
||||
# pylint: disable=no-name-in-module, invalid-name
|
||||
|
||||
import unittest
|
||||
|
||||
from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
|
||||
class TestDiscreteFactorGraph(GtsamTestCase):
|
||||
"""Tests for Discrete Factor Graphs."""
|
||||
|
||||
def test_evaluation(self):
|
||||
"""Test constructing and evaluating a discrete factor graph."""
|
||||
|
||||
# Three keys
|
||||
P1 = (0, 2)
|
||||
P2 = (1, 2)
|
||||
P3 = (2, 3)
|
||||
|
||||
# Create the DiscreteFactorGraph
|
||||
graph = DiscreteFactorGraph()
|
||||
|
||||
# Add two unary factors (priors)
|
||||
graph.add(P1, "0.9 0.3")
|
||||
graph.add(P2, "0.9 0.6")
|
||||
|
||||
# Add a binary factor
|
||||
graph.add(P1, P2, "4 1 10 4")
|
||||
|
||||
# Instantiate Values
|
||||
assignment = DiscreteValues()
|
||||
assignment[0] = 1
|
||||
assignment[1] = 1
|
||||
|
||||
# Check if graph evaluation works ( 0.3*0.6*4 )
|
||||
self.assertAlmostEqual(.72, graph(assignment))
|
||||
|
||||
# Create a new test with third node and adding unary and ternary factor
|
||||
graph.add(P3, "0.9 0.2 0.5")
|
||||
keys = DiscreteKeys()
|
||||
keys.push_back(P1)
|
||||
keys.push_back(P2)
|
||||
keys.push_back(P3)
|
||||
graph.add(keys, "1 2 3 4 5 6 7 8 9 10 11 12")
|
||||
|
||||
# Below assignment selects the 8th index in the ternary factor table
|
||||
assignment[0] = 1
|
||||
assignment[1] = 0
|
||||
assignment[2] = 1
|
||||
|
||||
# Check if graph evaluation works (0.3*0.9*1*0.2*8)
|
||||
self.assertAlmostEqual(4.32, graph(assignment))
|
||||
|
||||
# Below assignment selects the 3rd index in the ternary factor table
|
||||
assignment[0] = 0
|
||||
assignment[1] = 1
|
||||
assignment[2] = 0
|
||||
|
||||
# Check if graph evaluation works (0.9*0.6*1*0.9*4)
|
||||
self.assertAlmostEqual(1.944, graph(assignment))
|
||||
|
||||
# Check if graph product works
|
||||
product = graph.product()
|
||||
self.assertAlmostEqual(1.944, product(assignment))
|
||||
|
||||
def test_optimize(self):
|
||||
"""Test constructing and optizing a discrete factor graph."""
|
||||
|
||||
# Three keys
|
||||
C = (0, 2)
|
||||
B = (1, 2)
|
||||
A = (2, 2)
|
||||
|
||||
# A simple factor graph (A)-fAC-(C)-fBC-(B)
|
||||
# with smoothness priors
|
||||
graph = DiscreteFactorGraph()
|
||||
graph.add(A, C, "3 1 1 3")
|
||||
graph.add(C, B, "3 1 1 3")
|
||||
|
||||
# Test optimization
|
||||
expectedValues = DiscreteValues()
|
||||
expectedValues[0] = 0
|
||||
expectedValues[1] = 0
|
||||
expectedValues[2] = 0
|
||||
actualValues = graph.optimize()
|
||||
self.assertEqual(list(actualValues.items()),
|
||||
list(expectedValues.items()))
|
||||
|
||||
def test_MPE(self):
|
||||
"""Test maximum probable explanation (MPE): same as optimize."""
|
||||
|
||||
# Declare a bunch of keys
|
||||
C, A, B = (0, 2), (1, 2), (2, 2)
|
||||
|
||||
# Create Factor graph
|
||||
graph = DiscreteFactorGraph()
|
||||
graph.add(C, A, "0.2 0.8 0.3 0.7")
|
||||
graph.add(C, B, "0.1 0.9 0.4 0.6")
|
||||
|
||||
actualMPE = graph.optimize()
|
||||
|
||||
expectedMPE = DiscreteValues()
|
||||
expectedMPE[0] = 0
|
||||
expectedMPE[1] = 1
|
||||
expectedMPE[2] = 1
|
||||
self.assertEqual(list(actualMPE.items()),
|
||||
list(expectedMPE.items()))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue