Fixed CSP, now a DiscreteFactorGraph again, uses dynamic_cast for Constraint-specific functionality :-(
parent
421a0725dd
commit
4ed447ca8e
|
@ -21,7 +21,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void AllDiff::print(const std::string& s) const {
|
void AllDiff::print(const std::string& s) const {
|
||||||
std::cout << s << ": AllDiff on ";
|
std::cout << s << "AllDiff on ";
|
||||||
BOOST_FOREACH (Index dkey, keys_)
|
BOOST_FOREACH (Index dkey, keys_)
|
||||||
std::cout << dkey << " ";
|
std::cout << dkey << " ";
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
|
|
|
@ -34,7 +34,7 @@ namespace gtsam {
|
||||||
|
|
||||||
// print
|
// print
|
||||||
virtual void print(const std::string& s = "") const {
|
virtual void print(const std::string& s = "") const {
|
||||||
std::cout << s << ": BinaryAllDiff on " << keys_[0] << " and " << keys_[1]
|
std::cout << s << "BinaryAllDiff on " << keys_[0] << " and " << keys_[1]
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,8 @@ set (discrete_full_libs
|
||||||
# Exclude tests that don't work
|
# Exclude tests that don't work
|
||||||
set (discrete_excluded_tests
|
set (discrete_excluded_tests
|
||||||
"${CMAKE_CURRENT_SOURCE_DIR}/tests/testScheduler.cpp"
|
"${CMAKE_CURRENT_SOURCE_DIR}/tests/testScheduler.cpp"
|
||||||
"${CMAKE_CURRENT_SOURCE_DIR}/tests/testCSP.cpp")
|
#"${CMAKE_CURRENT_SOURCE_DIR}/tests/testCSP.cpp"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Add all tests
|
# Add all tests
|
||||||
|
|
|
@ -49,8 +49,9 @@ namespace gtsam {
|
||||||
// if not already a singleton
|
// if not already a singleton
|
||||||
if (!domains[v].isSingleton()) {
|
if (!domains[v].isSingleton()) {
|
||||||
// get the constraint and call its ensureArcConsistency method
|
// get the constraint and call its ensureArcConsistency method
|
||||||
Constraint::shared_ptr factor = (*this)[f];
|
Constraint::shared_ptr constraint = boost::dynamic_pointer_cast<Constraint>((*this)[f]);
|
||||||
changed[v] = factor->ensureArcConsistency(v,domains) || changed[v];
|
if (!constraint) throw runtime_error("CSP:runArcConsistency: non-constraint factor");
|
||||||
|
changed[v] = constraint->ensureArcConsistency(v,domains) || changed[v];
|
||||||
}
|
}
|
||||||
} // f
|
} // f
|
||||||
if (changed[v]) anyChange = true;
|
if (changed[v]) anyChange = true;
|
||||||
|
@ -84,8 +85,10 @@ namespace gtsam {
|
||||||
// TODO: create a new ordering as we go, to ensure a connected graph
|
// TODO: create a new ordering as we go, to ensure a connected graph
|
||||||
// KeyOrdering ordering;
|
// KeyOrdering ordering;
|
||||||
// vector<Index> dkeys;
|
// vector<Index> dkeys;
|
||||||
BOOST_FOREACH(const Constraint::shared_ptr& factor, factors_) {
|
BOOST_FOREACH(const DiscreteFactor::shared_ptr& f, factors_) {
|
||||||
Constraint::shared_ptr reduced = factor->partiallyApply(domains);
|
Constraint::shared_ptr constraint = boost::dynamic_pointer_cast<Constraint>(f);
|
||||||
|
if (!constraint) throw runtime_error("CSP:runArcConsistency: non-constraint factor");
|
||||||
|
Constraint::shared_ptr reduced = constraint->partiallyApply(domains);
|
||||||
if (print) reduced->print();
|
if (print) reduced->print();
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -18,7 +18,7 @@ namespace gtsam {
|
||||||
* A specialization of a DiscreteFactorGraph.
|
* A specialization of a DiscreteFactorGraph.
|
||||||
* It knows about CSP-specific constraints and algorithms
|
* It knows about CSP-specific constraints and algorithms
|
||||||
*/
|
*/
|
||||||
class CSP: public FactorGraph<Constraint> {
|
class CSP: public DiscreteFactorGraph {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/** A map from keys to values */
|
/** A map from keys to values */
|
||||||
|
@ -27,30 +27,10 @@ namespace gtsam {
|
||||||
typedef boost::shared_ptr<Values> sharedValues;
|
typedef boost::shared_ptr<Values> sharedValues;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/// Constructor
|
|
||||||
CSP() {
|
|
||||||
}
|
|
||||||
|
|
||||||
template<class SOURCE>
|
// /// Constructor
|
||||||
void add(const DiscreteKey& j, SOURCE table) {
|
// CSP() {
|
||||||
DiscreteKeys keys;
|
// }
|
||||||
keys.push_back(j);
|
|
||||||
push_back(boost::make_shared<DecisionTreeFactor>(keys, table));
|
|
||||||
}
|
|
||||||
|
|
||||||
template<class SOURCE>
|
|
||||||
void add(const DiscreteKey& j1, const DiscreteKey& j2, SOURCE table) {
|
|
||||||
DiscreteKeys keys;
|
|
||||||
keys.push_back(j1);
|
|
||||||
keys.push_back(j2);
|
|
||||||
push_back(boost::make_shared<DecisionTreeFactor>(keys, table));
|
|
||||||
}
|
|
||||||
|
|
||||||
/** add shared discreteFactor immediately from arguments */
|
|
||||||
template<class SOURCE>
|
|
||||||
void add(const DiscreteKeys& keys, SOURCE table) {
|
|
||||||
push_back(boost::make_shared<DecisionTreeFactor>(keys, table));
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add a unary constraint, allowing only a single value
|
/// Add a unary constraint, allowing only a single value
|
||||||
void addSingleValue(const DiscreteKey& dkey, size_t value) {
|
void addSingleValue(const DiscreteKey& dkey, size_t value) {
|
||||||
|
@ -71,19 +51,28 @@ namespace gtsam {
|
||||||
push_back(factor);
|
push_back(factor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// /** return product of all factors as a single factor */
|
||||||
|
// DecisionTreeFactor product() const {
|
||||||
|
// DecisionTreeFactor result;
|
||||||
|
// BOOST_FOREACH(const sharedFactor& factor, *this)
|
||||||
|
// if (factor) result = (*factor) * result;
|
||||||
|
// return result;
|
||||||
|
// }
|
||||||
|
|
||||||
/// Find the best total assignment - can be expensive
|
/// Find the best total assignment - can be expensive
|
||||||
sharedValues optimalAssignment() const;
|
sharedValues optimalAssignment() const;
|
||||||
|
|
||||||
/*
|
// /*
|
||||||
* Perform loopy belief propagation
|
// * Perform loopy belief propagation
|
||||||
* True belief propagation would check for each value in domain
|
// * True belief propagation would check for each value in domain
|
||||||
* whether any satisfying separator assignment can be found.
|
// * whether any satisfying separator assignment can be found.
|
||||||
* This corresponds to hyper-arc consistency in CSP speak.
|
// * This corresponds to hyper-arc consistency in CSP speak.
|
||||||
* This can be done by creating a mini-factor graph and search.
|
// * This can be done by creating a mini-factor graph and search.
|
||||||
* For a nine-by-nine Sudoku, the search tree will be 8+6+6=20 levels deep.
|
// * For a nine-by-nine Sudoku, the search tree will be 8+6+6=20 levels deep.
|
||||||
* It will be very expensive to exclude values that way.
|
// * It will be very expensive to exclude values that way.
|
||||||
*/
|
// */
|
||||||
// void applyBeliefPropagation(size_t nrIterations = 10) const;
|
// void applyBeliefPropagation(size_t nrIterations = 10) const;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Apply arc-consistency ~ Approximate loopy belief propagation
|
* Apply arc-consistency ~ Approximate loopy belief propagation
|
||||||
* We need to give the domains to a constraint, and it returns
|
* We need to give the domains to a constraint, and it returns
|
||||||
|
@ -92,7 +81,7 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
void runArcConsistency(size_t cardinality, size_t nrIterations = 10,
|
void runArcConsistency(size_t cardinality, size_t nrIterations = 10,
|
||||||
bool print = false) const;
|
bool print = false) const;
|
||||||
};
|
}; // CSP
|
||||||
|
|
||||||
} // gtsam
|
} // gtsam
|
||||||
|
|
||||||
|
|
|
@ -17,8 +17,8 @@ namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void SingleValue::print(const string& s) const {
|
void SingleValue::print(const string& s) const {
|
||||||
cout << s << ": SingleValue on " << keys_[0] << " (j=" << keys_[0]
|
cout << s << "SingleValue on " << "j=" << keys_[0]
|
||||||
<< ") with value " << value_ << endl;
|
<< " with value " << value_ << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -54,17 +54,17 @@ TEST_UNSAFE( CSP, allInOne)
|
||||||
invalid[ID.first] = 0;
|
invalid[ID.first] = 0;
|
||||||
invalid[UT.first] = 0;
|
invalid[UT.first] = 0;
|
||||||
invalid[AZ.first] = 0;
|
invalid[AZ.first] = 0;
|
||||||
EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9); // FIXME: fails due to lack of operator() interface
|
EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9);
|
||||||
|
|
||||||
// Check a valid combination
|
// Check a valid combination
|
||||||
DiscreteFactor::Values valid;
|
DiscreteFactor::Values valid;
|
||||||
valid[ID.first] = 0;
|
valid[ID.first] = 0;
|
||||||
valid[UT.first] = 1;
|
valid[UT.first] = 1;
|
||||||
valid[AZ.first] = 0;
|
valid[AZ.first] = 0;
|
||||||
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); // FIXME: fails due to lack of operator() interface
|
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
|
||||||
|
|
||||||
// Just for fun, create the product and check it
|
// Just for fun, create the product and check it
|
||||||
DecisionTreeFactor product = csp.product(); // FIXME: fails due to lack of product()
|
DecisionTreeFactor product = csp.product();
|
||||||
// product.dot("product");
|
// product.dot("product");
|
||||||
DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0");
|
DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0");
|
||||||
EXPECT(assert_equal(expectedProduct,product));
|
EXPECT(assert_equal(expectedProduct,product));
|
||||||
|
@ -74,7 +74,7 @@ TEST_UNSAFE( CSP, allInOne)
|
||||||
CSP::Values expected;
|
CSP::Values expected;
|
||||||
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1);
|
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1);
|
||||||
EXPECT(assert_equal(expected,*mpe));
|
EXPECT(assert_equal(expected,*mpe));
|
||||||
EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); // FIXME: fails due to lack of operator() interface
|
EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -122,7 +122,7 @@ TEST_UNSAFE( CSP, WesternUS)
|
||||||
(MT.first,1)(WY.first,0)(NM.first,3)(CO.first,2)
|
(MT.first,1)(WY.first,0)(NM.first,3)(CO.first,2)
|
||||||
(ID.first,2)(UT.first,1)(AZ.first,0);
|
(ID.first,2)(UT.first,1)(AZ.first,0);
|
||||||
EXPECT(assert_equal(expected,*mpe));
|
EXPECT(assert_equal(expected,*mpe));
|
||||||
EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); // FIXME: fails due to lack of operator() interface
|
EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9);
|
||||||
|
|
||||||
// Write out the dual graph for hmetis
|
// Write out the dual graph for hmetis
|
||||||
#ifdef DUAL
|
#ifdef DUAL
|
||||||
|
@ -146,7 +146,7 @@ TEST_UNSAFE( CSP, AllDiff)
|
||||||
dkeys += ID,UT,AZ;
|
dkeys += ID,UT,AZ;
|
||||||
csp.addAllDiff(dkeys);
|
csp.addAllDiff(dkeys);
|
||||||
csp.addSingleValue(AZ,2);
|
csp.addSingleValue(AZ,2);
|
||||||
//GTSAM_PRINT(csp);
|
// GTSAM_PRINT(csp);
|
||||||
|
|
||||||
// Check construction and conversion
|
// Check construction and conversion
|
||||||
SingleValue s(AZ,2);
|
SingleValue s(AZ,2);
|
||||||
|
@ -167,21 +167,21 @@ TEST_UNSAFE( CSP, AllDiff)
|
||||||
invalid[ID.first] = 0;
|
invalid[ID.first] = 0;
|
||||||
invalid[UT.first] = 1;
|
invalid[UT.first] = 1;
|
||||||
invalid[AZ.first] = 0;
|
invalid[AZ.first] = 0;
|
||||||
EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9); // FIXME: fails due to lack of operator() interface
|
EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9);
|
||||||
|
|
||||||
// Check a valid combination
|
// Check a valid combination
|
||||||
DiscreteFactor::Values valid;
|
DiscreteFactor::Values valid;
|
||||||
valid[ID.first] = 0;
|
valid[ID.first] = 0;
|
||||||
valid[UT.first] = 1;
|
valid[UT.first] = 1;
|
||||||
valid[AZ.first] = 2;
|
valid[AZ.first] = 2;
|
||||||
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); // FIXME: fails due to lack of operator() interface
|
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
|
||||||
|
|
||||||
// Solve
|
// Solve
|
||||||
CSP::sharedValues mpe = csp.optimalAssignment();
|
CSP::sharedValues mpe = csp.optimalAssignment();
|
||||||
CSP::Values expected;
|
CSP::Values expected;
|
||||||
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2);
|
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2);
|
||||||
EXPECT(assert_equal(expected,*mpe));
|
EXPECT(assert_equal(expected,*mpe));
|
||||||
EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); // FIXME: fails due to lack of operator() interface
|
EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9);
|
||||||
|
|
||||||
// Arc-consistency
|
// Arc-consistency
|
||||||
vector<Domain> domains;
|
vector<Domain> domains;
|
||||||
|
|
|
@ -14,6 +14,8 @@
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
|
#define PRINT false
|
||||||
|
|
||||||
class Sudoku: public CSP {
|
class Sudoku: public CSP {
|
||||||
|
|
||||||
/// sudoku size
|
/// sudoku size
|
||||||
|
@ -119,7 +121,7 @@ TEST_UNSAFE( Sudoku, small)
|
||||||
0,1, 0,0);
|
0,1, 0,0);
|
||||||
|
|
||||||
// Do BP
|
// Do BP
|
||||||
csp.runArcConsistency(4);
|
csp.runArcConsistency(4,10,PRINT);
|
||||||
|
|
||||||
// optimize and check
|
// optimize and check
|
||||||
CSP::sharedValues solution = csp.optimalAssignment();
|
CSP::sharedValues solution = csp.optimalAssignment();
|
||||||
|
@ -150,7 +152,7 @@ TEST_UNSAFE( Sudoku, easy)
|
||||||
5,0,0, 0,3,0, 7,0,0);
|
5,0,0, 0,3,0, 7,0,0);
|
||||||
|
|
||||||
// Do BP
|
// Do BP
|
||||||
sudoku.runArcConsistency(4);
|
sudoku.runArcConsistency(4,10,PRINT);
|
||||||
|
|
||||||
// sudoku.printSolution(); // don't do it
|
// sudoku.printSolution(); // don't do it
|
||||||
}
|
}
|
||||||
|
@ -172,7 +174,7 @@ TEST_UNSAFE( Sudoku, extreme)
|
||||||
0,0,0, 2,7,5, 9,0,0);
|
0,0,0, 2,7,5, 9,0,0);
|
||||||
|
|
||||||
// Do BP
|
// Do BP
|
||||||
sudoku.runArcConsistency(9,10,false);
|
sudoku.runArcConsistency(9,10,PRINT);
|
||||||
|
|
||||||
#ifdef METIS
|
#ifdef METIS
|
||||||
VariableIndex index(sudoku);
|
VariableIndex index(sudoku);
|
||||||
|
@ -201,7 +203,7 @@ TEST_UNSAFE( Sudoku, AJC_3star_Feb8_2012)
|
||||||
0,0,0, 1,0,0, 0,3,7);
|
0,0,0, 1,0,0, 0,3,7);
|
||||||
|
|
||||||
// Do BP
|
// Do BP
|
||||||
sudoku.runArcConsistency(9,10,true);
|
sudoku.runArcConsistency(9,10,PRINT);
|
||||||
|
|
||||||
//sudoku.printSolution(); // don't do it
|
//sudoku.printSolution(); // don't do it
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue