Formatting with Google style

release/4.3a0
Frank Dellaert 2021-11-18 10:54:00 -05:00
parent 13b0136e03
commit d27d6b60a7
16 changed files with 1301 additions and 1386 deletions

View File

@ -5,61 +5,60 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam_unstable/discrete/Domain.h>
#include <gtsam_unstable/discrete/AllDiff.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam_unstable/discrete/AllDiff.h>
#include <gtsam_unstable/discrete/Domain.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
namespace gtsam { namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
AllDiff::AllDiff(const DiscreteKeys& dkeys) : AllDiff::AllDiff(const DiscreteKeys& dkeys) : Constraint(dkeys.indices()) {
Constraint(dkeys.indices()) { for (const DiscreteKey& dkey : dkeys) cardinalities_.insert(dkey);
for(const DiscreteKey& dkey: dkeys) }
cardinalities_.insert(dkey);
}
/* ************************************************************************* */ /* ************************************************************************* */
void AllDiff::print(const std::string& s, void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const {
const KeyFormatter& formatter) const {
std::cout << s << "AllDiff on "; std::cout << s << "AllDiff on ";
for (Key dkey: keys_) for (Key dkey : keys_) std::cout << formatter(dkey) << " ";
std::cout << formatter(dkey) << " ";
std::cout << std::endl; std::cout << std::endl;
} }
/* ************************************************************************* */ /* ************************************************************************* */
double AllDiff::operator()(const Values& values) const { double AllDiff::operator()(const Values& values) const {
std::set < size_t > taken; // record values taken by keys std::set<size_t> taken; // record values taken by keys
for(Key dkey: keys_) { for (Key dkey : keys_) {
size_t value = values.at(dkey); // get the value for that key size_t value = values.at(dkey); // get the value for that key
if (taken.count(value)) return 0.0;// check if value alreday taken if (taken.count(value)) return 0.0; // check if value alreday taken
taken.insert(value);// if not, record it as taken and keep checking taken.insert(value); // if not, record it as taken and keep checking
} }
return 1.0; return 1.0;
} }
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor AllDiff::toDecisionTreeFactor() const { DecisionTreeFactor AllDiff::toDecisionTreeFactor() const {
// We will do this by converting the allDif into many BinaryAllDiff constraints // We will do this by converting the allDif into many BinaryAllDiff
// constraints
DecisionTreeFactor converted; DecisionTreeFactor converted;
size_t nrKeys = keys_.size(); size_t nrKeys = keys_.size();
for (size_t i1 = 0; i1 < nrKeys; i1++) for (size_t i1 = 0; i1 < nrKeys; i1++)
for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) { for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) {
BinaryAllDiff binary12(discreteKey(i1),discreteKey(i2)); BinaryAllDiff binary12(discreteKey(i1), discreteKey(i2));
converted = converted * binary12.toDecisionTreeFactor(); converted = converted * binary12.toDecisionTreeFactor();
} }
return converted; return converted;
} }
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const { DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const {
// TODO: can we do this more efficiently? // TODO: can we do this more efficiently?
return toDecisionTreeFactor() * f; return toDecisionTreeFactor() * f;
} }
/* ************************************************************************* */ /* ************************************************************************* */
bool AllDiff::ensureArcConsistency(size_t j, std::vector<Domain>& domains) const { bool AllDiff::ensureArcConsistency(size_t j,
std::vector<Domain>& domains) const {
// Though strictly not part of allDiff, we check for // Though strictly not part of allDiff, we check for
// a value in domains[j] that does not occur in any other connected domain. // a value in domains[j] that does not occur in any other connected domain.
// If found, we make this a singleton... // If found, we make this a singleton...
@ -70,7 +69,7 @@ namespace gtsam {
// Check all other domains for singletons and erase corresponding values // Check all other domains for singletons and erase corresponding values
// This is the same as arc-consistency on the equivalent binary constraints // This is the same as arc-consistency on the equivalent binary constraints
bool changed = false; bool changed = false;
for(Key k: keys_) for (Key k : keys_)
if (k != j) { if (k != j) {
const Domain& Dk = domains[k]; const Domain& Dk = domains[k];
if (Dk.isSingleton()) { // check if singleton if (Dk.isSingleton()) { // check if singleton
@ -82,30 +81,29 @@ namespace gtsam {
} }
} }
return changed; return changed;
} }
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const { Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const {
DiscreteKeys newKeys; DiscreteKeys newKeys;
// loop over keys and add them only if they do not appear in values // loop over keys and add them only if they do not appear in values
for(Key k: keys_) for (Key k : keys_)
if (values.find(k) == values.end()) { if (values.find(k) == values.end()) {
newKeys.push_back(DiscreteKey(k,cardinalities_.at(k))); newKeys.push_back(DiscreteKey(k, cardinalities_.at(k)));
} }
return boost::make_shared<AllDiff>(newKeys); return boost::make_shared<AllDiff>(newKeys);
} }
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr AllDiff::partiallyApply( Constraint::shared_ptr AllDiff::partiallyApply(
const std::vector<Domain>& domains) const { const std::vector<Domain>& domains) const {
DiscreteFactor::Values known; DiscreteFactor::Values known;
for(Key k: keys_) { for (Key k : keys_) {
const Domain& Dk = domains[k]; const Domain& Dk = domains[k];
if (Dk.isSingleton()) if (Dk.isSingleton()) known[k] = Dk.firstValue();
known[k] = Dk.firstValue();
} }
return partiallyApply(known); return partiallyApply(known);
} }
/* ************************************************************************* */ /* ************************************************************************* */
} // namespace gtsam } // namespace gtsam

View File

@ -7,44 +7,42 @@
#pragma once #pragma once
#include <gtsam_unstable/discrete/BinaryAllDiff.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam_unstable/discrete/BinaryAllDiff.h>
namespace gtsam { namespace gtsam {
/** /**
* General AllDiff constraint * General AllDiff constraint
* Returns 1 if values for all keys are different, 0 otherwise * Returns 1 if values for all keys are different, 0 otherwise
* DiscreteFactors are all awkward in that they have to store two types of keys: * DiscreteFactors are all awkward in that they have to store two types of keys:
* for each variable we have a Key and an Key. In this factor, we * for each variable we have a Key and an Key. In this factor, we
* keep the Indices locally, and the Indices are stored in IndexFactor. * keep the Indices locally, and the Indices are stored in IndexFactor.
*/ */
class GTSAM_UNSTABLE_EXPORT AllDiff: public Constraint { class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
std::map<Key, size_t> cardinalities_;
std::map<Key,size_t> cardinalities_;
DiscreteKey discreteKey(size_t i) const { DiscreteKey discreteKey(size_t i) const {
Key j = keys_[i]; Key j = keys_[i];
return DiscreteKey(j,cardinalities_.at(j)); return DiscreteKey(j, cardinalities_.at(j));
} }
public: public:
/// Constructor /// Constructor
AllDiff(const DiscreteKeys& dkeys); AllDiff(const DiscreteKeys& dkeys);
// print // print
void print(const std::string& s = "", void print(const std::string& s = "", const KeyFormatter& formatter =
const KeyFormatter& formatter = DefaultKeyFormatter) const override; DefaultKeyFormatter) const override;
/// equals /// equals
bool equals(const DiscreteFactor& other, double tol) const override { bool equals(const DiscreteFactor& other, double tol) const override {
if(!dynamic_cast<const AllDiff*>(&other)) if (!dynamic_cast<const AllDiff*>(&other))
return false; return false;
else { else {
const AllDiff& f(static_cast<const AllDiff&>(other)); const AllDiff& f(static_cast<const AllDiff&>(other));
return cardinalities_.size() == f.cardinalities_.size() return cardinalities_.size() == f.cardinalities_.size() &&
&& std::equal(cardinalities_.begin(), cardinalities_.end(), std::equal(cardinalities_.begin(), cardinalities_.end(),
f.cardinalities_.begin()); f.cardinalities_.begin());
} }
} }
@ -65,13 +63,15 @@ namespace gtsam {
* @param j domain to be checked * @param j domain to be checked
* @param domains all other domains * @param domains all other domains
*/ */
bool ensureArcConsistency(size_t j, std::vector<Domain>& domains) const override; bool ensureArcConsistency(size_t j,
std::vector<Domain>& domains) const override;
/// Partially apply known values /// Partially apply known values
Constraint::shared_ptr partiallyApply(const Values&) const override; Constraint::shared_ptr partiallyApply(const Values&) const override;
/// Partially apply known values, domain version /// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply(const std::vector<Domain>&) const override; Constraint::shared_ptr partiallyApply(
}; const std::vector<Domain>&) const override;
};
} // namespace gtsam } // namespace gtsam

View File

