diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index 796c0c8c8..219f2d93e 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -34,16 +34,6 @@ namespace gtsam { return Base::equals(bn, tol); } - /* ************************************************************************* */ -// void DiscreteBayesNet::add_front(const Signature& s) { -// push_front(boost::make_shared(s)); -// } - - /* ************************************************************************* */ - void DiscreteBayesNet::add(const Signature& s) { - push_back(boost::make_shared(s)); - } - /* ************************************************************************* */ double DiscreteBayesNet::evaluate(const DiscreteValues & values) const { // evaluate all conditionals and multiply diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index 9f5f10388..4ffac95ed 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -71,12 +71,15 @@ namespace gtsam { /// @name Standard Interface /// @{ + // Add inherited versions of add. + using Base::add; + /** Add a DiscreteCondtional */ - void add(const Signature& s); - -// /** Add a DiscreteCondtional in front, when listing parents first*/ -// GTSAM_EXPORT void add_front(const Signature& s); - + template + void add(Args&&... args) { + emplace_shared(std::forward(args)...); + } + //** evaluate for given DiscreteValues */ double evaluate(const DiscreteValues & values) const; diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index c1de114eb..7abad4245 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -83,8 +83,8 @@ public: /// Find value for given assignment of values to variables virtual double operator()(const DiscreteValues&) const = 0; - /// Synonym for operator(), mostly for wrapper - double evaluate(const DiscreteValues& values) const { return operator()(values); } + /// Synonym for operator(), mostly for wrapper + double evaluate(const DiscreteValues& values) const { return operator()(values); } /// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; diff --git a/gtsam/discrete/DiscreteKey.h b/gtsam/discrete/DiscreteKey.h index c041c7e8e..f829e4f7c 100644 --- a/gtsam/discrete/DiscreteKey.h +++ b/gtsam/discrete/DiscreteKey.h @@ -36,9 +36,8 @@ namespace gtsam { /// DiscreteKeys is a set of keys that can be assembled using the & operator struct DiscreteKeys: public std::vector { - /// Default constructor - DiscreteKeys() { - } + // Forward all constructors. + using std::vector::vector; /// Construct from a key DiscreteKeys(const DiscreteKey& key) { diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 8e67478db..47583c612 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -25,14 +25,10 @@ class DiscreteFactor { gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const; bool empty() const; + size_t size() const; double evaluate(const gtsam::DiscreteValues& values) const; }; -#include -class DiscreteConditional { - DiscreteConditional(); -}; - #include virtual class DecisionTreeFactor: gtsam::DiscreteFactor { DecisionTreeFactor(); @@ -45,18 +41,91 @@ virtual class DecisionTreeFactor: gtsam::DiscreteFactor { double evaluate(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat??? }; +#include +virtual class DiscreteConditional : gtsam::DecisionTreeFactor { + DiscreteConditional(); + DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f); + DiscreteConditional(const gtsam::DiscreteKey& key, + const gtsam::DiscreteKeys& parents, string spec); + DiscreteConditional(const gtsam::DecisionTreeFactor& joint, + const gtsam::DecisionTreeFactor& marginal); + DiscreteConditional(const gtsam::DecisionTreeFactor& joint, + const gtsam::DecisionTreeFactor& marginal, + const gtsam::Ordering& orderedKeys); + size_t size() const; // TODO(dellaert): why do I have to repeat??? + double evaluate(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat??? + void print(string s = "Discrete Conditional\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const; + void printSignature( + string s = "Discrete Conditional: ", + const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; + gtsam::DecisionTreeFactor* toFactor() const; + gtsam::DecisionTreeFactor* chooseAsFactor(const gtsam::DiscreteValues& parentsValues) const; + size_t solve(const gtsam::DiscreteValues& parentsValues) const; + size_t sample(const gtsam::DiscreteValues& parentsValues) const; + void solveInPlace(gtsam::DiscreteValues@ parentsValues) const; + void sampleInPlace(gtsam::DiscreteValues@ parentsValues) const; +}; + +#include +class DiscreteBayesNet { + DiscreteBayesNet(); + void add(const gtsam::DiscreteKey& key, + const gtsam::DiscreteKeys& parents, string spec); + bool empty() const; + size_t size() const; + gtsam::KeySet keys() const; + const gtsam::DiscreteConditional* at(size_t i) const; + void print(string s = "DiscreteBayesNet\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const; + void saveGraph(string s, + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + void add(const gtsam::DiscreteConditional& s); + double evaluate(const gtsam::DiscreteValues& values) const; + gtsam::DiscreteValues optimize() const; + gtsam::DiscreteValues sample() const; +}; + +#include +class DiscreteBayesTree { + DiscreteBayesTree(); + void print(string s = "DiscreteBayesTree\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const; + double evaluate(const gtsam::DiscreteValues& values) const; +}; + #include class DiscreteFactorGraph { DiscreteFactorGraph(); + DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet); + void add(const gtsam::DiscreteKey& j, string table); void add(const gtsam::DiscreteKey& j1, const gtsam::DiscreteKey& j2, string table); void add(const gtsam::DiscreteKeys& keys, string table); + + bool empty() const; + size_t size() const; + gtsam::KeySet keys() const; + const gtsam::DiscreteFactor* at(size_t i) const; + void print(string s = "") const; bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const; - gtsam::KeySet keys() const; + gtsam::DecisionTreeFactor product() const; double evaluate(const gtsam::DiscreteValues& values) const; - DiscreteValues optimize() const; + gtsam::DiscreteValues optimize() const; + + gtsam::DiscreteBayesNet eliminateSequential(); + gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering); + gtsam::DiscreteBayesTree eliminateMultifrontal(); + gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering); }; } // namespace gtsam diff --git a/python/gtsam/tests/test_DiscreteBayesNet.py b/python/gtsam/tests/test_DiscreteBayesNet.py new file mode 100644 index 000000000..2abc65715 --- /dev/null +++ b/python/gtsam/tests/test_DiscreteBayesNet.py @@ -0,0 +1,115 @@ +""" +GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Bayes Nets. +Author: Frank Dellaert +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph, + DiscreteKeys, DiscreteValues, Ordering) +from gtsam.utils.test_case import GtsamTestCase + + +class TestDiscreteBayesNet(GtsamTestCase): + """Tests for Discrete Bayes Nets.""" + + def test_constructor(self): + """Test constructing a Bayes net.""" + + bayesNet = DiscreteBayesNet() + Parent, Child = (0, 2), (1, 2) + empty = DiscreteKeys() + prior = DiscreteConditional(Parent, empty, "6/4") + bayesNet.add(prior) + + parents = DiscreteKeys() + parents.push_back(Parent) + conditional = DiscreteConditional(Child, parents, "7/3 8/2") + bayesNet.add(conditional) + + # Check conversion to factor graph: + fg = DiscreteFactorGraph(bayesNet) + self.assertEqual(fg.size(), 2) + self.assertEqual(fg.at(1).size(), 2) + + def test_Asia(self): + """Test full Asia example.""" + + Asia = (0, 2) + Smoking = (4, 2) + Tuberculosis = (3, 2) + LungCancer = (6, 2) + + Bronchitis = (7, 2) + Either = (5, 2) + XRay = (2, 2) + Dyspnea = (1, 2) + + def P(keys): + dks = DiscreteKeys() + for key in keys: + dks.push_back(key) + return dks + + asia = DiscreteBayesNet() + asia.add(Asia, P([]), "99/1") + asia.add(Smoking, P([]), "50/50") + + asia.add(Tuberculosis, P([Asia]), "99/1 95/5") + asia.add(LungCancer, P([Smoking]), "99/1 90/10") + asia.add(Bronchitis, P([Smoking]), "70/30 40/60") + + asia.add(Either, P([Tuberculosis, LungCancer]), "F T T T") + + asia.add(XRay, P([Either]), "95/5 2/98") + asia.add(Dyspnea, P([Either, Bronchitis]), "9/1 2/8 3/7 1/9") + + # Convert to factor graph + fg = DiscreteFactorGraph(asia) + + # Create solver and eliminate + ordering = Ordering() + for j in range(8): + ordering.push_back(j) + chordal = fg.eliminateSequential(ordering) + expected2 = DiscreteConditional(Bronchitis, P([]), "11/9") + self.gtsamAssertEquals(chordal.at(7), expected2) + + # solve + actualMPE = chordal.optimize() + expectedMPE = DiscreteValues() + for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]: + expectedMPE[key[0]] = 0 + self.assertEqual(list(actualMPE.items()), + list(expectedMPE.items())) + + # add evidence, we were in Asia and we have dyspnea + fg.add(Asia, "0 1") + fg.add(Dyspnea, "0 1") + + # solve again, now with evidence + chordal2 = fg.eliminateSequential(ordering) + actualMPE2 = chordal2.optimize() + expectedMPE2 = DiscreteValues() + for key in [XRay, Tuberculosis, Either, LungCancer]: + expectedMPE2[key[0]] = 0 + for key in [Asia, Dyspnea, Smoking, Bronchitis]: + expectedMPE2[key[0]] = 1 + self.assertEqual(list(actualMPE2.items()), + list(expectedMPE2.items())) + + # now sample from it + actualSample = chordal2.sample() + self.assertEqual(len(actualSample), 8) + + +if __name__ == "__main__": + unittest.main()