Cleaned up AC1 implementation

release/4.3a0
Frank Dellaert 2021-11-20 15:52:12 -05:00
parent 23bcf96da4
commit ad3225953b
12 changed files with 270 additions and 195 deletions

View File

@ -57,14 +57,11 @@ DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
bool AllDiff::ensureArcConsistency(size_t j, bool AllDiff::ensureArcConsistency(Key j, Domains* domains) const {
std::vector<Domain>* domains) const {
// We are changing the domain of variable j.
// TODO(dellaert): confusing, I thought we were changing others...
Domain& Dj = domains->at(j); Domain& Dj = domains->at(j);
// Though strictly not part of allDiff, we check for // Though strictly not part of allDiff, we check for
// a value in domains[j] that does not occur in any other connected domain. // a value in domains->at(j) that does not occur in any other connected domain.
// If found, we make this a singleton... // If found, we make this a singleton...
// TODO: make a new constraint where this really is true // TODO: make a new constraint where this really is true
boost::optional<Domain> maybeChanged = Dj.checkAllDiff(keys_, *domains); boost::optional<Domain> maybeChanged = Dj.checkAllDiff(keys_, *domains);
@ -103,10 +100,10 @@ Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const {
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr AllDiff::partiallyApply( Constraint::shared_ptr AllDiff::partiallyApply(
const std::vector<Domain>& domains) const { const Domains& domains) const {
DiscreteFactor::Values known; DiscreteFactor::Values known;
for (Key k : keys_) { for (Key k : keys_) {
const Domain& Dk = domains[k]; const Domain& Dk = domains.at(k);
if (Dk.isSingleton()) known[k] = Dk.firstValue(); if (Dk.isSingleton()) known[k] = Dk.firstValue();
} }
return partiallyApply(known); return partiallyApply(known);

View File

@ -54,21 +54,19 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
/* /*
* Ensure Arc-consistency * Ensure Arc-consistency by checking every possible value of domain j.
* Arc-consistency involves creating binaryAllDiff constraints
* In which case the combinatorial hyper-arc explosion disappears.
* @param j domain to be checked * @param j domain to be checked
* @param (in/out) domains all other domains * @param (in/out) domains all domains, but only domains->at(j) will be checked.
* @return true if domains->at(j) was changed, false otherwise.
*/ */
bool ensureArcConsistency(size_t j, bool ensureArcConsistency(Key j, Domains* domains) const override;
std::vector<Domain>* domains) const override;
/// Partially apply known values /// Partially apply known values
Constraint::shared_ptr partiallyApply(const Values&) const override; Constraint::shared_ptr partiallyApply(const Values&) const override;
/// Partially apply known values, domain version /// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply( Constraint::shared_ptr partiallyApply(
const std::vector<Domain>&) const override; const Domains&) const override;
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -70,13 +70,12 @@ class BinaryAllDiff : public Constraint {
} }
/* /*
* Ensure Arc-consistency * Ensure Arc-consistency by checking every possible value of domain j.
* @param j domain to be checked * @param j domain to be checked
* @param domains all other domains * @param (in/out) domains all domains, but only domains->at(j) will be checked.
* @return true if domains->at(j) was changed, false otherwise.
*/ */
/// bool ensureArcConsistency(Key j, Domains* domains) const override {
bool ensureArcConsistency(size_t j,
std::vector<Domain>* domains) const override {
throw std::runtime_error( throw std::runtime_error(
"BinaryAllDiff::ensureArcConsistency not implemented"); "BinaryAllDiff::ensureArcConsistency not implemented");
return false; return false;
@ -89,7 +88,7 @@ class BinaryAllDiff : public Constraint {
/// Partially apply known values, domain version /// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply( Constraint::shared_ptr partiallyApply(
const std::vector<Domain>&) const override { const Domains&) const override {
throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented");
} }
}; };

View File