@ -7,33 +7,32 @@
#pragma once #pragma once
#include <gtsam_unstable/discrete/Domain.h>
#include <gtsam_unstable/discrete/Constraint.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam_unstable/discrete/Constraint.h>
#include <gtsam_unstable/discrete/Domain.h>
namespace gtsam { namespace gtsam {
/** /**
* Binary AllDiff constraint * Binary AllDiff constraint
* Returns 1 if values for two keys are different, 0 otherwise * Returns 1 if values for two keys are different, 0 otherwise
* DiscreteFactors are all awkward in that they have to store two types of keys: * DiscreteFactors are all awkward in that they have to store two types of keys:
* for each variable we have a Index and an Index. In this factor, we * for each variable we have a Index and an Index. In this factor, we
* keep the Indices locally, and the Indices are stored in IndexFactor. * keep the Indices locally, and the Indices are stored in IndexFactor.
*/ */
class BinaryAllDiff: public Constraint { class BinaryAllDiff : public Constraint {
size_t cardinality0_, cardinality1_; /// cardinality size_t cardinality0_, cardinality1_; /// cardinality
public: public:
/// Constructor /// Constructor
BinaryAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) : BinaryAllDiff(const DiscreteKey& key1, const DiscreteKey& key2)
Constraint(key1.first, key2.first), : Constraint(key1.first, key2.first),
cardinality0_(key1.second), cardinality1_(key2.second) { cardinality0_(key1.second),
} cardinality1_(key2.second) {}
// print // print
void print(const std::string& s = "", void print(
const std::string& s = "",
const KeyFormatter& formatter = DefaultKeyFormatter) const override { const KeyFormatter& formatter = DefaultKeyFormatter) const override {
std::cout << s << "BinaryAllDiff on " << formatter(keys_[0]) << " and " std::cout << s << "BinaryAllDiff on " << formatter(keys_[0]) << " and "
<< formatter(keys_[1]) << std::endl; << formatter(keys_[1]) << std::endl;
@ -41,28 +40,28 @@ namespace gtsam {
/// equals /// equals
bool equals(const DiscreteFactor& other, double tol) const override { bool equals(const DiscreteFactor& other, double tol) const override {
if(!dynamic_cast<const BinaryAllDiff*>(&other)) if (!dynamic_cast<const BinaryAllDiff*>(&other))
return false; return false;
else { else {
const BinaryAllDiff& f(static_cast<const BinaryAllDiff&>(other)); const BinaryAllDiff& f(static_cast<const BinaryAllDiff&>(other));
return (cardinality0_==f.cardinality0_) && (cardinality1_==f.cardinality1_); return (cardinality0_ == f.cardinality0_) &&
(cardinality1_ == f.cardinality1_);
} }
} }
/// Calculate value /// Calculate value
double operator()(const Values& values) const override { double operator()(const Values& values) const override {
return (double) (values.at(keys_[0]) != values.at(keys_[1])); return (double)(values.at(keys_[0]) != values.at(keys_[1]));
} }
/// Convert into a decisiontree /// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override { DecisionTreeFactor toDecisionTreeFactor() const override {
DiscreteKeys keys; DiscreteKeys keys;
keys.push_back(DiscreteKey(keys_[0],cardinality0_)); keys.push_back(DiscreteKey(keys_[0], cardinality0_));
keys.push_back(DiscreteKey(keys_[1],cardinality1_)); keys.push_back(DiscreteKey(keys_[1], cardinality1_));
std::vector<double> table; std::vector<double> table;
for (size_t i1 = 0; i1 < cardinality0_; i1++) for (size_t i1 = 0; i1 < cardinality0_; i1++)
for (size_t i2 = 0; i2 < cardinality1_; i2++) for (size_t i2 = 0; i2 < cardinality1_; i2++) table.push_back(i1 != i2);
table.push_back(i1 != i2);
DecisionTreeFactor converted(keys, table); DecisionTreeFactor converted(keys, table);
return converted; return converted;
} }
@ -78,10 +77,10 @@ namespace gtsam {
* @param j domain to be checked * @param j domain to be checked
* @param domains all other domains * @param domains all other domains
*/ */
/// bool ensureArcConsistency(size_t j,
bool ensureArcConsistency(size_t j, std::vector<Domain>& domains) const override { std::vector<Domain>& domains) const override {
// throw std::runtime_error( // throw std::runtime_error(
// "BinaryAllDiff::ensureArcConsistency not implemented"); // "BinaryAllDiff::ensureArcConsistency not implemented");
return false; return false;
} }
@ -95,6 +94,6 @@ namespace gtsam {
const std::vector<Domain>&) const override { const std::vector<Domain>&) const override {
throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented");
} }
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -5,29 +5,30 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam_unstable/discrete/Domain.h>
#include <gtsam_unstable/discrete/CSP.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam_unstable/discrete/CSP.h>
#include <gtsam_unstable/discrete/Domain.h>
using namespace std; using namespace std;
namespace gtsam { namespace gtsam {
/// Find the best total assignment - can be expensive /// Find the best total assignment - can be expensive
CSP::sharedValues CSP::optimalAssignment() const { CSP::sharedValues CSP::optimalAssignment() const {
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(); DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential();
sharedValues mpe = chordal->optimize(); sharedValues mpe = chordal->optimize();
return mpe; return mpe;
} }
/// Find the best total assignment - can be expensive /// Find the best total assignment - can be expensive
CSP::sharedValues CSP::optimalAssignment(const Ordering& ordering) const { CSP::sharedValues CSP::optimalAssignment(const Ordering& ordering) const {
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(ordering); DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(ordering);
sharedValues mpe = chordal->optimize(); sharedValues mpe = chordal->optimize();
return mpe; return mpe;
} }
void CSP::runArcConsistency(size_t cardinality, size_t nrIterations, bool print) const { void CSP::runArcConsistency(size_t cardinality, size_t nrIterations,
bool print) const {
// Create VariableIndex // Create VariableIndex
VariableIndex index(*this); VariableIndex index(*this);
// index.print(); // index.print();
@ -35,9 +36,9 @@ namespace gtsam {
size_t n = index.size(); size_t n = index.size();
// Initialize domains // Initialize domains
std::vector < Domain > domains; std::vector<Domain> domains;
for (size_t j = 0; j < n; j++) for (size_t j = 0; j < n; j++)
domains.push_back(Domain(DiscreteKey(j,cardinality))); domains.push_back(Domain(DiscreteKey(j, cardinality)));
// Create array of flags indicating a domain changed or not // Create array of flags indicating a domain changed or not
std::vector<bool> changed(n); std::vector<bool> changed(n);
@ -51,13 +52,16 @@ namespace gtsam {
changed[v] = false; changed[v] = false;
// loop over all factors/constraints for variable v // loop over all factors/constraints for variable v
const FactorIndices& factors = index[v]; const FactorIndices& factors = index[v];
for(size_t f: factors) { for (size_t f : factors) {
// if not already a singleton // if not already a singleton
if (!domains[v].isSingleton()) { if (!domains[v].isSingleton()) {
// get the constraint and call its ensureArcConsistency method // get the constraint and call its ensureArcConsistency method
Constraint::shared_ptr constraint = boost::dynamic_pointer_cast<Constraint>((*this)[f]); Constraint::shared_ptr constraint =
if (!constraint) throw runtime_error("CSP:runArcConsistency: non-constraint factor"); boost::dynamic_pointer_cast<Constraint>((*this)[f]);
changed[v] = constraint->ensureArcConsistency(v,domains) || changed[v]; if (!constraint)
throw runtime_error("CSP:runArcConsistency: non-constraint factor");
changed[v] =
constraint->ensureArcConsistency(v, domains) || changed[v];
} }
} // f } // f
if (changed[v]) anyChange = true; if (changed[v]) anyChange = true;
@ -91,13 +95,14 @@ namespace gtsam {
// TODO: create a new ordering as we go, to ensure a connected graph // TODO: create a new ordering as we go, to ensure a connected graph
// KeyOrdering ordering; // KeyOrdering ordering;
// vector<Index> dkeys; // vector<Index> dkeys;
for(const DiscreteFactor::shared_ptr& f: factors_) { for (const DiscreteFactor::shared_ptr& f : factors_) {
Constraint::shared_ptr constraint = boost::dynamic_pointer_cast<Constraint>(f); Constraint::shared_ptr constraint =
if (!constraint) throw runtime_error("CSP:runArcConsistency: non-constraint factor"); boost::dynamic_pointer_cast<Constraint>(f);
if (!constraint)
throw runtime_error("CSP:runArcConsistency: non-constraint factor");
Constraint::shared_ptr reduced = constraint->partiallyApply(domains); Constraint::shared_ptr reduced = constraint->partiallyApply(domains);
if (print) reduced->print(); if (print) reduced->print();
} }
#endif #endif
} }
} // gtsam } // namespace gtsam

View File

@ -7,30 +7,28 @@
#pragma once #pragma once
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam_unstable/discrete/AllDiff.h> #include <gtsam_unstable/discrete/AllDiff.h>
#include <gtsam_unstable/discrete/SingleValue.h> #include <gtsam_unstable/discrete/SingleValue.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
namespace gtsam { namespace gtsam {
/** /**
* Constraint Satisfaction Problem class * Constraint Satisfaction Problem class
* A specialization of a DiscreteFactorGraph. * A specialization of a DiscreteFactorGraph.
* It knows about CSP-specific constraints and algorithms * It knows about CSP-specific constraints and algorithms
*/ */
class GTSAM_UNSTABLE_EXPORT CSP: public DiscreteFactorGraph { class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph {
public: public:
/** A map from keys to values */ /** A map from keys to values */
typedef KeyVector Indices; typedef KeyVector Indices;
typedef Assignment<Key> Values; typedef Assignment<Key> Values;
typedef boost::shared_ptr<Values> sharedValues; typedef boost::shared_ptr<Values> sharedValues;
public: public:
// /// Constructor
// /// Constructor // CSP() {
// CSP() { // }
// }
/// Add a unary constraint, allowing only a single value /// Add a unary constraint, allowing only a single value
void addSingleValue(const DiscreteKey& dkey, size_t value) { void addSingleValue(const DiscreteKey& dkey, size_t value) {
@ -40,8 +38,7 @@ namespace gtsam {
/// Add a binary AllDiff constraint /// Add a binary AllDiff constraint
void addAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) { void addAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) {
boost::shared_ptr<BinaryAllDiff> factor( boost::shared_ptr<BinaryAllDiff> factor(new BinaryAllDiff(key1, key2));
new BinaryAllDiff(key1, key2));
push_back(factor); push_back(factor);
} }
@ -51,13 +48,13 @@ namespace gtsam {
push_back(factor); push_back(factor);
} }
// /** return product of all factors as a single factor */ // /** return product of all factors as a single factor */
// DecisionTreeFactor product() const { // DecisionTreeFactor product() const {
// DecisionTreeFactor result; // DecisionTreeFactor result;
// for(const sharedFactor& factor: *this) // for(const sharedFactor& factor: *this)
// if (factor) result = (*factor) * result; // if (factor) result = (*factor) * result;
// return result; // return result;
// } // }
/// Find the best total assignment - can be expensive /// Find the best total assignment - can be expensive
sharedValues optimalAssignment() const; sharedValues optimalAssignment() const;
@ -65,16 +62,17 @@ namespace gtsam {
/// Find the best total assignment - can be expensive /// Find the best total assignment - can be expensive
sharedValues optimalAssignment(const Ordering& ordering) const; sharedValues optimalAssignment(const Ordering& ordering) const;
// /* // /*
// * Perform loopy belief propagation // * Perform loopy belief propagation
// * True belief propagation would check for each value in domain // * True belief propagation would check for each value in domain
// * whether any satisfying separator assignment can be found. // * whether any satisfying separator assignment can be found.
// * This corresponds to hyper-arc consistency in CSP speak. // * This corresponds to hyper-arc consistency in CSP speak.
// * This can be done by creating a mini-factor graph and search. // * This can be done by creating a mini-factor graph and search.
// * For a nine-by-nine Sudoku, the search tree will be 8+6+6=20 levels deep. // * For a nine-by-nine Sudoku, the search tree will be 8+6+6=20 levels
// * It will be very expensive to exclude values that way. // deep.
// */ // * It will be very expensive to exclude values that way.
// void applyBeliefPropagation(size_t nrIterations = 10) const; // */
// void applyBeliefPropagation(size_t nrIterations = 10) const;
/* /*
* Apply arc-consistency ~ Approximate loopy belief propagation * Apply arc-consistency ~ Approximate loopy belief propagation
@ -84,7 +82,6 @@ namespace gtsam {
*/ */
void runArcConsistency(size_t cardinality, size_t nrIterations = 10, void runArcConsistency(size_t cardinality, size_t nrIterations = 10,
bool print = false) const; bool print = false) const;
}; // CSP }; // CSP
} // gtsam
} // namespace gtsam

View File

@ -17,49 +17,40 @@
#pragma once #pragma once
#include <gtsam_unstable/dllexport.h>
#include <gtsam/discrete/DiscreteFactor.h> #include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam_unstable/dllexport.h>
#include <boost/assign.hpp> #include <boost/assign.hpp>
namespace gtsam { namespace gtsam {
class Domain; class Domain;
/** /**
* Base class for discrete probabilistic factors * Base class for discrete probabilistic factors
* The most general one is the derived DecisionTreeFactor * The most general one is the derived DecisionTreeFactor
*/ */
class Constraint : public DiscreteFactor { class Constraint : public DiscreteFactor {
public: public:
typedef boost::shared_ptr<Constraint> shared_ptr; typedef boost::shared_ptr<Constraint> shared_ptr;
protected: protected:
/// Construct n-way factor /// Construct n-way factor
Constraint(const KeyVector& js) : Constraint(const KeyVector& js) : DiscreteFactor(js) {}
DiscreteFactor(js) {
}
/// Construct unary factor /// Construct unary factor
Constraint(Key j) : Constraint(Key j) : DiscreteFactor(boost::assign::cref_list_of<1>(j)) {}
DiscreteFactor(boost::assign::cref_list_of<1>(j)) {
}
/// Construct binary factor /// Construct binary factor
Constraint(Key j1, Key j2) : Constraint(Key j1, Key j2)
DiscreteFactor(boost::assign::cref_list_of<2>(j1)(j2)) { : DiscreteFactor(boost::assign::cref_list_of<2>(j1)(j2)) {}
}
/// construct from container /// construct from container
template<class KeyIterator> template <class KeyIterator>
Constraint(KeyIterator beginKey, KeyIterator endKey) : Constraint(KeyIterator beginKey, KeyIterator endKey)
DiscreteFactor(beginKey, endKey) { : DiscreteFactor(beginKey, endKey) {}
}
public: public:
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
@ -78,16 +69,16 @@ namespace gtsam {
* @param j domain to be checked * @param j domain to be checked
* @param domains all other domains * @param domains all other domains
*/ */
virtual bool ensureArcConsistency(size_t j, std::vector<Domain>& domains) const = 0; virtual bool ensureArcConsistency(size_t j,
std::vector<Domain>& domains) const = 0;
/// Partially apply known values /// Partially apply known values
virtual shared_ptr partiallyApply(const Values&) const = 0; virtual shared_ptr partiallyApply(const Values&) const = 0;
/// Partially apply known values, domain version /// Partially apply known values, domain version
virtual shared_ptr partiallyApply(const std::vector<Domain>&) const = 0; virtual shared_ptr partiallyApply(const std::vector<Domain>&) const = 0;
/// @} /// @}
}; };
// DiscreteFactor // DiscreteFactor
}// namespace gtsam } // namespace gtsam

View File

@ -5,92 +5,89 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam_unstable/discrete/Domain.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam_unstable/discrete/Domain.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
namespace gtsam { namespace gtsam {
using namespace std; using namespace std;
/* ************************************************************************* */ /* ************************************************************************* */
void Domain::print(const string& s, void Domain::print(const string& s, const KeyFormatter& formatter) const {
const KeyFormatter& formatter) const { // cout << s << ": Domain on " << formatter(keys_[0]) << " (j=" <<
// cout << s << ": Domain on " << formatter(keys_[0]) << " (j=" << // formatter(keys_[0]) << ") with values";
// formatter(keys_[0]) << ") with values"; // for (size_t v: values_) cout << " " << v;
// for (size_t v: values_) cout << " " << v; // cout << endl;
// cout << endl; for (size_t v : values_) cout << v;
for (size_t v: values_) cout << v; }
}
/* ************************************************************************* */ /* ************************************************************************* */
double Domain::operator()(const Values& values) const { double Domain::operator()(const Values& values) const {
return contains(values.at(keys_[0])); return contains(values.at(keys_[0]));
} }
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor Domain::toDecisionTreeFactor() const { DecisionTreeFactor Domain::toDecisionTreeFactor() const {
DiscreteKeys keys; DiscreteKeys keys;
keys += DiscreteKey(keys_[0],cardinality_); keys += DiscreteKey(keys_[0], cardinality_);
vector<double> table; vector<double> table;
for (size_t i1 = 0; i1 < cardinality_; ++i1) for (size_t i1 = 0; i1 < cardinality_; ++i1) table.push_back(contains(i1));
table.push_back(contains(i1));
DecisionTreeFactor converted(keys, table); DecisionTreeFactor converted(keys, table);
return converted; return converted;
} }
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const { DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const {
// TODO: can we do this more efficiently? // TODO: can we do this more efficiently?
return toDecisionTreeFactor() * f; return toDecisionTreeFactor() * f;
} }
/* ************************************************************************* */ /* ************************************************************************* */
bool Domain::ensureArcConsistency(size_t j, vector<Domain>& domains) const { bool Domain::ensureArcConsistency(size_t j, vector<Domain>& domains) const {
if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain"); if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain");
Domain& D = domains[j]; Domain& D = domains[j];
for(size_t value: values_) for (size_t value : values_)
if (!D.contains(value)) throw runtime_error("Unsatisfiable"); if (!D.contains(value)) throw runtime_error("Unsatisfiable");
D = *this; D = *this;
return true; return true;
} }
/* ************************************************************************* */ /* ************************************************************************* */
bool Domain::checkAllDiff(const KeyVector keys, vector<Domain>& domains) { bool Domain::checkAllDiff(const KeyVector keys, vector<Domain>& domains) {
Key j = keys_[0]; Key j = keys_[0];
// for all values in this domain // for all values in this domain
for(size_t value: values_) { for (size_t value : values_) {
// for all connected domains // for all connected domains
for(Key k: keys) for (Key k : keys)
// if any domain contains the value we cannot make this domain singleton // if any domain contains the value we cannot make this domain singleton
if (k!=j && domains[k].contains(value)) if (k != j && domains[k].contains(value)) goto found;
goto found;
values_.clear(); values_.clear();
values_.insert(value); values_.insert(value);
return true; // we changed it return true; // we changed it
found:; found:;
} }
return false; // we did not change it return false; // we did not change it
} }
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr Domain::partiallyApply( Constraint::shared_ptr Domain::partiallyApply(const Values& values) const {
const Values& values) const {
Values::const_iterator it = values.find(keys_[0]); Values::const_iterator it = values.find(keys_[0]);
if (it != values.end() && !contains(it->second)) throw runtime_error( if (it != values.end() && !contains(it->second))
"Domain::partiallyApply: unsatisfiable"); throw runtime_error("Domain::partiallyApply: unsatisfiable");
return boost::make_shared < Domain > (*this); return boost::make_shared<Domain>(*this);
} }
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr Domain::partiallyApply( Constraint::shared_ptr Domain::partiallyApply(
const vector<Domain>& domains) const { const vector<Domain>& domains) const {
const Domain& Dk = domains[keys_[0]]; const Domain& Dk = domains[keys_[0]];
if (Dk.isSingleton() && !contains(*Dk.begin())) throw runtime_error( if (Dk.isSingleton() && !contains(*Dk.begin()))
"Domain::partiallyApply: unsatisfiable"); throw runtime_error("Domain::partiallyApply: unsatisfiable");
return boost::make_shared < Domain > (Dk); return boost::make_shared<Domain>(Dk);
} }
/* ************************************************************************* */ /* ************************************************************************* */
} // namespace gtsam } // namespace gtsam

View File

@ -7,81 +7,65 @@
#pragma once #pragma once
#include <gtsam_unstable/discrete/Constraint.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam_unstable/discrete/Constraint.h>
namespace gtsam { namespace gtsam {
/** /**
* Domain restriction constraint * Domain restriction constraint
*/ */
class GTSAM_UNSTABLE_EXPORT Domain: public Constraint { class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
size_t cardinality_; /// Cardinality size_t cardinality_; /// Cardinality
std::set<size_t> values_; /// allowed values std::set<size_t> values_; /// allowed values
public: public:
typedef boost::shared_ptr<Domain> shared_ptr; typedef boost::shared_ptr<Domain> shared_ptr;
// Constructor on Discrete Key initializes an "all-allowed" domain // Constructor on Discrete Key initializes an "all-allowed" domain
Domain(const DiscreteKey& dkey) : Domain(const DiscreteKey& dkey)
Constraint(dkey.first), cardinality_(dkey.second) { : Constraint(dkey.first), cardinality_(dkey.second) {
for (size_t v = 0; v < cardinality_; v++) for (size_t v = 0; v < cardinality_; v++) values_.insert(v);
values_.insert(v);
} }
// Constructor on Discrete Key with single allowed value // Constructor on Discrete Key with single allowed value
// Consider SingleValue constraint // Consider SingleValue constraint
Domain(const DiscreteKey& dkey, size_t v) : Domain(const DiscreteKey& dkey, size_t v)
Constraint(dkey.first), cardinality_(dkey.second) { : Constraint(dkey.first), cardinality_(dkey.second) {
values_.insert(v); values_.insert(v);
} }
/// Constructor /// Constructor
Domain(const Domain& other) : Domain(const Domain& other)
Constraint(other.keys_[0]), values_(other.values_) { : Constraint(other.keys_[0]), values_(other.values_) {}
}
/// insert a value, non const :-( /// insert a value, non const :-(
void insert(size_t value) { void insert(size_t value) { values_.insert(value); }
values_.insert(value);
}
/// erase a value, non const :-( /// erase a value, non const :-(
void erase(size_t value) { void erase(size_t value) { values_.erase(value); }
values_.erase(value);
}
size_t nrValues() const { size_t nrValues() const { return values_.size(); }
return values_.size();
}
bool isSingleton() const { bool isSingleton() const { return nrValues() == 1; }
return nrValues() == 1;
}
size_t firstValue() const { size_t firstValue() const { return *values_.begin(); }
return *values_.begin();
}
// print // print
void print(const std::string& s = "", void print(const std::string& s = "", const KeyFormatter& formatter =
const KeyFormatter& formatter = DefaultKeyFormatter) const override; DefaultKeyFormatter) const override;
/// equals /// equals
bool equals(const DiscreteFactor& other, double tol) const override { bool equals(const DiscreteFactor& other, double tol) const override {
if(!dynamic_cast<const Domain*>(&other)) if (!dynamic_cast<const Domain*>(&other))
return false; return false;
else { else {
const Domain& f(static_cast<const Domain&>(other)); const Domain& f(static_cast<const Domain&>(other));
return (cardinality_==f.cardinality_) && (values_==f.values_); return (cardinality_ == f.cardinality_) && (values_ == f.values_);
} }
} }
bool contains(size_t value) const { bool contains(size_t value) const { return values_.count(value) > 0; }
return values_.count(value)>0;
}
/// Calculate value /// Calculate value
double operator()(const Values& values) const override; double operator()(const Values& values) const override;
@ -97,11 +81,13 @@ namespace gtsam {
* @param j domain to be checked * @param j domain to be checked
* @param domains all other domains * @param domains all other domains
*/ */
bool ensureArcConsistency(size_t j, std::vector<Domain>& domains) const override; bool ensureArcConsistency(size_t j,
std::vector<Domain>& domains) const override;
/** /**
* Check for a value in domain that does not occur in any other connected domain. * Check for a value in domain that does not occur in any other connected
* If found, we make this a singleton... Called in AllDiff::ensureArcConsistency * domain. If found, we make this a singleton... Called in
* AllDiff::ensureArcConsistency
* @param keys connected domains through alldiff * @param keys connected domains through alldiff
*/ */
bool checkAllDiff(const KeyVector keys, std::vector<Domain>& domains); bool checkAllDiff(const KeyVector keys, std::vector<Domain>& domains);
@ -112,6 +98,6 @@ namespace gtsam {
/// Partially apply known values, domain version /// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply( Constraint::shared_ptr partiallyApply(
const std::vector<Domain>& domains) const override; const std::vector<Domain>& domains) const override;
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -5,24 +5,22 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam_unstable/discrete/Scheduler.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/base/debug.h> #include <gtsam/base/debug.h>
#include <gtsam/base/timing.h> #include <gtsam/base/timing.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam_unstable/discrete/Scheduler.h>
#include <boost/tokenizer.hpp> #include <boost/tokenizer.hpp>
#include <cmath>
#include <fstream> #include <fstream>
#include <iomanip> #include <iomanip>
#include <cmath>
namespace gtsam { namespace gtsam {
using namespace std; using namespace std;
Scheduler::Scheduler(size_t maxNrStudents, const string& filename): Scheduler::Scheduler(size_t maxNrStudents, const string& filename)
maxNrStudents_(maxNrStudents) : maxNrStudents_(maxNrStudents) {
{
typedef boost::tokenizer<boost::escaped_list_separator<char> > Tokenizer; typedef boost::tokenizer<boost::escaped_list_separator<char> > Tokenizer;
// open file // open file
@ -38,8 +36,7 @@ namespace gtsam {
if (getline(is, line, '\r')) { if (getline(is, line, '\r')) {
Tokenizer tok(line); Tokenizer tok(line);
Tokenizer::iterator it = tok.begin(); Tokenizer::iterator it = tok.begin();
for (++it; it != tok.end(); ++it) for (++it; it != tok.end(); ++it) addFaculty(*it);
addFaculty(*it);
} }
// for all remaining lines // for all remaining lines
@ -50,17 +47,16 @@ namespace gtsam {
Tokenizer::iterator it = tok.begin(); Tokenizer::iterator it = tok.begin();
addSlot(*it++); // add slot addSlot(*it++); // add slot
// add availability // add availability
for (; it != tok.end(); ++it) for (; it != tok.end(); ++it) available_ += (it->empty()) ? "0 " : "1 ";
available_ += (it->empty()) ? "0 " : "1 ";
available_ += '\n'; available_ += '\n';
} }
} // constructor } // constructor
/** addStudent has to be called after adding slots and faculty */ /** addStudent has to be called after adding slots and faculty */
void Scheduler::addStudent(const string& studentName, void Scheduler::addStudent(const string& studentName, const string& area1,
const string& area1, const string& area2, const string& area2, const string& area3,
const string& area3, const string& advisor) { const string& advisor) {
assert(nrStudents()<maxNrStudents_); assert(nrStudents() < maxNrStudents_);
assert(facultyInArea_.count(area1)); assert(facultyInArea_.count(area1));
assert(facultyInArea_.count(area2)); assert(facultyInArea_.count(area2));
assert(facultyInArea_.count(area3)); assert(facultyInArea_.count(area3));
@ -69,71 +65,73 @@ namespace gtsam {
student.name_ = studentName; student.name_ = studentName;
// We fix the ordering by assigning a higher index to the student // We fix the ordering by assigning a higher index to the student
// and numbering the areas lower // and numbering the areas lower
Key j = 3*maxNrStudents_ + nrStudents(); Key j = 3 * maxNrStudents_ + nrStudents();
student.key_ = DiscreteKey(j, nrTimeSlots()); student.key_ = DiscreteKey(j, nrTimeSlots());
Key base = 3*nrStudents(); Key base = 3 * nrStudents();
student.keys_[0] = DiscreteKey(base+0, nrFaculty()); student.keys_[0] = DiscreteKey(base + 0, nrFaculty());
student.keys_[1] = DiscreteKey(base+1, nrFaculty()); student.keys_[1] = DiscreteKey(base + 1, nrFaculty());
student.keys_[2] = DiscreteKey(base+2, nrFaculty()); student.keys_[2] = DiscreteKey(base + 2, nrFaculty());
student.areaName_[0] = area1; student.areaName_[0] = area1;
student.areaName_[1] = area2; student.areaName_[1] = area2;
student.areaName_[2] = area3; student.areaName_[2] = area3;
students_.push_back(student); students_.push_back(student);
} }
/** get key for student and area, 0 is time slot itself */ /** get key for student and area, 0 is time slot itself */
const DiscreteKey& Scheduler::key(size_t s, boost::optional<size_t> area) const { const DiscreteKey& Scheduler::key(size_t s,
boost::optional<size_t> area) const {
return area ? students_[s].keys_[*area] : students_[s].key_; return area ? students_[s].keys_[*area] : students_[s].key_;
} }
const string& Scheduler::studentName(size_t i) const { const string& Scheduler::studentName(size_t i) const {
assert(i<nrStudents()); assert(i < nrStudents());
return students_[i].name_; return students_[i].name_;
} }
const DiscreteKey& Scheduler::studentKey(size_t i) const { const DiscreteKey& Scheduler::studentKey(size_t i) const {
assert(i<nrStudents()); assert(i < nrStudents());
return students_[i].key_; return students_[i].key_;
} }
const string& Scheduler::studentArea(size_t i, size_t area) const { const string& Scheduler::studentArea(size_t i, size_t area) const {
assert(i<nrStudents()); assert(i < nrStudents());
return students_[i].areaName_[area]; return students_[i].areaName_[area];
} }
/** Add student-specific constraints to the graph */ /** Add student-specific constraints to the graph */
void Scheduler::addStudentSpecificConstraints(size_t i, boost::optional<size_t> slot) { void Scheduler::addStudentSpecificConstraints(size_t i,
boost::optional<size_t> slot) {
bool debug = ISDEBUG("Scheduler::buildGraph"); bool debug = ISDEBUG("Scheduler::buildGraph");
assert(i<nrStudents()); assert(i < nrStudents());
const Student& s = students_[i]; const Student& s = students_[i];
if (!slot && !slotsAvailable_.empty()) { if (!slot && !slotsAvailable_.empty()) {
if (debug) cout << "Adding availability of slots" << endl; if (debug) cout << "Adding availability of slots" << endl;
assert(slotsAvailable_.size()==s.key_.second); assert(slotsAvailable_.size() == s.key_.second);
CSP::add(s.key_, slotsAvailable_); CSP::add(s.key_, slotsAvailable_);
} }
// For all areas // For all areas
for (size_t area = 0; area < 3; area++) { for (size_t area = 0; area < 3; area++) {
DiscreteKey areaKey = s.keys_[area]; DiscreteKey areaKey = s.keys_[area];
const string& areaName = s.areaName_[area]; const string& areaName = s.areaName_[area];
if (debug) cout << "Area constraints " << areaName << endl; if (debug) cout << "Area constraints " << areaName << endl;
assert(facultyInArea_[areaName].size()==areaKey.second); assert(facultyInArea_[areaName].size() == areaKey.second);
CSP::add(areaKey, facultyInArea_[areaName]); CSP::add(areaKey, facultyInArea_[areaName]);
if (debug) cout << "Advisor constraint " << areaName << endl; if (debug) cout << "Advisor constraint " << areaName << endl;
assert(s.advisor_.size()==areaKey.second); assert(s.advisor_.size() == areaKey.second);
CSP::add(areaKey, s.advisor_); CSP::add(areaKey, s.advisor_);
if (debug) cout << "Availability of faculty " << areaName << endl; if (debug) cout << "Availability of faculty " << areaName << endl;
if (slot) { if (slot) {
// get all constraints then specialize to slot // get all constraints then specialize to slot
size_t dummyIndex = maxNrStudents_*3+maxNrStudents_; size_t dummyIndex = maxNrStudents_ * 3 + maxNrStudents_;
DiscreteKey dummy(dummyIndex, nrTimeSlots()); DiscreteKey dummy(dummyIndex, nrTimeSlots());
Potentials::ADT p(dummy & areaKey, available_); // available_ is Doodle string Potentials::ADT p(dummy & areaKey,
available_); // available_ is Doodle string
Potentials::ADT q = p.choose(dummyIndex, *slot); Potentials::ADT q = p.choose(dummyIndex, *slot);
DiscreteFactor::shared_ptr f(new DecisionTreeFactor(areaKey, q)); DiscreteFactor::shared_ptr f(new DecisionTreeFactor(areaKey, q));
CSP::push_back(f); CSP::push_back(f);
@ -145,25 +143,22 @@ namespace gtsam {
// add mutex // add mutex
if (debug) cout << "Mutex for faculty" << endl; if (debug) cout << "Mutex for faculty" << endl;
addAllDiff(s.keys_[0] & s.keys_[1] & s.keys_[2]); addAllDiff(s.keys_[0] & s.keys_[1] & s.keys_[2]);
} }
/** Main routine that builds factor graph */
/** Main routine that builds factor graph */ void Scheduler::buildGraph(size_t mutexBound) {
void Scheduler::buildGraph(size_t mutexBound) {
bool debug = ISDEBUG("Scheduler::buildGraph"); bool debug = ISDEBUG("Scheduler::buildGraph");
if (debug) cout << "Adding student-specific constraints" << endl; if (debug) cout << "Adding student-specific constraints" << endl;
for (size_t i = 0; i < nrStudents(); i++) for (size_t i = 0; i < nrStudents(); i++) addStudentSpecificConstraints(i);
addStudentSpecificConstraints(i);
// special constraint for MN // special constraint for MN
if (studentName(0) == "Michael N") CSP::add(studentKey(0), if (studentName(0) == "Michael N")
"0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1"); CSP::add(studentKey(0), "0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1");
if (!mutexBound) { if (!mutexBound) {
DiscreteKeys dkeys; DiscreteKeys dkeys;
for(const Student& s: students_) for (const Student& s : students_) dkeys.push_back(s.key_);
dkeys.push_back(s.key_);
addAllDiff(dkeys); addAllDiff(dkeys);
} else { } else {
if (debug) cout << "Mutex for Students" << endl; if (debug) cout << "Mutex for Students" << endl;
@ -175,103 +170,98 @@ namespace gtsam {
} }
} }
} }
} // buildGraph } // buildGraph
/** print */ /** print */
void Scheduler::print(const string& s, const KeyFormatter& formatter) const { void Scheduler::print(const string& s, const KeyFormatter& formatter) const {
cout << s << " Faculty:" << endl; cout << s << " Faculty:" << endl;
for(const string& name: facultyName_) for (const string& name : facultyName_) cout << name << '\n';
cout << name << '\n';
cout << endl; cout << endl;
cout << s << " Slots:\n"; cout << s << " Slots:\n";
size_t i = 0; size_t i = 0;
for(const string& name: slotName_) for (const string& name : slotName_) cout << i++ << " " << name << endl;
cout << i++ << " " << name << endl;
cout << endl; cout << endl;
cout << "Availability:\n" << available_ << '\n'; cout << "Availability:\n" << available_ << '\n';
cout << s << " Area constraints:\n"; cout << s << " Area constraints:\n";
for(const FacultyInArea::value_type& it: facultyInArea_) for (const FacultyInArea::value_type& it : facultyInArea_) {
{
cout << setw(12) << it.first << ": "; cout << setw(12) << it.first << ": ";
for(double v: it.second) for (double v : it.second) cout << v << " ";
cout << v << " ";
cout << '\n'; cout << '\n';
} }
cout << endl; cout << endl;
cout << s << " Students:\n"; cout << s << " Students:\n";
for (const Student& student: students_) for (const Student& student : students_) student.print();
student.print();
cout << endl; cout << endl;
CSP::print(s + " Factor graph"); CSP::print(s + " Factor graph");
cout << endl; cout << endl;
} // print } // print
/** Print readable form of assignment */ /** Print readable form of assignment */
void Scheduler::printAssignment(sharedValues assignment) const { void Scheduler::printAssignment(sharedValues assignment) const {
// Not intended to be general! Assumes very particular ordering ! // Not intended to be general! Assumes very particular ordering !
cout << endl; cout << endl;
for (size_t s = 0; s < nrStudents(); s++) { for (size_t s = 0; s < nrStudents(); s++) {
Key j = 3*maxNrStudents_ + s; Key j = 3 * maxNrStudents_ + s;
size_t slot = assignment->at(j); size_t slot = assignment->at(j);
cout << studentName(s) << " slot: " << slotName_[slot] << endl; cout << studentName(s) << " slot: " << slotName_[slot] << endl;
Key base = 3*s; Key base = 3 * s;
for (size_t area = 0; area < 3; area++) { for (size_t area = 0; area < 3; area++) {
size_t faculty = assignment->at(base+area); size_t faculty = assignment->at(base + area);
cout << setw(12) << studentArea(s,area) << ": " << facultyName_[faculty] cout << setw(12) << studentArea(s, area) << ": " << facultyName_[faculty]
<< endl; << endl;
} }
cout << endl; cout << endl;
} }
} }
/** Special print for single-student case */ /** Special print for single-student case */
void Scheduler::printSpecial(sharedValues assignment) const { void Scheduler::printSpecial(sharedValues assignment) const {
Values::const_iterator it = assignment->begin(); Values::const_iterator it = assignment->begin();
for (size_t area = 0; area < 3; area++, it++) { for (size_t area = 0; area < 3; area++, it++) {
size_t f = it->second; size_t f = it->second;
cout << setw(12) << studentArea(0,area) << ": " << facultyName_[f] << endl; cout << setw(12) << studentArea(0, area) << ": " << facultyName_[f] << endl;
} }
cout << endl; cout << endl;
} }
/** Accumulate faculty stats */ /** Accumulate faculty stats */
void Scheduler::accumulateStats(sharedValues assignment, vector< void Scheduler::accumulateStats(sharedValues assignment,
size_t>& stats) const { vector<size_t>& stats) const {
for (size_t s = 0; s < nrStudents(); s++) { for (size_t s = 0; s < nrStudents(); s++) {
Key base = 3*s; Key base = 3 * s;
for (size_t area = 0; area < 3; area++) { for (size_t area = 0; area < 3; area++) {
size_t f = assignment->at(base+area); size_t f = assignment->at(base + area);
assert(f<stats.size()); assert(f < stats.size());
stats[f]++; stats[f]++;
} // area } // area
} // s } // s
} }
/** Eliminate, return a Bayes net */ /** Eliminate, return a Bayes net */
DiscreteBayesNet::shared_ptr Scheduler::eliminate() const { DiscreteBayesNet::shared_ptr Scheduler::eliminate() const {
gttic(my_eliminate); gttic(my_eliminate);
// TODO: fix this!! // TODO: fix this!!
size_t maxKey = keys().size(); size_t maxKey = keys().size();
Ordering defaultKeyOrdering; Ordering defaultKeyOrdering;
for (size_t i = 0; i<maxKey; ++i) for (size_t i = 0; i < maxKey; ++i) defaultKeyOrdering += Key(i);
defaultKeyOrdering += Key(i); DiscreteBayesNet::shared_ptr chordal =
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(defaultKeyOrdering); this->eliminateSequential(defaultKeyOrdering);
gttoc(my_eliminate); gttoc(my_eliminate);
return chordal; return chordal;
} }
/** Find the best total assignment - can be expensive */ /** Find the best total assignment - can be expensive */
Scheduler::sharedValues Scheduler::optimalAssignment() const { Scheduler::sharedValues Scheduler::optimalAssignment() const {
DiscreteBayesNet::shared_ptr chordal = eliminate(); DiscreteBayesNet::shared_ptr chordal = eliminate();
if (ISDEBUG("Scheduler::optimalAssignment")) { if (ISDEBUG("Scheduler::optimalAssignment")) {
DiscreteBayesNet::const_iterator it = chordal->end()-1; DiscreteBayesNet::const_iterator it = chordal->end() - 1;
const Student & student = students_.front(); const Student& student = students_.front();
cout << endl; cout << endl;
(*it)->print(student.name_); (*it)->print(student.name_);
} }
@ -280,23 +270,21 @@ namespace gtsam {
sharedValues mpe = chordal->optimize(); sharedValues mpe = chordal->optimize();
gttoc(my_optimize); gttoc(my_optimize);
return mpe; return mpe;
} }
/** find the assignment of students to slots with most possible committees */ /** find the assignment of students to slots with most possible committees */
Scheduler::sharedValues Scheduler::bestSchedule() const { Scheduler::sharedValues Scheduler::bestSchedule() const {
sharedValues best; sharedValues best;
throw runtime_error("bestSchedule not implemented"); throw runtime_error("bestSchedule not implemented");
return best; return best;
} }
/** find the corresponding most desirable committee assignment */ /** find the corresponding most desirable committee assignment */
Scheduler::sharedValues Scheduler::bestAssignment( Scheduler::sharedValues Scheduler::bestAssignment(
sharedValues bestSchedule) const { sharedValues bestSchedule) const {
sharedValues best; sharedValues best;
throw runtime_error("bestAssignment not implemented"); throw runtime_error("bestAssignment not implemented");
return best; return best;
} }
} // gtsam
} // namespace gtsam

View File

@ -11,17 +11,15 @@
namespace gtsam { namespace gtsam {
/** /**
* Scheduler class * Scheduler class
* Creates one variable for each student, and three variables for each * Creates one variable for each student, and three variables for each
* of the student's areas, for a total of 4*nrStudents variables. * of the student's areas, for a total of 4*nrStudents variables.
* The "student" variable will determine when the student takes the qual. * The "student" variable will determine when the student takes the qual.
* The "area" variables determine which faculty are on his/her committee. * The "area" variables determine which faculty are on his/her committee.
*/ */
class GTSAM_UNSTABLE_EXPORT Scheduler : public CSP { class GTSAM_UNSTABLE_EXPORT Scheduler : public CSP {
private: private:
/** Internal data structure for students */ /** Internal data structure for students */
struct Student { struct Student {
std::string name_; std::string name_;
@ -29,15 +27,14 @@ namespace gtsam {
std::vector<DiscreteKey> keys_; // key for areas std::vector<DiscreteKey> keys_; // key for areas
std::vector<std::string> areaName_; std::vector<std::string> areaName_;
std::vector<double> advisor_; std::vector<double> advisor_;
Student(size_t nrFaculty, size_t advisorIndex) : Student(size_t nrFaculty, size_t advisorIndex)
keys_(3), areaName_(3), advisor_(nrFaculty, 1.0) { : keys_(3), areaName_(3), advisor_(nrFaculty, 1.0) {
advisor_[advisorIndex] = 0.0; advisor_[advisorIndex] = 0.0;
} }
void print() const { void print() const {
using std::cout; using std::cout;
cout << name_ << ": "; cout << name_ << ": ";
for (size_t area = 0; area < 3; area++) for (size_t area = 0; area < 3; area++) cout << areaName_[area] << " ";
cout << areaName_[area] << " ";
cout << std::endl; cout << std::endl;
} }
}; };
@ -63,7 +60,6 @@ namespace gtsam {
std::vector<double> slotsAvailable_; std::vector<double> slotsAvailable_;
public: public:
/** /**
* Constructor * Constructor
* We need to know the number of students in advance for ordering keys. * We need to know the number of students in advance for ordering keys.
@ -79,26 +75,16 @@ namespace gtsam {
facultyName_.push_back(facultyName); facultyName_.push_back(facultyName);
} }
size_t nrFaculty() const { size_t nrFaculty() const { return facultyName_.size(); }
return facultyName_.size();
}
/** boolean std::string of nrTimeSlots * nrFaculty */ /** boolean std::string of nrTimeSlots * nrFaculty */
void setAvailability(const std::string& available) { void setAvailability(const std::string& available) { available_ = available; }
available_ = available;
}
void addSlot(const std::string& slotName) { void addSlot(const std::string& slotName) { slotName_.push_back(slotName); }
slotName_.push_back(slotName);
}
size_t nrTimeSlots() const { size_t nrTimeSlots() const { return slotName_.size(); }
return slotName_.size();
}
const std::string& slotName(size_t s) const { const std::string& slotName(size_t s) const { return slotName_[s]; }
return slotName_[s];
}
/** slots available, boolean */ /** slots available, boolean */
void setSlotsAvailable(const std::vector<double>& slotsAvailable) { void setSlotsAvailable(const std::vector<double>& slotsAvailable) {
@ -107,7 +93,8 @@ namespace gtsam {
void addArea(const std::string& facultyName, const std::string& areaName) { void addArea(const std::string& facultyName, const std::string& areaName) {
areaName_.push_back(areaName); areaName_.push_back(areaName);
std::vector<double>& table = facultyInArea_[areaName]; // will create if needed std::vector<double>& table =
facultyInArea_[areaName]; // will create if needed
if (table.empty()) table.resize(nrFaculty(), 0); if (table.empty()) table.resize(nrFaculty(), 0);
table[facultyIndex_[facultyName]] = 1; table[facultyIndex_[facultyName]] = 1;
} }
@ -119,7 +106,8 @@ namespace gtsam {
Scheduler(size_t maxNrStudents, const std::string& filename); Scheduler(size_t maxNrStudents, const std::string& filename);
/** get key for student and area, 0 is time slot itself */ /** get key for student and area, 0 is time slot itself */
const DiscreteKey& key(size_t s, boost::optional<size_t> area = boost::none) const; const DiscreteKey& key(size_t s,
boost::optional<size_t> area = boost::none) const;
/** addStudent has to be called after adding slots and faculty */ /** addStudent has to be called after adding slots and faculty */
void addStudent(const std::string& studentName, const std::string& area1, void addStudent(const std::string& studentName, const std::string& area1,
@ -127,16 +115,15 @@ namespace gtsam {
const std::string& advisor); const std::string& advisor);
/// current number of students /// current number of students
size_t nrStudents() const { size_t nrStudents() const { return students_.size(); }
return students_.size();
}
const std::string& studentName(size_t i) const; const std::string& studentName(size_t i) const;
const DiscreteKey& studentKey(size_t i) const; const DiscreteKey& studentKey(size_t i) const;
const std::string& studentArea(size_t i, size_t area) const; const std::string& studentArea(size_t i, size_t area) const;
/** Add student-specific constraints to the graph */ /** Add student-specific constraints to the graph */
void addStudentSpecificConstraints(size_t i, boost::optional<size_t> slot = boost::none); void addStudentSpecificConstraints(
size_t i, boost::optional<size_t> slot = boost::none);
/** Main routine that builds factor graph */ /** Main routine that builds factor graph */
void buildGraph(size_t mutexBound = 7); void buildGraph(size_t mutexBound = 7);
@ -168,8 +155,6 @@ namespace gtsam {
/** find the corresponding most desirable committee assignment */ /** find the corresponding most desirable committee assignment */
sharedValues bestAssignment(sharedValues bestSchedule) const; sharedValues bestAssignment(sharedValues bestSchedule) const;
}; // Scheduler }; // Scheduler
} // gtsam
} // namespace gtsam

View File

@ -5,75 +5,74 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam_unstable/discrete/SingleValue.h>
#include <gtsam_unstable/discrete/Domain.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam_unstable/discrete/Domain.h>
#include <gtsam_unstable/discrete/SingleValue.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
namespace gtsam { namespace gtsam {
using namespace std; using namespace std;
/* ************************************************************************* */ /* ************************************************************************* */
void SingleValue::print(const string& s, void SingleValue::print(const string& s, const KeyFormatter& formatter) const {
const KeyFormatter& formatter) const { cout << s << "SingleValue on "
cout << s << "SingleValue on " << "j=" << formatter(keys_[0]) << "j=" << formatter(keys_[0]) << " with value " << value_ << endl;
<< " with value " << value_ << endl; }
}
/* ************************************************************************* */ /* ************************************************************************* */
double SingleValue::operator()(const Values& values) const { double SingleValue::operator()(const Values& values) const {
return (double) (values.at(keys_[0]) == value_); return (double)(values.at(keys_[0]) == value_);
} }
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor SingleValue::toDecisionTreeFactor() const { DecisionTreeFactor SingleValue::toDecisionTreeFactor() const {
DiscreteKeys keys; DiscreteKeys keys;
keys += DiscreteKey(keys_[0],cardinality_); keys += DiscreteKey(keys_[0], cardinality_);
vector<double> table; vector<double> table;
for (size_t i1 = 0; i1 < cardinality_; i1++) for (size_t i1 = 0; i1 < cardinality_; i1++) table.push_back(i1 == value_);
table.push_back(i1 == value_);
DecisionTreeFactor converted(keys, table); DecisionTreeFactor converted(keys, table);
return converted; return converted;
} }
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const { DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const {
// TODO: can we do this more efficiently? // TODO: can we do this more efficiently?
return toDecisionTreeFactor() * f; return toDecisionTreeFactor() * f;
} }
/* ************************************************************************* */ /* ************************************************************************* */
bool SingleValue::ensureArcConsistency(size_t j, bool SingleValue::ensureArcConsistency(size_t j,
vector<Domain>& domains) const { vector<Domain>& domains) const {
if (j != keys_[0]) throw invalid_argument( if (j != keys_[0])
"SingleValue check on wrong domain"); throw invalid_argument("SingleValue check on wrong domain");
Domain& D = domains[j]; Domain& D = domains[j];
if (D.isSingleton()) { if (D.isSingleton()) {
if (D.firstValue() != value_) throw runtime_error("Unsatisfiable"); if (D.firstValue() != value_) throw runtime_error("Unsatisfiable");
return false; return false;
} }
D = Domain(discreteKey(),value_); D = Domain(discreteKey(), value_);
return true; return true;
} }
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const { Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const {
Values::const_iterator it = values.find(keys_[0]); Values::const_iterator it = values.find(keys_[0]);
if (it != values.end() && it->second != value_) throw runtime_error( if (it != values.end() && it->second != value_)
"SingleValue::partiallyApply: unsatisfiable"); throw runtime_error("SingleValue::partiallyApply: unsatisfiable");
return boost::make_shared < SingleValue > (keys_[0], cardinality_, value_); return boost::make_shared<SingleValue>(keys_[0], cardinality_, value_);
} }
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr SingleValue::partiallyApply( Constraint::shared_ptr SingleValue::partiallyApply(
const vector<Domain>& domains) const { const vector<Domain>& domains) const {
const Domain& Dk = domains[keys_[0]]; const Domain& Dk = domains[keys_[0]];
if (Dk.isSingleton() && !Dk.contains(value_)) throw runtime_error( if (Dk.isSingleton() && !Dk.contains(value_))
"SingleValue::partiallyApply: unsatisfiable"); throw runtime_error("SingleValue::partiallyApply: unsatisfiable");
return boost::make_shared < SingleValue > (discreteKey(), value_); return boost::make_shared<SingleValue>(discreteKey(), value_);
} }
/* ************************************************************************* */ /* ************************************************************************* */
} // namespace gtsam } // namespace gtsam

View File

@ -7,16 +7,15 @@
#pragma once #pragma once
#include <gtsam_unstable/discrete/Constraint.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam_unstable/discrete/Constraint.h>
namespace gtsam { namespace gtsam {
/** /**
* SingleValue constraint * SingleValue constraint
*/ */
class GTSAM_UNSTABLE_EXPORT SingleValue: public Constraint { class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
/// Number of values /// Number of values
size_t cardinality_; size_t cardinality_;
@ -24,34 +23,31 @@ namespace gtsam {
size_t value_; size_t value_;
DiscreteKey discreteKey() const { DiscreteKey discreteKey() const {
return DiscreteKey(keys_[0],cardinality_); return DiscreteKey(keys_[0], cardinality_);
} }
public: public:
typedef boost::shared_ptr<SingleValue> shared_ptr; typedef boost::shared_ptr<SingleValue> shared_ptr;
/// Constructor /// Constructor
SingleValue(Key key, size_t n, size_t value) : SingleValue(Key key, size_t n, size_t value)
Constraint(key), cardinality_(n), value_(value) { : Constraint(key), cardinality_(n), value_(value) {}
}
/// Constructor /// Constructor
SingleValue(const DiscreteKey& dkey, size_t value) : SingleValue(const DiscreteKey& dkey, size_t value)
Constraint(dkey.first), cardinality_(dkey.second), value_(value) { : Constraint(dkey.first), cardinality_(dkey.second), value_(value) {}
}
// print // print
void print(const std::string& s = "", void print(const std::string& s = "", const KeyFormatter& formatter =
const KeyFormatter& formatter = DefaultKeyFormatter) const override; DefaultKeyFormatter) const override;
/// equals /// equals
bool equals(const DiscreteFactor& other, double tol) const override { bool equals(const DiscreteFactor& other, double tol) const override {
if(!dynamic_cast<const SingleValue*>(&other)) if (!dynamic_cast<const SingleValue*>(&other))
return false; return false;
else { else {
const SingleValue& f(static_cast<const SingleValue&>(other)); const SingleValue& f(static_cast<const SingleValue&>(other));
return (cardinality_==f.cardinality_) && (value_==f.value_); return (cardinality_ == f.cardinality_) && (value_ == f.value_);
} }
} }
@ -69,7 +65,8 @@ namespace gtsam {
* @param j domain to be checked * @param j domain to be checked
* @param domains all other domains * @param domains all other domains
*/ */
bool ensureArcConsistency(size_t j, std::vector<Domain>& domains) const override; bool ensureArcConsistency(size_t j,
std::vector<Domain>& domains) const override;
/// Partially apply known values /// Partially apply known values
Constraint::shared_ptr partiallyApply(const Values& values) const override; Constraint::shared_ptr partiallyApply(const Values& values) const override;
@ -77,6 +74,6 @@ namespace gtsam {
/// Partially apply known values, domain version /// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply( Constraint::shared_ptr partiallyApply(
const std::vector<Domain>& domains) const override; const std::vector<Domain>& domains) const override;
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -7,49 +7,50 @@
#include <gtsam_unstable/discrete/CSP.h> #include <gtsam_unstable/discrete/CSP.h>
#include <gtsam_unstable/discrete/Domain.h> #include <gtsam_unstable/discrete/Domain.h>
#include <boost/assign/std/map.hpp> #include <boost/assign/std/map.hpp>
using boost::assign::insert; using boost::assign::insert;
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <iostream>
#include <fstream> #include <fstream>
#include <iostream>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
/* ************************************************************************* */ /* ************************************************************************* */
TEST_UNSAFE( BinaryAllDif, allInOne) TEST_UNSAFE(BinaryAllDif, allInOne) {
{
// Create keys and ordering // Create keys and ordering
size_t nrColors = 2; size_t nrColors = 2;
// DiscreteKey ID("Idaho", nrColors), UT("Utah", nrColors), AZ("Arizona", nrColors); // DiscreteKey ID("Idaho", nrColors), UT("Utah", nrColors), AZ("Arizona",
// nrColors);
DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors);
// Check construction and conversion // Check construction and conversion
BinaryAllDiff c1(ID, UT); BinaryAllDiff c1(ID, UT);
DecisionTreeFactor f1(ID & UT, "0 1 1 0"); DecisionTreeFactor f1(ID & UT, "0 1 1 0");
EXPECT(assert_equal(f1,c1.toDecisionTreeFactor())); EXPECT(assert_equal(f1, c1.toDecisionTreeFactor()));
// Check construction and conversion // Check construction and conversion
BinaryAllDiff c2(UT, AZ); BinaryAllDiff c2(UT, AZ);
DecisionTreeFactor f2(UT & AZ, "0 1 1 0"); DecisionTreeFactor f2(UT & AZ, "0 1 1 0");
EXPECT(assert_equal(f2,c2.toDecisionTreeFactor())); EXPECT(assert_equal(f2, c2.toDecisionTreeFactor()));
DecisionTreeFactor f3 = f1*f2; DecisionTreeFactor f3 = f1 * f2;
EXPECT(assert_equal(f3,c1*f2)); EXPECT(assert_equal(f3, c1 * f2));
EXPECT(assert_equal(f3,c2*f1)); EXPECT(assert_equal(f3, c2 * f1));
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST_UNSAFE( CSP, allInOne) TEST_UNSAFE(CSP, allInOne) {
{
// Create keys and ordering // Create keys and ordering
size_t nrColors = 2; size_t nrColors = 2;
DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors);
// Create the CSP // Create the CSP
CSP csp; CSP csp;
csp.addAllDiff(ID,UT); csp.addAllDiff(ID, UT);
csp.addAllDiff(UT,AZ); csp.addAllDiff(UT, AZ);
// Check an invalid combination, with ID==UT==AZ all same color // Check an invalid combination, with ID==UT==AZ all same color
DiscreteFactor::Values invalid; DiscreteFactor::Values invalid;
@ -69,67 +70,67 @@ TEST_UNSAFE( CSP, allInOne)
DecisionTreeFactor product = csp.product(); DecisionTreeFactor product = csp.product();
// product.dot("product"); // product.dot("product");
DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0"); DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0");
EXPECT(assert_equal(expectedProduct,product)); EXPECT(assert_equal(expectedProduct, product));
// Solve // Solve
CSP::sharedValues mpe = csp.optimalAssignment(); CSP::sharedValues mpe = csp.optimalAssignment();
CSP::Values expected; CSP::Values expected;
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1); insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1);
EXPECT(assert_equal(expected,*mpe)); EXPECT(assert_equal(expected, *mpe));
EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9);
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST_UNSAFE( CSP, WesternUS) TEST_UNSAFE(CSP, WesternUS) {
{
// Create keys // Create keys
size_t nrColors = 4; size_t nrColors = 4;
DiscreteKey DiscreteKey
// Create ordering according to example in ND-CSP.lyx // Create ordering according to example in ND-CSP.lyx
WA(0, nrColors), OR(3, nrColors), CA(1, nrColors),NV(2, nrColors), WA(0, nrColors),
ID(8, nrColors), UT(9, nrColors), AZ(10, nrColors), OR(3, nrColors), CA(1, nrColors), NV(2, nrColors), ID(8, nrColors),
MT(4, nrColors), WY(5, nrColors), CO(7, nrColors), NM(6, nrColors); UT(9, nrColors), AZ(10, nrColors), MT(4, nrColors), WY(5, nrColors),
CO(7, nrColors), NM(6, nrColors);
// Create the CSP // Create the CSP
CSP csp; CSP csp;
csp.addAllDiff(WA,ID); csp.addAllDiff(WA, ID);
csp.addAllDiff(WA,OR); csp.addAllDiff(WA, OR);
csp.addAllDiff(OR,ID); csp.addAllDiff(OR, ID);
csp.addAllDiff(OR,CA); csp.addAllDiff(OR, CA);
csp.addAllDiff(OR,NV); csp.addAllDiff(OR, NV);
csp.addAllDiff(CA,NV); csp.addAllDiff(CA, NV);
csp.addAllDiff(CA,AZ); csp.addAllDiff(CA, AZ);
csp.addAllDiff(ID,MT); csp.addAllDiff(ID, MT);
csp.addAllDiff(ID,WY); csp.addAllDiff(ID, WY);
csp.addAllDiff(ID,UT); csp.addAllDiff(ID, UT);
csp.addAllDiff(ID,NV); csp.addAllDiff(ID, NV);
csp.addAllDiff(NV,UT); csp.addAllDiff(NV, UT);
csp.addAllDiff(NV,AZ); csp.addAllDiff(NV, AZ);
csp.addAllDiff(UT,WY); csp.addAllDiff(UT, WY);
csp.addAllDiff(UT,CO); csp.addAllDiff(UT, CO);
csp.addAllDiff(UT,NM); csp.addAllDiff(UT, NM);
csp.addAllDiff(UT,AZ); csp.addAllDiff(UT, AZ);
csp.addAllDiff(AZ,CO); csp.addAllDiff(AZ, CO);
csp.addAllDiff(AZ,NM); csp.addAllDiff(AZ, NM);
csp.addAllDiff(MT,WY); csp.addAllDiff(MT, WY);
csp.addAllDiff(WY,CO); csp.addAllDiff(WY, CO);
csp.addAllDiff(CO,NM); csp.addAllDiff(CO, NM);
// Solve // Solve
Ordering ordering; Ordering ordering;
ordering += Key(0),Key(1),Key(2),Key(3),Key(4),Key(5),Key(6),Key(7),Key(8),Key(9),Key(10); ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7),
Key(8), Key(9), Key(10);
CSP::sharedValues mpe = csp.optimalAssignment(ordering); CSP::sharedValues mpe = csp.optimalAssignment(ordering);
// GTSAM_PRINT(*mpe); // GTSAM_PRINT(*mpe);
CSP::Values expected; CSP::Values expected;
insert(expected) insert(expected)(WA.first, 1)(CA.first, 1)(NV.first, 3)(OR.first, 0)(
(WA.first,1)(CA.first,1)(NV.first,3)(OR.first,0) MT.first, 1)(WY.first, 0)(NM.first, 3)(CO.first, 2)(ID.first, 2)(
(MT.first,1)(WY.first,0)(NM.first,3)(CO.first,2) UT.first, 1)(AZ.first, 0);
(ID.first,2)(UT.first,1)(AZ.first,0);
// TODO: Fix me! mpe result seems to be right. (See the printing) // TODO: Fix me! mpe result seems to be right. (See the printing)
// It has the same prob as the expected solution. // It has the same prob as the expected solution.
// Is mpe another solution, or the expected solution is unique??? // Is mpe another solution, or the expected solution is unique???
EXPECT(assert_equal(expected,*mpe)); EXPECT(assert_equal(expected, *mpe));
EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9);
// Write out the dual graph for hmetis // Write out the dual graph for hmetis
@ -142,8 +143,7 @@ TEST_UNSAFE( CSP, WesternUS)
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST_UNSAFE( CSP, AllDiff) TEST_UNSAFE(CSP, AllDiff) {
{
// Create keys and ordering // Create keys and ordering
size_t nrColors = 3; size_t nrColors = 3;
DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors);
@ -151,24 +151,25 @@ TEST_UNSAFE( CSP, AllDiff)
// Create the CSP // Create the CSP
CSP csp; CSP csp;
vector<DiscreteKey> dkeys; vector<DiscreteKey> dkeys;
dkeys += ID,UT,AZ; dkeys += ID, UT, AZ;
csp.addAllDiff(dkeys); csp.addAllDiff(dkeys);
csp.addSingleValue(AZ,2); csp.addSingleValue(AZ, 2);
// GTSAM_PRINT(csp); // GTSAM_PRINT(csp);
// Check construction and conversion // Check construction and conversion
SingleValue s(AZ,2); SingleValue s(AZ, 2);
DecisionTreeFactor f1(AZ,"0 0 1"); DecisionTreeFactor f1(AZ, "0 0 1");
EXPECT(assert_equal(f1,s.toDecisionTreeFactor())); EXPECT(assert_equal(f1, s.toDecisionTreeFactor()));
// Check construction and conversion // Check construction and conversion
AllDiff alldiff(dkeys); AllDiff alldiff(dkeys);
DecisionTreeFactor actual = alldiff.toDecisionTreeFactor(); DecisionTreeFactor actual = alldiff.toDecisionTreeFactor();
// GTSAM_PRINT(actual); // GTSAM_PRINT(actual);
// actual.dot("actual"); // actual.dot("actual");
DecisionTreeFactor f2(ID & AZ & UT, DecisionTreeFactor f2(
ID & AZ & UT,
"0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0"); "0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0");
EXPECT(assert_equal(f2,actual)); EXPECT(assert_equal(f2, actual));
// Check an invalid combination, with ID==UT==AZ all same color // Check an invalid combination, with ID==UT==AZ all same color
DiscreteFactor::Values invalid; DiscreteFactor::Values invalid;
@ -188,36 +189,36 @@ TEST_UNSAFE( CSP, AllDiff)
CSP::sharedValues mpe = csp.optimalAssignment(); CSP::sharedValues mpe = csp.optimalAssignment();
CSP::Values expected; CSP::Values expected;
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2); insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2);
EXPECT(assert_equal(expected,*mpe)); EXPECT(assert_equal(expected, *mpe));
EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9);
// Arc-consistency // Arc-consistency
vector<Domain> domains; vector<Domain> domains;
domains += Domain(ID), Domain(AZ), Domain(UT); domains += Domain(ID), Domain(AZ), Domain(UT);
SingleValue singleValue(AZ,2); SingleValue singleValue(AZ, 2);
EXPECT(singleValue.ensureArcConsistency(1,domains)); EXPECT(singleValue.ensureArcConsistency(1, domains));
EXPECT(alldiff.ensureArcConsistency(0,domains)); EXPECT(alldiff.ensureArcConsistency(0, domains));
EXPECT(!alldiff.ensureArcConsistency(1,domains)); EXPECT(!alldiff.ensureArcConsistency(1, domains));
EXPECT(alldiff.ensureArcConsistency(2,domains)); EXPECT(alldiff.ensureArcConsistency(2, domains));
LONGS_EQUAL(2,domains[0].nrValues()); LONGS_EQUAL(2, domains[0].nrValues());
LONGS_EQUAL(1,domains[1].nrValues()); LONGS_EQUAL(1, domains[1].nrValues());
LONGS_EQUAL(2,domains[2].nrValues()); LONGS_EQUAL(2, domains[2].nrValues());
// Parial application, version 1 // Parial application, version 1
DiscreteFactor::Values known; DiscreteFactor::Values known;
known[AZ.first] = 2; known[AZ.first] = 2;
DiscreteFactor::shared_ptr reduced1 = alldiff.partiallyApply(known); DiscreteFactor::shared_ptr reduced1 = alldiff.partiallyApply(known);
DecisionTreeFactor f3(ID & UT, "0 1 1 1 0 1 1 1 0"); DecisionTreeFactor f3(ID & UT, "0 1 1 1 0 1 1 1 0");
EXPECT(assert_equal(f3,reduced1->toDecisionTreeFactor())); EXPECT(assert_equal(f3, reduced1->toDecisionTreeFactor()));
DiscreteFactor::shared_ptr reduced2 = singleValue.partiallyApply(known); DiscreteFactor::shared_ptr reduced2 = singleValue.partiallyApply(known);
DecisionTreeFactor f4(AZ, "0 0 1"); DecisionTreeFactor f4(AZ, "0 0 1");
EXPECT(assert_equal(f4,reduced2->toDecisionTreeFactor())); EXPECT(assert_equal(f4, reduced2->toDecisionTreeFactor()));
// Parial application, version 2 // Parial application, version 2
DiscreteFactor::shared_ptr reduced3 = alldiff.partiallyApply(domains); DiscreteFactor::shared_ptr reduced3 = alldiff.partiallyApply(domains);
EXPECT(assert_equal(f3,reduced3->toDecisionTreeFactor())); EXPECT(assert_equal(f3, reduced3->toDecisionTreeFactor()));
DiscreteFactor::shared_ptr reduced4 = singleValue.partiallyApply(domains); DiscreteFactor::shared_ptr reduced4 = singleValue.partiallyApply(domains);
EXPECT(assert_equal(f4,reduced4->toDecisionTreeFactor())); EXPECT(assert_equal(f4, reduced4->toDecisionTreeFactor()));
// full arc-consistency test // full arc-consistency test
csp.runArcConsistency(nrColors); csp.runArcConsistency(nrColors);
@ -229,4 +230,3 @@ int main() {
return TestRegistry::runAllTests(tr); return TestRegistry::runAllTests(tr);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -5,14 +5,15 @@
* @date Oct 11, 2013 * @date Oct 11, 2013
*/ */
#include <gtsam/inference/VariableIndex.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <CppUnitLite/TestHarness.h> #include <gtsam/inference/VariableIndex.h>
#include <boost/range/adaptor/map.hpp>
#include <boost/assign/list_of.hpp> #include <boost/assign/list_of.hpp>
#include <iostream> #include <boost/range/adaptor/map.hpp>
#include <fstream> #include <fstream>
#include <iostream>
using namespace std; using namespace std;
using namespace boost; using namespace boost;
@ -23,11 +24,12 @@ using namespace gtsam;
* Loopy belief solver for graphs with only binary and unary factors * Loopy belief solver for graphs with only binary and unary factors
*/ */
class LoopyBelief { class LoopyBelief {
/** Star graph struct for each node, containing /** Star graph struct for each node, containing
* - the star graph itself * - the star graph itself
* - the product of original unary factors so we don't have to recompute it later, and * - the product of original unary factors so we don't have to recompute it
* - the factor indices of the corrected belief factors of the neighboring nodes * later, and
* - the factor indices of the corrected belief factors of the neighboring
* nodes
*/ */
typedef std::map<Key, size_t> CorrectedBeliefIndices; typedef std::map<Key, size_t> CorrectedBeliefIndices;
struct StarGraph { struct StarGraph {
@ -37,40 +39,40 @@ class LoopyBelief {
VariableIndex varIndex_; VariableIndex varIndex_;
StarGraph(const DiscreteFactorGraph::shared_ptr& _star, StarGraph(const DiscreteFactorGraph::shared_ptr& _star,
const CorrectedBeliefIndices& _beliefIndices, const CorrectedBeliefIndices& _beliefIndices,
const DecisionTreeFactor::shared_ptr& _unary) : const DecisionTreeFactor::shared_ptr& _unary)
star(_star), correctedBeliefIndices(_beliefIndices), unary(_unary), varIndex_( : star(_star),
*_star) { correctedBeliefIndices(_beliefIndices),
} unary(_unary),
varIndex_(*_star) {}
void print(const std::string& s = "") const { void print(const std::string& s = "") const {
cout << s << ":" << endl; cout << s << ":" << endl;
star->print("Star graph: "); star->print("Star graph: ");
for(Key key: correctedBeliefIndices | boost::adaptors::map_keys) { for (Key key : correctedBeliefIndices | boost::adaptors::map_keys) {
cout << "Belief factor index for " << key << ": " cout << "Belief factor index for " << key << ": "
<< correctedBeliefIndices.at(key) << endl; << correctedBeliefIndices.at(key) << endl;
} }
if (unary) if (unary) unary->print("Unary: ");
unary->print("Unary: ");
} }
}; };
typedef std::map<Key, StarGraph> StarGraphs; typedef std::map<Key, StarGraph> StarGraphs;
StarGraphs starGraphs_; ///< star graph at each variable StarGraphs starGraphs_; ///< star graph at each variable
public: public:
/** Constructor /** Constructor
* Need all discrete keys to access node's cardinality for creating belief factors * Need all discrete keys to access node's cardinality for creating belief
* factors
* TODO: so troublesome!! * TODO: so troublesome!!
*/ */
LoopyBelief(const DiscreteFactorGraph& graph, LoopyBelief(const DiscreteFactorGraph& graph,
const std::map<Key, DiscreteKey>& allDiscreteKeys) : const std::map<Key, DiscreteKey>& allDiscreteKeys)
starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) { : starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) {}
}
/// print /// print
void print(const std::string& s = "") const { void print(const std::string& s = "") const {
cout << s << ":" << endl; cout << s << ":" << endl;
for(Key key: starGraphs_ | boost::adaptors::map_keys) { for (Key key : starGraphs_ | boost::adaptors::map_keys) {
starGraphs_.at(key).print((boost::format("Node %d:") % key).str()); starGraphs_.at(key).print((boost::format("Node %d:") % key).str());
} }
} }
@ -79,12 +81,13 @@ public:
DiscreteFactorGraph::shared_ptr iterate( DiscreteFactorGraph::shared_ptr iterate(
const std::map<Key, DiscreteKey>& allDiscreteKeys) { const std::map<Key, DiscreteKey>& allDiscreteKeys) {
static const bool debug = false; static const bool debug = false;
static DiscreteConditional::shared_ptr dummyCond; // unused by-product of elimination static DiscreteConditional::shared_ptr
dummyCond; // unused by-product of elimination
DiscreteFactorGraph::shared_ptr beliefs(new DiscreteFactorGraph()); DiscreteFactorGraph::shared_ptr beliefs(new DiscreteFactorGraph());
std::map<Key, std::map<Key, DiscreteFactor::shared_ptr> > allMessages; std::map<Key, std::map<Key, DiscreteFactor::shared_ptr> > allMessages;
// Eliminate each star graph // Eliminate each star graph
for(Key key: starGraphs_ | boost::adaptors::map_keys) { for (Key key : starGraphs_ | boost::adaptors::map_keys) {
// cout << "***** Node " << key << "*****" << endl; // cout << "***** Node " << key << "*****" << endl;
// initialize belief to the unary factor from the original graph // initialize belief to the unary factor from the original graph
DecisionTreeFactor::shared_ptr beliefAtKey; DecisionTreeFactor::shared_ptr beliefAtKey;
@ -92,15 +95,16 @@ public:
std::map<Key, DiscreteFactor::shared_ptr> messages; std::map<Key, DiscreteFactor::shared_ptr> messages;
// eliminate each neighbor in this star graph one by one // eliminate each neighbor in this star graph one by one
for(Key neighbor: starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) { for (Key neighbor : starGraphs_.at(key).correctedBeliefIndices |
boost::adaptors::map_keys) {
DiscreteFactorGraph subGraph; DiscreteFactorGraph subGraph;
for(size_t factor: starGraphs_.at(key).varIndex_[neighbor]) { for (size_t factor : starGraphs_.at(key).varIndex_[neighbor]) {
subGraph.push_back(starGraphs_.at(key).star->at(factor)); subGraph.push_back(starGraphs_.at(key).star->at(factor));
} }
if (debug) subGraph.print("------- Subgraph:"); if (debug) subGraph.print("------- Subgraph:");
DiscreteFactor::shared_ptr message; DiscreteFactor::shared_ptr message;
boost::tie(dummyCond, message) = EliminateDiscrete(subGraph, boost::tie(dummyCond, message) =
Ordering(list_of(neighbor))); EliminateDiscrete(subGraph, Ordering(list_of(neighbor)));
// store the new factor into messages // store the new factor into messages
messages.insert(make_pair(neighbor, message)); messages.insert(make_pair(neighbor, message));
if (debug) message->print("------- Message: "); if (debug) message->print("------- Message: ");
@ -108,14 +112,12 @@ public:
// Belief is the product of all messages and the unary factor // Belief is the product of all messages and the unary factor
// Incorporate new the factor to belief // Incorporate new the factor to belief
if (!beliefAtKey) if (!beliefAtKey)
beliefAtKey = boost::dynamic_pointer_cast<DecisionTreeFactor>(
message);
else
beliefAtKey = beliefAtKey =
boost::make_shared<DecisionTreeFactor>( boost::dynamic_pointer_cast<DecisionTreeFactor>(message);
(*beliefAtKey) else
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>( beliefAtKey = boost::make_shared<DecisionTreeFactor>(
message))); (*beliefAtKey) *
(*boost::dynamic_pointer_cast<DecisionTreeFactor>(message)));
} }
if (starGraphs_.at(key).unary) if (starGraphs_.at(key).unary)
beliefAtKey = boost::make_shared<DecisionTreeFactor>( beliefAtKey = boost::make_shared<DecisionTreeFactor>(
@ -133,7 +135,8 @@ public:
sumFactorTable = (boost::format("%s %f") % sumFactorTable % sum).str(); sumFactorTable = (boost::format("%s %f") % sumFactorTable % sum).str();
DecisionTreeFactor sumFactor(allDiscreteKeys.at(key), sumFactorTable); DecisionTreeFactor sumFactor(allDiscreteKeys.at(key), sumFactorTable);
if (debug) sumFactor.print("denomFactor: "); if (debug) sumFactor.print("denomFactor: ");
beliefAtKey = boost::make_shared<DecisionTreeFactor>((*beliefAtKey) / sumFactor); beliefAtKey =
boost::make_shared<DecisionTreeFactor>((*beliefAtKey) / sumFactor);
if (debug) beliefAtKey->print("New belief at key normalized: "); if (debug) beliefAtKey->print("New belief at key normalized: ");
beliefs->push_back(beliefAtKey); beliefs->push_back(beliefAtKey);
allMessages[key] = messages; allMessages[key] = messages;
@ -141,17 +144,20 @@ public:
// Update corrected beliefs // Update corrected beliefs
VariableIndex beliefFactors(*beliefs); VariableIndex beliefFactors(*beliefs);
for(Key key: starGraphs_ | boost::adaptors::map_keys) { for (Key key : starGraphs_ | boost::adaptors::map_keys) {
std::map<Key, DiscreteFactor::shared_ptr> messages = allMessages[key]; std::map<Key, DiscreteFactor::shared_ptr> messages = allMessages[key];
for(Key neighbor: starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) { for (Key neighbor : starGraphs_.at(key).correctedBeliefIndices |
DecisionTreeFactor correctedBelief = (*boost::dynamic_pointer_cast< boost::adaptors::map_keys) {
DecisionTreeFactor>(beliefs->at(beliefFactors[key].front()))) DecisionTreeFactor correctedBelief =
/ (*boost::dynamic_pointer_cast<DecisionTreeFactor>( (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
beliefs->at(beliefFactors[key].front()))) /
(*boost::dynamic_pointer_cast<DecisionTreeFactor>(
messages.at(neighbor))); messages.at(neighbor)));
if (debug) correctedBelief.print("correctedBelief: "); if (debug) correctedBelief.print("correctedBelief: ");
size_t beliefIndex = starGraphs_.at(neighbor).correctedBeliefIndices.at( size_t beliefIndex =
key); starGraphs_.at(neighbor).correctedBeliefIndices.at(key);
starGraphs_.at(neighbor).star->replace(beliefIndex, starGraphs_.at(neighbor).star->replace(
beliefIndex,
boost::make_shared<DecisionTreeFactor>(correctedBelief)); boost::make_shared<DecisionTreeFactor>(correctedBelief));
} }
} }
@ -161,21 +167,22 @@ public:
return beliefs; return beliefs;
} }
private: private:
/** /**
* Build star graphs for each node. * Build star graphs for each node.
*/ */
StarGraphs buildStarGraphs(const DiscreteFactorGraph& graph, StarGraphs buildStarGraphs(
const DiscreteFactorGraph& graph,
const std::map<Key, DiscreteKey>& allDiscreteKeys) const { const std::map<Key, DiscreteKey>& allDiscreteKeys) const {
StarGraphs starGraphs; StarGraphs starGraphs;
VariableIndex varIndex(graph); ///< access to all factors of each node VariableIndex varIndex(graph); ///< access to all factors of each node
for(Key key: varIndex | boost::adaptors::map_keys) { for (Key key : varIndex | boost::adaptors::map_keys) {
// initialize to multiply with other unary factors later // initialize to multiply with other unary factors later
DecisionTreeFactor::shared_ptr prodOfUnaries; DecisionTreeFactor::shared_ptr prodOfUnaries;
// collect all factors involving this key in the original graph // collect all factors involving this key in the original graph
DiscreteFactorGraph::shared_ptr star(new DiscreteFactorGraph()); DiscreteFactorGraph::shared_ptr star(new DiscreteFactorGraph());
for(size_t factorIndex: varIndex[key]) { for (size_t factorIndex : varIndex[key]) {
star->push_back(graph.at(factorIndex)); star->push_back(graph.at(factorIndex));
// accumulate unary factors // accumulate unary factors
@ -185,8 +192,8 @@ private:
graph.at(factorIndex)); graph.at(factorIndex));
else else
prodOfUnaries = boost::make_shared<DecisionTreeFactor>( prodOfUnaries = boost::make_shared<DecisionTreeFactor>(
*prodOfUnaries *prodOfUnaries *
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>( (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
graph.at(factorIndex)))); graph.at(factorIndex))));
} }
} }
@ -196,7 +203,7 @@ private:
KeySet neighbors = star->keys(); KeySet neighbors = star->keys();
neighbors.erase(key); neighbors.erase(key);
CorrectedBeliefIndices correctedBeliefIndices; CorrectedBeliefIndices correctedBeliefIndices;
for(Key neighbor: neighbors) { for (Key neighbor : neighbors) {
// TODO: default table for keys with more than 2 values? // TODO: default table for keys with more than 2 values?
string initialBelief; string initialBelief;
for (size_t v = 0; v < allDiscreteKeys.at(neighbor).second - 1; ++v) { for (size_t v = 0; v < allDiscreteKeys.at(neighbor).second - 1; ++v) {
@ -207,9 +214,8 @@ private:
DecisionTreeFactor(allDiscreteKeys.at(neighbor), initialBelief)); DecisionTreeFactor(allDiscreteKeys.at(neighbor), initialBelief));
correctedBeliefIndices.insert(make_pair(neighbor, star->size() - 1)); correctedBeliefIndices.insert(make_pair(neighbor, star->size() - 1));
} }
starGraphs.insert( starGraphs.insert(make_pair(
make_pair(key, key, StarGraph(star, correctedBeliefIndices, prodOfUnaries)));
StarGraph(star, correctedBeliefIndices, prodOfUnaries)));
} }
return starGraphs; return starGraphs;
} }
@ -249,7 +255,6 @@ TEST_UNSAFE(LoopyBelief, construction) {
DiscreteFactorGraph::shared_ptr beliefs = solver.iterate(allKeys); DiscreteFactorGraph::shared_ptr beliefs = solver.iterate(allKeys);
beliefs->print(); beliefs->print();
} }
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -5,14 +5,13 @@
*/ */
//#define ENABLE_TIMING //#define ENABLE_TIMING
#include <gtsam_unstable/discrete/Scheduler.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/base/timing.h> #include <gtsam/base/timing.h>
#include <gtsam_unstable/discrete/Scheduler.h>
#include <CppUnitLite/TestHarness.h>
#include <boost/assign/std/vector.hpp>
#include <boost/assign/std/map.hpp> #include <boost/assign/std/map.hpp>
#include <boost/assign/std/vector.hpp>
#include <boost/optional.hpp> #include <boost/optional.hpp>
using namespace boost::assign; using namespace boost::assign;
@ -22,7 +21,6 @@ using namespace gtsam;
/* ************************************************************************* */ /* ************************************************************************* */
// Create the expected graph of constraints // Create the expected graph of constraints
DiscreteFactorGraph createExpected() { DiscreteFactorGraph createExpected() {
// Start building // Start building
size_t nrFaculty = 4, nrTimeSlots = 3; size_t nrFaculty = 4, nrTimeSlots = 3;
@ -79,8 +77,7 @@ DiscreteFactorGraph createExpected() {
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( schedulingExample, test) TEST(schedulingExample, test) {
{
Scheduler s(2); Scheduler s(2);
// add faculty // add faculty
@ -121,7 +118,7 @@ TEST( schedulingExample, test)
// Do brute force product and output that to file // Do brute force product and output that to file
DecisionTreeFactor product = s.product(); DecisionTreeFactor product = s.product();
//product.dot("scheduling", false); // product.dot("scheduling", false);
// Do exact inference // Do exact inference
gttic(small); gttic(small);
@ -129,25 +126,24 @@ TEST( schedulingExample, test)
gttoc(small); gttoc(small);
// print MPE, commented out as unit tests don't print // print MPE, commented out as unit tests don't print
// s.printAssignment(MPE); // s.printAssignment(MPE);
// Commented out as does not work yet // Commented out as does not work yet
// s.runArcConsistency(8,10,true); // s.runArcConsistency(8,10,true);
// find the assignment of students to slots with most possible committees // find the assignment of students to slots with most possible committees
// Commented out as not implemented yet // Commented out as not implemented yet
// sharedValues bestSchedule = s.bestSchedule(); // sharedValues bestSchedule = s.bestSchedule();
// GTSAM_PRINT(*bestSchedule); // GTSAM_PRINT(*bestSchedule);
// find the corresponding most desirable committee assignment // find the corresponding most desirable committee assignment
// Commented out as not implemented yet // Commented out as not implemented yet
// sharedValues bestAssignment = s.bestAssignment(bestSchedule); // sharedValues bestAssignment = s.bestAssignment(bestSchedule);
// GTSAM_PRINT(*bestAssignment); // GTSAM_PRINT(*bestAssignment);
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( schedulingExample, smallFromFile) TEST(schedulingExample, smallFromFile) {
{
string path(TOPSRCDIR "/gtsam_unstable/discrete/examples/"); string path(TOPSRCDIR "/gtsam_unstable/discrete/examples/");
Scheduler s(2, path + "small.csv"); Scheduler s(2, path + "small.csv");
@ -179,4 +175,3 @@ int main() {
return TestRegistry::runAllTests(tr); return TestRegistry::runAllTests(tr);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -5,21 +5,22 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam_unstable/discrete/CSP.h>
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam_unstable/discrete/CSP.h>
#include <boost/assign/std/map.hpp> #include <boost/assign/std/map.hpp>
using boost::assign::insert; using boost::assign::insert;
#include <stdarg.h>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <stdarg.h>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
#define PRINT false #define PRINT false
class Sudoku: public CSP { class Sudoku : public CSP {
/// sudoku size /// sudoku size
size_t n_; size_t n_;
@ -27,25 +28,21 @@ class Sudoku: public CSP {
typedef std::pair<size_t, size_t> IJ; typedef std::pair<size_t, size_t> IJ;
std::map<IJ, DiscreteKey> dkeys_; std::map<IJ, DiscreteKey> dkeys_;
public: public:
/// return DiscreteKey for cell(i,j) /// return DiscreteKey for cell(i,j)
const DiscreteKey& dkey(size_t i, size_t j) const { const DiscreteKey& dkey(size_t i, size_t j) const {
return dkeys_.at(IJ(i, j)); return dkeys_.at(IJ(i, j));
} }
/// return Key for cell(i,j) /// return Key for cell(i,j)
Key key(size_t i, size_t j) const { Key key(size_t i, size_t j) const { return dkey(i, j).first; }
return dkey(i, j).first;
}
/// Constructor /// Constructor
Sudoku(size_t n, ...) : Sudoku(size_t n, ...) : n_(n) {
n_(n) {
// Create variables, ordering, and unary constraints // Create variables, ordering, and unary constraints
va_list ap; va_list ap;
va_start(ap, n); va_start(ap, n);
Key k=0; Key k = 0;
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
for (size_t j = 0; j < n; ++j, ++k) { for (size_t j = 0; j < n; ++j, ++k) {
// create the key // create the key
@ -56,23 +53,21 @@ public:
// cout << value << " "; // cout << value << " ";
if (value != 0) addSingleValue(dkeys_[ij], value - 1); if (value != 0) addSingleValue(dkeys_[ij], value - 1);
} }
//cout << endl; // cout << endl;
} }
va_end(ap); va_end(ap);
// add row constraints // add row constraints
for (size_t i = 0; i < n; i++) { for (size_t i = 0; i < n; i++) {
DiscreteKeys dkeys; DiscreteKeys dkeys;
for (size_t j = 0; j < n; j++) for (size_t j = 0; j < n; j++) dkeys += dkey(i, j);
dkeys += dkey(i, j);
addAllDiff(dkeys); addAllDiff(dkeys);
} }
// add col constraints // add col constraints
for (size_t j = 0; j < n; j++) { for (size_t j = 0; j < n; j++) {
DiscreteKeys dkeys; DiscreteKeys dkeys;
for (size_t i = 0; i < n; i++) for (size_t i = 0; i < n; i++) dkeys += dkey(i, j);
dkeys += dkey(i, j);
addAllDiff(dkeys); addAllDiff(dkeys);
} }
@ -84,8 +79,7 @@ public:
// Box I,J // Box I,J
DiscreteKeys dkeys; DiscreteKeys dkeys;
for (size_t i = i0; i < i0 + N; i++) for (size_t i = i0; i < i0 + N; i++)
for (size_t j = j0; j < j0 + N; j++) for (size_t j = j0; j < j0 + N; j++) dkeys += dkey(i, j);
dkeys += dkey(i, j);
addAllDiff(dkeys); addAllDiff(dkeys);
j0 += N; j0 += N;
} }
@ -109,74 +103,59 @@ public:
DiscreteFactor::sharedValues MPE = optimalAssignment(); DiscreteFactor::sharedValues MPE = optimalAssignment();
printAssignment(MPE); printAssignment(MPE);
} }
}; };
/* ************************************************************************* */ /* ************************************************************************* */
TEST_UNSAFE( Sudoku, small) TEST_UNSAFE(Sudoku, small) {
{ Sudoku csp(4, 1, 0, 0, 4, 0, 0, 0, 0,
Sudoku csp(4,
1,0, 0,4,
0,0, 0,0,
4,0, 2,0, 4, 0, 2, 0, 0, 1, 0, 0);
0,1, 0,0);
// Do BP // Do BP
csp.runArcConsistency(4,10,PRINT); csp.runArcConsistency(4, 10, PRINT);
// optimize and check // optimize and check
CSP::sharedValues solution = csp.optimalAssignment(); CSP::sharedValues solution = csp.optimalAssignment();
CSP::Values expected; CSP::Values expected;
insert(expected) insert(expected)(csp.key(0, 0), 0)(csp.key(0, 1), 1)(csp.key(0, 2), 2)(
(csp.key(0,0), 0)(csp.key(0,1), 1)(csp.key(0,2), 2)(csp.key(0,3), 3) csp.key(0, 3), 3)(csp.key(1, 0), 2)(csp.key(1, 1), 3)(csp.key(1, 2), 0)(
(csp.key(1,0), 2)(csp.key(1,1), 3)(csp.key(1,2), 0)(csp.key(1,3), 1) csp.key(1, 3), 1)(csp.key(2, 0), 3)(csp.key(2, 1), 2)(csp.key(2, 2), 1)(
(csp.key(2,0), 3)(csp.key(2,1), 2)(csp.key(2,2), 1)(csp.key(2,3), 0) csp.key(2, 3), 0)(csp.key(3, 0), 1)(csp.key(3, 1), 0)(csp.key(3, 2), 3)(
(csp.key(3,0), 1)(csp.key(3,1), 0)(csp.key(3,2), 3)(csp.key(3,3), 2); csp.key(3, 3), 2);
EXPECT(assert_equal(expected,*solution)); EXPECT(assert_equal(expected, *solution));
//csp.printAssignment(solution); // csp.printAssignment(solution);
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST_UNSAFE( Sudoku, easy) TEST_UNSAFE(Sudoku, easy) {
{ Sudoku sudoku(9, 0, 0, 5, 0, 9, 0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 7, 3, 7, 6,
Sudoku sudoku(9, 0, 0, 0, 8, 2, 0, 0,
0,0,5, 0,9,0, 0,0,1,
0,0,0, 0,0,2, 0,7,3,
7,6,0, 0,0,8, 2,0,0,
0,1,2, 0,0,9, 0,0,4, 0, 1, 2, 0, 0, 9, 0, 0, 4, 0, 0, 0, 2, 0, 3, 0, 0, 0, 3, 0, 0,
0,0,0, 2,0,3, 0,0,0, 1, 0, 0, 9, 6, 0,
3,0,0, 1,0,0, 9,6,0,
0,0,1, 9,0,0, 0,5,8, 0, 0, 1, 9, 0, 0, 0, 5, 8, 9, 7, 0, 5, 0, 0, 0, 0, 0, 5, 0, 0,
9,7,0, 5,0,0, 0,0,0, 0, 3, 0, 7, 0, 0);
5,0,0, 0,3,0, 7,0,0);
// Do BP // Do BP
sudoku.runArcConsistency(4,10,PRINT); sudoku.runArcConsistency(4, 10, PRINT);
// sudoku.printSolution(); // don't do it // sudoku.printSolution(); // don't do it
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST_UNSAFE( Sudoku, extreme) TEST_UNSAFE(Sudoku, extreme) {
{ Sudoku sudoku(9, 0, 0, 9, 7, 4, 8, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,
Sudoku sudoku(9, 0, 1, 0, 9, 0, 0, 0,
0,0,9, 7,4,8, 0,0,0,
7,0,0, 0,0,0, 0,0,0,
0,2,0, 1,0,9, 0,0,0,
0,0,7, 0,0,0, 2,4,0, 0, 0, 7, 0, 0, 0, 2, 4, 0, 0, 6, 4, 0, 1, 0, 5, 9, 0, 0, 9, 8,
0,6,4, 0,1,0, 5,9,0, 0, 0, 0, 3, 0, 0,
0,9,8, 0,0,0, 3,0,0,
0,0,0, 8,0,3, 0,2,0, 0, 0, 0, 8, 0, 3, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0,
0,0,0, 0,0,0, 0,0,6, 2, 7, 5, 9, 0, 0);
0,0,0, 2,7,5, 9,0,0);
// Do BP // Do BP
sudoku.runArcConsistency(9,10,PRINT); sudoku.runArcConsistency(9, 10, PRINT);
#ifdef METIS #ifdef METIS
VariableIndexOrdered index(sudoku); VariableIndexOrdered index(sudoku);
@ -185,29 +164,24 @@ TEST_UNSAFE( Sudoku, extreme)
index.outputMetisFormat(os); index.outputMetisFormat(os);
#endif #endif
//sudoku.printSolution(); // don't do it // sudoku.printSolution(); // don't do it
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST_UNSAFE( Sudoku, AJC_3star_Feb8_2012) TEST_UNSAFE(Sudoku, AJC_3star_Feb8_2012) {
{ Sudoku sudoku(9, 9, 5, 0, 0, 0, 6, 0, 0, 0, 0, 8, 4, 0, 7, 0, 0, 0, 0, 6, 2,
Sudoku sudoku(9, 0, 5, 0, 0, 4, 0, 0,
9,5,0, 0,0,6, 0,0,0,
0,8,4, 0,7,0, 0,0,0,
6,2,0, 5,0,0, 4,0,0,
0,0,0, 2,9,0, 6,0,0, 0, 0, 0, 2, 9, 0, 6, 0, 0, 0, 9, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2,
0,9,0, 0,0,0, 0,2,0, 0, 6, 3, 0, 0, 0,
0,0,2, 0,6,3, 0,0,0,
0,0,9, 0,0,7, 0,6,8, 0, 0, 9, 0, 0, 7, 0, 6, 8, 0, 0, 0, 0, 3, 0, 2, 9, 0, 0, 0, 0,
0,0,0, 0,3,0, 2,9,0, 1, 0, 0, 0, 3, 7);
0,0,0, 1,0,0, 0,3,7);
// Do BP // Do BP
sudoku.runArcConsistency(9,10,PRINT); sudoku.runArcConsistency(9, 10, PRINT);
//sudoku.printSolution(); // don't do it // sudoku.printSolution(); // don't do it
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -216,4 +190,3 @@ int main() {
return TestRegistry::runAllTests(tr); return TestRegistry::runAllTests(tr);
} }
/* ************************************************************************* */ /* ************************************************************************* */