diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index 4ffac95ed..2d92b72e8 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -83,6 +83,11 @@ namespace gtsam { //** evaluate for given DiscreteValues */ double evaluate(const DiscreteValues & values) const; + //** (Preferred) sugar for the above for given DiscreteValues */ + double operator()(const DiscreteValues & values) const { + return evaluate(values); + } + /** * Solve the DiscreteBayesNet by back-substitution */ diff --git a/gtsam/discrete/DiscreteBayesTree.h b/gtsam/discrete/DiscreteBayesTree.h index 655bcb9ee..42ec7d417 100644 --- a/gtsam/discrete/DiscreteBayesTree.h +++ b/gtsam/discrete/DiscreteBayesTree.h @@ -80,6 +80,12 @@ class GTSAM_EXPORT DiscreteBayesTree //** evaluate probability for given DiscreteValues */ double evaluate(const DiscreteValues& values) const; + + //** (Preferred) sugar for the above for given DiscreteValues */ + double operator()(const DiscreteValues & values) const { + return evaluate(values); + } + }; } // namespace gtsam diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index be268afaf..06928e2e7 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -58,24 +58,28 @@ public: DiscreteConditional(const Signature& signature); /** - * Construct from key, parents, and a Table specifying the CPT. - * - * The first string is parsed to add a key and parents. - * - * Example: DiscreteConditional P(D, {B,E}, table); - */ + * Construct from key, parents, and a Signature::Table specifying the + * conditional probability table (CPT) in 00 01 10 11 order. For + * three-valued, it would be 00 01 02 10 11 12 20 21 22, etc.... + * + * The first string is parsed to add a key and parents. + * + * Example: DiscreteConditional P(D, {B,E}, table); + */ DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents, const Signature::Table& table) : DiscreteConditional(Signature(key, parents, table)) {} /** - * Construct from key, parents, and a string specifying the CPT. - * - * The first string is parsed to add a key and parents. The second string - * parses into a table. - * - * Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9"); - */ + * Construct from key, parents, and a string specifying the conditional + * probability table (CPT) in 00 01 10 11 order. For three-valued, it would + * be 00 01 02 10 11 12 20 21 22, etc.... + * + * The first string is parsed to add a key and parents. The second string + * parses into a table. + * + * Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9"); + */ DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents, const std::string& spec) : DiscreteConditional(Signature(key, parents, spec)) {} diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 7abad4245..e2be94b94 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -83,9 +83,6 @@ public: /// Find value for given assignment of values to variables virtual double operator()(const DiscreteValues&) const = 0; - /// Synonym for operator(), mostly for wrapper - double evaluate(const DiscreteValues& values) const { return operator()(values); } - /// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index 472702231..ff0aaef19 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -136,9 +136,6 @@ public: */ double operator()(const DiscreteValues& values) const; - /// Synonym for operator(), mostly for wrapper - double evaluate(const DiscreteValues& values) const { return operator()(values); } - /// print void print( const std::string& s = "DiscreteFactorGraph", diff --git a/gtsam/discrete/Signature.h b/gtsam/discrete/Signature.h index 05f10ed23..ff83caa53 100644 --- a/gtsam/discrete/Signature.h +++ b/gtsam/discrete/Signature.h @@ -30,7 +30,7 @@ namespace gtsam { * The format is (Key % string) for nodes with no parents, * and (Key | Key, Key = string) for nodes with parents. * - * The string specifies a conditional probability spec in the 00 01 10 11 order. + * The string specifies a conditional probability table in 00 01 10 11 order. * For three-valued, it would be 00 01 02 10 11 12 20 21 22, etc... * * For example, given the following keys @@ -73,22 +73,29 @@ namespace gtsam { public: /** - * Construct from key, parents, and a Table specifying the CPT. + * Construct from key, parents, and a Signature::Table specifying the + * conditional probability table (CPT) in 00 01 10 11 order. For + * three-valued, it would be 00 01 02 10 11 12 20 21 22, etc.... * * The first string is parsed to add a key and parents. - * - * Example: Signature sig(D, {B,E}, table); + * + * Example: + * Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}}; + * Signature sig(D, {E, B}, table); */ Signature(const DiscreteKey& key, const DiscreteKeys& parents, const Table& table); /** - * Construct from key, parents, and a string specifying the CPT. + * Construct from key, parents, and a string specifying the conditional + * probability table (CPT) in 00 01 10 11 order. For three-valued, it would + * be 00 01 02 10 11 12 20 21 22, etc.... * * The first string is parsed to add a key and parents. The second string * parses into a table. - * - * Example: Signature sig(D, {B,E}, "9/1 2/8 3/7 1/9"); + * + * Example (same CPT as above): + * Signature sig(D, {B,E}, "9/1 2/8 3/7 1/9"); */ Signature(const DiscreteKey& key, const DiscreteKeys& parents, const std::string& spec); diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 47583c612..daea84e70 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -1,5 +1,5 @@ //************************************************************************* -// basis +// discrete //************************************************************************* namespace gtsam { @@ -26,7 +26,7 @@ class DiscreteFactor { bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const; bool empty() const; size_t size() const; - double evaluate(const gtsam::DiscreteValues& values) const; + double operator()(const gtsam::DiscreteValues& values) const; }; #include @@ -38,7 +38,7 @@ virtual class DecisionTreeFactor: gtsam::DiscreteFactor { const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; - double evaluate(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat??? + double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat??? }; #include @@ -53,7 +53,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { const gtsam::DecisionTreeFactor& marginal, const gtsam::Ordering& orderedKeys); size_t size() const; // TODO(dellaert): why do I have to repeat??? - double evaluate(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat??? + double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat??? void print(string s = "Discrete Conditional\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; @@ -86,7 +86,7 @@ class DiscreteBayesNet { const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; void add(const gtsam::DiscreteConditional& s); - double evaluate(const gtsam::DiscreteValues& values) const; + double operator()(const gtsam::DiscreteValues& values) const; gtsam::DiscreteValues optimize() const; gtsam::DiscreteValues sample() const; }; @@ -98,7 +98,7 @@ class DiscreteBayesTree { const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const; - double evaluate(const gtsam::DiscreteValues& values) const; + double operator()(const gtsam::DiscreteValues& values) const; }; #include @@ -119,7 +119,7 @@ class DiscreteFactorGraph { bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const; gtsam::DecisionTreeFactor product() const; - double evaluate(const gtsam::DiscreteValues& values) const; + double operator()(const gtsam::DiscreteValues& values) const; gtsam::DiscreteValues optimize() const; gtsam::DiscreteBayesNet eliminateSequential(); diff --git a/gtsam/discrete/tests/testSignature.cpp b/gtsam/discrete/tests/testSignature.cpp index fd15eb36c..737bd8aef 100644 --- a/gtsam/discrete/tests/testSignature.cpp +++ b/gtsam/discrete/tests/testSignature.cpp @@ -92,7 +92,6 @@ TEST(testSignature, all_examples) { Signature b(B, {S}, "70/30 40/60"); Signature e(E, {T, L}, "F F F 1"); Signature x(X, {E}, "95/5 2/98"); - Signature d(D, {E, B}, "9/1 2/8 3/7 1/9"); } // Make sure we can create all signatures for Asia network with operator magic. @@ -105,7 +104,17 @@ TEST(testSignature, all_examples_magic) { Signature b(B | S = "70/30 40/60"); Signature e((E | T, L) = "F F F 1"); Signature x(X | E = "95/5 2/98"); - Signature d((D | E, B) = "9/1 2/8 3/7 1/9"); +} + +// Check example from docs. +TEST(testSignature, doxygen_example) { + Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}}; + Signature d1(D, {E, B}, table); + Signature d2((D | E, B) = "9/1 2/8 3/7 1/9"); + Signature d3(D, {E, B}, "9/1 2/8 3/7 1/9"); + EXPECT(*(d1.table()) == table); + EXPECT(*(d2.table()) == table); + EXPECT(*(d3.table()) == table); } /* ************************************************************************* */ diff --git a/python/gtsam/tests/test_DiscreteBayesNet.py b/python/gtsam/tests/test_DiscreteBayesNet.py index 2abc65715..bf09da193 100644 --- a/python/gtsam/tests/test_DiscreteBayesNet.py +++ b/python/gtsam/tests/test_DiscreteBayesNet.py @@ -91,6 +91,9 @@ class TestDiscreteBayesNet(GtsamTestCase): self.assertEqual(list(actualMPE.items()), list(expectedMPE.items())) + # Check value for MPE is the same + self.assertAlmostEqual(asia(actualMPE), fg(actualMPE)) + # add evidence, we were in Asia and we have dyspnea fg.add(Asia, "0 1") fg.add(Dyspnea, "0 1") diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py index e73e9dc7b..9dafff33f 100644 --- a/python/gtsam/tests/test_DiscreteFactorGraph.py +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -44,9 +44,9 @@ class TestDiscreteFactorGraph(GtsamTestCase): assignment[1] = 1 # Check if graph evaluation works ( 0.3*0.6*4 ) - self.assertAlmostEqual(.72, graph.evaluate(assignment)) + self.assertAlmostEqual(.72, graph(assignment)) - # Creating a new test with third node and adding unary and ternary factors on it + # Create a new test with third node and adding unary and ternary factor graph.add(P3, "0.9 0.2 0.5") keys = DiscreteKeys() keys.push_back(P1) @@ -54,25 +54,25 @@ class TestDiscreteFactorGraph(GtsamTestCase): keys.push_back(P3) graph.add(keys, "1 2 3 4 5 6 7 8 9 10 11 12") - # Below assignment lead to selecting the 8th index in the ternary factor table + # Below assignment selects the 8th index in the ternary factor table assignment[0] = 1 assignment[1] = 0 assignment[2] = 1 # Check if graph evaluation works (0.3*0.9*1*0.2*8) - self.assertAlmostEqual(4.32, graph.evaluate(assignment)) + self.assertAlmostEqual(4.32, graph(assignment)) - # Below assignment lead to selecting the 3rd index in the ternary factor table + # Below assignment selects the 3rd index in the ternary factor table assignment[0] = 0 assignment[1] = 1 assignment[2] = 0 # Check if graph evaluation works (0.9*0.6*1*0.9*4) - self.assertAlmostEqual(1.944, graph.evaluate(assignment)) + self.assertAlmostEqual(1.944, graph(assignment)) # Check if graph product works product = graph.product() - self.assertAlmostEqual(1.944, product.evaluate(assignment)) + self.assertAlmostEqual(1.944, product(assignment)) def test_optimize(self): """Test constructing and optizing a discrete factor graph."""