Additional DiscreteConditional constructors to support wrapper.
parent
a1b8f52da8
commit
3339517340
|
|
@ -23,6 +23,7 @@
|
|||
#include <boost/shared_ptr.hpp>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
#include <gtsam/inference/FactorGraph.h>
|
||||
#include <gtsam/discrete/DiscretePrior.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
|
@ -75,6 +76,11 @@ namespace gtsam {
|
|||
// Add inherited versions of add.
|
||||
using Base::add;
|
||||
|
||||
/** Add a DiscretePrior using a table or a string */
|
||||
void add(const DiscreteKey& key, const std::string& spec) {
|
||||
emplace_shared<DiscretePrior>(key, spec);
|
||||
}
|
||||
|
||||
/** Add a DiscreteCondtional */
|
||||
template <typename... Args>
|
||||
void add(Args&&... args) {
|
||||
|
|
|
|||
|
|
@ -81,6 +81,20 @@ public:
|
|||
const std::string& spec)
|
||||
: DiscreteConditional(Signature(key, parents, spec)) {}
|
||||
|
||||
/// No-parent specialization; can also use DiscretePrior.
|
||||
DiscreteConditional(const DiscreteKey& key, const std::string& spec)
|
||||
: DiscreteConditional(Signature(key, {}, spec)) {}
|
||||
|
||||
/// Single-parent specialization
|
||||
DiscreteConditional(const DiscreteKey& key, const std::string& spec,
|
||||
const DiscreteKey& parent1)
|
||||
: DiscreteConditional(Signature(key, {parent1}, spec)) {}
|
||||
|
||||
/// Two-parent specialization
|
||||
DiscreteConditional(const DiscreteKey& key, const std::string& spec,
|
||||
const DiscreteKey& parent1, const DiscreteKey& parent2)
|
||||
: DiscreteConditional(Signature(key, {parent1, parent2}, spec)) {}
|
||||
|
||||
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
|
||||
DiscreteConditional(const DecisionTreeFactor& joint,
|
||||
const DecisionTreeFactor& marginal);
|
||||
|
|
|
|||
|
|
@ -84,10 +84,17 @@ virtual class DiscretePrior : gtsam::DiscreteConditional {
|
|||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
class DiscreteBayesNet {
|
||||
class DiscreteBayesNet {
|
||||
DiscreteBayesNet();
|
||||
void add(const gtsam::DiscreteKey& key,
|
||||
const gtsam::DiscreteKeys& parents, string spec);
|
||||
void add(const gtsam::DiscreteConditional& s);
|
||||
void add(const gtsam::DiscreteKey& key, string spec);
|
||||
void add(const gtsam::DiscreteKey& key, string spec,
|
||||
const gtsam::DiscreteKey& parent1);
|
||||
void add(const gtsam::DiscreteKey& key, string spec,
|
||||
const gtsam::DiscreteKey& parent1,
|
||||
const gtsam::DiscreteKey& parent2);
|
||||
void add(const gtsam::DiscreteKey& key, const gtsam::DiscreteKeys& parents,
|
||||
string spec);
|
||||
bool empty() const;
|
||||
size_t size() const;
|
||||
gtsam::KeySet keys() const;
|
||||
|
|
@ -98,15 +105,13 @@ class DiscreteBayesNet {
|
|||
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
|
||||
string dot(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
void saveGraph(string s,
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
void add(const gtsam::DiscreteConditional& s);
|
||||
void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
gtsam::DiscreteValues optimize() const;
|
||||
gtsam::DiscreteValues sample() const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||
|
|
|
|||
|
|
@ -75,8 +75,8 @@ TEST(DiscreteBayesNet, bayesNet) {
|
|||
TEST(DiscreteBayesNet, Asia) {
|
||||
DiscreteBayesNet asia;
|
||||
|
||||
asia.add(Asia % "99/1");
|
||||
asia.add(Smoking % "50/50");
|
||||
asia.add(Asia, "99/1");
|
||||
asia.add(Smoking % "50/50"); // Signature version
|
||||
|
||||
asia.add(Tuberculosis | Asia = "99/1 95/5");
|
||||
asia.add(LungCancer | Smoking = "99/1 90/10");
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ Author: Frank Dellaert
|
|||
import unittest
|
||||
|
||||
from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
|
||||
DiscreteKeys, DiscreteValues, Ordering)
|
||||
DiscreteKeys, DiscretePrior, DiscreteValues, Ordering)
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
|
||||
|
|
@ -53,24 +53,18 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
XRay = (2, 2)
|
||||
Dyspnea = (1, 2)
|
||||
|
||||
def P(keys):
|
||||
dks = DiscreteKeys()
|
||||
for key in keys:
|
||||
dks.push_back(key)
|
||||
return dks
|
||||
|
||||
asia = DiscreteBayesNet()
|
||||
asia.add(Asia, P([]), "99/1")
|
||||
asia.add(Smoking, P([]), "50/50")
|
||||
asia.add(Asia, "99/1")
|
||||
asia.add(Smoking, "50/50")
|
||||
|
||||
asia.add(Tuberculosis, P([Asia]), "99/1 95/5")
|
||||
asia.add(LungCancer, P([Smoking]), "99/1 90/10")
|
||||
asia.add(Bronchitis, P([Smoking]), "70/30 40/60")
|
||||
asia.add(Tuberculosis, "99/1 95/5", Asia)
|
||||
asia.add(LungCancer, "99/1 90/10", Smoking)
|
||||
asia.add(Bronchitis, "70/30 40/60", Smoking)
|
||||
|
||||
asia.add(Either, P([Tuberculosis, LungCancer]), "F T T T")
|
||||
asia.add(Either, "F T T T", Tuberculosis, LungCancer)
|
||||
|
||||
asia.add(XRay, P([Either]), "95/5 2/98")
|
||||
asia.add(Dyspnea, P([Either, Bronchitis]), "9/1 2/8 3/7 1/9")
|
||||
asia.add(XRay, "95/5 2/98", Either)
|
||||
asia.add(Dyspnea, "9/1 2/8 3/7 1/9", Either, Bronchitis)
|
||||
|
||||
# Convert to factor graph
|
||||
fg = DiscreteFactorGraph(asia)
|
||||
|
|
@ -80,7 +74,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
for j in range(8):
|
||||
ordering.push_back(j)
|
||||
chordal = fg.eliminateSequential(ordering)
|
||||
expected2 = DiscreteConditional(Bronchitis, P([]), "11/9")
|
||||
expected2 = DiscretePrior(Bronchitis, "11/9")
|
||||
self.gtsamAssertEquals(chordal.at(7), expected2)
|
||||
|
||||
# solve
|
||||
|
|
|
|||
Loading…
Reference in New Issue