Wrapped and test Discrete Bayes Nets
parent
8f4b15b780
commit
a4dab12bb0
|
@ -34,16 +34,6 @@ namespace gtsam {
|
|||
return Base::equals(bn, tol);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// void DiscreteBayesNet::add_front(const Signature& s) {
|
||||
// push_front(boost::make_shared<DiscreteConditional>(s));
|
||||
// }
|
||||
|
||||
/* ************************************************************************* */
|
||||
void DiscreteBayesNet::add(const Signature& s) {
|
||||
push_back(boost::make_shared<DiscreteConditional>(s));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double DiscreteBayesNet::evaluate(const DiscreteValues & values) const {
|
||||
// evaluate all conditionals and multiply
|
||||
|
|
|
@ -71,11 +71,14 @@ namespace gtsam {
|
|||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/** Add a DiscreteCondtional */
|
||||
void add(const Signature& s);
|
||||
// Add inherited versions of add.
|
||||
using Base::add;
|
||||
|
||||
// /** Add a DiscreteCondtional in front, when listing parents first*/
|
||||
// GTSAM_EXPORT void add_front(const Signature& s);
|
||||
/** Add a DiscreteCondtional */
|
||||
template <typename... Args>
|
||||
void add(Args&&... args) {
|
||||
emplace_shared<DiscreteConditional>(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
//** evaluate for given DiscreteValues */
|
||||
double evaluate(const DiscreteValues & values) const;
|
||||
|
|
|
@ -36,9 +36,8 @@ namespace gtsam {
|
|||
/// DiscreteKeys is a set of keys that can be assembled using the & operator
|
||||
struct DiscreteKeys: public std::vector<DiscreteKey> {
|
||||
|
||||
/// Default constructor
|
||||
DiscreteKeys() {
|
||||
}
|
||||
// Forward all constructors.
|
||||
using std::vector<DiscreteKey>::vector;
|
||||
|
||||
/// Construct from a key
|
||||
DiscreteKeys(const DiscreteKey& key) {
|
||||
|
|
|
@ -25,14 +25,10 @@ class DiscreteFactor {
|
|||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const;
|
||||
bool empty() const;
|
||||
size_t size() const;
|
||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
class DiscreteConditional {
|
||||
DiscreteConditional();
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
virtual class DecisionTreeFactor: gtsam::DiscreteFactor {
|
||||
DecisionTreeFactor();
|
||||
|
@ -45,18 +41,91 @@ virtual class DecisionTreeFactor: gtsam::DiscreteFactor {
|
|||
double evaluate(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||
DiscreteConditional();
|
||||
DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f);
|
||||
DiscreteConditional(const gtsam::DiscreteKey& key,
|
||||
const gtsam::DiscreteKeys& parents, string spec);
|
||||
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
||||
const gtsam::DecisionTreeFactor& marginal);
|
||||
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
||||
const gtsam::DecisionTreeFactor& marginal,
|
||||
const gtsam::Ordering& orderedKeys);
|
||||
size_t size() const; // TODO(dellaert): why do I have to repeat???
|
||||
double evaluate(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
|
||||
void print(string s = "Discrete Conditional\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
|
||||
void printSignature(
|
||||
string s = "Discrete Conditional: ",
|
||||
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||
gtsam::DecisionTreeFactor* toFactor() const;
|
||||
gtsam::DecisionTreeFactor* chooseAsFactor(const gtsam::DiscreteValues& parentsValues) const;
|
||||
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
|
||||
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
||||
void solveInPlace(gtsam::DiscreteValues@ parentsValues) const;
|
||||
void sampleInPlace(gtsam::DiscreteValues@ parentsValues) const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
class DiscreteBayesNet {
|
||||
DiscreteBayesNet();
|
||||
void add(const gtsam::DiscreteKey& key,
|
||||
const gtsam::DiscreteKeys& parents, string spec);
|
||||
bool empty() const;
|
||||
size_t size() const;
|
||||
gtsam::KeySet keys() const;
|
||||
const gtsam::DiscreteConditional* at(size_t i) const;
|
||||
void print(string s = "DiscreteBayesNet\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
|
||||
void saveGraph(string s,
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
void add(const gtsam::DiscreteConditional& s);
|
||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||
gtsam::DiscreteValues optimize() const;
|
||||
gtsam::DiscreteValues sample() const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||
class DiscreteBayesTree {
|
||||
DiscreteBayesTree();
|
||||
void print(string s = "DiscreteBayesTree\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const;
|
||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
class DiscreteFactorGraph {
|
||||
DiscreteFactorGraph();
|
||||
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
|
||||
|
||||
void add(const gtsam::DiscreteKey& j, string table);
|
||||
void add(const gtsam::DiscreteKey& j1, const gtsam::DiscreteKey& j2, string table);
|
||||
void add(const gtsam::DiscreteKeys& keys, string table);
|
||||
|
||||
bool empty() const;
|
||||
size_t size() const;
|
||||
gtsam::KeySet keys() const;
|
||||
const gtsam::DiscreteFactor* at(size_t i) const;
|
||||
|
||||
void print(string s = "") const;
|
||||
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;
|
||||
gtsam::KeySet keys() const;
|
||||
|
||||
gtsam::DecisionTreeFactor product() const;
|
||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||
DiscreteValues optimize() const;
|
||||
gtsam::DiscreteValues optimize() const;
|
||||
|
||||
gtsam::DiscreteBayesNet eliminateSequential();
|
||||
gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering);
|
||||
gtsam::DiscreteBayesTree eliminateMultifrontal();
|
||||
gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering);
|
||||
};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
"""
|
||||
GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
|
||||
Atlanta, Georgia 30332-0415
|
||||
All Rights Reserved
|
||||
|
||||
See LICENSE for the license information
|
||||
|
||||
Unit tests for Discrete Bayes Nets.
|
||||
Author: Frank Dellaert
|
||||
"""
|
||||
|
||||
# pylint: disable=no-name-in-module, invalid-name
|
||||
|
||||
import unittest
|
||||
|
||||
from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
|
||||
DiscreteKeys, DiscreteValues, Ordering)
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
|
||||
class TestDiscreteBayesNet(GtsamTestCase):
|
||||
"""Tests for Discrete Bayes Nets."""
|
||||
|
||||
def test_constructor(self):
|
||||
"""Test constructing a Bayes net."""
|
||||
|
||||
bayesNet = DiscreteBayesNet()
|
||||
Parent, Child = (0, 2), (1, 2)
|
||||
empty = DiscreteKeys()
|
||||
prior = DiscreteConditional(Parent, empty, "6/4")
|
||||
bayesNet.add(prior)
|
||||
|
||||
parents = DiscreteKeys()
|
||||
parents.push_back(Parent)
|
||||
conditional = DiscreteConditional(Child, parents, "7/3 8/2")
|
||||
bayesNet.add(conditional)
|
||||
|
||||
# Check conversion to factor graph:
|
||||
fg = DiscreteFactorGraph(bayesNet)
|
||||
self.assertEqual(fg.size(), 2)
|
||||
self.assertEqual(fg.at(1).size(), 2)
|
||||
|
||||
def test_Asia(self):
|
||||
"""Test full Asia example."""
|
||||
|
||||
Asia = (0, 2)
|
||||
Smoking = (4, 2)
|
||||
Tuberculosis = (3, 2)
|
||||
LungCancer = (6, 2)
|
||||
|
||||
Bronchitis = (7, 2)
|
||||
Either = (5, 2)
|
||||
XRay = (2, 2)
|
||||
Dyspnea = (1, 2)
|
||||
|
||||
def P(keys):
|
||||
dks = DiscreteKeys()
|
||||
for key in keys:
|
||||
dks.push_back(key)
|
||||
return dks
|
||||
|
||||
asia = DiscreteBayesNet()
|
||||
asia.add(Asia, P([]), "99/1")
|
||||
asia.add(Smoking, P([]), "50/50")
|
||||
|
||||
asia.add(Tuberculosis, P([Asia]), "99/1 95/5")
|
||||
asia.add(LungCancer, P([Smoking]), "99/1 90/10")
|
||||
asia.add(Bronchitis, P([Smoking]), "70/30 40/60")
|
||||
|
||||
asia.add(Either, P([Tuberculosis, LungCancer]), "F T T T")
|
||||
|
||||
asia.add(XRay, P([Either]), "95/5 2/98")
|
||||
asia.add(Dyspnea, P([Either, Bronchitis]), "9/1 2/8 3/7 1/9")
|
||||
|
||||
# Convert to factor graph
|
||||
fg = DiscreteFactorGraph(asia)
|
||||
|
||||
# Create solver and eliminate
|
||||
ordering = Ordering()
|
||||
for j in range(8):
|
||||
ordering.push_back(j)
|
||||
chordal = fg.eliminateSequential(ordering)
|
||||
expected2 = DiscreteConditional(Bronchitis, P([]), "11/9")
|
||||
self.gtsamAssertEquals(chordal.at(7), expected2)
|
||||
|
||||
# solve
|
||||
actualMPE = chordal.optimize()
|
||||
expectedMPE = DiscreteValues()
|
||||
for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]:
|
||||
expectedMPE[key[0]] = 0
|
||||
self.assertEqual(list(actualMPE.items()),
|
||||
list(expectedMPE.items()))
|
||||
|
||||
# add evidence, we were in Asia and we have dyspnea
|
||||
fg.add(Asia, "0 1")
|
||||
fg.add(Dyspnea, "0 1")
|
||||
|
||||
# solve again, now with evidence
|
||||
chordal2 = fg.eliminateSequential(ordering)
|
||||
actualMPE2 = chordal2.optimize()
|
||||
expectedMPE2 = DiscreteValues()
|
||||
for key in [XRay, Tuberculosis, Either, LungCancer]:
|
||||
expectedMPE2[key[0]] = 0
|
||||
for key in [Asia, Dyspnea, Smoking, Bronchitis]:
|
||||
expectedMPE2[key[0]] = 1
|
||||
self.assertEqual(list(actualMPE2.items()),
|
||||
list(expectedMPE2.items()))
|
||||
|
||||
# now sample from it
|
||||
actualSample = chordal2.sample()
|
||||
self.assertEqual(len(actualSample), 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue