gtsam/gtsam_unstable/discrete/tests/testCSP.cpp

227 lines
6.6 KiB
C++

/*
* testCSP.cpp
* @brief develop code for CSP solver
* @date Feb 5, 2012
* @author Frank Dellaert
*/
#include <gtsam_unstable/discrete/CSP.h>
#include <gtsam_unstable/discrete/Domain.h>
#include <boost/assign/std/map.hpp>
using boost::assign::insert;
#include <CppUnitLite/TestHarness.h>
#include <iostream>
#include <fstream>
using namespace std;
using namespace gtsam;
/* ************************************************************************* */
TEST_UNSAFE( BinaryAllDif, allInOne)
{
// Create keys and ordering
size_t nrColors = 2;
// DiscreteKey ID("Idaho", nrColors), UT("Utah", nrColors), AZ("Arizona", nrColors);
DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors);
// Check construction and conversion
BinaryAllDiff c1(ID, UT);
DecisionTreeFactor f1(ID & UT, "0 1 1 0");
EXPECT(assert_equal(f1,(DecisionTreeFactor)c1));
// Check construction and conversion
BinaryAllDiff c2(UT, AZ);
DecisionTreeFactor f2(UT & AZ, "0 1 1 0");
EXPECT(assert_equal(f2,(DecisionTreeFactor)c2));
DecisionTreeFactor f3 = f1*f2;
EXPECT(assert_equal(f3,c1*f2));
EXPECT(assert_equal(f3,c2*f1));
}
/* ************************************************************************* */
TEST_UNSAFE( CSP, allInOne)
{
// Create keys and ordering
size_t nrColors = 2;
DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors);
// Create the CSP
CSP csp;
csp.addAllDiff(ID,UT);
csp.addAllDiff(UT,AZ);
// Check an invalid combination, with ID==UT==AZ all same color
DiscreteFactor::Values invalid;
invalid[ID.first] = 0;
invalid[UT.first] = 0;
invalid[AZ.first] = 0;
EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9);
// Check a valid combination
DiscreteFactor::Values valid;
valid[ID.first] = 0;
valid[UT.first] = 1;
valid[AZ.first] = 0;
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
// Just for fun, create the product and check it
DecisionTreeFactor product = csp.product();
// product.dot("product");
DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0");
EXPECT(assert_equal(expectedProduct,product));
// Solve
CSP::sharedValues mpe = csp.optimalAssignment();
CSP::Values expected;
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1);
EXPECT(assert_equal(expected,*mpe));
EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9);
}
/* ************************************************************************* */
TEST_UNSAFE( CSP, WesternUS)
{
// Create keys
size_t nrColors = 4;
DiscreteKey
// Create ordering according to example in ND-CSP.lyx
WA(0, nrColors), OR(3, nrColors), CA(1, nrColors),NV(2, nrColors),
ID(8, nrColors), UT(9, nrColors), AZ(10, nrColors),
MT(4, nrColors), WY(5, nrColors), CO(7, nrColors), NM(6, nrColors);
// Create the CSP
CSP csp;
csp.addAllDiff(WA,ID);
csp.addAllDiff(WA,OR);
csp.addAllDiff(OR,ID);
csp.addAllDiff(OR,CA);
csp.addAllDiff(OR,NV);
csp.addAllDiff(CA,NV);
csp.addAllDiff(CA,AZ);
csp.addAllDiff(ID,MT);
csp.addAllDiff(ID,WY);
csp.addAllDiff(ID,UT);
csp.addAllDiff(ID,NV);
csp.addAllDiff(NV,UT);
csp.addAllDiff(NV,AZ);
csp.addAllDiff(UT,WY);
csp.addAllDiff(UT,CO);
csp.addAllDiff(UT,NM);
csp.addAllDiff(UT,AZ);
csp.addAllDiff(AZ,CO);
csp.addAllDiff(AZ,NM);
csp.addAllDiff(MT,WY);
csp.addAllDiff(WY,CO);
csp.addAllDiff(CO,NM);
// Solve
CSP::sharedValues mpe = csp.optimalAssignment();
// GTSAM_PRINT(*mpe);
CSP::Values expected;
insert(expected)
(WA.first,1)(CA.first,1)(NV.first,3)(OR.first,0)
(MT.first,1)(WY.first,0)(NM.first,3)(CO.first,2)
(ID.first,2)(UT.first,1)(AZ.first,0);
EXPECT(assert_equal(expected,*mpe));
EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9);
// Write out the dual graph for hmetis
#ifdef DUAL
VariableIndex index(csp);
index.print("index");
ofstream os("/Users/dellaert/src/hmetis-1.5-osx-i686/US-West-dual.txt");
index.outputMetisFormat(os);
#endif
}
/* ************************************************************************* */
TEST_UNSAFE( CSP, AllDiff)
{
// Create keys and ordering
size_t nrColors = 3;
DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors);
// Create the CSP
CSP csp;
vector<DiscreteKey> dkeys;
dkeys += ID,UT,AZ;
csp.addAllDiff(dkeys);
csp.addSingleValue(AZ,2);
// GTSAM_PRINT(csp);
// Check construction and conversion
SingleValue s(AZ,2);
DecisionTreeFactor f1(AZ,"0 0 1");
EXPECT(assert_equal(f1,(DecisionTreeFactor)s));
// Check construction and conversion
AllDiff alldiff(dkeys);
DecisionTreeFactor actual = (DecisionTreeFactor)alldiff;
// GTSAM_PRINT(actual);
// actual.dot("actual");
DecisionTreeFactor f2(ID & AZ & UT,
"0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0");
EXPECT(assert_equal(f2,actual));
// Check an invalid combination, with ID==UT==AZ all same color
DiscreteFactor::Values invalid;
invalid[ID.first] = 0;
invalid[UT.first] = 1;
invalid[AZ.first] = 0;
EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9);
// Check a valid combination
DiscreteFactor::Values valid;
valid[ID.first] = 0;
valid[UT.first] = 1;
valid[AZ.first] = 2;
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
// Solve
CSP::sharedValues mpe = csp.optimalAssignment();
CSP::Values expected;
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2);
EXPECT(assert_equal(expected,*mpe));
EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9);
// Arc-consistency
vector<Domain> domains;
domains += Domain(ID), Domain(AZ), Domain(UT);
SingleValue singleValue(AZ,2);
EXPECT(singleValue.ensureArcConsistency(1,domains));
EXPECT(alldiff.ensureArcConsistency(0,domains));
EXPECT(!alldiff.ensureArcConsistency(1,domains));
EXPECT(alldiff.ensureArcConsistency(2,domains));
LONGS_EQUAL(2,domains[0].nrValues());
LONGS_EQUAL(1,domains[1].nrValues());
LONGS_EQUAL(2,domains[2].nrValues());
// Parial application, version 1
DiscreteFactor::Values known;
known[AZ.first] = 2;
DiscreteFactor::shared_ptr reduced1 = alldiff.partiallyApply(known);
DecisionTreeFactor f3(ID & UT, "0 1 1 1 0 1 1 1 0");
EXPECT(assert_equal(f3,reduced1->operator DecisionTreeFactor()));
DiscreteFactor::shared_ptr reduced2 = singleValue.partiallyApply(known);
DecisionTreeFactor f4(AZ, "0 0 1");
EXPECT(assert_equal(f4,reduced2->operator DecisionTreeFactor()));
// Parial application, version 2
DiscreteFactor::shared_ptr reduced3 = alldiff.partiallyApply(domains);
EXPECT(assert_equal(f3,reduced3->operator DecisionTreeFactor()));
DiscreteFactor::shared_ptr reduced4 = singleValue.partiallyApply(domains);
EXPECT(assert_equal(f4,reduced4->operator DecisionTreeFactor()));
// full arc-consistency test
csp.runArcConsistency(nrColors);
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */