diff --git a/gtsam_unstable/discrete/AllDiff.cpp b/gtsam_unstable/discrete/AllDiff.cpp index d6e1c6453..85cf0b472 100644 --- a/gtsam_unstable/discrete/AllDiff.cpp +++ b/gtsam_unstable/discrete/AllDiff.cpp @@ -57,21 +57,25 @@ DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const { } /* ************************************************************************* */ -bool AllDiff::ensureArcConsistency(size_t j, - std::vector& domains) const { +bool AllDiff::ensureArcConsistency(Key j, Domains* domains) const { + Domain& Dj = domains->at(j); + // Though strictly not part of allDiff, we check for - // a value in domains[j] that does not occur in any other connected domain. + // a value in domains->at(j) that does not occur in any other connected domain. // If found, we make this a singleton... // TODO: make a new constraint where this really is true - Domain& Dj = domains[j]; - if (Dj.checkAllDiff(keys_, domains)) return true; + boost::optional maybeChanged = Dj.checkAllDiff(keys_, *domains); + if (maybeChanged) { + Dj = *maybeChanged; + return true; + } - // Check all other domains for singletons and erase corresponding values + // Check all other domains for singletons and erase corresponding values. // This is the same as arc-consistency on the equivalent binary constraints bool changed = false; for (Key k : keys_) if (k != j) { - const Domain& Dk = domains[k]; + const Domain& Dk = domains->at(k); if (Dk.isSingleton()) { // check if singleton size_t value = Dk.firstValue(); if (Dj.contains(value)) { @@ -96,10 +100,10 @@ Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const { /* ************************************************************************* */ Constraint::shared_ptr AllDiff::partiallyApply( - const std::vector& domains) const { + const Domains& domains) const { DiscreteFactor::Values known; for (Key k : keys_) { - const Domain& Dk = domains[k]; + const Domain& Dk = domains.at(k); if (Dk.isSingleton()) known[k] = Dk.firstValue(); } return partiallyApply(known); diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index b0fd1d631..57b0aeb5c 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -13,11 +13,8 @@ namespace gtsam { /** - * General AllDiff constraint - * Returns 1 if values for all keys are different, 0 otherwise - * DiscreteFactors are all awkward in that they have to store two types of keys: - * for each variable we have a Key and an Key. In this factor, we - * keep the Indices locally, and the Indices are stored in IndexFactor. + * General AllDiff constraint. + * Returns 1 if values for all keys are different, 0 otherwise. */ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { std::map cardinalities_; @@ -28,7 +25,7 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { } public: - /// Constructor + /// Construct from keys. AllDiff(const DiscreteKeys& dkeys); // print @@ -57,21 +54,19 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; /* - * Ensure Arc-consistency - * Arc-consistency involves creating binaryAllDiff constraints - * In which case the combinatorial hyper-arc explosion disappears. + * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked - * @param domains all other domains + * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @return true if domains->at(j) was changed, false otherwise. */ - bool ensureArcConsistency(size_t j, - std::vector& domains) const override; + bool ensureArcConsistency(Key j, Domains* domains) const override; /// Partially apply known values Constraint::shared_ptr partiallyApply(const Values&) const override; /// Partially apply known values, domain version Constraint::shared_ptr partiallyApply( - const std::vector&) const override; + const Domains&) const override; }; } // namespace gtsam diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index d8e1a590a..a2c7ba660 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -15,10 +15,7 @@ namespace gtsam { /** * Binary AllDiff constraint - * Returns 1 if values for two keys are different, 0 otherwise - * DiscreteFactors are all awkward in that they have to store two types of keys: - * for each variable we have a Index and an Index. In this factor, we - * keep the Indices locally, and the Indices are stored in IndexFactor. + * Returns 1 if values for two keys are different, 0 otherwise. */ class BinaryAllDiff : public Constraint { size_t cardinality0_, cardinality1_; /// cardinality @@ -73,14 +70,14 @@ class BinaryAllDiff : public Constraint { } /* - * Ensure Arc-consistency + * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked - * @param domains all other domains + * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @return true if domains->at(j) was changed, false otherwise. */ - bool ensureArcConsistency(size_t j, - std::vector& domains) const override { - // throw std::runtime_error( - // "BinaryAllDiff::ensureArcConsistency not implemented"); + bool ensureArcConsistency(Key j, Domains* domains) const override { + throw std::runtime_error( + "BinaryAllDiff::ensureArcConsistency not implemented"); return false; } @@ -91,7 +88,7 @@ class BinaryAllDiff : public Constraint { /// Partially apply known values, domain version Constraint::shared_ptr partiallyApply( - const std::vector&) const override { + const Domains&) const override { throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); } }; diff --git a/gtsam_unstable/discrete/CSP.cpp b/gtsam_unstable/discrete/CSP.cpp index ad88469c5..1b38d20d6 100644 --- a/gtsam_unstable/discrete/CSP.cpp +++ b/gtsam_unstable/discrete/CSP.cpp @@ -25,82 +25,75 @@ CSP::Values CSP::optimalAssignment(const Ordering& ordering) const { return chordal->optimize(); } -void CSP::runArcConsistency(size_t cardinality, size_t nrIterations, - bool print) const { +bool CSP::runArcConsistency(const VariableIndex& index, + Domains* domains) const { + bool changed = false; + + // iterate over all variables in the index + for (auto entry : index) { + // Get the variable's key and associated factors: + const Key key = entry.first; + const FactorIndices& factors = entry.second; + + // If this domain is already a singleton, we do nothing. + if (domains->at(key).isSingleton()) continue; + + // Otherwise, loop over all factors/constraints for variable with given key. + for (size_t f : factors) { + // If this factor is a constraint, call its ensureArcConsistency method: + auto constraint = boost::dynamic_pointer_cast((*this)[f]); + if (constraint) { + changed = constraint->ensureArcConsistency(key, domains) || changed; + } + } + } + return changed; +} + +// TODO(dellaert): This is AC1, which is inefficient as any change will cause +// the algorithm to revisit *all* variables again. Implement AC3. +Domains CSP::runArcConsistency(size_t cardinality, size_t maxIterations) const { // Create VariableIndex VariableIndex index(*this); - // index.print(); - - size_t n = index.size(); // Initialize domains - std::vector domains; - for (size_t j = 0; j < n; j++) - domains.push_back(Domain(DiscreteKey(j, cardinality))); + Domains domains; + for (auto entry : index) { + const Key key = entry.first; + domains.emplace(key, DiscreteKey(key, cardinality)); + } - // Create array of flags indicating a domain changed or not - std::vector changed(n); + // Iterate until convergence or not a single domain changed. + for (size_t it = 0; it < maxIterations; it++) { + bool changed = runArcConsistency(index, &domains); + if (!changed) break; + } + return domains; +} - // iterate nrIterations over entire grid - for (size_t it = 0; it < nrIterations; it++) { - bool anyChange = false; - // iterate over all cells - for (size_t v = 0; v < n; v++) { - // keep track of which domains changed - changed[v] = false; - // loop over all factors/constraints for variable v - const FactorIndices& factors = index[v]; - for (size_t f : factors) { - // if not already a singleton - if (!domains[v].isSingleton()) { - // get the constraint and call its ensureArcConsistency method - Constraint::shared_ptr constraint = - boost::dynamic_pointer_cast((*this)[f]); - if (!constraint) - throw runtime_error("CSP:runArcConsistency: non-constraint factor"); - changed[v] = - constraint->ensureArcConsistency(v, domains) || changed[v]; - } - } // f - if (changed[v]) anyChange = true; - } // v - if (!anyChange) break; - // TODO: Sudoku specific hack - if (print) { - if (cardinality == 9 && n == 81) { - for (size_t i = 0, v = 0; i < (size_t)std::sqrt((double)n); i++) { - for (size_t j = 0; j < (size_t)std::sqrt((double)n); j++, v++) { - if (changed[v]) cout << "*"; - domains[v].print(); - cout << "\t"; - } // i - cout << endl; - } // j - } else { - for (size_t v = 0; v < n; v++) { - if (changed[v]) cout << "*"; - domains[v].print(); - cout << "\t"; - } // v - } - cout << endl; - } // print - } // it - -#ifndef INPROGRESS - // Now create new problem with all singleton variables removed - // We do this by adding simplifying all factors using parial application +CSP CSP::partiallyApply(const Domains& domains) const { + // Create new problem with all singleton variables removed + // We do this by adding simplifying all factors using partial application. // TODO: create a new ordering as we go, to ensure a connected graph // KeyOrdering ordering; // vector dkeys; + CSP new_csp; + + // Add tightened domains as new factors: + for (auto key_domain : domains) { + new_csp.emplace_shared(key_domain.second); + } + + // Reduce all existing factors: for (const DiscreteFactor::shared_ptr& f : factors_) { - Constraint::shared_ptr constraint = - boost::dynamic_pointer_cast(f); + auto constraint = boost::dynamic_pointer_cast(f); if (!constraint) throw runtime_error("CSP:runArcConsistency: non-constraint factor"); Constraint::shared_ptr reduced = constraint->partiallyApply(domains); - if (print) reduced->print(); + if (reduced->size() > 1) { + new_csp.push_back(reduced); + } } -#endif + return new_csp; } } // namespace gtsam diff --git a/gtsam_unstable/discrete/CSP.h b/gtsam_unstable/discrete/CSP.h index 098ef56c9..0d4e0b177 100644 --- a/gtsam_unstable/discrete/CSP.h +++ b/gtsam_unstable/discrete/CSP.h @@ -61,7 +61,7 @@ class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph { // deep. // * It will be very expensive to exclude values that way. // */ - // void applyBeliefPropagation(size_t nrIterations = 10) const; + // void applyBeliefPropagation(size_t maxIterations = 10) const; /* * Apply arc-consistency ~ Approximate loopy belief propagation @@ -69,8 +69,16 @@ class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph { * a domain whose values don't conflict in the arc-consistency way. * TODO: should get cardinality from DiscreteKeys */ - void runArcConsistency(size_t cardinality, size_t nrIterations = 10, - bool print = false) const; + Domains runArcConsistency(size_t cardinality, + size_t maxIterations = 10) const; + + /// Run arc consistency for all variables, return true if any domain changed. + bool runArcConsistency(const VariableIndex& index, Domains* domains) const; + + /* + * Create a new CSP, applying the given Domain constraints. + */ + CSP partiallyApply(const Domains& domains) const; }; // CSP } // namespace gtsam diff --git a/gtsam_unstable/discrete/Constraint.h b/gtsam_unstable/discrete/Constraint.h index b8baccff9..f0e51b723 100644 --- a/gtsam_unstable/discrete/Constraint.h +++ b/gtsam_unstable/discrete/Constraint.h @@ -21,30 +21,32 @@ #include #include +#include namespace gtsam { class Domain; +using Domains = std::map; /** - * Base class for discrete probabilistic factors - * The most general one is the derived DecisionTreeFactor + * Base class for constraint factors + * Derived classes include SingleValue, BinaryAllDiff, and AllDiff. */ -class Constraint : public DiscreteFactor { +class GTSAM_EXPORT Constraint : public DiscreteFactor { public: typedef boost::shared_ptr shared_ptr; protected: - /// Construct n-way factor - Constraint(const KeyVector& js) : DiscreteFactor(js) {} - - /// Construct unary factor + /// Construct unary constraint factor. Constraint(Key j) : DiscreteFactor(boost::assign::cref_list_of<1>(j)) {} - /// Construct binary factor + /// Construct binary constraint factor. Constraint(Key j1, Key j2) : DiscreteFactor(boost::assign::cref_list_of<2>(j1)(j2)) {} + /// Construct n-way constraint factor. + Constraint(const KeyVector& js) : DiscreteFactor(js) {} + /// construct from container template Constraint(KeyIterator beginKey, KeyIterator endKey) @@ -65,18 +67,18 @@ class Constraint : public DiscreteFactor { /// @{ /* - * Ensure Arc-consistency + * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked - * @param domains all other domains + * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @return true if domains->at(j) was changed, false otherwise. */ - virtual bool ensureArcConsistency(size_t j, - std::vector& domains) const = 0; + virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0; /// Partially apply known values virtual shared_ptr partiallyApply(const Values&) const = 0; /// Partially apply known values, domain version - virtual shared_ptr partiallyApply(const std::vector&) const = 0; + virtual shared_ptr partiallyApply(const Domains&) const = 0; /// @} }; // DiscreteFactor diff --git a/gtsam_unstable/discrete/Domain.cpp b/gtsam_unstable/discrete/Domain.cpp index a81b1d1ad..98b735c6c 100644 --- a/gtsam_unstable/discrete/Domain.cpp +++ b/gtsam_unstable/discrete/Domain.cpp @@ -10,29 +10,35 @@ #include #include - +#include namespace gtsam { using namespace std; /* ************************************************************************* */ void Domain::print(const string& s, const KeyFormatter& formatter) const { - // cout << s << ": Domain on " << formatter(keys_[0]) << " (j=" << - // formatter(keys_[0]) << ") with values"; - // for (size_t v: values_) cout << " " << v; - // cout << endl; - for (size_t v : values_) cout << v; + cout << s << ": Domain on " << formatter(key()) << " (j=" << formatter(key()) + << ") with values"; + for (size_t v : values_) cout << " " << v; + cout << endl; +} + +/* ************************************************************************* */ +string Domain::base1Str() const { + stringstream ss; + for (size_t v : values_) ss << v + 1; + return ss.str(); } /* ************************************************************************* */ double Domain::operator()(const Values& values) const { - return contains(values.at(keys_[0])); + return contains(values.at(key())); } /* ************************************************************************* */ DecisionTreeFactor Domain::toDecisionTreeFactor() const { DiscreteKeys keys; - keys += DiscreteKey(keys_[0], cardinality_); + keys += DiscreteKey(key(), cardinality_); vector table; for (size_t i1 = 0; i1 < cardinality_; ++i1) table.push_back(contains(i1)); DecisionTreeFactor converted(keys, table); @@ -46,9 +52,9 @@ DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const { } /* ************************************************************************* */ -bool Domain::ensureArcConsistency(size_t j, vector& domains) const { - if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain"); - Domain& D = domains[j]; +bool Domain::ensureArcConsistency(Key j, Domains* domains) const { + if (j != key()) throw invalid_argument("Domain check on wrong domain"); + Domain& D = domains->at(j); for (size_t value : values_) if (!D.contains(value)) throw runtime_error("Unsatisfiable"); D = *this; @@ -56,34 +62,33 @@ bool Domain::ensureArcConsistency(size_t j, vector& domains) const { } /* ************************************************************************* */ -bool Domain::checkAllDiff(const KeyVector keys, vector& domains) { - Key j = keys_[0]; +boost::optional Domain::checkAllDiff(const KeyVector keys, + const Domains& domains) const { + Key j = key(); // for all values in this domain - for (size_t value : values_) { + for (const size_t value : values_) { // for all connected domains - for (Key k : keys) + for (const Key k : keys) // if any domain contains the value we cannot make this domain singleton - if (k != j && domains[k].contains(value)) goto found; - values_.clear(); - values_.insert(value); - return true; // we changed it + if (k != j && domains.at(k).contains(value)) goto found; + // Otherwise: return a singleton: + return Domain(this->discreteKey(), value); found:; } - return false; // we did not change it + return boost::none; // we did not change it } /* ************************************************************************* */ Constraint::shared_ptr Domain::partiallyApply(const Values& values) const { - Values::const_iterator it = values.find(keys_[0]); + Values::const_iterator it = values.find(key()); if (it != values.end() && !contains(it->second)) throw runtime_error("Domain::partiallyApply: unsatisfiable"); return boost::make_shared(*this); } /* ************************************************************************* */ -Constraint::shared_ptr Domain::partiallyApply( - const vector& domains) const { - const Domain& Dk = domains[keys_[0]]; +Constraint::shared_ptr Domain::partiallyApply(const Domains& domains) const { + const Domain& Dk = domains.at(key()); if (Dk.isSingleton() && !contains(*Dk.begin())) throw runtime_error("Domain::partiallyApply: unsatisfiable"); return boost::make_shared(Dk); diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 15828b653..ae137ca33 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -13,7 +13,8 @@ namespace gtsam { /** - * Domain restriction constraint + * The Domain class represents a constraint that restricts the possible values a + * particular variable, with given key, can take on. */ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { size_t cardinality_; /// Cardinality @@ -35,14 +36,16 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { values_.insert(v); } - /// Constructor - Domain(const Domain& other) - : Constraint(other.keys_[0]), values_(other.values_) {} + /// The one key + Key key() const { return keys_[0]; } - /// insert a value, non const :-( + // The associated discrete key + DiscreteKey discreteKey() const { return DiscreteKey(key(), cardinality_); } + + /// Insert a value, non const :-( void insert(size_t value) { values_.insert(value); } - /// erase a value, non const :-( + /// Erase a value, non const :-( void erase(size_t value) { values_.erase(value); } size_t nrValues() const { return values_.size(); } @@ -65,6 +68,11 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { } } + // Return concise string representation, mostly to debug arc consistency. + // Converts from base 0 to base1. + std::string base1Str() const; + + // Check whether domain cotains a specific value. bool contains(size_t value) const { return values_.count(value) > 0; } /// Calculate value @@ -77,27 +85,29 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; /* - * Ensure Arc-consistency + * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked - * @param domains all other domains + * @param (in/out) domains all domains, but only domains->at(j) will be + * checked. + * @return true if domains->at(j) was changed, false otherwise. */ - bool ensureArcConsistency(size_t j, - std::vector& domains) const override; + bool ensureArcConsistency(Key j, Domains* domains) const override; /** - * Check for a value in domain that does not occur in any other connected - * domain. If found, we make this a singleton... Called in - * AllDiff::ensureArcConsistency - * @param keys connected domains through alldiff + * Check for a value in domain that does not occur in any other connected + * domain. If found, return a a new singleton domain... + * Called in AllDiff::ensureArcConsistency + * @param keys connected domains through alldiff + * @param keys other domains */ - bool checkAllDiff(const KeyVector keys, std::vector& domains); + boost::optional checkAllDiff(const KeyVector keys, + const Domains& domains) const; /// Partially apply known values Constraint::shared_ptr partiallyApply(const Values& values) const override; /// Partially apply known values, domain version - Constraint::shared_ptr partiallyApply( - const std::vector& domains) const override; + Constraint::shared_ptr partiallyApply(const Domains& domains) const override; }; } // namespace gtsam diff --git a/gtsam_unstable/discrete/SingleValue.cpp b/gtsam_unstable/discrete/SingleValue.cpp index 105887dc9..162e21512 100644 --- a/gtsam_unstable/discrete/SingleValue.cpp +++ b/gtsam_unstable/discrete/SingleValue.cpp @@ -44,11 +44,10 @@ DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const { } /* ************************************************************************* */ -bool SingleValue::ensureArcConsistency(size_t j, - vector& domains) const { +bool SingleValue::ensureArcConsistency(Key j, Domains* domains) const { if (j != keys_[0]) throw invalid_argument("SingleValue check on wrong domain"); - Domain& D = domains[j]; + Domain& D = domains->at(j); if (D.isSingleton()) { if (D.firstValue() != value_) throw runtime_error("Unsatisfiable"); return false; @@ -67,8 +66,8 @@ Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const { /* ************************************************************************* */ Constraint::shared_ptr SingleValue::partiallyApply( - const vector& domains) const { - const Domain& Dk = domains[keys_[0]]; + const Domains& domains) const { + const Domain& Dk = domains.at(keys_[0]); if (Dk.isSingleton() && !Dk.contains(value_)) throw runtime_error("SingleValue::partiallyApply: unsatisfiable"); return boost::make_shared(discreteKey(), value_); diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index a2aec338c..d826093df 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -13,14 +13,12 @@ namespace gtsam { /** - * SingleValue constraint + * SingleValue constraint: ensures a variable takes on a certain value. + * This could of course also be implemented by changing its `Domain`. */ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { - /// Number of values - size_t cardinality_; - - /// allowed value - size_t value_; + size_t cardinality_; /// < Number of values + size_t value_; ///< allowed value DiscreteKey discreteKey() const { return DiscreteKey(keys_[0], cardinality_); @@ -29,11 +27,11 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { public: typedef boost::shared_ptr shared_ptr; - /// Constructor + /// Construct from key, cardinality, and given value. SingleValue(Key key, size_t n, size_t value) : Constraint(key), cardinality_(n), value_(value) {} - /// Constructor + /// Construct from DiscreteKey and given value. SingleValue(const DiscreteKey& dkey, size_t value) : Constraint(dkey.first), cardinality_(dkey.second), value_(value) {} @@ -61,19 +59,19 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; /* - * Ensure Arc-consistency + * Ensure Arc-consistency: just sets domain[j] to {value_}. * @param j domain to be checked - * @param domains all other domains + * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @return true if domains->at(j) was changed, false otherwise. */ - bool ensureArcConsistency(size_t j, - std::vector& domains) const override; + bool ensureArcConsistency(Key j, Domains* domains) const override; /// Partially apply known values Constraint::shared_ptr partiallyApply(const Values& values) const override; /// Partially apply known values, domain version Constraint::shared_ptr partiallyApply( - const std::vector& domains) const override; + const Domains& domains) const override; }; } // namespace gtsam diff --git a/gtsam_unstable/discrete/tests/testCSP.cpp b/gtsam_unstable/discrete/tests/testCSP.cpp index e27cd3486..4a3b22561 100644 --- a/gtsam_unstable/discrete/tests/testCSP.cpp +++ b/gtsam_unstable/discrete/tests/testCSP.cpp @@ -19,12 +19,34 @@ using namespace std; using namespace gtsam; /* ************************************************************************* */ -TEST_UNSAFE(BinaryAllDif, allInOne) { - // Create keys and ordering +TEST(CSP, SingleValue) { + // Create keys for Idaho, Arizona, and Utah, allowing two colors for each: + size_t nrColors = 3; + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); + + // Check that a single value is equal to a decision stump with only one "1": + SingleValue singleValue(AZ, 2); + DecisionTreeFactor f1(AZ, "0 0 1"); + EXPECT(assert_equal(f1, singleValue.toDecisionTreeFactor())); + + // Create domains + Domains domains; + domains.emplace(0, Domain(ID)); + domains.emplace(1, Domain(AZ)); + domains.emplace(2, Domain(UT)); + + // Ensure arc-consistency: just wipes out values in AZ domain: + EXPECT(singleValue.ensureArcConsistency(1, &domains)); + LONGS_EQUAL(3, domains.at(0).nrValues()); + LONGS_EQUAL(1, domains.at(1).nrValues()); + LONGS_EQUAL(3, domains.at(2).nrValues()); +} + +/* ************************************************************************* */ +TEST(CSP, BinaryAllDif) { + // Create keys for Idaho, Arizona, and Utah, allowing 2 colors for each: size_t nrColors = 2; - // DiscreteKey ID("Idaho", nrColors), UT("Utah", nrColors), AZ("Arizona", - // nrColors); - DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); // Check construction and conversion BinaryAllDiff c1(ID, UT); @@ -36,16 +58,53 @@ TEST_UNSAFE(BinaryAllDif, allInOne) { DecisionTreeFactor f2(UT & AZ, "0 1 1 0"); EXPECT(assert_equal(f2, c2.toDecisionTreeFactor())); + // Check multiplication of factors with constraint: DecisionTreeFactor f3 = f1 * f2; EXPECT(assert_equal(f3, c1 * f2)); EXPECT(assert_equal(f3, c2 * f1)); } /* ************************************************************************* */ -TEST_UNSAFE(CSP, allInOne) { - // Create keys and ordering +TEST(CSP, AllDiff) { + // Create keys for Idaho, Arizona, and Utah, allowing two colors for each: + size_t nrColors = 3; + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); + + // Check construction and conversion + vector dkeys{ID, UT, AZ}; + AllDiff alldiff(dkeys); + DecisionTreeFactor actual = alldiff.toDecisionTreeFactor(); + // GTSAM_PRINT(actual); + actual.dot("actual"); + DecisionTreeFactor f2( + ID & AZ & UT, + "0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0"); + EXPECT(assert_equal(f2, actual)); + + // Create domains. + Domains domains; + domains.emplace(0, Domain(ID)); + domains.emplace(1, Domain(AZ)); + domains.emplace(2, Domain(UT)); + + // First constrict AZ domain: + SingleValue singleValue(AZ, 2); + EXPECT(singleValue.ensureArcConsistency(1, &domains)); + + // Arc-consistency + EXPECT(alldiff.ensureArcConsistency(0, &domains)); + EXPECT(!alldiff.ensureArcConsistency(1, &domains)); + EXPECT(alldiff.ensureArcConsistency(2, &domains)); + LONGS_EQUAL(2, domains.at(0).nrValues()); + LONGS_EQUAL(1, domains.at(1).nrValues()); + LONGS_EQUAL(2, domains.at(2).nrValues()); +} + +/* ************************************************************************* */ +TEST(CSP, allInOne) { + // Create keys for Idaho, Arizona, and Utah, allowing 3 colors for each: size_t nrColors = 2; - DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); // Create the CSP CSP csp; @@ -81,15 +140,12 @@ TEST_UNSAFE(CSP, allInOne) { } /* ************************************************************************* */ -TEST_UNSAFE(CSP, WesternUS) { - // Create keys +TEST(CSP, WesternUS) { + // Create keys for all states in Western US, with 4 color possibilities. size_t nrColors = 4; - DiscreteKey - // Create ordering according to example in ND-CSP.lyx - WA(0, nrColors), - OR(3, nrColors), CA(1, nrColors), NV(2, nrColors), ID(8, nrColors), - UT(9, nrColors), AZ(10, nrColors), MT(4, nrColors), WY(5, nrColors), - CO(7, nrColors), NM(6, nrColors); + DiscreteKey WA(0, nrColors), OR(3, nrColors), CA(1, nrColors), + NV(2, nrColors), ID(8, nrColors), UT(9, nrColors), AZ(10, nrColors), + MT(4, nrColors), WY(5, nrColors), CO(7, nrColors), NM(6, nrColors); // Create the CSP CSP csp; @@ -116,10 +172,11 @@ TEST_UNSAFE(CSP, WesternUS) { csp.addAllDiff(WY, CO); csp.addAllDiff(CO, NM); - // Solve + // Create ordering according to example in ND-CSP.lyx Ordering ordering; ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7), Key(8), Key(9), Key(10); + // Solve using that ordering: auto mpe = csp.optimalAssignment(ordering); // GTSAM_PRINT(mpe); CSP::Values expected; @@ -143,33 +200,17 @@ TEST_UNSAFE(CSP, WesternUS) { } /* ************************************************************************* */ -TEST_UNSAFE(CSP, AllDiff) { - // Create keys and ordering +TEST(CSP, ArcConsistency) { + // Create keys for Idaho, Arizona, and Utah, allowing three colors for each: size_t nrColors = 3; - DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); - // Create the CSP + // Create the CSP using just one all-diff constraint, plus constrain Arizona. CSP csp; - vector dkeys; - dkeys += ID, UT, AZ; + vector dkeys{ID, UT, AZ}; csp.addAllDiff(dkeys); csp.addSingleValue(AZ, 2); - // GTSAM_PRINT(csp); - - // Check construction and conversion - SingleValue s(AZ, 2); - DecisionTreeFactor f1(AZ, "0 0 1"); - EXPECT(assert_equal(f1, s.toDecisionTreeFactor())); - - // Check construction and conversion - AllDiff alldiff(dkeys); - DecisionTreeFactor actual = alldiff.toDecisionTreeFactor(); - // GTSAM_PRINT(actual); - // actual.dot("actual"); - DecisionTreeFactor f2( - ID & AZ & UT, - "0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0"); - EXPECT(assert_equal(f2, actual)); + // GTSAM_PRINT(csp); // Check an invalid combination, with ID==UT==AZ all same color DiscreteFactor::Values invalid; @@ -192,17 +233,21 @@ TEST_UNSAFE(CSP, AllDiff) { EXPECT(assert_equal(expected, mpe)); EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9); - // Arc-consistency - vector domains; - domains += Domain(ID), Domain(AZ), Domain(UT); + // ensure arc-consistency, i.e., narrow domains... + Domains domains; + domains.emplace(0, Domain(ID)); + domains.emplace(1, Domain(AZ)); + domains.emplace(2, Domain(UT)); + SingleValue singleValue(AZ, 2); - EXPECT(singleValue.ensureArcConsistency(1, domains)); - EXPECT(alldiff.ensureArcConsistency(0, domains)); - EXPECT(!alldiff.ensureArcConsistency(1, domains)); - EXPECT(alldiff.ensureArcConsistency(2, domains)); - LONGS_EQUAL(2, domains[0].nrValues()); - LONGS_EQUAL(1, domains[1].nrValues()); - LONGS_EQUAL(2, domains[2].nrValues()); + AllDiff alldiff(dkeys); + EXPECT(singleValue.ensureArcConsistency(1, &domains)); + EXPECT(alldiff.ensureArcConsistency(0, &domains)); + EXPECT(!alldiff.ensureArcConsistency(1, &domains)); + EXPECT(alldiff.ensureArcConsistency(2, &domains)); + LONGS_EQUAL(2, domains.at(0).nrValues()); + LONGS_EQUAL(1, domains.at(1).nrValues()); + LONGS_EQUAL(2, domains.at(2).nrValues()); // Parial application, version 1 DiscreteFactor::Values known; @@ -222,6 +267,7 @@ TEST_UNSAFE(CSP, AllDiff) { // full arc-consistency test csp.runArcConsistency(nrColors); + // GTSAM_PRINT(csp); } /* ************************************************************************* */ diff --git a/gtsam_unstable/discrete/tests/testSudoku.cpp b/gtsam_unstable/discrete/tests/testSudoku.cpp index cedf900a2..808f98b1c 100644 --- a/gtsam_unstable/discrete/tests/testSudoku.cpp +++ b/gtsam_unstable/discrete/tests/testSudoku.cpp @@ -6,6 +6,7 @@ */ #include +#include #include #include @@ -20,12 +21,12 @@ using namespace gtsam; #define PRINT false +/// A class that encodes Sudoku's as a CSP problem class Sudoku : public CSP { - /// sudoku size - size_t n_; + size_t n_; ///< Side of Sudoku, e.g. 4 or 9 - /// discrete keys - typedef std::pair IJ; + /// Mapping from base i,j coordinates to discrete keys: + using IJ = std::pair; std::map dkeys_; public: @@ -42,15 +43,14 @@ class Sudoku : public CSP { // Create variables, ordering, and unary constraints va_list ap; va_start(ap, n); - Key k = 0; for (size_t i = 0; i < n; ++i) { - for (size_t j = 0; j < n; ++j, ++k) { + for (size_t j = 0; j < n; ++j) { // create the key IJ ij(i, j); - dkeys_[ij] = DiscreteKey(k, n); + Symbol key('1' + i, j + 1); + dkeys_[ij] = DiscreteKey(key, n); // get the unary constraint, if any int value = va_arg(ap, int); - // cout << value << " "; if (value != 0) addSingleValue(dkeys_[ij], value - 1); } // cout << endl; @@ -99,23 +99,32 @@ class Sudoku : public CSP { } /// solve and print solution - void printSolution() { + void printSolution() const { auto MPE = optimalAssignment(); printAssignment(MPE); } + + // Print domain + void printDomains(const Domains& domains) { + for (size_t i = 0; i < n_; i++) { + for (size_t j = 0; j < n_; j++) { + Key k = key(i, j); + cout << domains.at(k).base1Str(); + cout << "\t"; + } // i + cout << endl; + } // j + } }; /* ************************************************************************* */ -TEST_UNSAFE(Sudoku, small) { +TEST(Sudoku, small) { Sudoku csp(4, // 1, 0, 0, 4, // 0, 0, 0, 0, // 4, 0, 2, 0, // 0, 1, 0, 0); - // Do BP - csp.runArcConsistency(4, 10, PRINT); - // optimize and check auto solution = csp.optimalAssignment(); CSP::Values expected; @@ -126,73 +135,124 @@ TEST_UNSAFE(Sudoku, small) { csp.key(3, 3), 2); EXPECT(assert_equal(expected, solution)); // csp.printAssignment(solution); + + // Do BP (AC1) + auto domains = csp.runArcConsistency(4, 3); + // csp.printDomains(domains); + Domain domain44 = domains.at(Symbol('4', 4)); + EXPECT_LONGS_EQUAL(1, domain44.nrValues()); + + // Test Creation of a new, simpler CSP + CSP new_csp = csp.partiallyApply(domains); + // Should only be 16 new Domains + EXPECT_LONGS_EQUAL(16, new_csp.size()); + + // Check that solution + auto new_solution = new_csp.optimalAssignment(); + // csp.printAssignment(new_solution); + EXPECT(assert_equal(expected, new_solution)); } /* ************************************************************************* */ -TEST_UNSAFE(Sudoku, easy) { - Sudoku sudoku(9, // - 0, 0, 5, 0, 9, 0, 0, 0, 1, // - 0, 0, 0, 0, 0, 2, 0, 7, 3, // - 7, 6, 0, 0, 0, 8, 2, 0, 0, // +TEST(Sudoku, easy) { + Sudoku csp(9, // + 0, 0, 5, 0, 9, 0, 0, 0, 1, // + 0, 0, 0, 0, 0, 2, 0, 7, 3, // + 7, 6, 0, 0, 0, 8, 2, 0, 0, // - 0, 1, 2, 0, 0, 9, 0, 0, 4, // - 0, 0, 0, 2, 0, 3, 0, 0, 0, // - 3, 0, 0, 1, 0, 0, 9, 6, 0, // + 0, 1, 2, 0, 0, 9, 0, 0, 4, // + 0, 0, 0, 2, 0, 3, 0, 0, 0, // + 3, 0, 0, 1, 0, 0, 9, 6, 0, // - 0, 0, 1, 9, 0, 0, 0, 5, 8, // - 9, 7, 0, 5, 0, 0, 0, 0, 0, // - 5, 0, 0, 0, 3, 0, 7, 0, 0); + 0, 0, 1, 9, 0, 0, 0, 5, 8, // + 9, 7, 0, 5, 0, 0, 0, 0, 0, // + 5, 0, 0, 0, 3, 0, 7, 0, 0); - // Do BP - sudoku.runArcConsistency(4, 10, PRINT); + // csp.printSolution(); // don't do it - // sudoku.printSolution(); // don't do it + // Do BP (AC1) + auto domains = csp.runArcConsistency(9, 10); + // csp.printDomains(domains); + Key key99 = Symbol('9', 9); + Domain domain99 = domains.at(key99); + EXPECT_LONGS_EQUAL(1, domain99.nrValues()); + + // Test Creation of a new, simpler CSP + CSP new_csp = csp.partiallyApply(domains); + // 81 new Domains, and still 26 all-diff constraints + EXPECT_LONGS_EQUAL(81 + 26, new_csp.size()); + + // csp.printSolution(); // still don't do it ! :-( } /* ************************************************************************* */ -TEST_UNSAFE(Sudoku, extreme) { - Sudoku sudoku(9, // - 0, 0, 9, 7, 4, 8, 0, 0, 0, 7, // - 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, // - 0, 1, 0, 9, 0, 0, 0, 0, 0, 7, // - 0, 0, 0, 2, 4, 0, 0, 6, 4, 0, // - 1, 0, 5, 9, 0, 0, 9, 8, 0, 0, // - 0, 3, 0, 0, 0, 0, 0, 8, 0, 3, // - 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, // - 0, 6, 0, 0, 0, 2, 7, 5, 9, 0, 0); +TEST(Sudoku, extreme) { + Sudoku csp(9, // + 0, 0, 9, 7, 4, 8, 0, 0, 0, 7, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, // + 0, 1, 0, 9, 0, 0, 0, 0, 0, 7, // + 0, 0, 0, 2, 4, 0, 0, 6, 4, 0, // + 1, 0, 5, 9, 0, 0, 9, 8, 0, 0, // + 0, 3, 0, 0, 0, 0, 0, 8, 0, 3, // + 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 6, 0, 0, 0, 2, 7, 5, 9, 0, 0); // Do BP - sudoku.runArcConsistency(9, 10, PRINT); + csp.runArcConsistency(9, 10); #ifdef METIS - VariableIndexOrdered index(sudoku); + VariableIndexOrdered index(csp); index.print("index"); ofstream os("/Users/dellaert/src/hmetis-1.5-osx-i686/extreme-dual.txt"); index.outputMetisFormat(os); #endif - // sudoku.printSolution(); // don't do it + // Do BP (AC1) + auto domains = csp.runArcConsistency(9, 10); + // csp.printDomains(domains); + Key key99 = Symbol('9', 9); + Domain domain99 = domains.at(key99); + EXPECT_LONGS_EQUAL(2, domain99.nrValues()); + + // Test Creation of a new, simpler CSP + CSP new_csp = csp.partiallyApply(domains); + // 81 new Domains, and still 20 all-diff constraints + EXPECT_LONGS_EQUAL(81 + 20, new_csp.size()); + + // csp.printSolution(); // still don't do it ! :-( } /* ************************************************************************* */ -TEST_UNSAFE(Sudoku, AJC_3star_Feb8_2012) { - Sudoku sudoku(9, // - 9, 5, 0, 0, 0, 6, 0, 0, 0, // - 0, 8, 4, 0, 7, 0, 0, 0, 0, // - 6, 2, 0, 5, 0, 0, 4, 0, 0, // +TEST(Sudoku, AJC_3star_Feb8_2012) { + Sudoku csp(9, // + 9, 5, 0, 0, 0, 6, 0, 0, 0, // + 0, 8, 4, 0, 7, 0, 0, 0, 0, // + 6, 2, 0, 5, 0, 0, 4, 0, 0, // - 0, 0, 0, 2, 9, 0, 6, 0, 0, // - 0, 9, 0, 0, 0, 0, 0, 2, 0, // - 0, 0, 2, 0, 6, 3, 0, 0, 0, // + 0, 0, 0, 2, 9, 0, 6, 0, 0, // + 0, 9, 0, 0, 0, 0, 0, 2, 0, // + 0, 0, 2, 0, 6, 3, 0, 0, 0, // - 0, 0, 9, 0, 0, 7, 0, 6, 8, // - 0, 0, 0, 0, 3, 0, 2, 9, 0, // - 0, 0, 0, 1, 0, 0, 0, 3, 7); + 0, 0, 9, 0, 0, 7, 0, 6, 8, // + 0, 0, 0, 0, 3, 0, 2, 9, 0, // + 0, 0, 0, 1, 0, 0, 0, 3, 7); - // Do BP - sudoku.runArcConsistency(9, 10, PRINT); + // Do BP (AC1) + auto domains = csp.runArcConsistency(9, 10); + // csp.printDomains(domains); + Key key99 = Symbol('9', 9); + Domain domain99 = domains.at(key99); + EXPECT_LONGS_EQUAL(1, domain99.nrValues()); - // sudoku.printSolution(); // don't do it + // Test Creation of a new, simpler CSP + CSP new_csp = csp.partiallyApply(domains); + // Just the 81 new Domains + EXPECT_LONGS_EQUAL(81, new_csp.size()); + + // Check that solution + auto solution = new_csp.optimalAssignment(); + // csp.printAssignment(solution); + EXPECT_LONGS_EQUAL(6, solution.at(key99)); } /* ************************************************************************* */