Fixed CSP, now a DiscreteFactorGraph again, uses dynamic_cast for Constraint-specific functionality :-(

release/4.3a0
Frank Dellaert 2012-05-25 14:51:03 +00:00
parent 421a0725dd
commit 4ed447ca8e
8 changed files with 52 additions and 57 deletions

View File

@ -21,7 +21,7 @@ namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
void AllDiff::print(const std::string& s) const { void AllDiff::print(const std::string& s) const {
std::cout << s << ": AllDiff on "; std::cout << s << "AllDiff on ";
BOOST_FOREACH (Index dkey, keys_) BOOST_FOREACH (Index dkey, keys_)
std::cout << dkey << " "; std::cout << dkey << " ";
std::cout << std::endl; std::cout << std::endl;

View File

@ -34,7 +34,7 @@ namespace gtsam {
// print // print
virtual void print(const std::string& s = "") const { virtual void print(const std::string& s = "") const {
std::cout << s << ": BinaryAllDiff on " << keys_[0] << " and " << keys_[1] std::cout << s << "BinaryAllDiff on " << keys_[0] << " and " << keys_[1]
<< std::endl; << std::endl;
} }

View File

@ -18,7 +18,8 @@ set (discrete_full_libs
# Exclude tests that don't work # Exclude tests that don't work
set (discrete_excluded_tests set (discrete_excluded_tests
"${CMAKE_CURRENT_SOURCE_DIR}/tests/testScheduler.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/tests/testScheduler.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/tests/testCSP.cpp") #"${CMAKE_CURRENT_SOURCE_DIR}/tests/testCSP.cpp"
)
# Add all tests # Add all tests

View File

@ -49,8 +49,9 @@ namespace gtsam {
// if not already a singleton // if not already a singleton
if (!domains[v].isSingleton()) { if (!domains[v].isSingleton()) {
// get the constraint and call its ensureArcConsistency method // get the constraint and call its ensureArcConsistency method
Constraint::shared_ptr factor = (*this)[f]; Constraint::shared_ptr constraint = boost::dynamic_pointer_cast<Constraint>((*this)[f]);
changed[v] = factor->ensureArcConsistency(v,domains) || changed[v]; if (!constraint) throw runtime_error("CSP:runArcConsistency: non-constraint factor");
changed[v] = constraint->ensureArcConsistency(v,domains) || changed[v];
} }
} // f } // f
if (changed[v]) anyChange = true; if (changed[v]) anyChange = true;
@ -84,8 +85,10 @@ namespace gtsam {
// TODO: create a new ordering as we go, to ensure a connected graph // TODO: create a new ordering as we go, to ensure a connected graph
// KeyOrdering ordering; // KeyOrdering ordering;
// vector<Index> dkeys; // vector<Index> dkeys;
BOOST_FOREACH(const Constraint::shared_ptr& factor, factors_) { BOOST_FOREACH(const DiscreteFactor::shared_ptr& f, factors_) {
Constraint::shared_ptr reduced = factor->partiallyApply(domains); Constraint::shared_ptr constraint = boost::dynamic_pointer_cast<Constraint>(f);
if (!constraint) throw runtime_error("CSP:runArcConsistency: non-constraint factor");
Constraint::shared_ptr reduced = constraint->partiallyApply(domains);
if (print) reduced->print(); if (print) reduced->print();
} }
#endif #endif

View File

@ -18,7 +18,7 @@ namespace gtsam {
* A specialization of a DiscreteFactorGraph. * A specialization of a DiscreteFactorGraph.
* It knows about CSP-specific constraints and algorithms * It knows about CSP-specific constraints and algorithms
*/ */
class CSP: public FactorGraph<Constraint> { class CSP: public DiscreteFactorGraph {
public: public:
/** A map from keys to values */ /** A map from keys to values */
@ -27,30 +27,10 @@ namespace gtsam {
typedef boost::shared_ptr<Values> sharedValues; typedef boost::shared_ptr<Values> sharedValues;
public: public:
/// Constructor
CSP() {
}
template<class SOURCE> // /// Constructor
void add(const DiscreteKey& j, SOURCE table) { // CSP() {
DiscreteKeys keys; // }
keys.push_back(j);
push_back(boost::make_shared<DecisionTreeFactor>(keys, table));
}
template<class SOURCE>
void add(const DiscreteKey& j1, const DiscreteKey& j2, SOURCE table) {
DiscreteKeys keys;
keys.push_back(j1);
keys.push_back(j2);
push_back(boost::make_shared<DecisionTreeFactor>(keys, table));
}
/** add shared discreteFactor immediately from arguments */
template<class SOURCE>
void add(const DiscreteKeys& keys, SOURCE table) {
push_back(boost::make_shared<DecisionTreeFactor>(keys, table));
}
/// Add a unary constraint, allowing only a single value /// Add a unary constraint, allowing only a single value
void addSingleValue(const DiscreteKey& dkey, size_t value) { void addSingleValue(const DiscreteKey& dkey, size_t value) {
@ -71,19 +51,28 @@ namespace gtsam {
push_back(factor); push_back(factor);
} }
// /** return product of all factors as a single factor */
// DecisionTreeFactor product() const {
// DecisionTreeFactor result;
// BOOST_FOREACH(const sharedFactor& factor, *this)
// if (factor) result = (*factor) * result;
// return result;
// }
/// Find the best total assignment - can be expensive /// Find the best total assignment - can be expensive
sharedValues optimalAssignment() const; sharedValues optimalAssignment() const;
/* // /*
* Perform loopy belief propagation // * Perform loopy belief propagation
* True belief propagation would check for each value in domain // * True belief propagation would check for each value in domain
* whether any satisfying separator assignment can be found. // * whether any satisfying separator assignment can be found.
* This corresponds to hyper-arc consistency in CSP speak. // * This corresponds to hyper-arc consistency in CSP speak.
* This can be done by creating a mini-factor graph and search. // * This can be done by creating a mini-factor graph and search.
* For a nine-by-nine Sudoku, the search tree will be 8+6+6=20 levels deep. // * For a nine-by-nine Sudoku, the search tree will be 8+6+6=20 levels deep.
* It will be very expensive to exclude values that way. // * It will be very expensive to exclude values that way.
*/ // */
// void applyBeliefPropagation(size_t nrIterations = 10) const; // void applyBeliefPropagation(size_t nrIterations = 10) const;
/* /*
* Apply arc-consistency ~ Approximate loopy belief propagation * Apply arc-consistency ~ Approximate loopy belief propagation
* We need to give the domains to a constraint, and it returns * We need to give the domains to a constraint, and it returns
@ -92,7 +81,7 @@ namespace gtsam {
*/ */
void runArcConsistency(size_t cardinality, size_t nrIterations = 10, void runArcConsistency(size_t cardinality, size_t nrIterations = 10,
bool print = false) const; bool print = false) const;
}; }; // CSP
} // gtsam } // gtsam

View File

@ -17,8 +17,8 @@ namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
void SingleValue::print(const string& s) const { void SingleValue::print(const string& s) const {
cout << s << ": SingleValue on " << keys_[0] << " (j=" << keys_[0] cout << s << "SingleValue on " << "j=" << keys_[0]
<< ") with value " << value_ << endl; << " with value " << value_ << endl;
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -54,17 +54,17 @@ TEST_UNSAFE( CSP, allInOne)
invalid[ID.first] = 0; invalid[ID.first] = 0;
invalid[UT.first] = 0; invalid[UT.first] = 0;
invalid[AZ.first] = 0; invalid[AZ.first] = 0;
EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9); // FIXME: fails due to lack of operator() interface EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9);
// Check a valid combination // Check a valid combination
DiscreteFactor::Values valid; DiscreteFactor::Values valid;
valid[ID.first] = 0; valid[ID.first] = 0;
valid[UT.first] = 1; valid[UT.first] = 1;
valid[AZ.first] = 0; valid[AZ.first] = 0;
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); // FIXME: fails due to lack of operator() interface EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
// Just for fun, create the product and check it // Just for fun, create the product and check it
DecisionTreeFactor product = csp.product(); // FIXME: fails due to lack of product() DecisionTreeFactor product = csp.product();
// product.dot("product"); // product.dot("product");
DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0"); DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0");
EXPECT(assert_equal(expectedProduct,product)); EXPECT(assert_equal(expectedProduct,product));
@ -74,7 +74,7 @@ TEST_UNSAFE( CSP, allInOne)
CSP::Values expected; CSP::Values expected;
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1); insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1);
EXPECT(assert_equal(expected,*mpe)); EXPECT(assert_equal(expected,*mpe));
EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); // FIXME: fails due to lack of operator() interface EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9);
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -122,7 +122,7 @@ TEST_UNSAFE( CSP, WesternUS)
(MT.first,1)(WY.first,0)(NM.first,3)(CO.first,2) (MT.first,1)(WY.first,0)(NM.first,3)(CO.first,2)
(ID.first,2)(UT.first,1)(AZ.first,0); (ID.first,2)(UT.first,1)(AZ.first,0);
EXPECT(assert_equal(expected,*mpe)); EXPECT(assert_equal(expected,*mpe));
EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); // FIXME: fails due to lack of operator() interface EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9);
// Write out the dual graph for hmetis // Write out the dual graph for hmetis
#ifdef DUAL #ifdef DUAL
@ -146,7 +146,7 @@ TEST_UNSAFE( CSP, AllDiff)
dkeys += ID,UT,AZ; dkeys += ID,UT,AZ;
csp.addAllDiff(dkeys); csp.addAllDiff(dkeys);
csp.addSingleValue(AZ,2); csp.addSingleValue(AZ,2);
//GTSAM_PRINT(csp); // GTSAM_PRINT(csp);
// Check construction and conversion // Check construction and conversion
SingleValue s(AZ,2); SingleValue s(AZ,2);
@ -167,21 +167,21 @@ TEST_UNSAFE( CSP, AllDiff)
invalid[ID.first] = 0; invalid[ID.first] = 0;
invalid[UT.first] = 1; invalid[UT.first] = 1;
invalid[AZ.first] = 0; invalid[AZ.first] = 0;
EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9); // FIXME: fails due to lack of operator() interface EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9);
// Check a valid combination // Check a valid combination
DiscreteFactor::Values valid; DiscreteFactor::Values valid;
valid[ID.first] = 0; valid[ID.first] = 0;
valid[UT.first] = 1; valid[UT.first] = 1;
valid[AZ.first] = 2; valid[AZ.first] = 2;
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); // FIXME: fails due to lack of operator() interface EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
// Solve // Solve
CSP::sharedValues mpe = csp.optimalAssignment(); CSP::sharedValues mpe = csp.optimalAssignment();
CSP::Values expected; CSP::Values expected;
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2); insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2);
EXPECT(assert_equal(expected,*mpe)); EXPECT(assert_equal(expected,*mpe));
EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); // FIXME: fails due to lack of operator() interface EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9);
// Arc-consistency // Arc-consistency
vector<Domain> domains; vector<Domain> domains;

