Additional DiscreteConditional constructors to support wrapper.

release/4.3a0
Frank Dellaert 2021-12-26 16:05:05 -05:00
parent a1b8f52da8
commit 3339517340
5 changed files with 45 additions and 26 deletions

View File

@ -23,6 +23,7 @@
#include <boost/shared_ptr.hpp>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/discrete/DiscretePrior.h>
#include <gtsam/discrete/DiscreteConditional.h>
namespace gtsam {
@ -75,6 +76,11 @@ namespace gtsam {
// Add inherited versions of add.
using Base::add;
/** Add a DiscretePrior using a table or a string */
void add(const DiscreteKey& key, const std::string& spec) {
emplace_shared<DiscretePrior>(key, spec);
}
/** Add a DiscreteCondtional */
template <typename... Args>
void add(Args&&... args) {

View File

@ -81,6 +81,20 @@ public:
const std::string& spec)
: DiscreteConditional(Signature(key, parents, spec)) {}
/// No-parent specialization; can also use DiscretePrior.
DiscreteConditional(const DiscreteKey& key, const std::string& spec)
: DiscreteConditional(Signature(key, {}, spec)) {}
/// Single-parent specialization
DiscreteConditional(const DiscreteKey& key, const std::string& spec,
const DiscreteKey& parent1)
: DiscreteConditional(Signature(key, {parent1}, spec)) {}
/// Two-parent specialization
DiscreteConditional(const DiscreteKey& key, const std::string& spec,
const DiscreteKey& parent1, const DiscreteKey& parent2)
: DiscreteConditional(Signature(key, {parent1, parent2}, spec)) {}
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal);

View File

@ -84,10 +84,17 @@ virtual class DiscretePrior : gtsam::DiscreteConditional {
};
#include <gtsam/discrete/DiscreteBayesNet.h>
class DiscreteBayesNet {
class DiscreteBayesNet {
DiscreteBayesNet();
void add(const gtsam::DiscreteKey& key,
const gtsam::DiscreteKeys& parents, string spec);
void add(const gtsam::DiscreteConditional& s);
void add(const gtsam::DiscreteKey& key, string spec);
void add(const gtsam::DiscreteKey& key, string spec,
const gtsam::DiscreteKey& parent1);
void add(const gtsam::DiscreteKey& key, string spec,
const gtsam::DiscreteKey& parent1,
const gtsam::DiscreteKey& parent2);
void add(const gtsam::DiscreteKey& key, const gtsam::DiscreteKeys& parents,
string spec);
bool empty() const;
size_t size() const;
gtsam::KeySet keys() const;
@ -98,15 +105,13 @@ class DiscreteBayesNet {
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void saveGraph(string s,
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void add(const gtsam::DiscreteConditional& s);
void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const;
gtsam::DiscreteValues sample() const;
string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
gtsam::DefaultKeyFormatter) const;
};
#include <gtsam/discrete/DiscreteBayesTree.h>

View File

@ -75,8 +75,8 @@ TEST(DiscreteBayesNet, bayesNet) {
TEST(DiscreteBayesNet, Asia) {
DiscreteBayesNet asia;
asia.add(Asia % "99/1");
asia.add(Smoking % "50/50");
asia.add(Asia, "99/1");
asia.add(Smoking % "50/50"); // Signature version
asia.add(Tuberculosis | Asia = "99/1 95/5");
asia.add(LungCancer | Smoking = "99/1 90/10");

View File

@ -14,7 +14,7 @@ Author: Frank Dellaert
import unittest
from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
DiscreteKeys, DiscreteValues, Ordering)
DiscreteKeys, DiscretePrior, DiscreteValues, Ordering)
from gtsam.utils.test_case import GtsamTestCase
@ -53,24 +53,18 @@ class TestDiscreteBayesNet(GtsamTestCase):
XRay = (2, 2)
Dyspnea = (1, 2)
def P(keys):
dks = DiscreteKeys()
for key in keys:
dks.push_back(key)
return dks
asia = DiscreteBayesNet()
asia.add(Asia, P([]), "99/1")
asia.add(Smoking, P([]), "50/50")
asia.add(Asia, "99/1")
asia.add(Smoking, "50/50")
asia.add(Tuberculosis, P([Asia]), "99/1 95/5")
asia.add(LungCancer, P([Smoking]), "99/1 90/10")
asia.add(Bronchitis, P([Smoking]), "70/30 40/60")
asia.add(Tuberculosis, "99/1 95/5", Asia)
asia.add(LungCancer, "99/1 90/10", Smoking)
asia.add(Bronchitis, "70/30 40/60", Smoking)
asia.add(Either, P([Tuberculosis, LungCancer]), "F T T T")
asia.add(Either, "F T T T", Tuberculosis, LungCancer)
asia.add(XRay, P([Either]), "95/5 2/98")
asia.add(Dyspnea, P([Either, Bronchitis]), "9/1 2/8 3/7 1/9")
asia.add(XRay, "95/5 2/98", Either)
asia.add(Dyspnea, "9/1 2/8 3/7 1/9", Either, Bronchitis)
# Convert to factor graph
fg = DiscreteFactorGraph(asia)
@ -80,7 +74,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
for j in range(8):
ordering.push_back(j)
chordal = fg.eliminateSequential(ordering)
expected2 = DiscreteConditional(Bronchitis, P([]), "11/9")
expected2 = DiscretePrior(Bronchitis, "11/9")
self.gtsamAssertEquals(chordal.at(7), expected2)
# solve