Cleaned up tests

release/4.3a0
Frank dellaert 2020-07-11 13:18:48 -04:00
parent 550dc377e3
commit 4c7ba2a98f
4 changed files with 100 additions and 109 deletions

View File

@ -25,8 +25,9 @@
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <boost/assign/std/map.hpp>
#include <boost/assign/list_inserter.hpp> #include <boost/assign/list_inserter.hpp>
#include <boost/assign/std/map.hpp>
using namespace boost::assign; using namespace boost::assign;
@ -68,94 +69,84 @@ TEST(DiscreteBayesNet, bayesNet) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscreteBayesNet, Asia) TEST(DiscreteBayesNet, Asia) {
{
DiscreteBayesNet asia; DiscreteBayesNet asia;
// DiscreteKey A("Asia"), S("Smoking"), T("Tuberculosis"), L("LungCancer"), B( DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2),
// "Bronchitis"), E("Either"), X("XRay"), D("Dyspnoea"); Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
DiscreteKey A(0,2), S(4,2), T(3,2), L(6,2), B(7,2), E(5,2), X(2,2), D(1,2);
// TODO: make a version that doesn't use the parser asia.add(Asia % "99/1");
asia.add(A % "99/1"); asia.add(Smoking % "50/50");
asia.add(S % "50/50");
asia.add(T | A = "99/1 95/5"); asia.add(Tuberculosis | Asia = "99/1 95/5");
asia.add(L | S = "99/1 90/10"); asia.add(LungCancer | Smoking = "99/1 90/10");
asia.add(B | S = "70/30 40/60"); asia.add(Bronchitis | Smoking = "70/30 40/60");
asia.add((E | T, L) = "F T T T"); asia.add((Either | Tuberculosis, LungCancer) = "F T T T");
asia.add(X | E = "95/5 2/98"); asia.add(XRay | Either = "95/5 2/98");
// next lines are same as asia.add((D | E, B) = "9/1 2/8 3/7 1/9"); asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");
DiscreteConditional::shared_ptr actual =
boost::make_shared<DiscreteConditional>((D | E, B) = "9/1 2/8 3/7 1/9");
asia.push_back(actual);
// GTSAM_PRINT(asia);
// Convert to factor graph // Convert to factor graph
DiscreteFactorGraph fg(asia); DiscreteFactorGraph fg(asia);
// GTSAM_PRINT(fg); LONGS_EQUAL(3, fg.back()->size());
LONGS_EQUAL(3,fg.back()->size());
Potentials::ADT expected(B & D & E, "0.9 0.3 0.1 0.7 0.2 0.1 0.8 0.9"); // Check the marginals we know (of the parent-less nodes)
CHECK(assert_equal(expected,(Potentials::ADT)*actual)); DiscreteMarginals marginals(fg);
Vector2 va(0.99, 0.01), vs(0.5, 0.5);
EXPECT(assert_equal(va, marginals.marginalProbabilities(Asia)));
EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking)));
// Create solver and eliminate // Create solver and eliminate
Ordering ordering; Ordering ordering;
ordering += Key(0),Key(1),Key(2),Key(3),Key(4),Key(5),Key(6),Key(7); ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7);
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
// GTSAM_PRINT(*chordal); DiscreteConditional expected2(Bronchitis % "11/9");
DiscreteConditional expected2(B % "11/9"); EXPECT(assert_equal(expected2, *chordal->back()));
CHECK(assert_equal(expected2,*chordal->back()));
// solve // solve
DiscreteFactor::sharedValues actualMPE = chordal->optimize(); DiscreteFactor::sharedValues actualMPE = chordal->optimize();
DiscreteFactor::Values expectedMPE; DiscreteFactor::Values expectedMPE;
insert(expectedMPE)(A.first, 0)(D.first, 0)(X.first, 0)(T.first, 0)(S.first, insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)(
0)(E.first, 0)(L.first, 0)(B.first, 0); Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)(
LungCancer.first, 0)(Bronchitis.first, 0);
EXPECT(assert_equal(expectedMPE, *actualMPE)); EXPECT(assert_equal(expectedMPE, *actualMPE));
// add evidence, we were in Asia and we have Dispnoea // add evidence, we were in Asia and we have dyspnea
fg.add(A, "0 1"); fg.add(Asia, "0 1");
fg.add(D, "0 1"); fg.add(Dyspnea, "0 1");
// fg.product().dot("fg");
// solve again, now with evidence // solve again, now with evidence
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering); DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
// GTSAM_PRINT(*chordal2);
DiscreteFactor::sharedValues actualMPE2 = chordal2->optimize(); DiscreteFactor::sharedValues actualMPE2 = chordal2->optimize();
DiscreteFactor::Values expectedMPE2; DiscreteFactor::Values expectedMPE2;
insert(expectedMPE2)(A.first, 1)(D.first, 1)(X.first, 0)(T.first, 0)(S.first, insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)(
1)(E.first, 0)(L.first, 0)(B.first, 1); Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)(
LungCancer.first, 0)(Bronchitis.first, 1);
EXPECT(assert_equal(expectedMPE2, *actualMPE2)); EXPECT(assert_equal(expectedMPE2, *actualMPE2));
// now sample from it // now sample from it
DiscreteFactor::Values expectedSample; DiscreteFactor::Values expectedSample;
SETDEBUG("DiscreteConditional::sample", false); SETDEBUG("DiscreteConditional::sample", false);
insert(expectedSample)(A.first, 1)(D.first, 1)(X.first, 1)(T.first, 0)( insert(expectedSample)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 1)(
S.first, 1)(E.first, 1)(L.first, 1)(B.first, 0); Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 1)(
LungCancer.first, 1)(Bronchitis.first, 0);
DiscreteFactor::sharedValues actualSample = chordal2->sample(); DiscreteFactor::sharedValues actualSample = chordal2->sample();
EXPECT(assert_equal(expectedSample, *actualSample)); EXPECT(assert_equal(expectedSample, *actualSample));
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST_UNSAFE(DiscreteBayesNet, Sugar) TEST_UNSAFE(DiscreteBayesNet, Sugar) {
{ DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2);
DiscreteKey T(0,2), L(1,2), E(2,2), D(3,2), C(8,3), S(7,2);
DiscreteBayesNet bn; DiscreteBayesNet bn;
// test some mistakes
// add(bn, D);
// add(bn, D | E);
// add(bn, D | E = "blah");
// try logic // try logic
bn.add((E | T, L) = "OR"); bn.add((E | T, L) = "OR");
bn.add((E | T, L) = "AND"); bn.add((E | T, L) = "AND");
// // try multivalued // try multivalued
bn.add(C % "1/1/2"); bn.add(C % "1/1/2");
bn.add(C | S = "1/1/2 5/2/3"); bn.add(C | S = "1/1/2 5/2/3");
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -164,4 +155,3 @@ int main() {
return TestRegistry::runAllTests(tr); return TestRegistry::runAllTests(tr);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -16,9 +16,9 @@
* @date Feb 14, 2011 * @date Feb 14, 2011
*/ */
#include <boost/make_shared.hpp>
#include <boost/assign/std/map.hpp> #include <boost/assign/std/map.hpp>
#include <boost/assign/std/vector.hpp> #include <boost/assign/std/vector.hpp>
#include <boost/make_shared.hpp>
using namespace boost::assign; using namespace boost::assign;
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
@ -48,71 +48,68 @@ TEST( DiscreteConditional, constructors)
DecisionTreeFactor f2(X & Y & Z, DecisionTreeFactor f2(X & Y & Z,
"0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2); DiscreteConditional actual2(1, f2);
DecisionTreeFactor::shared_ptr actual2factor = actual2.toFactor(); EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9));
// EXPECT(assert_equal(f2, *actual2factor, 1e-9));
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( DiscreteConditional, constructors_alt_interface) TEST(DiscreteConditional, constructors_alt_interface) {
{ DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
Signature::Table table; Signature::Table table;
Signature::Row r1, r2, r3; Signature::Row r1, r2, r3;
r1 += 1.0, 1.0; r2 += 2.0, 3.0; r3 += 1.0, 4.0; r1 += 1.0, 1.0;
r2 += 2.0, 3.0;
r3 += 1.0, 4.0;
table += r1, r2, r3; table += r1, r2, r3;
DiscreteConditional::shared_ptr expected1 = // auto actual1 = boost::make_shared<DiscreteConditional>(X | Y = table);
boost::make_shared<DiscreteConditional>(X | Y = table); EXPECT(actual1);
EXPECT(expected1);
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
DiscreteConditional actual1(1, f1); DiscreteConditional expected1(1, f1);
EXPECT(assert_equal(*expected1, actual1, 1e-9)); EXPECT(assert_equal(expected1, *actual1, 1e-9));
DecisionTreeFactor f2(X & Y & Z, DecisionTreeFactor f2(
"0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2); DiscreteConditional actual2(1, f2);
DecisionTreeFactor::shared_ptr actual2factor = actual2.toFactor(); EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9));
// EXPECT(assert_equal(f2, *actual2factor, 1e-9));
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( DiscreteConditional, constructors2) TEST(DiscreteConditional, constructors2) {
{
// Declare keys and ordering // Declare keys and ordering
DiscreteKey C(0,2), B(1,2); DiscreteKey C(0, 2), B(1, 2);
DecisionTreeFactor expected(C & B, "0.8 0.75 0.2 0.25"); DecisionTreeFactor actual(C & B, "0.8 0.75 0.2 0.25");
Signature signature((C | B) = "4/1 3/1"); Signature signature((C | B) = "4/1 3/1");
DiscreteConditional actual(signature); DiscreteConditional expected(signature);
DecisionTreeFactor::shared_ptr actualFactor = actual.toFactor(); DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor();
EXPECT(assert_equal(expected, *actualFactor)); EXPECT(assert_equal(*expectedFactor, actual));
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( DiscreteConditional, constructors3) TEST(DiscreteConditional, constructors3) {
{
// Declare keys and ordering // Declare keys and ordering
DiscreteKey C(0,2), B(1,2), A(2,2); DiscreteKey C(0, 2), B(1, 2), A(2, 2);
DecisionTreeFactor expected(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8"); DecisionTreeFactor actual(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8");
Signature signature((C | B, A) = "4/1 1/1 1/1 1/4"); Signature signature((C | B, A) = "4/1 1/1 1/1 1/4");
DiscreteConditional actual(signature); DiscreteConditional expected(signature);
DecisionTreeFactor::shared_ptr actualFactor = actual.toFactor(); DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor();
EXPECT(assert_equal(expected, *actualFactor)); EXPECT(assert_equal(*expectedFactor, actual));
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( DiscreteConditional, Combine) { TEST(DiscreteConditional, Combine) {
DiscreteKey A(0, 2), B(1, 2); DiscreteKey A(0, 2), B(1, 2);
vector<DiscreteConditional::shared_ptr> c; vector<DiscreteConditional::shared_ptr> c;
c.push_back(boost::make_shared<DiscreteConditional>(A | B = "1/2 2/1")); c.push_back(boost::make_shared<DiscreteConditional>(A | B = "1/2 2/1"));
c.push_back(boost::make_shared<DiscreteConditional>(B % "1/2")); c.push_back(boost::make_shared<DiscreteConditional>(B % "1/2"));
DecisionTreeFactor factor(A & B, "0.111111 0.444444 0.222222 0.222222"); DecisionTreeFactor factor(A & B, "0.111111 0.444444 0.222222 0.222222");
DiscreteConditional expected(2, factor); DiscreteConditional actual(2, factor);
DiscreteConditional::shared_ptr actual = DiscreteConditional::Combine( auto expected = DiscreteConditional::Combine(c.begin(), c.end());
c.begin(), c.end()); EXPECT(assert_equal(*expected, actual, 1e-5));
EXPECT(assert_equal(expected, *actual,1e-5));
} }
/* ************************************************************************* */ /* ************************************************************************* */
int main() { TestResult tr; return TestRegistry::runAllTests(tr); } int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -146,8 +146,7 @@ TEST_UNSAFE( DiscreteMarginals, truss ) {
/* ************************************************************************* */ /* ************************************************************************* */
// Second truss example with non-trivial factors // Second truss example with non-trivial factors
TEST_UNSAFE( DiscreteMarginals, truss2 ) { TEST_UNSAFE(DiscreteMarginals, truss2) {
const int nrNodes = 5; const int nrNodes = 5;
const size_t nrStates = 2; const size_t nrStates = 2;
@ -160,40 +159,39 @@ TEST_UNSAFE( DiscreteMarginals, truss2 ) {
// create graph and add three truss potentials // create graph and add three truss potentials
DiscreteFactorGraph graph; DiscreteFactorGraph graph;
graph.add(key[0] & key[2] & key[4],"1 2 3 4 5 6 7 8"); graph.add(key[0] & key[2] & key[4], "1 2 3 4 5 6 7 8");
graph.add(key[1] & key[3] & key[4],"1 2 3 4 5 6 7 8"); graph.add(key[1] & key[3] & key[4], "1 2 3 4 5 6 7 8");
graph.add(key[2] & key[3] & key[4],"1 2 3 4 5 6 7 8"); graph.add(key[2] & key[3] & key[4], "1 2 3 4 5 6 7 8");
// Calculate the marginals by brute force // Calculate the marginals by brute force
vector<DiscreteFactor::Values> allPosbValues = cartesianProduct( vector<DiscreteFactor::Values> allPosbValues =
key[0] & key[1] & key[2] & key[3] & key[4]); cartesianProduct(key[0] & key[1] & key[2] & key[3] & key[4]);
Vector T = Z_5x1, F = Z_5x1; Vector T = Z_5x1, F = Z_5x1;
for (size_t i = 0; i < allPosbValues.size(); ++i) { for (size_t i = 0; i < allPosbValues.size(); ++i) {
DiscreteFactor::Values x = allPosbValues[i]; DiscreteFactor::Values x = allPosbValues[i];
double px = graph(x); double px = graph(x);
for (size_t j=0;j<5;j++) for (size_t j = 0; j < 5; j++)
if (x[j]) T[j]+=px; else F[j]+=px; if (x[j])
// cout << x[0] << " " << x[1] << " "<< x[2] << " " << x[3] << " " << x[4] << " :\t" << px << endl; T[j] += px;
else
F[j] += px;
} }
// Check all marginals given by a sequential solver and Marginals // Check all marginals given by a sequential solver and Marginals
// DiscreteSequentialSolver solver(graph); // DiscreteSequentialSolver solver(graph);
DiscreteMarginals marginals(graph); DiscreteMarginals marginals(graph);
for (size_t j=0;j<5;j++) { for (size_t j = 0; j < 5; j++) {
double sum = T[j]+F[j]; double sum = T[j] + F[j];
T[j]/=sum; T[j] /= sum;
F[j]/=sum; F[j] /= sum;
// // solver
// Vector actualV = solver.marginalProbabilities(key[j]);
// EXPECT(assert_equal((Vector(2) << F[j], T[j]), actualV));
// Marginals // Marginals
vector<double> table; vector<double> table;
table += F[j],T[j]; table += F[j], T[j];
DecisionTreeFactor expectedM(key[j],table); DecisionTreeFactor expectedM(key[j], table);
DiscreteFactor::shared_ptr actualM = marginals(j); DiscreteFactor::shared_ptr actualM = marginals(j);
EXPECT(assert_equal(expectedM, *boost::dynamic_pointer_cast<DecisionTreeFactor>(actualM))); EXPECT(assert_equal(
expectedM, *boost::dynamic_pointer_cast<DecisionTreeFactor>(actualM)));
} }
} }

View File

@ -22,6 +22,7 @@
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <boost/assign/std/vector.hpp> #include <boost/assign/std/vector.hpp>
#include <vector>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
@ -32,6 +33,11 @@ DiscreteKey X(0, 2), Y(1, 3), Z(2, 2);
/* ************************************************************************* */ /* ************************************************************************* */
TEST(testSignature, simple_conditional) { TEST(testSignature, simple_conditional) {
Signature sig(X | Y = "1/1 2/3 1/4"); Signature sig(X | Y = "1/1 2/3 1/4");
Signature::Table table = *sig.table();
vector<double> row[3]{{0.5, 0.5}, {0.4, 0.6}, {0.2, 0.8}};
CHECK(row[0] == table[0]);
CHECK(row[1] == table[1]);
CHECK(row[2] == table[2]);
DiscreteKey actKey = sig.key(); DiscreteKey actKey = sig.key();
LONGS_EQUAL(X.first, actKey.first); LONGS_EQUAL(X.first, actKey.first);