View File

@ -14,6 +14,8 @@
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
#define PRINT false
class Sudoku: public CSP { class Sudoku: public CSP {
/// sudoku size /// sudoku size
@ -119,7 +121,7 @@ TEST_UNSAFE( Sudoku, small)
0,1, 0,0); 0,1, 0,0);
// Do BP // Do BP
csp.runArcConsistency(4); csp.runArcConsistency(4,10,PRINT);
// optimize and check // optimize and check
CSP::sharedValues solution = csp.optimalAssignment(); CSP::sharedValues solution = csp.optimalAssignment();
@ -150,7 +152,7 @@ TEST_UNSAFE( Sudoku, easy)
5,0,0, 0,3,0, 7,0,0); 5,0,0, 0,3,0, 7,0,0);
// Do BP // Do BP
sudoku.runArcConsistency(4); sudoku.runArcConsistency(4,10,PRINT);
// sudoku.printSolution(); // don't do it // sudoku.printSolution(); // don't do it
} }
@ -172,7 +174,7 @@ TEST_UNSAFE( Sudoku, extreme)
0,0,0, 2,7,5, 9,0,0); 0,0,0, 2,7,5, 9,0,0);
// Do BP // Do BP
sudoku.runArcConsistency(9,10,false); sudoku.runArcConsistency(9,10,PRINT);
#ifdef METIS #ifdef METIS
VariableIndex index(sudoku); VariableIndex index(sudoku);
@ -201,7 +203,7 @@ TEST_UNSAFE( Sudoku, AJC_3star_Feb8_2012)
0,0,0, 1,0,0, 0,3,7); 0,0,0, 1,0,0, 0,3,7);
// Do BP // Do BP
sudoku.runArcConsistency(9,10,true); sudoku.runArcConsistency(9,10,PRINT);
//sudoku.printSolution(); // don't do it //sudoku.printSolution(); // don't do it
} }