diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index d78eed08f..aed4cec0a 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -23,6 +23,7 @@ #include #include #include +#include #include namespace gtsam { @@ -75,6 +76,11 @@ namespace gtsam { // Add inherited versions of add. using Base::add; + /** Add a DiscretePrior using a table or a string */ + void add(const DiscreteKey& key, const std::string& spec) { + emplace_shared(key, spec); + } + /** Add a DiscreteCondtional */ template void add(Args&&... args) { diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index e8cf6afe0..e95dfb515 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -81,6 +81,20 @@ public: const std::string& spec) : DiscreteConditional(Signature(key, parents, spec)) {} + /// No-parent specialization; can also use DiscretePrior. + DiscreteConditional(const DiscreteKey& key, const std::string& spec) + : DiscreteConditional(Signature(key, {}, spec)) {} + + /// Single-parent specialization + DiscreteConditional(const DiscreteKey& key, const std::string& spec, + const DiscreteKey& parent1) + : DiscreteConditional(Signature(key, {parent1}, spec)) {} + + /// Two-parent specialization + DiscreteConditional(const DiscreteKey& key, const std::string& spec, + const DiscreteKey& parent1, const DiscreteKey& parent2) + : DiscreteConditional(Signature(key, {parent1, parent2}, spec)) {} + /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ DiscreteConditional(const DecisionTreeFactor& joint, const DecisionTreeFactor& marginal); diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 9782480c3..f2e7456d8 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -84,10 +84,17 @@ virtual class DiscretePrior : gtsam::DiscreteConditional { }; #include -class DiscreteBayesNet { +class DiscreteBayesNet { DiscreteBayesNet(); - void add(const gtsam::DiscreteKey& key, - const gtsam::DiscreteKeys& parents, string spec); + void add(const gtsam::DiscreteConditional& s); + void add(const gtsam::DiscreteKey& key, string spec); + void add(const gtsam::DiscreteKey& key, string spec, + const gtsam::DiscreteKey& parent1); + void add(const gtsam::DiscreteKey& key, string spec, + const gtsam::DiscreteKey& parent1, + const gtsam::DiscreteKey& parent2); + void add(const gtsam::DiscreteKey& key, const gtsam::DiscreteKeys& parents, + string spec); bool empty() const; size_t size() const; gtsam::KeySet keys() const; @@ -98,15 +105,13 @@ class DiscreteBayesNet { bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const; string dot(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; - void saveGraph(string s, - const gtsam::KeyFormatter& keyFormatter = - gtsam::DefaultKeyFormatter) const; - void add(const gtsam::DiscreteConditional& s); + void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; double operator()(const gtsam::DiscreteValues& values) const; gtsam::DiscreteValues optimize() const; gtsam::DiscreteValues sample() const; string markdown(const gtsam::KeyFormatter& keyFormatter = - gtsam::DefaultKeyFormatter) const; + gtsam::DefaultKeyFormatter) const; }; #include diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index ea5816566..7a5f180ad 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -75,8 +75,8 @@ TEST(DiscreteBayesNet, bayesNet) { TEST(DiscreteBayesNet, Asia) { DiscreteBayesNet asia; - asia.add(Asia % "99/1"); - asia.add(Smoking % "50/50"); + asia.add(Asia, "99/1"); + asia.add(Smoking % "50/50"); // Signature version asia.add(Tuberculosis | Asia = "99/1 95/5"); asia.add(LungCancer | Smoking = "99/1 90/10"); diff --git a/python/gtsam/tests/test_DiscreteBayesNet.py b/python/gtsam/tests/test_DiscreteBayesNet.py index bf09da193..706cdf93d 100644 --- a/python/gtsam/tests/test_DiscreteBayesNet.py +++ b/python/gtsam/tests/test_DiscreteBayesNet.py @@ -14,7 +14,7 @@ Author: Frank Dellaert import unittest from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph, - DiscreteKeys, DiscreteValues, Ordering) + DiscreteKeys, DiscretePrior, DiscreteValues, Ordering) from gtsam.utils.test_case import GtsamTestCase @@ -53,24 +53,18 @@ class TestDiscreteBayesNet(GtsamTestCase): XRay = (2, 2) Dyspnea = (1, 2) - def P(keys): - dks = DiscreteKeys() - for key in keys: - dks.push_back(key) - return dks - asia = DiscreteBayesNet() - asia.add(Asia, P([]), "99/1") - asia.add(Smoking, P([]), "50/50") + asia.add(Asia, "99/1") + asia.add(Smoking, "50/50") - asia.add(Tuberculosis, P([Asia]), "99/1 95/5") - asia.add(LungCancer, P([Smoking]), "99/1 90/10") - asia.add(Bronchitis, P([Smoking]), "70/30 40/60") + asia.add(Tuberculosis, "99/1 95/5", Asia) + asia.add(LungCancer, "99/1 90/10", Smoking) + asia.add(Bronchitis, "70/30 40/60", Smoking) - asia.add(Either, P([Tuberculosis, LungCancer]), "F T T T") + asia.add(Either, "F T T T", Tuberculosis, LungCancer) - asia.add(XRay, P([Either]), "95/5 2/98") - asia.add(Dyspnea, P([Either, Bronchitis]), "9/1 2/8 3/7 1/9") + asia.add(XRay, "95/5 2/98", Either) + asia.add(Dyspnea, "9/1 2/8 3/7 1/9", Either, Bronchitis) # Convert to factor graph fg = DiscreteFactorGraph(asia) @@ -80,7 +74,7 @@ class TestDiscreteBayesNet(GtsamTestCase): for j in range(8): ordering.push_back(j) chordal = fg.eliminateSequential(ordering) - expected2 = DiscreteConditional(Bronchitis, P([]), "11/9") + expected2 = DiscretePrior(Bronchitis, "11/9") self.gtsamAssertEquals(chordal.at(7), expected2) # solve