Wrapped and test Discrete Bayes Nets

release/4.3a0
Frank Dellaert 2021-12-15 21:56:41 -05:00
parent 8f4b15b780
commit a4dab12bb0
6 changed files with 203 additions and 27 deletions

View File

@ -34,16 +34,6 @@ namespace gtsam {
return Base::equals(bn, tol); return Base::equals(bn, tol);
} }
/* ************************************************************************* */
// void DiscreteBayesNet::add_front(const Signature& s) {
// push_front(boost::make_shared<DiscreteConditional>(s));
// }
/* ************************************************************************* */
void DiscreteBayesNet::add(const Signature& s) {
push_back(boost::make_shared<DiscreteConditional>(s));
}
/* ************************************************************************* */ /* ************************************************************************* */
double DiscreteBayesNet::evaluate(const DiscreteValues & values) const { double DiscreteBayesNet::evaluate(const DiscreteValues & values) const {
// evaluate all conditionals and multiply // evaluate all conditionals and multiply

View File

@ -71,12 +71,15 @@ namespace gtsam {
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
// Add inherited versions of add.
using Base::add;
/** Add a DiscreteCondtional */ /** Add a DiscreteCondtional */
void add(const Signature& s); template <typename... Args>
void add(Args&&... args) {
// /** Add a DiscreteCondtional in front, when listing parents first*/ emplace_shared<DiscreteConditional>(std::forward<Args>(args)...);
// GTSAM_EXPORT void add_front(const Signature& s); }
//** evaluate for given DiscreteValues */ //** evaluate for given DiscreteValues */
double evaluate(const DiscreteValues & values) const; double evaluate(const DiscreteValues & values) const;

View File

@ -83,8 +83,8 @@ public:
/// Find value for given assignment of values to variables /// Find value for given assignment of values to variables
virtual double operator()(const DiscreteValues&) const = 0; virtual double operator()(const DiscreteValues&) const = 0;
/// Synonym for operator(), mostly for wrapper /// Synonym for operator(), mostly for wrapper
double evaluate(const DiscreteValues& values) const { return operator()(values); } double evaluate(const DiscreteValues& values) const { return operator()(values); }
/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor /// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;

View File

@ -36,9 +36,8 @@ namespace gtsam {
/// DiscreteKeys is a set of keys that can be assembled using the & operator /// DiscreteKeys is a set of keys that can be assembled using the & operator
struct DiscreteKeys: public std::vector<DiscreteKey> { struct DiscreteKeys: public std::vector<DiscreteKey> {
/// Default constructor // Forward all constructors.
DiscreteKeys() { using std::vector<DiscreteKey>::vector;
}
/// Construct from a key /// Construct from a key
DiscreteKeys(const DiscreteKey& key) { DiscreteKeys(const DiscreteKey& key) {

View File

@ -25,14 +25,10 @@ class DiscreteFactor {
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const; bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const;
bool empty() const; bool empty() const;
size_t size() const;
double evaluate(const gtsam::DiscreteValues& values) const; double evaluate(const gtsam::DiscreteValues& values) const;
}; };
#include <gtsam/discrete/DiscreteConditional.h>
class DiscreteConditional {
DiscreteConditional();
};
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
virtual class DecisionTreeFactor: gtsam::DiscreteFactor { virtual class DecisionTreeFactor: gtsam::DiscreteFactor {
DecisionTreeFactor(); DecisionTreeFactor();
@ -45,18 +41,91 @@ virtual class DecisionTreeFactor: gtsam::DiscreteFactor {
double evaluate(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat??? double evaluate(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
}; };
#include <gtsam/discrete/DiscreteConditional.h>
virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
DiscreteConditional();
DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f);
DiscreteConditional(const gtsam::DiscreteKey& key,
const gtsam::DiscreteKeys& parents, string spec);
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
const gtsam::DecisionTreeFactor& marginal);
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
const gtsam::DecisionTreeFactor& marginal,
const gtsam::Ordering& orderedKeys);
size_t size() const; // TODO(dellaert): why do I have to repeat???
double evaluate(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
void print(string s = "Discrete Conditional\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
void printSignature(
string s = "Discrete Conditional: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
gtsam::DecisionTreeFactor* toFactor() const;
gtsam::DecisionTreeFactor* chooseAsFactor(const gtsam::DiscreteValues& parentsValues) const;
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
void solveInPlace(gtsam::DiscreteValues@ parentsValues) const;
void sampleInPlace(gtsam::DiscreteValues@ parentsValues) const;
};
#include <gtsam/discrete/DiscreteBayesNet.h>
class DiscreteBayesNet {
DiscreteBayesNet();
void add(const gtsam::DiscreteKey& key,
const gtsam::DiscreteKeys& parents, string spec);
bool empty() const;
size_t size() const;
gtsam::KeySet keys() const;
const gtsam::DiscreteConditional* at(size_t i) const;
void print(string s = "DiscreteBayesNet\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
void saveGraph(string s,
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void add(const gtsam::DiscreteConditional& s);
double evaluate(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const;
gtsam::DiscreteValues sample() const;
};
#include <gtsam/discrete/DiscreteBayesTree.h>
class DiscreteBayesTree {
DiscreteBayesTree();
void print(string s = "DiscreteBayesTree\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const;
double evaluate(const gtsam::DiscreteValues& values) const;
};
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
class DiscreteFactorGraph { class DiscreteFactorGraph {
DiscreteFactorGraph(); DiscreteFactorGraph();
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
void add(const gtsam::DiscreteKey& j, string table); void add(const gtsam::DiscreteKey& j, string table);
void add(const gtsam::DiscreteKey& j1, const gtsam::DiscreteKey& j2, string table); void add(const gtsam::DiscreteKey& j1, const gtsam::DiscreteKey& j2, string table);
void add(const gtsam::DiscreteKeys& keys, string table); void add(const gtsam::DiscreteKeys& keys, string table);
bool empty() const;
size_t size() const;
gtsam::KeySet keys() const;
const gtsam::DiscreteFactor* at(size_t i) const;
void print(string s = "") const; void print(string s = "") const;
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const; bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;
gtsam::KeySet keys() const;
gtsam::DecisionTreeFactor product() const; gtsam::DecisionTreeFactor product() const;
double evaluate(const gtsam::DiscreteValues& values) const; double evaluate(const gtsam::DiscreteValues& values) const;
DiscreteValues optimize() const; gtsam::DiscreteValues optimize() const;
gtsam::DiscreteBayesNet eliminateSequential();
gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesTree eliminateMultifrontal();
gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering);
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -0,0 +1,115 @@
"""
GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
Atlanta, Georgia 30332-0415
All Rights Reserved
See LICENSE for the license information
Unit tests for Discrete Bayes Nets.
Author: Frank Dellaert
"""
# pylint: disable=no-name-in-module, invalid-name
import unittest
from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
DiscreteKeys, DiscreteValues, Ordering)
from gtsam.utils.test_case import GtsamTestCase
class TestDiscreteBayesNet(GtsamTestCase):
"""Tests for Discrete Bayes Nets."""
def test_constructor(self):
"""Test constructing a Bayes net."""
bayesNet = DiscreteBayesNet()
Parent, Child = (0, 2), (1, 2)
empty = DiscreteKeys()
prior = DiscreteConditional(Parent, empty, "6/4")
bayesNet.add(prior)
parents = DiscreteKeys()
parents.push_back(Parent)
conditional = DiscreteConditional(Child, parents, "7/3 8/2")
bayesNet.add(conditional)
# Check conversion to factor graph:
fg = DiscreteFactorGraph(bayesNet)
self.assertEqual(fg.size(), 2)
self.assertEqual(fg.at(1).size(), 2)
def test_Asia(self):
"""Test full Asia example."""
Asia = (0, 2)
Smoking = (4, 2)
Tuberculosis = (3, 2)
LungCancer = (6, 2)
Bronchitis = (7, 2)
Either = (5, 2)
XRay = (2, 2)
Dyspnea = (1, 2)
def P(keys):
dks = DiscreteKeys()
for key in keys:
dks.push_back(key)
return dks
asia = DiscreteBayesNet()
asia.add(Asia, P([]), "99/1")
asia.add(Smoking, P([]), "50/50")
asia.add(Tuberculosis, P([Asia]), "99/1 95/5")
asia.add(LungCancer, P([Smoking]), "99/1 90/10")
asia.add(Bronchitis, P([Smoking]), "70/30 40/60")
asia.add(Either, P([Tuberculosis, LungCancer]), "F T T T")
asia.add(XRay, P([Either]), "95/5 2/98")
asia.add(Dyspnea, P([Either, Bronchitis]), "9/1 2/8 3/7 1/9")
# Convert to factor graph
fg = DiscreteFactorGraph(asia)
# Create solver and eliminate
ordering = Ordering()
for j in range(8):
ordering.push_back(j)
chordal = fg.eliminateSequential(ordering)
expected2 = DiscreteConditional(Bronchitis, P([]), "11/9")
self.gtsamAssertEquals(chordal.at(7), expected2)
# solve
actualMPE = chordal.optimize()
expectedMPE = DiscreteValues()
for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]:
expectedMPE[key[0]] = 0
self.assertEqual(list(actualMPE.items()),
list(expectedMPE.items()))
# add evidence, we were in Asia and we have dyspnea
fg.add(Asia, "0 1")
fg.add(Dyspnea, "0 1")
# solve again, now with evidence
chordal2 = fg.eliminateSequential(ordering)
actualMPE2 = chordal2.optimize()
expectedMPE2 = DiscreteValues()
for key in [XRay, Tuberculosis, Either, LungCancer]:
expectedMPE2[key[0]] = 0
for key in [Asia, Dyspnea, Smoking, Bronchitis]:
expectedMPE2[key[0]] = 1
self.assertEqual(list(actualMPE2.items()),
list(expectedMPE2.items()))
# now sample from it
actualSample = chordal2.sample()
self.assertEqual(len(actualSample), 8)
if __name__ == "__main__":
unittest.main()