Revamped arc consistency

release/4.3a0
Frank Dellaert 2021-11-18 10:17:49 -05:00
parent 770fda9a26
commit dd50975668
10 changed files with 612 additions and 547 deletions

View File

@ -5,73 +5,82 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam/base/Testable.h>
#include <gtsam_unstable/discrete/AllDiff.h>
#include <gtsam_unstable/discrete/Domain.h> #include <gtsam_unstable/discrete/Domain.h>
#include <gtsam_unstable/discrete/AllDiff.h>
#include <gtsam/base/Testable.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
namespace gtsam { namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
AllDiff::AllDiff(const DiscreteKeys& dkeys) : Constraint(dkeys.indices()) { AllDiff::AllDiff(const DiscreteKeys& dkeys) :
for (const DiscreteKey& dkey : dkeys) cardinalities_.insert(dkey); Constraint(dkeys.indices()) {
} for(const DiscreteKey& dkey: dkeys)
cardinalities_.insert(dkey);
}
/* ************************************************************************* */ /* ************************************************************************* */
void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const { void AllDiff::print(const std::string& s,
const KeyFormatter& formatter) const {
std::cout << s << "AllDiff on "; std::cout << s << "AllDiff on ";
for (Key dkey : keys_) std::cout << formatter(dkey) << " "; for (Key dkey: keys_)
std::cout << formatter(dkey) << " ";
std::cout << std::endl; std::cout << std::endl;
} }
/* ************************************************************************* */ /* ************************************************************************* */
double AllDiff::operator()(const Values& values) const { double AllDiff::operator()(const Values& values) const {
std::set<size_t> taken; // record values taken by keys std::set < size_t > taken; // record values taken by keys
for (Key dkey : keys_) { for(Key dkey: keys_) {
size_t value = values.at(dkey); // get the value for that key size_t value = values.at(dkey); // get the value for that key
if (taken.count(value)) return 0.0; // check if value alreday taken if (taken.count(value)) return 0.0;// check if value alreday taken
taken.insert(value); // if not, record it as taken and keep checking taken.insert(value);// if not, record it as taken and keep checking
} }
return 1.0; return 1.0;
} }
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor AllDiff::toDecisionTreeFactor() const { DecisionTreeFactor AllDiff::toDecisionTreeFactor() const {
// We will do this by converting the allDif into many BinaryAllDiff // We will do this by converting the allDif into many BinaryAllDiff constraints
// constraints
DecisionTreeFactor converted; DecisionTreeFactor converted;
size_t nrKeys = keys_.size(); size_t nrKeys = keys_.size();
for (size_t i1 = 0; i1 < nrKeys; i1++) for (size_t i1 = 0; i1 < nrKeys; i1++)
for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) { for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) {
BinaryAllDiff binary12(discreteKey(i1), discreteKey(i2)); BinaryAllDiff binary12(discreteKey(i1),discreteKey(i2));
converted = converted * binary12.toDecisionTreeFactor(); converted = converted * binary12.toDecisionTreeFactor();
} }
return converted; return converted;
} }
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const { DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const {
// TODO: can we do this more efficiently? // TODO: can we do this more efficiently?
return toDecisionTreeFactor() * f; return toDecisionTreeFactor() * f;
} }
/* ************************************************************************* */
bool AllDiff::ensureArcConsistency(size_t j,
std::vector<Domain>* domains) const {
// We are changing the domain of variable j.
// TODO(dellaert): confusing, I thought we were changing others...
Domain& Dj = domains->at(j);
/* ************************************************************************* */
bool AllDiff::ensureArcConsistency(size_t j,
std::vector<Domain>& domains) const {
// Though strictly not part of allDiff, we check for // Though strictly not part of allDiff, we check for
// a value in domains[j] that does not occur in any other connected domain. // a value in domains[j] that does not occur in any other connected domain.
// If found, we make this a singleton... // If found, we make this a singleton...
// TODO: make a new constraint where this really is true // TODO: make a new constraint where this really is true
Domain& Dj = domains[j]; boost::optional<Domain> maybeChanged = Dj.checkAllDiff(keys_, *domains);
if (Dj.checkAllDiff(keys_, domains)) return true; if (maybeChanged) {
Dj = *maybeChanged;
return true;
}
// Check all other domains for singletons and erase corresponding values // Check all other domains for singletons and erase corresponding values.
// This is the same as arc-consistency on the equivalent binary constraints // This is the same as arc-consistency on the equivalent binary constraints
bool changed = false; bool changed = false;
for (Key k : keys_) for (Key k : keys_)
if (k != j) { if (k != j) {
const Domain& Dk = domains[k]; const Domain& Dk = domains->at(k);
if (Dk.isSingleton()) { // check if singleton if (Dk.isSingleton()) { // check if singleton
size_t value = Dk.firstValue(); size_t value = Dk.firstValue();
if (Dj.contains(value)) { if (Dj.contains(value)) {
@ -81,29 +90,30 @@ bool AllDiff::ensureArcConsistency(size_t j,
} }
} }
return changed; return changed;
} }
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const { Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const {
DiscreteKeys newKeys; DiscreteKeys newKeys;
// loop over keys and add them only if they do not appear in values // loop over keys and add them only if they do not appear in values
for (Key k : keys_) for(Key k: keys_)
if (values.find(k) == values.end()) { if (values.find(k) == values.end()) {
newKeys.push_back(DiscreteKey(k, cardinalities_.at(k))); newKeys.push_back(DiscreteKey(k,cardinalities_.at(k)));
} }
return boost::make_shared<AllDiff>(newKeys); return boost::make_shared<AllDiff>(newKeys);
} }
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr AllDiff::partiallyApply( Constraint::shared_ptr AllDiff::partiallyApply(
const std::vector<Domain>& domains) const { const std::vector<Domain>& domains) const {
DiscreteFactor::Values known; DiscreteFactor::Values known;
for (Key k : keys_) { for(Key k: keys_) {
const Domain& Dk = domains[k]; const Domain& Dk = domains[k];
if (Dk.isSingleton()) known[k] = Dk.firstValue(); if (Dk.isSingleton())
known[k] = Dk.firstValue();
} }
return partiallyApply(known); return partiallyApply(known);
} }
/* ************************************************************************* */ /* ************************************************************************* */
} // namespace gtsam } // namespace gtsam

View File

@ -7,42 +7,41 @@
#pragma once #pragma once
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam_unstable/discrete/BinaryAllDiff.h> #include <gtsam_unstable/discrete/BinaryAllDiff.h>
#include <gtsam/discrete/DiscreteKey.h>
namespace gtsam { namespace gtsam {
/** /**
* General AllDiff constraint * General AllDiff constraint.
* Returns 1 if values for all keys are different, 0 otherwise * Returns 1 if values for all keys are different, 0 otherwise.
* DiscreteFactors are all awkward in that they have to store two types of keys:
* for each variable we have a Key and an Key. In this factor, we
* keep the Indices locally, and the Indices are stored in IndexFactor.
*/ */
class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { class GTSAM_UNSTABLE_EXPORT AllDiff: public Constraint {
std::map<Key, size_t> cardinalities_;
std::map<Key,size_t> cardinalities_;
DiscreteKey discreteKey(size_t i) const { DiscreteKey discreteKey(size_t i) const {
Key j = keys_[i]; Key j = keys_[i];
return DiscreteKey(j, cardinalities_.at(j)); return DiscreteKey(j,cardinalities_.at(j));
} }
public: public:
/// Constructor
/// Construct from keys.
AllDiff(const DiscreteKeys& dkeys); AllDiff(const DiscreteKeys& dkeys);
// print // print
void print(const std::string& s = "", const KeyFormatter& formatter = void print(const std::string& s = "",
DefaultKeyFormatter) const override; const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/// equals /// equals
bool equals(const DiscreteFactor& other, double tol) const override { bool equals(const DiscreteFactor& other, double tol) const override {
if (!dynamic_cast<const AllDiff*>(&other)) if(!dynamic_cast<const AllDiff*>(&other))
return false; return false;
else { else {
const AllDiff& f(static_cast<const AllDiff&>(other)); const AllDiff& f(static_cast<const AllDiff&>(other));
return cardinalities_.size() == f.cardinalities_.size() && return cardinalities_.size() == f.cardinalities_.size()
std::equal(cardinalities_.begin(), cardinalities_.end(), && std::equal(cardinalities_.begin(), cardinalities_.end(),
f.cardinalities_.begin()); f.cardinalities_.begin());
} }
} }
@ -61,10 +60,10 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
* Arc-consistency involves creating binaryAllDiff constraints * Arc-consistency involves creating binaryAllDiff constraints
* In which case the combinatorial hyper-arc explosion disappears. * In which case the combinatorial hyper-arc explosion disappears.
* @param j domain to be checked * @param j domain to be checked
* @param domains all other domains * @param (in/out) domains all other domains
*/ */
bool ensureArcConsistency(size_t j, bool ensureArcConsistency(size_t j,
std::vector<Domain>& domains) const override; std::vector<Domain>* domains) const override;
/// Partially apply known values /// Partially apply known values
Constraint::shared_ptr partiallyApply(const Values&) const override; Constraint::shared_ptr partiallyApply(const Values&) const override;
@ -72,6 +71,6 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
/// Partially apply known values, domain version /// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply( Constraint::shared_ptr partiallyApply(
const std::vector<Domain>&) const override; const std::vector<Domain>&) const override;
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -7,32 +7,30 @@
#pragma once #pragma once
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam_unstable/discrete/Constraint.h>
#include <gtsam_unstable/discrete/Domain.h> #include <gtsam_unstable/discrete/Domain.h>
#include <gtsam_unstable/discrete/Constraint.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
namespace gtsam { namespace gtsam {
/** /**
* Binary AllDiff constraint * Binary AllDiff constraint
* Returns 1 if values for two keys are different, 0 otherwise * Returns 1 if values for two keys are different, 0 otherwise.
* DiscreteFactors are all awkward in that they have to store two types of keys:
* for each variable we have a Index and an Index. In this factor, we
* keep the Indices locally, and the Indices are stored in IndexFactor.
*/ */
class BinaryAllDiff : public Constraint { class BinaryAllDiff: public Constraint {
size_t cardinality0_, cardinality1_; /// cardinality size_t cardinality0_, cardinality1_; /// cardinality
public: public:
/// Constructor /// Constructor
BinaryAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) BinaryAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) :
: Constraint(key1.first, key2.first), Constraint(key1.first, key2.first),
cardinality0_(key1.second), cardinality0_(key1.second), cardinality1_(key2.second) {
cardinality1_(key2.second) {} }
// print // print
void print( void print(const std::string& s = "",
const std::string& s = "",
const KeyFormatter& formatter = DefaultKeyFormatter) const override { const KeyFormatter& formatter = DefaultKeyFormatter) const override {
std::cout << s << "BinaryAllDiff on " << formatter(keys_[0]) << " and " std::cout << s << "BinaryAllDiff on " << formatter(keys_[0]) << " and "
<< formatter(keys_[1]) << std::endl; << formatter(keys_[1]) << std::endl;
@ -40,28 +38,28 @@ class BinaryAllDiff : public Constraint {
/// equals /// equals
bool equals(const DiscreteFactor& other, double tol) const override { bool equals(const DiscreteFactor& other, double tol) const override {
if (!dynamic_cast<const BinaryAllDiff*>(&other)) if(!dynamic_cast<const BinaryAllDiff*>(&other))
return false; return false;
else { else {
const BinaryAllDiff& f(static_cast<const BinaryAllDiff&>(other)); const BinaryAllDiff& f(static_cast<const BinaryAllDiff&>(other));
return (cardinality0_ == f.cardinality0_) && return (cardinality0_==f.cardinality0_) && (cardinality1_==f.cardinality1_);
(cardinality1_ == f.cardinality1_);
} }
} }
/// Calculate value /// Calculate value
double operator()(const Values& values) const override { double operator()(const Values& values) const override {
return (double)(values.at(keys_[0]) != values.at(keys_[1])); return (double) (values.at(keys_[0]) != values.at(keys_[1]));
} }
/// Convert into a decisiontree /// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override { DecisionTreeFactor toDecisionTreeFactor() const override {
DiscreteKeys keys; DiscreteKeys keys;
keys.push_back(DiscreteKey(keys_[0], cardinality0_)); keys.push_back(DiscreteKey(keys_[0],cardinality0_));
keys.push_back(DiscreteKey(keys_[1], cardinality1_)); keys.push_back(DiscreteKey(keys_[1],cardinality1_));
std::vector<double> table; std::vector<double> table;
for (size_t i1 = 0; i1 < cardinality0_; i1++) for (size_t i1 = 0; i1 < cardinality0_; i1++)
for (size_t i2 = 0; i2 < cardinality1_; i2++) table.push_back(i1 != i2); for (size_t i2 = 0; i2 < cardinality1_; i2++)
table.push_back(i1 != i2);
DecisionTreeFactor converted(keys, table); DecisionTreeFactor converted(keys, table);
return converted; return converted;
} }
@ -77,10 +75,11 @@ class BinaryAllDiff : public Constraint {
* @param j domain to be checked * @param j domain to be checked
* @param domains all other domains * @param domains all other domains
*/ */
///
bool ensureArcConsistency(size_t j, bool ensureArcConsistency(size_t j,
std::vector<Domain>& domains) const override { std::vector<Domain>* domains) const override {
// throw std::runtime_error( throw std::runtime_error(
// "BinaryAllDiff::ensureArcConsistency not implemented"); "BinaryAllDiff::ensureArcConsistency not implemented");
return false; return false;
} }
@ -94,6 +93,6 @@ class BinaryAllDiff : public Constraint {
const std::vector<Domain>&) const override { const std::vector<Domain>&) const override {
throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented");
} }
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -56,12 +56,11 @@ void CSP::runArcConsistency(size_t cardinality, size_t nrIterations,
// if not already a singleton // if not already a singleton
if (!domains[v].isSingleton()) { if (!domains[v].isSingleton()) {
// get the constraint and call its ensureArcConsistency method // get the constraint and call its ensureArcConsistency method
Constraint::shared_ptr constraint = auto constraint = boost::dynamic_pointer_cast<Constraint>((*this)[f]);
boost::dynamic_pointer_cast<Constraint>((*this)[f]);
if (!constraint) if (!constraint)
throw runtime_error("CSP:runArcConsistency: non-constraint factor"); throw runtime_error("CSP:runArcConsistency: non-constraint factor");
changed[v] = changed[v] =
constraint->ensureArcConsistency(v, domains) || changed[v]; constraint->ensureArcConsistency(v, &domains) || changed[v];
} }
} // f } // f
if (changed[v]) anyChange = true; if (changed[v]) anyChange = true;

View File

@ -17,40 +17,49 @@
#pragma once #pragma once
#include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam_unstable/dllexport.h> #include <gtsam_unstable/dllexport.h>
#include <gtsam/discrete/DiscreteFactor.h>
#include <boost/assign.hpp> #include <boost/assign.hpp>
namespace gtsam { namespace gtsam {
class Domain; class Domain;
/** /**
* Base class for discrete probabilistic factors * Base class for constraint factors
* The most general one is the derived DecisionTreeFactor * Derived classes include SingleValue, BinaryAllDiff, and AllDiff.
*/ */
class Constraint : public DiscreteFactor { class GTSAM_EXPORT Constraint : public DiscreteFactor {
public: public:
typedef boost::shared_ptr<Constraint> shared_ptr; typedef boost::shared_ptr<Constraint> shared_ptr;
protected: protected:
/// Construct n-way factor
Constraint(const KeyVector& js) : DiscreteFactor(js) {}
/// Construct unary factor /// Construct unary constraint factor.
Constraint(Key j) : DiscreteFactor(boost::assign::cref_list_of<1>(j)) {} Constraint(Key j) :
DiscreteFactor(boost::assign::cref_list_of<1>(j)) {
}
/// Construct binary factor /// Construct binary constraint factor.
Constraint(Key j1, Key j2) Constraint(Key j1, Key j2) :
: DiscreteFactor(boost::assign::cref_list_of<2>(j1)(j2)) {} DiscreteFactor(boost::assign::cref_list_of<2>(j1)(j2)) {
}
/// Construct n-way constraint factor.
Constraint(const KeyVector& js) :
DiscreteFactor(js) {
}
/// construct from container /// construct from container
template <class KeyIterator> template<class KeyIterator>
Constraint(KeyIterator beginKey, KeyIterator endKey) Constraint(KeyIterator beginKey, KeyIterator endKey) :
: DiscreteFactor(beginKey, endKey) {} DiscreteFactor(beginKey, endKey) {
}
public: public:
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
@ -65,20 +74,22 @@ class Constraint : public DiscreteFactor {
/// @{ /// @{
/* /*
* Ensure Arc-consistency * Ensure Arc-consistency, possibly changing domains of connected variables.
* @param j domain to be checked * @param j domain to be checked
* @param domains all other domains * @param (in/out) domains all other domains
* @return true if domains were changed, false otherwise.
*/ */
virtual bool ensureArcConsistency(size_t j, virtual bool ensureArcConsistency(size_t j,
std::vector<Domain>& domains) const = 0; std::vector<Domain>* domains) const = 0;
/// Partially apply known values /// Partially apply known values
virtual shared_ptr partiallyApply(const Values&) const = 0; virtual shared_ptr partiallyApply(const Values&) const = 0;
/// Partially apply known values, domain version /// Partially apply known values, domain version
virtual shared_ptr partiallyApply(const std::vector<Domain>&) const = 0; virtual shared_ptr partiallyApply(const std::vector<Domain>&) const = 0;
/// @} /// @}
}; };
// DiscreteFactor // DiscreteFactor
} // namespace gtsam }// namespace gtsam

View File

@ -5,89 +5,90 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam_unstable/discrete/Domain.h> #include <gtsam_unstable/discrete/Domain.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/base/Testable.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
namespace gtsam { namespace gtsam {
using namespace std; using namespace std;
/* ************************************************************************* */ /* ************************************************************************* */
void Domain::print(const string& s, const KeyFormatter& formatter) const { void Domain::print(const string& s,
// cout << s << ": Domain on " << formatter(keys_[0]) << " (j=" << const KeyFormatter& formatter) const {
// formatter(keys_[0]) << ") with values"; cout << s << ": Domain on " << formatter(keys_[0]) << " (j=" <<
// for (size_t v: values_) cout << " " << v; formatter(keys_[0]) << ") with values";
// cout << endl; for (size_t v: values_) cout << " " << v;
for (size_t v : values_) cout << v; cout << endl;
} }
/* ************************************************************************* */ /* ************************************************************************* */
double Domain::operator()(const Values& values) const { double Domain::operator()(const Values& values) const {
return contains(values.at(keys_[0])); return contains(values.at(keys_[0]));
} }
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor Domain::toDecisionTreeFactor() const { DecisionTreeFactor Domain::toDecisionTreeFactor() const {
DiscreteKeys keys; DiscreteKeys keys;
keys += DiscreteKey(keys_[0], cardinality_); keys += DiscreteKey(keys_[0],cardinality_);
vector<double> table; vector<double> table;
for (size_t i1 = 0; i1 < cardinality_; ++i1) table.push_back(contains(i1)); for (size_t i1 = 0; i1 < cardinality_; ++i1)
table.push_back(contains(i1));
DecisionTreeFactor converted(keys, table); DecisionTreeFactor converted(keys, table);
return converted; return converted;
} }
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const { DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const {
// TODO: can we do this more efficiently? // TODO: can we do this more efficiently?
return toDecisionTreeFactor() * f; return toDecisionTreeFactor() * f;
} }
/* ************************************************************************* */ /* ************************************************************************* */
bool Domain::ensureArcConsistency(size_t j, vector<Domain>& domains) const { bool Domain::ensureArcConsistency(size_t j, vector<Domain>* domains) const {
if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain"); if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain");
Domain& D = domains[j]; Domain& D = domains->at(j);
for (size_t value : values_) for(size_t value: values_)
if (!D.contains(value)) throw runtime_error("Unsatisfiable"); if (!D.contains(value)) throw runtime_error("Unsatisfiable");
D = *this; D = *this;
return true; return true;
} }
/* ************************************************************************* */ /* ************************************************************************* */
bool Domain::checkAllDiff(const KeyVector keys, vector<Domain>& domains) { boost::optional<Domain> Domain::checkAllDiff(
const KeyVector keys, const vector<Domain>& domains) const {
Key j = keys_[0]; Key j = keys_[0];
// for all values in this domain // for all values in this domain
for (size_t value : values_) { for (const size_t value : values_) {
// for all connected domains // for all connected domains
for (Key k : keys) for (const Key k : keys)
// if any domain contains the value we cannot make this domain singleton // if any domain contains the value we cannot make this domain singleton
if (k != j && domains[k].contains(value)) goto found; if (k != j && domains[k].contains(value)) goto found;
values_.clear(); // Otherwise: return a singleton:
values_.insert(value); return Domain(this->discreteKey(), value);
return true; // we changed it
found:; found:;
} }
return false; // we did not change it return boost::none; // we did not change it
} }
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr Domain::partiallyApply(const Values& values) const { Constraint::shared_ptr Domain::partiallyApply(
const Values& values) const {
Values::const_iterator it = values.find(keys_[0]); Values::const_iterator it = values.find(keys_[0]);
if (it != values.end() && !contains(it->second)) if (it != values.end() && !contains(it->second)) throw runtime_error(
throw runtime_error("Domain::partiallyApply: unsatisfiable"); "Domain::partiallyApply: unsatisfiable");
return boost::make_shared<Domain>(*this); return boost::make_shared < Domain > (*this);
} }
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr Domain::partiallyApply( Constraint::shared_ptr Domain::partiallyApply(
const vector<Domain>& domains) const { const vector<Domain>& domains) const {
const Domain& Dk = domains[keys_[0]]; const Domain& Dk = domains[keys_[0]];
if (Dk.isSingleton() && !contains(*Dk.begin())) if (Dk.isSingleton() && !contains(*Dk.begin())) throw runtime_error(
throw runtime_error("Domain::partiallyApply: unsatisfiable"); "Domain::partiallyApply: unsatisfiable");
return boost::make_shared<Domain>(Dk); return boost::make_shared < Domain > (Dk);
} }
/* ************************************************************************* */ /* ************************************************************************* */
} // namespace gtsam } // namespace gtsam

View File

@ -7,18 +7,23 @@
#pragma once #pragma once
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam_unstable/discrete/Constraint.h> #include <gtsam_unstable/discrete/Constraint.h>
#include <gtsam/discrete/DiscreteKey.h>
namespace gtsam { namespace gtsam {
/** /**
* Domain restriction constraint * The Domain class represents a constraint that restricts the possible values a
* particular variable, with given key, can take on.
*/ */
class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
size_t cardinality_; /// Cardinality size_t cardinality_; /// Cardinality
std::set<size_t> values_; /// allowed values std::set<size_t> values_; /// allowed values
DiscreteKey discreteKey() const {
return DiscreteKey(keys_[0], cardinality_);
}
public: public:
typedef boost::shared_ptr<Domain> shared_ptr; typedef boost::shared_ptr<Domain> shared_ptr;
@ -35,14 +40,10 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
values_.insert(v); values_.insert(v);
} }
/// Constructor /// Insert a value, non const :-(
Domain(const Domain& other)
: Constraint(other.keys_[0]), values_(other.values_) {}
/// insert a value, non const :-(
void insert(size_t value) { values_.insert(value); } void insert(size_t value) { values_.insert(value); }
/// erase a value, non const :-( /// Erase a value, non const :-(
void erase(size_t value) { values_.erase(value); } void erase(size_t value) { values_.erase(value); }
size_t nrValues() const { return values_.size(); } size_t nrValues() const { return values_.size(); }
@ -82,15 +83,17 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
* @param domains all other domains * @param domains all other domains
*/ */
bool ensureArcConsistency(size_t j, bool ensureArcConsistency(size_t j,
std::vector<Domain>& domains) const override; std::vector<Domain>* domains) const override;
/** /**
* Check for a value in domain that does not occur in any other connected * Check for a value in domain that does not occur in any other connected
* domain. If found, we make this a singleton... Called in * domain. If found, return a a new singleton domain...
* AllDiff::ensureArcConsistency * Called in AllDiff::ensureArcConsistency
* @param keys connected domains through alldiff * @param keys connected domains through alldiff
* @param keys other domains
*/ */
bool checkAllDiff(const KeyVector keys, std::vector<Domain>& domains); boost::optional<Domain> checkAllDiff(
const KeyVector keys, const std::vector<Domain>& domains) const;
/// Partially apply known values /// Partially apply known values
Constraint::shared_ptr partiallyApply(const Values& values) const override; Constraint::shared_ptr partiallyApply(const Values& values) const override;
@ -98,6 +101,6 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
/// Partially apply known values, domain version /// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply( Constraint::shared_ptr partiallyApply(
const std::vector<Domain>& domains) const override; const std::vector<Domain>& domains) const override;
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -5,74 +5,75 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam_unstable/discrete/Domain.h>
#include <gtsam_unstable/discrete/SingleValue.h> #include <gtsam_unstable/discrete/SingleValue.h>
#include <gtsam_unstable/discrete/Domain.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/base/Testable.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
namespace gtsam { namespace gtsam {
using namespace std; using namespace std;
/* ************************************************************************* */ /* ************************************************************************* */
void SingleValue::print(const string& s, const KeyFormatter& formatter) const { void SingleValue::print(const string& s,
cout << s << "SingleValue on " const KeyFormatter& formatter) const {
<< "j=" << formatter(keys_[0]) << " with value " << value_ << endl; cout << s << "SingleValue on " << "j=" << formatter(keys_[0])
} << " with value " << value_ << endl;
}
/* ************************************************************************* */ /* ************************************************************************* */
double SingleValue::operator()(const Values& values) const { double SingleValue::operator()(const Values& values) const {
return (double)(values.at(keys_[0]) == value_); return (double) (values.at(keys_[0]) == value_);
} }
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor SingleValue::toDecisionTreeFactor() const { DecisionTreeFactor SingleValue::toDecisionTreeFactor() const {
DiscreteKeys keys; DiscreteKeys keys;
keys += DiscreteKey(keys_[0], cardinality_); keys += DiscreteKey(keys_[0],cardinality_);
vector<double> table; vector<double> table;
for (size_t i1 = 0; i1 < cardinality_; i1++) table.push_back(i1 == value_); for (size_t i1 = 0; i1 < cardinality_; i1++)
table.push_back(i1 == value_);
DecisionTreeFactor converted(keys, table); DecisionTreeFactor converted(keys, table);
return converted; return converted;
} }
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const { DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const {
// TODO: can we do this more efficiently? // TODO: can we do this more efficiently?
return toDecisionTreeFactor() * f; return toDecisionTreeFactor() * f;
} }
/* ************************************************************************* */ /* ************************************************************************* */
bool SingleValue::ensureArcConsistency(size_t j, bool SingleValue::ensureArcConsistency(size_t j,
vector<Domain>& domains) const { vector<Domain>* domains) const {
if (j != keys_[0]) if (j != keys_[0])
throw invalid_argument("SingleValue check on wrong domain"); throw invalid_argument("SingleValue check on wrong domain");
Domain& D = domains[j]; Domain& D = domains->at(j);
if (D.isSingleton()) { if (D.isSingleton()) {
if (D.firstValue() != value_) throw runtime_error("Unsatisfiable"); if (D.firstValue() != value_) throw runtime_error("Unsatisfiable");
return false; return false;
} }
D = Domain(discreteKey(), value_); D = Domain(discreteKey(), value_);
return true; return true;
} }
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const { Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const {
Values::const_iterator it = values.find(keys_[0]); Values::const_iterator it = values.find(keys_[0]);
if (it != values.end() && it->second != value_) if (it != values.end() && it->second != value_) throw runtime_error(
throw runtime_error("SingleValue::partiallyApply: unsatisfiable"); "SingleValue::partiallyApply: unsatisfiable");
return boost::make_shared<SingleValue>(keys_[0], cardinality_, value_); return boost::make_shared<SingleValue>(keys_[0], cardinality_, value_);
} }
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr SingleValue::partiallyApply( Constraint::shared_ptr SingleValue::partiallyApply(
const vector<Domain>& domains) const { const vector<Domain>& domains) const {
const Domain& Dk = domains[keys_[0]]; const Domain& Dk = domains[keys_[0]];
if (Dk.isSingleton() && !Dk.contains(value_)) if (Dk.isSingleton() && !Dk.contains(value_)) throw runtime_error(
throw runtime_error("SingleValue::partiallyApply: unsatisfiable"); "SingleValue::partiallyApply: unsatisfiable");
return boost::make_shared<SingleValue>(discreteKey(), value_); return boost::make_shared<SingleValue>(discreteKey(), value_);
} }
/* ************************************************************************* */ /* ************************************************************************* */
} // namespace gtsam } // namespace gtsam

View File

@ -7,47 +7,48 @@
#pragma once #pragma once
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam_unstable/discrete/Constraint.h> #include <gtsam_unstable/discrete/Constraint.h>
namespace gtsam { namespace gtsam {
/** /**
* SingleValue constraint * SingleValue constraint: ensures a variable takes on a certain value.
* This could of course also be implemented by changing its `Domain`.
*/ */
class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { class GTSAM_UNSTABLE_EXPORT SingleValue: public Constraint {
/// Number of values
size_t cardinality_;
/// allowed value size_t cardinality_; /// < Number of values
size_t value_; size_t value_; ///< allowed value
DiscreteKey discreteKey() const { DiscreteKey discreteKey() const {
return DiscreteKey(keys_[0], cardinality_); return DiscreteKey(keys_[0],cardinality_);
} }
public: public:
typedef boost::shared_ptr<SingleValue> shared_ptr; typedef boost::shared_ptr<SingleValue> shared_ptr;
/// Constructor /// Construct from key, cardinality, and given value.
SingleValue(Key key, size_t n, size_t value) SingleValue(Key key, size_t n, size_t value) :
: Constraint(key), cardinality_(n), value_(value) {} Constraint(key), cardinality_(n), value_(value) {
}
/// Constructor /// Construct from DiscreteKey and given value.
SingleValue(const DiscreteKey& dkey, size_t value) SingleValue(const DiscreteKey& dkey, size_t value) :
: Constraint(dkey.first), cardinality_(dkey.second), value_(value) {} Constraint(dkey.first), cardinality_(dkey.second), value_(value) {
}
// print // print
void print(const std::string& s = "", const KeyFormatter& formatter = void print(const std::string& s = "",
DefaultKeyFormatter) const override; const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/// equals /// equals
bool equals(const DiscreteFactor& other, double tol) const override { bool equals(const DiscreteFactor& other, double tol) const override {
if (!dynamic_cast<const SingleValue*>(&other)) if(!dynamic_cast<const SingleValue*>(&other))
return false; return false;
else { else {
const SingleValue& f(static_cast<const SingleValue&>(other)); const SingleValue& f(static_cast<const SingleValue&>(other));
return (cardinality_ == f.cardinality_) && (value_ == f.value_); return (cardinality_==f.cardinality_) && (value_==f.value_);
} }
} }
@ -61,12 +62,12 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
/* /*
* Ensure Arc-consistency * Ensure Arc-consistency: just sets domain[j] to {value_}
* @param j domain to be checked * @param j domain to be checked
* @param domains all other domains * @param domains all other domains
*/ */
bool ensureArcConsistency(size_t j, bool ensureArcConsistency(size_t j,
std::vector<Domain>& domains) const override; std::vector<Domain>* domains) const override;
/// Partially apply known values /// Partially apply known values
Constraint::shared_ptr partiallyApply(const Values& values) const override; Constraint::shared_ptr partiallyApply(const Values& values) const override;
@ -74,6 +75,6 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
/// Partially apply known values, domain version /// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply( Constraint::shared_ptr partiallyApply(
const std::vector<Domain>& domains) const override; const std::vector<Domain>& domains) const override;
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -19,12 +19,33 @@ using namespace std;
using namespace gtsam; using namespace gtsam;
/* ************************************************************************* */ /* ************************************************************************* */
TEST_UNSAFE(BinaryAllDif, allInOne) { TEST(CSP, SingleValue) {
// Create keys and ordering // Create keys for Idaho, Arizona, and Utah, allowing two colors for each:
size_t nrColors = 3;
DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors);
// Check that a single value is equal to a decision stump with only one "1":
SingleValue singleValue(AZ, 2);
DecisionTreeFactor f1(AZ, "0 0 1");
EXPECT(assert_equal(f1, singleValue.toDecisionTreeFactor()));
// Create domains, laid out as a vector.
// TODO(dellaert): should be map??
vector<Domain> domains;
domains += Domain(ID), Domain(AZ), Domain(UT);
// Ensure arc-consistency: just wipes out values in AZ domain:
EXPECT(singleValue.ensureArcConsistency(1, &domains));
LONGS_EQUAL(3, domains[0].nrValues());
LONGS_EQUAL(1, domains[1].nrValues());
LONGS_EQUAL(3, domains[2].nrValues());
}
/* ************************************************************************* */
TEST(CSP, BinaryAllDif) {
// Create keys for Idaho, Arizona, and Utah, allowing 2 colors for each:
size_t nrColors = 2; size_t nrColors = 2;
// DiscreteKey ID("Idaho", nrColors), UT("Utah", nrColors), AZ("Arizona", DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors);
// nrColors);
DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors);
// Check construction and conversion // Check construction and conversion
BinaryAllDiff c1(ID, UT); BinaryAllDiff c1(ID, UT);
@ -36,16 +57,51 @@ TEST_UNSAFE(BinaryAllDif, allInOne) {
DecisionTreeFactor f2(UT & AZ, "0 1 1 0"); DecisionTreeFactor f2(UT & AZ, "0 1 1 0");
EXPECT(assert_equal(f2, c2.toDecisionTreeFactor())); EXPECT(assert_equal(f2, c2.toDecisionTreeFactor()));
// Check multiplication of factors with constraint:
DecisionTreeFactor f3 = f1 * f2; DecisionTreeFactor f3 = f1 * f2;
EXPECT(assert_equal(f3, c1 * f2)); EXPECT(assert_equal(f3, c1 * f2));
EXPECT(assert_equal(f3, c2 * f1)); EXPECT(assert_equal(f3, c2 * f1));
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST_UNSAFE(CSP, allInOne) { TEST(CSP, AllDiff) {
// Create keys and ordering // Create keys for Idaho, Arizona, and Utah, allowing two colors for each:
size_t nrColors = 3;
DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors);
// Check construction and conversion
vector<DiscreteKey> dkeys{ID, UT, AZ};
AllDiff alldiff(dkeys);
DecisionTreeFactor actual = alldiff.toDecisionTreeFactor();
// GTSAM_PRINT(actual);
actual.dot("actual");
DecisionTreeFactor f2(
ID & AZ & UT,
"0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0");
EXPECT(assert_equal(f2, actual));
// Create domains.
vector<Domain> domains;
domains += Domain(ID), Domain(AZ), Domain(UT);
// First constrict AZ domain:
SingleValue singleValue(AZ, 2);
EXPECT(singleValue.ensureArcConsistency(1, &domains));
// Arc-consistency
EXPECT(alldiff.ensureArcConsistency(0, &domains));
EXPECT(!alldiff.ensureArcConsistency(1, &domains));
EXPECT(alldiff.ensureArcConsistency(2, &domains));
LONGS_EQUAL(2, domains[0].nrValues());
LONGS_EQUAL(1, domains[1].nrValues());
LONGS_EQUAL(2, domains[2].nrValues());
}
/* ************************************************************************* */
TEST(CSP, allInOne) {
// Create keys for Idaho, Arizona, and Utah, allowing 3 colors for each:
size_t nrColors = 2; size_t nrColors = 2;
DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors);
// Create the CSP // Create the CSP
CSP csp; CSP csp;
@ -81,15 +137,12 @@ TEST_UNSAFE(CSP, allInOne) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST_UNSAFE(CSP, WesternUS) { TEST(CSP, WesternUS) {
// Create keys // Create keys for all states in Western US, with 4 color possibilities.
size_t nrColors = 4; size_t nrColors = 4;
DiscreteKey DiscreteKey WA(0, nrColors), OR(3, nrColors), CA(1, nrColors),
// Create ordering according to example in ND-CSP.lyx NV(2, nrColors), ID(8, nrColors), UT(9, nrColors), AZ(10, nrColors),
WA(0, nrColors), MT(4, nrColors), WY(5, nrColors), CO(7, nrColors), NM(6, nrColors);
OR(3, nrColors), CA(1, nrColors), NV(2, nrColors), ID(8, nrColors),
UT(9, nrColors), AZ(10, nrColors), MT(4, nrColors), WY(5, nrColors),
CO(7, nrColors), NM(6, nrColors);
// Create the CSP // Create the CSP
CSP csp; CSP csp;
@ -116,10 +169,12 @@ TEST_UNSAFE(CSP, WesternUS) {
csp.addAllDiff(WY, CO); csp.addAllDiff(WY, CO);
csp.addAllDiff(CO, NM); csp.addAllDiff(CO, NM);
// Solve // Create ordering according to example in ND-CSP.lyx
Ordering ordering; Ordering ordering;
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7), ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7),
Key(8), Key(9), Key(10); Key(8), Key(9), Key(10);
// Solve using that ordering:
CSP::sharedValues mpe = csp.optimalAssignment(ordering); CSP::sharedValues mpe = csp.optimalAssignment(ordering);
// GTSAM_PRINT(*mpe); // GTSAM_PRINT(*mpe);
CSP::Values expected; CSP::Values expected;
@ -143,34 +198,18 @@ TEST_UNSAFE(CSP, WesternUS) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST_UNSAFE(CSP, AllDiff) { TEST(CSP, ArcConsistency) {
// Create keys and ordering // Create keys for Idaho, Arizona, and Utah, allowing three colors for each:
size_t nrColors = 3; size_t nrColors = 3;
DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors);
// Create the CSP // Create the CSP using just one all-diff constraint, plus constrain Arizona.
CSP csp; CSP csp;
vector<DiscreteKey> dkeys; vector<DiscreteKey> dkeys{ID, UT, AZ};
dkeys += ID, UT, AZ;
csp.addAllDiff(dkeys); csp.addAllDiff(dkeys);
csp.addSingleValue(AZ, 2); csp.addSingleValue(AZ, 2);
// GTSAM_PRINT(csp); // GTSAM_PRINT(csp);
// Check construction and conversion
SingleValue s(AZ, 2);
DecisionTreeFactor f1(AZ, "0 0 1");
EXPECT(assert_equal(f1, s.toDecisionTreeFactor()));
// Check construction and conversion
AllDiff alldiff(dkeys);
DecisionTreeFactor actual = alldiff.toDecisionTreeFactor();
// GTSAM_PRINT(actual);
// actual.dot("actual");
DecisionTreeFactor f2(
ID & AZ & UT,
"0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0");
EXPECT(assert_equal(f2, actual));
// Check an invalid combination, with ID==UT==AZ all same color // Check an invalid combination, with ID==UT==AZ all same color
DiscreteFactor::Values invalid; DiscreteFactor::Values invalid;
invalid[ID.first] = 0; invalid[ID.first] = 0;
@ -192,14 +231,15 @@ TEST_UNSAFE(CSP, AllDiff) {
EXPECT(assert_equal(expected, *mpe)); EXPECT(assert_equal(expected, *mpe));
EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9);
// Arc-consistency // ensure arc-consistency, i.e., narrow domains...
vector<Domain> domains; vector<Domain> domains;
domains += Domain(ID), Domain(AZ), Domain(UT); domains += Domain(ID), Domain(AZ), Domain(UT);
SingleValue singleValue(AZ, 2); SingleValue singleValue(AZ, 2);
EXPECT(singleValue.ensureArcConsistency(1, domains)); AllDiff alldiff(dkeys);
EXPECT(alldiff.ensureArcConsistency(0, domains)); EXPECT(singleValue.ensureArcConsistency(1, &domains));
EXPECT(!alldiff.ensureArcConsistency(1, domains)); EXPECT(alldiff.ensureArcConsistency(0, &domains));
EXPECT(alldiff.ensureArcConsistency(2, domains)); EXPECT(!alldiff.ensureArcConsistency(1, &domains));
EXPECT(alldiff.ensureArcConsistency(2, &domains));
LONGS_EQUAL(2, domains[0].nrValues()); LONGS_EQUAL(2, domains[0].nrValues());
LONGS_EQUAL(1, domains[1].nrValues()); LONGS_EQUAL(1, domains[1].nrValues());
LONGS_EQUAL(2, domains[2].nrValues()); LONGS_EQUAL(2, domains[2].nrValues());
@ -222,6 +262,7 @@ TEST_UNSAFE(CSP, AllDiff) {
// full arc-consistency test // full arc-consistency test
csp.runArcConsistency(nrColors); csp.runArcConsistency(nrColors);
// GTSAM_PRINT(csp);
} }
/* ************************************************************************* */ /* ************************************************************************* */