Merge pull request #402 from borglab/feature/bayesnet_example
Two new discrete examplesrelease/4.3a0
commit
52fcdb51a8
|
@ -0,0 +1,83 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file DiscreteBayesNetExample.cpp
|
||||||
|
* @brief Discrete Bayes Net example with famous Asia Bayes Network
|
||||||
|
* @author Frank Dellaert
|
||||||
|
* @date JULY 10, 2020
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
|
#include <gtsam/discrete/DiscreteMarginals.h>
|
||||||
|
#include <gtsam/inference/BayesNet-inst.h>
|
||||||
|
|
||||||
|
#include <iomanip>
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace gtsam;
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
DiscreteBayesNet asia;
|
||||||
|
DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2),
|
||||||
|
Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
|
||||||
|
asia.add(Asia % "99/1");
|
||||||
|
asia.add(Smoking % "50/50");
|
||||||
|
|
||||||
|
asia.add(Tuberculosis | Asia = "99/1 95/5");
|
||||||
|
asia.add(LungCancer | Smoking = "99/1 90/10");
|
||||||
|
asia.add(Bronchitis | Smoking = "70/30 40/60");
|
||||||
|
|
||||||
|
asia.add((Either | Tuberculosis, LungCancer) = "F T T T");
|
||||||
|
|
||||||
|
asia.add(XRay | Either = "95/5 2/98");
|
||||||
|
asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");
|
||||||
|
|
||||||
|
// print
|
||||||
|
vector<string> pretty = {"Asia", "Dyspnea", "XRay", "Tuberculosis",
|
||||||
|
"Smoking", "Either", "LungCancer", "Bronchitis"};
|
||||||
|
auto formatter = [pretty](Key key) { return pretty[key]; };
|
||||||
|
asia.print("Asia", formatter);
|
||||||
|
|
||||||
|
// Convert to factor graph
|
||||||
|
DiscreteFactorGraph fg(asia);
|
||||||
|
|
||||||
|
// Create solver and eliminate
|
||||||
|
Ordering ordering;
|
||||||
|
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7);
|
||||||
|
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
|
||||||
|
|
||||||
|
// solve
|
||||||
|
DiscreteFactor::sharedValues mpe = chordal->optimize();
|
||||||
|
GTSAM_PRINT(*mpe);
|
||||||
|
|
||||||
|
// We can also build a Bayes tree (directed junction tree).
|
||||||
|
// The elimination order above will do fine:
|
||||||
|
auto bayesTree = fg.eliminateMultifrontal(ordering);
|
||||||
|
bayesTree->print("bayesTree", formatter);
|
||||||
|
|
||||||
|
// add evidence, we were in Asia and we have dyspnea
|
||||||
|
fg.add(Asia, "0 1");
|
||||||
|
fg.add(Dyspnea, "0 1");
|
||||||
|
|
||||||
|
// solve again, now with evidence
|
||||||
|
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
|
||||||
|
DiscreteFactor::sharedValues mpe2 = chordal2->optimize();
|
||||||
|
GTSAM_PRINT(*mpe2);
|
||||||
|
|
||||||
|
// We can also sample from it
|
||||||
|
cout << "\n10 samples:" << endl;
|
||||||
|
for (size_t i = 0; i < 10; i++) {
|
||||||
|
DiscreteFactor::sharedValues sample = chordal2->sample();
|
||||||
|
GTSAM_PRINT(*sample);
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
|
@ -10,7 +10,7 @@
|
||||||
* -------------------------------------------------------------------------- */
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @file DiscreteBayesNet_graph.cpp
|
* @file DiscreteBayesNet_FG.cpp
|
||||||
* @brief Discrete Bayes Net example using Factor Graphs
|
* @brief Discrete Bayes Net example using Factor Graphs
|
||||||
* @author Abhijit
|
* @author Abhijit
|
||||||
* @date Jun 4, 2012
|
* @date Jun 4, 2012
|
||||||
|
|
|
@ -0,0 +1,94 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010-2020, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file DiscreteBayesNetExample.cpp
|
||||||
|
* @brief Hidden Markov Model example, discrete.
|
||||||
|
* @author Frank Dellaert
|
||||||
|
* @date July 12, 2020
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
|
#include <gtsam/discrete/DiscreteMarginals.h>
|
||||||
|
#include <gtsam/inference/BayesNet-inst.h>
|
||||||
|
|
||||||
|
#include <iomanip>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace gtsam;
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
const int nrNodes = 4;
|
||||||
|
const size_t nrStates = 3;
|
||||||
|
|
||||||
|
// Define variables as well as ordering
|
||||||
|
Ordering ordering;
|
||||||
|
vector<DiscreteKey> keys;
|
||||||
|
for (int k = 0; k < nrNodes; k++) {
|
||||||
|
DiscreteKey key_i(k, nrStates);
|
||||||
|
keys.push_back(key_i);
|
||||||
|
ordering.emplace_back(k);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create HMM as a DiscreteBayesNet
|
||||||
|
DiscreteBayesNet hmm;
|
||||||
|
|
||||||
|
// Define backbone
|
||||||
|
const string transition = "8/1/1 1/8/1 1/1/8";
|
||||||
|
for (int k = 1; k < nrNodes; k++) {
|
||||||
|
hmm.add(keys[k] | keys[k - 1] = transition);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add some measurements, not needed for all time steps!
|
||||||
|
hmm.add(keys[0] % "7/2/1");
|
||||||
|
hmm.add(keys[1] % "1/9/0");
|
||||||
|
hmm.add(keys.back() % "5/4/1");
|
||||||
|
|
||||||
|
// print
|
||||||
|
hmm.print("HMM");
|
||||||
|
|
||||||
|
// Convert to factor graph
|
||||||
|
DiscreteFactorGraph factorGraph(hmm);
|
||||||
|
|
||||||
|
// Create solver and eliminate
|
||||||
|
// This will create a DAG ordered with arrow of time reversed
|
||||||
|
DiscreteBayesNet::shared_ptr chordal =
|
||||||
|
factorGraph.eliminateSequential(ordering);
|
||||||
|
chordal->print("Eliminated");
|
||||||
|
|
||||||
|
// solve
|
||||||
|
DiscreteFactor::sharedValues mpe = chordal->optimize();
|
||||||
|
GTSAM_PRINT(*mpe);
|
||||||
|
|
||||||
|
// We can also sample from it
|
||||||
|
cout << "\n10 samples:" << endl;
|
||||||
|
for (size_t k = 0; k < 10; k++) {
|
||||||
|
DiscreteFactor::sharedValues sample = chordal->sample();
|
||||||
|
GTSAM_PRINT(*sample);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Or compute the marginals. This re-eliminates the FG into a Bayes tree
|
||||||
|
cout << "\nComputing Node Marginals .." << endl;
|
||||||
|
DiscreteMarginals marginals(factorGraph);
|
||||||
|
for (int k = 0; k < nrNodes; k++) {
|
||||||
|
Vector margProbs = marginals.marginalProbabilities(keys[k]);
|
||||||
|
stringstream ss;
|
||||||
|
ss << "marginal " << k;
|
||||||
|
print(margProbs, ss.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(frank): put in the glue to have DiscreteMarginals produce *arbitrary*
|
||||||
|
// joints efficiently, by the Bayes tree shortcut magic. All the code is there
|
||||||
|
// but it's not yet connected.
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
|
@ -27,6 +27,7 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
@ -61,16 +62,26 @@ DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
DiscreteConditional::DiscreteConditional(const Signature& signature) :
|
DiscreteConditional::DiscreteConditional(const Signature& signature)
|
||||||
BaseFactor(signature.discreteKeysParentsFirst(), signature.cpt()), BaseConditional(
|
: BaseFactor(signature.discreteKeys(), signature.cpt()),
|
||||||
1) {
|
BaseConditional(1) {}
|
||||||
}
|
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
void DiscreteConditional::print(const std::string& s,
|
void DiscreteConditional::print(const string& s,
|
||||||
const KeyFormatter& formatter) const {
|
const KeyFormatter& formatter) const {
|
||||||
std::cout << s << std::endl;
|
cout << s << " P( ";
|
||||||
Potentials::print(s);
|
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
|
||||||
|
cout << formatter(*it) << " ";
|
||||||
|
}
|
||||||
|
if (nrParents()) {
|
||||||
|
cout << "| ";
|
||||||
|
for (const_iterator it = beginParents(); it != endParents(); ++it) {
|
||||||
|
cout << formatter(*it) << " ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cout << ")";
|
||||||
|
Potentials::print("");
|
||||||
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
|
@ -173,55 +184,28 @@ size_t DiscreteConditional::solve(const Values& parentsValues) const {
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
size_t DiscreteConditional::sample(const Values& parentsValues) const {
|
size_t DiscreteConditional::sample(const Values& parentsValues) const {
|
||||||
|
static mt19937 rng(2); // random number generator
|
||||||
static mt19937 rng(2); // random number generator
|
|
||||||
|
|
||||||
bool debug = ISDEBUG("DiscreteConditional::sample");
|
|
||||||
|
|
||||||
// Get the correct conditional density
|
// Get the correct conditional density
|
||||||
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
|
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
|
||||||
if (debug)
|
|
||||||
GTSAM_PRINT(pFS);
|
|
||||||
|
|
||||||
// get cumulative distribution function (cdf)
|
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||||
// TODO, only works for one key now, seems horribly slow this way
|
|
||||||
assert(nrFrontals() == 1);
|
assert(nrFrontals() == 1);
|
||||||
Key j = (firstFrontalKey());
|
Key key = firstFrontalKey();
|
||||||
size_t nj = cardinality(j);
|
size_t nj = cardinality(key);
|
||||||
vector<double> cdf(nj);
|
vector<double> p(nj);
|
||||||
Values frontals;
|
Values frontals;
|
||||||
double sum = 0;
|
|
||||||
for (size_t value = 0; value < nj; value++) {
|
for (size_t value = 0; value < nj; value++) {
|
||||||
frontals[j] = value;
|
frontals[key] = value;
|
||||||
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
|
p[value] = pFS(frontals); // P(F=value|S=parentsValues)
|
||||||
sum += pValueS; // accumulate
|
if (p[value] == 1.0) {
|
||||||
if (debug)
|
return value; // shortcut exit
|
||||||
cout << sum << " ";
|
|
||||||
if (pValueS == 1) {
|
|
||||||
if (debug)
|
|
||||||
cout << "--> " << value << endl;
|
|
||||||
return value; // shortcut exit
|
|
||||||
}
|
}
|
||||||
cdf[value] = sum;
|
|
||||||
}
|
}
|
||||||
|
std::discrete_distribution<size_t> distribution(p.begin(), p.end());
|
||||||
// inspired by http://www.boost.org/doc/libs/1_46_1/doc/html/boost_random/tutorial.html
|
return distribution(rng);
|
||||||
uniform_real_distribution<double> dist(0, cdf.back());
|
|
||||||
size_t sampled = lower_bound(cdf.begin(), cdf.end(), dist(rng)) - cdf.begin();
|
|
||||||
if (debug)
|
|
||||||
cout << "-> " << sampled << endl;
|
|
||||||
|
|
||||||
return sampled;
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
|
||||||
//void DiscreteConditional::permuteWithInverse(
|
|
||||||
// const Permutation& inversePermutation) {
|
|
||||||
// IndexConditionalOrdered::permuteWithInverse(inversePermutation);
|
|
||||||
// Potentials::permuteWithInverse(inversePermutation);
|
|
||||||
//}
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
|
|
||||||
}// namespace
|
}// namespace
|
||||||
|
|
|
@ -15,50 +15,52 @@
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/discrete/Potentials.h>
|
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||||
|
#include <gtsam/discrete/Potentials.h>
|
||||||
|
|
||||||
#include <boost/format.hpp>
|
#include <boost/format.hpp>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
// explicit instantiation
|
// explicit instantiation
|
||||||
template class DecisionTree<Key, double> ;
|
template class DecisionTree<Key, double>;
|
||||||
template class AlgebraicDecisionTree<Key> ;
|
template class AlgebraicDecisionTree<Key>;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double Potentials::safe_div(const double& a, const double& b) {
|
double Potentials::safe_div(const double& a, const double& b) {
|
||||||
// cout << boost::format("%g / %g = %g\n") % a % b % ((a == 0) ? 0 : (a / b));
|
// cout << boost::format("%g / %g = %g\n") % a % b % ((a == 0) ? 0 : (a / b));
|
||||||
// The use for safe_div is when we divide the product factor by the sum factor.
|
// The use for safe_div is when we divide the product factor by the sum
|
||||||
// If the product or sum is zero, we accord zero probability to the event.
|
// factor. If the product or sum is zero, we accord zero probability to the
|
||||||
return (a == 0 || b == 0) ? 0 : (a / b);
|
// event.
|
||||||
}
|
return (a == 0 || b == 0) ? 0 : (a / b);
|
||||||
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ********************************************************************************
|
||||||
Potentials::Potentials() :
|
*/
|
||||||
ADT(1.0) {
|
Potentials::Potentials() : ADT(1.0) {}
|
||||||
}
|
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ********************************************************************************
|
||||||
Potentials::Potentials(const DiscreteKeys& keys, const ADT& decisionTree) :
|
*/
|
||||||
ADT(decisionTree), cardinalities_(keys.cardinalities()) {
|
Potentials::Potentials(const DiscreteKeys& keys, const ADT& decisionTree)
|
||||||
}
|
: ADT(decisionTree), cardinalities_(keys.cardinalities()) {}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
bool Potentials::equals(const Potentials& other, double tol) const {
|
bool Potentials::equals(const Potentials& other, double tol) const {
|
||||||
return ADT::equals(other, tol);
|
return ADT::equals(other, tol);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void Potentials::print(const string& s,
|
void Potentials::print(const string& s, const KeyFormatter& formatter) const {
|
||||||
const KeyFormatter& formatter) const {
|
cout << s << "\n Cardinalities: {";
|
||||||
cout << s << "\n Cardinalities: ";
|
for (const DiscreteKey& key : cardinalities_)
|
||||||
for(const DiscreteKey& key: cardinalities_)
|
cout << formatter(key.first) << ":" << key.second << ", ";
|
||||||
cout << formatter(key.first) << "=" << formatter(key.second) << " ";
|
cout << "}" << endl;
|
||||||
cout << endl;
|
ADT::print(" ");
|
||||||
ADT::print(" ");
|
}
|
||||||
}
|
|
||||||
//
|
//
|
||||||
// /* ************************************************************************* */
|
// /* ************************************************************************* */
|
||||||
// template<class P>
|
// template<class P>
|
||||||
|
@ -95,4 +97,4 @@ namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -122,28 +122,30 @@ namespace gtsam {
|
||||||
key_(key) {
|
key_(key) {
|
||||||
}
|
}
|
||||||
|
|
||||||
DiscreteKeys Signature::discreteKeysParentsFirst() const {
|
DiscreteKeys Signature::discreteKeys() const {
|
||||||
DiscreteKeys keys;
|
DiscreteKeys keys;
|
||||||
for(const DiscreteKey& key: parents_)
|
|
||||||
keys.push_back(key);
|
|
||||||
keys.push_back(key_);
|
keys.push_back(key_);
|
||||||
|
for (const DiscreteKey& key : parents_) keys.push_back(key);
|
||||||
return keys;
|
return keys;
|
||||||
}
|
}
|
||||||
|
|
||||||
KeyVector Signature::indices() const {
|
KeyVector Signature::indices() const {
|
||||||
KeyVector js;
|
KeyVector js;
|
||||||
js.push_back(key_.first);
|
js.push_back(key_.first);
|
||||||
for(const DiscreteKey& key: parents_)
|
for (const DiscreteKey& key : parents_) js.push_back(key.first);
|
||||||
js.push_back(key.first);
|
|
||||||
return js;
|
return js;
|
||||||
}
|
}
|
||||||
|
|
||||||
vector<double> Signature::cpt() const {
|
vector<double> Signature::cpt() const {
|
||||||
vector<double> cpt;
|
vector<double> cpt;
|
||||||
if (table_) {
|
if (table_) {
|
||||||
for(const Row& row: *table_)
|
const size_t nrStates = table_->at(0).size();
|
||||||
for(const double& x: row)
|
for (size_t j = 0; j < nrStates; j++) {
|
||||||
cpt.push_back(x);
|
for (const Row& row : *table_) {
|
||||||
|
assert(row.size() == nrStates);
|
||||||
|
cpt.push_back(row[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return cpt;
|
return cpt;
|
||||||
}
|
}
|
||||||
|
|
|
@ -86,8 +86,8 @@ namespace gtsam {
|
||||||
return parents_;
|
return parents_;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** All keys, with variable key last */
|
/** All keys, with variable key first */
|
||||||
DiscreteKeys discreteKeysParentsFirst() const;
|
DiscreteKeys discreteKeys() const;
|
||||||
|
|
||||||
/** All key indices, with variable key first */
|
/** All key indices, with variable key first */
|
||||||
KeyVector indices() const;
|
KeyVector indices() const;
|
||||||
|
|
|
@ -132,7 +132,7 @@ TEST(ADT, example3)
|
||||||
|
|
||||||
/** Convert Signature into CPT */
|
/** Convert Signature into CPT */
|
||||||
ADT create(const Signature& signature) {
|
ADT create(const Signature& signature) {
|
||||||
ADT p(signature.discreteKeysParentsFirst(), signature.cpt());
|
ADT p(signature.discreteKeys(), signature.cpt());
|
||||||
static size_t count = 0;
|
static size_t count = 0;
|
||||||
const DiscreteKey& key = signature.key();
|
const DiscreteKey& key = signature.key();
|
||||||
string dotfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str();
|
string dotfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str();
|
||||||
|
@ -181,19 +181,20 @@ TEST(ADT, joint)
|
||||||
dot(joint, "Asia-ASTLBEX");
|
dot(joint, "Asia-ASTLBEX");
|
||||||
joint = apply(joint, pD, &mul);
|
joint = apply(joint, pD, &mul);
|
||||||
dot(joint, "Asia-ASTLBEXD");
|
dot(joint, "Asia-ASTLBEXD");
|
||||||
EXPECT_LONGS_EQUAL(346, (long)muls);
|
EXPECT_LONGS_EQUAL(346, muls);
|
||||||
gttoc_(asiaJoint);
|
gttoc_(asiaJoint);
|
||||||
tictoc_getNode(asiaJointNode, asiaJoint);
|
tictoc_getNode(asiaJointNode, asiaJoint);
|
||||||
elapsed = asiaJointNode->secs() + asiaJointNode->wall();
|
elapsed = asiaJointNode->secs() + asiaJointNode->wall();
|
||||||
tictoc_reset_();
|
tictoc_reset_();
|
||||||
printCounts("Asia joint");
|
printCounts("Asia joint");
|
||||||
|
|
||||||
|
// Form P(A,S,T,L) = P(A) P(S) P(T|A) P(L|S)
|
||||||
ADT pASTL = pA;
|
ADT pASTL = pA;
|
||||||
pASTL = apply(pASTL, pS, &mul);
|
pASTL = apply(pASTL, pS, &mul);
|
||||||
pASTL = apply(pASTL, pT, &mul);
|
pASTL = apply(pASTL, pT, &mul);
|
||||||
pASTL = apply(pASTL, pL, &mul);
|
pASTL = apply(pASTL, pL, &mul);
|
||||||
|
|
||||||
// test combine
|
// test combine to check that P(A) = \sum_{S,T,L} P(A,S,T,L)
|
||||||
ADT fAa = pASTL.combine(L, &add_).combine(T, &add_).combine(S, &add_);
|
ADT fAa = pASTL.combine(L, &add_).combine(T, &add_).combine(S, &add_);
|
||||||
EXPECT(assert_equal(pA, fAa));
|
EXPECT(assert_equal(pA, fAa));
|
||||||
ADT fAb = pASTL.combine(S, &add_).combine(T, &add_).combine(L, &add_);
|
ADT fAb = pASTL.combine(S, &add_).combine(T, &add_).combine(L, &add_);
|
||||||
|
|
|
@ -18,110 +18,135 @@
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/discrete/DiscreteMarginals.h>
|
||||||
#include <gtsam/base/debug.h>
|
#include <gtsam/base/debug.h>
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/base/Vector.h>
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
#include <boost/assign/std/map.hpp>
|
|
||||||
#include <boost/assign/list_inserter.hpp>
|
#include <boost/assign/list_inserter.hpp>
|
||||||
|
#include <boost/assign/std/map.hpp>
|
||||||
|
|
||||||
using namespace boost::assign;
|
using namespace boost::assign;
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteBayesNet, Asia)
|
TEST(DiscreteBayesNet, bayesNet) {
|
||||||
{
|
DiscreteBayesNet bayesNet;
|
||||||
|
DiscreteKey Parent(0, 2), Child(1, 2);
|
||||||
|
|
||||||
|
auto prior = boost::make_shared<DiscreteConditional>(Parent % "6/4");
|
||||||
|
CHECK(assert_equal(Potentials::ADT({Parent}, "0.6 0.4"),
|
||||||
|
(Potentials::ADT)*prior));
|
||||||
|
bayesNet.push_back(prior);
|
||||||
|
|
||||||
|
auto conditional =
|
||||||
|
boost::make_shared<DiscreteConditional>(Child | Parent = "7/3 8/2");
|
||||||
|
EXPECT_LONGS_EQUAL(1, *(conditional->beginFrontals()));
|
||||||
|
Potentials::ADT expected(Child & Parent, "0.7 0.8 0.3 0.2");
|
||||||
|
CHECK(assert_equal(expected, (Potentials::ADT)*conditional));
|
||||||
|
bayesNet.push_back(conditional);
|
||||||
|
|
||||||
|
DiscreteFactorGraph fg(bayesNet);
|
||||||
|
LONGS_EQUAL(2, fg.back()->size());
|
||||||
|
|
||||||
|
// Check the marginals
|
||||||
|
const double expectedMarginal[2]{0.4, 0.6 * 0.3 + 0.4 * 0.2};
|
||||||
|
DiscreteMarginals marginals(fg);
|
||||||
|
for (size_t j = 0; j < 2; j++) {
|
||||||
|
Vector FT = marginals.marginalProbabilities(DiscreteKey(j, 2));
|
||||||
|
EXPECT_DOUBLES_EQUAL(expectedMarginal[j], FT[1], 1e-3);
|
||||||
|
EXPECT_DOUBLES_EQUAL(FT[0], 1.0 - FT[1], 1e-9);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteBayesNet, Asia) {
|
||||||
DiscreteBayesNet asia;
|
DiscreteBayesNet asia;
|
||||||
// DiscreteKey A("Asia"), S("Smoking"), T("Tuberculosis"), L("LungCancer"), B(
|
DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2),
|
||||||
// "Bronchitis"), E("Either"), X("XRay"), D("Dyspnoea");
|
Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
|
||||||
DiscreteKey A(0,2), S(4,2), T(3,2), L(6,2), B(7,2), E(5,2), X(2,2), D(1,2);
|
|
||||||
|
|
||||||
// TODO: make a version that doesn't use the parser
|
asia.add(Asia % "99/1");
|
||||||
asia.add(A % "99/1");
|
asia.add(Smoking % "50/50");
|
||||||
asia.add(S % "50/50");
|
|
||||||
|
|
||||||
asia.add(T | A = "99/1 95/5");
|
asia.add(Tuberculosis | Asia = "99/1 95/5");
|
||||||
asia.add(L | S = "99/1 90/10");
|
asia.add(LungCancer | Smoking = "99/1 90/10");
|
||||||
asia.add(B | S = "70/30 40/60");
|
asia.add(Bronchitis | Smoking = "70/30 40/60");
|
||||||
|
|
||||||
asia.add((E | T, L) = "F T T T");
|
asia.add((Either | Tuberculosis, LungCancer) = "F T T T");
|
||||||
|
|
||||||
asia.add(X | E = "95/5 2/98");
|
asia.add(XRay | Either = "95/5 2/98");
|
||||||
// next lines are same as asia.add((D | E, B) = "9/1 2/8 3/7 1/9");
|
asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");
|
||||||
DiscreteConditional::shared_ptr actual =
|
|
||||||
boost::make_shared<DiscreteConditional>((D | E, B) = "9/1 2/8 3/7 1/9");
|
|
||||||
asia.push_back(actual);
|
|
||||||
// GTSAM_PRINT(asia);
|
|
||||||
|
|
||||||
// Convert to factor graph
|
// Convert to factor graph
|
||||||
DiscreteFactorGraph fg(asia);
|
DiscreteFactorGraph fg(asia);
|
||||||
// GTSAM_PRINT(fg);
|
LONGS_EQUAL(3, fg.back()->size());
|
||||||
LONGS_EQUAL(3,fg.back()->size());
|
|
||||||
Potentials::ADT expected(B & D & E, "0.9 0.3 0.1 0.7 0.2 0.1 0.8 0.9");
|
// Check the marginals we know (of the parent-less nodes)
|
||||||
CHECK(assert_equal(expected,(Potentials::ADT)*actual));
|
DiscreteMarginals marginals(fg);
|
||||||
|
Vector2 va(0.99, 0.01), vs(0.5, 0.5);
|
||||||
|
EXPECT(assert_equal(va, marginals.marginalProbabilities(Asia)));
|
||||||
|
EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking)));
|
||||||
|
|
||||||
// Create solver and eliminate
|
// Create solver and eliminate
|
||||||
Ordering ordering;
|
Ordering ordering;
|
||||||
ordering += Key(0),Key(1),Key(2),Key(3),Key(4),Key(5),Key(6),Key(7);
|
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7);
|
||||||
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
|
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
|
||||||
// GTSAM_PRINT(*chordal);
|
DiscreteConditional expected2(Bronchitis % "11/9");
|
||||||
DiscreteConditional expected2(B % "11/9");
|
EXPECT(assert_equal(expected2, *chordal->back()));
|
||||||
CHECK(assert_equal(expected2,*chordal->back()));
|
|
||||||
|
|
||||||
// solve
|
// solve
|
||||||
DiscreteFactor::sharedValues actualMPE = chordal->optimize();
|
DiscreteFactor::sharedValues actualMPE = chordal->optimize();
|
||||||
DiscreteFactor::Values expectedMPE;
|
DiscreteFactor::Values expectedMPE;
|
||||||
insert(expectedMPE)(A.first, 0)(D.first, 0)(X.first, 0)(T.first, 0)(S.first,
|
insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)(
|
||||||
0)(E.first, 0)(L.first, 0)(B.first, 0);
|
Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)(
|
||||||
|
LungCancer.first, 0)(Bronchitis.first, 0);
|
||||||
EXPECT(assert_equal(expectedMPE, *actualMPE));
|
EXPECT(assert_equal(expectedMPE, *actualMPE));
|
||||||
|
|
||||||
// add evidence, we were in Asia and we have Dispnoea
|
// add evidence, we were in Asia and we have dyspnea
|
||||||
fg.add(A, "0 1");
|
fg.add(Asia, "0 1");
|
||||||
fg.add(D, "0 1");
|
fg.add(Dyspnea, "0 1");
|
||||||
// fg.product().dot("fg");
|
|
||||||
|
|
||||||
// solve again, now with evidence
|
// solve again, now with evidence
|
||||||
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
|
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
|
||||||
// GTSAM_PRINT(*chordal2);
|
|
||||||
DiscreteFactor::sharedValues actualMPE2 = chordal2->optimize();
|
DiscreteFactor::sharedValues actualMPE2 = chordal2->optimize();
|
||||||
DiscreteFactor::Values expectedMPE2;
|
DiscreteFactor::Values expectedMPE2;
|
||||||
insert(expectedMPE2)(A.first, 1)(D.first, 1)(X.first, 0)(T.first, 0)(S.first,
|
insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)(
|
||||||
1)(E.first, 0)(L.first, 0)(B.first, 1);
|
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)(
|
||||||
|
LungCancer.first, 0)(Bronchitis.first, 1);
|
||||||
EXPECT(assert_equal(expectedMPE2, *actualMPE2));
|
EXPECT(assert_equal(expectedMPE2, *actualMPE2));
|
||||||
|
|
||||||
// now sample from it
|
// now sample from it
|
||||||
DiscreteFactor::Values expectedSample;
|
DiscreteFactor::Values expectedSample;
|
||||||
SETDEBUG("DiscreteConditional::sample", false);
|
SETDEBUG("DiscreteConditional::sample", false);
|
||||||
insert(expectedSample)(A.first, 1)(D.first, 1)(X.first, 1)(T.first, 0)(
|
insert(expectedSample)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 1)(
|
||||||
S.first, 1)(E.first, 1)(L.first, 1)(B.first, 0);
|
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 1)(
|
||||||
|
LungCancer.first, 1)(Bronchitis.first, 0);
|
||||||
DiscreteFactor::sharedValues actualSample = chordal2->sample();
|
DiscreteFactor::sharedValues actualSample = chordal2->sample();
|
||||||
EXPECT(assert_equal(expectedSample, *actualSample));
|
EXPECT(assert_equal(expectedSample, *actualSample));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST_UNSAFE(DiscreteBayesNet, Sugar)
|
TEST_UNSAFE(DiscreteBayesNet, Sugar) {
|
||||||
{
|
DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2);
|
||||||
DiscreteKey T(0,2), L(1,2), E(2,2), D(3,2), C(8,3), S(7,2);
|
|
||||||
|
|
||||||
DiscreteBayesNet bn;
|
DiscreteBayesNet bn;
|
||||||
|
|
||||||
// test some mistakes
|
|
||||||
// add(bn, D);
|
|
||||||
// add(bn, D | E);
|
|
||||||
// add(bn, D | E = "blah");
|
|
||||||
|
|
||||||
// try logic
|
// try logic
|
||||||
bn.add((E | T, L) = "OR");
|
bn.add((E | T, L) = "OR");
|
||||||
bn.add((E | T, L) = "AND");
|
bn.add((E | T, L) = "AND");
|
||||||
|
|
||||||
// // try multivalued
|
// try multivalued
|
||||||
bn.add(C % "1/1/2");
|
bn.add(C % "1/1/2");
|
||||||
bn.add(C | S = "1/1/2 5/2/3");
|
bn.add(C | S = "1/1/2 5/2/3");
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -130,4 +155,3 @@ int main() {
|
||||||
return TestRegistry::runAllTests(tr);
|
return TestRegistry::runAllTests(tr);
|
||||||
}
|
}
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
||||||
|
|
|
@ -80,6 +80,12 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
|
||||||
bayesTree->saveGraph("/tmp/discreteBayesTree.dot");
|
bayesTree->saveGraph("/tmp/discreteBayesTree.dot");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check frontals and parents
|
||||||
|
for (size_t i : {13, 14, 9, 3, 2, 8, 1, 0, 10, 5, 4}) {
|
||||||
|
auto clique_i = (*bayesTree)[i];
|
||||||
|
EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals()));
|
||||||
|
}
|
||||||
|
|
||||||
auto R = bayesTree->roots().front();
|
auto R = bayesTree->roots().front();
|
||||||
|
|
||||||
// Check whether BN and BT give the same answer on all configurations
|
// Check whether BN and BT give the same answer on all configurations
|
||||||
|
@ -104,16 +110,22 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
|
||||||
double px = bayesTree->evaluate(x);
|
double px = bayesTree->evaluate(x);
|
||||||
for (size_t i = 0; i < 15; i++)
|
for (size_t i = 0; i < 15; i++)
|
||||||
if (x[i]) marginals[i] += px;
|
if (x[i]) marginals[i] += px;
|
||||||
if (x[12] && x[14]) joint_12_14 += px;
|
if (x[12] && x[14]) {
|
||||||
if (x[9] && x[12] && x[14]) joint_9_12_14 += px;
|
joint_12_14 += px;
|
||||||
if (x[8] && x[12] && x[14]) joint_8_12_14 += px;
|
if (x[9]) joint_9_12_14 += px;
|
||||||
|
if (x[8]) joint_8_12_14 += px;
|
||||||
|
}
|
||||||
if (x[8] && x[12]) joint_8_12 += px;
|
if (x[8] && x[12]) joint_8_12 += px;
|
||||||
if (x[8] && x[2]) joint82 += px;
|
if (x[2]) {
|
||||||
if (x[1] && x[2]) joint12 += px;
|
if (x[8]) joint82 += px;
|
||||||
if (x[2] && x[4]) joint24 += px;
|
if (x[1]) joint12 += px;
|
||||||
if (x[4] && x[5]) joint45 += px;
|
}
|
||||||
if (x[4] && x[6]) joint46 += px;
|
if (x[4]) {
|
||||||
if (x[4] && x[11]) joint_4_11 += px;
|
if (x[2]) joint24 += px;
|
||||||
|
if (x[5]) joint45 += px;
|
||||||
|
if (x[6]) joint46 += px;
|
||||||
|
if (x[11]) joint_4_11 += px;
|
||||||
|
}
|
||||||
if (x[11] && x[13]) {
|
if (x[11] && x[13]) {
|
||||||
joint_11_13 += px;
|
joint_11_13 += px;
|
||||||
if (x[8] && x[12]) joint_8_11_12_13 += px;
|
if (x[8] && x[12]) joint_8_11_12_13 += px;
|
||||||
|
@ -129,50 +141,50 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
|
||||||
DiscreteFactor::Values all1 = allPosbValues.back();
|
DiscreteFactor::Values all1 = allPosbValues.back();
|
||||||
|
|
||||||
// check separator marginal P(S0)
|
// check separator marginal P(S0)
|
||||||
auto c = (*bayesTree)[0];
|
auto clique = (*bayesTree)[0];
|
||||||
DiscreteFactorGraph separatorMarginal0 =
|
DiscreteFactorGraph separatorMarginal0 =
|
||||||
c->separatorMarginal(EliminateDiscrete);
|
clique->separatorMarginal(EliminateDiscrete);
|
||||||
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
|
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
|
||||||
|
|
||||||
// check separator marginal P(S9), should be P(14)
|
// check separator marginal P(S9), should be P(14)
|
||||||
c = (*bayesTree)[9];
|
clique = (*bayesTree)[9];
|
||||||
DiscreteFactorGraph separatorMarginal9 =
|
DiscreteFactorGraph separatorMarginal9 =
|
||||||
c->separatorMarginal(EliminateDiscrete);
|
clique->separatorMarginal(EliminateDiscrete);
|
||||||
DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
|
DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
|
||||||
|
|
||||||
// check separator marginal of root, should be empty
|
// check separator marginal of root, should be empty
|
||||||
c = (*bayesTree)[11];
|
clique = (*bayesTree)[11];
|
||||||
DiscreteFactorGraph separatorMarginal11 =
|
DiscreteFactorGraph separatorMarginal11 =
|
||||||
c->separatorMarginal(EliminateDiscrete);
|
clique->separatorMarginal(EliminateDiscrete);
|
||||||
LONGS_EQUAL(0, separatorMarginal11.size());
|
LONGS_EQUAL(0, separatorMarginal11.size());
|
||||||
|
|
||||||
// check shortcut P(S9||R) to root
|
// check shortcut P(S9||R) to root
|
||||||
c = (*bayesTree)[9];
|
clique = (*bayesTree)[9];
|
||||||
DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
|
DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||||
LONGS_EQUAL(1, shortcut.size());
|
LONGS_EQUAL(1, shortcut.size());
|
||||||
DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||||
|
|
||||||
// check shortcut P(S8||R) to root
|
// check shortcut P(S8||R) to root
|
||||||
c = (*bayesTree)[8];
|
clique = (*bayesTree)[8];
|
||||||
shortcut = c->shortcut(R, EliminateDiscrete);
|
shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||||
DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||||
|
|
||||||
// check shortcut P(S2||R) to root
|
// check shortcut P(S2||R) to root
|
||||||
c = (*bayesTree)[2];
|
clique = (*bayesTree)[2];
|
||||||
shortcut = c->shortcut(R, EliminateDiscrete);
|
shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||||
DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||||
|
|
||||||
// check shortcut P(S0||R) to root
|
// check shortcut P(S0||R) to root
|
||||||
c = (*bayesTree)[0];
|
clique = (*bayesTree)[0];
|
||||||
shortcut = c->shortcut(R, EliminateDiscrete);
|
shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||||
DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||||
|
|
||||||
// calculate all shortcuts to root
|
// calculate all shortcuts to root
|
||||||
DiscreteBayesTree::Nodes cliques = bayesTree->nodes();
|
DiscreteBayesTree::Nodes cliques = bayesTree->nodes();
|
||||||
for (auto c : cliques) {
|
for (auto clique : cliques) {
|
||||||
DiscreteBayesNet shortcut = c.second->shortcut(R, EliminateDiscrete);
|
DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete);
|
||||||
if (debug) {
|
if (debug) {
|
||||||
c.second->conditional_->printSignature();
|
clique.second->conditional_->printSignature();
|
||||||
shortcut.print("shortcut:");
|
shortcut.print("shortcut:");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,9 +16,9 @@
|
||||||
* @date Feb 14, 2011
|
* @date Feb 14, 2011
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
|
||||||
#include <boost/assign/std/map.hpp>
|
#include <boost/assign/std/map.hpp>
|
||||||
#include <boost/assign/std/vector.hpp>
|
#include <boost/assign/std/vector.hpp>
|
||||||
|
#include <boost/make_shared.hpp>
|
||||||
using namespace boost::assign;
|
using namespace boost::assign;
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
@ -36,6 +36,11 @@ TEST( DiscreteConditional, constructors)
|
||||||
DiscreteConditional::shared_ptr expected1 = //
|
DiscreteConditional::shared_ptr expected1 = //
|
||||||
boost::make_shared<DiscreteConditional>(X | Y = "1/1 2/3 1/4");
|
boost::make_shared<DiscreteConditional>(X | Y = "1/1 2/3 1/4");
|
||||||
EXPECT(expected1);
|
EXPECT(expected1);
|
||||||
|
EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals()));
|
||||||
|
EXPECT_LONGS_EQUAL(2, *(expected1->beginParents()));
|
||||||
|
EXPECT(expected1->endParents() == expected1->end());
|
||||||
|
EXPECT(expected1->endFrontals() == expected1->beginParents());
|
||||||
|
|
||||||
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
||||||
DiscreteConditional actual1(1, f1);
|
DiscreteConditional actual1(1, f1);
|
||||||
EXPECT(assert_equal(*expected1, actual1, 1e-9));
|
EXPECT(assert_equal(*expected1, actual1, 1e-9));
|
||||||
|
@ -43,71 +48,68 @@ TEST( DiscreteConditional, constructors)
|
||||||
DecisionTreeFactor f2(X & Y & Z,
|
DecisionTreeFactor f2(X & Y & Z,
|
||||||
"0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
"0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||||
DiscreteConditional actual2(1, f2);
|
DiscreteConditional actual2(1, f2);
|
||||||
DecisionTreeFactor::shared_ptr actual2factor = actual2.toFactor();
|
EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9));
|
||||||
// EXPECT(assert_equal(f2, *actual2factor, 1e-9));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( DiscreteConditional, constructors_alt_interface)
|
TEST(DiscreteConditional, constructors_alt_interface) {
|
||||||
{
|
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
|
||||||
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
|
|
||||||
|
|
||||||
Signature::Table table;
|
Signature::Table table;
|
||||||
Signature::Row r1, r2, r3;
|
Signature::Row r1, r2, r3;
|
||||||
r1 += 1.0, 1.0; r2 += 2.0, 3.0; r3 += 1.0, 4.0;
|
r1 += 1.0, 1.0;
|
||||||
|
r2 += 2.0, 3.0;
|
||||||
|
r3 += 1.0, 4.0;
|
||||||
table += r1, r2, r3;
|
table += r1, r2, r3;
|
||||||
DiscreteConditional::shared_ptr expected1 = //
|
auto actual1 = boost::make_shared<DiscreteConditional>(X | Y = table);
|
||||||
boost::make_shared<DiscreteConditional>(X | Y = table);
|
EXPECT(actual1);
|
||||||
EXPECT(expected1);
|
|
||||||
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
||||||
DiscreteConditional actual1(1, f1);
|
DiscreteConditional expected1(1, f1);
|
||||||
EXPECT(assert_equal(*expected1, actual1, 1e-9));
|
EXPECT(assert_equal(expected1, *actual1, 1e-9));
|
||||||
|
|
||||||
DecisionTreeFactor f2(X & Y & Z,
|
DecisionTreeFactor f2(
|
||||||
"0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||||
DiscreteConditional actual2(1, f2);
|
DiscreteConditional actual2(1, f2);
|
||||||
DecisionTreeFactor::shared_ptr actual2factor = actual2.toFactor();
|
EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9));
|
||||||
// EXPECT(assert_equal(f2, *actual2factor, 1e-9));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( DiscreteConditional, constructors2)
|
TEST(DiscreteConditional, constructors2) {
|
||||||
{
|
|
||||||
// Declare keys and ordering
|
// Declare keys and ordering
|
||||||
DiscreteKey C(0,2), B(1,2);
|
DiscreteKey C(0, 2), B(1, 2);
|
||||||
DecisionTreeFactor expected(C & B, "0.8 0.75 0.2 0.25");
|
DecisionTreeFactor actual(C & B, "0.8 0.75 0.2 0.25");
|
||||||
Signature signature((C | B) = "4/1 3/1");
|
Signature signature((C | B) = "4/1 3/1");
|
||||||
DiscreteConditional actual(signature);
|
DiscreteConditional expected(signature);
|
||||||
DecisionTreeFactor::shared_ptr actualFactor = actual.toFactor();
|
DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor();
|
||||||
EXPECT(assert_equal(expected, *actualFactor));
|
EXPECT(assert_equal(*expectedFactor, actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( DiscreteConditional, constructors3)
|
TEST(DiscreteConditional, constructors3) {
|
||||||
{
|
|
||||||
// Declare keys and ordering
|
// Declare keys and ordering
|
||||||
DiscreteKey C(0,2), B(1,2), A(2,2);
|
DiscreteKey C(0, 2), B(1, 2), A(2, 2);
|
||||||
DecisionTreeFactor expected(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8");
|
DecisionTreeFactor actual(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8");
|
||||||
Signature signature((C | B, A) = "4/1 1/1 1/1 1/4");
|
Signature signature((C | B, A) = "4/1 1/1 1/1 1/4");
|
||||||
DiscreteConditional actual(signature);
|
DiscreteConditional expected(signature);
|
||||||
DecisionTreeFactor::shared_ptr actualFactor = actual.toFactor();
|
DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor();
|
||||||
EXPECT(assert_equal(expected, *actualFactor));
|
EXPECT(assert_equal(*expectedFactor, actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( DiscreteConditional, Combine) {
|
TEST(DiscreteConditional, Combine) {
|
||||||
DiscreteKey A(0, 2), B(1, 2);
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
vector<DiscreteConditional::shared_ptr> c;
|
vector<DiscreteConditional::shared_ptr> c;
|
||||||
c.push_back(boost::make_shared<DiscreteConditional>(A | B = "1/2 2/1"));
|
c.push_back(boost::make_shared<DiscreteConditional>(A | B = "1/2 2/1"));
|
||||||
c.push_back(boost::make_shared<DiscreteConditional>(B % "1/2"));
|
c.push_back(boost::make_shared<DiscreteConditional>(B % "1/2"));
|
||||||
DecisionTreeFactor factor(A & B, "0.111111 0.444444 0.222222 0.222222");
|
DecisionTreeFactor factor(A & B, "0.111111 0.444444 0.222222 0.222222");
|
||||||
DiscreteConditional expected(2, factor);
|
DiscreteConditional actual(2, factor);
|
||||||
DiscreteConditional::shared_ptr actual = DiscreteConditional::Combine(
|
auto expected = DiscreteConditional::Combine(c.begin(), c.end());
|
||||||
c.begin(), c.end());
|
EXPECT(assert_equal(*expected, actual, 1e-5));
|
||||||
EXPECT(assert_equal(expected, *actual,1e-5));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() { TestResult tr; return TestRegistry::runAllTests(tr); }
|
int main() {
|
||||||
|
TestResult tr;
|
||||||
|
return TestRegistry::runAllTests(tr);
|
||||||
|
}
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
||||||
|
|
|
@ -146,8 +146,7 @@ TEST_UNSAFE( DiscreteMarginals, truss ) {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Second truss example with non-trivial factors
|
// Second truss example with non-trivial factors
|
||||||
TEST_UNSAFE( DiscreteMarginals, truss2 ) {
|
TEST_UNSAFE(DiscreteMarginals, truss2) {
|
||||||
|
|
||||||
const int nrNodes = 5;
|
const int nrNodes = 5;
|
||||||
const size_t nrStates = 2;
|
const size_t nrStates = 2;
|
||||||
|
|
||||||
|
@ -160,40 +159,39 @@ TEST_UNSAFE( DiscreteMarginals, truss2 ) {
|
||||||
|
|
||||||
// create graph and add three truss potentials
|
// create graph and add three truss potentials
|
||||||
DiscreteFactorGraph graph;
|
DiscreteFactorGraph graph;
|
||||||
graph.add(key[0] & key[2] & key[4],"1 2 3 4 5 6 7 8");
|
graph.add(key[0] & key[2] & key[4], "1 2 3 4 5 6 7 8");
|
||||||
graph.add(key[1] & key[3] & key[4],"1 2 3 4 5 6 7 8");
|
graph.add(key[1] & key[3] & key[4], "1 2 3 4 5 6 7 8");
|
||||||
graph.add(key[2] & key[3] & key[4],"1 2 3 4 5 6 7 8");
|
graph.add(key[2] & key[3] & key[4], "1 2 3 4 5 6 7 8");
|
||||||
|
|
||||||
// Calculate the marginals by brute force
|
// Calculate the marginals by brute force
|
||||||
vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
|
vector<DiscreteFactor::Values> allPosbValues =
|
||||||
key[0] & key[1] & key[2] & key[3] & key[4]);
|
cartesianProduct(key[0] & key[1] & key[2] & key[3] & key[4]);
|
||||||
Vector T = Z_5x1, F = Z_5x1;
|
Vector T = Z_5x1, F = Z_5x1;
|
||||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||||
DiscreteFactor::Values x = allPosbValues[i];
|
DiscreteFactor::Values x = allPosbValues[i];
|
||||||
double px = graph(x);
|
double px = graph(x);
|
||||||
for (size_t j=0;j<5;j++)
|
for (size_t j = 0; j < 5; j++)
|
||||||
if (x[j]) T[j]+=px; else F[j]+=px;
|
if (x[j])
|
||||||
// cout << x[0] << " " << x[1] << " "<< x[2] << " " << x[3] << " " << x[4] << " :\t" << px << endl;
|
T[j] += px;
|
||||||
|
else
|
||||||
|
F[j] += px;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check all marginals given by a sequential solver and Marginals
|
// Check all marginals given by a sequential solver and Marginals
|
||||||
// DiscreteSequentialSolver solver(graph);
|
// DiscreteSequentialSolver solver(graph);
|
||||||
DiscreteMarginals marginals(graph);
|
DiscreteMarginals marginals(graph);
|
||||||
for (size_t j=0;j<5;j++) {
|
for (size_t j = 0; j < 5; j++) {
|
||||||
double sum = T[j]+F[j];
|
double sum = T[j] + F[j];
|
||||||
T[j]/=sum;
|
T[j] /= sum;
|
||||||
F[j]/=sum;
|
F[j] /= sum;
|
||||||
|
|
||||||
// // solver
|
|
||||||
// Vector actualV = solver.marginalProbabilities(key[j]);
|
|
||||||
// EXPECT(assert_equal((Vector(2) << F[j], T[j]), actualV));
|
|
||||||
|
|
||||||
// Marginals
|
// Marginals
|
||||||
vector<double> table;
|
vector<double> table;
|
||||||
table += F[j],T[j];
|
table += F[j], T[j];
|
||||||
DecisionTreeFactor expectedM(key[j],table);
|
DecisionTreeFactor expectedM(key[j], table);
|
||||||
DiscreteFactor::shared_ptr actualM = marginals(j);
|
DiscreteFactor::shared_ptr actualM = marginals(j);
|
||||||
EXPECT(assert_equal(expectedM, *boost::dynamic_pointer_cast<DecisionTreeFactor>(actualM)));
|
EXPECT(assert_equal(
|
||||||
|
expectedM, *boost::dynamic_pointer_cast<DecisionTreeFactor>(actualM)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -11,36 +11,43 @@
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @file testSignature
|
* @file testSignature
|
||||||
* @brief Tests focusing on the details of Signatures to evaluate boost compliance
|
* @brief Tests focusing on the details of Signatures to evaluate boost
|
||||||
|
* compliance
|
||||||
* @author Alex Cunningham
|
* @author Alex Cunningham
|
||||||
* @date Sept 19th 2011
|
* @date Sept 19th 2011
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <boost/assign/std/vector.hpp>
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
|
||||||
|
#include <boost/assign/std/vector.hpp>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
using namespace boost::assign;
|
using namespace boost::assign;
|
||||||
|
|
||||||
DiscreteKey X(0,2), Y(1,3), Z(2,2);
|
DiscreteKey X(0, 2), Y(1, 3), Z(2, 2);
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(testSignature, simple_conditional) {
|
TEST(testSignature, simple_conditional) {
|
||||||
Signature sig(X | Y = "1/1 2/3 1/4");
|
Signature sig(X | Y = "1/1 2/3 1/4");
|
||||||
|
Signature::Table table = *sig.table();
|
||||||
|
vector<double> row[3]{{0.5, 0.5}, {0.4, 0.6}, {0.2, 0.8}};
|
||||||
|
CHECK(row[0] == table[0]);
|
||||||
|
CHECK(row[1] == table[1]);
|
||||||
|
CHECK(row[2] == table[2]);
|
||||||
DiscreteKey actKey = sig.key();
|
DiscreteKey actKey = sig.key();
|
||||||
LONGS_EQUAL((long)X.first, (long)actKey.first);
|
LONGS_EQUAL(X.first, actKey.first);
|
||||||
|
|
||||||
DiscreteKeys actKeys = sig.discreteKeysParentsFirst();
|
DiscreteKeys actKeys = sig.discreteKeys();
|
||||||
LONGS_EQUAL(2, (long)actKeys.size());
|
LONGS_EQUAL(2, actKeys.size());
|
||||||
LONGS_EQUAL((long)Y.first, (long)actKeys.front().first);
|
LONGS_EQUAL(X.first, actKeys.front().first);
|
||||||
LONGS_EQUAL((long)X.first, (long)actKeys.back().first);
|
LONGS_EQUAL(Y.first, actKeys.back().first);
|
||||||
|
|
||||||
vector<double> actCpt = sig.cpt();
|
vector<double> actCpt = sig.cpt();
|
||||||
EXPECT_LONGS_EQUAL(6, (long)actCpt.size());
|
EXPECT_LONGS_EQUAL(6, actCpt.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -54,17 +61,20 @@ TEST(testSignature, simple_conditional_nonparser) {
|
||||||
|
|
||||||
Signature sig(X | Y = table);
|
Signature sig(X | Y = table);
|
||||||
DiscreteKey actKey = sig.key();
|
DiscreteKey actKey = sig.key();
|
||||||
EXPECT_LONGS_EQUAL((long)X.first, (long)actKey.first);
|
EXPECT_LONGS_EQUAL(X.first, actKey.first);
|
||||||
|
|
||||||
DiscreteKeys actKeys = sig.discreteKeysParentsFirst();
|
DiscreteKeys actKeys = sig.discreteKeys();
|
||||||
LONGS_EQUAL(2, (long)actKeys.size());
|
LONGS_EQUAL(2, actKeys.size());
|
||||||
LONGS_EQUAL((long)Y.first, (long)actKeys.front().first);
|
LONGS_EQUAL(X.first, actKeys.front().first);
|
||||||
LONGS_EQUAL((long)X.first, (long)actKeys.back().first);
|
LONGS_EQUAL(Y.first, actKeys.back().first);
|
||||||
|
|
||||||
vector<double> actCpt = sig.cpt();
|
vector<double> actCpt = sig.cpt();
|
||||||
EXPECT_LONGS_EQUAL(6, (long)actCpt.size());
|
EXPECT_LONGS_EQUAL(6, actCpt.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() { TestResult tr; return TestRegistry::runAllTests(tr); }
|
int main() {
|
||||||
|
TestResult tr;
|
||||||
|
return TestRegistry::runAllTests(tr);
|
||||||
|
}
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
Loading…
Reference in New Issue