Revamped arc consistency
parent
770fda9a26
commit
dd50975668
|
@ -5,105 +5,115 @@
|
|||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam_unstable/discrete/AllDiff.h>
|
||||
#include <gtsam_unstable/discrete/Domain.h>
|
||||
|
||||
#include <gtsam_unstable/discrete/AllDiff.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <boost/make_shared.hpp>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************* */
|
||||
AllDiff::AllDiff(const DiscreteKeys& dkeys) : Constraint(dkeys.indices()) {
|
||||
for (const DiscreteKey& dkey : dkeys) cardinalities_.insert(dkey);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const {
|
||||
std::cout << s << "AllDiff on ";
|
||||
for (Key dkey : keys_) std::cout << formatter(dkey) << " ";
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double AllDiff::operator()(const Values& values) const {
|
||||
std::set<size_t> taken; // record values taken by keys
|
||||
for (Key dkey : keys_) {
|
||||
size_t value = values.at(dkey); // get the value for that key
|
||||
if (taken.count(value)) return 0.0; // check if value alreday taken
|
||||
taken.insert(value); // if not, record it as taken and keep checking
|
||||
/* ************************************************************************* */
|
||||
AllDiff::AllDiff(const DiscreteKeys& dkeys) :
|
||||
Constraint(dkeys.indices()) {
|
||||
for(const DiscreteKey& dkey: dkeys)
|
||||
cardinalities_.insert(dkey);
|
||||
}
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor AllDiff::toDecisionTreeFactor() const {
|
||||
// We will do this by converting the allDif into many BinaryAllDiff
|
||||
// constraints
|
||||
DecisionTreeFactor converted;
|
||||
size_t nrKeys = keys_.size();
|
||||
for (size_t i1 = 0; i1 < nrKeys; i1++)
|
||||
for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) {
|
||||
BinaryAllDiff binary12(discreteKey(i1), discreteKey(i2));
|
||||
converted = converted * binary12.toDecisionTreeFactor();
|
||||
/* ************************************************************************* */
|
||||
void AllDiff::print(const std::string& s,
|
||||
const KeyFormatter& formatter) const {
|
||||
std::cout << s << "AllDiff on ";
|
||||
for (Key dkey: keys_)
|
||||
std::cout << formatter(dkey) << " ";
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double AllDiff::operator()(const Values& values) const {
|
||||
std::set < size_t > taken; // record values taken by keys
|
||||
for(Key dkey: keys_) {
|
||||
size_t value = values.at(dkey); // get the value for that key
|
||||
if (taken.count(value)) return 0.0;// check if value alreday taken
|
||||
taken.insert(value);// if not, record it as taken and keep checking
|
||||
}
|
||||
return converted;
|
||||
}
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const {
|
||||
// TODO: can we do this more efficiently?
|
||||
return toDecisionTreeFactor() * f;
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor AllDiff::toDecisionTreeFactor() const {
|
||||
// We will do this by converting the allDif into many BinaryAllDiff constraints
|
||||
DecisionTreeFactor converted;
|
||||
size_t nrKeys = keys_.size();
|
||||
for (size_t i1 = 0; i1 < nrKeys; i1++)
|
||||
for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) {
|
||||
BinaryAllDiff binary12(discreteKey(i1),discreteKey(i2));
|
||||
converted = converted * binary12.toDecisionTreeFactor();
|
||||
}
|
||||
return converted;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool AllDiff::ensureArcConsistency(size_t j,
|
||||
std::vector<Domain>& domains) const {
|
||||
// Though strictly not part of allDiff, we check for
|
||||
// a value in domains[j] that does not occur in any other connected domain.
|
||||
// If found, we make this a singleton...
|
||||
// TODO: make a new constraint where this really is true
|
||||
Domain& Dj = domains[j];
|
||||
if (Dj.checkAllDiff(keys_, domains)) return true;
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const {
|
||||
// TODO: can we do this more efficiently?
|
||||
return toDecisionTreeFactor() * f;
|
||||
}
|
||||
|
||||
// Check all other domains for singletons and erase corresponding values
|
||||
// This is the same as arc-consistency on the equivalent binary constraints
|
||||
bool changed = false;
|
||||
for (Key k : keys_)
|
||||
if (k != j) {
|
||||
const Domain& Dk = domains[k];
|
||||
if (Dk.isSingleton()) { // check if singleton
|
||||
size_t value = Dk.firstValue();
|
||||
if (Dj.contains(value)) {
|
||||
Dj.erase(value); // erase value if true
|
||||
changed = true;
|
||||
/* ************************************************************************* */
|
||||
bool AllDiff::ensureArcConsistency(size_t j,
|
||||
std::vector<Domain>* domains) const {
|
||||
// We are changing the domain of variable j.
|
||||
// TODO(dellaert): confusing, I thought we were changing others...
|
||||
Domain& Dj = domains->at(j);
|
||||
|
||||
// Though strictly not part of allDiff, we check for
|
||||
// a value in domains[j] that does not occur in any other connected domain.
|
||||
// If found, we make this a singleton...
|
||||
// TODO: make a new constraint where this really is true
|
||||
boost::optional<Domain> maybeChanged = Dj.checkAllDiff(keys_, *domains);
|
||||
if (maybeChanged) {
|
||||
Dj = *maybeChanged;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check all other domains for singletons and erase corresponding values.
|
||||
// This is the same as arc-consistency on the equivalent binary constraints
|
||||
bool changed = false;
|
||||
for (Key k : keys_)
|
||||
if (k != j) {
|
||||
const Domain& Dk = domains->at(k);
|
||||
if (Dk.isSingleton()) { // check if singleton
|
||||
size_t value = Dk.firstValue();
|
||||
if (Dj.contains(value)) {
|
||||
Dj.erase(value); // erase value if true
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const {
|
||||
DiscreteKeys newKeys;
|
||||
// loop over keys and add them only if they do not appear in values
|
||||
for (Key k : keys_)
|
||||
if (values.find(k) == values.end()) {
|
||||
newKeys.push_back(DiscreteKey(k, cardinalities_.at(k)));
|
||||
}
|
||||
return boost::make_shared<AllDiff>(newKeys);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
Constraint::shared_ptr AllDiff::partiallyApply(
|
||||
const std::vector<Domain>& domains) const {
|
||||
DiscreteFactor::Values known;
|
||||
for (Key k : keys_) {
|
||||
const Domain& Dk = domains[k];
|
||||
if (Dk.isSingleton()) known[k] = Dk.firstValue();
|
||||
return changed;
|
||||
}
|
||||
return partiallyApply(known);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
} // namespace gtsam
|
||||
/* ************************************************************************* */
|
||||
Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const {
|
||||
DiscreteKeys newKeys;
|
||||
// loop over keys and add them only if they do not appear in values
|
||||
for(Key k: keys_)
|
||||
if (values.find(k) == values.end()) {
|
||||
newKeys.push_back(DiscreteKey(k,cardinalities_.at(k)));
|
||||
}
|
||||
return boost::make_shared<AllDiff>(newKeys);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
Constraint::shared_ptr AllDiff::partiallyApply(
|
||||
const std::vector<Domain>& domains) const {
|
||||
DiscreteFactor::Values known;
|
||||
for(Key k: keys_) {
|
||||
const Domain& Dk = domains[k];
|
||||
if (Dk.isSingleton())
|
||||
known[k] = Dk.firstValue();
|
||||
}
|
||||
return partiallyApply(known);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -7,71 +7,70 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam_unstable/discrete/BinaryAllDiff.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* General AllDiff constraint
|
||||
* Returns 1 if values for all keys are different, 0 otherwise
|
||||
* DiscreteFactors are all awkward in that they have to store two types of keys:
|
||||
* for each variable we have a Key and an Key. In this factor, we
|
||||
* keep the Indices locally, and the Indices are stored in IndexFactor.
|
||||
*/
|
||||
class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
|
||||
std::map<Key, size_t> cardinalities_;
|
||||
|
||||
DiscreteKey discreteKey(size_t i) const {
|
||||
Key j = keys_[i];
|
||||
return DiscreteKey(j, cardinalities_.at(j));
|
||||
}
|
||||
|
||||
public:
|
||||
/// Constructor
|
||||
AllDiff(const DiscreteKeys& dkeys);
|
||||
|
||||
// print
|
||||
void print(const std::string& s = "", const KeyFormatter& formatter =
|
||||
DefaultKeyFormatter) const override;
|
||||
|
||||
/// equals
|
||||
bool equals(const DiscreteFactor& other, double tol) const override {
|
||||
if (!dynamic_cast<const AllDiff*>(&other))
|
||||
return false;
|
||||
else {
|
||||
const AllDiff& f(static_cast<const AllDiff&>(other));
|
||||
return cardinalities_.size() == f.cardinalities_.size() &&
|
||||
std::equal(cardinalities_.begin(), cardinalities_.end(),
|
||||
f.cardinalities_.begin());
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate value = expensive !
|
||||
double operator()(const Values& values) const override;
|
||||
|
||||
/// Convert into a decisiontree, can be *very* expensive !
|
||||
DecisionTreeFactor toDecisionTreeFactor() const override;
|
||||
|
||||
/// Multiply into a decisiontree
|
||||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
|
||||
|
||||
/*
|
||||
* Ensure Arc-consistency
|
||||
* Arc-consistency involves creating binaryAllDiff constraints
|
||||
* In which case the combinatorial hyper-arc explosion disappears.
|
||||
* @param j domain to be checked
|
||||
* @param domains all other domains
|
||||
/**
|
||||
* General AllDiff constraint.
|
||||
* Returns 1 if values for all keys are different, 0 otherwise.
|
||||
*/
|
||||
bool ensureArcConsistency(size_t j,
|
||||
std::vector<Domain>& domains) const override;
|
||||
class GTSAM_UNSTABLE_EXPORT AllDiff: public Constraint {
|
||||
|
||||
/// Partially apply known values
|
||||
Constraint::shared_ptr partiallyApply(const Values&) const override;
|
||||
std::map<Key,size_t> cardinalities_;
|
||||
|
||||
/// Partially apply known values, domain version
|
||||
Constraint::shared_ptr partiallyApply(
|
||||
const std::vector<Domain>&) const override;
|
||||
};
|
||||
DiscreteKey discreteKey(size_t i) const {
|
||||
Key j = keys_[i];
|
||||
return DiscreteKey(j,cardinalities_.at(j));
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
public:
|
||||
|
||||
/// Construct from keys.
|
||||
AllDiff(const DiscreteKeys& dkeys);
|
||||
|
||||
// print
|
||||
void print(const std::string& s = "",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/// equals
|
||||
bool equals(const DiscreteFactor& other, double tol) const override {
|
||||
if(!dynamic_cast<const AllDiff*>(&other))
|
||||
return false;
|
||||
else {
|
||||
const AllDiff& f(static_cast<const AllDiff&>(other));
|
||||
return cardinalities_.size() == f.cardinalities_.size()
|
||||
&& std::equal(cardinalities_.begin(), cardinalities_.end(),
|
||||
f.cardinalities_.begin());
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate value = expensive !
|
||||
double operator()(const Values& values) const override;
|
||||
|
||||
/// Convert into a decisiontree, can be *very* expensive !
|
||||
DecisionTreeFactor toDecisionTreeFactor() const override;
|
||||
|
||||
/// Multiply into a decisiontree
|
||||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
|
||||
|
||||
/*
|
||||
* Ensure Arc-consistency
|
||||
* Arc-consistency involves creating binaryAllDiff constraints
|
||||
* In which case the combinatorial hyper-arc explosion disappears.
|
||||
* @param j domain to be checked
|
||||
* @param (in/out) domains all other domains
|
||||
*/
|
||||
bool ensureArcConsistency(size_t j,
|
||||
std::vector<Domain>* domains) const override;
|
||||
|
||||
/// Partially apply known values
|
||||
Constraint::shared_ptr partiallyApply(const Values&) const override;
|
||||
|
||||
/// Partially apply known values, domain version
|
||||
Constraint::shared_ptr partiallyApply(
|
||||
const std::vector<Domain>&) const override;
|
||||
};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -7,93 +7,92 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam_unstable/discrete/Constraint.h>
|
||||
#include <gtsam_unstable/discrete/Domain.h>
|
||||
#include <gtsam_unstable/discrete/Constraint.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* Binary AllDiff constraint
|
||||
* Returns 1 if values for two keys are different, 0 otherwise
|
||||
* DiscreteFactors are all awkward in that they have to store two types of keys:
|
||||
* for each variable we have a Index and an Index. In this factor, we
|
||||
* keep the Indices locally, and the Indices are stored in IndexFactor.
|
||||
*/
|
||||
class BinaryAllDiff : public Constraint {
|
||||
size_t cardinality0_, cardinality1_; /// cardinality
|
||||
|
||||
public:
|
||||
/// Constructor
|
||||
BinaryAllDiff(const DiscreteKey& key1, const DiscreteKey& key2)
|
||||
: Constraint(key1.first, key2.first),
|
||||
cardinality0_(key1.second),
|
||||
cardinality1_(key2.second) {}
|
||||
|
||||
// print
|
||||
void print(
|
||||
const std::string& s = "",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override {
|
||||
std::cout << s << "BinaryAllDiff on " << formatter(keys_[0]) << " and "
|
||||
<< formatter(keys_[1]) << std::endl;
|
||||
}
|
||||
|
||||
/// equals
|
||||
bool equals(const DiscreteFactor& other, double tol) const override {
|
||||
if (!dynamic_cast<const BinaryAllDiff*>(&other))
|
||||
return false;
|
||||
else {
|
||||
const BinaryAllDiff& f(static_cast<const BinaryAllDiff&>(other));
|
||||
return (cardinality0_ == f.cardinality0_) &&
|
||||
(cardinality1_ == f.cardinality1_);
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate value
|
||||
double operator()(const Values& values) const override {
|
||||
return (double)(values.at(keys_[0]) != values.at(keys_[1]));
|
||||
}
|
||||
|
||||
/// Convert into a decisiontree
|
||||
DecisionTreeFactor toDecisionTreeFactor() const override {
|
||||
DiscreteKeys keys;
|
||||
keys.push_back(DiscreteKey(keys_[0], cardinality0_));
|
||||
keys.push_back(DiscreteKey(keys_[1], cardinality1_));
|
||||
std::vector<double> table;
|
||||
for (size_t i1 = 0; i1 < cardinality0_; i1++)
|
||||
for (size_t i2 = 0; i2 < cardinality1_; i2++) table.push_back(i1 != i2);
|
||||
DecisionTreeFactor converted(keys, table);
|
||||
return converted;
|
||||
}
|
||||
|
||||
/// Multiply into a decisiontree
|
||||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
|
||||
// TODO: can we do this more efficiently?
|
||||
return toDecisionTreeFactor() * f;
|
||||
}
|
||||
|
||||
/*
|
||||
* Ensure Arc-consistency
|
||||
* @param j domain to be checked
|
||||
* @param domains all other domains
|
||||
/**
|
||||
* Binary AllDiff constraint
|
||||
* Returns 1 if values for two keys are different, 0 otherwise.
|
||||
*/
|
||||
bool ensureArcConsistency(size_t j,
|
||||
std::vector<Domain>& domains) const override {
|
||||
// throw std::runtime_error(
|
||||
// "BinaryAllDiff::ensureArcConsistency not implemented");
|
||||
return false;
|
||||
}
|
||||
class BinaryAllDiff: public Constraint {
|
||||
|
||||
/// Partially apply known values
|
||||
Constraint::shared_ptr partiallyApply(const Values&) const override {
|
||||
throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented");
|
||||
}
|
||||
size_t cardinality0_, cardinality1_; /// cardinality
|
||||
|
||||
/// Partially apply known values, domain version
|
||||
Constraint::shared_ptr partiallyApply(
|
||||
const std::vector<Domain>&) const override {
|
||||
throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented");
|
||||
}
|
||||
};
|
||||
public:
|
||||
|
||||
} // namespace gtsam
|
||||
/// Constructor
|
||||
BinaryAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) :
|
||||
Constraint(key1.first, key2.first),
|
||||
cardinality0_(key1.second), cardinality1_(key2.second) {
|
||||
}
|
||||
|
||||
// print
|
||||
void print(const std::string& s = "",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override {
|
||||
std::cout << s << "BinaryAllDiff on " << formatter(keys_[0]) << " and "
|
||||
<< formatter(keys_[1]) << std::endl;
|
||||
}
|
||||
|
||||
/// equals
|
||||
bool equals(const DiscreteFactor& other, double tol) const override {
|
||||
if(!dynamic_cast<const BinaryAllDiff*>(&other))
|
||||
return false;
|
||||
else {
|
||||
const BinaryAllDiff& f(static_cast<const BinaryAllDiff&>(other));
|
||||
return (cardinality0_==f.cardinality0_) && (cardinality1_==f.cardinality1_);
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate value
|
||||
double operator()(const Values& values) const override {
|
||||
return (double) (values.at(keys_[0]) != values.at(keys_[1]));
|
||||
}
|
||||
|
||||
/// Convert into a decisiontree
|
||||
DecisionTreeFactor toDecisionTreeFactor() const override {
|
||||
DiscreteKeys keys;
|
||||
keys.push_back(DiscreteKey(keys_[0],cardinality0_));
|
||||
keys.push_back(DiscreteKey(keys_[1],cardinality1_));
|
||||
std::vector<double> table;
|
||||
for (size_t i1 = 0; i1 < cardinality0_; i1++)
|
||||
for (size_t i2 = 0; i2 < cardinality1_; i2++)
|
||||
table.push_back(i1 != i2);
|
||||
DecisionTreeFactor converted(keys, table);
|
||||
return converted;
|
||||
}
|
||||
|
||||
/// Multiply into a decisiontree
|
||||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
|
||||
// TODO: can we do this more efficiently?
|
||||
return toDecisionTreeFactor() * f;
|
||||
}
|
||||
|
||||
/*
|
||||
* Ensure Arc-consistency
|
||||
* @param j domain to be checked
|
||||
* @param domains all other domains
|
||||
*/
|
||||
///
|
||||
bool ensureArcConsistency(size_t j,
|
||||
std::vector<Domain>* domains) const override {
|
||||
throw std::runtime_error(
|
||||
"BinaryAllDiff::ensureArcConsistency not implemented");
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Partially apply known values
|
||||
Constraint::shared_ptr partiallyApply(const Values&) const override {
|
||||
throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented");
|
||||
}
|
||||
|
||||
/// Partially apply known values, domain version
|
||||
Constraint::shared_ptr partiallyApply(
|
||||
const std::vector<Domain>&) const override {
|
||||
throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -56,12 +56,11 @@ void CSP::runArcConsistency(size_t cardinality, size_t nrIterations,
|
|||
// if not already a singleton
|
||||
if (!domains[v].isSingleton()) {
|
||||
// get the constraint and call its ensureArcConsistency method
|
||||
Constraint::shared_ptr constraint =
|
||||
boost::dynamic_pointer_cast<Constraint>((*this)[f]);
|
||||
auto constraint = boost::dynamic_pointer_cast<Constraint>((*this)[f]);
|
||||
if (!constraint)
|
||||
throw runtime_error("CSP:runArcConsistency: non-constraint factor");
|
||||
changed[v] =
|
||||
constraint->ensureArcConsistency(v, domains) || changed[v];
|
||||
constraint->ensureArcConsistency(v, &domains) || changed[v];
|
||||
}
|
||||
} // f
|
||||
if (changed[v]) anyChange = true;
|
||||
|
|
|
@ -17,68 +17,79 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DiscreteFactor.h>
|
||||
#include <gtsam_unstable/dllexport.h>
|
||||
|
||||
#include <gtsam/discrete/DiscreteFactor.h>
|
||||
#include <boost/assign.hpp>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
class Domain;
|
||||
class Domain;
|
||||
|
||||
/**
|
||||
* Base class for discrete probabilistic factors
|
||||
* The most general one is the derived DecisionTreeFactor
|
||||
*/
|
||||
class Constraint : public DiscreteFactor {
|
||||
public:
|
||||
typedef boost::shared_ptr<Constraint> shared_ptr;
|
||||
|
||||
protected:
|
||||
/// Construct n-way factor
|
||||
Constraint(const KeyVector& js) : DiscreteFactor(js) {}
|
||||
|
||||
/// Construct unary factor
|
||||
Constraint(Key j) : DiscreteFactor(boost::assign::cref_list_of<1>(j)) {}
|
||||
|
||||
/// Construct binary factor
|
||||
Constraint(Key j1, Key j2)
|
||||
: DiscreteFactor(boost::assign::cref_list_of<2>(j1)(j2)) {}
|
||||
|
||||
/// construct from container
|
||||
template <class KeyIterator>
|
||||
Constraint(KeyIterator beginKey, KeyIterator endKey)
|
||||
: DiscreteFactor(beginKey, endKey) {}
|
||||
|
||||
public:
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/// Default constructor for I/O
|
||||
Constraint();
|
||||
|
||||
/// Virtual destructor
|
||||
~Constraint() override {}
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/*
|
||||
* Ensure Arc-consistency
|
||||
* @param j domain to be checked
|
||||
* @param domains all other domains
|
||||
/**
|
||||
* Base class for constraint factors
|
||||
* Derived classes include SingleValue, BinaryAllDiff, and AllDiff.
|
||||
*/
|
||||
virtual bool ensureArcConsistency(size_t j,
|
||||
std::vector<Domain>& domains) const = 0;
|
||||
class GTSAM_EXPORT Constraint : public DiscreteFactor {
|
||||
|
||||
/// Partially apply known values
|
||||
virtual shared_ptr partiallyApply(const Values&) const = 0;
|
||||
public:
|
||||
|
||||
/// Partially apply known values, domain version
|
||||
virtual shared_ptr partiallyApply(const std::vector<Domain>&) const = 0;
|
||||
/// @}
|
||||
};
|
||||
typedef boost::shared_ptr<Constraint> shared_ptr;
|
||||
|
||||
protected:
|
||||
|
||||
/// Construct unary constraint factor.
|
||||
Constraint(Key j) :
|
||||
DiscreteFactor(boost::assign::cref_list_of<1>(j)) {
|
||||
}
|
||||
|
||||
/// Construct binary constraint factor.
|
||||
Constraint(Key j1, Key j2) :
|
||||
DiscreteFactor(boost::assign::cref_list_of<2>(j1)(j2)) {
|
||||
}
|
||||
|
||||
/// Construct n-way constraint factor.
|
||||
Constraint(const KeyVector& js) :
|
||||
DiscreteFactor(js) {
|
||||
}
|
||||
|
||||
/// construct from container
|
||||
template<class KeyIterator>
|
||||
Constraint(KeyIterator beginKey, KeyIterator endKey) :
|
||||
DiscreteFactor(beginKey, endKey) {
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/// Default constructor for I/O
|
||||
Constraint();
|
||||
|
||||
/// Virtual destructor
|
||||
~Constraint() override {}
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/*
|
||||
* Ensure Arc-consistency, possibly changing domains of connected variables.
|
||||
* @param j domain to be checked
|
||||
* @param (in/out) domains all other domains
|
||||
* @return true if domains were changed, false otherwise.
|
||||
*/
|
||||
virtual bool ensureArcConsistency(size_t j,
|
||||
std::vector<Domain>* domains) const = 0;
|
||||
|
||||
/// Partially apply known values
|
||||
virtual shared_ptr partiallyApply(const Values&) const = 0;
|
||||
|
||||
|
||||
/// Partially apply known values, domain version
|
||||
virtual shared_ptr partiallyApply(const std::vector<Domain>&) const = 0;
|
||||
/// @}
|
||||
};
|
||||
// DiscreteFactor
|
||||
|
||||
} // namespace gtsam
|
||||
}// namespace gtsam
|
||||
|
|
|
@ -5,89 +5,90 @@
|
|||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam_unstable/discrete/Domain.h>
|
||||
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <boost/make_shared.hpp>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
using namespace std;
|
||||
using namespace std;
|
||||
|
||||
/* ************************************************************************* */
|
||||
void Domain::print(const string& s, const KeyFormatter& formatter) const {
|
||||
// cout << s << ": Domain on " << formatter(keys_[0]) << " (j=" <<
|
||||
// formatter(keys_[0]) << ") with values";
|
||||
// for (size_t v: values_) cout << " " << v;
|
||||
// cout << endl;
|
||||
for (size_t v : values_) cout << v;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double Domain::operator()(const Values& values) const {
|
||||
return contains(values.at(keys_[0]));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor Domain::toDecisionTreeFactor() const {
|
||||
DiscreteKeys keys;
|
||||
keys += DiscreteKey(keys_[0], cardinality_);
|
||||
vector<double> table;
|
||||
for (size_t i1 = 0; i1 < cardinality_; ++i1) table.push_back(contains(i1));
|
||||
DecisionTreeFactor converted(keys, table);
|
||||
return converted;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const {
|
||||
// TODO: can we do this more efficiently?
|
||||
return toDecisionTreeFactor() * f;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool Domain::ensureArcConsistency(size_t j, vector<Domain>& domains) const {
|
||||
if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain");
|
||||
Domain& D = domains[j];
|
||||
for (size_t value : values_)
|
||||
if (!D.contains(value)) throw runtime_error("Unsatisfiable");
|
||||
D = *this;
|
||||
return true;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool Domain::checkAllDiff(const KeyVector keys, vector<Domain>& domains) {
|
||||
Key j = keys_[0];
|
||||
// for all values in this domain
|
||||
for (size_t value : values_) {
|
||||
// for all connected domains
|
||||
for (Key k : keys)
|
||||
// if any domain contains the value we cannot make this domain singleton
|
||||
if (k != j && domains[k].contains(value)) goto found;
|
||||
values_.clear();
|
||||
values_.insert(value);
|
||||
return true; // we changed it
|
||||
found:;
|
||||
/* ************************************************************************* */
|
||||
void Domain::print(const string& s,
|
||||
const KeyFormatter& formatter) const {
|
||||
cout << s << ": Domain on " << formatter(keys_[0]) << " (j=" <<
|
||||
formatter(keys_[0]) << ") with values";
|
||||
for (size_t v: values_) cout << " " << v;
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double Domain::operator()(const Values& values) const {
|
||||
return contains(values.at(keys_[0]));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor Domain::toDecisionTreeFactor() const {
|
||||
DiscreteKeys keys;
|
||||
keys += DiscreteKey(keys_[0],cardinality_);
|
||||
vector<double> table;
|
||||
for (size_t i1 = 0; i1 < cardinality_; ++i1)
|
||||
table.push_back(contains(i1));
|
||||
DecisionTreeFactor converted(keys, table);
|
||||
return converted;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const {
|
||||
// TODO: can we do this more efficiently?
|
||||
return toDecisionTreeFactor() * f;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool Domain::ensureArcConsistency(size_t j, vector<Domain>* domains) const {
|
||||
if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain");
|
||||
Domain& D = domains->at(j);
|
||||
for(size_t value: values_)
|
||||
if (!D.contains(value)) throw runtime_error("Unsatisfiable");
|
||||
D = *this;
|
||||
return true;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
boost::optional<Domain> Domain::checkAllDiff(
|
||||
const KeyVector keys, const vector<Domain>& domains) const {
|
||||
Key j = keys_[0];
|
||||
// for all values in this domain
|
||||
for (const size_t value : values_) {
|
||||
// for all connected domains
|
||||
for (const Key k : keys)
|
||||
// if any domain contains the value we cannot make this domain singleton
|
||||
if (k != j && domains[k].contains(value)) goto found;
|
||||
// Otherwise: return a singleton:
|
||||
return Domain(this->discreteKey(), value);
|
||||
found:;
|
||||
}
|
||||
return boost::none; // we did not change it
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
Constraint::shared_ptr Domain::partiallyApply(
|
||||
const Values& values) const {
|
||||
Values::const_iterator it = values.find(keys_[0]);
|
||||
if (it != values.end() && !contains(it->second)) throw runtime_error(
|
||||
"Domain::partiallyApply: unsatisfiable");
|
||||
return boost::make_shared < Domain > (*this);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
Constraint::shared_ptr Domain::partiallyApply(
|
||||
const vector<Domain>& domains) const {
|
||||
const Domain& Dk = domains[keys_[0]];
|
||||
if (Dk.isSingleton() && !contains(*Dk.begin())) throw runtime_error(
|
||||
"Domain::partiallyApply: unsatisfiable");
|
||||
return boost::make_shared < Domain > (Dk);
|
||||
}
|
||||
return false; // we did not change it
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
Constraint::shared_ptr Domain::partiallyApply(const Values& values) const {
|
||||
Values::const_iterator it = values.find(keys_[0]);
|
||||
if (it != values.end() && !contains(it->second))
|
||||
throw runtime_error("Domain::partiallyApply: unsatisfiable");
|
||||
return boost::make_shared<Domain>(*this);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
Constraint::shared_ptr Domain::partiallyApply(
|
||||
const vector<Domain>& domains) const {
|
||||
const Domain& Dk = domains[keys_[0]];
|
||||
if (Dk.isSingleton() && !contains(*Dk.begin()))
|
||||
throw runtime_error("Domain::partiallyApply: unsatisfiable");
|
||||
return boost::make_shared<Domain>(Dk);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
} // namespace gtsam
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -7,18 +7,23 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam_unstable/discrete/Constraint.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* Domain restriction constraint
|
||||
* The Domain class represents a constraint that restricts the possible values a
|
||||
* particular variable, with given key, can take on.
|
||||
*/
|
||||
class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
|
||||
size_t cardinality_; /// Cardinality
|
||||
std::set<size_t> values_; /// allowed values
|
||||
|
||||
DiscreteKey discreteKey() const {
|
||||
return DiscreteKey(keys_[0], cardinality_);
|
||||
}
|
||||
|
||||
public:
|
||||
typedef boost::shared_ptr<Domain> shared_ptr;
|
||||
|
||||
|
@ -35,14 +40,10 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
|
|||
values_.insert(v);
|
||||
}
|
||||
|
||||
/// Constructor
|
||||
Domain(const Domain& other)
|
||||
: Constraint(other.keys_[0]), values_(other.values_) {}
|
||||
|
||||
/// insert a value, non const :-(
|
||||
/// Insert a value, non const :-(
|
||||
void insert(size_t value) { values_.insert(value); }
|
||||
|
||||
/// erase a value, non const :-(
|
||||
/// Erase a value, non const :-(
|
||||
void erase(size_t value) { values_.erase(value); }
|
||||
|
||||
size_t nrValues() const { return values_.size(); }
|
||||
|
@ -82,15 +83,17 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
|
|||
* @param domains all other domains
|
||||
*/
|
||||
bool ensureArcConsistency(size_t j,
|
||||
std::vector<Domain>& domains) const override;
|
||||
std::vector<Domain>* domains) const override;
|
||||
|
||||
/**
|
||||
* Check for a value in domain that does not occur in any other connected
|
||||
* domain. If found, we make this a singleton... Called in
|
||||
* AllDiff::ensureArcConsistency
|
||||
* @param keys connected domains through alldiff
|
||||
* Check for a value in domain that does not occur in any other connected
|
||||
* domain. If found, return a a new singleton domain...
|
||||
* Called in AllDiff::ensureArcConsistency
|
||||
* @param keys connected domains through alldiff
|
||||
* @param keys other domains
|
||||
*/
|
||||
bool checkAllDiff(const KeyVector keys, std::vector<Domain>& domains);
|
||||
boost::optional<Domain> checkAllDiff(
|
||||
const KeyVector keys, const std::vector<Domain>& domains) const;
|
||||
|
||||
/// Partially apply known values
|
||||
Constraint::shared_ptr partiallyApply(const Values& values) const override;
|
||||
|
@ -98,6 +101,6 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
|
|||
/// Partially apply known values, domain version
|
||||
Constraint::shared_ptr partiallyApply(
|
||||
const std::vector<Domain>& domains) const override;
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace gtsam
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -5,74 +5,75 @@
|
|||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam_unstable/discrete/Domain.h>
|
||||
#include <gtsam_unstable/discrete/SingleValue.h>
|
||||
|
||||
#include <gtsam_unstable/discrete/Domain.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <boost/make_shared.hpp>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
using namespace std;
|
||||
using namespace std;
|
||||
|
||||
/* ************************************************************************* */
|
||||
void SingleValue::print(const string& s, const KeyFormatter& formatter) const {
|
||||
cout << s << "SingleValue on "
|
||||
<< "j=" << formatter(keys_[0]) << " with value " << value_ << endl;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double SingleValue::operator()(const Values& values) const {
|
||||
return (double)(values.at(keys_[0]) == value_);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor SingleValue::toDecisionTreeFactor() const {
|
||||
DiscreteKeys keys;
|
||||
keys += DiscreteKey(keys_[0], cardinality_);
|
||||
vector<double> table;
|
||||
for (size_t i1 = 0; i1 < cardinality_; i1++) table.push_back(i1 == value_);
|
||||
DecisionTreeFactor converted(keys, table);
|
||||
return converted;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const {
|
||||
// TODO: can we do this more efficiently?
|
||||
return toDecisionTreeFactor() * f;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool SingleValue::ensureArcConsistency(size_t j,
|
||||
vector<Domain>& domains) const {
|
||||
if (j != keys_[0])
|
||||
throw invalid_argument("SingleValue check on wrong domain");
|
||||
Domain& D = domains[j];
|
||||
if (D.isSingleton()) {
|
||||
if (D.firstValue() != value_) throw runtime_error("Unsatisfiable");
|
||||
return false;
|
||||
/* ************************************************************************* */
|
||||
void SingleValue::print(const string& s,
|
||||
const KeyFormatter& formatter) const {
|
||||
cout << s << "SingleValue on " << "j=" << formatter(keys_[0])
|
||||
<< " with value " << value_ << endl;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double SingleValue::operator()(const Values& values) const {
|
||||
return (double) (values.at(keys_[0]) == value_);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor SingleValue::toDecisionTreeFactor() const {
|
||||
DiscreteKeys keys;
|
||||
keys += DiscreteKey(keys_[0],cardinality_);
|
||||
vector<double> table;
|
||||
for (size_t i1 = 0; i1 < cardinality_; i1++)
|
||||
table.push_back(i1 == value_);
|
||||
DecisionTreeFactor converted(keys, table);
|
||||
return converted;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const {
|
||||
// TODO: can we do this more efficiently?
|
||||
return toDecisionTreeFactor() * f;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool SingleValue::ensureArcConsistency(size_t j,
|
||||
vector<Domain>* domains) const {
|
||||
if (j != keys_[0])
|
||||
throw invalid_argument("SingleValue check on wrong domain");
|
||||
Domain& D = domains->at(j);
|
||||
if (D.isSingleton()) {
|
||||
if (D.firstValue() != value_) throw runtime_error("Unsatisfiable");
|
||||
return false;
|
||||
}
|
||||
D = Domain(discreteKey(), value_);
|
||||
return true;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const {
|
||||
Values::const_iterator it = values.find(keys_[0]);
|
||||
if (it != values.end() && it->second != value_) throw runtime_error(
|
||||
"SingleValue::partiallyApply: unsatisfiable");
|
||||
return boost::make_shared<SingleValue>(keys_[0], cardinality_, value_);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
Constraint::shared_ptr SingleValue::partiallyApply(
|
||||
const vector<Domain>& domains) const {
|
||||
const Domain& Dk = domains[keys_[0]];
|
||||
if (Dk.isSingleton() && !Dk.contains(value_)) throw runtime_error(
|
||||
"SingleValue::partiallyApply: unsatisfiable");
|
||||
return boost::make_shared<SingleValue>(discreteKey(), value_);
|
||||
}
|
||||
D = Domain(discreteKey(), value_);
|
||||
return true;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const {
|
||||
Values::const_iterator it = values.find(keys_[0]);
|
||||
if (it != values.end() && it->second != value_)
|
||||
throw runtime_error("SingleValue::partiallyApply: unsatisfiable");
|
||||
return boost::make_shared<SingleValue>(keys_[0], cardinality_, value_);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
Constraint::shared_ptr SingleValue::partiallyApply(
|
||||
const vector<Domain>& domains) const {
|
||||
const Domain& Dk = domains[keys_[0]];
|
||||
if (Dk.isSingleton() && !Dk.contains(value_))
|
||||
throw runtime_error("SingleValue::partiallyApply: unsatisfiable");
|
||||
return boost::make_shared<SingleValue>(discreteKey(), value_);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
} // namespace gtsam
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -7,73 +7,74 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam_unstable/discrete/Constraint.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* SingleValue constraint
|
||||
*/
|
||||
class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
|
||||
/// Number of values
|
||||
size_t cardinality_;
|
||||
|
||||
/// allowed value
|
||||
size_t value_;
|
||||
|
||||
DiscreteKey discreteKey() const {
|
||||
return DiscreteKey(keys_[0], cardinality_);
|
||||
}
|
||||
|
||||
public:
|
||||
typedef boost::shared_ptr<SingleValue> shared_ptr;
|
||||
|
||||
/// Constructor
|
||||
SingleValue(Key key, size_t n, size_t value)
|
||||
: Constraint(key), cardinality_(n), value_(value) {}
|
||||
|
||||
/// Constructor
|
||||
SingleValue(const DiscreteKey& dkey, size_t value)
|
||||
: Constraint(dkey.first), cardinality_(dkey.second), value_(value) {}
|
||||
|
||||
// print
|
||||
void print(const std::string& s = "", const KeyFormatter& formatter =
|
||||
DefaultKeyFormatter) const override;
|
||||
|
||||
/// equals
|
||||
bool equals(const DiscreteFactor& other, double tol) const override {
|
||||
if (!dynamic_cast<const SingleValue*>(&other))
|
||||
return false;
|
||||
else {
|
||||
const SingleValue& f(static_cast<const SingleValue&>(other));
|
||||
return (cardinality_ == f.cardinality_) && (value_ == f.value_);
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate value
|
||||
double operator()(const Values& values) const override;
|
||||
|
||||
/// Convert into a decisiontree
|
||||
DecisionTreeFactor toDecisionTreeFactor() const override;
|
||||
|
||||
/// Multiply into a decisiontree
|
||||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
|
||||
|
||||
/*
|
||||
* Ensure Arc-consistency
|
||||
* @param j domain to be checked
|
||||
* @param domains all other domains
|
||||
/**
|
||||
* SingleValue constraint: ensures a variable takes on a certain value.
|
||||
* This could of course also be implemented by changing its `Domain`.
|
||||
*/
|
||||
bool ensureArcConsistency(size_t j,
|
||||
std::vector<Domain>& domains) const override;
|
||||
class GTSAM_UNSTABLE_EXPORT SingleValue: public Constraint {
|
||||
|
||||
size_t cardinality_; /// < Number of values
|
||||
size_t value_; ///< allowed value
|
||||
|
||||
/// Partially apply known values
|
||||
Constraint::shared_ptr partiallyApply(const Values& values) const override;
|
||||
DiscreteKey discreteKey() const {
|
||||
return DiscreteKey(keys_[0],cardinality_);
|
||||
}
|
||||
|
||||
/// Partially apply known values, domain version
|
||||
Constraint::shared_ptr partiallyApply(
|
||||
const std::vector<Domain>& domains) const override;
|
||||
};
|
||||
public:
|
||||
|
||||
} // namespace gtsam
|
||||
typedef boost::shared_ptr<SingleValue> shared_ptr;
|
||||
|
||||
/// Construct from key, cardinality, and given value.
|
||||
SingleValue(Key key, size_t n, size_t value) :
|
||||
Constraint(key), cardinality_(n), value_(value) {
|
||||
}
|
||||
|
||||
/// Construct from DiscreteKey and given value.
|
||||
SingleValue(const DiscreteKey& dkey, size_t value) :
|
||||
Constraint(dkey.first), cardinality_(dkey.second), value_(value) {
|
||||
}
|
||||
|
||||
// print
|
||||
void print(const std::string& s = "",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/// equals
|
||||
bool equals(const DiscreteFactor& other, double tol) const override {
|
||||
if(!dynamic_cast<const SingleValue*>(&other))
|
||||
return false;
|
||||
else {
|
||||
const SingleValue& f(static_cast<const SingleValue&>(other));
|
||||
return (cardinality_==f.cardinality_) && (value_==f.value_);
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate value
|
||||
double operator()(const Values& values) const override;
|
||||
|
||||
/// Convert into a decisiontree
|
||||
DecisionTreeFactor toDecisionTreeFactor() const override;
|
||||
|
||||
/// Multiply into a decisiontree
|
||||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
|
||||
|
||||
/*
|
||||
* Ensure Arc-consistency: just sets domain[j] to {value_}
|
||||
* @param j domain to be checked
|
||||
* @param domains all other domains
|
||||
*/
|
||||
bool ensureArcConsistency(size_t j,
|
||||
std::vector<Domain>* domains) const override;
|
||||
|
||||
/// Partially apply known values
|
||||
Constraint::shared_ptr partiallyApply(const Values& values) const override;
|
||||
|
||||
/// Partially apply known values, domain version
|
||||
Constraint::shared_ptr partiallyApply(
|
||||
const std::vector<Domain>& domains) const override;
|
||||
};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -19,12 +19,33 @@ using namespace std;
|
|||
using namespace gtsam;
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST_UNSAFE(BinaryAllDif, allInOne) {
|
||||
// Create keys and ordering
|
||||
TEST(CSP, SingleValue) {
|
||||
// Create keys for Idaho, Arizona, and Utah, allowing two colors for each:
|
||||
size_t nrColors = 3;
|
||||
DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors);
|
||||
|
||||
// Check that a single value is equal to a decision stump with only one "1":
|
||||
SingleValue singleValue(AZ, 2);
|
||||
DecisionTreeFactor f1(AZ, "0 0 1");
|
||||
EXPECT(assert_equal(f1, singleValue.toDecisionTreeFactor()));
|
||||
|
||||
// Create domains, laid out as a vector.
|
||||
// TODO(dellaert): should be map??
|
||||
vector<Domain> domains;
|
||||
domains += Domain(ID), Domain(AZ), Domain(UT);
|
||||
|
||||
// Ensure arc-consistency: just wipes out values in AZ domain:
|
||||
EXPECT(singleValue.ensureArcConsistency(1, &domains));
|
||||
LONGS_EQUAL(3, domains[0].nrValues());
|
||||
LONGS_EQUAL(1, domains[1].nrValues());
|
||||
LONGS_EQUAL(3, domains[2].nrValues());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(CSP, BinaryAllDif) {
|
||||
// Create keys for Idaho, Arizona, and Utah, allowing 2 colors for each:
|
||||
size_t nrColors = 2;
|
||||
// DiscreteKey ID("Idaho", nrColors), UT("Utah", nrColors), AZ("Arizona",
|
||||
// nrColors);
|
||||
DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors);
|
||||
DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors);
|
||||
|
||||
// Check construction and conversion
|
||||
BinaryAllDiff c1(ID, UT);
|
||||
|
@ -36,16 +57,51 @@ TEST_UNSAFE(BinaryAllDif, allInOne) {
|
|||
DecisionTreeFactor f2(UT & AZ, "0 1 1 0");
|
||||
EXPECT(assert_equal(f2, c2.toDecisionTreeFactor()));
|
||||
|
||||
// Check multiplication of factors with constraint:
|
||||
DecisionTreeFactor f3 = f1 * f2;
|
||||
EXPECT(assert_equal(f3, c1 * f2));
|
||||
EXPECT(assert_equal(f3, c2 * f1));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST_UNSAFE(CSP, allInOne) {
|
||||
// Create keys and ordering
|
||||
TEST(CSP, AllDiff) {
|
||||
// Create keys for Idaho, Arizona, and Utah, allowing two colors for each:
|
||||
size_t nrColors = 3;
|
||||
DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors);
|
||||
|
||||
// Check construction and conversion
|
||||
vector<DiscreteKey> dkeys{ID, UT, AZ};
|
||||
AllDiff alldiff(dkeys);
|
||||
DecisionTreeFactor actual = alldiff.toDecisionTreeFactor();
|
||||
// GTSAM_PRINT(actual);
|
||||
actual.dot("actual");
|
||||
DecisionTreeFactor f2(
|
||||
ID & AZ & UT,
|
||||
"0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0");
|
||||
EXPECT(assert_equal(f2, actual));
|
||||
|
||||
// Create domains.
|
||||
vector<Domain> domains;
|
||||
domains += Domain(ID), Domain(AZ), Domain(UT);
|
||||
|
||||
// First constrict AZ domain:
|
||||
SingleValue singleValue(AZ, 2);
|
||||
EXPECT(singleValue.ensureArcConsistency(1, &domains));
|
||||
|
||||
// Arc-consistency
|
||||
EXPECT(alldiff.ensureArcConsistency(0, &domains));
|
||||
EXPECT(!alldiff.ensureArcConsistency(1, &domains));
|
||||
EXPECT(alldiff.ensureArcConsistency(2, &domains));
|
||||
LONGS_EQUAL(2, domains[0].nrValues());
|
||||
LONGS_EQUAL(1, domains[1].nrValues());
|
||||
LONGS_EQUAL(2, domains[2].nrValues());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(CSP, allInOne) {
|
||||
// Create keys for Idaho, Arizona, and Utah, allowing 3 colors for each:
|
||||
size_t nrColors = 2;
|
||||
DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors);
|
||||
DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors);
|
||||
|
||||
// Create the CSP
|
||||
CSP csp;
|
||||
|
@ -81,15 +137,12 @@ TEST_UNSAFE(CSP, allInOne) {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST_UNSAFE(CSP, WesternUS) {
|
||||
// Create keys
|
||||
TEST(CSP, WesternUS) {
|
||||
// Create keys for all states in Western US, with 4 color possibilities.
|
||||
size_t nrColors = 4;
|
||||
DiscreteKey
|
||||
// Create ordering according to example in ND-CSP.lyx
|
||||
WA(0, nrColors),
|
||||
OR(3, nrColors), CA(1, nrColors), NV(2, nrColors), ID(8, nrColors),
|
||||
UT(9, nrColors), AZ(10, nrColors), MT(4, nrColors), WY(5, nrColors),
|
||||
CO(7, nrColors), NM(6, nrColors);
|
||||
DiscreteKey WA(0, nrColors), OR(3, nrColors), CA(1, nrColors),
|
||||
NV(2, nrColors), ID(8, nrColors), UT(9, nrColors), AZ(10, nrColors),
|
||||
MT(4, nrColors), WY(5, nrColors), CO(7, nrColors), NM(6, nrColors);
|
||||
|
||||
// Create the CSP
|
||||
CSP csp;
|
||||
|
@ -116,10 +169,12 @@ TEST_UNSAFE(CSP, WesternUS) {
|
|||
csp.addAllDiff(WY, CO);
|
||||
csp.addAllDiff(CO, NM);
|
||||
|
||||
// Solve
|
||||
// Create ordering according to example in ND-CSP.lyx
|
||||
Ordering ordering;
|
||||
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7),
|
||||
Key(8), Key(9), Key(10);
|
||||
|
||||
// Solve using that ordering:
|
||||
CSP::sharedValues mpe = csp.optimalAssignment(ordering);
|
||||
// GTSAM_PRINT(*mpe);
|
||||
CSP::Values expected;
|
||||
|
@ -143,33 +198,17 @@ TEST_UNSAFE(CSP, WesternUS) {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST_UNSAFE(CSP, AllDiff) {
|
||||
// Create keys and ordering
|
||||
TEST(CSP, ArcConsistency) {
|
||||
// Create keys for Idaho, Arizona, and Utah, allowing three colors for each:
|
||||
size_t nrColors = 3;
|
||||
DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors);
|
||||
DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors);
|
||||
|
||||
// Create the CSP
|
||||
// Create the CSP using just one all-diff constraint, plus constrain Arizona.
|
||||
CSP csp;
|
||||
vector<DiscreteKey> dkeys;
|
||||
dkeys += ID, UT, AZ;
|
||||
vector<DiscreteKey> dkeys{ID, UT, AZ};
|
||||
csp.addAllDiff(dkeys);
|
||||
csp.addSingleValue(AZ, 2);
|
||||
// GTSAM_PRINT(csp);
|
||||
|
||||
// Check construction and conversion
|
||||
SingleValue s(AZ, 2);
|
||||
DecisionTreeFactor f1(AZ, "0 0 1");
|
||||
EXPECT(assert_equal(f1, s.toDecisionTreeFactor()));
|
||||
|
||||
// Check construction and conversion
|
||||
AllDiff alldiff(dkeys);
|
||||
DecisionTreeFactor actual = alldiff.toDecisionTreeFactor();
|
||||
// GTSAM_PRINT(actual);
|
||||
// actual.dot("actual");
|
||||
DecisionTreeFactor f2(
|
||||
ID & AZ & UT,
|
||||
"0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0");
|
||||
EXPECT(assert_equal(f2, actual));
|
||||
// GTSAM_PRINT(csp);
|
||||
|
||||
// Check an invalid combination, with ID==UT==AZ all same color
|
||||
DiscreteFactor::Values invalid;
|
||||
|
@ -192,14 +231,15 @@ TEST_UNSAFE(CSP, AllDiff) {
|
|||
EXPECT(assert_equal(expected, *mpe));
|
||||
EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9);
|
||||
|
||||
// Arc-consistency
|
||||
// ensure arc-consistency, i.e., narrow domains...
|
||||
vector<Domain> domains;
|
||||
domains += Domain(ID), Domain(AZ), Domain(UT);
|
||||
SingleValue singleValue(AZ, 2);
|
||||
EXPECT(singleValue.ensureArcConsistency(1, domains));
|
||||
EXPECT(alldiff.ensureArcConsistency(0, domains));
|
||||
EXPECT(!alldiff.ensureArcConsistency(1, domains));
|
||||
EXPECT(alldiff.ensureArcConsistency(2, domains));
|
||||
AllDiff alldiff(dkeys);
|
||||
EXPECT(singleValue.ensureArcConsistency(1, &domains));
|
||||
EXPECT(alldiff.ensureArcConsistency(0, &domains));
|
||||
EXPECT(!alldiff.ensureArcConsistency(1, &domains));
|
||||
EXPECT(alldiff.ensureArcConsistency(2, &domains));
|
||||
LONGS_EQUAL(2, domains[0].nrValues());
|
||||
LONGS_EQUAL(1, domains[1].nrValues());
|
||||
LONGS_EQUAL(2, domains[2].nrValues());
|
||||
|
@ -222,6 +262,7 @@ TEST_UNSAFE(CSP, AllDiff) {
|
|||
|
||||
// full arc-consistency test
|
||||
csp.runArcConsistency(nrColors);
|
||||
// GTSAM_PRINT(csp);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
Loading…
Reference in New Issue