@ -27,81 +27,75 @@ CSP::sharedValues CSP::optimalAssignment(const Ordering& ordering) const {
return mpe; return mpe;
} }
void CSP::runArcConsistency(size_t cardinality, size_t nrIterations, bool CSP::runArcConsistency(const VariableIndex& index,
bool print) const { Domains* domains) const {
bool changed = false;
// iterate over all variables in the index
for (auto entry : index) {
// Get the variable's key and associated factors:
const Key key = entry.first;
const FactorIndices& factors = entry.second;
// If this domain is already a singleton, we do nothing.
if (domains->at(key).isSingleton()) continue;
// Otherwise, loop over all factors/constraints for variable with given key.
for (size_t f : factors) {
// If this factor is a constraint, call its ensureArcConsistency method:
auto constraint = boost::dynamic_pointer_cast<Constraint>((*this)[f]);
if (constraint) {
changed = constraint->ensureArcConsistency(key, domains) || changed;
}
}
}
return changed;
}
// TODO(dellaert): This is AC1, which is inefficient as any change will cause
// the algorithm to revisit *all* variables again. Implement AC3.
Domains CSP::runArcConsistency(size_t cardinality, size_t maxIterations) const {
// Create VariableIndex // Create VariableIndex
VariableIndex index(*this); VariableIndex index(*this);
// index.print();
size_t n = index.size();
// Initialize domains // Initialize domains
std::vector<Domain> domains; Domains domains;
for (size_t j = 0; j < n; j++) for (auto entry : index) {
domains.push_back(Domain(DiscreteKey(j, cardinality))); const Key key = entry.first;
domains.emplace(key, DiscreteKey(key, cardinality));
// Create array of flags indicating a domain changed or not
std::vector<bool> changed(n);
// iterate nrIterations over entire grid
for (size_t it = 0; it < nrIterations; it++) {
bool anyChange = false;
// iterate over all cells
for (size_t v = 0; v < n; v++) {
// keep track of which domains changed
changed[v] = false;
// loop over all factors/constraints for variable v
const FactorIndices& factors = index[v];
for (size_t f : factors) {
// if not already a singleton
if (!domains[v].isSingleton()) {
// get the constraint and call its ensureArcConsistency method
auto constraint = boost::dynamic_pointer_cast<Constraint>((*this)[f]);
if (!constraint)
throw runtime_error("CSP:runArcConsistency: non-constraint factor");
changed[v] =
constraint->ensureArcConsistency(v, &domains) || changed[v];
} }
} // f
if (changed[v]) anyChange = true;
} // v
if (!anyChange) break;
// TODO: Sudoku specific hack
if (print) {
if (cardinality == 9 && n == 81) {
for (size_t i = 0, v = 0; i < (size_t)std::sqrt((double)n); i++) {
for (size_t j = 0; j < (size_t)std::sqrt((double)n); j++, v++) {
if (changed[v]) cout << "*";
domains[v].print();
cout << "\t";
} // i
cout << endl;
} // j
} else {
for (size_t v = 0; v < n; v++) {
if (changed[v]) cout << "*";
domains[v].print();
cout << "\t";
} // v
}
cout << endl;
} // print
} // it
#ifndef INPROGRESS // Iterate until convergence or not a single domain changed.
// Now create new problem with all singleton variables removed for (size_t it = 0; it < maxIterations; it++) {
// We do this by adding simplifying all factors using parial application bool changed = runArcConsistency(index, &domains);
if (!changed) break;
}
return domains;
}
CSP CSP::partiallyApply(const Domains& domains) const {
// Create new problem with all singleton variables removed
// We do this by adding simplifying all factors using partial application.
// TODO: create a new ordering as we go, to ensure a connected graph // TODO: create a new ordering as we go, to ensure a connected graph
// KeyOrdering ordering; // KeyOrdering ordering;
// vector<Index> dkeys; // vector<Index> dkeys;
CSP new_csp;
// Add tightened domains as new factors:
for (auto key_domain : domains) {
new_csp.emplace_shared<Domain>(key_domain.second);
}
// Reduce all existing factors:
for (const DiscreteFactor::shared_ptr& f : factors_) { for (const DiscreteFactor::shared_ptr& f : factors_) {
Constraint::shared_ptr constraint = auto constraint = boost::dynamic_pointer_cast<Constraint>(f);
boost::dynamic_pointer_cast<Constraint>(f);
if (!constraint) if (!constraint)
throw runtime_error("CSP:runArcConsistency: non-constraint factor"); throw runtime_error("CSP:runArcConsistency: non-constraint factor");
Constraint::shared_ptr reduced = constraint->partiallyApply(domains); Constraint::shared_ptr reduced = constraint->partiallyApply(domains);
if (print) reduced->print(); if (reduced->size() > 1) {
new_csp.push_back(reduced);
} }
#endif }
return new_csp;
} }
} // namespace gtsam } // namespace gtsam

View File

@ -62,7 +62,7 @@ class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph {
// deep. // deep.
// * It will be very expensive to exclude values that way. // * It will be very expensive to exclude values that way.
// */ // */
// void applyBeliefPropagation(size_t nrIterations = 10) const; // void applyBeliefPropagation(size_t maxIterations = 10) const;
/* /*
* Apply arc-consistency ~ Approximate loopy belief propagation * Apply arc-consistency ~ Approximate loopy belief propagation
@ -70,8 +70,16 @@ class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph {
* a domain whose values don't conflict in the arc-consistency way. * a domain whose values don't conflict in the arc-consistency way.
* TODO: should get cardinality from DiscreteKeys * TODO: should get cardinality from DiscreteKeys
*/ */
void runArcConsistency(size_t cardinality, size_t nrIterations = 10, Domains runArcConsistency(size_t cardinality,
bool print = false) const; size_t maxIterations = 10) const;
/// Run arc consistency for all variables, return true if any domain changed.
bool runArcConsistency(const VariableIndex& index, Domains* domains) const;
/*
* Create a new CSP, applying the given Domain constraints.
*/
CSP partiallyApply(const Domains& domains) const;
}; // CSP }; // CSP
} // namespace gtsam } // namespace gtsam

View File

@ -21,10 +21,12 @@
#include <gtsam_unstable/dllexport.h> #include <gtsam_unstable/dllexport.h>
#include <boost/assign.hpp> #include <boost/assign.hpp>
#include <map>
namespace gtsam { namespace gtsam {
class Domain; class Domain;
using Domains = std::map<Key, Domain>;
/** /**
* Base class for constraint factors * Base class for constraint factors
@ -65,19 +67,18 @@ class GTSAM_EXPORT Constraint : public DiscreteFactor {
/// @{ /// @{
/* /*
* Ensure Arc-consistency, possibly changing domains of connected variables. * Ensure Arc-consistency by checking every possible value of domain j.
* @param j domain to be checked * @param j domain to be checked
* @param (in/out) domains all other domains * @param (in/out) domains all domains, but only domains->at(j) will be checked.
* @return true if domains were changed, false otherwise. * @return true if domains->at(j) was changed, false otherwise.
*/ */
virtual bool ensureArcConsistency(size_t j, virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0;
std::vector<Domain>* domains) const = 0;
/// Partially apply known values /// Partially apply known values
virtual shared_ptr partiallyApply(const Values&) const = 0; virtual shared_ptr partiallyApply(const Values&) const = 0;
/// Partially apply known values, domain version /// Partially apply known values, domain version
virtual shared_ptr partiallyApply(const std::vector<Domain>&) const = 0; virtual shared_ptr partiallyApply(const Domains&) const = 0;
/// @} /// @}
}; };
// DiscreteFactor // DiscreteFactor

View File

@ -10,28 +10,35 @@
#include <gtsam_unstable/discrete/Domain.h> #include <gtsam_unstable/discrete/Domain.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
#include <sstream>
namespace gtsam { namespace gtsam {
using namespace std; using namespace std;
/* ************************************************************************* */ /* ************************************************************************* */
void Domain::print(const string& s, const KeyFormatter& formatter) const { void Domain::print(const string& s, const KeyFormatter& formatter) const {
cout << s << ": Domain on " << formatter(keys_[0]) cout << s << ": Domain on " << formatter(key()) << " (j=" << formatter(key())
<< " (j=" << formatter(keys_[0]) << ") with values"; << ") with values";
for (size_t v : values_) cout << " " << v; for (size_t v : values_) cout << " " << v;
cout << endl; cout << endl;
} }
/* ************************************************************************* */
string Domain::base1Str() const {
stringstream ss;
for (size_t v : values_) ss << v + 1;
return ss.str();
}
/* ************************************************************************* */ /* ************************************************************************* */
double Domain::operator()(const Values& values) const { double Domain::operator()(const Values& values) const {
return contains(values.at(keys_[0])); return contains(values.at(key()));
} }
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor Domain::toDecisionTreeFactor() const { DecisionTreeFactor Domain::toDecisionTreeFactor() const {
DiscreteKeys keys; DiscreteKeys keys;
keys += DiscreteKey(keys_[0], cardinality_); keys += DiscreteKey(key(), cardinality_);
vector<double> table; vector<double> table;
for (size_t i1 = 0; i1 < cardinality_; ++i1) table.push_back(contains(i1)); for (size_t i1 = 0; i1 < cardinality_; ++i1) table.push_back(contains(i1));
DecisionTreeFactor converted(keys, table); DecisionTreeFactor converted(keys, table);
@ -45,8 +52,8 @@ DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
bool Domain::ensureArcConsistency(size_t j, vector<Domain>* domains) const { bool Domain::ensureArcConsistency(Key j, Domains* domains) const {
if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain"); if (j != key()) throw invalid_argument("Domain check on wrong domain");
Domain& D = domains->at(j); Domain& D = domains->at(j);
for (size_t value : values_) for (size_t value : values_)
if (!D.contains(value)) throw runtime_error("Unsatisfiable"); if (!D.contains(value)) throw runtime_error("Unsatisfiable");
@ -55,15 +62,15 @@ bool Domain::ensureArcConsistency(size_t j, vector<Domain>* domains) const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
boost::optional<Domain> Domain::checkAllDiff( boost::optional<Domain> Domain::checkAllDiff(const KeyVector keys,
const KeyVector keys, const vector<Domain>& domains) const { const Domains& domains) const {
Key j = keys_[0]; Key j = key();
// for all values in this domain // for all values in this domain
for (const size_t value : values_) { for (const size_t value : values_) {
// for all connected domains // for all connected domains
for (const Key k : keys) for (const Key k : keys)
// if any domain contains the value we cannot make this domain singleton // if any domain contains the value we cannot make this domain singleton
if (k != j && domains[k].contains(value)) goto found; if (k != j && domains.at(k).contains(value)) goto found;
// Otherwise: return a singleton: // Otherwise: return a singleton:
return Domain(this->discreteKey(), value); return Domain(this->discreteKey(), value);
found:; found:;
@ -73,16 +80,15 @@ boost::optional<Domain> Domain::checkAllDiff(
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr Domain::partiallyApply(const Values& values) const { Constraint::shared_ptr Domain::partiallyApply(const Values& values) const {
Values::const_iterator it = values.find(keys_[0]); Values::const_iterator it = values.find(key());
if (it != values.end() && !contains(it->second)) if (it != values.end() && !contains(it->second))
throw runtime_error("Domain::partiallyApply: unsatisfiable"); throw runtime_error("Domain::partiallyApply: unsatisfiable");
return boost::make_shared<Domain>(*this); return boost::make_shared<Domain>(*this);
} }
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr Domain::partiallyApply( Constraint::shared_ptr Domain::partiallyApply(const Domains& domains) const {
const vector<Domain>& domains) const { const Domain& Dk = domains.at(key());
const Domain& Dk = domains[keys_[0]];
if (Dk.isSingleton() && !contains(*Dk.begin())) if (Dk.isSingleton() && !contains(*Dk.begin()))
throw runtime_error("Domain::partiallyApply: unsatisfiable"); throw runtime_error("Domain::partiallyApply: unsatisfiable");
return boost::make_shared<Domain>(Dk); return boost::make_shared<Domain>(Dk);

View File

@ -20,10 +20,6 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
size_t cardinality_; /// Cardinality size_t cardinality_; /// Cardinality
std::set<size_t> values_; /// allowed values std::set<size_t> values_; /// allowed values
DiscreteKey discreteKey() const {
return DiscreteKey(keys_[0], cardinality_);
}
public: public:
typedef boost::shared_ptr<Domain> shared_ptr; typedef boost::shared_ptr<Domain> shared_ptr;
@ -40,6 +36,12 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
values_.insert(v); values_.insert(v);
} }
/// The one key
Key key() const { return keys_[0]; }
// The associated discrete key
DiscreteKey discreteKey() const { return DiscreteKey(key(), cardinality_); }
/// Insert a value, non const :-( /// Insert a value, non const :-(
void insert(size_t value) { values_.insert(value); } void insert(size_t value) { values_.insert(value); }
@ -66,6 +68,11 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
} }
} }
// Return concise string representation, mostly to debug arc consistency.
// Converts from base 0 to base1.
std::string base1Str() const;
// Check whether domain cotains a specific value.
bool contains(size_t value) const { return values_.count(value) > 0; } bool contains(size_t value) const { return values_.count(value) > 0; }
/// Calculate value /// Calculate value
@ -78,12 +85,13 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
/* /*
* Ensure Arc-consistency * Ensure Arc-consistency by checking every possible value of domain j.
* @param j domain to be checked * @param j domain to be checked
* @param domains all other domains * @param (in/out) domains all domains, but only domains->at(j) will be
* checked.
* @return true if domains->at(j) was changed, false otherwise.
*/ */
bool ensureArcConsistency(size_t j, bool ensureArcConsistency(Key j, Domains* domains) const override;
std::vector<Domain>* domains) const override;
/** /**
* Check for a value in domain that does not occur in any other connected * Check for a value in domain that does not occur in any other connected
@ -92,15 +100,14 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
* @param keys connected domains through alldiff * @param keys connected domains through alldiff
* @param keys other domains * @param keys other domains
*/ */
boost::optional<Domain> checkAllDiff( boost::optional<Domain> checkAllDiff(const KeyVector keys,
const KeyVector keys, const std::vector<Domain>& domains) const; const Domains& domains) const;
/// Partially apply known values /// Partially apply known values
Constraint::shared_ptr partiallyApply(const Values& values) const override; Constraint::shared_ptr partiallyApply(const Values& values) const override;
/// Partially apply known values, domain version /// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply( Constraint::shared_ptr partiallyApply(const Domains& domains) const override;
const std::vector<Domain>& domains) const override;
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -44,8 +44,7 @@ DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
bool SingleValue::ensureArcConsistency(size_t j, bool SingleValue::ensureArcConsistency(Key j, Domains* domains) const {
vector<Domain>* domains) const {
if (j != keys_[0]) if (j != keys_[0])
throw invalid_argument("SingleValue check on wrong domain"); throw invalid_argument("SingleValue check on wrong domain");
Domain& D = domains->at(j); Domain& D = domains->at(j);
@ -67,8 +66,8 @@ Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const {
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr SingleValue::partiallyApply( Constraint::shared_ptr SingleValue::partiallyApply(
const vector<Domain>& domains) const { const Domains& domains) const {
const Domain& Dk = domains[keys_[0]]; const Domain& Dk = domains.at(keys_[0]);
if (Dk.isSingleton() && !Dk.contains(value_)) if (Dk.isSingleton() && !Dk.contains(value_))
throw runtime_error("SingleValue::partiallyApply: unsatisfiable"); throw runtime_error("SingleValue::partiallyApply: unsatisfiable");
return boost::make_shared<SingleValue>(discreteKey(), value_); return boost::make_shared<SingleValue>(discreteKey(), value_);

View File

@ -59,19 +59,19 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
/* /*
* Ensure Arc-consistency: just sets domain[j] to {value_} * Ensure Arc-consistency: just sets domain[j] to {value_}.
* @param j domain to be checked * @param j domain to be checked
* @param domains all other domains * @param (in/out) domains all domains, but only domains->at(j) will be checked.
* @return true if domains->at(j) was changed, false otherwise.
*/ */
bool ensureArcConsistency(size_t j, bool ensureArcConsistency(Key j, Domains* domains) const override;
std::vector<Domain>* domains) const override;
/// Partially apply known values /// Partially apply known values
Constraint::shared_ptr partiallyApply(const Values& values) const override; Constraint::shared_ptr partiallyApply(const Values& values) const override;
/// Partially apply known values, domain version /// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply( Constraint::shared_ptr partiallyApply(
const std::vector<Domain>& domains) const override; const Domains& domains) const override;
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -29,16 +29,17 @@ TEST(CSP, SingleValue) {
DecisionTreeFactor f1(AZ, "0 0 1"); DecisionTreeFactor f1(AZ, "0 0 1");
EXPECT(assert_equal(f1, singleValue.toDecisionTreeFactor())); EXPECT(assert_equal(f1, singleValue.toDecisionTreeFactor()));
// Create domains, laid out as a vector. // Create domains
// TODO(dellaert): should be map?? Domains domains;
vector<Domain> domains; domains.emplace(0, Domain(ID));
domains += Domain(ID), Domain(AZ), Domain(UT); domains.emplace(1, Domain(AZ));
domains.emplace(2, Domain(UT));
// Ensure arc-consistency: just wipes out values in AZ domain: // Ensure arc-consistency: just wipes out values in AZ domain:
EXPECT(singleValue.ensureArcConsistency(1, &domains)); EXPECT(singleValue.ensureArcConsistency(1, &domains));
LONGS_EQUAL(3, domains[0].nrValues()); LONGS_EQUAL(3, domains.at(0).nrValues());
LONGS_EQUAL(1, domains[1].nrValues()); LONGS_EQUAL(1, domains.at(1).nrValues());
LONGS_EQUAL(3, domains[2].nrValues()); LONGS_EQUAL(3, domains.at(2).nrValues());
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -81,8 +82,10 @@ TEST(CSP, AllDiff) {
EXPECT(assert_equal(f2, actual)); EXPECT(assert_equal(f2, actual));
// Create domains. // Create domains.
vector<Domain> domains; Domains domains;
domains += Domain(ID), Domain(AZ), Domain(UT); domains.emplace(0, Domain(ID));
domains.emplace(1, Domain(AZ));
domains.emplace(2, Domain(UT));
// First constrict AZ domain: // First constrict AZ domain:
SingleValue singleValue(AZ, 2); SingleValue singleValue(AZ, 2);
@ -92,9 +95,9 @@ TEST(CSP, AllDiff) {
EXPECT(alldiff.ensureArcConsistency(0, &domains)); EXPECT(alldiff.ensureArcConsistency(0, &domains));
EXPECT(!alldiff.ensureArcConsistency(1, &domains)); EXPECT(!alldiff.ensureArcConsistency(1, &domains));
EXPECT(alldiff.ensureArcConsistency(2, &domains)); EXPECT(alldiff.ensureArcConsistency(2, &domains));
LONGS_EQUAL(2, domains[0].nrValues()); LONGS_EQUAL(2, domains.at(0).nrValues());
LONGS_EQUAL(1, domains[1].nrValues()); LONGS_EQUAL(1, domains.at(1).nrValues());
LONGS_EQUAL(2, domains[2].nrValues()); LONGS_EQUAL(2, domains.at(2).nrValues());
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -232,17 +235,20 @@ TEST(CSP, ArcConsistency) {
EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9);
// ensure arc-consistency, i.e., narrow domains... // ensure arc-consistency, i.e., narrow domains...
vector<Domain> domains; Domains domains;
domains += Domain(ID), Domain(AZ), Domain(UT); domains.emplace(0, Domain(ID));
domains.emplace(1, Domain(AZ));
domains.emplace(2, Domain(UT));
SingleValue singleValue(AZ, 2); SingleValue singleValue(AZ, 2);
AllDiff alldiff(dkeys); AllDiff alldiff(dkeys);
EXPECT(singleValue.ensureArcConsistency(1, &domains)); EXPECT(singleValue.ensureArcConsistency(1, &domains));
EXPECT(alldiff.ensureArcConsistency(0, &domains)); EXPECT(alldiff.ensureArcConsistency(0, &domains));
EXPECT(!alldiff.ensureArcConsistency(1, &domains)); EXPECT(!alldiff.ensureArcConsistency(1, &domains));
EXPECT(alldiff.ensureArcConsistency(2, &domains)); EXPECT(alldiff.ensureArcConsistency(2, &domains));
LONGS_EQUAL(2, domains[0].nrValues()); LONGS_EQUAL(2, domains.at(0).nrValues());
LONGS_EQUAL(1, domains[1].nrValues()); LONGS_EQUAL(1, domains.at(1).nrValues());
LONGS_EQUAL(2, domains[2].nrValues()); LONGS_EQUAL(2, domains.at(2).nrValues());
// Parial application, version 1 // Parial application, version 1
DiscreteFactor::Values known; DiscreteFactor::Values known;

View File

@ -6,6 +6,7 @@
*/ */
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam_unstable/discrete/CSP.h> #include <gtsam_unstable/discrete/CSP.h>
#include <boost/assign/std/map.hpp> #include <boost/assign/std/map.hpp>
@ -20,12 +21,12 @@ using namespace gtsam;
#define PRINT false #define PRINT false
/// A class that encodes Sudoku's as a CSP problem
class Sudoku : public CSP { class Sudoku : public CSP {
/// sudoku size size_t n_; ///< Side of Sudoku, e.g. 4 or 9
size_t n_;
/// discrete keys /// Mapping from base i,j coordinates to discrete keys:
typedef std::pair<size_t, size_t> IJ; using IJ = std::pair<size_t, size_t>;
std::map<IJ, DiscreteKey> dkeys_; std::map<IJ, DiscreteKey> dkeys_;
public: public:
@ -42,15 +43,14 @@ class Sudoku : public CSP {
// Create variables, ordering, and unary constraints // Create variables, ordering, and unary constraints
va_list ap; va_list ap;
va_start(ap, n); va_start(ap, n);
Key k = 0;
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
for (size_t j = 0; j < n; ++j, ++k) { for (size_t j = 0; j < n; ++j) {
// create the key // create the key
IJ ij(i, j); IJ ij(i, j);
dkeys_[ij] = DiscreteKey(k, n); Symbol key('1' + i, j + 1);
dkeys_[ij] = DiscreteKey(key, n);
// get the unary constraint, if any // get the unary constraint, if any
int value = va_arg(ap, int); int value = va_arg(ap, int);
// cout << value << " ";
if (value != 0) addSingleValue(dkeys_[ij], value - 1); if (value != 0) addSingleValue(dkeys_[ij], value - 1);
} }
// cout << endl; // cout << endl;
@ -88,7 +88,7 @@ class Sudoku : public CSP {
} }
/// Print readable form of assignment /// Print readable form of assignment
void printAssignment(DiscreteFactor::sharedValues assignment) const { void printAssignment(const DiscreteFactor::sharedValues& assignment) const {
for (size_t i = 0; i < n_; i++) { for (size_t i = 0; i < n_; i++) {
for (size_t j = 0; j < n_; j++) { for (size_t j = 0; j < n_; j++) {
Key k = key(i, j); Key k = key(i, j);
@ -99,10 +99,22 @@ class Sudoku : public CSP {
} }
/// solve and print solution /// solve and print solution
void printSolution() { void printSolution() const {
DiscreteFactor::sharedValues MPE = optimalAssignment(); DiscreteFactor::sharedValues MPE = optimalAssignment();
printAssignment(MPE); printAssignment(MPE);
} }
// Print domain
void printDomains(const Domains& domains) {
for (size_t i = 0; i < n_; i++) {
for (size_t j = 0; j < n_; j++) {
Key k = key(i, j);
cout << domains.at(k).base1Str();
cout << "\t";
} // i
cout << endl;
} // j
}
}; };
/* ************************************************************************* */ /* ************************************************************************* */
@ -113,9 +125,6 @@ TEST_UNSAFE(Sudoku, small) {
4, 0, 2, 0, // 4, 0, 2, 0, //
0, 1, 0, 0); 0, 1, 0, 0);
// Do BP
csp.runArcConsistency(4, 10, PRINT);
// optimize and check // optimize and check
CSP::sharedValues solution = csp.optimalAssignment(); CSP::sharedValues solution = csp.optimalAssignment();
CSP::Values expected; CSP::Values expected;
@ -126,11 +135,27 @@ TEST_UNSAFE(Sudoku, small) {
csp.key(3, 3), 2); csp.key(3, 3), 2);
EXPECT(assert_equal(expected, *solution)); EXPECT(assert_equal(expected, *solution));
// csp.printAssignment(solution); // csp.printAssignment(solution);
// Do BP (AC1)
auto domains = csp.runArcConsistency(4, 3);
// csp.printDomains(domains);
Domain domain44 = domains.at(Symbol('4', 4));
EXPECT_LONGS_EQUAL(1, domain44.nrValues());
// Test Creation of a new, simpler CSP
CSP new_csp = csp.partiallyApply(domains);
// Should only be 16 new Domains
EXPECT_LONGS_EQUAL(16, new_csp.size());
// Check that solution
CSP::sharedValues new_solution = new_csp.optimalAssignment();
// csp.printAssignment(new_solution);
EXPECT(assert_equal(expected, *new_solution));
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST_UNSAFE(Sudoku, easy) { TEST_UNSAFE(Sudoku, easy) {
Sudoku sudoku(9, // Sudoku csp(9, //
0, 0, 5, 0, 9, 0, 0, 0, 1, // 0, 0, 5, 0, 9, 0, 0, 0, 1, //
0, 0, 0, 0, 0, 2, 0, 7, 3, // 0, 0, 0, 0, 0, 2, 0, 7, 3, //
7, 6, 0, 0, 0, 8, 2, 0, 0, // 7, 6, 0, 0, 0, 8, 2, 0, 0, //
@ -143,15 +168,26 @@ TEST_UNSAFE(Sudoku, easy) {
9, 7, 0, 5, 0, 0, 0, 0, 0, // 9, 7, 0, 5, 0, 0, 0, 0, 0, //
5, 0, 0, 0, 3, 0, 7, 0, 0); 5, 0, 0, 0, 3, 0, 7, 0, 0);
// Do BP // csp.printSolution(); // don't do it
sudoku.runArcConsistency(4, 10, PRINT);
// sudoku.printSolution(); // don't do it // Do BP (AC1)
auto domains = csp.runArcConsistency(9, 10);
// csp.printDomains(domains);
Key key99 = Symbol('9', 9);
Domain domain99 = domains.at(key99);
EXPECT_LONGS_EQUAL(1, domain99.nrValues());
// Test Creation of a new, simpler CSP
CSP new_csp = csp.partiallyApply(domains);
// 81 new Domains, and still 26 all-diff constraints
EXPECT_LONGS_EQUAL(81 + 26, new_csp.size());
// csp.printSolution(); // still don't do it ! :-(
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST_UNSAFE(Sudoku, extreme) { TEST_UNSAFE(Sudoku, extreme) {
Sudoku sudoku(9, // Sudoku csp(9, //
0, 0, 9, 7, 4, 8, 0, 0, 0, 7, // 0, 0, 9, 7, 4, 8, 0, 0, 0, 7, //
0, 0, 0, 0, 0, 0, 0, 0, 0, 2, // 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, //
0, 1, 0, 9, 0, 0, 0, 0, 0, 7, // 0, 1, 0, 9, 0, 0, 0, 0, 0, 7, //
@ -162,21 +198,33 @@ TEST_UNSAFE(Sudoku, extreme) {
0, 6, 0, 0, 0, 2, 7, 5, 9, 0, 0); 0, 6, 0, 0, 0, 2, 7, 5, 9, 0, 0);
// Do BP // Do BP
sudoku.runArcConsistency(9, 10, PRINT); csp.runArcConsistency(9, 10);
#ifdef METIS #ifdef METIS
VariableIndexOrdered index(sudoku); VariableIndexOrdered index(csp);
index.print("index"); index.print("index");
ofstream os("/Users/dellaert/src/hmetis-1.5-osx-i686/extreme-dual.txt"); ofstream os("/Users/dellaert/src/hmetis-1.5-osx-i686/extreme-dual.txt");
index.outputMetisFormat(os); index.outputMetisFormat(os);
#endif #endif
// sudoku.printSolution(); // don't do it // Do BP (AC1)
auto domains = csp.runArcConsistency(9, 10);
// csp.printDomains(domains);
Key key99 = Symbol('9', 9);
Domain domain99 = domains.at(key99);
EXPECT_LONGS_EQUAL(2, domain99.nrValues());
// Test Creation of a new, simpler CSP
CSP new_csp = csp.partiallyApply(domains);
// 81 new Domains, and still 20 all-diff constraints
EXPECT_LONGS_EQUAL(81 + 20, new_csp.size());
// csp.printSolution(); // still don't do it ! :-(
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST_UNSAFE(Sudoku, AJC_3star_Feb8_2012) { TEST_UNSAFE(Sudoku, AJC_3star_Feb8_2012) {
Sudoku sudoku(9, // Sudoku csp(9, //
9, 5, 0, 0, 0, 6, 0, 0, 0, // 9, 5, 0, 0, 0, 6, 0, 0, 0, //
0, 8, 4, 0, 7, 0, 0, 0, 0, // 0, 8, 4, 0, 7, 0, 0, 0, 0, //
6, 2, 0, 5, 0, 0, 4, 0, 0, // 6, 2, 0, 5, 0, 0, 4, 0, 0, //
@ -189,10 +237,22 @@ TEST_UNSAFE(Sudoku, AJC_3star_Feb8_2012) {
0, 0, 0, 0, 3, 0, 2, 9, 0, // 0, 0, 0, 0, 3, 0, 2, 9, 0, //
0, 0, 0, 1, 0, 0, 0, 3, 7); 0, 0, 0, 1, 0, 0, 0, 3, 7);
// Do BP // Do BP (AC1)
sudoku.runArcConsistency(9, 10, PRINT); auto domains = csp.runArcConsistency(9, 10);
// csp.printDomains(domains);
Key key99 = Symbol('9', 9);
Domain domain99 = domains.at(key99);
EXPECT_LONGS_EQUAL(1, domain99.nrValues());
// sudoku.printSolution(); // don't do it // Test Creation of a new, simpler CSP
CSP new_csp = csp.partiallyApply(domains);
// Just the 81 new Domains
EXPECT_LONGS_EQUAL(81, new_csp.size());
// Check that solution
CSP::sharedValues solution = new_csp.optimalAssignment();
// csp.printAssignment(solution);
EXPECT_LONGS_EQUAL(6, solution->at(key99));
} }
/* ************************************************************************* */ /* ************************************************************************* */