Wrap () operators
parent
b2e3654960
commit
7257797a5f
|
@ -83,9 +83,6 @@ public:
|
|||
/// Find value for given assignment of values to variables
|
||||
virtual double operator()(const DiscreteValues&) const = 0;
|
||||
|
||||
/// Synonym for operator(), mostly for wrapper
|
||||
double evaluate(const DiscreteValues& values) const { return operator()(values); }
|
||||
|
||||
/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
|
||||
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
|
||||
|
||||
|
|
|
@ -136,9 +136,6 @@ public:
|
|||
*/
|
||||
double operator()(const DiscreteValues& values) const;
|
||||
|
||||
/// Synonym for operator(), mostly for wrapper
|
||||
double evaluate(const DiscreteValues& values) const { return operator()(values); }
|
||||
|
||||
/// print
|
||||
void print(
|
||||
const std::string& s = "DiscreteFactorGraph",
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
//*************************************************************************
|
||||
// basis
|
||||
// discrete
|
||||
//*************************************************************************
|
||||
|
||||
namespace gtsam {
|
||||
|
@ -26,7 +26,7 @@ class DiscreteFactor {
|
|||
bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const;
|
||||
bool empty() const;
|
||||
size_t size() const;
|
||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
|
@ -38,7 +38,7 @@ virtual class DecisionTreeFactor: gtsam::DiscreteFactor {
|
|||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
|
||||
double evaluate(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
|
||||
double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
|
@ -53,7 +53,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
|||
const gtsam::DecisionTreeFactor& marginal,
|
||||
const gtsam::Ordering& orderedKeys);
|
||||
size_t size() const; // TODO(dellaert): why do I have to repeat???
|
||||
double evaluate(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
|
||||
double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
|
||||
void print(string s = "Discrete Conditional\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
|
@ -86,7 +86,7 @@ class DiscreteBayesNet {
|
|||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
void add(const gtsam::DiscreteConditional& s);
|
||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
gtsam::DiscreteValues optimize() const;
|
||||
gtsam::DiscreteValues sample() const;
|
||||
};
|
||||
|
@ -98,7 +98,7 @@ class DiscreteBayesTree {
|
|||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const;
|
||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
|
@ -119,7 +119,7 @@ class DiscreteFactorGraph {
|
|||
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;
|
||||
|
||||
gtsam::DecisionTreeFactor product() const;
|
||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
gtsam::DiscreteValues optimize() const;
|
||||
|
||||
gtsam::DiscreteBayesNet eliminateSequential();
|
||||
|
|
|
@ -91,6 +91,9 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
self.assertEqual(list(actualMPE.items()),
|
||||
list(expectedMPE.items()))
|
||||
|
||||
# Check value for MPE is the same
|
||||
self.assertAlmostEqual(asia(actualMPE), fg(actualMPE))
|
||||
|
||||
# add evidence, we were in Asia and we have dyspnea
|
||||
fg.add(Asia, "0 1")
|
||||
fg.add(Dyspnea, "0 1")
|
||||
|
|
|
@ -44,9 +44,9 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
|||
assignment[1] = 1
|
||||
|
||||
# Check if graph evaluation works ( 0.3*0.6*4 )
|
||||
self.assertAlmostEqual(.72, graph.evaluate(assignment))
|
||||
self.assertAlmostEqual(.72, graph(assignment))
|
||||
|
||||
# Creating a new test with third node and adding unary and ternary factors on it
|
||||
# Create a new test with third node and adding unary and ternary factor
|
||||
graph.add(P3, "0.9 0.2 0.5")
|
||||
keys = DiscreteKeys()
|
||||
keys.push_back(P1)
|
||||
|
@ -54,25 +54,25 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
|||
keys.push_back(P3)
|
||||
graph.add(keys, "1 2 3 4 5 6 7 8 9 10 11 12")
|
||||
|
||||
# Below assignment lead to selecting the 8th index in the ternary factor table
|
||||
# Below assignment selects the 8th index in the ternary factor table
|
||||
assignment[0] = 1
|
||||
assignment[1] = 0
|
||||
assignment[2] = 1
|
||||
|
||||
# Check if graph evaluation works (0.3*0.9*1*0.2*8)
|
||||
self.assertAlmostEqual(4.32, graph.evaluate(assignment))
|
||||
self.assertAlmostEqual(4.32, graph(assignment))
|
||||
|
||||
# Below assignment lead to selecting the 3rd index in the ternary factor table
|
||||
# Below assignment selects the 3rd index in the ternary factor table
|
||||
assignment[0] = 0
|
||||
assignment[1] = 1
|
||||
assignment[2] = 0
|
||||
|
||||
# Check if graph evaluation works (0.9*0.6*1*0.9*4)
|
||||
self.assertAlmostEqual(1.944, graph.evaluate(assignment))
|
||||
self.assertAlmostEqual(1.944, graph(assignment))
|
||||
|
||||
# Check if graph product works
|
||||
product = graph.product()
|
||||
self.assertAlmostEqual(1.944, product.evaluate(assignment))
|
||||
self.assertAlmostEqual(1.944, product(assignment))
|
||||
|
||||
def test_optimize(self):
|
||||
"""Test constructing and optizing a discrete factor graph."""
|
||||
|
|
Loading…
Reference in New Issue