diff --git a/gtsam_unstable/discrete/AllDiff.cpp b/gtsam_unstable/discrete/AllDiff.cpp index ef18053a4..85cf0b472 100644 --- a/gtsam_unstable/discrete/AllDiff.cpp +++ b/gtsam_unstable/discrete/AllDiff.cpp @@ -57,14 +57,11 @@ DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const { } /* ************************************************************************* */ -bool AllDiff::ensureArcConsistency(size_t j, - std::vector* domains) const { - // We are changing the domain of variable j. - // TODO(dellaert): confusing, I thought we were changing others... +bool AllDiff::ensureArcConsistency(Key j, Domains* domains) const { Domain& Dj = domains->at(j); // Though strictly not part of allDiff, we check for - // a value in domains[j] that does not occur in any other connected domain. + // a value in domains->at(j) that does not occur in any other connected domain. // If found, we make this a singleton... // TODO: make a new constraint where this really is true boost::optional maybeChanged = Dj.checkAllDiff(keys_, *domains); @@ -103,10 +100,10 @@ Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const { /* ************************************************************************* */ Constraint::shared_ptr AllDiff::partiallyApply( - const std::vector& domains) const { + const Domains& domains) const { DiscreteFactor::Values known; for (Key k : keys_) { - const Domain& Dk = domains[k]; + const Domain& Dk = domains.at(k); if (Dk.isSingleton()) known[k] = Dk.firstValue(); } return partiallyApply(known); diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index 4deabda94..57b0aeb5c 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -54,21 +54,19 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; /* - * Ensure Arc-consistency - * Arc-consistency involves creating binaryAllDiff constraints - * In which case the combinatorial hyper-arc explosion disappears. + * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked - * @param (in/out) domains all other domains + * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @return true if domains->at(j) was changed, false otherwise. */ - bool ensureArcConsistency(size_t j, - std::vector* domains) const override; + bool ensureArcConsistency(Key j, Domains* domains) const override; /// Partially apply known values Constraint::shared_ptr partiallyApply(const Values&) const override; /// Partially apply known values, domain version Constraint::shared_ptr partiallyApply( - const std::vector&) const override; + const Domains&) const override; }; } // namespace gtsam diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index 21cfb18f2..a2c7ba660 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -70,13 +70,12 @@ class BinaryAllDiff : public Constraint { } /* - * Ensure Arc-consistency + * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked - * @param domains all other domains + * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @return true if domains->at(j) was changed, false otherwise. */ - /// - bool ensureArcConsistency(size_t j, - std::vector* domains) const override { + bool ensureArcConsistency(Key j, Domains* domains) const override { throw std::runtime_error( "BinaryAllDiff::ensureArcConsistency not implemented"); return false; @@ -89,7 +88,7 @@ class BinaryAllDiff : public Constraint { /// Partially apply known values, domain version Constraint::shared_ptr partiallyApply( - const std::vector&) const override { + const Domains&) const override { throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); } }; diff --git a/gtsam_unstable/discrete/CSP.cpp b/gtsam_unstable/discrete/CSP.cpp index bab1ac3c8..8c974f4fd 100644 --- a/gtsam_unstable/discrete/CSP.cpp +++ b/gtsam_unstable/discrete/CSP.cpp @@ -27,81 +27,75 @@ CSP::sharedValues CSP::optimalAssignment(const Ordering& ordering) const { return mpe; } -void CSP::runArcConsistency(size_t cardinality, size_t nrIterations, - bool print) const { +bool CSP::runArcConsistency(const VariableIndex& index, + Domains* domains) const { + bool changed = false; + + // iterate over all variables in the index + for (auto entry : index) { + // Get the variable's key and associated factors: + const Key key = entry.first; + const FactorIndices& factors = entry.second; + + // If this domain is already a singleton, we do nothing. + if (domains->at(key).isSingleton()) continue; + + // Otherwise, loop over all factors/constraints for variable with given key. + for (size_t f : factors) { + // If this factor is a constraint, call its ensureArcConsistency method: + auto constraint = boost::dynamic_pointer_cast((*this)[f]); + if (constraint) { + changed = constraint->ensureArcConsistency(key, domains) || changed; + } + } + } + return changed; +} + +// TODO(dellaert): This is AC1, which is inefficient as any change will cause +// the algorithm to revisit *all* variables again. Implement AC3. +Domains CSP::runArcConsistency(size_t cardinality, size_t maxIterations) const { // Create VariableIndex VariableIndex index(*this); - // index.print(); - - size_t n = index.size(); // Initialize domains - std::vector domains; - for (size_t j = 0; j < n; j++) - domains.push_back(Domain(DiscreteKey(j, cardinality))); + Domains domains; + for (auto entry : index) { + const Key key = entry.first; + domains.emplace(key, DiscreteKey(key, cardinality)); + } - // Create array of flags indicating a domain changed or not - std::vector changed(n); + // Iterate until convergence or not a single domain changed. + for (size_t it = 0; it < maxIterations; it++) { + bool changed = runArcConsistency(index, &domains); + if (!changed) break; + } + return domains; +} - // iterate nrIterations over entire grid - for (size_t it = 0; it < nrIterations; it++) { - bool anyChange = false; - // iterate over all cells - for (size_t v = 0; v < n; v++) { - // keep track of which domains changed - changed[v] = false; - // loop over all factors/constraints for variable v - const FactorIndices& factors = index[v]; - for (size_t f : factors) { - // if not already a singleton - if (!domains[v].isSingleton()) { - // get the constraint and call its ensureArcConsistency method - auto constraint = boost::dynamic_pointer_cast((*this)[f]); - if (!constraint) - throw runtime_error("CSP:runArcConsistency: non-constraint factor"); - changed[v] = - constraint->ensureArcConsistency(v, &domains) || changed[v]; - } - } // f - if (changed[v]) anyChange = true; - } // v - if (!anyChange) break; - // TODO: Sudoku specific hack - if (print) { - if (cardinality == 9 && n == 81) { - for (size_t i = 0, v = 0; i < (size_t)std::sqrt((double)n); i++) { - for (size_t j = 0; j < (size_t)std::sqrt((double)n); j++, v++) { - if (changed[v]) cout << "*"; - domains[v].print(); - cout << "\t"; - } // i - cout << endl; - } // j - } else { - for (size_t v = 0; v < n; v++) { - if (changed[v]) cout << "*"; - domains[v].print(); - cout << "\t"; - } // v - } - cout << endl; - } // print - } // it - -#ifndef INPROGRESS - // Now create new problem with all singleton variables removed - // We do this by adding simplifying all factors using parial application +CSP CSP::partiallyApply(const Domains& domains) const { + // Create new problem with all singleton variables removed + // We do this by adding simplifying all factors using partial application. // TODO: create a new ordering as we go, to ensure a connected graph // KeyOrdering ordering; // vector dkeys; + CSP new_csp; + + // Add tightened domains as new factors: + for (auto key_domain : domains) { + new_csp.emplace_shared(key_domain.second); + } + + // Reduce all existing factors: for (const DiscreteFactor::shared_ptr& f : factors_) { - Constraint::shared_ptr constraint = - boost::dynamic_pointer_cast(f); + auto constraint = boost::dynamic_pointer_cast(f); if (!constraint) throw runtime_error("CSP:runArcConsistency: non-constraint factor"); Constraint::shared_ptr reduced = constraint->partiallyApply(domains); - if (print) reduced->print(); + if (reduced->size() > 1) { + new_csp.push_back(reduced); + } } -#endif + return new_csp; } } // namespace gtsam diff --git a/gtsam_unstable/discrete/CSP.h b/gtsam_unstable/discrete/CSP.h index e43e53932..d94913682 100644 --- a/gtsam_unstable/discrete/CSP.h +++ b/gtsam_unstable/discrete/CSP.h @@ -62,7 +62,7 @@ class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph { // deep. // * It will be very expensive to exclude values that way. // */ - // void applyBeliefPropagation(size_t nrIterations = 10) const; + // void applyBeliefPropagation(size_t maxIterations = 10) const; /* * Apply arc-consistency ~ Approximate loopy belief propagation @@ -70,8 +70,16 @@ class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph { * a domain whose values don't conflict in the arc-consistency way. * TODO: should get cardinality from DiscreteKeys */ - void runArcConsistency(size_t cardinality, size_t nrIterations = 10, - bool print = false) const; + Domains runArcConsistency(size_t cardinality, + size_t maxIterations = 10) const; + + /// Run arc consistency for all variables, return true if any domain changed. + bool runArcConsistency(const VariableIndex& index, Domains* domains) const; + + /* + * Create a new CSP, applying the given Domain constraints. + */ + CSP partiallyApply(const Domains& domains) const; }; // CSP } // namespace gtsam diff --git a/gtsam_unstable/discrete/Constraint.h b/gtsam_unstable/discrete/Constraint.h index e9714d6b4..f0e51b723 100644 --- a/gtsam_unstable/discrete/Constraint.h +++ b/gtsam_unstable/discrete/Constraint.h @@ -21,10 +21,12 @@ #include #include +#include namespace gtsam { class Domain; +using Domains = std::map; /** * Base class for constraint factors @@ -65,19 +67,18 @@ class GTSAM_EXPORT Constraint : public DiscreteFactor { /// @{ /* - * Ensure Arc-consistency, possibly changing domains of connected variables. + * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked - * @param (in/out) domains all other domains - * @return true if domains were changed, false otherwise. + * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @return true if domains->at(j) was changed, false otherwise. */ - virtual bool ensureArcConsistency(size_t j, - std::vector* domains) const = 0; + virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0; /// Partially apply known values virtual shared_ptr partiallyApply(const Values&) const = 0; /// Partially apply known values, domain version - virtual shared_ptr partiallyApply(const std::vector&) const = 0; + virtual shared_ptr partiallyApply(const Domains&) const = 0; /// @} }; // DiscreteFactor diff --git a/gtsam_unstable/discrete/Domain.cpp b/gtsam_unstable/discrete/Domain.cpp index da23717f6..98b735c6c 100644 --- a/gtsam_unstable/discrete/Domain.cpp +++ b/gtsam_unstable/discrete/Domain.cpp @@ -10,28 +10,35 @@ #include #include - +#include namespace gtsam { using namespace std; /* ************************************************************************* */ void Domain::print(const string& s, const KeyFormatter& formatter) const { - cout << s << ": Domain on " << formatter(keys_[0]) - << " (j=" << formatter(keys_[0]) << ") with values"; + cout << s << ": Domain on " << formatter(key()) << " (j=" << formatter(key()) + << ") with values"; for (size_t v : values_) cout << " " << v; cout << endl; } +/* ************************************************************************* */ +string Domain::base1Str() const { + stringstream ss; + for (size_t v : values_) ss << v + 1; + return ss.str(); +} + /* ************************************************************************* */ double Domain::operator()(const Values& values) const { - return contains(values.at(keys_[0])); + return contains(values.at(key())); } /* ************************************************************************* */ DecisionTreeFactor Domain::toDecisionTreeFactor() const { DiscreteKeys keys; - keys += DiscreteKey(keys_[0], cardinality_); + keys += DiscreteKey(key(), cardinality_); vector table; for (size_t i1 = 0; i1 < cardinality_; ++i1) table.push_back(contains(i1)); DecisionTreeFactor converted(keys, table); @@ -45,8 +52,8 @@ DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const { } /* ************************************************************************* */ -bool Domain::ensureArcConsistency(size_t j, vector* domains) const { - if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain"); +bool Domain::ensureArcConsistency(Key j, Domains* domains) const { + if (j != key()) throw invalid_argument("Domain check on wrong domain"); Domain& D = domains->at(j); for (size_t value : values_) if (!D.contains(value)) throw runtime_error("Unsatisfiable"); @@ -55,15 +62,15 @@ bool Domain::ensureArcConsistency(size_t j, vector* domains) const { } /* ************************************************************************* */ -boost::optional Domain::checkAllDiff( - const KeyVector keys, const vector& domains) const { - Key j = keys_[0]; +boost::optional Domain::checkAllDiff(const KeyVector keys, + const Domains& domains) const { + Key j = key(); // for all values in this domain for (const size_t value : values_) { // for all connected domains for (const Key k : keys) // if any domain contains the value we cannot make this domain singleton - if (k != j && domains[k].contains(value)) goto found; + if (k != j && domains.at(k).contains(value)) goto found; // Otherwise: return a singleton: return Domain(this->discreteKey(), value); found:; @@ -73,16 +80,15 @@ boost::optional Domain::checkAllDiff( /* ************************************************************************* */ Constraint::shared_ptr Domain::partiallyApply(const Values& values) const { - Values::const_iterator it = values.find(keys_[0]); + Values::const_iterator it = values.find(key()); if (it != values.end() && !contains(it->second)) throw runtime_error("Domain::partiallyApply: unsatisfiable"); return boost::make_shared(*this); } /* ************************************************************************* */ -Constraint::shared_ptr Domain::partiallyApply( - const vector& domains) const { - const Domain& Dk = domains[keys_[0]]; +Constraint::shared_ptr Domain::partiallyApply(const Domains& domains) const { + const Domain& Dk = domains.at(key()); if (Dk.isSingleton() && !contains(*Dk.begin())) throw runtime_error("Domain::partiallyApply: unsatisfiable"); return boost::make_shared(Dk); diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 9fa22175a..ae137ca33 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -20,10 +20,6 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { size_t cardinality_; /// Cardinality std::set values_; /// allowed values - DiscreteKey discreteKey() const { - return DiscreteKey(keys_[0], cardinality_); - } - public: typedef boost::shared_ptr shared_ptr; @@ -40,6 +36,12 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { values_.insert(v); } + /// The one key + Key key() const { return keys_[0]; } + + // The associated discrete key + DiscreteKey discreteKey() const { return DiscreteKey(key(), cardinality_); } + /// Insert a value, non const :-( void insert(size_t value) { values_.insert(value); } @@ -66,6 +68,11 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { } } + // Return concise string representation, mostly to debug arc consistency. + // Converts from base 0 to base1. + std::string base1Str() const; + + // Check whether domain cotains a specific value. bool contains(size_t value) const { return values_.count(value) > 0; } /// Calculate value @@ -78,12 +85,13 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; /* - * Ensure Arc-consistency + * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked - * @param domains all other domains + * @param (in/out) domains all domains, but only domains->at(j) will be + * checked. + * @return true if domains->at(j) was changed, false otherwise. */ - bool ensureArcConsistency(size_t j, - std::vector* domains) const override; + bool ensureArcConsistency(Key j, Domains* domains) const override; /** * Check for a value in domain that does not occur in any other connected @@ -92,15 +100,14 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { * @param keys connected domains through alldiff * @param keys other domains */ - boost::optional checkAllDiff( - const KeyVector keys, const std::vector& domains) const; + boost::optional checkAllDiff(const KeyVector keys, + const Domains& domains) const; /// Partially apply known values Constraint::shared_ptr partiallyApply(const Values& values) const override; /// Partially apply known values, domain version - Constraint::shared_ptr partiallyApply( - const std::vector& domains) const override; + Constraint::shared_ptr partiallyApply(const Domains& domains) const override; }; } // namespace gtsam diff --git a/gtsam_unstable/discrete/SingleValue.cpp b/gtsam_unstable/discrete/SingleValue.cpp index 753d46cff..162e21512 100644 --- a/gtsam_unstable/discrete/SingleValue.cpp +++ b/gtsam_unstable/discrete/SingleValue.cpp @@ -44,8 +44,7 @@ DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const { } /* ************************************************************************* */ -bool SingleValue::ensureArcConsistency(size_t j, - vector* domains) const { +bool SingleValue::ensureArcConsistency(Key j, Domains* domains) const { if (j != keys_[0]) throw invalid_argument("SingleValue check on wrong domain"); Domain& D = domains->at(j); @@ -67,8 +66,8 @@ Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const { /* ************************************************************************* */ Constraint::shared_ptr SingleValue::partiallyApply( - const vector& domains) const { - const Domain& Dk = domains[keys_[0]]; + const Domains& domains) const { + const Domain& Dk = domains.at(keys_[0]); if (Dk.isSingleton() && !Dk.contains(value_)) throw runtime_error("SingleValue::partiallyApply: unsatisfiable"); return boost::make_shared(discreteKey(), value_); diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index d8a9a770b..d826093df 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -59,19 +59,19 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; /* - * Ensure Arc-consistency: just sets domain[j] to {value_} + * Ensure Arc-consistency: just sets domain[j] to {value_}. * @param j domain to be checked - * @param domains all other domains + * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @return true if domains->at(j) was changed, false otherwise. */ - bool ensureArcConsistency(size_t j, - std::vector* domains) const override; + bool ensureArcConsistency(Key j, Domains* domains) const override; /// Partially apply known values Constraint::shared_ptr partiallyApply(const Values& values) const override; /// Partially apply known values, domain version Constraint::shared_ptr partiallyApply( - const std::vector& domains) const override; + const Domains& domains) const override; }; } // namespace gtsam diff --git a/gtsam_unstable/discrete/tests/testCSP.cpp b/gtsam_unstable/discrete/tests/testCSP.cpp index 832175455..63069d710 100644 --- a/gtsam_unstable/discrete/tests/testCSP.cpp +++ b/gtsam_unstable/discrete/tests/testCSP.cpp @@ -29,16 +29,17 @@ TEST(CSP, SingleValue) { DecisionTreeFactor f1(AZ, "0 0 1"); EXPECT(assert_equal(f1, singleValue.toDecisionTreeFactor())); - // Create domains, laid out as a vector. - // TODO(dellaert): should be map?? - vector domains; - domains += Domain(ID), Domain(AZ), Domain(UT); + // Create domains + Domains domains; + domains.emplace(0, Domain(ID)); + domains.emplace(1, Domain(AZ)); + domains.emplace(2, Domain(UT)); // Ensure arc-consistency: just wipes out values in AZ domain: EXPECT(singleValue.ensureArcConsistency(1, &domains)); - LONGS_EQUAL(3, domains[0].nrValues()); - LONGS_EQUAL(1, domains[1].nrValues()); - LONGS_EQUAL(3, domains[2].nrValues()); + LONGS_EQUAL(3, domains.at(0).nrValues()); + LONGS_EQUAL(1, domains.at(1).nrValues()); + LONGS_EQUAL(3, domains.at(2).nrValues()); } /* ************************************************************************* */ @@ -81,8 +82,10 @@ TEST(CSP, AllDiff) { EXPECT(assert_equal(f2, actual)); // Create domains. - vector domains; - domains += Domain(ID), Domain(AZ), Domain(UT); + Domains domains; + domains.emplace(0, Domain(ID)); + domains.emplace(1, Domain(AZ)); + domains.emplace(2, Domain(UT)); // First constrict AZ domain: SingleValue singleValue(AZ, 2); @@ -92,9 +95,9 @@ TEST(CSP, AllDiff) { EXPECT(alldiff.ensureArcConsistency(0, &domains)); EXPECT(!alldiff.ensureArcConsistency(1, &domains)); EXPECT(alldiff.ensureArcConsistency(2, &domains)); - LONGS_EQUAL(2, domains[0].nrValues()); - LONGS_EQUAL(1, domains[1].nrValues()); - LONGS_EQUAL(2, domains[2].nrValues()); + LONGS_EQUAL(2, domains.at(0).nrValues()); + LONGS_EQUAL(1, domains.at(1).nrValues()); + LONGS_EQUAL(2, domains.at(2).nrValues()); } /* ************************************************************************* */ @@ -232,17 +235,20 @@ TEST(CSP, ArcConsistency) { EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); // ensure arc-consistency, i.e., narrow domains... - vector domains; - domains += Domain(ID), Domain(AZ), Domain(UT); + Domains domains; + domains.emplace(0, Domain(ID)); + domains.emplace(1, Domain(AZ)); + domains.emplace(2, Domain(UT)); + SingleValue singleValue(AZ, 2); AllDiff alldiff(dkeys); EXPECT(singleValue.ensureArcConsistency(1, &domains)); EXPECT(alldiff.ensureArcConsistency(0, &domains)); EXPECT(!alldiff.ensureArcConsistency(1, &domains)); EXPECT(alldiff.ensureArcConsistency(2, &domains)); - LONGS_EQUAL(2, domains[0].nrValues()); - LONGS_EQUAL(1, domains[1].nrValues()); - LONGS_EQUAL(2, domains[2].nrValues()); + LONGS_EQUAL(2, domains.at(0).nrValues()); + LONGS_EQUAL(1, domains.at(1).nrValues()); + LONGS_EQUAL(2, domains.at(2).nrValues()); // Parial application, version 1 DiscreteFactor::Values known; diff --git a/gtsam_unstable/discrete/tests/testSudoku.cpp b/gtsam_unstable/discrete/tests/testSudoku.cpp index 4843ae269..ee307fd5b 100644 --- a/gtsam_unstable/discrete/tests/testSudoku.cpp +++ b/gtsam_unstable/discrete/tests/testSudoku.cpp @@ -6,6 +6,7 @@ */ #include +#include #include #include @@ -20,12 +21,12 @@ using namespace gtsam; #define PRINT false +/// A class that encodes Sudoku's as a CSP problem class Sudoku : public CSP { - /// sudoku size - size_t n_; + size_t n_; ///< Side of Sudoku, e.g. 4 or 9 - /// discrete keys - typedef std::pair IJ; + /// Mapping from base i,j coordinates to discrete keys: + using IJ = std::pair; std::map dkeys_; public: @@ -42,15 +43,14 @@ class Sudoku : public CSP { // Create variables, ordering, and unary constraints va_list ap; va_start(ap, n); - Key k = 0; for (size_t i = 0; i < n; ++i) { - for (size_t j = 0; j < n; ++j, ++k) { + for (size_t j = 0; j < n; ++j) { // create the key IJ ij(i, j); - dkeys_[ij] = DiscreteKey(k, n); + Symbol key('1' + i, j + 1); + dkeys_[ij] = DiscreteKey(key, n); // get the unary constraint, if any int value = va_arg(ap, int); - // cout << value << " "; if (value != 0) addSingleValue(dkeys_[ij], value - 1); } // cout << endl; @@ -88,7 +88,7 @@ class Sudoku : public CSP { } /// Print readable form of assignment - void printAssignment(DiscreteFactor::sharedValues assignment) const { + void printAssignment(const DiscreteFactor::sharedValues& assignment) const { for (size_t i = 0; i < n_; i++) { for (size_t j = 0; j < n_; j++) { Key k = key(i, j); @@ -99,10 +99,22 @@ class Sudoku : public CSP { } /// solve and print solution - void printSolution() { + void printSolution() const { DiscreteFactor::sharedValues MPE = optimalAssignment(); printAssignment(MPE); } + + // Print domain + void printDomains(const Domains& domains) { + for (size_t i = 0; i < n_; i++) { + for (size_t j = 0; j < n_; j++) { + Key k = key(i, j); + cout << domains.at(k).base1Str(); + cout << "\t"; + } // i + cout << endl; + } // j + } }; /* ************************************************************************* */ @@ -113,9 +125,6 @@ TEST_UNSAFE(Sudoku, small) { 4, 0, 2, 0, // 0, 1, 0, 0); - // Do BP - csp.runArcConsistency(4, 10, PRINT); - // optimize and check CSP::sharedValues solution = csp.optimalAssignment(); CSP::Values expected; @@ -126,73 +135,124 @@ TEST_UNSAFE(Sudoku, small) { csp.key(3, 3), 2); EXPECT(assert_equal(expected, *solution)); // csp.printAssignment(solution); + + // Do BP (AC1) + auto domains = csp.runArcConsistency(4, 3); + // csp.printDomains(domains); + Domain domain44 = domains.at(Symbol('4', 4)); + EXPECT_LONGS_EQUAL(1, domain44.nrValues()); + + // Test Creation of a new, simpler CSP + CSP new_csp = csp.partiallyApply(domains); + // Should only be 16 new Domains + EXPECT_LONGS_EQUAL(16, new_csp.size()); + + // Check that solution + CSP::sharedValues new_solution = new_csp.optimalAssignment(); + // csp.printAssignment(new_solution); + EXPECT(assert_equal(expected, *new_solution)); } /* ************************************************************************* */ TEST_UNSAFE(Sudoku, easy) { - Sudoku sudoku(9, // - 0, 0, 5, 0, 9, 0, 0, 0, 1, // - 0, 0, 0, 0, 0, 2, 0, 7, 3, // - 7, 6, 0, 0, 0, 8, 2, 0, 0, // + Sudoku csp(9, // + 0, 0, 5, 0, 9, 0, 0, 0, 1, // + 0, 0, 0, 0, 0, 2, 0, 7, 3, // + 7, 6, 0, 0, 0, 8, 2, 0, 0, // - 0, 1, 2, 0, 0, 9, 0, 0, 4, // - 0, 0, 0, 2, 0, 3, 0, 0, 0, // - 3, 0, 0, 1, 0, 0, 9, 6, 0, // + 0, 1, 2, 0, 0, 9, 0, 0, 4, // + 0, 0, 0, 2, 0, 3, 0, 0, 0, // + 3, 0, 0, 1, 0, 0, 9, 6, 0, // - 0, 0, 1, 9, 0, 0, 0, 5, 8, // - 9, 7, 0, 5, 0, 0, 0, 0, 0, // - 5, 0, 0, 0, 3, 0, 7, 0, 0); + 0, 0, 1, 9, 0, 0, 0, 5, 8, // + 9, 7, 0, 5, 0, 0, 0, 0, 0, // + 5, 0, 0, 0, 3, 0, 7, 0, 0); - // Do BP - sudoku.runArcConsistency(4, 10, PRINT); + // csp.printSolution(); // don't do it - // sudoku.printSolution(); // don't do it + // Do BP (AC1) + auto domains = csp.runArcConsistency(9, 10); + // csp.printDomains(domains); + Key key99 = Symbol('9', 9); + Domain domain99 = domains.at(key99); + EXPECT_LONGS_EQUAL(1, domain99.nrValues()); + + // Test Creation of a new, simpler CSP + CSP new_csp = csp.partiallyApply(domains); + // 81 new Domains, and still 26 all-diff constraints + EXPECT_LONGS_EQUAL(81 + 26, new_csp.size()); + + // csp.printSolution(); // still don't do it ! :-( } /* ************************************************************************* */ TEST_UNSAFE(Sudoku, extreme) { - Sudoku sudoku(9, // - 0, 0, 9, 7, 4, 8, 0, 0, 0, 7, // - 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, // - 0, 1, 0, 9, 0, 0, 0, 0, 0, 7, // - 0, 0, 0, 2, 4, 0, 0, 6, 4, 0, // - 1, 0, 5, 9, 0, 0, 9, 8, 0, 0, // - 0, 3, 0, 0, 0, 0, 0, 8, 0, 3, // - 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, // - 0, 6, 0, 0, 0, 2, 7, 5, 9, 0, 0); + Sudoku csp(9, // + 0, 0, 9, 7, 4, 8, 0, 0, 0, 7, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, // + 0, 1, 0, 9, 0, 0, 0, 0, 0, 7, // + 0, 0, 0, 2, 4, 0, 0, 6, 4, 0, // + 1, 0, 5, 9, 0, 0, 9, 8, 0, 0, // + 0, 3, 0, 0, 0, 0, 0, 8, 0, 3, // + 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 6, 0, 0, 0, 2, 7, 5, 9, 0, 0); // Do BP - sudoku.runArcConsistency(9, 10, PRINT); + csp.runArcConsistency(9, 10); #ifdef METIS - VariableIndexOrdered index(sudoku); + VariableIndexOrdered index(csp); index.print("index"); ofstream os("/Users/dellaert/src/hmetis-1.5-osx-i686/extreme-dual.txt"); index.outputMetisFormat(os); #endif - // sudoku.printSolution(); // don't do it + // Do BP (AC1) + auto domains = csp.runArcConsistency(9, 10); + // csp.printDomains(domains); + Key key99 = Symbol('9', 9); + Domain domain99 = domains.at(key99); + EXPECT_LONGS_EQUAL(2, domain99.nrValues()); + + // Test Creation of a new, simpler CSP + CSP new_csp = csp.partiallyApply(domains); + // 81 new Domains, and still 20 all-diff constraints + EXPECT_LONGS_EQUAL(81 + 20, new_csp.size()); + + // csp.printSolution(); // still don't do it ! :-( } /* ************************************************************************* */ TEST_UNSAFE(Sudoku, AJC_3star_Feb8_2012) { - Sudoku sudoku(9, // - 9, 5, 0, 0, 0, 6, 0, 0, 0, // - 0, 8, 4, 0, 7, 0, 0, 0, 0, // - 6, 2, 0, 5, 0, 0, 4, 0, 0, // + Sudoku csp(9, // + 9, 5, 0, 0, 0, 6, 0, 0, 0, // + 0, 8, 4, 0, 7, 0, 0, 0, 0, // + 6, 2, 0, 5, 0, 0, 4, 0, 0, // - 0, 0, 0, 2, 9, 0, 6, 0, 0, // - 0, 9, 0, 0, 0, 0, 0, 2, 0, // - 0, 0, 2, 0, 6, 3, 0, 0, 0, // + 0, 0, 0, 2, 9, 0, 6, 0, 0, // + 0, 9, 0, 0, 0, 0, 0, 2, 0, // + 0, 0, 2, 0, 6, 3, 0, 0, 0, // - 0, 0, 9, 0, 0, 7, 0, 6, 8, // - 0, 0, 0, 0, 3, 0, 2, 9, 0, // - 0, 0, 0, 1, 0, 0, 0, 3, 7); + 0, 0, 9, 0, 0, 7, 0, 6, 8, // + 0, 0, 0, 0, 3, 0, 2, 9, 0, // + 0, 0, 0, 1, 0, 0, 0, 3, 7); - // Do BP - sudoku.runArcConsistency(9, 10, PRINT); + // Do BP (AC1) + auto domains = csp.runArcConsistency(9, 10); + // csp.printDomains(domains); + Key key99 = Symbol('9', 9); + Domain domain99 = domains.at(key99); + EXPECT_LONGS_EQUAL(1, domain99.nrValues()); - // sudoku.printSolution(); // don't do it + // Test Creation of a new, simpler CSP + CSP new_csp = csp.partiallyApply(domains); + // Just the 81 new Domains + EXPECT_LONGS_EQUAL(81, new_csp.size()); + + // Check that solution + CSP::sharedValues solution = new_csp.optimalAssignment(); + // csp.printAssignment(solution); + EXPECT_LONGS_EQUAL(6, solution->at(key99)); } /* ************************************************************************* */