Removed all specialized constructors, because wrapper is awesome!
parent
911819c7f2
commit
93e9756ef0
|
|
@ -141,6 +141,7 @@ namespace gtsam {
|
||||||
for (auto& key : keys()) {
|
for (auto& key : keys()) {
|
||||||
pairs.emplace_back(key, cardinalities_.at(key));
|
pairs.emplace_back(key, cardinalities_.at(key));
|
||||||
}
|
}
|
||||||
|
// Reverse to make cartesianProduct output a more natural ordering.
|
||||||
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
||||||
const auto assignments = cartesianProduct(rpairs);
|
const auto assignments = cartesianProduct(rpairs);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -107,7 +107,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
|
||||||
try {
|
try {
|
||||||
value = parentsValues.at(j);
|
value = parentsValues.at(j);
|
||||||
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||||
} catch (exception&) {
|
} catch (std::out_of_range&) {
|
||||||
parentsValues.print("parentsValues: ");
|
parentsValues.print("parentsValues: ");
|
||||||
throw runtime_error("DiscreteConditional::choose: parent value missing");
|
throw runtime_error("DiscreteConditional::choose: parent value missing");
|
||||||
};
|
};
|
||||||
|
|
@ -251,7 +251,11 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
||||||
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
||||||
|
|
||||||
// TODO(Duy): only works for one key now, seems horribly slow this way
|
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||||
assert(nrFrontals() == 1);
|
if (nrFrontals() != 1) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"DiscreteConditional::sample can only be called on single variable "
|
||||||
|
"conditionals");
|
||||||
|
}
|
||||||
Key key = firstFrontalKey();
|
Key key = firstFrontalKey();
|
||||||
size_t nj = cardinality(key);
|
size_t nj = cardinality(key);
|
||||||
vector<double> p(nj);
|
vector<double> p(nj);
|
||||||
|
|
|
||||||
|
|
@ -85,16 +85,6 @@ public:
|
||||||
DiscreteConditional(const DiscreteKey& key, const std::string& spec)
|
DiscreteConditional(const DiscreteKey& key, const std::string& spec)
|
||||||
: DiscreteConditional(Signature(key, {}, spec)) {}
|
: DiscreteConditional(Signature(key, {}, spec)) {}
|
||||||
|
|
||||||
/// Single-parent specialization
|
|
||||||
DiscreteConditional(const DiscreteKey& key, const std::string& spec,
|
|
||||||
const DiscreteKey& parent1)
|
|
||||||
: DiscreteConditional(Signature(key, {parent1}, spec)) {}
|
|
||||||
|
|
||||||
/// Two-parent specialization
|
|
||||||
DiscreteConditional(const DiscreteKey& key, const std::string& spec,
|
|
||||||
const DiscreteKey& parent1, const DiscreteKey& parent2)
|
|
||||||
: DiscreteConditional(Signature(key, {parent1, parent2}, spec)) {}
|
|
||||||
|
|
||||||
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
|
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
|
||||||
DiscreteConditional(const DecisionTreeFactor& joint,
|
DiscreteConditional(const DecisionTreeFactor& joint,
|
||||||
const DecisionTreeFactor& marginal);
|
const DecisionTreeFactor& marginal);
|
||||||
|
|
|
||||||
|
|
@ -57,13 +57,10 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
DiscreteConditional();
|
DiscreteConditional();
|
||||||
DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f);
|
DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f);
|
||||||
DiscreteConditional(const gtsam::DiscreteKey& key, string spec);
|
DiscreteConditional(const gtsam::DiscreteKey& key, string spec);
|
||||||
DiscreteConditional(const gtsam::DiscreteKey& key, string spec,
|
|
||||||
const gtsam::DiscreteKey& parent1);
|
|
||||||
DiscreteConditional(const gtsam::DiscreteKey& key, string spec,
|
|
||||||
const gtsam::DiscreteKey& parent1,
|
|
||||||
const gtsam::DiscreteKey& parent2);
|
|
||||||
DiscreteConditional(const gtsam::DiscreteKey& key,
|
DiscreteConditional(const gtsam::DiscreteKey& key,
|
||||||
const gtsam::DiscreteKeys& parents, string spec);
|
const gtsam::DiscreteKeys& parents, string spec);
|
||||||
|
DiscreteConditional(const gtsam::DiscreteKey& key,
|
||||||
|
const std::vector<gtsam::DiscreteKey>& parents, string spec);
|
||||||
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
||||||
const gtsam::DecisionTreeFactor& marginal);
|
const gtsam::DecisionTreeFactor& marginal);
|
||||||
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
||||||
|
|
@ -109,13 +106,10 @@ class DiscreteBayesNet {
|
||||||
DiscreteBayesNet();
|
DiscreteBayesNet();
|
||||||
void add(const gtsam::DiscreteConditional& s);
|
void add(const gtsam::DiscreteConditional& s);
|
||||||
void add(const gtsam::DiscreteKey& key, string spec);
|
void add(const gtsam::DiscreteKey& key, string spec);
|
||||||
void add(const gtsam::DiscreteKey& key, string spec,
|
|
||||||
const gtsam::DiscreteKey& parent1);
|
|
||||||
void add(const gtsam::DiscreteKey& key, string spec,
|
|
||||||
const gtsam::DiscreteKey& parent1,
|
|
||||||
const gtsam::DiscreteKey& parent2);
|
|
||||||
void add(const gtsam::DiscreteKey& key, const gtsam::DiscreteKeys& parents,
|
void add(const gtsam::DiscreteKey& key, const gtsam::DiscreteKeys& parents,
|
||||||
string spec);
|
string spec);
|
||||||
|
void add(const gtsam::DiscreteKey& key,
|
||||||
|
const std::vector<gtsam::DiscreteKey>& parents, string spec);
|
||||||
bool empty() const;
|
bool empty() const;
|
||||||
size_t size() const;
|
size_t size() const;
|
||||||
gtsam::KeySet keys() const;
|
gtsam::KeySet keys() const;
|
||||||
|
|
|
||||||
|
|
@ -57,14 +57,14 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
asia.add(Asia, "99/1")
|
asia.add(Asia, "99/1")
|
||||||
asia.add(Smoking, "50/50")
|
asia.add(Smoking, "50/50")
|
||||||
|
|
||||||
asia.add(Tuberculosis, "99/1 95/5", Asia)
|
asia.add(Tuberculosis, [Asia], "99/1 95/5")
|
||||||
asia.add(LungCancer, "99/1 90/10", Smoking)
|
asia.add(LungCancer, [Smoking], "99/1 90/10")
|
||||||
asia.add(Bronchitis, "70/30 40/60", Smoking)
|
asia.add(Bronchitis, [Smoking], "70/30 40/60")
|
||||||
|
|
||||||
asia.add(Either, "F T T T", Tuberculosis, LungCancer)
|
asia.add(Either, [Tuberculosis, LungCancer], "F T T T")
|
||||||
|
|
||||||
asia.add(XRay, "95/5 2/98", Either)
|
asia.add(XRay, [Either], "95/5 2/98")
|
||||||
asia.add(Dyspnea, "9/1 2/8 3/7 1/9", Either, Bronchitis)
|
asia.add(Dyspnea, [Either, Bronchitis], "9/1 2/8 3/7 1/9")
|
||||||
|
|
||||||
# Convert to factor graph
|
# Convert to factor graph
|
||||||
fg = DiscreteFactorGraph(asia)
|
fg = DiscreteFactorGraph(asia)
|
||||||
|
|
|
||||||
|
|
@ -14,20 +14,10 @@ Author: Frank Dellaert
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
|
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
|
||||||
DiscreteConditional, DiscreteFactorGraph, DiscreteKeys,
|
DiscreteConditional, DiscreteFactorGraph, Ordering)
|
||||||
Ordering)
|
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
|
||||||
def P(*args):
|
|
||||||
""" Create a DiscreteKeys instances from a variable number of DiscreteKey pairs."""
|
|
||||||
# TODO: We can make life easier by providing variable argument functions in C++ itself.
|
|
||||||
dks = DiscreteKeys()
|
|
||||||
for key in args:
|
|
||||||
dks.push_back(key)
|
|
||||||
return dks
|
|
||||||
|
|
||||||
|
|
||||||
class TestDiscreteBayesNet(GtsamTestCase):
|
class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
"""Tests for Discrete Bayes Nets."""
|
"""Tests for Discrete Bayes Nets."""
|
||||||
|
|
||||||
|
|
@ -40,25 +30,25 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
# Create thin-tree Bayesnet.
|
# Create thin-tree Bayesnet.
|
||||||
bayesNet = DiscreteBayesNet()
|
bayesNet = DiscreteBayesNet()
|
||||||
|
|
||||||
bayesNet.add(keys[0], P(keys[8], keys[12]), "2/3 1/4 3/2 4/1")
|
bayesNet.add(keys[0], [keys[8], keys[12]], "2/3 1/4 3/2 4/1")
|
||||||
bayesNet.add(keys[1], P(keys[8], keys[12]), "4/1 2/3 3/2 1/4")
|
bayesNet.add(keys[1], [keys[8], keys[12]], "4/1 2/3 3/2 1/4")
|
||||||
bayesNet.add(keys[2], P(keys[9], keys[12]), "1/4 8/2 2/3 4/1")
|
bayesNet.add(keys[2], [keys[9], keys[12]], "1/4 8/2 2/3 4/1")
|
||||||
bayesNet.add(keys[3], P(keys[9], keys[12]), "1/4 2/3 3/2 4/1")
|
bayesNet.add(keys[3], [keys[9], keys[12]], "1/4 2/3 3/2 4/1")
|
||||||
|
|
||||||
bayesNet.add(keys[4], P(keys[10], keys[13]), "2/3 1/4 3/2 4/1")
|
bayesNet.add(keys[4], [keys[10], keys[13]], "2/3 1/4 3/2 4/1")
|
||||||
bayesNet.add(keys[5], P(keys[10], keys[13]), "4/1 2/3 3/2 1/4")
|
bayesNet.add(keys[5], [keys[10], keys[13]], "4/1 2/3 3/2 1/4")
|
||||||
bayesNet.add(keys[6], P(keys[11], keys[13]), "1/4 3/2 2/3 4/1")
|
bayesNet.add(keys[6], [keys[11], keys[13]], "1/4 3/2 2/3 4/1")
|
||||||
bayesNet.add(keys[7], P(keys[11], keys[13]), "1/4 2/3 3/2 4/1")
|
bayesNet.add(keys[7], [keys[11], keys[13]], "1/4 2/3 3/2 4/1")
|
||||||
|
|
||||||
bayesNet.add(keys[8], P(keys[12], keys[14]), "T 1/4 3/2 4/1")
|
bayesNet.add(keys[8], [keys[12], keys[14]], "T 1/4 3/2 4/1")
|
||||||
bayesNet.add(keys[9], P(keys[12], keys[14]), "4/1 2/3 F 1/4")
|
bayesNet.add(keys[9], [keys[12], keys[14]], "4/1 2/3 F 1/4")
|
||||||
bayesNet.add(keys[10], P(keys[13], keys[14]), "1/4 3/2 2/3 4/1")
|
bayesNet.add(keys[10], [keys[13], keys[14]], "1/4 3/2 2/3 4/1")
|
||||||
bayesNet.add(keys[11], P(keys[13], keys[14]), "1/4 2/3 3/2 4/1")
|
bayesNet.add(keys[11], [keys[13], keys[14]], "1/4 2/3 3/2 4/1")
|
||||||
|
|
||||||
bayesNet.add(keys[12], P(keys[14]), "3/1 3/1")
|
bayesNet.add(keys[12], [keys[14]], "3/1 3/1")
|
||||||
bayesNet.add(keys[13], P(keys[14]), "1/3 3/1")
|
bayesNet.add(keys[13], [keys[14]], "1/3 3/1")
|
||||||
|
|
||||||
bayesNet.add(keys[14], P(), "1/3")
|
bayesNet.add(keys[14], "1/3")
|
||||||
|
|
||||||
# Create a factor graph out of the Bayes net.
|
# Create a factor graph out of the Bayes net.
|
||||||
factorGraph = DiscreteFactorGraph(bayesNet)
|
factorGraph = DiscreteFactorGraph(bayesNet)
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ class TestDiscreteConditional(GtsamTestCase):
|
||||||
def test_likelihood(self):
|
def test_likelihood(self):
|
||||||
X = (0, 2)
|
X = (0, 2)
|
||||||
Y = (1, 3)
|
Y = (1, 3)
|
||||||
conditional = DiscreteConditional(X, "2/8 4/6 5/5", Y)
|
conditional = DiscreteConditional(X, [Y], "2/8 4/6 5/5")
|
||||||
|
|
||||||
actual0 = conditional.likelihood(0)
|
actual0 = conditional.likelihood(0)
|
||||||
expected0 = DecisionTreeFactor(Y, "0.2 0.4 0.5")
|
expected0 = DecisionTreeFactor(Y, "0.2 0.4 0.5")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue