From dd509756685469413eb6e964b291b0a6b35dd1ec Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Thu, 18 Nov 2021 10:17:49 -0500 Subject: [PATCH 1/5] Revamped arc consistency --- gtsam_unstable/discrete/AllDiff.cpp | 184 ++++++++++++---------- gtsam_unstable/discrete/AllDiff.h | 119 +++++++------- gtsam_unstable/discrete/BinaryAllDiff.h | 163 ++++++++++--------- gtsam_unstable/discrete/CSP.cpp | 5 +- gtsam_unstable/discrete/Constraint.h | 119 +++++++------- gtsam_unstable/discrete/Domain.cpp | 155 +++++++++--------- gtsam_unstable/discrete/Domain.h | 35 ++-- gtsam_unstable/discrete/SingleValue.cpp | 125 +++++++-------- gtsam_unstable/discrete/SingleValue.h | 125 +++++++-------- gtsam_unstable/discrete/tests/testCSP.cpp | 129 +++++++++------ 10 files changed, 612 insertions(+), 547 deletions(-) diff --git a/gtsam_unstable/discrete/AllDiff.cpp b/gtsam_unstable/discrete/AllDiff.cpp index d6e1c6453..ebc789ec2 100644 --- a/gtsam_unstable/discrete/AllDiff.cpp +++ b/gtsam_unstable/discrete/AllDiff.cpp @@ -5,105 +5,115 @@ * @author Frank Dellaert */ -#include -#include #include - +#include +#include #include namespace gtsam { -/* ************************************************************************* */ -AllDiff::AllDiff(const DiscreteKeys& dkeys) : Constraint(dkeys.indices()) { - for (const DiscreteKey& dkey : dkeys) cardinalities_.insert(dkey); -} - -/* ************************************************************************* */ -void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const { - std::cout << s << "AllDiff on "; - for (Key dkey : keys_) std::cout << formatter(dkey) << " "; - std::cout << std::endl; -} - -/* ************************************************************************* */ -double AllDiff::operator()(const Values& values) const { - std::set taken; // record values taken by keys - for (Key dkey : keys_) { - size_t value = values.at(dkey); // get the value for that key - if (taken.count(value)) return 0.0; // check if value alreday taken - taken.insert(value); // if not, record it as taken and keep checking + /* ************************************************************************* */ + AllDiff::AllDiff(const DiscreteKeys& dkeys) : + Constraint(dkeys.indices()) { + for(const DiscreteKey& dkey: dkeys) + cardinalities_.insert(dkey); } - return 1.0; -} -/* ************************************************************************* */ -DecisionTreeFactor AllDiff::toDecisionTreeFactor() const { - // We will do this by converting the allDif into many BinaryAllDiff - // constraints - DecisionTreeFactor converted; - size_t nrKeys = keys_.size(); - for (size_t i1 = 0; i1 < nrKeys; i1++) - for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) { - BinaryAllDiff binary12(discreteKey(i1), discreteKey(i2)); - converted = converted * binary12.toDecisionTreeFactor(); + /* ************************************************************************* */ + void AllDiff::print(const std::string& s, + const KeyFormatter& formatter) const { + std::cout << s << "AllDiff on "; + for (Key dkey: keys_) + std::cout << formatter(dkey) << " "; + std::cout << std::endl; + } + + /* ************************************************************************* */ + double AllDiff::operator()(const Values& values) const { + std::set < size_t > taken; // record values taken by keys + for(Key dkey: keys_) { + size_t value = values.at(dkey); // get the value for that key + if (taken.count(value)) return 0.0;// check if value alreday taken + taken.insert(value);// if not, record it as taken and keep checking } - return converted; -} + return 1.0; + } -/* ************************************************************************* */ -DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const { - // TODO: can we do this more efficiently? - return toDecisionTreeFactor() * f; -} + /* ************************************************************************* */ + DecisionTreeFactor AllDiff::toDecisionTreeFactor() const { + // We will do this by converting the allDif into many BinaryAllDiff constraints + DecisionTreeFactor converted; + size_t nrKeys = keys_.size(); + for (size_t i1 = 0; i1 < nrKeys; i1++) + for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) { + BinaryAllDiff binary12(discreteKey(i1),discreteKey(i2)); + converted = converted * binary12.toDecisionTreeFactor(); + } + return converted; + } -/* ************************************************************************* */ -bool AllDiff::ensureArcConsistency(size_t j, - std::vector& domains) const { - // Though strictly not part of allDiff, we check for - // a value in domains[j] that does not occur in any other connected domain. - // If found, we make this a singleton... - // TODO: make a new constraint where this really is true - Domain& Dj = domains[j]; - if (Dj.checkAllDiff(keys_, domains)) return true; + /* ************************************************************************* */ + DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const { + // TODO: can we do this more efficiently? + return toDecisionTreeFactor() * f; + } - // Check all other domains for singletons and erase corresponding values - // This is the same as arc-consistency on the equivalent binary constraints - bool changed = false; - for (Key k : keys_) - if (k != j) { - const Domain& Dk = domains[k]; - if (Dk.isSingleton()) { // check if singleton - size_t value = Dk.firstValue(); - if (Dj.contains(value)) { - Dj.erase(value); // erase value if true - changed = true; + /* ************************************************************************* */ + bool AllDiff::ensureArcConsistency(size_t j, + std::vector* domains) const { + // We are changing the domain of variable j. + // TODO(dellaert): confusing, I thought we were changing others... + Domain& Dj = domains->at(j); + + // Though strictly not part of allDiff, we check for + // a value in domains[j] that does not occur in any other connected domain. + // If found, we make this a singleton... + // TODO: make a new constraint where this really is true + boost::optional maybeChanged = Dj.checkAllDiff(keys_, *domains); + if (maybeChanged) { + Dj = *maybeChanged; + return true; + } + + // Check all other domains for singletons and erase corresponding values. + // This is the same as arc-consistency on the equivalent binary constraints + bool changed = false; + for (Key k : keys_) + if (k != j) { + const Domain& Dk = domains->at(k); + if (Dk.isSingleton()) { // check if singleton + size_t value = Dk.firstValue(); + if (Dj.contains(value)) { + Dj.erase(value); // erase value if true + changed = true; + } } } - } - return changed; -} - -/* ************************************************************************* */ -Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const { - DiscreteKeys newKeys; - // loop over keys and add them only if they do not appear in values - for (Key k : keys_) - if (values.find(k) == values.end()) { - newKeys.push_back(DiscreteKey(k, cardinalities_.at(k))); - } - return boost::make_shared(newKeys); -} - -/* ************************************************************************* */ -Constraint::shared_ptr AllDiff::partiallyApply( - const std::vector& domains) const { - DiscreteFactor::Values known; - for (Key k : keys_) { - const Domain& Dk = domains[k]; - if (Dk.isSingleton()) known[k] = Dk.firstValue(); + return changed; } - return partiallyApply(known); -} -/* ************************************************************************* */ -} // namespace gtsam + /* ************************************************************************* */ + Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const { + DiscreteKeys newKeys; + // loop over keys and add them only if they do not appear in values + for(Key k: keys_) + if (values.find(k) == values.end()) { + newKeys.push_back(DiscreteKey(k,cardinalities_.at(k))); + } + return boost::make_shared(newKeys); + } + + /* ************************************************************************* */ + Constraint::shared_ptr AllDiff::partiallyApply( + const std::vector& domains) const { + DiscreteFactor::Values known; + for(Key k: keys_) { + const Domain& Dk = domains[k]; + if (Dk.isSingleton()) + known[k] = Dk.firstValue(); + } + return partiallyApply(known); + } + + /* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index b0fd1d631..8c83e5ba1 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -7,71 +7,70 @@ #pragma once -#include #include +#include namespace gtsam { -/** - * General AllDiff constraint - * Returns 1 if values for all keys are different, 0 otherwise - * DiscreteFactors are all awkward in that they have to store two types of keys: - * for each variable we have a Key and an Key. In this factor, we - * keep the Indices locally, and the Indices are stored in IndexFactor. - */ -class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { - std::map cardinalities_; - - DiscreteKey discreteKey(size_t i) const { - Key j = keys_[i]; - return DiscreteKey(j, cardinalities_.at(j)); - } - - public: - /// Constructor - AllDiff(const DiscreteKeys& dkeys); - - // print - void print(const std::string& s = "", const KeyFormatter& formatter = - DefaultKeyFormatter) const override; - - /// equals - bool equals(const DiscreteFactor& other, double tol) const override { - if (!dynamic_cast(&other)) - return false; - else { - const AllDiff& f(static_cast(other)); - return cardinalities_.size() == f.cardinalities_.size() && - std::equal(cardinalities_.begin(), cardinalities_.end(), - f.cardinalities_.begin()); - } - } - - /// Calculate value = expensive ! - double operator()(const Values& values) const override; - - /// Convert into a decisiontree, can be *very* expensive ! - DecisionTreeFactor toDecisionTreeFactor() const override; - - /// Multiply into a decisiontree - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; - - /* - * Ensure Arc-consistency - * Arc-consistency involves creating binaryAllDiff constraints - * In which case the combinatorial hyper-arc explosion disappears. - * @param j domain to be checked - * @param domains all other domains + /** + * General AllDiff constraint. + * Returns 1 if values for all keys are different, 0 otherwise. */ - bool ensureArcConsistency(size_t j, - std::vector& domains) const override; + class GTSAM_UNSTABLE_EXPORT AllDiff: public Constraint { - /// Partially apply known values - Constraint::shared_ptr partiallyApply(const Values&) const override; + std::map cardinalities_; - /// Partially apply known values, domain version - Constraint::shared_ptr partiallyApply( - const std::vector&) const override; -}; + DiscreteKey discreteKey(size_t i) const { + Key j = keys_[i]; + return DiscreteKey(j,cardinalities_.at(j)); + } -} // namespace gtsam + public: + + /// Construct from keys. + AllDiff(const DiscreteKeys& dkeys); + + // print + void print(const std::string& s = "", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; + + /// equals + bool equals(const DiscreteFactor& other, double tol) const override { + if(!dynamic_cast(&other)) + return false; + else { + const AllDiff& f(static_cast(other)); + return cardinalities_.size() == f.cardinalities_.size() + && std::equal(cardinalities_.begin(), cardinalities_.end(), + f.cardinalities_.begin()); + } + } + + /// Calculate value = expensive ! + double operator()(const Values& values) const override; + + /// Convert into a decisiontree, can be *very* expensive ! + DecisionTreeFactor toDecisionTreeFactor() const override; + + /// Multiply into a decisiontree + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + + /* + * Ensure Arc-consistency + * Arc-consistency involves creating binaryAllDiff constraints + * In which case the combinatorial hyper-arc explosion disappears. + * @param j domain to be checked + * @param (in/out) domains all other domains + */ + bool ensureArcConsistency(size_t j, + std::vector* domains) const override; + + /// Partially apply known values + Constraint::shared_ptr partiallyApply(const Values&) const override; + + /// Partially apply known values, domain version + Constraint::shared_ptr partiallyApply( + const std::vector&) const override; + }; + +} // namespace gtsam diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index d8e1a590a..acc3cc421 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -7,93 +7,92 @@ #pragma once -#include -#include #include +#include +#include namespace gtsam { -/** - * Binary AllDiff constraint - * Returns 1 if values for two keys are different, 0 otherwise - * DiscreteFactors are all awkward in that they have to store two types of keys: - * for each variable we have a Index and an Index. In this factor, we - * keep the Indices locally, and the Indices are stored in IndexFactor. - */ -class BinaryAllDiff : public Constraint { - size_t cardinality0_, cardinality1_; /// cardinality - - public: - /// Constructor - BinaryAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) - : Constraint(key1.first, key2.first), - cardinality0_(key1.second), - cardinality1_(key2.second) {} - - // print - void print( - const std::string& s = "", - const KeyFormatter& formatter = DefaultKeyFormatter) const override { - std::cout << s << "BinaryAllDiff on " << formatter(keys_[0]) << " and " - << formatter(keys_[1]) << std::endl; - } - - /// equals - bool equals(const DiscreteFactor& other, double tol) const override { - if (!dynamic_cast(&other)) - return false; - else { - const BinaryAllDiff& f(static_cast(other)); - return (cardinality0_ == f.cardinality0_) && - (cardinality1_ == f.cardinality1_); - } - } - - /// Calculate value - double operator()(const Values& values) const override { - return (double)(values.at(keys_[0]) != values.at(keys_[1])); - } - - /// Convert into a decisiontree - DecisionTreeFactor toDecisionTreeFactor() const override { - DiscreteKeys keys; - keys.push_back(DiscreteKey(keys_[0], cardinality0_)); - keys.push_back(DiscreteKey(keys_[1], cardinality1_)); - std::vector table; - for (size_t i1 = 0; i1 < cardinality0_; i1++) - for (size_t i2 = 0; i2 < cardinality1_; i2++) table.push_back(i1 != i2); - DecisionTreeFactor converted(keys, table); - return converted; - } - - /// Multiply into a decisiontree - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { - // TODO: can we do this more efficiently? - return toDecisionTreeFactor() * f; - } - - /* - * Ensure Arc-consistency - * @param j domain to be checked - * @param domains all other domains + /** + * Binary AllDiff constraint + * Returns 1 if values for two keys are different, 0 otherwise. */ - bool ensureArcConsistency(size_t j, - std::vector& domains) const override { - // throw std::runtime_error( - // "BinaryAllDiff::ensureArcConsistency not implemented"); - return false; - } + class BinaryAllDiff: public Constraint { - /// Partially apply known values - Constraint::shared_ptr partiallyApply(const Values&) const override { - throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); - } + size_t cardinality0_, cardinality1_; /// cardinality - /// Partially apply known values, domain version - Constraint::shared_ptr partiallyApply( - const std::vector&) const override { - throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); - } -}; + public: -} // namespace gtsam + /// Constructor + BinaryAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) : + Constraint(key1.first, key2.first), + cardinality0_(key1.second), cardinality1_(key2.second) { + } + + // print + void print(const std::string& s = "", + const KeyFormatter& formatter = DefaultKeyFormatter) const override { + std::cout << s << "BinaryAllDiff on " << formatter(keys_[0]) << " and " + << formatter(keys_[1]) << std::endl; + } + + /// equals + bool equals(const DiscreteFactor& other, double tol) const override { + if(!dynamic_cast(&other)) + return false; + else { + const BinaryAllDiff& f(static_cast(other)); + return (cardinality0_==f.cardinality0_) && (cardinality1_==f.cardinality1_); + } + } + + /// Calculate value + double operator()(const Values& values) const override { + return (double) (values.at(keys_[0]) != values.at(keys_[1])); + } + + /// Convert into a decisiontree + DecisionTreeFactor toDecisionTreeFactor() const override { + DiscreteKeys keys; + keys.push_back(DiscreteKey(keys_[0],cardinality0_)); + keys.push_back(DiscreteKey(keys_[1],cardinality1_)); + std::vector table; + for (size_t i1 = 0; i1 < cardinality0_; i1++) + for (size_t i2 = 0; i2 < cardinality1_; i2++) + table.push_back(i1 != i2); + DecisionTreeFactor converted(keys, table); + return converted; + } + + /// Multiply into a decisiontree + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { + // TODO: can we do this more efficiently? + return toDecisionTreeFactor() * f; + } + + /* + * Ensure Arc-consistency + * @param j domain to be checked + * @param domains all other domains + */ + /// + bool ensureArcConsistency(size_t j, + std::vector* domains) const override { + throw std::runtime_error( + "BinaryAllDiff::ensureArcConsistency not implemented"); + return false; + } + + /// Partially apply known values + Constraint::shared_ptr partiallyApply(const Values&) const override { + throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); + } + + /// Partially apply known values, domain version + Constraint::shared_ptr partiallyApply( + const std::vector&) const override { + throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); + } + }; + +} // namespace gtsam diff --git a/gtsam_unstable/discrete/CSP.cpp b/gtsam_unstable/discrete/CSP.cpp index b1d70dc6e..bab1ac3c8 100644 --- a/gtsam_unstable/discrete/CSP.cpp +++ b/gtsam_unstable/discrete/CSP.cpp @@ -56,12 +56,11 @@ void CSP::runArcConsistency(size_t cardinality, size_t nrIterations, // if not already a singleton if (!domains[v].isSingleton()) { // get the constraint and call its ensureArcConsistency method - Constraint::shared_ptr constraint = - boost::dynamic_pointer_cast((*this)[f]); + auto constraint = boost::dynamic_pointer_cast((*this)[f]); if (!constraint) throw runtime_error("CSP:runArcConsistency: non-constraint factor"); changed[v] = - constraint->ensureArcConsistency(v, domains) || changed[v]; + constraint->ensureArcConsistency(v, &domains) || changed[v]; } } // f if (changed[v]) anyChange = true; diff --git a/gtsam_unstable/discrete/Constraint.h b/gtsam_unstable/discrete/Constraint.h index b8baccff9..ff6f3834e 100644 --- a/gtsam_unstable/discrete/Constraint.h +++ b/gtsam_unstable/discrete/Constraint.h @@ -17,68 +17,79 @@ #pragma once -#include #include - +#include #include namespace gtsam { -class Domain; + class Domain; -/** - * Base class for discrete probabilistic factors - * The most general one is the derived DecisionTreeFactor - */ -class Constraint : public DiscreteFactor { - public: - typedef boost::shared_ptr shared_ptr; - - protected: - /// Construct n-way factor - Constraint(const KeyVector& js) : DiscreteFactor(js) {} - - /// Construct unary factor - Constraint(Key j) : DiscreteFactor(boost::assign::cref_list_of<1>(j)) {} - - /// Construct binary factor - Constraint(Key j1, Key j2) - : DiscreteFactor(boost::assign::cref_list_of<2>(j1)(j2)) {} - - /// construct from container - template - Constraint(KeyIterator beginKey, KeyIterator endKey) - : DiscreteFactor(beginKey, endKey) {} - - public: - /// @name Standard Constructors - /// @{ - - /// Default constructor for I/O - Constraint(); - - /// Virtual destructor - ~Constraint() override {} - - /// @} - /// @name Standard Interface - /// @{ - - /* - * Ensure Arc-consistency - * @param j domain to be checked - * @param domains all other domains + /** + * Base class for constraint factors + * Derived classes include SingleValue, BinaryAllDiff, and AllDiff. */ - virtual bool ensureArcConsistency(size_t j, - std::vector& domains) const = 0; + class GTSAM_EXPORT Constraint : public DiscreteFactor { - /// Partially apply known values - virtual shared_ptr partiallyApply(const Values&) const = 0; + public: - /// Partially apply known values, domain version - virtual shared_ptr partiallyApply(const std::vector&) const = 0; - /// @} -}; + typedef boost::shared_ptr shared_ptr; + + protected: + + /// Construct unary constraint factor. + Constraint(Key j) : + DiscreteFactor(boost::assign::cref_list_of<1>(j)) { + } + + /// Construct binary constraint factor. + Constraint(Key j1, Key j2) : + DiscreteFactor(boost::assign::cref_list_of<2>(j1)(j2)) { + } + + /// Construct n-way constraint factor. + Constraint(const KeyVector& js) : + DiscreteFactor(js) { + } + + /// construct from container + template + Constraint(KeyIterator beginKey, KeyIterator endKey) : + DiscreteFactor(beginKey, endKey) { + } + + public: + + /// @name Standard Constructors + /// @{ + + /// Default constructor for I/O + Constraint(); + + /// Virtual destructor + ~Constraint() override {} + + /// @} + /// @name Standard Interface + /// @{ + + /* + * Ensure Arc-consistency, possibly changing domains of connected variables. + * @param j domain to be checked + * @param (in/out) domains all other domains + * @return true if domains were changed, false otherwise. + */ + virtual bool ensureArcConsistency(size_t j, + std::vector* domains) const = 0; + + /// Partially apply known values + virtual shared_ptr partiallyApply(const Values&) const = 0; + + + /// Partially apply known values, domain version + virtual shared_ptr partiallyApply(const std::vector&) const = 0; + /// @} + }; // DiscreteFactor -} // namespace gtsam +}// namespace gtsam diff --git a/gtsam_unstable/discrete/Domain.cpp b/gtsam_unstable/discrete/Domain.cpp index a81b1d1ad..c2ba1c7f9 100644 --- a/gtsam_unstable/discrete/Domain.cpp +++ b/gtsam_unstable/discrete/Domain.cpp @@ -5,89 +5,90 @@ * @author Frank Dellaert */ -#include -#include #include - +#include +#include #include namespace gtsam { -using namespace std; + using namespace std; -/* ************************************************************************* */ -void Domain::print(const string& s, const KeyFormatter& formatter) const { - // cout << s << ": Domain on " << formatter(keys_[0]) << " (j=" << - // formatter(keys_[0]) << ") with values"; - // for (size_t v: values_) cout << " " << v; - // cout << endl; - for (size_t v : values_) cout << v; -} - -/* ************************************************************************* */ -double Domain::operator()(const Values& values) const { - return contains(values.at(keys_[0])); -} - -/* ************************************************************************* */ -DecisionTreeFactor Domain::toDecisionTreeFactor() const { - DiscreteKeys keys; - keys += DiscreteKey(keys_[0], cardinality_); - vector table; - for (size_t i1 = 0; i1 < cardinality_; ++i1) table.push_back(contains(i1)); - DecisionTreeFactor converted(keys, table); - return converted; -} - -/* ************************************************************************* */ -DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const { - // TODO: can we do this more efficiently? - return toDecisionTreeFactor() * f; -} - -/* ************************************************************************* */ -bool Domain::ensureArcConsistency(size_t j, vector& domains) const { - if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain"); - Domain& D = domains[j]; - for (size_t value : values_) - if (!D.contains(value)) throw runtime_error("Unsatisfiable"); - D = *this; - return true; -} - -/* ************************************************************************* */ -bool Domain::checkAllDiff(const KeyVector keys, vector& domains) { - Key j = keys_[0]; - // for all values in this domain - for (size_t value : values_) { - // for all connected domains - for (Key k : keys) - // if any domain contains the value we cannot make this domain singleton - if (k != j && domains[k].contains(value)) goto found; - values_.clear(); - values_.insert(value); - return true; // we changed it - found:; + /* ************************************************************************* */ + void Domain::print(const string& s, + const KeyFormatter& formatter) const { + cout << s << ": Domain on " << formatter(keys_[0]) << " (j=" << + formatter(keys_[0]) << ") with values"; + for (size_t v: values_) cout << " " << v; + cout << endl; + } + + /* ************************************************************************* */ + double Domain::operator()(const Values& values) const { + return contains(values.at(keys_[0])); + } + + /* ************************************************************************* */ + DecisionTreeFactor Domain::toDecisionTreeFactor() const { + DiscreteKeys keys; + keys += DiscreteKey(keys_[0],cardinality_); + vector table; + for (size_t i1 = 0; i1 < cardinality_; ++i1) + table.push_back(contains(i1)); + DecisionTreeFactor converted(keys, table); + return converted; + } + + /* ************************************************************************* */ + DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const { + // TODO: can we do this more efficiently? + return toDecisionTreeFactor() * f; + } + + /* ************************************************************************* */ + bool Domain::ensureArcConsistency(size_t j, vector* domains) const { + if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain"); + Domain& D = domains->at(j); + for(size_t value: values_) + if (!D.contains(value)) throw runtime_error("Unsatisfiable"); + D = *this; + return true; + } + + /* ************************************************************************* */ + boost::optional Domain::checkAllDiff( + const KeyVector keys, const vector& domains) const { + Key j = keys_[0]; + // for all values in this domain + for (const size_t value : values_) { + // for all connected domains + for (const Key k : keys) + // if any domain contains the value we cannot make this domain singleton + if (k != j && domains[k].contains(value)) goto found; + // Otherwise: return a singleton: + return Domain(this->discreteKey(), value); + found:; + } + return boost::none; // we did not change it + } + + /* ************************************************************************* */ + Constraint::shared_ptr Domain::partiallyApply( + const Values& values) const { + Values::const_iterator it = values.find(keys_[0]); + if (it != values.end() && !contains(it->second)) throw runtime_error( + "Domain::partiallyApply: unsatisfiable"); + return boost::make_shared < Domain > (*this); + } + + /* ************************************************************************* */ + Constraint::shared_ptr Domain::partiallyApply( + const vector& domains) const { + const Domain& Dk = domains[keys_[0]]; + if (Dk.isSingleton() && !contains(*Dk.begin())) throw runtime_error( + "Domain::partiallyApply: unsatisfiable"); + return boost::make_shared < Domain > (Dk); } - return false; // we did not change it -} /* ************************************************************************* */ -Constraint::shared_ptr Domain::partiallyApply(const Values& values) const { - Values::const_iterator it = values.find(keys_[0]); - if (it != values.end() && !contains(it->second)) - throw runtime_error("Domain::partiallyApply: unsatisfiable"); - return boost::make_shared(*this); -} - -/* ************************************************************************* */ -Constraint::shared_ptr Domain::partiallyApply( - const vector& domains) const { - const Domain& Dk = domains[keys_[0]]; - if (Dk.isSingleton() && !contains(*Dk.begin())) - throw runtime_error("Domain::partiallyApply: unsatisfiable"); - return boost::make_shared(Dk); -} - -/* ************************************************************************* */ -} // namespace gtsam +} // namespace gtsam diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 15828b653..d06966081 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -7,18 +7,23 @@ #pragma once -#include #include +#include namespace gtsam { /** - * Domain restriction constraint + * The Domain class represents a constraint that restricts the possible values a + * particular variable, with given key, can take on. */ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { size_t cardinality_; /// Cardinality std::set values_; /// allowed values + DiscreteKey discreteKey() const { + return DiscreteKey(keys_[0], cardinality_); + } + public: typedef boost::shared_ptr shared_ptr; @@ -35,14 +40,10 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { values_.insert(v); } - /// Constructor - Domain(const Domain& other) - : Constraint(other.keys_[0]), values_(other.values_) {} - - /// insert a value, non const :-( + /// Insert a value, non const :-( void insert(size_t value) { values_.insert(value); } - /// erase a value, non const :-( + /// Erase a value, non const :-( void erase(size_t value) { values_.erase(value); } size_t nrValues() const { return values_.size(); } @@ -82,15 +83,17 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { * @param domains all other domains */ bool ensureArcConsistency(size_t j, - std::vector& domains) const override; + std::vector* domains) const override; /** - * Check for a value in domain that does not occur in any other connected - * domain. If found, we make this a singleton... Called in - * AllDiff::ensureArcConsistency - * @param keys connected domains through alldiff + * Check for a value in domain that does not occur in any other connected + * domain. If found, return a a new singleton domain... + * Called in AllDiff::ensureArcConsistency + * @param keys connected domains through alldiff + * @param keys other domains */ - bool checkAllDiff(const KeyVector keys, std::vector& domains); + boost::optional checkAllDiff( + const KeyVector keys, const std::vector& domains) const; /// Partially apply known values Constraint::shared_ptr partiallyApply(const Values& values) const override; @@ -98,6 +101,6 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { /// Partially apply known values, domain version Constraint::shared_ptr partiallyApply( const std::vector& domains) const override; -}; + }; -} // namespace gtsam +} // namespace gtsam diff --git a/gtsam_unstable/discrete/SingleValue.cpp b/gtsam_unstable/discrete/SingleValue.cpp index 105887dc9..e042e550c 100644 --- a/gtsam_unstable/discrete/SingleValue.cpp +++ b/gtsam_unstable/discrete/SingleValue.cpp @@ -5,74 +5,75 @@ * @author Frank Dellaert */ -#include -#include -#include #include - +#include +#include +#include #include namespace gtsam { -using namespace std; + using namespace std; -/* ************************************************************************* */ -void SingleValue::print(const string& s, const KeyFormatter& formatter) const { - cout << s << "SingleValue on " - << "j=" << formatter(keys_[0]) << " with value " << value_ << endl; -} - -/* ************************************************************************* */ -double SingleValue::operator()(const Values& values) const { - return (double)(values.at(keys_[0]) == value_); -} - -/* ************************************************************************* */ -DecisionTreeFactor SingleValue::toDecisionTreeFactor() const { - DiscreteKeys keys; - keys += DiscreteKey(keys_[0], cardinality_); - vector table; - for (size_t i1 = 0; i1 < cardinality_; i1++) table.push_back(i1 == value_); - DecisionTreeFactor converted(keys, table); - return converted; -} - -/* ************************************************************************* */ -DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const { - // TODO: can we do this more efficiently? - return toDecisionTreeFactor() * f; -} - -/* ************************************************************************* */ -bool SingleValue::ensureArcConsistency(size_t j, - vector& domains) const { - if (j != keys_[0]) - throw invalid_argument("SingleValue check on wrong domain"); - Domain& D = domains[j]; - if (D.isSingleton()) { - if (D.firstValue() != value_) throw runtime_error("Unsatisfiable"); - return false; + /* ************************************************************************* */ + void SingleValue::print(const string& s, + const KeyFormatter& formatter) const { + cout << s << "SingleValue on " << "j=" << formatter(keys_[0]) + << " with value " << value_ << endl; + } + + /* ************************************************************************* */ + double SingleValue::operator()(const Values& values) const { + return (double) (values.at(keys_[0]) == value_); + } + + /* ************************************************************************* */ + DecisionTreeFactor SingleValue::toDecisionTreeFactor() const { + DiscreteKeys keys; + keys += DiscreteKey(keys_[0],cardinality_); + vector table; + for (size_t i1 = 0; i1 < cardinality_; i1++) + table.push_back(i1 == value_); + DecisionTreeFactor converted(keys, table); + return converted; + } + + /* ************************************************************************* */ + DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const { + // TODO: can we do this more efficiently? + return toDecisionTreeFactor() * f; + } + + /* ************************************************************************* */ + bool SingleValue::ensureArcConsistency(size_t j, + vector* domains) const { + if (j != keys_[0]) + throw invalid_argument("SingleValue check on wrong domain"); + Domain& D = domains->at(j); + if (D.isSingleton()) { + if (D.firstValue() != value_) throw runtime_error("Unsatisfiable"); + return false; + } + D = Domain(discreteKey(), value_); + return true; + } + + /* ************************************************************************* */ + Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const { + Values::const_iterator it = values.find(keys_[0]); + if (it != values.end() && it->second != value_) throw runtime_error( + "SingleValue::partiallyApply: unsatisfiable"); + return boost::make_shared(keys_[0], cardinality_, value_); + } + + /* ************************************************************************* */ + Constraint::shared_ptr SingleValue::partiallyApply( + const vector& domains) const { + const Domain& Dk = domains[keys_[0]]; + if (Dk.isSingleton() && !Dk.contains(value_)) throw runtime_error( + "SingleValue::partiallyApply: unsatisfiable"); + return boost::make_shared(discreteKey(), value_); } - D = Domain(discreteKey(), value_); - return true; -} /* ************************************************************************* */ -Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const { - Values::const_iterator it = values.find(keys_[0]); - if (it != values.end() && it->second != value_) - throw runtime_error("SingleValue::partiallyApply: unsatisfiable"); - return boost::make_shared(keys_[0], cardinality_, value_); -} - -/* ************************************************************************* */ -Constraint::shared_ptr SingleValue::partiallyApply( - const vector& domains) const { - const Domain& Dk = domains[keys_[0]]; - if (Dk.isSingleton() && !Dk.contains(value_)) - throw runtime_error("SingleValue::partiallyApply: unsatisfiable"); - return boost::make_shared(discreteKey(), value_); -} - -/* ************************************************************************* */ -} // namespace gtsam +} // namespace gtsam diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index a2aec338c..0f9a8fb0f 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -7,73 +7,74 @@ #pragma once -#include #include namespace gtsam { -/** - * SingleValue constraint - */ -class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { - /// Number of values - size_t cardinality_; - - /// allowed value - size_t value_; - - DiscreteKey discreteKey() const { - return DiscreteKey(keys_[0], cardinality_); - } - - public: - typedef boost::shared_ptr shared_ptr; - - /// Constructor - SingleValue(Key key, size_t n, size_t value) - : Constraint(key), cardinality_(n), value_(value) {} - - /// Constructor - SingleValue(const DiscreteKey& dkey, size_t value) - : Constraint(dkey.first), cardinality_(dkey.second), value_(value) {} - - // print - void print(const std::string& s = "", const KeyFormatter& formatter = - DefaultKeyFormatter) const override; - - /// equals - bool equals(const DiscreteFactor& other, double tol) const override { - if (!dynamic_cast(&other)) - return false; - else { - const SingleValue& f(static_cast(other)); - return (cardinality_ == f.cardinality_) && (value_ == f.value_); - } - } - - /// Calculate value - double operator()(const Values& values) const override; - - /// Convert into a decisiontree - DecisionTreeFactor toDecisionTreeFactor() const override; - - /// Multiply into a decisiontree - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; - - /* - * Ensure Arc-consistency - * @param j domain to be checked - * @param domains all other domains + /** + * SingleValue constraint: ensures a variable takes on a certain value. + * This could of course also be implemented by changing its `Domain`. */ - bool ensureArcConsistency(size_t j, - std::vector& domains) const override; + class GTSAM_UNSTABLE_EXPORT SingleValue: public Constraint { + + size_t cardinality_; /// < Number of values + size_t value_; ///< allowed value - /// Partially apply known values - Constraint::shared_ptr partiallyApply(const Values& values) const override; + DiscreteKey discreteKey() const { + return DiscreteKey(keys_[0],cardinality_); + } - /// Partially apply known values, domain version - Constraint::shared_ptr partiallyApply( - const std::vector& domains) const override; -}; + public: -} // namespace gtsam + typedef boost::shared_ptr shared_ptr; + + /// Construct from key, cardinality, and given value. + SingleValue(Key key, size_t n, size_t value) : + Constraint(key), cardinality_(n), value_(value) { + } + + /// Construct from DiscreteKey and given value. + SingleValue(const DiscreteKey& dkey, size_t value) : + Constraint(dkey.first), cardinality_(dkey.second), value_(value) { + } + + // print + void print(const std::string& s = "", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; + + /// equals + bool equals(const DiscreteFactor& other, double tol) const override { + if(!dynamic_cast(&other)) + return false; + else { + const SingleValue& f(static_cast(other)); + return (cardinality_==f.cardinality_) && (value_==f.value_); + } + } + + /// Calculate value + double operator()(const Values& values) const override; + + /// Convert into a decisiontree + DecisionTreeFactor toDecisionTreeFactor() const override; + + /// Multiply into a decisiontree + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + + /* + * Ensure Arc-consistency: just sets domain[j] to {value_} + * @param j domain to be checked + * @param domains all other domains + */ + bool ensureArcConsistency(size_t j, + std::vector* domains) const override; + + /// Partially apply known values + Constraint::shared_ptr partiallyApply(const Values& values) const override; + + /// Partially apply known values, domain version + Constraint::shared_ptr partiallyApply( + const std::vector& domains) const override; + }; + +} // namespace gtsam diff --git a/gtsam_unstable/discrete/tests/testCSP.cpp b/gtsam_unstable/discrete/tests/testCSP.cpp index 1552fcbf1..b1aaab303 100644 --- a/gtsam_unstable/discrete/tests/testCSP.cpp +++ b/gtsam_unstable/discrete/tests/testCSP.cpp @@ -19,12 +19,33 @@ using namespace std; using namespace gtsam; /* ************************************************************************* */ -TEST_UNSAFE(BinaryAllDif, allInOne) { - // Create keys and ordering +TEST(CSP, SingleValue) { + // Create keys for Idaho, Arizona, and Utah, allowing two colors for each: + size_t nrColors = 3; + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); + + // Check that a single value is equal to a decision stump with only one "1": + SingleValue singleValue(AZ, 2); + DecisionTreeFactor f1(AZ, "0 0 1"); + EXPECT(assert_equal(f1, singleValue.toDecisionTreeFactor())); + + // Create domains, laid out as a vector. + // TODO(dellaert): should be map?? + vector domains; + domains += Domain(ID), Domain(AZ), Domain(UT); + + // Ensure arc-consistency: just wipes out values in AZ domain: + EXPECT(singleValue.ensureArcConsistency(1, &domains)); + LONGS_EQUAL(3, domains[0].nrValues()); + LONGS_EQUAL(1, domains[1].nrValues()); + LONGS_EQUAL(3, domains[2].nrValues()); +} + +/* ************************************************************************* */ +TEST(CSP, BinaryAllDif) { + // Create keys for Idaho, Arizona, and Utah, allowing 2 colors for each: size_t nrColors = 2; - // DiscreteKey ID("Idaho", nrColors), UT("Utah", nrColors), AZ("Arizona", - // nrColors); - DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); // Check construction and conversion BinaryAllDiff c1(ID, UT); @@ -36,16 +57,51 @@ TEST_UNSAFE(BinaryAllDif, allInOne) { DecisionTreeFactor f2(UT & AZ, "0 1 1 0"); EXPECT(assert_equal(f2, c2.toDecisionTreeFactor())); + // Check multiplication of factors with constraint: DecisionTreeFactor f3 = f1 * f2; EXPECT(assert_equal(f3, c1 * f2)); EXPECT(assert_equal(f3, c2 * f1)); } /* ************************************************************************* */ -TEST_UNSAFE(CSP, allInOne) { - // Create keys and ordering +TEST(CSP, AllDiff) { + // Create keys for Idaho, Arizona, and Utah, allowing two colors for each: + size_t nrColors = 3; + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); + + // Check construction and conversion + vector dkeys{ID, UT, AZ}; + AllDiff alldiff(dkeys); + DecisionTreeFactor actual = alldiff.toDecisionTreeFactor(); + // GTSAM_PRINT(actual); + actual.dot("actual"); + DecisionTreeFactor f2( + ID & AZ & UT, + "0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0"); + EXPECT(assert_equal(f2, actual)); + + // Create domains. + vector domains; + domains += Domain(ID), Domain(AZ), Domain(UT); + + // First constrict AZ domain: + SingleValue singleValue(AZ, 2); + EXPECT(singleValue.ensureArcConsistency(1, &domains)); + + // Arc-consistency + EXPECT(alldiff.ensureArcConsistency(0, &domains)); + EXPECT(!alldiff.ensureArcConsistency(1, &domains)); + EXPECT(alldiff.ensureArcConsistency(2, &domains)); + LONGS_EQUAL(2, domains[0].nrValues()); + LONGS_EQUAL(1, domains[1].nrValues()); + LONGS_EQUAL(2, domains[2].nrValues()); +} + +/* ************************************************************************* */ +TEST(CSP, allInOne) { + // Create keys for Idaho, Arizona, and Utah, allowing 3 colors for each: size_t nrColors = 2; - DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); // Create the CSP CSP csp; @@ -81,15 +137,12 @@ TEST_UNSAFE(CSP, allInOne) { } /* ************************************************************************* */ -TEST_UNSAFE(CSP, WesternUS) { - // Create keys +TEST(CSP, WesternUS) { + // Create keys for all states in Western US, with 4 color possibilities. size_t nrColors = 4; - DiscreteKey - // Create ordering according to example in ND-CSP.lyx - WA(0, nrColors), - OR(3, nrColors), CA(1, nrColors), NV(2, nrColors), ID(8, nrColors), - UT(9, nrColors), AZ(10, nrColors), MT(4, nrColors), WY(5, nrColors), - CO(7, nrColors), NM(6, nrColors); + DiscreteKey WA(0, nrColors), OR(3, nrColors), CA(1, nrColors), + NV(2, nrColors), ID(8, nrColors), UT(9, nrColors), AZ(10, nrColors), + MT(4, nrColors), WY(5, nrColors), CO(7, nrColors), NM(6, nrColors); // Create the CSP CSP csp; @@ -116,10 +169,12 @@ TEST_UNSAFE(CSP, WesternUS) { csp.addAllDiff(WY, CO); csp.addAllDiff(CO, NM); - // Solve + // Create ordering according to example in ND-CSP.lyx Ordering ordering; ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7), Key(8), Key(9), Key(10); + + // Solve using that ordering: CSP::sharedValues mpe = csp.optimalAssignment(ordering); // GTSAM_PRINT(*mpe); CSP::Values expected; @@ -143,33 +198,17 @@ TEST_UNSAFE(CSP, WesternUS) { } /* ************************************************************************* */ -TEST_UNSAFE(CSP, AllDiff) { - // Create keys and ordering +TEST(CSP, ArcConsistency) { + // Create keys for Idaho, Arizona, and Utah, allowing three colors for each: size_t nrColors = 3; - DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); - // Create the CSP + // Create the CSP using just one all-diff constraint, plus constrain Arizona. CSP csp; - vector dkeys; - dkeys += ID, UT, AZ; + vector dkeys{ID, UT, AZ}; csp.addAllDiff(dkeys); csp.addSingleValue(AZ, 2); - // GTSAM_PRINT(csp); - - // Check construction and conversion - SingleValue s(AZ, 2); - DecisionTreeFactor f1(AZ, "0 0 1"); - EXPECT(assert_equal(f1, s.toDecisionTreeFactor())); - - // Check construction and conversion - AllDiff alldiff(dkeys); - DecisionTreeFactor actual = alldiff.toDecisionTreeFactor(); - // GTSAM_PRINT(actual); - // actual.dot("actual"); - DecisionTreeFactor f2( - ID & AZ & UT, - "0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0"); - EXPECT(assert_equal(f2, actual)); + // GTSAM_PRINT(csp); // Check an invalid combination, with ID==UT==AZ all same color DiscreteFactor::Values invalid; @@ -192,14 +231,15 @@ TEST_UNSAFE(CSP, AllDiff) { EXPECT(assert_equal(expected, *mpe)); EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); - // Arc-consistency + // ensure arc-consistency, i.e., narrow domains... vector domains; domains += Domain(ID), Domain(AZ), Domain(UT); SingleValue singleValue(AZ, 2); - EXPECT(singleValue.ensureArcConsistency(1, domains)); - EXPECT(alldiff.ensureArcConsistency(0, domains)); - EXPECT(!alldiff.ensureArcConsistency(1, domains)); - EXPECT(alldiff.ensureArcConsistency(2, domains)); + AllDiff alldiff(dkeys); + EXPECT(singleValue.ensureArcConsistency(1, &domains)); + EXPECT(alldiff.ensureArcConsistency(0, &domains)); + EXPECT(!alldiff.ensureArcConsistency(1, &domains)); + EXPECT(alldiff.ensureArcConsistency(2, &domains)); LONGS_EQUAL(2, domains[0].nrValues()); LONGS_EQUAL(1, domains[1].nrValues()); LONGS_EQUAL(2, domains[2].nrValues()); @@ -222,6 +262,7 @@ TEST_UNSAFE(CSP, AllDiff) { // full arc-consistency test csp.runArcConsistency(nrColors); + // GTSAM_PRINT(csp); } /* ************************************************************************* */ From b7f43906bc3fc136a14c8b62b240399d39544e5d Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Thu, 18 Nov 2021 15:08:01 -0500 Subject: [PATCH 2/5] Formatting only --- gtsam_unstable/discrete/AllDiff.cpp | 189 +++++++++++----------- gtsam_unstable/discrete/AllDiff.h | 116 +++++++------ gtsam_unstable/discrete/BinaryAllDiff.h | 153 +++++++++--------- gtsam_unstable/discrete/Constraint.h | 120 +++++++------- gtsam_unstable/discrete/Domain.cpp | 158 +++++++++--------- gtsam_unstable/discrete/Domain.h | 6 +- gtsam_unstable/discrete/SingleValue.cpp | 129 ++++++++------- gtsam_unstable/discrete/SingleValue.h | 123 +++++++------- gtsam_unstable/discrete/tests/testCSP.cpp | 2 +- 9 files changed, 487 insertions(+), 509 deletions(-) diff --git a/gtsam_unstable/discrete/AllDiff.cpp b/gtsam_unstable/discrete/AllDiff.cpp index ebc789ec2..ef18053a4 100644 --- a/gtsam_unstable/discrete/AllDiff.cpp +++ b/gtsam_unstable/discrete/AllDiff.cpp @@ -5,115 +5,112 @@ * @author Frank Dellaert */ -#include -#include #include +#include +#include + #include namespace gtsam { - /* ************************************************************************* */ - AllDiff::AllDiff(const DiscreteKeys& dkeys) : - Constraint(dkeys.indices()) { - for(const DiscreteKey& dkey: dkeys) - cardinalities_.insert(dkey); - } +/* ************************************************************************* */ +AllDiff::AllDiff(const DiscreteKeys& dkeys) : Constraint(dkeys.indices()) { + for (const DiscreteKey& dkey : dkeys) cardinalities_.insert(dkey); +} - /* ************************************************************************* */ - void AllDiff::print(const std::string& s, - const KeyFormatter& formatter) const { - std::cout << s << "AllDiff on "; - for (Key dkey: keys_) - std::cout << formatter(dkey) << " "; - std::cout << std::endl; - } +/* ************************************************************************* */ +void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const { + std::cout << s << "AllDiff on "; + for (Key dkey : keys_) std::cout << formatter(dkey) << " "; + std::cout << std::endl; +} - /* ************************************************************************* */ - double AllDiff::operator()(const Values& values) const { - std::set < size_t > taken; // record values taken by keys - for(Key dkey: keys_) { - size_t value = values.at(dkey); // get the value for that key - if (taken.count(value)) return 0.0;// check if value alreday taken - taken.insert(value);// if not, record it as taken and keep checking +/* ************************************************************************* */ +double AllDiff::operator()(const Values& values) const { + std::set taken; // record values taken by keys + for (Key dkey : keys_) { + size_t value = values.at(dkey); // get the value for that key + if (taken.count(value)) return 0.0; // check if value alreday taken + taken.insert(value); // if not, record it as taken and keep checking + } + return 1.0; +} + +/* ************************************************************************* */ +DecisionTreeFactor AllDiff::toDecisionTreeFactor() const { + // We will do this by converting the allDif into many BinaryAllDiff + // constraints + DecisionTreeFactor converted; + size_t nrKeys = keys_.size(); + for (size_t i1 = 0; i1 < nrKeys; i1++) + for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) { + BinaryAllDiff binary12(discreteKey(i1), discreteKey(i2)); + converted = converted * binary12.toDecisionTreeFactor(); } - return 1.0; + return converted; +} + +/* ************************************************************************* */ +DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const { + // TODO: can we do this more efficiently? + return toDecisionTreeFactor() * f; +} + +/* ************************************************************************* */ +bool AllDiff::ensureArcConsistency(size_t j, + std::vector* domains) const { + // We are changing the domain of variable j. + // TODO(dellaert): confusing, I thought we were changing others... + Domain& Dj = domains->at(j); + + // Though strictly not part of allDiff, we check for + // a value in domains[j] that does not occur in any other connected domain. + // If found, we make this a singleton... + // TODO: make a new constraint where this really is true + boost::optional maybeChanged = Dj.checkAllDiff(keys_, *domains); + if (maybeChanged) { + Dj = *maybeChanged; + return true; } - /* ************************************************************************* */ - DecisionTreeFactor AllDiff::toDecisionTreeFactor() const { - // We will do this by converting the allDif into many BinaryAllDiff constraints - DecisionTreeFactor converted; - size_t nrKeys = keys_.size(); - for (size_t i1 = 0; i1 < nrKeys; i1++) - for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) { - BinaryAllDiff binary12(discreteKey(i1),discreteKey(i2)); - converted = converted * binary12.toDecisionTreeFactor(); - } - return converted; - } - - /* ************************************************************************* */ - DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const { - // TODO: can we do this more efficiently? - return toDecisionTreeFactor() * f; - } - - /* ************************************************************************* */ - bool AllDiff::ensureArcConsistency(size_t j, - std::vector* domains) const { - // We are changing the domain of variable j. - // TODO(dellaert): confusing, I thought we were changing others... - Domain& Dj = domains->at(j); - - // Though strictly not part of allDiff, we check for - // a value in domains[j] that does not occur in any other connected domain. - // If found, we make this a singleton... - // TODO: make a new constraint where this really is true - boost::optional maybeChanged = Dj.checkAllDiff(keys_, *domains); - if (maybeChanged) { - Dj = *maybeChanged; - return true; - } - - // Check all other domains for singletons and erase corresponding values. - // This is the same as arc-consistency on the equivalent binary constraints - bool changed = false; - for (Key k : keys_) - if (k != j) { - const Domain& Dk = domains->at(k); - if (Dk.isSingleton()) { // check if singleton - size_t value = Dk.firstValue(); - if (Dj.contains(value)) { - Dj.erase(value); // erase value if true - changed = true; - } + // Check all other domains for singletons and erase corresponding values. + // This is the same as arc-consistency on the equivalent binary constraints + bool changed = false; + for (Key k : keys_) + if (k != j) { + const Domain& Dk = domains->at(k); + if (Dk.isSingleton()) { // check if singleton + size_t value = Dk.firstValue(); + if (Dj.contains(value)) { + Dj.erase(value); // erase value if true + changed = true; } } - return changed; - } + } + return changed; +} - /* ************************************************************************* */ - Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const { - DiscreteKeys newKeys; - // loop over keys and add them only if they do not appear in values - for(Key k: keys_) - if (values.find(k) == values.end()) { - newKeys.push_back(DiscreteKey(k,cardinalities_.at(k))); - } - return boost::make_shared(newKeys); - } +/* ************************************************************************* */ +Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const { + DiscreteKeys newKeys; + // loop over keys and add them only if they do not appear in values + for (Key k : keys_) + if (values.find(k) == values.end()) { + newKeys.push_back(DiscreteKey(k, cardinalities_.at(k))); + } + return boost::make_shared(newKeys); +} - /* ************************************************************************* */ - Constraint::shared_ptr AllDiff::partiallyApply( - const std::vector& domains) const { - DiscreteFactor::Values known; - for(Key k: keys_) { - const Domain& Dk = domains[k]; - if (Dk.isSingleton()) - known[k] = Dk.firstValue(); - } - return partiallyApply(known); +/* ************************************************************************* */ +Constraint::shared_ptr AllDiff::partiallyApply( + const std::vector& domains) const { + DiscreteFactor::Values known; + for (Key k : keys_) { + const Domain& Dk = domains[k]; + if (Dk.isSingleton()) known[k] = Dk.firstValue(); } + return partiallyApply(known); +} - /* ************************************************************************* */ -} // namespace gtsam +/* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index 8c83e5ba1..4deabda94 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -7,70 +7,68 @@ #pragma once -#include #include +#include namespace gtsam { - /** - * General AllDiff constraint. - * Returns 1 if values for all keys are different, 0 otherwise. +/** + * General AllDiff constraint. + * Returns 1 if values for all keys are different, 0 otherwise. + */ +class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { + std::map cardinalities_; + + DiscreteKey discreteKey(size_t i) const { + Key j = keys_[i]; + return DiscreteKey(j, cardinalities_.at(j)); + } + + public: + /// Construct from keys. + AllDiff(const DiscreteKeys& dkeys); + + // print + void print(const std::string& s = "", const KeyFormatter& formatter = + DefaultKeyFormatter) const override; + + /// equals + bool equals(const DiscreteFactor& other, double tol) const override { + if (!dynamic_cast(&other)) + return false; + else { + const AllDiff& f(static_cast(other)); + return cardinalities_.size() == f.cardinalities_.size() && + std::equal(cardinalities_.begin(), cardinalities_.end(), + f.cardinalities_.begin()); + } + } + + /// Calculate value = expensive ! + double operator()(const Values& values) const override; + + /// Convert into a decisiontree, can be *very* expensive ! + DecisionTreeFactor toDecisionTreeFactor() const override; + + /// Multiply into a decisiontree + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + + /* + * Ensure Arc-consistency + * Arc-consistency involves creating binaryAllDiff constraints + * In which case the combinatorial hyper-arc explosion disappears. + * @param j domain to be checked + * @param (in/out) domains all other domains */ - class GTSAM_UNSTABLE_EXPORT AllDiff: public Constraint { + bool ensureArcConsistency(size_t j, + std::vector* domains) const override; - std::map cardinalities_; + /// Partially apply known values + Constraint::shared_ptr partiallyApply(const Values&) const override; - DiscreteKey discreteKey(size_t i) const { - Key j = keys_[i]; - return DiscreteKey(j,cardinalities_.at(j)); - } + /// Partially apply known values, domain version + Constraint::shared_ptr partiallyApply( + const std::vector&) const override; +}; - public: - - /// Construct from keys. - AllDiff(const DiscreteKeys& dkeys); - - // print - void print(const std::string& s = "", - const KeyFormatter& formatter = DefaultKeyFormatter) const override; - - /// equals - bool equals(const DiscreteFactor& other, double tol) const override { - if(!dynamic_cast(&other)) - return false; - else { - const AllDiff& f(static_cast(other)); - return cardinalities_.size() == f.cardinalities_.size() - && std::equal(cardinalities_.begin(), cardinalities_.end(), - f.cardinalities_.begin()); - } - } - - /// Calculate value = expensive ! - double operator()(const Values& values) const override; - - /// Convert into a decisiontree, can be *very* expensive ! - DecisionTreeFactor toDecisionTreeFactor() const override; - - /// Multiply into a decisiontree - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; - - /* - * Ensure Arc-consistency - * Arc-consistency involves creating binaryAllDiff constraints - * In which case the combinatorial hyper-arc explosion disappears. - * @param j domain to be checked - * @param (in/out) domains all other domains - */ - bool ensureArcConsistency(size_t j, - std::vector* domains) const override; - - /// Partially apply known values - Constraint::shared_ptr partiallyApply(const Values&) const override; - - /// Partially apply known values, domain version - Constraint::shared_ptr partiallyApply( - const std::vector&) const override; - }; - -} // namespace gtsam +} // namespace gtsam diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index acc3cc421..21cfb18f2 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -7,92 +7,91 @@ #pragma once -#include -#include #include +#include +#include namespace gtsam { - /** - * Binary AllDiff constraint - * Returns 1 if values for two keys are different, 0 otherwise. - */ - class BinaryAllDiff: public Constraint { +/** + * Binary AllDiff constraint + * Returns 1 if values for two keys are different, 0 otherwise. + */ +class BinaryAllDiff : public Constraint { + size_t cardinality0_, cardinality1_; /// cardinality - size_t cardinality0_, cardinality1_; /// cardinality + public: + /// Constructor + BinaryAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) + : Constraint(key1.first, key2.first), + cardinality0_(key1.second), + cardinality1_(key2.second) {} - public: + // print + void print( + const std::string& s = "", + const KeyFormatter& formatter = DefaultKeyFormatter) const override { + std::cout << s << "BinaryAllDiff on " << formatter(keys_[0]) << " and " + << formatter(keys_[1]) << std::endl; + } - /// Constructor - BinaryAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) : - Constraint(key1.first, key2.first), - cardinality0_(key1.second), cardinality1_(key2.second) { - } - - // print - void print(const std::string& s = "", - const KeyFormatter& formatter = DefaultKeyFormatter) const override { - std::cout << s << "BinaryAllDiff on " << formatter(keys_[0]) << " and " - << formatter(keys_[1]) << std::endl; - } - - /// equals - bool equals(const DiscreteFactor& other, double tol) const override { - if(!dynamic_cast(&other)) - return false; - else { - const BinaryAllDiff& f(static_cast(other)); - return (cardinality0_==f.cardinality0_) && (cardinality1_==f.cardinality1_); - } - } - - /// Calculate value - double operator()(const Values& values) const override { - return (double) (values.at(keys_[0]) != values.at(keys_[1])); - } - - /// Convert into a decisiontree - DecisionTreeFactor toDecisionTreeFactor() const override { - DiscreteKeys keys; - keys.push_back(DiscreteKey(keys_[0],cardinality0_)); - keys.push_back(DiscreteKey(keys_[1],cardinality1_)); - std::vector table; - for (size_t i1 = 0; i1 < cardinality0_; i1++) - for (size_t i2 = 0; i2 < cardinality1_; i2++) - table.push_back(i1 != i2); - DecisionTreeFactor converted(keys, table); - return converted; - } - - /// Multiply into a decisiontree - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { - // TODO: can we do this more efficiently? - return toDecisionTreeFactor() * f; - } - - /* - * Ensure Arc-consistency - * @param j domain to be checked - * @param domains all other domains - */ - /// - bool ensureArcConsistency(size_t j, - std::vector* domains) const override { - throw std::runtime_error( - "BinaryAllDiff::ensureArcConsistency not implemented"); + /// equals + bool equals(const DiscreteFactor& other, double tol) const override { + if (!dynamic_cast(&other)) return false; + else { + const BinaryAllDiff& f(static_cast(other)); + return (cardinality0_ == f.cardinality0_) && + (cardinality1_ == f.cardinality1_); } + } - /// Partially apply known values - Constraint::shared_ptr partiallyApply(const Values&) const override { - throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); - } + /// Calculate value + double operator()(const Values& values) const override { + return (double)(values.at(keys_[0]) != values.at(keys_[1])); + } - /// Partially apply known values, domain version - Constraint::shared_ptr partiallyApply( - const std::vector&) const override { - throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); - } - }; + /// Convert into a decisiontree + DecisionTreeFactor toDecisionTreeFactor() const override { + DiscreteKeys keys; + keys.push_back(DiscreteKey(keys_[0], cardinality0_)); + keys.push_back(DiscreteKey(keys_[1], cardinality1_)); + std::vector table; + for (size_t i1 = 0; i1 < cardinality0_; i1++) + for (size_t i2 = 0; i2 < cardinality1_; i2++) table.push_back(i1 != i2); + DecisionTreeFactor converted(keys, table); + return converted; + } -} // namespace gtsam + /// Multiply into a decisiontree + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { + // TODO: can we do this more efficiently? + return toDecisionTreeFactor() * f; + } + + /* + * Ensure Arc-consistency + * @param j domain to be checked + * @param domains all other domains + */ + /// + bool ensureArcConsistency(size_t j, + std::vector* domains) const override { + throw std::runtime_error( + "BinaryAllDiff::ensureArcConsistency not implemented"); + return false; + } + + /// Partially apply known values + Constraint::shared_ptr partiallyApply(const Values&) const override { + throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); + } + + /// Partially apply known values, domain version + Constraint::shared_ptr partiallyApply( + const std::vector&) const override { + throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); + } +}; + +} // namespace gtsam diff --git a/gtsam_unstable/discrete/Constraint.h b/gtsam_unstable/discrete/Constraint.h index ff6f3834e..e9714d6b4 100644 --- a/gtsam_unstable/discrete/Constraint.h +++ b/gtsam_unstable/discrete/Constraint.h @@ -17,79 +17,69 @@ #pragma once -#include #include +#include + #include namespace gtsam { - class Domain; +class Domain; - /** - * Base class for constraint factors - * Derived classes include SingleValue, BinaryAllDiff, and AllDiff. +/** + * Base class for constraint factors + * Derived classes include SingleValue, BinaryAllDiff, and AllDiff. + */ +class GTSAM_EXPORT Constraint : public DiscreteFactor { + public: + typedef boost::shared_ptr shared_ptr; + + protected: + /// Construct unary constraint factor. + Constraint(Key j) : DiscreteFactor(boost::assign::cref_list_of<1>(j)) {} + + /// Construct binary constraint factor. + Constraint(Key j1, Key j2) + : DiscreteFactor(boost::assign::cref_list_of<2>(j1)(j2)) {} + + /// Construct n-way constraint factor. + Constraint(const KeyVector& js) : DiscreteFactor(js) {} + + /// construct from container + template + Constraint(KeyIterator beginKey, KeyIterator endKey) + : DiscreteFactor(beginKey, endKey) {} + + public: + /// @name Standard Constructors + /// @{ + + /// Default constructor for I/O + Constraint(); + + /// Virtual destructor + ~Constraint() override {} + + /// @} + /// @name Standard Interface + /// @{ + + /* + * Ensure Arc-consistency, possibly changing domains of connected variables. + * @param j domain to be checked + * @param (in/out) domains all other domains + * @return true if domains were changed, false otherwise. */ - class GTSAM_EXPORT Constraint : public DiscreteFactor { + virtual bool ensureArcConsistency(size_t j, + std::vector* domains) const = 0; - public: + /// Partially apply known values + virtual shared_ptr partiallyApply(const Values&) const = 0; - typedef boost::shared_ptr shared_ptr; - - protected: - - /// Construct unary constraint factor. - Constraint(Key j) : - DiscreteFactor(boost::assign::cref_list_of<1>(j)) { - } - - /// Construct binary constraint factor. - Constraint(Key j1, Key j2) : - DiscreteFactor(boost::assign::cref_list_of<2>(j1)(j2)) { - } - - /// Construct n-way constraint factor. - Constraint(const KeyVector& js) : - DiscreteFactor(js) { - } - - /// construct from container - template - Constraint(KeyIterator beginKey, KeyIterator endKey) : - DiscreteFactor(beginKey, endKey) { - } - - public: - - /// @name Standard Constructors - /// @{ - - /// Default constructor for I/O - Constraint(); - - /// Virtual destructor - ~Constraint() override {} - - /// @} - /// @name Standard Interface - /// @{ - - /* - * Ensure Arc-consistency, possibly changing domains of connected variables. - * @param j domain to be checked - * @param (in/out) domains all other domains - * @return true if domains were changed, false otherwise. - */ - virtual bool ensureArcConsistency(size_t j, - std::vector* domains) const = 0; - - /// Partially apply known values - virtual shared_ptr partiallyApply(const Values&) const = 0; - - - /// Partially apply known values, domain version - virtual shared_ptr partiallyApply(const std::vector&) const = 0; - /// @} - }; + /// Partially apply known values, domain version + virtual shared_ptr partiallyApply(const std::vector&) const = 0; + /// @} +}; // DiscreteFactor -}// namespace gtsam +} // namespace gtsam diff --git a/gtsam_unstable/discrete/Domain.cpp b/gtsam_unstable/discrete/Domain.cpp index c2ba1c7f9..da23717f6 100644 --- a/gtsam_unstable/discrete/Domain.cpp +++ b/gtsam_unstable/discrete/Domain.cpp @@ -5,90 +5,88 @@ * @author Frank Dellaert */ -#include -#include #include +#include +#include + #include namespace gtsam { - using namespace std; - - /* ************************************************************************* */ - void Domain::print(const string& s, - const KeyFormatter& formatter) const { - cout << s << ": Domain on " << formatter(keys_[0]) << " (j=" << - formatter(keys_[0]) << ") with values"; - for (size_t v: values_) cout << " " << v; - cout << endl; - } - - /* ************************************************************************* */ - double Domain::operator()(const Values& values) const { - return contains(values.at(keys_[0])); - } - - /* ************************************************************************* */ - DecisionTreeFactor Domain::toDecisionTreeFactor() const { - DiscreteKeys keys; - keys += DiscreteKey(keys_[0],cardinality_); - vector table; - for (size_t i1 = 0; i1 < cardinality_; ++i1) - table.push_back(contains(i1)); - DecisionTreeFactor converted(keys, table); - return converted; - } - - /* ************************************************************************* */ - DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const { - // TODO: can we do this more efficiently? - return toDecisionTreeFactor() * f; - } - - /* ************************************************************************* */ - bool Domain::ensureArcConsistency(size_t j, vector* domains) const { - if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain"); - Domain& D = domains->at(j); - for(size_t value: values_) - if (!D.contains(value)) throw runtime_error("Unsatisfiable"); - D = *this; - return true; - } - - /* ************************************************************************* */ - boost::optional Domain::checkAllDiff( - const KeyVector keys, const vector& domains) const { - Key j = keys_[0]; - // for all values in this domain - for (const size_t value : values_) { - // for all connected domains - for (const Key k : keys) - // if any domain contains the value we cannot make this domain singleton - if (k != j && domains[k].contains(value)) goto found; - // Otherwise: return a singleton: - return Domain(this->discreteKey(), value); - found:; - } - return boost::none; // we did not change it - } - - /* ************************************************************************* */ - Constraint::shared_ptr Domain::partiallyApply( - const Values& values) const { - Values::const_iterator it = values.find(keys_[0]); - if (it != values.end() && !contains(it->second)) throw runtime_error( - "Domain::partiallyApply: unsatisfiable"); - return boost::make_shared < Domain > (*this); - } - - /* ************************************************************************* */ - Constraint::shared_ptr Domain::partiallyApply( - const vector& domains) const { - const Domain& Dk = domains[keys_[0]]; - if (Dk.isSingleton() && !contains(*Dk.begin())) throw runtime_error( - "Domain::partiallyApply: unsatisfiable"); - return boost::make_shared < Domain > (Dk); - } +using namespace std; /* ************************************************************************* */ -} // namespace gtsam +void Domain::print(const string& s, const KeyFormatter& formatter) const { + cout << s << ": Domain on " << formatter(keys_[0]) + << " (j=" << formatter(keys_[0]) << ") with values"; + for (size_t v : values_) cout << " " << v; + cout << endl; +} + +/* ************************************************************************* */ +double Domain::operator()(const Values& values) const { + return contains(values.at(keys_[0])); +} + +/* ************************************************************************* */ +DecisionTreeFactor Domain::toDecisionTreeFactor() const { + DiscreteKeys keys; + keys += DiscreteKey(keys_[0], cardinality_); + vector table; + for (size_t i1 = 0; i1 < cardinality_; ++i1) table.push_back(contains(i1)); + DecisionTreeFactor converted(keys, table); + return converted; +} + +/* ************************************************************************* */ +DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const { + // TODO: can we do this more efficiently? + return toDecisionTreeFactor() * f; +} + +/* ************************************************************************* */ +bool Domain::ensureArcConsistency(size_t j, vector* domains) const { + if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain"); + Domain& D = domains->at(j); + for (size_t value : values_) + if (!D.contains(value)) throw runtime_error("Unsatisfiable"); + D = *this; + return true; +} + +/* ************************************************************************* */ +boost::optional Domain::checkAllDiff( + const KeyVector keys, const vector& domains) const { + Key j = keys_[0]; + // for all values in this domain + for (const size_t value : values_) { + // for all connected domains + for (const Key k : keys) + // if any domain contains the value we cannot make this domain singleton + if (k != j && domains[k].contains(value)) goto found; + // Otherwise: return a singleton: + return Domain(this->discreteKey(), value); + found:; + } + return boost::none; // we did not change it +} + +/* ************************************************************************* */ +Constraint::shared_ptr Domain::partiallyApply(const Values& values) const { + Values::const_iterator it = values.find(keys_[0]); + if (it != values.end() && !contains(it->second)) + throw runtime_error("Domain::partiallyApply: unsatisfiable"); + return boost::make_shared(*this); +} + +/* ************************************************************************* */ +Constraint::shared_ptr Domain::partiallyApply( + const vector& domains) const { + const Domain& Dk = domains[keys_[0]]; + if (Dk.isSingleton() && !contains(*Dk.begin())) + throw runtime_error("Domain::partiallyApply: unsatisfiable"); + return boost::make_shared(Dk); +} + +/* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index d06966081..9fa22175a 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -7,8 +7,8 @@ #pragma once -#include #include +#include namespace gtsam { @@ -101,6 +101,6 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { /// Partially apply known values, domain version Constraint::shared_ptr partiallyApply( const std::vector& domains) const override; - }; +}; -} // namespace gtsam +} // namespace gtsam diff --git a/gtsam_unstable/discrete/SingleValue.cpp b/gtsam_unstable/discrete/SingleValue.cpp index e042e550c..753d46cff 100644 --- a/gtsam_unstable/discrete/SingleValue.cpp +++ b/gtsam_unstable/discrete/SingleValue.cpp @@ -5,75 +5,74 @@ * @author Frank Dellaert */ -#include -#include -#include #include +#include +#include +#include + #include namespace gtsam { - using namespace std; - - /* ************************************************************************* */ - void SingleValue::print(const string& s, - const KeyFormatter& formatter) const { - cout << s << "SingleValue on " << "j=" << formatter(keys_[0]) - << " with value " << value_ << endl; - } - - /* ************************************************************************* */ - double SingleValue::operator()(const Values& values) const { - return (double) (values.at(keys_[0]) == value_); - } - - /* ************************************************************************* */ - DecisionTreeFactor SingleValue::toDecisionTreeFactor() const { - DiscreteKeys keys; - keys += DiscreteKey(keys_[0],cardinality_); - vector table; - for (size_t i1 = 0; i1 < cardinality_; i1++) - table.push_back(i1 == value_); - DecisionTreeFactor converted(keys, table); - return converted; - } - - /* ************************************************************************* */ - DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const { - // TODO: can we do this more efficiently? - return toDecisionTreeFactor() * f; - } - - /* ************************************************************************* */ - bool SingleValue::ensureArcConsistency(size_t j, - vector* domains) const { - if (j != keys_[0]) - throw invalid_argument("SingleValue check on wrong domain"); - Domain& D = domains->at(j); - if (D.isSingleton()) { - if (D.firstValue() != value_) throw runtime_error("Unsatisfiable"); - return false; - } - D = Domain(discreteKey(), value_); - return true; - } - - /* ************************************************************************* */ - Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const { - Values::const_iterator it = values.find(keys_[0]); - if (it != values.end() && it->second != value_) throw runtime_error( - "SingleValue::partiallyApply: unsatisfiable"); - return boost::make_shared(keys_[0], cardinality_, value_); - } - - /* ************************************************************************* */ - Constraint::shared_ptr SingleValue::partiallyApply( - const vector& domains) const { - const Domain& Dk = domains[keys_[0]]; - if (Dk.isSingleton() && !Dk.contains(value_)) throw runtime_error( - "SingleValue::partiallyApply: unsatisfiable"); - return boost::make_shared(discreteKey(), value_); - } +using namespace std; /* ************************************************************************* */ -} // namespace gtsam +void SingleValue::print(const string& s, const KeyFormatter& formatter) const { + cout << s << "SingleValue on " + << "j=" << formatter(keys_[0]) << " with value " << value_ << endl; +} + +/* ************************************************************************* */ +double SingleValue::operator()(const Values& values) const { + return (double)(values.at(keys_[0]) == value_); +} + +/* ************************************************************************* */ +DecisionTreeFactor SingleValue::toDecisionTreeFactor() const { + DiscreteKeys keys; + keys += DiscreteKey(keys_[0], cardinality_); + vector table; + for (size_t i1 = 0; i1 < cardinality_; i1++) table.push_back(i1 == value_); + DecisionTreeFactor converted(keys, table); + return converted; +} + +/* ************************************************************************* */ +DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const { + // TODO: can we do this more efficiently? + return toDecisionTreeFactor() * f; +} + +/* ************************************************************************* */ +bool SingleValue::ensureArcConsistency(size_t j, + vector* domains) const { + if (j != keys_[0]) + throw invalid_argument("SingleValue check on wrong domain"); + Domain& D = domains->at(j); + if (D.isSingleton()) { + if (D.firstValue() != value_) throw runtime_error("Unsatisfiable"); + return false; + } + D = Domain(discreteKey(), value_); + return true; +} + +/* ************************************************************************* */ +Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const { + Values::const_iterator it = values.find(keys_[0]); + if (it != values.end() && it->second != value_) + throw runtime_error("SingleValue::partiallyApply: unsatisfiable"); + return boost::make_shared(keys_[0], cardinality_, value_); +} + +/* ************************************************************************* */ +Constraint::shared_ptr SingleValue::partiallyApply( + const vector& domains) const { + const Domain& Dk = domains[keys_[0]]; + if (Dk.isSingleton() && !Dk.contains(value_)) + throw runtime_error("SingleValue::partiallyApply: unsatisfiable"); + return boost::make_shared(discreteKey(), value_); +} + +/* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index 0f9a8fb0f..d8a9a770b 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -7,74 +7,71 @@ #pragma once +#include #include namespace gtsam { - /** - * SingleValue constraint: ensures a variable takes on a certain value. - * This could of course also be implemented by changing its `Domain`. +/** + * SingleValue constraint: ensures a variable takes on a certain value. + * This could of course also be implemented by changing its `Domain`. + */ +class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { + size_t cardinality_; /// < Number of values + size_t value_; ///< allowed value + + DiscreteKey discreteKey() const { + return DiscreteKey(keys_[0], cardinality_); + } + + public: + typedef boost::shared_ptr shared_ptr; + + /// Construct from key, cardinality, and given value. + SingleValue(Key key, size_t n, size_t value) + : Constraint(key), cardinality_(n), value_(value) {} + + /// Construct from DiscreteKey and given value. + SingleValue(const DiscreteKey& dkey, size_t value) + : Constraint(dkey.first), cardinality_(dkey.second), value_(value) {} + + // print + void print(const std::string& s = "", const KeyFormatter& formatter = + DefaultKeyFormatter) const override; + + /// equals + bool equals(const DiscreteFactor& other, double tol) const override { + if (!dynamic_cast(&other)) + return false; + else { + const SingleValue& f(static_cast(other)); + return (cardinality_ == f.cardinality_) && (value_ == f.value_); + } + } + + /// Calculate value + double operator()(const Values& values) const override; + + /// Convert into a decisiontree + DecisionTreeFactor toDecisionTreeFactor() const override; + + /// Multiply into a decisiontree + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + + /* + * Ensure Arc-consistency: just sets domain[j] to {value_} + * @param j domain to be checked + * @param domains all other domains */ - class GTSAM_UNSTABLE_EXPORT SingleValue: public Constraint { - - size_t cardinality_; /// < Number of values - size_t value_; ///< allowed value + bool ensureArcConsistency(size_t j, + std::vector* domains) const override; - DiscreteKey discreteKey() const { - return DiscreteKey(keys_[0],cardinality_); - } + /// Partially apply known values + Constraint::shared_ptr partiallyApply(const Values& values) const override; - public: + /// Partially apply known values, domain version + Constraint::shared_ptr partiallyApply( + const std::vector& domains) const override; +}; - typedef boost::shared_ptr shared_ptr; - - /// Construct from key, cardinality, and given value. - SingleValue(Key key, size_t n, size_t value) : - Constraint(key), cardinality_(n), value_(value) { - } - - /// Construct from DiscreteKey and given value. - SingleValue(const DiscreteKey& dkey, size_t value) : - Constraint(dkey.first), cardinality_(dkey.second), value_(value) { - } - - // print - void print(const std::string& s = "", - const KeyFormatter& formatter = DefaultKeyFormatter) const override; - - /// equals - bool equals(const DiscreteFactor& other, double tol) const override { - if(!dynamic_cast(&other)) - return false; - else { - const SingleValue& f(static_cast(other)); - return (cardinality_==f.cardinality_) && (value_==f.value_); - } - } - - /// Calculate value - double operator()(const Values& values) const override; - - /// Convert into a decisiontree - DecisionTreeFactor toDecisionTreeFactor() const override; - - /// Multiply into a decisiontree - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; - - /* - * Ensure Arc-consistency: just sets domain[j] to {value_} - * @param j domain to be checked - * @param domains all other domains - */ - bool ensureArcConsistency(size_t j, - std::vector* domains) const override; - - /// Partially apply known values - Constraint::shared_ptr partiallyApply(const Values& values) const override; - - /// Partially apply known values, domain version - Constraint::shared_ptr partiallyApply( - const std::vector& domains) const override; - }; - -} // namespace gtsam +} // namespace gtsam diff --git a/gtsam_unstable/discrete/tests/testCSP.cpp b/gtsam_unstable/discrete/tests/testCSP.cpp index b1aaab303..832175455 100644 --- a/gtsam_unstable/discrete/tests/testCSP.cpp +++ b/gtsam_unstable/discrete/tests/testCSP.cpp @@ -30,7 +30,7 @@ TEST(CSP, SingleValue) { EXPECT(assert_equal(f1, singleValue.toDecisionTreeFactor())); // Create domains, laid out as a vector. - // TODO(dellaert): should be map?? + // TODO(dellaert): should be map?? vector domains; domains += Domain(ID), Domain(AZ), Domain(UT); From 23bcf96da4955cd9f3332634d781f528c45cc451 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 20 Nov 2021 11:46:32 -0500 Subject: [PATCH 3/5] use emplace_shared --- gtsam/discrete/DiscreteFactorGraph.h | 16 +++++++++------- gtsam_unstable/discrete/CSP.h | 22 ++++++---------------- 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index f39adc9a8..3ea9c3cdd 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -101,25 +101,27 @@ public: /// @} - template + // Add single key decision-tree factor. + template void add(const DiscreteKey& j, SOURCE table) { DiscreteKeys keys; keys.push_back(j); - push_back(boost::make_shared(keys, table)); + emplace_shared(keys, table); } - template + // Add binary key decision-tree factor. + template void add(const DiscreteKey& j1, const DiscreteKey& j2, SOURCE table) { DiscreteKeys keys; keys.push_back(j1); keys.push_back(j2); - push_back(boost::make_shared(keys, table)); + emplace_shared(keys, table); } - /** add shared discreteFactor immediately from arguments */ - template + // Add shared discreteFactor immediately from arguments. + template void add(const DiscreteKeys& keys, SOURCE table) { - push_back(boost::make_shared(keys, table)); + emplace_shared(keys, table); } /** Return the set of variables involved in the factors (set union) */ diff --git a/gtsam_unstable/discrete/CSP.h b/gtsam_unstable/discrete/CSP.h index 544cdf0c9..e43e53932 100644 --- a/gtsam_unstable/discrete/CSP.h +++ b/gtsam_unstable/discrete/CSP.h @@ -21,32 +21,22 @@ namespace gtsam { class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph { public: /** A map from keys to values */ - typedef KeyVector Indices; typedef Assignment Values; typedef boost::shared_ptr sharedValues; public: - // /// Constructor - // CSP() { - // } - /// Add a unary constraint, allowing only a single value void addSingleValue(const DiscreteKey& dkey, size_t value) { - boost::shared_ptr factor(new SingleValue(dkey, value)); - push_back(factor); + emplace_shared(dkey, value); } /// Add a binary AllDiff constraint void addAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) { - boost::shared_ptr factor(new BinaryAllDiff(key1, key2)); - push_back(factor); + emplace_shared(key1, key2); } /// Add a general AllDiff constraint - void addAllDiff(const DiscreteKeys& dkeys) { - boost::shared_ptr factor(new AllDiff(dkeys)); - push_back(factor); - } + void addAllDiff(const DiscreteKeys& dkeys) { emplace_shared(dkeys); } // /** return product of all factors as a single factor */ // DecisionTreeFactor product() const { @@ -56,10 +46,10 @@ class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph { // return result; // } - /// Find the best total assignment - can be expensive + /// Find the best total assignment - can be expensive. sharedValues optimalAssignment() const; - /// Find the best total assignment - can be expensive + /// Find the best total assignment, with given ordering - can be expensive. sharedValues optimalAssignment(const Ordering& ordering) const; // /* @@ -78,7 +68,7 @@ class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph { * Apply arc-consistency ~ Approximate loopy belief propagation * We need to give the domains to a constraint, and it returns * a domain whose values don't conflict in the arc-consistency way. - * TODO: should get cardinality from Indices + * TODO: should get cardinality from DiscreteKeys */ void runArcConsistency(size_t cardinality, size_t nrIterations = 10, bool print = false) const; From ad3225953b53dba1d5bf09e69248ad93d53de056 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 20 Nov 2021 15:52:12 -0500 Subject: [PATCH 4/5] Cleaned up AC1 implementation --- gtsam_unstable/discrete/AllDiff.cpp | 11 +- gtsam_unstable/discrete/AllDiff.h | 12 +- gtsam_unstable/discrete/BinaryAllDiff.h | 11 +- gtsam_unstable/discrete/CSP.cpp | 118 +++++++------- gtsam_unstable/discrete/CSP.h | 14 +- gtsam_unstable/discrete/Constraint.h | 13 +- gtsam_unstable/discrete/Domain.cpp | 36 +++-- gtsam_unstable/discrete/Domain.h | 31 ++-- gtsam_unstable/discrete/SingleValue.cpp | 7 +- gtsam_unstable/discrete/SingleValue.h | 10 +- gtsam_unstable/discrete/tests/testCSP.cpp | 40 +++-- gtsam_unstable/discrete/tests/testSudoku.cpp | 162 +++++++++++++------ 12 files changed, 270 insertions(+), 195 deletions(-) diff --git a/gtsam_unstable/discrete/AllDiff.cpp b/gtsam_unstable/discrete/AllDiff.cpp index ef18053a4..85cf0b472 100644 --- a/gtsam_unstable/discrete/AllDiff.cpp +++ b/gtsam_unstable/discrete/AllDiff.cpp @@ -57,14 +57,11 @@ DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const { } /* ************************************************************************* */ -bool AllDiff::ensureArcConsistency(size_t j, - std::vector* domains) const { - // We are changing the domain of variable j. - // TODO(dellaert): confusing, I thought we were changing others... +bool AllDiff::ensureArcConsistency(Key j, Domains* domains) const { Domain& Dj = domains->at(j); // Though strictly not part of allDiff, we check for - // a value in domains[j] that does not occur in any other connected domain. + // a value in domains->at(j) that does not occur in any other connected domain. // If found, we make this a singleton... // TODO: make a new constraint where this really is true boost::optional maybeChanged = Dj.checkAllDiff(keys_, *domains); @@ -103,10 +100,10 @@ Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const { /* ************************************************************************* */ Constraint::shared_ptr AllDiff::partiallyApply( - const std::vector& domains) const { + const Domains& domains) const { DiscreteFactor::Values known; for (Key k : keys_) { - const Domain& Dk = domains[k]; + const Domain& Dk = domains.at(k); if (Dk.isSingleton()) known[k] = Dk.firstValue(); } return partiallyApply(known); diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index 4deabda94..57b0aeb5c 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -54,21 +54,19 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; /* - * Ensure Arc-consistency - * Arc-consistency involves creating binaryAllDiff constraints - * In which case the combinatorial hyper-arc explosion disappears. + * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked - * @param (in/out) domains all other domains + * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @return true if domains->at(j) was changed, false otherwise. */ - bool ensureArcConsistency(size_t j, - std::vector* domains) const override; + bool ensureArcConsistency(Key j, Domains* domains) const override; /// Partially apply known values Constraint::shared_ptr partiallyApply(const Values&) const override; /// Partially apply known values, domain version Constraint::shared_ptr partiallyApply( - const std::vector&) const override; + const Domains&) const override; }; } // namespace gtsam diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index 21cfb18f2..a2c7ba660 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -70,13 +70,12 @@ class BinaryAllDiff : public Constraint { } /* - * Ensure Arc-consistency + * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked - * @param domains all other domains + * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @return true if domains->at(j) was changed, false otherwise. */ - /// - bool ensureArcConsistency(size_t j, - std::vector* domains) const override { + bool ensureArcConsistency(Key j, Domains* domains) const override { throw std::runtime_error( "BinaryAllDiff::ensureArcConsistency not implemented"); return false; @@ -89,7 +88,7 @@ class BinaryAllDiff : public Constraint { /// Partially apply known values, domain version Constraint::shared_ptr partiallyApply( - const std::vector&) const override { + const Domains&) const override { throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); } }; diff --git a/gtsam_unstable/discrete/CSP.cpp b/gtsam_unstable/discrete/CSP.cpp index bab1ac3c8..8c974f4fd 100644 --- a/gtsam_unstable/discrete/CSP.cpp +++ b/gtsam_unstable/discrete/CSP.cpp @@ -27,81 +27,75 @@ CSP::sharedValues CSP::optimalAssignment(const Ordering& ordering) const { return mpe; } -void CSP::runArcConsistency(size_t cardinality, size_t nrIterations, - bool print) const { +bool CSP::runArcConsistency(const VariableIndex& index, + Domains* domains) const { + bool changed = false; + + // iterate over all variables in the index + for (auto entry : index) { + // Get the variable's key and associated factors: + const Key key = entry.first; + const FactorIndices& factors = entry.second; + + // If this domain is already a singleton, we do nothing. + if (domains->at(key).isSingleton()) continue; + + // Otherwise, loop over all factors/constraints for variable with given key. + for (size_t f : factors) { + // If this factor is a constraint, call its ensureArcConsistency method: + auto constraint = boost::dynamic_pointer_cast((*this)[f]); + if (constraint) { + changed = constraint->ensureArcConsistency(key, domains) || changed; + } + } + } + return changed; +} + +// TODO(dellaert): This is AC1, which is inefficient as any change will cause +// the algorithm to revisit *all* variables again. Implement AC3. +Domains CSP::runArcConsistency(size_t cardinality, size_t maxIterations) const { // Create VariableIndex VariableIndex index(*this); - // index.print(); - - size_t n = index.size(); // Initialize domains - std::vector domains; - for (size_t j = 0; j < n; j++) - domains.push_back(Domain(DiscreteKey(j, cardinality))); + Domains domains; + for (auto entry : index) { + const Key key = entry.first; + domains.emplace(key, DiscreteKey(key, cardinality)); + } - // Create array of flags indicating a domain changed or not - std::vector changed(n); + // Iterate until convergence or not a single domain changed. + for (size_t it = 0; it < maxIterations; it++) { + bool changed = runArcConsistency(index, &domains); + if (!changed) break; + } + return domains; +} - // iterate nrIterations over entire grid - for (size_t it = 0; it < nrIterations; it++) { - bool anyChange = false; - // iterate over all cells - for (size_t v = 0; v < n; v++) { - // keep track of which domains changed - changed[v] = false; - // loop over all factors/constraints for variable v - const FactorIndices& factors = index[v]; - for (size_t f : factors) { - // if not already a singleton - if (!domains[v].isSingleton()) { - // get the constraint and call its ensureArcConsistency method - auto constraint = boost::dynamic_pointer_cast((*this)[f]); - if (!constraint) - throw runtime_error("CSP:runArcConsistency: non-constraint factor"); - changed[v] = - constraint->ensureArcConsistency(v, &domains) || changed[v]; - } - } // f - if (changed[v]) anyChange = true; - } // v - if (!anyChange) break; - // TODO: Sudoku specific hack - if (print) { - if (cardinality == 9 && n == 81) { - for (size_t i = 0, v = 0; i < (size_t)std::sqrt((double)n); i++) { - for (size_t j = 0; j < (size_t)std::sqrt((double)n); j++, v++) { - if (changed[v]) cout << "*"; - domains[v].print(); - cout << "\t"; - } // i - cout << endl; - } // j - } else { - for (size_t v = 0; v < n; v++) { - if (changed[v]) cout << "*"; - domains[v].print(); - cout << "\t"; - } // v - } - cout << endl; - } // print - } // it - -#ifndef INPROGRESS - // Now create new problem with all singleton variables removed - // We do this by adding simplifying all factors using parial application +CSP CSP::partiallyApply(const Domains& domains) const { + // Create new problem with all singleton variables removed + // We do this by adding simplifying all factors using partial application. // TODO: create a new ordering as we go, to ensure a connected graph // KeyOrdering ordering; // vector dkeys; + CSP new_csp; + + // Add tightened domains as new factors: + for (auto key_domain : domains) { + new_csp.emplace_shared(key_domain.second); + } + + // Reduce all existing factors: for (const DiscreteFactor::shared_ptr& f : factors_) { - Constraint::shared_ptr constraint = - boost::dynamic_pointer_cast(f); + auto constraint = boost::dynamic_pointer_cast(f); if (!constraint) throw runtime_error("CSP:runArcConsistency: non-constraint factor"); Constraint::shared_ptr reduced = constraint->partiallyApply(domains); - if (print) reduced->print(); + if (reduced->size() > 1) { + new_csp.push_back(reduced); + } } -#endif + return new_csp; } } // namespace gtsam diff --git a/gtsam_unstable/discrete/CSP.h b/gtsam_unstable/discrete/CSP.h index e43e53932..d94913682 100644 --- a/gtsam_unstable/discrete/CSP.h +++ b/gtsam_unstable/discrete/CSP.h @@ -62,7 +62,7 @@ class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph { // deep. // * It will be very expensive to exclude values that way. // */ - // void applyBeliefPropagation(size_t nrIterations = 10) const; + // void applyBeliefPropagation(size_t maxIterations = 10) const; /* * Apply arc-consistency ~ Approximate loopy belief propagation @@ -70,8 +70,16 @@ class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph { * a domain whose values don't conflict in the arc-consistency way. * TODO: should get cardinality from DiscreteKeys */ - void runArcConsistency(size_t cardinality, size_t nrIterations = 10, - bool print = false) const; + Domains runArcConsistency(size_t cardinality, + size_t maxIterations = 10) const; + + /// Run arc consistency for all variables, return true if any domain changed. + bool runArcConsistency(const VariableIndex& index, Domains* domains) const; + + /* + * Create a new CSP, applying the given Domain constraints. + */ + CSP partiallyApply(const Domains& domains) const; }; // CSP } // namespace gtsam diff --git a/gtsam_unstable/discrete/Constraint.h b/gtsam_unstable/discrete/Constraint.h index e9714d6b4..f0e51b723 100644 --- a/gtsam_unstable/discrete/Constraint.h +++ b/gtsam_unstable/discrete/Constraint.h @@ -21,10 +21,12 @@ #include #include +#include namespace gtsam { class Domain; +using Domains = std::map; /** * Base class for constraint factors @@ -65,19 +67,18 @@ class GTSAM_EXPORT Constraint : public DiscreteFactor { /// @{ /* - * Ensure Arc-consistency, possibly changing domains of connected variables. + * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked - * @param (in/out) domains all other domains - * @return true if domains were changed, false otherwise. + * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @return true if domains->at(j) was changed, false otherwise. */ - virtual bool ensureArcConsistency(size_t j, - std::vector* domains) const = 0; + virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0; /// Partially apply known values virtual shared_ptr partiallyApply(const Values&) const = 0; /// Partially apply known values, domain version - virtual shared_ptr partiallyApply(const std::vector&) const = 0; + virtual shared_ptr partiallyApply(const Domains&) const = 0; /// @} }; // DiscreteFactor diff --git a/gtsam_unstable/discrete/Domain.cpp b/gtsam_unstable/discrete/Domain.cpp index da23717f6..98b735c6c 100644 --- a/gtsam_unstable/discrete/Domain.cpp +++ b/gtsam_unstable/discrete/Domain.cpp @@ -10,28 +10,35 @@ #include #include - +#include namespace gtsam { using namespace std; /* ************************************************************************* */ void Domain::print(const string& s, const KeyFormatter& formatter) const { - cout << s << ": Domain on " << formatter(keys_[0]) - << " (j=" << formatter(keys_[0]) << ") with values"; + cout << s << ": Domain on " << formatter(key()) << " (j=" << formatter(key()) + << ") with values"; for (size_t v : values_) cout << " " << v; cout << endl; } +/* ************************************************************************* */ +string Domain::base1Str() const { + stringstream ss; + for (size_t v : values_) ss << v + 1; + return ss.str(); +} + /* ************************************************************************* */ double Domain::operator()(const Values& values) const { - return contains(values.at(keys_[0])); + return contains(values.at(key())); } /* ************************************************************************* */ DecisionTreeFactor Domain::toDecisionTreeFactor() const { DiscreteKeys keys; - keys += DiscreteKey(keys_[0], cardinality_); + keys += DiscreteKey(key(), cardinality_); vector table; for (size_t i1 = 0; i1 < cardinality_; ++i1) table.push_back(contains(i1)); DecisionTreeFactor converted(keys, table); @@ -45,8 +52,8 @@ DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const { } /* ************************************************************************* */ -bool Domain::ensureArcConsistency(size_t j, vector* domains) const { - if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain"); +bool Domain::ensureArcConsistency(Key j, Domains* domains) const { + if (j != key()) throw invalid_argument("Domain check on wrong domain"); Domain& D = domains->at(j); for (size_t value : values_) if (!D.contains(value)) throw runtime_error("Unsatisfiable"); @@ -55,15 +62,15 @@ bool Domain::ensureArcConsistency(size_t j, vector* domains) const { } /* ************************************************************************* */ -boost::optional Domain::checkAllDiff( - const KeyVector keys, const vector& domains) const { - Key j = keys_[0]; +boost::optional Domain::checkAllDiff(const KeyVector keys, + const Domains& domains) const { + Key j = key(); // for all values in this domain for (const size_t value : values_) { // for all connected domains for (const Key k : keys) // if any domain contains the value we cannot make this domain singleton - if (k != j && domains[k].contains(value)) goto found; + if (k != j && domains.at(k).contains(value)) goto found; // Otherwise: return a singleton: return Domain(this->discreteKey(), value); found:; @@ -73,16 +80,15 @@ boost::optional Domain::checkAllDiff( /* ************************************************************************* */ Constraint::shared_ptr Domain::partiallyApply(const Values& values) const { - Values::const_iterator it = values.find(keys_[0]); + Values::const_iterator it = values.find(key()); if (it != values.end() && !contains(it->second)) throw runtime_error("Domain::partiallyApply: unsatisfiable"); return boost::make_shared(*this); } /* ************************************************************************* */ -Constraint::shared_ptr Domain::partiallyApply( - const vector& domains) const { - const Domain& Dk = domains[keys_[0]]; +Constraint::shared_ptr Domain::partiallyApply(const Domains& domains) const { + const Domain& Dk = domains.at(key()); if (Dk.isSingleton() && !contains(*Dk.begin())) throw runtime_error("Domain::partiallyApply: unsatisfiable"); return boost::make_shared(Dk); diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 9fa22175a..ae137ca33 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -20,10 +20,6 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { size_t cardinality_; /// Cardinality std::set values_; /// allowed values - DiscreteKey discreteKey() const { - return DiscreteKey(keys_[0], cardinality_); - } - public: typedef boost::shared_ptr shared_ptr; @@ -40,6 +36,12 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { values_.insert(v); } + /// The one key + Key key() const { return keys_[0]; } + + // The associated discrete key + DiscreteKey discreteKey() const { return DiscreteKey(key(), cardinality_); } + /// Insert a value, non const :-( void insert(size_t value) { values_.insert(value); } @@ -66,6 +68,11 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { } } + // Return concise string representation, mostly to debug arc consistency. + // Converts from base 0 to base1. + std::string base1Str() const; + + // Check whether domain cotains a specific value. bool contains(size_t value) const { return values_.count(value) > 0; } /// Calculate value @@ -78,12 +85,13 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; /* - * Ensure Arc-consistency + * Ensure Arc-consistency by checking every possible value of domain j. * @param j domain to be checked - * @param domains all other domains + * @param (in/out) domains all domains, but only domains->at(j) will be + * checked. + * @return true if domains->at(j) was changed, false otherwise. */ - bool ensureArcConsistency(size_t j, - std::vector* domains) const override; + bool ensureArcConsistency(Key j, Domains* domains) const override; /** * Check for a value in domain that does not occur in any other connected @@ -92,15 +100,14 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { * @param keys connected domains through alldiff * @param keys other domains */ - boost::optional checkAllDiff( - const KeyVector keys, const std::vector& domains) const; + boost::optional checkAllDiff(const KeyVector keys, + const Domains& domains) const; /// Partially apply known values Constraint::shared_ptr partiallyApply(const Values& values) const override; /// Partially apply known values, domain version - Constraint::shared_ptr partiallyApply( - const std::vector& domains) const override; + Constraint::shared_ptr partiallyApply(const Domains& domains) const override; }; } // namespace gtsam diff --git a/gtsam_unstable/discrete/SingleValue.cpp b/gtsam_unstable/discrete/SingleValue.cpp index 753d46cff..162e21512 100644 --- a/gtsam_unstable/discrete/SingleValue.cpp +++ b/gtsam_unstable/discrete/SingleValue.cpp @@ -44,8 +44,7 @@ DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const { } /* ************************************************************************* */ -bool SingleValue::ensureArcConsistency(size_t j, - vector* domains) const { +bool SingleValue::ensureArcConsistency(Key j, Domains* domains) const { if (j != keys_[0]) throw invalid_argument("SingleValue check on wrong domain"); Domain& D = domains->at(j); @@ -67,8 +66,8 @@ Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const { /* ************************************************************************* */ Constraint::shared_ptr SingleValue::partiallyApply( - const vector& domains) const { - const Domain& Dk = domains[keys_[0]]; + const Domains& domains) const { + const Domain& Dk = domains.at(keys_[0]); if (Dk.isSingleton() && !Dk.contains(value_)) throw runtime_error("SingleValue::partiallyApply: unsatisfiable"); return boost::make_shared(discreteKey(), value_); diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index d8a9a770b..d826093df 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -59,19 +59,19 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; /* - * Ensure Arc-consistency: just sets domain[j] to {value_} + * Ensure Arc-consistency: just sets domain[j] to {value_}. * @param j domain to be checked - * @param domains all other domains + * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @return true if domains->at(j) was changed, false otherwise. */ - bool ensureArcConsistency(size_t j, - std::vector* domains) const override; + bool ensureArcConsistency(Key j, Domains* domains) const override; /// Partially apply known values Constraint::shared_ptr partiallyApply(const Values& values) const override; /// Partially apply known values, domain version Constraint::shared_ptr partiallyApply( - const std::vector& domains) const override; + const Domains& domains) const override; }; } // namespace gtsam diff --git a/gtsam_unstable/discrete/tests/testCSP.cpp b/gtsam_unstable/discrete/tests/testCSP.cpp index 832175455..63069d710 100644 --- a/gtsam_unstable/discrete/tests/testCSP.cpp +++ b/gtsam_unstable/discrete/tests/testCSP.cpp @@ -29,16 +29,17 @@ TEST(CSP, SingleValue) { DecisionTreeFactor f1(AZ, "0 0 1"); EXPECT(assert_equal(f1, singleValue.toDecisionTreeFactor())); - // Create domains, laid out as a vector. - // TODO(dellaert): should be map?? - vector domains; - domains += Domain(ID), Domain(AZ), Domain(UT); + // Create domains + Domains domains; + domains.emplace(0, Domain(ID)); + domains.emplace(1, Domain(AZ)); + domains.emplace(2, Domain(UT)); // Ensure arc-consistency: just wipes out values in AZ domain: EXPECT(singleValue.ensureArcConsistency(1, &domains)); - LONGS_EQUAL(3, domains[0].nrValues()); - LONGS_EQUAL(1, domains[1].nrValues()); - LONGS_EQUAL(3, domains[2].nrValues()); + LONGS_EQUAL(3, domains.at(0).nrValues()); + LONGS_EQUAL(1, domains.at(1).nrValues()); + LONGS_EQUAL(3, domains.at(2).nrValues()); } /* ************************************************************************* */ @@ -81,8 +82,10 @@ TEST(CSP, AllDiff) { EXPECT(assert_equal(f2, actual)); // Create domains. - vector domains; - domains += Domain(ID), Domain(AZ), Domain(UT); + Domains domains; + domains.emplace(0, Domain(ID)); + domains.emplace(1, Domain(AZ)); + domains.emplace(2, Domain(UT)); // First constrict AZ domain: SingleValue singleValue(AZ, 2); @@ -92,9 +95,9 @@ TEST(CSP, AllDiff) { EXPECT(alldiff.ensureArcConsistency(0, &domains)); EXPECT(!alldiff.ensureArcConsistency(1, &domains)); EXPECT(alldiff.ensureArcConsistency(2, &domains)); - LONGS_EQUAL(2, domains[0].nrValues()); - LONGS_EQUAL(1, domains[1].nrValues()); - LONGS_EQUAL(2, domains[2].nrValues()); + LONGS_EQUAL(2, domains.at(0).nrValues()); + LONGS_EQUAL(1, domains.at(1).nrValues()); + LONGS_EQUAL(2, domains.at(2).nrValues()); } /* ************************************************************************* */ @@ -232,17 +235,20 @@ TEST(CSP, ArcConsistency) { EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); // ensure arc-consistency, i.e., narrow domains... - vector domains; - domains += Domain(ID), Domain(AZ), Domain(UT); + Domains domains; + domains.emplace(0, Domain(ID)); + domains.emplace(1, Domain(AZ)); + domains.emplace(2, Domain(UT)); + SingleValue singleValue(AZ, 2); AllDiff alldiff(dkeys); EXPECT(singleValue.ensureArcConsistency(1, &domains)); EXPECT(alldiff.ensureArcConsistency(0, &domains)); EXPECT(!alldiff.ensureArcConsistency(1, &domains)); EXPECT(alldiff.ensureArcConsistency(2, &domains)); - LONGS_EQUAL(2, domains[0].nrValues()); - LONGS_EQUAL(1, domains[1].nrValues()); - LONGS_EQUAL(2, domains[2].nrValues()); + LONGS_EQUAL(2, domains.at(0).nrValues()); + LONGS_EQUAL(1, domains.at(1).nrValues()); + LONGS_EQUAL(2, domains.at(2).nrValues()); // Parial application, version 1 DiscreteFactor::Values known; diff --git a/gtsam_unstable/discrete/tests/testSudoku.cpp b/gtsam_unstable/discrete/tests/testSudoku.cpp index 4843ae269..ee307fd5b 100644 --- a/gtsam_unstable/discrete/tests/testSudoku.cpp +++ b/gtsam_unstable/discrete/tests/testSudoku.cpp @@ -6,6 +6,7 @@ */ #include +#include #include #include @@ -20,12 +21,12 @@ using namespace gtsam; #define PRINT false +/// A class that encodes Sudoku's as a CSP problem class Sudoku : public CSP { - /// sudoku size - size_t n_; + size_t n_; ///< Side of Sudoku, e.g. 4 or 9 - /// discrete keys - typedef std::pair IJ; + /// Mapping from base i,j coordinates to discrete keys: + using IJ = std::pair; std::map dkeys_; public: @@ -42,15 +43,14 @@ class Sudoku : public CSP { // Create variables, ordering, and unary constraints va_list ap; va_start(ap, n); - Key k = 0; for (size_t i = 0; i < n; ++i) { - for (size_t j = 0; j < n; ++j, ++k) { + for (size_t j = 0; j < n; ++j) { // create the key IJ ij(i, j); - dkeys_[ij] = DiscreteKey(k, n); + Symbol key('1' + i, j + 1); + dkeys_[ij] = DiscreteKey(key, n); // get the unary constraint, if any int value = va_arg(ap, int); - // cout << value << " "; if (value != 0) addSingleValue(dkeys_[ij], value - 1); } // cout << endl; @@ -88,7 +88,7 @@ class Sudoku : public CSP { } /// Print readable form of assignment - void printAssignment(DiscreteFactor::sharedValues assignment) const { + void printAssignment(const DiscreteFactor::sharedValues& assignment) const { for (size_t i = 0; i < n_; i++) { for (size_t j = 0; j < n_; j++) { Key k = key(i, j); @@ -99,10 +99,22 @@ class Sudoku : public CSP { } /// solve and print solution - void printSolution() { + void printSolution() const { DiscreteFactor::sharedValues MPE = optimalAssignment(); printAssignment(MPE); } + + // Print domain + void printDomains(const Domains& domains) { + for (size_t i = 0; i < n_; i++) { + for (size_t j = 0; j < n_; j++) { + Key k = key(i, j); + cout << domains.at(k).base1Str(); + cout << "\t"; + } // i + cout << endl; + } // j + } }; /* ************************************************************************* */ @@ -113,9 +125,6 @@ TEST_UNSAFE(Sudoku, small) { 4, 0, 2, 0, // 0, 1, 0, 0); - // Do BP - csp.runArcConsistency(4, 10, PRINT); - // optimize and check CSP::sharedValues solution = csp.optimalAssignment(); CSP::Values expected; @@ -126,73 +135,124 @@ TEST_UNSAFE(Sudoku, small) { csp.key(3, 3), 2); EXPECT(assert_equal(expected, *solution)); // csp.printAssignment(solution); + + // Do BP (AC1) + auto domains = csp.runArcConsistency(4, 3); + // csp.printDomains(domains); + Domain domain44 = domains.at(Symbol('4', 4)); + EXPECT_LONGS_EQUAL(1, domain44.nrValues()); + + // Test Creation of a new, simpler CSP + CSP new_csp = csp.partiallyApply(domains); + // Should only be 16 new Domains + EXPECT_LONGS_EQUAL(16, new_csp.size()); + + // Check that solution + CSP::sharedValues new_solution = new_csp.optimalAssignment(); + // csp.printAssignment(new_solution); + EXPECT(assert_equal(expected, *new_solution)); } /* ************************************************************************* */ TEST_UNSAFE(Sudoku, easy) { - Sudoku sudoku(9, // - 0, 0, 5, 0, 9, 0, 0, 0, 1, // - 0, 0, 0, 0, 0, 2, 0, 7, 3, // - 7, 6, 0, 0, 0, 8, 2, 0, 0, // + Sudoku csp(9, // + 0, 0, 5, 0, 9, 0, 0, 0, 1, // + 0, 0, 0, 0, 0, 2, 0, 7, 3, // + 7, 6, 0, 0, 0, 8, 2, 0, 0, // - 0, 1, 2, 0, 0, 9, 0, 0, 4, // - 0, 0, 0, 2, 0, 3, 0, 0, 0, // - 3, 0, 0, 1, 0, 0, 9, 6, 0, // + 0, 1, 2, 0, 0, 9, 0, 0, 4, // + 0, 0, 0, 2, 0, 3, 0, 0, 0, // + 3, 0, 0, 1, 0, 0, 9, 6, 0, // - 0, 0, 1, 9, 0, 0, 0, 5, 8, // - 9, 7, 0, 5, 0, 0, 0, 0, 0, // - 5, 0, 0, 0, 3, 0, 7, 0, 0); + 0, 0, 1, 9, 0, 0, 0, 5, 8, // + 9, 7, 0, 5, 0, 0, 0, 0, 0, // + 5, 0, 0, 0, 3, 0, 7, 0, 0); - // Do BP - sudoku.runArcConsistency(4, 10, PRINT); + // csp.printSolution(); // don't do it - // sudoku.printSolution(); // don't do it + // Do BP (AC1) + auto domains = csp.runArcConsistency(9, 10); + // csp.printDomains(domains); + Key key99 = Symbol('9', 9); + Domain domain99 = domains.at(key99); + EXPECT_LONGS_EQUAL(1, domain99.nrValues()); + + // Test Creation of a new, simpler CSP + CSP new_csp = csp.partiallyApply(domains); + // 81 new Domains, and still 26 all-diff constraints + EXPECT_LONGS_EQUAL(81 + 26, new_csp.size()); + + // csp.printSolution(); // still don't do it ! :-( } /* ************************************************************************* */ TEST_UNSAFE(Sudoku, extreme) { - Sudoku sudoku(9, // - 0, 0, 9, 7, 4, 8, 0, 0, 0, 7, // - 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, // - 0, 1, 0, 9, 0, 0, 0, 0, 0, 7, // - 0, 0, 0, 2, 4, 0, 0, 6, 4, 0, // - 1, 0, 5, 9, 0, 0, 9, 8, 0, 0, // - 0, 3, 0, 0, 0, 0, 0, 8, 0, 3, // - 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, // - 0, 6, 0, 0, 0, 2, 7, 5, 9, 0, 0); + Sudoku csp(9, // + 0, 0, 9, 7, 4, 8, 0, 0, 0, 7, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, // + 0, 1, 0, 9, 0, 0, 0, 0, 0, 7, // + 0, 0, 0, 2, 4, 0, 0, 6, 4, 0, // + 1, 0, 5, 9, 0, 0, 9, 8, 0, 0, // + 0, 3, 0, 0, 0, 0, 0, 8, 0, 3, // + 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 6, 0, 0, 0, 2, 7, 5, 9, 0, 0); // Do BP - sudoku.runArcConsistency(9, 10, PRINT); + csp.runArcConsistency(9, 10); #ifdef METIS - VariableIndexOrdered index(sudoku); + VariableIndexOrdered index(csp); index.print("index"); ofstream os("/Users/dellaert/src/hmetis-1.5-osx-i686/extreme-dual.txt"); index.outputMetisFormat(os); #endif - // sudoku.printSolution(); // don't do it + // Do BP (AC1) + auto domains = csp.runArcConsistency(9, 10); + // csp.printDomains(domains); + Key key99 = Symbol('9', 9); + Domain domain99 = domains.at(key99); + EXPECT_LONGS_EQUAL(2, domain99.nrValues()); + + // Test Creation of a new, simpler CSP + CSP new_csp = csp.partiallyApply(domains); + // 81 new Domains, and still 20 all-diff constraints + EXPECT_LONGS_EQUAL(81 + 20, new_csp.size()); + + // csp.printSolution(); // still don't do it ! :-( } /* ************************************************************************* */ TEST_UNSAFE(Sudoku, AJC_3star_Feb8_2012) { - Sudoku sudoku(9, // - 9, 5, 0, 0, 0, 6, 0, 0, 0, // - 0, 8, 4, 0, 7, 0, 0, 0, 0, // - 6, 2, 0, 5, 0, 0, 4, 0, 0, // + Sudoku csp(9, // + 9, 5, 0, 0, 0, 6, 0, 0, 0, // + 0, 8, 4, 0, 7, 0, 0, 0, 0, // + 6, 2, 0, 5, 0, 0, 4, 0, 0, // - 0, 0, 0, 2, 9, 0, 6, 0, 0, // - 0, 9, 0, 0, 0, 0, 0, 2, 0, // - 0, 0, 2, 0, 6, 3, 0, 0, 0, // + 0, 0, 0, 2, 9, 0, 6, 0, 0, // + 0, 9, 0, 0, 0, 0, 0, 2, 0, // + 0, 0, 2, 0, 6, 3, 0, 0, 0, // - 0, 0, 9, 0, 0, 7, 0, 6, 8, // - 0, 0, 0, 0, 3, 0, 2, 9, 0, // - 0, 0, 0, 1, 0, 0, 0, 3, 7); + 0, 0, 9, 0, 0, 7, 0, 6, 8, // + 0, 0, 0, 0, 3, 0, 2, 9, 0, // + 0, 0, 0, 1, 0, 0, 0, 3, 7); - // Do BP - sudoku.runArcConsistency(9, 10, PRINT); + // Do BP (AC1) + auto domains = csp.runArcConsistency(9, 10); + // csp.printDomains(domains); + Key key99 = Symbol('9', 9); + Domain domain99 = domains.at(key99); + EXPECT_LONGS_EQUAL(1, domain99.nrValues()); - // sudoku.printSolution(); // don't do it + // Test Creation of a new, simpler CSP + CSP new_csp = csp.partiallyApply(domains); + // Just the 81 new Domains + EXPECT_LONGS_EQUAL(81, new_csp.size()); + + // Check that solution + CSP::sharedValues solution = new_csp.optimalAssignment(); + // csp.printAssignment(solution); + EXPECT_LONGS_EQUAL(6, solution->at(key99)); } /* ************************************************************************* */ From 58dafd43e9caf63ade6443015f5e7735a97b77e7 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 20 Nov 2021 16:44:17 -0500 Subject: [PATCH 5/5] Fixed up sudoku tests after merge --- gtsam_unstable/discrete/tests/testSudoku.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/gtsam_unstable/discrete/tests/testSudoku.cpp b/gtsam_unstable/discrete/tests/testSudoku.cpp index a8546bc2f..808f98b1c 100644 --- a/gtsam_unstable/discrete/tests/testSudoku.cpp +++ b/gtsam_unstable/discrete/tests/testSudoku.cpp @@ -118,7 +118,7 @@ class Sudoku : public CSP { }; /* ************************************************************************* */ -TEST_UNSAFE(Sudoku, small) { +TEST(Sudoku, small) { Sudoku csp(4, // 1, 0, 0, 4, // 0, 0, 0, 0, // @@ -148,13 +148,13 @@ TEST_UNSAFE(Sudoku, small) { EXPECT_LONGS_EQUAL(16, new_csp.size()); // Check that solution - CSP::sharedValues new_solution = new_csp.optimalAssignment(); + auto new_solution = new_csp.optimalAssignment(); // csp.printAssignment(new_solution); - EXPECT(assert_equal(expected, *new_solution)); + EXPECT(assert_equal(expected, new_solution)); } /* ************************************************************************* */ -TEST_UNSAFE(Sudoku, easy) { +TEST(Sudoku, easy) { Sudoku csp(9, // 0, 0, 5, 0, 9, 0, 0, 0, 1, // 0, 0, 0, 0, 0, 2, 0, 7, 3, // @@ -186,7 +186,7 @@ TEST_UNSAFE(Sudoku, easy) { } /* ************************************************************************* */ -TEST_UNSAFE(Sudoku, extreme) { +TEST(Sudoku, extreme) { Sudoku csp(9, // 0, 0, 9, 7, 4, 8, 0, 0, 0, 7, // 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, // @@ -223,7 +223,7 @@ TEST_UNSAFE(Sudoku, extreme) { } /* ************************************************************************* */ -TEST_UNSAFE(Sudoku, AJC_3star_Feb8_2012) { +TEST(Sudoku, AJC_3star_Feb8_2012) { Sudoku csp(9, // 9, 5, 0, 0, 0, 6, 0, 0, 0, // 0, 8, 4, 0, 7, 0, 0, 0, 0, // @@ -250,9 +250,9 @@ TEST_UNSAFE(Sudoku, AJC_3star_Feb8_2012) { EXPECT_LONGS_EQUAL(81, new_csp.size()); // Check that solution - CSP::sharedValues solution = new_csp.optimalAssignment(); + auto solution = new_csp.optimalAssignment(); // csp.printAssignment(solution); - EXPECT_LONGS_EQUAL(6, solution->at(key99)); + EXPECT_LONGS_EQUAL(6, solution.at(key99)); } /* ************************************************************************* */