Merge branch 'feature/discrete_wrapper' into feature/discrete_wrapper_2
commit
7401b6e0c2
|
|
@ -83,6 +83,11 @@ namespace gtsam {
|
||||||
//** evaluate for given DiscreteValues */
|
//** evaluate for given DiscreteValues */
|
||||||
double evaluate(const DiscreteValues & values) const;
|
double evaluate(const DiscreteValues & values) const;
|
||||||
|
|
||||||
|
//** (Preferred) sugar for the above for given DiscreteValues */
|
||||||
|
double operator()(const DiscreteValues & values) const {
|
||||||
|
return evaluate(values);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Solve the DiscreteBayesNet by back-substitution
|
* Solve the DiscreteBayesNet by back-substitution
|
||||||
*/
|
*/
|
||||||
|
|
|
||||||
|
|
@ -80,6 +80,12 @@ class GTSAM_EXPORT DiscreteBayesTree
|
||||||
|
|
||||||
//** evaluate probability for given DiscreteValues */
|
//** evaluate probability for given DiscreteValues */
|
||||||
double evaluate(const DiscreteValues& values) const;
|
double evaluate(const DiscreteValues& values) const;
|
||||||
|
|
||||||
|
//** (Preferred) sugar for the above for given DiscreteValues */
|
||||||
|
double operator()(const DiscreteValues & values) const {
|
||||||
|
return evaluate(values);
|
||||||
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -58,24 +58,28 @@ public:
|
||||||
DiscreteConditional(const Signature& signature);
|
DiscreteConditional(const Signature& signature);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct from key, parents, and a Table specifying the CPT.
|
* Construct from key, parents, and a Signature::Table specifying the
|
||||||
*
|
* conditional probability table (CPT) in 00 01 10 11 order. For
|
||||||
* The first string is parsed to add a key and parents.
|
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
|
||||||
*
|
*
|
||||||
* Example: DiscreteConditional P(D, {B,E}, table);
|
* The first string is parsed to add a key and parents.
|
||||||
*/
|
*
|
||||||
|
* Example: DiscreteConditional P(D, {B,E}, table);
|
||||||
|
*/
|
||||||
DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents,
|
DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||||
const Signature::Table& table)
|
const Signature::Table& table)
|
||||||
: DiscreteConditional(Signature(key, parents, table)) {}
|
: DiscreteConditional(Signature(key, parents, table)) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct from key, parents, and a string specifying the CPT.
|
* Construct from key, parents, and a string specifying the conditional
|
||||||
*
|
* probability table (CPT) in 00 01 10 11 order. For three-valued, it would
|
||||||
* The first string is parsed to add a key and parents. The second string
|
* be 00 01 02 10 11 12 20 21 22, etc....
|
||||||
* parses into a table.
|
*
|
||||||
*
|
* The first string is parsed to add a key and parents. The second string
|
||||||
* Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9");
|
* parses into a table.
|
||||||
*/
|
*
|
||||||
|
* Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9");
|
||||||
|
*/
|
||||||
DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents,
|
DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||||
const std::string& spec)
|
const std::string& spec)
|
||||||
: DiscreteConditional(Signature(key, parents, spec)) {}
|
: DiscreteConditional(Signature(key, parents, spec)) {}
|
||||||
|
|
|
||||||
|
|
@ -83,9 +83,6 @@ public:
|
||||||
/// Find value for given assignment of values to variables
|
/// Find value for given assignment of values to variables
|
||||||
virtual double operator()(const DiscreteValues&) const = 0;
|
virtual double operator()(const DiscreteValues&) const = 0;
|
||||||
|
|
||||||
/// Synonym for operator(), mostly for wrapper
|
|
||||||
double evaluate(const DiscreteValues& values) const { return operator()(values); }
|
|
||||||
|
|
||||||
/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
|
/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
|
||||||
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
|
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -136,9 +136,6 @@ public:
|
||||||
*/
|
*/
|
||||||
double operator()(const DiscreteValues& values) const;
|
double operator()(const DiscreteValues& values) const;
|
||||||
|
|
||||||
/// Synonym for operator(), mostly for wrapper
|
|
||||||
double evaluate(const DiscreteValues& values) const { return operator()(values); }
|
|
||||||
|
|
||||||
/// print
|
/// print
|
||||||
void print(
|
void print(
|
||||||
const std::string& s = "DiscreteFactorGraph",
|
const std::string& s = "DiscreteFactorGraph",
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ namespace gtsam {
|
||||||
* The format is (Key % string) for nodes with no parents,
|
* The format is (Key % string) for nodes with no parents,
|
||||||
* and (Key | Key, Key = string) for nodes with parents.
|
* and (Key | Key, Key = string) for nodes with parents.
|
||||||
*
|
*
|
||||||
* The string specifies a conditional probability spec in the 00 01 10 11 order.
|
* The string specifies a conditional probability table in 00 01 10 11 order.
|
||||||
* For three-valued, it would be 00 01 02 10 11 12 20 21 22, etc...
|
* For three-valued, it would be 00 01 02 10 11 12 20 21 22, etc...
|
||||||
*
|
*
|
||||||
* For example, given the following keys
|
* For example, given the following keys
|
||||||
|
|
@ -73,22 +73,29 @@ namespace gtsam {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/**
|
/**
|
||||||
* Construct from key, parents, and a Table specifying the CPT.
|
* Construct from key, parents, and a Signature::Table specifying the
|
||||||
|
* conditional probability table (CPT) in 00 01 10 11 order. For
|
||||||
|
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
|
||||||
*
|
*
|
||||||
* The first string is parsed to add a key and parents.
|
* The first string is parsed to add a key and parents.
|
||||||
*
|
*
|
||||||
* Example: Signature sig(D, {B,E}, table);
|
* Example:
|
||||||
|
* Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}};
|
||||||
|
* Signature sig(D, {E, B}, table);
|
||||||
*/
|
*/
|
||||||
Signature(const DiscreteKey& key, const DiscreteKeys& parents,
|
Signature(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||||
const Table& table);
|
const Table& table);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct from key, parents, and a string specifying the CPT.
|
* Construct from key, parents, and a string specifying the conditional
|
||||||
|
* probability table (CPT) in 00 01 10 11 order. For three-valued, it would
|
||||||
|
* be 00 01 02 10 11 12 20 21 22, etc....
|
||||||
*
|
*
|
||||||
* The first string is parsed to add a key and parents. The second string
|
* The first string is parsed to add a key and parents. The second string
|
||||||
* parses into a table.
|
* parses into a table.
|
||||||
*
|
*
|
||||||
* Example: Signature sig(D, {B,E}, "9/1 2/8 3/7 1/9");
|
* Example (same CPT as above):
|
||||||
|
* Signature sig(D, {B,E}, "9/1 2/8 3/7 1/9");
|
||||||
*/
|
*/
|
||||||
Signature(const DiscreteKey& key, const DiscreteKeys& parents,
|
Signature(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||||
const std::string& spec);
|
const std::string& spec);
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
//*************************************************************************
|
//*************************************************************************
|
||||||
// basis
|
// discrete
|
||||||
//*************************************************************************
|
//*************************************************************************
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
@ -26,7 +26,7 @@ class DiscreteFactor {
|
||||||
bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const;
|
bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const;
|
||||||
bool empty() const;
|
bool empty() const;
|
||||||
size_t size() const;
|
size_t size() const;
|
||||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
|
|
@ -38,7 +38,7 @@ virtual class DecisionTreeFactor: gtsam::DiscreteFactor {
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
|
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
|
||||||
double evaluate(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
|
double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
|
|
@ -53,7 +53,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
const gtsam::DecisionTreeFactor& marginal,
|
const gtsam::DecisionTreeFactor& marginal,
|
||||||
const gtsam::Ordering& orderedKeys);
|
const gtsam::Ordering& orderedKeys);
|
||||||
size_t size() const; // TODO(dellaert): why do I have to repeat???
|
size_t size() const; // TODO(dellaert): why do I have to repeat???
|
||||||
double evaluate(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
|
double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
|
||||||
void print(string s = "Discrete Conditional\n",
|
void print(string s = "Discrete Conditional\n",
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
|
@ -86,7 +86,7 @@ class DiscreteBayesNet {
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
void add(const gtsam::DiscreteConditional& s);
|
void add(const gtsam::DiscreteConditional& s);
|
||||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
gtsam::DiscreteValues optimize() const;
|
gtsam::DiscreteValues optimize() const;
|
||||||
gtsam::DiscreteValues sample() const;
|
gtsam::DiscreteValues sample() const;
|
||||||
};
|
};
|
||||||
|
|
@ -98,7 +98,7 @@ class DiscreteBayesTree {
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const;
|
bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const;
|
||||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
|
|
@ -119,7 +119,7 @@ class DiscreteFactorGraph {
|
||||||
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;
|
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;
|
||||||
|
|
||||||
gtsam::DecisionTreeFactor product() const;
|
gtsam::DecisionTreeFactor product() const;
|
||||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
gtsam::DiscreteValues optimize() const;
|
gtsam::DiscreteValues optimize() const;
|
||||||
|
|
||||||
gtsam::DiscreteBayesNet eliminateSequential();
|
gtsam::DiscreteBayesNet eliminateSequential();
|
||||||
|
|
|
||||||
|
|
@ -92,7 +92,6 @@ TEST(testSignature, all_examples) {
|
||||||
Signature b(B, {S}, "70/30 40/60");
|
Signature b(B, {S}, "70/30 40/60");
|
||||||
Signature e(E, {T, L}, "F F F 1");
|
Signature e(E, {T, L}, "F F F 1");
|
||||||
Signature x(X, {E}, "95/5 2/98");
|
Signature x(X, {E}, "95/5 2/98");
|
||||||
Signature d(D, {E, B}, "9/1 2/8 3/7 1/9");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure we can create all signatures for Asia network with operator magic.
|
// Make sure we can create all signatures for Asia network with operator magic.
|
||||||
|
|
@ -105,7 +104,17 @@ TEST(testSignature, all_examples_magic) {
|
||||||
Signature b(B | S = "70/30 40/60");
|
Signature b(B | S = "70/30 40/60");
|
||||||
Signature e((E | T, L) = "F F F 1");
|
Signature e((E | T, L) = "F F F 1");
|
||||||
Signature x(X | E = "95/5 2/98");
|
Signature x(X | E = "95/5 2/98");
|
||||||
Signature d((D | E, B) = "9/1 2/8 3/7 1/9");
|
}
|
||||||
|
|
||||||
|
// Check example from docs.
|
||||||
|
TEST(testSignature, doxygen_example) {
|
||||||
|
Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}};
|
||||||
|
Signature d1(D, {E, B}, table);
|
||||||
|
Signature d2((D | E, B) = "9/1 2/8 3/7 1/9");
|
||||||
|
Signature d3(D, {E, B}, "9/1 2/8 3/7 1/9");
|
||||||
|
EXPECT(*(d1.table()) == table);
|
||||||
|
EXPECT(*(d2.table()) == table);
|
||||||
|
EXPECT(*(d3.table()) == table);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
|
|
@ -91,6 +91,9 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
self.assertEqual(list(actualMPE.items()),
|
self.assertEqual(list(actualMPE.items()),
|
||||||
list(expectedMPE.items()))
|
list(expectedMPE.items()))
|
||||||
|
|
||||||
|
# Check value for MPE is the same
|
||||||
|
self.assertAlmostEqual(asia(actualMPE), fg(actualMPE))
|
||||||
|
|
||||||
# add evidence, we were in Asia and we have dyspnea
|
# add evidence, we were in Asia and we have dyspnea
|
||||||
fg.add(Asia, "0 1")
|
fg.add(Asia, "0 1")
|
||||||
fg.add(Dyspnea, "0 1")
|
fg.add(Dyspnea, "0 1")
|
||||||
|
|
|
||||||
|
|
@ -44,9 +44,9 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
||||||
assignment[1] = 1
|
assignment[1] = 1
|
||||||
|
|
||||||
# Check if graph evaluation works ( 0.3*0.6*4 )
|
# Check if graph evaluation works ( 0.3*0.6*4 )
|
||||||
self.assertAlmostEqual(.72, graph.evaluate(assignment))
|
self.assertAlmostEqual(.72, graph(assignment))
|
||||||
|
|
||||||
# Creating a new test with third node and adding unary and ternary factors on it
|
# Create a new test with third node and adding unary and ternary factor
|
||||||
graph.add(P3, "0.9 0.2 0.5")
|
graph.add(P3, "0.9 0.2 0.5")
|
||||||
keys = DiscreteKeys()
|
keys = DiscreteKeys()
|
||||||
keys.push_back(P1)
|
keys.push_back(P1)
|
||||||
|
|
@ -54,25 +54,25 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
||||||
keys.push_back(P3)
|
keys.push_back(P3)
|
||||||
graph.add(keys, "1 2 3 4 5 6 7 8 9 10 11 12")
|
graph.add(keys, "1 2 3 4 5 6 7 8 9 10 11 12")
|
||||||
|
|
||||||
# Below assignment lead to selecting the 8th index in the ternary factor table
|
# Below assignment selects the 8th index in the ternary factor table
|
||||||
assignment[0] = 1
|
assignment[0] = 1
|
||||||
assignment[1] = 0
|
assignment[1] = 0
|
||||||
assignment[2] = 1
|
assignment[2] = 1
|
||||||
|
|
||||||
# Check if graph evaluation works (0.3*0.9*1*0.2*8)
|
# Check if graph evaluation works (0.3*0.9*1*0.2*8)
|
||||||
self.assertAlmostEqual(4.32, graph.evaluate(assignment))
|
self.assertAlmostEqual(4.32, graph(assignment))
|
||||||
|
|
||||||
# Below assignment lead to selecting the 3rd index in the ternary factor table
|
# Below assignment selects the 3rd index in the ternary factor table
|
||||||
assignment[0] = 0
|
assignment[0] = 0
|
||||||
assignment[1] = 1
|
assignment[1] = 1
|
||||||
assignment[2] = 0
|
assignment[2] = 0
|
||||||
|
|
||||||
# Check if graph evaluation works (0.9*0.6*1*0.9*4)
|
# Check if graph evaluation works (0.9*0.6*1*0.9*4)
|
||||||
self.assertAlmostEqual(1.944, graph.evaluate(assignment))
|
self.assertAlmostEqual(1.944, graph(assignment))
|
||||||
|
|
||||||
# Check if graph product works
|
# Check if graph product works
|
||||||
product = graph.product()
|
product = graph.product()
|
||||||
self.assertAlmostEqual(1.944, product.evaluate(assignment))
|
self.assertAlmostEqual(1.944, product(assignment))
|
||||||
|
|
||||||
def test_optimize(self):
|
def test_optimize(self):
|
||||||
"""Test constructing and optizing a discrete factor graph."""
|
"""Test constructing and optizing a discrete factor graph."""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue