diff --git a/gtsam_unstable/discrete/AllDiff.cpp b/gtsam_unstable/discrete/AllDiff.cpp index d6e1c6453..ebc789ec2 100644 --- a/gtsam_unstable/discrete/AllDiff.cpp +++ b/gtsam_unstable/discrete/AllDiff.cpp @@ -5,105 +5,115 @@ * @author Frank Dellaert */ -#include -#include #include - +#include +#include #include namespace gtsam { -/* ************************************************************************* */ -AllDiff::AllDiff(const DiscreteKeys& dkeys) : Constraint(dkeys.indices()) { - for (const DiscreteKey& dkey : dkeys) cardinalities_.insert(dkey); -} - -/* ************************************************************************* */ -void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const { - std::cout << s << "AllDiff on "; - for (Key dkey : keys_) std::cout << formatter(dkey) << " "; - std::cout << std::endl; -} - -/* ************************************************************************* */ -double AllDiff::operator()(const Values& values) const { - std::set taken; // record values taken by keys - for (Key dkey : keys_) { - size_t value = values.at(dkey); // get the value for that key - if (taken.count(value)) return 0.0; // check if value alreday taken - taken.insert(value); // if not, record it as taken and keep checking + /* ************************************************************************* */ + AllDiff::AllDiff(const DiscreteKeys& dkeys) : + Constraint(dkeys.indices()) { + for(const DiscreteKey& dkey: dkeys) + cardinalities_.insert(dkey); } - return 1.0; -} -/* ************************************************************************* */ -DecisionTreeFactor AllDiff::toDecisionTreeFactor() const { - // We will do this by converting the allDif into many BinaryAllDiff - // constraints - DecisionTreeFactor converted; - size_t nrKeys = keys_.size(); - for (size_t i1 = 0; i1 < nrKeys; i1++) - for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) { - BinaryAllDiff binary12(discreteKey(i1), discreteKey(i2)); - converted = converted * binary12.toDecisionTreeFactor(); + /* ************************************************************************* */ + void AllDiff::print(const std::string& s, + const KeyFormatter& formatter) const { + std::cout << s << "AllDiff on "; + for (Key dkey: keys_) + std::cout << formatter(dkey) << " "; + std::cout << std::endl; + } + + /* ************************************************************************* */ + double AllDiff::operator()(const Values& values) const { + std::set < size_t > taken; // record values taken by keys + for(Key dkey: keys_) { + size_t value = values.at(dkey); // get the value for that key + if (taken.count(value)) return 0.0;// check if value alreday taken + taken.insert(value);// if not, record it as taken and keep checking } - return converted; -} + return 1.0; + } -/* ************************************************************************* */ -DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const { - // TODO: can we do this more efficiently? - return toDecisionTreeFactor() * f; -} + /* ************************************************************************* */ + DecisionTreeFactor AllDiff::toDecisionTreeFactor() const { + // We will do this by converting the allDif into many BinaryAllDiff constraints + DecisionTreeFactor converted; + size_t nrKeys = keys_.size(); + for (size_t i1 = 0; i1 < nrKeys; i1++) + for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) { + BinaryAllDiff binary12(discreteKey(i1),discreteKey(i2)); + converted = converted * binary12.toDecisionTreeFactor(); + } + return converted; + } -/* ************************************************************************* */ -bool AllDiff::ensureArcConsistency(size_t j, - std::vector& domains) const { - // Though strictly not part of allDiff, we check for - // a value in domains[j] that does not occur in any other connected domain. - // If found, we make this a singleton... - // TODO: make a new constraint where this really is true - Domain& Dj = domains[j]; - if (Dj.checkAllDiff(keys_, domains)) return true; + /* ************************************************************************* */ + DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const { + // TODO: can we do this more efficiently? + return toDecisionTreeFactor() * f; + } - // Check all other domains for singletons and erase corresponding values - // This is the same as arc-consistency on the equivalent binary constraints - bool changed = false; - for (Key k : keys_) - if (k != j) { - const Domain& Dk = domains[k]; - if (Dk.isSingleton()) { // check if singleton - size_t value = Dk.firstValue(); - if (Dj.contains(value)) { - Dj.erase(value); // erase value if true - changed = true; + /* ************************************************************************* */ + bool AllDiff::ensureArcConsistency(size_t j, + std::vector* domains) const { + // We are changing the domain of variable j. + // TODO(dellaert): confusing, I thought we were changing others... + Domain& Dj = domains->at(j); + + // Though strictly not part of allDiff, we check for + // a value in domains[j] that does not occur in any other connected domain. + // If found, we make this a singleton... + // TODO: make a new constraint where this really is true + boost::optional maybeChanged = Dj.checkAllDiff(keys_, *domains); + if (maybeChanged) { + Dj = *maybeChanged; + return true; + } + + // Check all other domains for singletons and erase corresponding values. + // This is the same as arc-consistency on the equivalent binary constraints + bool changed = false; + for (Key k : keys_) + if (k != j) { + const Domain& Dk = domains->at(k); + if (Dk.isSingleton()) { // check if singleton + size_t value = Dk.firstValue(); + if (Dj.contains(value)) { + Dj.erase(value); // erase value if true + changed = true; + } } } - } - return changed; -} - -/* ************************************************************************* */ -Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const { - DiscreteKeys newKeys; - // loop over keys and add them only if they do not appear in values - for (Key k : keys_) - if (values.find(k) == values.end()) { - newKeys.push_back(DiscreteKey(k, cardinalities_.at(k))); - } - return boost::make_shared(newKeys); -} - -/* ************************************************************************* */ -Constraint::shared_ptr AllDiff::partiallyApply( - const std::vector& domains) const { - DiscreteFactor::Values known; - for (Key k : keys_) { - const Domain& Dk = domains[k]; - if (Dk.isSingleton()) known[k] = Dk.firstValue(); + return changed; } - return partiallyApply(known); -} -/* ************************************************************************* */ -} // namespace gtsam + /* ************************************************************************* */ + Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const { + DiscreteKeys newKeys; + // loop over keys and add them only if they do not appear in values + for(Key k: keys_) + if (values.find(k) == values.end()) { + newKeys.push_back(DiscreteKey(k,cardinalities_.at(k))); + } + return boost::make_shared(newKeys); + } + + /* ************************************************************************* */ + Constraint::shared_ptr AllDiff::partiallyApply( + const std::vector& domains) const { + DiscreteFactor::Values known; + for(Key k: keys_) { + const Domain& Dk = domains[k]; + if (Dk.isSingleton()) + known[k] = Dk.firstValue(); + } + return partiallyApply(known); + } + + /* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index b0fd1d631..8c83e5ba1 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -7,71 +7,70 @@ #pragma once -#include #include +#include namespace gtsam { -/** - * General AllDiff constraint - * Returns 1 if values for all keys are different, 0 otherwise - * DiscreteFactors are all awkward in that they have to store two types of keys: - * for each variable we have a Key and an Key. In this factor, we - * keep the Indices locally, and the Indices are stored in IndexFactor. - */ -class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { - std::map cardinalities_; - - DiscreteKey discreteKey(size_t i) const { - Key j = keys_[i]; - return DiscreteKey(j, cardinalities_.at(j)); - } - - public: - /// Constructor - AllDiff(const DiscreteKeys& dkeys); - - // print - void print(const std::string& s = "", const KeyFormatter& formatter = - DefaultKeyFormatter) const override; - - /// equals - bool equals(const DiscreteFactor& other, double tol) const override { - if (!dynamic_cast(&other)) - return false; - else { - const AllDiff& f(static_cast(other)); - return cardinalities_.size() == f.cardinalities_.size() && - std::equal(cardinalities_.begin(), cardinalities_.end(), - f.cardinalities_.begin()); - } - } - - /// Calculate value = expensive ! - double operator()(const Values& values) const override; - - /// Convert into a decisiontree, can be *very* expensive ! - DecisionTreeFactor toDecisionTreeFactor() const override; - - /// Multiply into a decisiontree - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; - - /* - * Ensure Arc-consistency - * Arc-consistency involves creating binaryAllDiff constraints - * In which case the combinatorial hyper-arc explosion disappears. - * @param j domain to be checked - * @param domains all other domains + /** + * General AllDiff constraint. + * Returns 1 if values for all keys are different, 0 otherwise. */ - bool ensureArcConsistency(size_t j, - std::vector& domains) const override; + class GTSAM_UNSTABLE_EXPORT AllDiff: public Constraint { - /// Partially apply known values - Constraint::shared_ptr partiallyApply(const Values&) const override; + std::map cardinalities_; - /// Partially apply known values, domain version - Constraint::shared_ptr partiallyApply( - const std::vector&) const override; -}; + DiscreteKey discreteKey(size_t i) const { + Key j = keys_[i]; + return DiscreteKey(j,cardinalities_.at(j)); + } -} // namespace gtsam + public: + + /// Construct from keys. + AllDiff(const DiscreteKeys& dkeys); + + // print + void print(const std::string& s = "", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; + + /// equals + bool equals(const DiscreteFactor& other, double tol) const override { + if(!dynamic_cast(&other)) + return false; + else { + const AllDiff& f(static_cast(other)); + return cardinalities_.size() == f.cardinalities_.size() + && std::equal(cardinalities_.begin(), cardinalities_.end(), + f.cardinalities_.begin()); + } + } + + /// Calculate value = expensive ! + double operator()(const Values& values) const override; + + /// Convert into a decisiontree, can be *very* expensive ! + DecisionTreeFactor toDecisionTreeFactor() const override; + + /// Multiply into a decisiontree + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + + /* + * Ensure Arc-consistency + * Arc-consistency involves creating binaryAllDiff constraints + * In which case the combinatorial hyper-arc explosion disappears. + * @param j domain to be checked + * @param (in/out) domains all other domains + */ + bool ensureArcConsistency(size_t j, + std::vector* domains) const override; + + /// Partially apply known values + Constraint::shared_ptr partiallyApply(const Values&) const override; + + /// Partially apply known values, domain version + Constraint::shared_ptr partiallyApply( + const std::vector&) const override; + }; + +} // namespace gtsam diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index d8e1a590a..acc3cc421 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -7,93 +7,92 @@ #pragma once -#include -#include #include +#include +#include namespace gtsam { -/** - * Binary AllDiff constraint - * Returns 1 if values for two keys are different, 0 otherwise - * DiscreteFactors are all awkward in that they have to store two types of keys: - * for each variable we have a Index and an Index. In this factor, we - * keep the Indices locally, and the Indices are stored in IndexFactor. - */ -class BinaryAllDiff : public Constraint { - size_t cardinality0_, cardinality1_; /// cardinality - - public: - /// Constructor - BinaryAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) - : Constraint(key1.first, key2.first), - cardinality0_(key1.second), - cardinality1_(key2.second) {} - - // print - void print( - const std::string& s = "", - const KeyFormatter& formatter = DefaultKeyFormatter) const override { - std::cout << s << "BinaryAllDiff on " << formatter(keys_[0]) << " and " - << formatter(keys_[1]) << std::endl; - } - - /// equals - bool equals(const DiscreteFactor& other, double tol) const override { - if (!dynamic_cast(&other)) - return false; - else { - const BinaryAllDiff& f(static_cast(other)); - return (cardinality0_ == f.cardinality0_) && - (cardinality1_ == f.cardinality1_); - } - } - - /// Calculate value - double operator()(const Values& values) const override { - return (double)(values.at(keys_[0]) != values.at(keys_[1])); - } - - /// Convert into a decisiontree - DecisionTreeFactor toDecisionTreeFactor() const override { - DiscreteKeys keys; - keys.push_back(DiscreteKey(keys_[0], cardinality0_)); - keys.push_back(DiscreteKey(keys_[1], cardinality1_)); - std::vector table; - for (size_t i1 = 0; i1 < cardinality0_; i1++) - for (size_t i2 = 0; i2 < cardinality1_; i2++) table.push_back(i1 != i2); - DecisionTreeFactor converted(keys, table); - return converted; - } - - /// Multiply into a decisiontree - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { - // TODO: can we do this more efficiently? - return toDecisionTreeFactor() * f; - } - - /* - * Ensure Arc-consistency - * @param j domain to be checked - * @param domains all other domains + /** + * Binary AllDiff constraint + * Returns 1 if values for two keys are different, 0 otherwise. */ - bool ensureArcConsistency(size_t j, - std::vector& domains) const override { - // throw std::runtime_error( - // "BinaryAllDiff::ensureArcConsistency not implemented"); - return false; - } + class BinaryAllDiff: public Constraint { - /// Partially apply known values - Constraint::shared_ptr partiallyApply(const Values&) const override { - throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); - } + size_t cardinality0_, cardinality1_; /// cardinality - /// Partially apply known values, domain version - Constraint::shared_ptr partiallyApply( - const std::vector&) const override { - throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); - } -}; + public: -} // namespace gtsam + /// Constructor + BinaryAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) : + Constraint(key1.first, key2.first), + cardinality0_(key1.second), cardinality1_(key2.second) { + } + + // print + void print(const std::string& s = "", + const KeyFormatter& formatter = DefaultKeyFormatter) const override { + std::cout << s << "BinaryAllDiff on " << formatter(keys_[0]) << " and " + << formatter(keys_[1]) << std::endl; + } + + /// equals + bool equals(const DiscreteFactor& other, double tol) const override { + if(!dynamic_cast(&other)) + return false; + else { + const BinaryAllDiff& f(static_cast(other)); + return (cardinality0_==f.cardinality0_) && (cardinality1_==f.cardinality1_); + } + } + + /// Calculate value + double operator()(const Values& values) const override { + return (double) (values.at(keys_[0]) != values.at(keys_[1])); + } + + /// Convert into a decisiontree + DecisionTreeFactor toDecisionTreeFactor() const override { + DiscreteKeys keys; + keys.push_back(DiscreteKey(keys_[0],cardinality0_)); + keys.push_back(DiscreteKey(keys_[1],cardinality1_)); + std::vector table; + for (size_t i1 = 0; i1 < cardinality0_; i1++) + for (size_t i2 = 0; i2 < cardinality1_; i2++) + table.push_back(i1 != i2); + DecisionTreeFactor converted(keys, table); + return converted; + } + + /// Multiply into a decisiontree + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { + // TODO: can we do this more efficiently? + return toDecisionTreeFactor() * f; + } + + /* + * Ensure Arc-consistency + * @param j domain to be checked + * @param domains all other domains + */ + /// + bool ensureArcConsistency(size_t j, + std::vector* domains) const override { + throw std::runtime_error( + "BinaryAllDiff::ensureArcConsistency not implemented"); + return false; + } + + /// Partially apply known values + Constraint::shared_ptr partiallyApply(const Values&) const override { + throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); + } + + /// Partially apply known values, domain version + Constraint::shared_ptr partiallyApply( + const std::vector&) const override { + throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); + } + }; + +} // namespace gtsam diff --git a/gtsam_unstable/discrete/CSP.cpp b/gtsam_unstable/discrete/CSP.cpp index b1d70dc6e..bab1ac3c8 100644 --- a/gtsam_unstable/discrete/CSP.cpp +++ b/gtsam_unstable/discrete/CSP.cpp @@ -56,12 +56,11 @@ void CSP::runArcConsistency(size_t cardinality, size_t nrIterations, // if not already a singleton if (!domains[v].isSingleton()) { // get the constraint and call its ensureArcConsistency method - Constraint::shared_ptr constraint = - boost::dynamic_pointer_cast((*this)[f]); + auto constraint = boost::dynamic_pointer_cast((*this)[f]); if (!constraint) throw runtime_error("CSP:runArcConsistency: non-constraint factor"); changed[v] = - constraint->ensureArcConsistency(v, domains) || changed[v]; + constraint->ensureArcConsistency(v, &domains) || changed[v]; } } // f if (changed[v]) anyChange = true; diff --git a/gtsam_unstable/discrete/Constraint.h b/gtsam_unstable/discrete/Constraint.h index b8baccff9..ff6f3834e 100644 --- a/gtsam_unstable/discrete/Constraint.h +++ b/gtsam_unstable/discrete/Constraint.h @@ -17,68 +17,79 @@ #pragma once -#include #include - +#include #include namespace gtsam { -class Domain; + class Domain; -/** - * Base class for discrete probabilistic factors - * The most general one is the derived DecisionTreeFactor - */ -class Constraint : public DiscreteFactor { - public: - typedef boost::shared_ptr shared_ptr; - - protected: - /// Construct n-way factor - Constraint(const KeyVector& js) : DiscreteFactor(js) {} - - /// Construct unary factor - Constraint(Key j) : DiscreteFactor(boost::assign::cref_list_of<1>(j)) {} - - /// Construct binary factor - Constraint(Key j1, Key j2) - : DiscreteFactor(boost::assign::cref_list_of<2>(j1)(j2)) {} - - /// construct from container - template - Constraint(KeyIterator beginKey, KeyIterator endKey) - : DiscreteFactor(beginKey, endKey) {} - - public: - /// @name Standard Constructors - /// @{ - - /// Default constructor for I/O - Constraint(); - - /// Virtual destructor - ~Constraint() override {} - - /// @} - /// @name Standard Interface - /// @{ - - /* - * Ensure Arc-consistency - * @param j domain to be checked - * @param domains all other domains + /** + * Base class for constraint factors + * Derived classes include SingleValue, BinaryAllDiff, and AllDiff. */ - virtual bool ensureArcConsistency(size_t j, - std::vector& domains) const = 0; + class GTSAM_EXPORT Constraint : public DiscreteFactor { - /// Partially apply known values - virtual shared_ptr partiallyApply(const Values&) const = 0; + public: - /// Partially apply known values, domain version - virtual shared_ptr partiallyApply(const std::vector&) const = 0; - /// @} -}; + typedef boost::shared_ptr shared_ptr; + + protected: + + /// Construct unary constraint factor. + Constraint(Key j) : + DiscreteFactor(boost::assign::cref_list_of<1>(j)) { + } + + /// Construct binary constraint factor. + Constraint(Key j1, Key j2) : + DiscreteFactor(boost::assign::cref_list_of<2>(j1)(j2)) { + } + + /// Construct n-way constraint factor. + Constraint(const KeyVector& js) : + DiscreteFactor(js) { + } + + /// construct from container + template + Constraint(KeyIterator beginKey, KeyIterator endKey) : + DiscreteFactor(beginKey, endKey) { + } + + public: + + /// @name Standard Constructors + /// @{ + + /// Default constructor for I/O + Constraint(); + + /// Virtual destructor + ~Constraint() override {} + + /// @} + /// @name Standard Interface + /// @{ + + /* + * Ensure Arc-consistency, possibly changing domains of connected variables. + * @param j domain to be checked + * @param (in/out) domains all other domains + * @return true if domains were changed, false otherwise. + */ + virtual bool ensureArcConsistency(size_t j, + std::vector* domains) const = 0; + + /// Partially apply known values + virtual shared_ptr partiallyApply(const Values&) const = 0; + + + /// Partially apply known values, domain version + virtual shared_ptr partiallyApply(const std::vector&) const = 0; + /// @} + }; // DiscreteFactor -} // namespace gtsam +}// namespace gtsam diff --git a/gtsam_unstable/discrete/Domain.cpp b/gtsam_unstable/discrete/Domain.cpp index a81b1d1ad..c2ba1c7f9 100644 --- a/gtsam_unstable/discrete/Domain.cpp +++ b/gtsam_unstable/discrete/Domain.cpp @@ -5,89 +5,90 @@ * @author Frank Dellaert */ -#include -#include #include - +#include +#include #include namespace gtsam { -using namespace std; + using namespace std; -/* ************************************************************************* */ -void Domain::print(const string& s, const KeyFormatter& formatter) const { - // cout << s << ": Domain on " << formatter(keys_[0]) << " (j=" << - // formatter(keys_[0]) << ") with values"; - // for (size_t v: values_) cout << " " << v; - // cout << endl; - for (size_t v : values_) cout << v; -} - -/* ************************************************************************* */ -double Domain::operator()(const Values& values) const { - return contains(values.at(keys_[0])); -} - -/* ************************************************************************* */ -DecisionTreeFactor Domain::toDecisionTreeFactor() const { - DiscreteKeys keys; - keys += DiscreteKey(keys_[0], cardinality_); - vector table; - for (size_t i1 = 0; i1 < cardinality_; ++i1) table.push_back(contains(i1)); - DecisionTreeFactor converted(keys, table); - return converted; -} - -/* ************************************************************************* */ -DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const { - // TODO: can we do this more efficiently? - return toDecisionTreeFactor() * f; -} - -/* ************************************************************************* */ -bool Domain::ensureArcConsistency(size_t j, vector& domains) const { - if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain"); - Domain& D = domains[j]; - for (size_t value : values_) - if (!D.contains(value)) throw runtime_error("Unsatisfiable"); - D = *this; - return true; -} - -/* ************************************************************************* */ -bool Domain::checkAllDiff(const KeyVector keys, vector& domains) { - Key j = keys_[0]; - // for all values in this domain - for (size_t value : values_) { - // for all connected domains - for (Key k : keys) - // if any domain contains the value we cannot make this domain singleton - if (k != j && domains[k].contains(value)) goto found; - values_.clear(); - values_.insert(value); - return true; // we changed it - found:; + /* ************************************************************************* */ + void Domain::print(const string& s, + const KeyFormatter& formatter) const { + cout << s << ": Domain on " << formatter(keys_[0]) << " (j=" << + formatter(keys_[0]) << ") with values"; + for (size_t v: values_) cout << " " << v; + cout << endl; + } + + /* ************************************************************************* */ + double Domain::operator()(const Values& values) const { + return contains(values.at(keys_[0])); + } + + /* ************************************************************************* */ + DecisionTreeFactor Domain::toDecisionTreeFactor() const { + DiscreteKeys keys; + keys += DiscreteKey(keys_[0],cardinality_); + vector table; + for (size_t i1 = 0; i1 < cardinality_; ++i1) + table.push_back(contains(i1)); + DecisionTreeFactor converted(keys, table); + return converted; + } + + /* ************************************************************************* */ + DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const { + // TODO: can we do this more efficiently? + return toDecisionTreeFactor() * f; + } + + /* ************************************************************************* */ + bool Domain::ensureArcConsistency(size_t j, vector* domains) const { + if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain"); + Domain& D = domains->at(j); + for(size_t value: values_) + if (!D.contains(value)) throw runtime_error("Unsatisfiable"); + D = *this; + return true; + } + + /* ************************************************************************* */ + boost::optional Domain::checkAllDiff( + const KeyVector keys, const vector& domains) const { + Key j = keys_[0]; + // for all values in this domain + for (const size_t value : values_) { + // for all connected domains + for (const Key k : keys) + // if any domain contains the value we cannot make this domain singleton + if (k != j && domains[k].contains(value)) goto found; + // Otherwise: return a singleton: + return Domain(this->discreteKey(), value); + found:; + } + return boost::none; // we did not change it + } + + /* ************************************************************************* */ + Constraint::shared_ptr Domain::partiallyApply( + const Values& values) const { + Values::const_iterator it = values.find(keys_[0]); + if (it != values.end() && !contains(it->second)) throw runtime_error( + "Domain::partiallyApply: unsatisfiable"); + return boost::make_shared < Domain > (*this); + } + + /* ************************************************************************* */ + Constraint::shared_ptr Domain::partiallyApply( + const vector& domains) const { + const Domain& Dk = domains[keys_[0]]; + if (Dk.isSingleton() && !contains(*Dk.begin())) throw runtime_error( + "Domain::partiallyApply: unsatisfiable"); + return boost::make_shared < Domain > (Dk); } - return false; // we did not change it -} /* ************************************************************************* */ -Constraint::shared_ptr Domain::partiallyApply(const Values& values) const { - Values::const_iterator it = values.find(keys_[0]); - if (it != values.end() && !contains(it->second)) - throw runtime_error("Domain::partiallyApply: unsatisfiable"); - return boost::make_shared(*this); -} - -/* ************************************************************************* */ -Constraint::shared_ptr Domain::partiallyApply( - const vector& domains) const { - const Domain& Dk = domains[keys_[0]]; - if (Dk.isSingleton() && !contains(*Dk.begin())) - throw runtime_error("Domain::partiallyApply: unsatisfiable"); - return boost::make_shared(Dk); -} - -/* ************************************************************************* */ -} // namespace gtsam +} // namespace gtsam diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 15828b653..d06966081 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -7,18 +7,23 @@ #pragma once -#include #include +#include namespace gtsam { /** - * Domain restriction constraint + * The Domain class represents a constraint that restricts the possible values a + * particular variable, with given key, can take on. */ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { size_t cardinality_; /// Cardinality std::set values_; /// allowed values + DiscreteKey discreteKey() const { + return DiscreteKey(keys_[0], cardinality_); + } + public: typedef boost::shared_ptr shared_ptr; @@ -35,14 +40,10 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { values_.insert(v); } - /// Constructor - Domain(const Domain& other) - : Constraint(other.keys_[0]), values_(other.values_) {} - - /// insert a value, non const :-( + /// Insert a value, non const :-( void insert(size_t value) { values_.insert(value); } - /// erase a value, non const :-( + /// Erase a value, non const :-( void erase(size_t value) { values_.erase(value); } size_t nrValues() const { return values_.size(); } @@ -82,15 +83,17 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { * @param domains all other domains */ bool ensureArcConsistency(size_t j, - std::vector& domains) const override; + std::vector* domains) const override; /** - * Check for a value in domain that does not occur in any other connected - * domain. If found, we make this a singleton... Called in - * AllDiff::ensureArcConsistency - * @param keys connected domains through alldiff + * Check for a value in domain that does not occur in any other connected + * domain. If found, return a a new singleton domain... + * Called in AllDiff::ensureArcConsistency + * @param keys connected domains through alldiff + * @param keys other domains */ - bool checkAllDiff(const KeyVector keys, std::vector& domains); + boost::optional checkAllDiff( + const KeyVector keys, const std::vector& domains) const; /// Partially apply known values Constraint::shared_ptr partiallyApply(const Values& values) const override; @@ -98,6 +101,6 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { /// Partially apply known values, domain version Constraint::shared_ptr partiallyApply( const std::vector& domains) const override; -}; + }; -} // namespace gtsam +} // namespace gtsam diff --git a/gtsam_unstable/discrete/SingleValue.cpp b/gtsam_unstable/discrete/SingleValue.cpp index 105887dc9..e042e550c 100644 --- a/gtsam_unstable/discrete/SingleValue.cpp +++ b/gtsam_unstable/discrete/SingleValue.cpp @@ -5,74 +5,75 @@ * @author Frank Dellaert */ -#include -#include -#include #include - +#include +#include +#include #include namespace gtsam { -using namespace std; + using namespace std; -/* ************************************************************************* */ -void SingleValue::print(const string& s, const KeyFormatter& formatter) const { - cout << s << "SingleValue on " - << "j=" << formatter(keys_[0]) << " with value " << value_ << endl; -} - -/* ************************************************************************* */ -double SingleValue::operator()(const Values& values) const { - return (double)(values.at(keys_[0]) == value_); -} - -/* ************************************************************************* */ -DecisionTreeFactor SingleValue::toDecisionTreeFactor() const { - DiscreteKeys keys; - keys += DiscreteKey(keys_[0], cardinality_); - vector table; - for (size_t i1 = 0; i1 < cardinality_; i1++) table.push_back(i1 == value_); - DecisionTreeFactor converted(keys, table); - return converted; -} - -/* ************************************************************************* */ -DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const { - // TODO: can we do this more efficiently? - return toDecisionTreeFactor() * f; -} - -/* ************************************************************************* */ -bool SingleValue::ensureArcConsistency(size_t j, - vector& domains) const { - if (j != keys_[0]) - throw invalid_argument("SingleValue check on wrong domain"); - Domain& D = domains[j]; - if (D.isSingleton()) { - if (D.firstValue() != value_) throw runtime_error("Unsatisfiable"); - return false; + /* ************************************************************************* */ + void SingleValue::print(const string& s, + const KeyFormatter& formatter) const { + cout << s << "SingleValue on " << "j=" << formatter(keys_[0]) + << " with value " << value_ << endl; + } + + /* ************************************************************************* */ + double SingleValue::operator()(const Values& values) const { + return (double) (values.at(keys_[0]) == value_); + } + + /* ************************************************************************* */ + DecisionTreeFactor SingleValue::toDecisionTreeFactor() const { + DiscreteKeys keys; + keys += DiscreteKey(keys_[0],cardinality_); + vector table; + for (size_t i1 = 0; i1 < cardinality_; i1++) + table.push_back(i1 == value_); + DecisionTreeFactor converted(keys, table); + return converted; + } + + /* ************************************************************************* */ + DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const { + // TODO: can we do this more efficiently? + return toDecisionTreeFactor() * f; + } + + /* ************************************************************************* */ + bool SingleValue::ensureArcConsistency(size_t j, + vector* domains) const { + if (j != keys_[0]) + throw invalid_argument("SingleValue check on wrong domain"); + Domain& D = domains->at(j); + if (D.isSingleton()) { + if (D.firstValue() != value_) throw runtime_error("Unsatisfiable"); + return false; + } + D = Domain(discreteKey(), value_); + return true; + } + + /* ************************************************************************* */ + Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const { + Values::const_iterator it = values.find(keys_[0]); + if (it != values.end() && it->second != value_) throw runtime_error( + "SingleValue::partiallyApply: unsatisfiable"); + return boost::make_shared(keys_[0], cardinality_, value_); + } + + /* ************************************************************************* */ + Constraint::shared_ptr SingleValue::partiallyApply( + const vector& domains) const { + const Domain& Dk = domains[keys_[0]]; + if (Dk.isSingleton() && !Dk.contains(value_)) throw runtime_error( + "SingleValue::partiallyApply: unsatisfiable"); + return boost::make_shared(discreteKey(), value_); } - D = Domain(discreteKey(), value_); - return true; -} /* ************************************************************************* */ -Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const { - Values::const_iterator it = values.find(keys_[0]); - if (it != values.end() && it->second != value_) - throw runtime_error("SingleValue::partiallyApply: unsatisfiable"); - return boost::make_shared(keys_[0], cardinality_, value_); -} - -/* ************************************************************************* */ -Constraint::shared_ptr SingleValue::partiallyApply( - const vector& domains) const { - const Domain& Dk = domains[keys_[0]]; - if (Dk.isSingleton() && !Dk.contains(value_)) - throw runtime_error("SingleValue::partiallyApply: unsatisfiable"); - return boost::make_shared(discreteKey(), value_); -} - -/* ************************************************************************* */ -} // namespace gtsam +} // namespace gtsam diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index a2aec338c..0f9a8fb0f 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -7,73 +7,74 @@ #pragma once -#include #include namespace gtsam { -/** - * SingleValue constraint - */ -class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { - /// Number of values - size_t cardinality_; - - /// allowed value - size_t value_; - - DiscreteKey discreteKey() const { - return DiscreteKey(keys_[0], cardinality_); - } - - public: - typedef boost::shared_ptr shared_ptr; - - /// Constructor - SingleValue(Key key, size_t n, size_t value) - : Constraint(key), cardinality_(n), value_(value) {} - - /// Constructor - SingleValue(const DiscreteKey& dkey, size_t value) - : Constraint(dkey.first), cardinality_(dkey.second), value_(value) {} - - // print - void print(const std::string& s = "", const KeyFormatter& formatter = - DefaultKeyFormatter) const override; - - /// equals - bool equals(const DiscreteFactor& other, double tol) const override { - if (!dynamic_cast(&other)) - return false; - else { - const SingleValue& f(static_cast(other)); - return (cardinality_ == f.cardinality_) && (value_ == f.value_); - } - } - - /// Calculate value - double operator()(const Values& values) const override; - - /// Convert into a decisiontree - DecisionTreeFactor toDecisionTreeFactor() const override; - - /// Multiply into a decisiontree - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; - - /* - * Ensure Arc-consistency - * @param j domain to be checked - * @param domains all other domains + /** + * SingleValue constraint: ensures a variable takes on a certain value. + * This could of course also be implemented by changing its `Domain`. */ - bool ensureArcConsistency(size_t j, - std::vector& domains) const override; + class GTSAM_UNSTABLE_EXPORT SingleValue: public Constraint { + + size_t cardinality_; /// < Number of values + size_t value_; ///< allowed value - /// Partially apply known values - Constraint::shared_ptr partiallyApply(const Values& values) const override; + DiscreteKey discreteKey() const { + return DiscreteKey(keys_[0],cardinality_); + } - /// Partially apply known values, domain version - Constraint::shared_ptr partiallyApply( - const std::vector& domains) const override; -}; + public: -} // namespace gtsam + typedef boost::shared_ptr shared_ptr; + + /// Construct from key, cardinality, and given value. + SingleValue(Key key, size_t n, size_t value) : + Constraint(key), cardinality_(n), value_(value) { + } + + /// Construct from DiscreteKey and given value. + SingleValue(const DiscreteKey& dkey, size_t value) : + Constraint(dkey.first), cardinality_(dkey.second), value_(value) { + } + + // print + void print(const std::string& s = "", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; + + /// equals + bool equals(const DiscreteFactor& other, double tol) const override { + if(!dynamic_cast(&other)) + return false; + else { + const SingleValue& f(static_cast(other)); + return (cardinality_==f.cardinality_) && (value_==f.value_); + } + } + + /// Calculate value + double operator()(const Values& values) const override; + + /// Convert into a decisiontree + DecisionTreeFactor toDecisionTreeFactor() const override; + + /// Multiply into a decisiontree + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + + /* + * Ensure Arc-consistency: just sets domain[j] to {value_} + * @param j domain to be checked + * @param domains all other domains + */ + bool ensureArcConsistency(size_t j, + std::vector* domains) const override; + + /// Partially apply known values + Constraint::shared_ptr partiallyApply(const Values& values) const override; + + /// Partially apply known values, domain version + Constraint::shared_ptr partiallyApply( + const std::vector& domains) const override; + }; + +} // namespace gtsam diff --git a/gtsam_unstable/discrete/tests/testCSP.cpp b/gtsam_unstable/discrete/tests/testCSP.cpp index 1552fcbf1..b1aaab303 100644 --- a/gtsam_unstable/discrete/tests/testCSP.cpp +++ b/gtsam_unstable/discrete/tests/testCSP.cpp @@ -19,12 +19,33 @@ using namespace std; using namespace gtsam; /* ************************************************************************* */ -TEST_UNSAFE(BinaryAllDif, allInOne) { - // Create keys and ordering +TEST(CSP, SingleValue) { + // Create keys for Idaho, Arizona, and Utah, allowing two colors for each: + size_t nrColors = 3; + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); + + // Check that a single value is equal to a decision stump with only one "1": + SingleValue singleValue(AZ, 2); + DecisionTreeFactor f1(AZ, "0 0 1"); + EXPECT(assert_equal(f1, singleValue.toDecisionTreeFactor())); + + // Create domains, laid out as a vector. + // TODO(dellaert): should be map?? + vector domains; + domains += Domain(ID), Domain(AZ), Domain(UT); + + // Ensure arc-consistency: just wipes out values in AZ domain: + EXPECT(singleValue.ensureArcConsistency(1, &domains)); + LONGS_EQUAL(3, domains[0].nrValues()); + LONGS_EQUAL(1, domains[1].nrValues()); + LONGS_EQUAL(3, domains[2].nrValues()); +} + +/* ************************************************************************* */ +TEST(CSP, BinaryAllDif) { + // Create keys for Idaho, Arizona, and Utah, allowing 2 colors for each: size_t nrColors = 2; - // DiscreteKey ID("Idaho", nrColors), UT("Utah", nrColors), AZ("Arizona", - // nrColors); - DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); // Check construction and conversion BinaryAllDiff c1(ID, UT); @@ -36,16 +57,51 @@ TEST_UNSAFE(BinaryAllDif, allInOne) { DecisionTreeFactor f2(UT & AZ, "0 1 1 0"); EXPECT(assert_equal(f2, c2.toDecisionTreeFactor())); + // Check multiplication of factors with constraint: DecisionTreeFactor f3 = f1 * f2; EXPECT(assert_equal(f3, c1 * f2)); EXPECT(assert_equal(f3, c2 * f1)); } /* ************************************************************************* */ -TEST_UNSAFE(CSP, allInOne) { - // Create keys and ordering +TEST(CSP, AllDiff) { + // Create keys for Idaho, Arizona, and Utah, allowing two colors for each: + size_t nrColors = 3; + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); + + // Check construction and conversion + vector dkeys{ID, UT, AZ}; + AllDiff alldiff(dkeys); + DecisionTreeFactor actual = alldiff.toDecisionTreeFactor(); + // GTSAM_PRINT(actual); + actual.dot("actual"); + DecisionTreeFactor f2( + ID & AZ & UT, + "0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0"); + EXPECT(assert_equal(f2, actual)); + + // Create domains. + vector domains; + domains += Domain(ID), Domain(AZ), Domain(UT); + + // First constrict AZ domain: + SingleValue singleValue(AZ, 2); + EXPECT(singleValue.ensureArcConsistency(1, &domains)); + + // Arc-consistency + EXPECT(alldiff.ensureArcConsistency(0, &domains)); + EXPECT(!alldiff.ensureArcConsistency(1, &domains)); + EXPECT(alldiff.ensureArcConsistency(2, &domains)); + LONGS_EQUAL(2, domains[0].nrValues()); + LONGS_EQUAL(1, domains[1].nrValues()); + LONGS_EQUAL(2, domains[2].nrValues()); +} + +/* ************************************************************************* */ +TEST(CSP, allInOne) { + // Create keys for Idaho, Arizona, and Utah, allowing 3 colors for each: size_t nrColors = 2; - DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); // Create the CSP CSP csp; @@ -81,15 +137,12 @@ TEST_UNSAFE(CSP, allInOne) { } /* ************************************************************************* */ -TEST_UNSAFE(CSP, WesternUS) { - // Create keys +TEST(CSP, WesternUS) { + // Create keys for all states in Western US, with 4 color possibilities. size_t nrColors = 4; - DiscreteKey - // Create ordering according to example in ND-CSP.lyx - WA(0, nrColors), - OR(3, nrColors), CA(1, nrColors), NV(2, nrColors), ID(8, nrColors), - UT(9, nrColors), AZ(10, nrColors), MT(4, nrColors), WY(5, nrColors), - CO(7, nrColors), NM(6, nrColors); + DiscreteKey WA(0, nrColors), OR(3, nrColors), CA(1, nrColors), + NV(2, nrColors), ID(8, nrColors), UT(9, nrColors), AZ(10, nrColors), + MT(4, nrColors), WY(5, nrColors), CO(7, nrColors), NM(6, nrColors); // Create the CSP CSP csp; @@ -116,10 +169,12 @@ TEST_UNSAFE(CSP, WesternUS) { csp.addAllDiff(WY, CO); csp.addAllDiff(CO, NM); - // Solve + // Create ordering according to example in ND-CSP.lyx Ordering ordering; ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7), Key(8), Key(9), Key(10); + + // Solve using that ordering: CSP::sharedValues mpe = csp.optimalAssignment(ordering); // GTSAM_PRINT(*mpe); CSP::Values expected; @@ -143,33 +198,17 @@ TEST_UNSAFE(CSP, WesternUS) { } /* ************************************************************************* */ -TEST_UNSAFE(CSP, AllDiff) { - // Create keys and ordering +TEST(CSP, ArcConsistency) { + // Create keys for Idaho, Arizona, and Utah, allowing three colors for each: size_t nrColors = 3; - DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); - // Create the CSP + // Create the CSP using just one all-diff constraint, plus constrain Arizona. CSP csp; - vector dkeys; - dkeys += ID, UT, AZ; + vector dkeys{ID, UT, AZ}; csp.addAllDiff(dkeys); csp.addSingleValue(AZ, 2); - // GTSAM_PRINT(csp); - - // Check construction and conversion - SingleValue s(AZ, 2); - DecisionTreeFactor f1(AZ, "0 0 1"); - EXPECT(assert_equal(f1, s.toDecisionTreeFactor())); - - // Check construction and conversion - AllDiff alldiff(dkeys); - DecisionTreeFactor actual = alldiff.toDecisionTreeFactor(); - // GTSAM_PRINT(actual); - // actual.dot("actual"); - DecisionTreeFactor f2( - ID & AZ & UT, - "0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0"); - EXPECT(assert_equal(f2, actual)); + // GTSAM_PRINT(csp); // Check an invalid combination, with ID==UT==AZ all same color DiscreteFactor::Values invalid; @@ -192,14 +231,15 @@ TEST_UNSAFE(CSP, AllDiff) { EXPECT(assert_equal(expected, *mpe)); EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); - // Arc-consistency + // ensure arc-consistency, i.e., narrow domains... vector domains; domains += Domain(ID), Domain(AZ), Domain(UT); SingleValue singleValue(AZ, 2); - EXPECT(singleValue.ensureArcConsistency(1, domains)); - EXPECT(alldiff.ensureArcConsistency(0, domains)); - EXPECT(!alldiff.ensureArcConsistency(1, domains)); - EXPECT(alldiff.ensureArcConsistency(2, domains)); + AllDiff alldiff(dkeys); + EXPECT(singleValue.ensureArcConsistency(1, &domains)); + EXPECT(alldiff.ensureArcConsistency(0, &domains)); + EXPECT(!alldiff.ensureArcConsistency(1, &domains)); + EXPECT(alldiff.ensureArcConsistency(2, &domains)); LONGS_EQUAL(2, domains[0].nrValues()); LONGS_EQUAL(1, domains[1].nrValues()); LONGS_EQUAL(2, domains[2].nrValues()); @@ -222,6 +262,7 @@ TEST_UNSAFE(CSP, AllDiff) { // full arc-consistency test csp.runArcConsistency(nrColors); + // GTSAM_PRINT(csp); } /* ************************************************************************* */