Wrapped and test Discrete Bayes Nets
parent
8f4b15b780
commit
a4dab12bb0
|
@ -34,16 +34,6 @@ namespace gtsam {
|
||||||
return Base::equals(bn, tol);
|
return Base::equals(bn, tol);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
// void DiscreteBayesNet::add_front(const Signature& s) {
|
|
||||||
// push_front(boost::make_shared<DiscreteConditional>(s));
|
|
||||||
// }
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
void DiscreteBayesNet::add(const Signature& s) {
|
|
||||||
push_back(boost::make_shared<DiscreteConditional>(s));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double DiscreteBayesNet::evaluate(const DiscreteValues & values) const {
|
double DiscreteBayesNet::evaluate(const DiscreteValues & values) const {
|
||||||
// evaluate all conditionals and multiply
|
// evaluate all conditionals and multiply
|
||||||
|
|
|
@ -71,12 +71,15 @@ namespace gtsam {
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
// Add inherited versions of add.
|
||||||
|
using Base::add;
|
||||||
|
|
||||||
/** Add a DiscreteCondtional */
|
/** Add a DiscreteCondtional */
|
||||||
void add(const Signature& s);
|
template <typename... Args>
|
||||||
|
void add(Args&&... args) {
|
||||||
// /** Add a DiscreteCondtional in front, when listing parents first*/
|
emplace_shared<DiscreteConditional>(std::forward<Args>(args)...);
|
||||||
// GTSAM_EXPORT void add_front(const Signature& s);
|
}
|
||||||
|
|
||||||
//** evaluate for given DiscreteValues */
|
//** evaluate for given DiscreteValues */
|
||||||
double evaluate(const DiscreteValues & values) const;
|
double evaluate(const DiscreteValues & values) const;
|
||||||
|
|
||||||
|
|
|
@ -83,8 +83,8 @@ public:
|
||||||
/// Find value for given assignment of values to variables
|
/// Find value for given assignment of values to variables
|
||||||
virtual double operator()(const DiscreteValues&) const = 0;
|
virtual double operator()(const DiscreteValues&) const = 0;
|
||||||
|
|
||||||
/// Synonym for operator(), mostly for wrapper
|
/// Synonym for operator(), mostly for wrapper
|
||||||
double evaluate(const DiscreteValues& values) const { return operator()(values); }
|
double evaluate(const DiscreteValues& values) const { return operator()(values); }
|
||||||
|
|
||||||
/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
|
/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
|
||||||
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
|
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
|
||||||
|
|
|
@ -36,9 +36,8 @@ namespace gtsam {
|
||||||
/// DiscreteKeys is a set of keys that can be assembled using the & operator
|
/// DiscreteKeys is a set of keys that can be assembled using the & operator
|
||||||
struct DiscreteKeys: public std::vector<DiscreteKey> {
|
struct DiscreteKeys: public std::vector<DiscreteKey> {
|
||||||
|
|
||||||
/// Default constructor
|
// Forward all constructors.
|
||||||
DiscreteKeys() {
|
using std::vector<DiscreteKey>::vector;
|
||||||
}
|
|
||||||
|
|
||||||
/// Construct from a key
|
/// Construct from a key
|
||||||
DiscreteKeys(const DiscreteKey& key) {
|
DiscreteKeys(const DiscreteKey& key) {
|
||||||
|
|
|
@ -25,14 +25,10 @@ class DiscreteFactor {
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const;
|
bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const;
|
||||||
bool empty() const;
|
bool empty() const;
|
||||||
|
size_t size() const;
|
||||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
|
||||||
class DiscreteConditional {
|
|
||||||
DiscreteConditional();
|
|
||||||
};
|
|
||||||
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
virtual class DecisionTreeFactor: gtsam::DiscreteFactor {
|
virtual class DecisionTreeFactor: gtsam::DiscreteFactor {
|
||||||
DecisionTreeFactor();
|
DecisionTreeFactor();
|
||||||
|
@ -45,18 +41,91 @@ virtual class DecisionTreeFactor: gtsam::DiscreteFactor {
|
||||||
double evaluate(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
|
double evaluate(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
|
virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
|
DiscreteConditional();
|
||||||
|
DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f);
|
||||||
|
DiscreteConditional(const gtsam::DiscreteKey& key,
|
||||||
|
const gtsam::DiscreteKeys& parents, string spec);
|
||||||
|
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
||||||
|
const gtsam::DecisionTreeFactor& marginal);
|
||||||
|
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
||||||
|
const gtsam::DecisionTreeFactor& marginal,
|
||||||
|
const gtsam::Ordering& orderedKeys);
|
||||||
|
size_t size() const; // TODO(dellaert): why do I have to repeat???
|
||||||
|
double evaluate(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
|
||||||
|
void print(string s = "Discrete Conditional\n",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
|
||||||
|
void printSignature(
|
||||||
|
string s = "Discrete Conditional: ",
|
||||||
|
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||||
|
gtsam::DecisionTreeFactor* toFactor() const;
|
||||||
|
gtsam::DecisionTreeFactor* chooseAsFactor(const gtsam::DiscreteValues& parentsValues) const;
|
||||||
|
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
|
||||||
|
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
||||||
|
void solveInPlace(gtsam::DiscreteValues@ parentsValues) const;
|
||||||
|
void sampleInPlace(gtsam::DiscreteValues@ parentsValues) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
|
class DiscreteBayesNet {
|
||||||
|
DiscreteBayesNet();
|
||||||
|
void add(const gtsam::DiscreteKey& key,
|
||||||
|
const gtsam::DiscreteKeys& parents, string spec);
|
||||||
|
bool empty() const;
|
||||||
|
size_t size() const;
|
||||||
|
gtsam::KeySet keys() const;
|
||||||
|
const gtsam::DiscreteConditional* at(size_t i) const;
|
||||||
|
void print(string s = "DiscreteBayesNet\n",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
|
||||||
|
void saveGraph(string s,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
void add(const gtsam::DiscreteConditional& s);
|
||||||
|
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||||
|
gtsam::DiscreteValues optimize() const;
|
||||||
|
gtsam::DiscreteValues sample() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||||
|
class DiscreteBayesTree {
|
||||||
|
DiscreteBayesTree();
|
||||||
|
void print(string s = "DiscreteBayesTree\n",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const;
|
||||||
|
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||||
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
class DiscreteFactorGraph {
|
class DiscreteFactorGraph {
|
||||||
DiscreteFactorGraph();
|
DiscreteFactorGraph();
|
||||||
|
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
|
||||||
|
|
||||||
void add(const gtsam::DiscreteKey& j, string table);
|
void add(const gtsam::DiscreteKey& j, string table);
|
||||||
void add(const gtsam::DiscreteKey& j1, const gtsam::DiscreteKey& j2, string table);
|
void add(const gtsam::DiscreteKey& j1, const gtsam::DiscreteKey& j2, string table);
|
||||||
void add(const gtsam::DiscreteKeys& keys, string table);
|
void add(const gtsam::DiscreteKeys& keys, string table);
|
||||||
|
|
||||||
|
bool empty() const;
|
||||||
|
size_t size() const;
|
||||||
|
gtsam::KeySet keys() const;
|
||||||
|
const gtsam::DiscreteFactor* at(size_t i) const;
|
||||||
|
|
||||||
void print(string s = "") const;
|
void print(string s = "") const;
|
||||||
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;
|
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;
|
||||||
gtsam::KeySet keys() const;
|
|
||||||
gtsam::DecisionTreeFactor product() const;
|
gtsam::DecisionTreeFactor product() const;
|
||||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||||
DiscreteValues optimize() const;
|
gtsam::DiscreteValues optimize() const;
|
||||||
|
|
||||||
|
gtsam::DiscreteBayesNet eliminateSequential();
|
||||||
|
gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering);
|
||||||
|
gtsam::DiscreteBayesTree eliminateMultifrontal();
|
||||||
|
gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -0,0 +1,115 @@
|
||||||
|
"""
|
||||||
|
GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
|
||||||
|
Atlanta, Georgia 30332-0415
|
||||||
|
All Rights Reserved
|
||||||
|
|
||||||
|
See LICENSE for the license information
|
||||||
|
|
||||||
|
Unit tests for Discrete Bayes Nets.
|
||||||
|
Author: Frank Dellaert
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=no-name-in-module, invalid-name
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
|
||||||
|
DiscreteKeys, DiscreteValues, Ordering)
|
||||||
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
|
"""Tests for Discrete Bayes Nets."""
|
||||||
|
|
||||||
|
def test_constructor(self):
|
||||||
|
"""Test constructing a Bayes net."""
|
||||||
|
|
||||||
|
bayesNet = DiscreteBayesNet()
|
||||||
|
Parent, Child = (0, 2), (1, 2)
|
||||||
|
empty = DiscreteKeys()
|
||||||
|
prior = DiscreteConditional(Parent, empty, "6/4")
|
||||||
|
bayesNet.add(prior)
|
||||||
|
|
||||||
|
parents = DiscreteKeys()
|
||||||
|
parents.push_back(Parent)
|
||||||
|
conditional = DiscreteConditional(Child, parents, "7/3 8/2")
|
||||||
|
bayesNet.add(conditional)
|
||||||
|
|
||||||
|
# Check conversion to factor graph:
|
||||||
|
fg = DiscreteFactorGraph(bayesNet)
|
||||||
|
self.assertEqual(fg.size(), 2)
|
||||||
|
self.assertEqual(fg.at(1).size(), 2)
|
||||||
|
|
||||||
|
def test_Asia(self):
|
||||||
|
"""Test full Asia example."""
|
||||||
|
|
||||||
|
Asia = (0, 2)
|
||||||
|
Smoking = (4, 2)
|
||||||
|
Tuberculosis = (3, 2)
|
||||||
|
LungCancer = (6, 2)
|
||||||
|
|
||||||
|
Bronchitis = (7, 2)
|
||||||
|
Either = (5, 2)
|
||||||
|
XRay = (2, 2)
|
||||||
|
Dyspnea = (1, 2)
|
||||||
|
|
||||||
|
def P(keys):
|
||||||
|
dks = DiscreteKeys()
|
||||||
|
for key in keys:
|
||||||
|
dks.push_back(key)
|
||||||
|
return dks
|
||||||
|
|
||||||
|
asia = DiscreteBayesNet()
|
||||||
|
asia.add(Asia, P([]), "99/1")
|
||||||
|
asia.add(Smoking, P([]), "50/50")
|
||||||
|
|
||||||
|
asia.add(Tuberculosis, P([Asia]), "99/1 95/5")
|
||||||
|
asia.add(LungCancer, P([Smoking]), "99/1 90/10")
|
||||||
|
asia.add(Bronchitis, P([Smoking]), "70/30 40/60")
|
||||||
|
|
||||||
|
asia.add(Either, P([Tuberculosis, LungCancer]), "F T T T")
|
||||||
|
|
||||||
|
asia.add(XRay, P([Either]), "95/5 2/98")
|
||||||
|
asia.add(Dyspnea, P([Either, Bronchitis]), "9/1 2/8 3/7 1/9")
|
||||||
|
|
||||||
|
# Convert to factor graph
|
||||||
|
fg = DiscreteFactorGraph(asia)
|
||||||
|
|
||||||
|
# Create solver and eliminate
|
||||||
|
ordering = Ordering()
|
||||||
|
for j in range(8):
|
||||||
|
ordering.push_back(j)
|
||||||
|
chordal = fg.eliminateSequential(ordering)
|
||||||
|
expected2 = DiscreteConditional(Bronchitis, P([]), "11/9")
|
||||||
|
self.gtsamAssertEquals(chordal.at(7), expected2)
|
||||||
|
|
||||||
|
# solve
|
||||||
|
actualMPE = chordal.optimize()
|
||||||
|
expectedMPE = DiscreteValues()
|
||||||
|
for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]:
|
||||||
|
expectedMPE[key[0]] = 0
|
||||||
|
self.assertEqual(list(actualMPE.items()),
|
||||||
|
list(expectedMPE.items()))
|
||||||
|
|
||||||
|
# add evidence, we were in Asia and we have dyspnea
|
||||||
|
fg.add(Asia, "0 1")
|
||||||
|
fg.add(Dyspnea, "0 1")
|
||||||
|
|
||||||
|
# solve again, now with evidence
|
||||||
|
chordal2 = fg.eliminateSequential(ordering)
|
||||||
|
actualMPE2 = chordal2.optimize()
|
||||||
|
expectedMPE2 = DiscreteValues()
|
||||||
|
for key in [XRay, Tuberculosis, Either, LungCancer]:
|
||||||
|
expectedMPE2[key[0]] = 0
|
||||||
|
for key in [Asia, Dyspnea, Smoking, Bronchitis]:
|
||||||
|
expectedMPE2[key[0]] = 1
|
||||||
|
self.assertEqual(list(actualMPE2.items()),
|
||||||
|
list(expectedMPE2.items()))
|
||||||
|
|
||||||
|
# now sample from it
|
||||||
|
actualSample = chordal2.sample()
|
||||||
|
self.assertEqual(len(actualSample), 8)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue