Formatting with Google style

release/4.3a0
Frank Dellaert 2021-11-18 10:54:00 -05:00
parent 13b0136e03
commit d27d6b60a7
16 changed files with 1301 additions and 1386 deletions

View File

@ -5,61 +5,60 @@
* @author Frank Dellaert
*/
#include <gtsam_unstable/discrete/Domain.h>
#include <gtsam_unstable/discrete/AllDiff.h>
#include <gtsam/base/Testable.h>
#include <gtsam_unstable/discrete/AllDiff.h>
#include <gtsam_unstable/discrete/Domain.h>
#include <boost/make_shared.hpp>
namespace gtsam {
/* ************************************************************************* */
AllDiff::AllDiff(const DiscreteKeys& dkeys) :
Constraint(dkeys.indices()) {
for(const DiscreteKey& dkey: dkeys)
cardinalities_.insert(dkey);
}
/* ************************************************************************* */
AllDiff::AllDiff(const DiscreteKeys& dkeys) : Constraint(dkeys.indices()) {
for (const DiscreteKey& dkey : dkeys) cardinalities_.insert(dkey);
}
/* ************************************************************************* */
void AllDiff::print(const std::string& s,
const KeyFormatter& formatter) const {
/* ************************************************************************* */
void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const {
std::cout << s << "AllDiff on ";
for (Key dkey: keys_)
std::cout << formatter(dkey) << " ";
for (Key dkey : keys_) std::cout << formatter(dkey) << " ";
std::cout << std::endl;
}
}
/* ************************************************************************* */
double AllDiff::operator()(const Values& values) const {
std::set < size_t > taken; // record values taken by keys
for(Key dkey: keys_) {
/* ************************************************************************* */
double AllDiff::operator()(const Values& values) const {
std::set<size_t> taken; // record values taken by keys
for (Key dkey : keys_) {
size_t value = values.at(dkey); // get the value for that key
if (taken.count(value)) return 0.0;// check if value alreday taken
taken.insert(value);// if not, record it as taken and keep checking
if (taken.count(value)) return 0.0; // check if value alreday taken
taken.insert(value); // if not, record it as taken and keep checking
}
return 1.0;
}
}
/* ************************************************************************* */
DecisionTreeFactor AllDiff::toDecisionTreeFactor() const {
// We will do this by converting the allDif into many BinaryAllDiff constraints
/* ************************************************************************* */
DecisionTreeFactor AllDiff::toDecisionTreeFactor() const {
// We will do this by converting the allDif into many BinaryAllDiff
// constraints
DecisionTreeFactor converted;
size_t nrKeys = keys_.size();
for (size_t i1 = 0; i1 < nrKeys; i1++)
for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) {
BinaryAllDiff binary12(discreteKey(i1),discreteKey(i2));
BinaryAllDiff binary12(discreteKey(i1), discreteKey(i2));
converted = converted * binary12.toDecisionTreeFactor();
}
return converted;
}
}
/* ************************************************************************* */
DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const {
/* ************************************************************************* */
DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const {
// TODO: can we do this more efficiently?
return toDecisionTreeFactor() * f;
}
}
/* ************************************************************************* */
bool AllDiff::ensureArcConsistency(size_t j, std::vector<Domain>& domains) const {
/* ************************************************************************* */
bool AllDiff::ensureArcConsistency(size_t j,
std::vector<Domain>& domains) const {
// Though strictly not part of allDiff, we check for
// a value in domains[j] that does not occur in any other connected domain.
// If found, we make this a singleton...
@ -70,7 +69,7 @@ namespace gtsam {
// Check all other domains for singletons and erase corresponding values
// This is the same as arc-consistency on the equivalent binary constraints
bool changed = false;
for(Key k: keys_)
for (Key k : keys_)
if (k != j) {
const Domain& Dk = domains[k];
if (Dk.isSingleton()) { // check if singleton
@ -82,30 +81,29 @@ namespace gtsam {
}
}
return changed;
}
}
/* ************************************************************************* */
Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const {
/* ************************************************************************* */
Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const {
DiscreteKeys newKeys;
// loop over keys and add them only if they do not appear in values
for(Key k: keys_)
for (Key k : keys_)
if (values.find(k) == values.end()) {
newKeys.push_back(DiscreteKey(k,cardinalities_.at(k)));
newKeys.push_back(DiscreteKey(k, cardinalities_.at(k)));
}
return boost::make_shared<AllDiff>(newKeys);
}
}
/* ************************************************************************* */
Constraint::shared_ptr AllDiff::partiallyApply(
/* ************************************************************************* */
Constraint::shared_ptr AllDiff::partiallyApply(
const std::vector<Domain>& domains) const {
DiscreteFactor::Values known;
for(Key k: keys_) {
for (Key k : keys_) {
const Domain& Dk = domains[k];
if (Dk.isSingleton())
known[k] = Dk.firstValue();
if (Dk.isSingleton()) known[k] = Dk.firstValue();
}
return partiallyApply(known);
}
}
/* ************************************************************************* */
/* ************************************************************************* */
} // namespace gtsam

View File

@ -7,44 +7,42 @@
#pragma once
#include <gtsam_unstable/discrete/BinaryAllDiff.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam_unstable/discrete/BinaryAllDiff.h>
namespace gtsam {
/**
/**
* General AllDiff constraint
* Returns 1 if values for all keys are different, 0 otherwise
* DiscreteFactors are all awkward in that they have to store two types of keys:
* for each variable we have a Key and an Key. In this factor, we
* keep the Indices locally, and the Indices are stored in IndexFactor.
*/
class GTSAM_UNSTABLE_EXPORT AllDiff: public Constraint {
std::map<Key,size_t> cardinalities_;
class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
std::map<Key, size_t> cardinalities_;
DiscreteKey discreteKey(size_t i) const {
Key j = keys_[i];
return DiscreteKey(j,cardinalities_.at(j));
return DiscreteKey(j, cardinalities_.at(j));
}
public:
/// Constructor
AllDiff(const DiscreteKeys& dkeys);
// print
void print(const std::string& s = "",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
void print(const std::string& s = "", const KeyFormatter& formatter =
DefaultKeyFormatter) const override;
/// equals
bool equals(const DiscreteFactor& other, double tol) const override {
if(!dynamic_cast<const AllDiff*>(&other))
if (!dynamic_cast<const AllDiff*>(&other))
return false;
else {
const AllDiff& f(static_cast<const AllDiff&>(other));
return cardinalities_.size() == f.cardinalities_.size()
&& std::equal(cardinalities_.begin(), cardinalities_.end(),
return cardinalities_.size() == f.cardinalities_.size() &&
std::equal(cardinalities_.begin(), cardinalities_.end(),
f.cardinalities_.begin());
}
}
@ -65,13 +63,15 @@ namespace gtsam {
* @param j domain to be checked
* @param domains all other domains
*/
bool ensureArcConsistency(size_t j, std::vector<Domain>& domains) const override;
bool ensureArcConsistency(size_t j,
std::vector<Domain>& domains) const override;
/// Partially apply known values
Constraint::shared_ptr partiallyApply(const Values&) const override;
/// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply(const std::vector<Domain>&) const override;
};
Constraint::shared_ptr partiallyApply(
const std::vector<Domain>&) const override;
};
} // namespace gtsam

View File

@ -7,33 +7,32 @@
#pragma once
#include <gtsam_unstable/discrete/Domain.h>
#include <gtsam_unstable/discrete/Constraint.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam_unstable/discrete/Constraint.h>
#include <gtsam_unstable/discrete/Domain.h>
namespace gtsam {
/**
/**
* Binary AllDiff constraint
* Returns 1 if values for two keys are different, 0 otherwise
* DiscreteFactors are all awkward in that they have to store two types of keys:
* for each variable we have a Index and an Index. In this factor, we
* keep the Indices locally, and the Indices are stored in IndexFactor.
*/
class BinaryAllDiff: public Constraint {
class BinaryAllDiff : public Constraint {
size_t cardinality0_, cardinality1_; /// cardinality
public:
/// Constructor
BinaryAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) :
Constraint(key1.first, key2.first),
cardinality0_(key1.second), cardinality1_(key2.second) {
}
BinaryAllDiff(const DiscreteKey& key1, const DiscreteKey& key2)
: Constraint(key1.first, key2.first),
cardinality0_(key1.second),
cardinality1_(key2.second) {}
// print
void print(const std::string& s = "",
void print(
const std::string& s = "",
const KeyFormatter& formatter = DefaultKeyFormatter) const override {
std::cout << s << "BinaryAllDiff on " << formatter(keys_[0]) << " and "
<< formatter(keys_[1]) << std::endl;
@ -41,28 +40,28 @@ namespace gtsam {
/// equals
bool equals(const DiscreteFactor& other, double tol) const override {
if(!dynamic_cast<const BinaryAllDiff*>(&other))
if (!dynamic_cast<const BinaryAllDiff*>(&other))
return false;
else {
const BinaryAllDiff& f(static_cast<const BinaryAllDiff&>(other));
return (cardinality0_==f.cardinality0_) && (cardinality1_==f.cardinality1_);
return (cardinality0_ == f.cardinality0_) &&
(cardinality1_ == f.cardinality1_);
}
}
/// Calculate value
double operator()(const Values& values) const override {
return (double) (values.at(keys_[0]) != values.at(keys_[1]));
return (double)(values.at(keys_[0]) != values.at(keys_[1]));
}
/// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override {
DiscreteKeys keys;
keys.push_back(DiscreteKey(keys_[0],cardinality0_));
keys.push_back(DiscreteKey(keys_[1],cardinality1_));
keys.push_back(DiscreteKey(keys_[0], cardinality0_));
keys.push_back(DiscreteKey(keys_[1], cardinality1_));
std::vector<double> table;
for (size_t i1 = 0; i1 < cardinality0_; i1++)
for (size_t i2 = 0; i2 < cardinality1_; i2++)
table.push_back(i1 != i2);
for (size_t i2 = 0; i2 < cardinality1_; i2++) table.push_back(i1 != i2);
DecisionTreeFactor converted(keys, table);
return converted;
}
@ -78,10 +77,10 @@ namespace gtsam {
* @param j domain to be checked
* @param domains all other domains
*/
///
bool ensureArcConsistency(size_t j, std::vector<Domain>& domains) const override {
// throw std::runtime_error(
// "BinaryAllDiff::ensureArcConsistency not implemented");
bool ensureArcConsistency(size_t j,
std::vector<Domain>& domains) const override {
// throw std::runtime_error(
// "BinaryAllDiff::ensureArcConsistency not implemented");
return false;
}
@ -95,6 +94,6 @@ namespace gtsam {
const std::vector<Domain>&) const override {
throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented");
}
};
};
} // namespace gtsam

View File

@ -5,29 +5,30 @@
* @author Frank Dellaert
*/
#include <gtsam_unstable/discrete/Domain.h>
#include <gtsam_unstable/discrete/CSP.h>
#include <gtsam/base/Testable.h>
#include <gtsam_unstable/discrete/CSP.h>
#include <gtsam_unstable/discrete/Domain.h>
using namespace std;
namespace gtsam {
/// Find the best total assignment - can be expensive
CSP::sharedValues CSP::optimalAssignment() const {
/// Find the best total assignment - can be expensive
CSP::sharedValues CSP::optimalAssignment() const {
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential();
sharedValues mpe = chordal->optimize();
return mpe;
}
}
/// Find the best total assignment - can be expensive
CSP::sharedValues CSP::optimalAssignment(const Ordering& ordering) const {
/// Find the best total assignment - can be expensive
CSP::sharedValues CSP::optimalAssignment(const Ordering& ordering) const {
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(ordering);
sharedValues mpe = chordal->optimize();
return mpe;
}
}
void CSP::runArcConsistency(size_t cardinality, size_t nrIterations, bool print) const {
void CSP::runArcConsistency(size_t cardinality, size_t nrIterations,
bool print) const {
// Create VariableIndex
VariableIndex index(*this);
// index.print();
@ -35,9 +36,9 @@ namespace gtsam {
size_t n = index.size();
// Initialize domains
std::vector < Domain > domains;
std::vector<Domain> domains;
for (size_t j = 0; j < n; j++)
domains.push_back(Domain(DiscreteKey(j,cardinality)));
domains.push_back(Domain(DiscreteKey(j, cardinality)));
// Create array of flags indicating a domain changed or not
std::vector<bool> changed(n);
@ -51,13 +52,16 @@ namespace gtsam {
changed[v] = false;
// loop over all factors/constraints for variable v
const FactorIndices& factors = index[v];
for(size_t f: factors) {
for (size_t f : factors) {
// if not already a singleton
if (!domains[v].isSingleton()) {
// get the constraint and call its ensureArcConsistency method
Constraint::shared_ptr constraint = boost::dynamic_pointer_cast<Constraint>((*this)[f]);
if (!constraint) throw runtime_error("CSP:runArcConsistency: non-constraint factor");
changed[v] = constraint->ensureArcConsistency(v,domains) || changed[v];
Constraint::shared_ptr constraint =
boost::dynamic_pointer_cast<Constraint>((*this)[f]);
if (!constraint)
throw runtime_error("CSP:runArcConsistency: non-constraint factor");
changed[v] =
constraint->ensureArcConsistency(v, domains) || changed[v];
}
} // f
if (changed[v]) anyChange = true;
@ -91,13 +95,14 @@ namespace gtsam {
// TODO: create a new ordering as we go, to ensure a connected graph
// KeyOrdering ordering;
// vector<Index> dkeys;
for(const DiscreteFactor::shared_ptr& f: factors_) {
Constraint::shared_ptr constraint = boost::dynamic_pointer_cast<Constraint>(f);
if (!constraint) throw runtime_error("CSP:runArcConsistency: non-constraint factor");
for (const DiscreteFactor::shared_ptr& f : factors_) {
Constraint::shared_ptr constraint =
boost::dynamic_pointer_cast<Constraint>(f);
if (!constraint)
throw runtime_error("CSP:runArcConsistency: non-constraint factor");
Constraint::shared_ptr reduced = constraint->partiallyApply(domains);
if (print) reduced->print();
}
#endif
}
} // gtsam
}
} // namespace gtsam

View File

@ -7,30 +7,28 @@
#pragma once
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam_unstable/discrete/AllDiff.h>
#include <gtsam_unstable/discrete/SingleValue.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
namespace gtsam {
/**
/**
* Constraint Satisfaction Problem class
* A specialization of a DiscreteFactorGraph.
* It knows about CSP-specific constraints and algorithms
*/
class GTSAM_UNSTABLE_EXPORT CSP: public DiscreteFactorGraph {
class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph {
public:
/** A map from keys to values */
typedef KeyVector Indices;
typedef Assignment<Key> Values;
typedef boost::shared_ptr<Values> sharedValues;
public:
// /// Constructor
// CSP() {
// }
// /// Constructor
// CSP() {
// }
/// Add a unary constraint, allowing only a single value
void addSingleValue(const DiscreteKey& dkey, size_t value) {
@ -40,8 +38,7 @@ namespace gtsam {
/// Add a binary AllDiff constraint
void addAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) {
boost::shared_ptr<BinaryAllDiff> factor(
new BinaryAllDiff(key1, key2));
boost::shared_ptr<BinaryAllDiff> factor(new BinaryAllDiff(key1, key2));
push_back(factor);
}
@ -51,13 +48,13 @@ namespace gtsam {
push_back(factor);
}
// /** return product of all factors as a single factor */
// DecisionTreeFactor product() const {
// DecisionTreeFactor result;
// for(const sharedFactor& factor: *this)
// if (factor) result = (*factor) * result;
// return result;
// }
// /** return product of all factors as a single factor */
// DecisionTreeFactor product() const {
// DecisionTreeFactor result;
// for(const sharedFactor& factor: *this)
// if (factor) result = (*factor) * result;
// return result;
// }
/// Find the best total assignment - can be expensive
sharedValues optimalAssignment() const;
@ -65,16 +62,17 @@ namespace gtsam {
/// Find the best total assignment - can be expensive
sharedValues optimalAssignment(const Ordering& ordering) const;
// /*
// * Perform loopy belief propagation
// * True belief propagation would check for each value in domain
// * whether any satisfying separator assignment can be found.
// * This corresponds to hyper-arc consistency in CSP speak.
// * This can be done by creating a mini-factor graph and search.
// * For a nine-by-nine Sudoku, the search tree will be 8+6+6=20 levels deep.
// * It will be very expensive to exclude values that way.
// */
// void applyBeliefPropagation(size_t nrIterations = 10) const;
// /*
// * Perform loopy belief propagation
// * True belief propagation would check for each value in domain
// * whether any satisfying separator assignment can be found.
// * This corresponds to hyper-arc consistency in CSP speak.
// * This can be done by creating a mini-factor graph and search.
// * For a nine-by-nine Sudoku, the search tree will be 8+6+6=20 levels
// deep.
// * It will be very expensive to exclude values that way.
// */
// void applyBeliefPropagation(size_t nrIterations = 10) const;
/*
* Apply arc-consistency ~ Approximate loopy belief propagation
@ -84,7 +82,6 @@ namespace gtsam {
*/
void runArcConsistency(size_t cardinality, size_t nrIterations = 10,
bool print = false) const;
}; // CSP
} // gtsam
}; // CSP
} // namespace gtsam

View File

@ -17,49 +17,40 @@
#pragma once
#include <gtsam_unstable/dllexport.h>
#include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam_unstable/dllexport.h>
#include <boost/assign.hpp>
namespace gtsam {
class Domain;
class Domain;
/**
/**
* Base class for discrete probabilistic factors
* The most general one is the derived DecisionTreeFactor
*/
class Constraint : public DiscreteFactor {
class Constraint : public DiscreteFactor {
public:
typedef boost::shared_ptr<Constraint> shared_ptr;
protected:
/// Construct n-way factor
Constraint(const KeyVector& js) :
DiscreteFactor(js) {
}
Constraint(const KeyVector& js) : DiscreteFactor(js) {}
/// Construct unary factor
Constraint(Key j) :
DiscreteFactor(boost::assign::cref_list_of<1>(j)) {
}
Constraint(Key j) : DiscreteFactor(boost::assign::cref_list_of<1>(j)) {}
/// Construct binary factor
Constraint(Key j1, Key j2) :
DiscreteFactor(boost::assign::cref_list_of<2>(j1)(j2)) {
}
Constraint(Key j1, Key j2)
: DiscreteFactor(boost::assign::cref_list_of<2>(j1)(j2)) {}
/// construct from container
template<class KeyIterator>
Constraint(KeyIterator beginKey, KeyIterator endKey) :
DiscreteFactor(beginKey, endKey) {
}
template <class KeyIterator>
Constraint(KeyIterator beginKey, KeyIterator endKey)
: DiscreteFactor(beginKey, endKey) {}
public:
/// @name Standard Constructors
/// @{
@ -78,16 +69,16 @@ namespace gtsam {
* @param j domain to be checked
* @param domains all other domains
*/
virtual bool ensureArcConsistency(size_t j, std::vector<Domain>& domains) const = 0;
virtual bool ensureArcConsistency(size_t j,
std::vector<Domain>& domains) const = 0;
/// Partially apply known values
virtual shared_ptr partiallyApply(const Values&) const = 0;
/// Partially apply known values, domain version
virtual shared_ptr partiallyApply(const std::vector<Domain>&) const = 0;
/// @}
};
};
// DiscreteFactor
}// namespace gtsam
} // namespace gtsam

View File

@ -5,92 +5,89 @@
* @author Frank Dellaert
*/
#include <gtsam_unstable/discrete/Domain.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam_unstable/discrete/Domain.h>
#include <boost/make_shared.hpp>
namespace gtsam {
using namespace std;
using namespace std;
/* ************************************************************************* */
void Domain::print(const string& s,
const KeyFormatter& formatter) const {
// cout << s << ": Domain on " << formatter(keys_[0]) << " (j=" <<
// formatter(keys_[0]) << ") with values";
// for (size_t v: values_) cout << " " << v;
// cout << endl;
for (size_t v: values_) cout << v;
}
/* ************************************************************************* */
void Domain::print(const string& s, const KeyFormatter& formatter) const {
// cout << s << ": Domain on " << formatter(keys_[0]) << " (j=" <<
// formatter(keys_[0]) << ") with values";
// for (size_t v: values_) cout << " " << v;
// cout << endl;
for (size_t v : values_) cout << v;
}
/* ************************************************************************* */
double Domain::operator()(const Values& values) const {
/* ************************************************************************* */
double Domain::operator()(const Values& values) const {
return contains(values.at(keys_[0]));
}
}
/* ************************************************************************* */
DecisionTreeFactor Domain::toDecisionTreeFactor() const {
/* ************************************************************************* */
DecisionTreeFactor Domain::toDecisionTreeFactor() const {
DiscreteKeys keys;
keys += DiscreteKey(keys_[0],cardinality_);
keys += DiscreteKey(keys_[0], cardinality_);
vector<double> table;
for (size_t i1 = 0; i1 < cardinality_; ++i1)
table.push_back(contains(i1));
for (size_t i1 = 0; i1 < cardinality_; ++i1) table.push_back(contains(i1));
DecisionTreeFactor converted(keys, table);
return converted;
}
}
/* ************************************************************************* */
DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const {
/* ************************************************************************* */
DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const {
// TODO: can we do this more efficiently?
return toDecisionTreeFactor() * f;
}
}
/* ************************************************************************* */
bool Domain::ensureArcConsistency(size_t j, vector<Domain>& domains) const {
/* ************************************************************************* */
bool Domain::ensureArcConsistency(size_t j, vector<Domain>& domains) const {
if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain");
Domain& D = domains[j];
for(size_t value: values_)
for (size_t value : values_)
if (!D.contains(value)) throw runtime_error("Unsatisfiable");
D = *this;
return true;
}
}
/* ************************************************************************* */
bool Domain::checkAllDiff(const KeyVector keys, vector<Domain>& domains) {
/* ************************************************************************* */
bool Domain::checkAllDiff(const KeyVector keys, vector<Domain>& domains) {
Key j = keys_[0];
// for all values in this domain
for(size_t value: values_) {
for (size_t value : values_) {
// for all connected domains
for(Key k: keys)
for (Key k : keys)
// if any domain contains the value we cannot make this domain singleton
if (k!=j && domains[k].contains(value))
goto found;
if (k != j && domains[k].contains(value)) goto found;
values_.clear();
values_.insert(value);
return true; // we changed it
found:;
}
return false; // we did not change it
}
}
/* ************************************************************************* */
Constraint::shared_ptr Domain::partiallyApply(
const Values& values) const {
/* ************************************************************************* */
Constraint::shared_ptr Domain::partiallyApply(const Values& values) const {
Values::const_iterator it = values.find(keys_[0]);
if (it != values.end() && !contains(it->second)) throw runtime_error(
"Domain::partiallyApply: unsatisfiable");
return boost::make_shared < Domain > (*this);
}
if (it != values.end() && !contains(it->second))
throw runtime_error("Domain::partiallyApply: unsatisfiable");
return boost::make_shared<Domain>(*this);
}
/* ************************************************************************* */
Constraint::shared_ptr Domain::partiallyApply(
/* ************************************************************************* */
Constraint::shared_ptr Domain::partiallyApply(
const vector<Domain>& domains) const {
const Domain& Dk = domains[keys_[0]];
if (Dk.isSingleton() && !contains(*Dk.begin())) throw runtime_error(
"Domain::partiallyApply: unsatisfiable");
return boost::make_shared < Domain > (Dk);
}
if (Dk.isSingleton() && !contains(*Dk.begin()))
throw runtime_error("Domain::partiallyApply: unsatisfiable");
return boost::make_shared<Domain>(Dk);
}
/* ************************************************************************* */
} // namespace gtsam

View File

@ -7,81 +7,65 @@
#pragma once
#include <gtsam_unstable/discrete/Constraint.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam_unstable/discrete/Constraint.h>
namespace gtsam {
/**
/**
* Domain restriction constraint
*/
class GTSAM_UNSTABLE_EXPORT Domain: public Constraint {
class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
size_t cardinality_; /// Cardinality
std::set<size_t> values_; /// allowed values
public:
typedef boost::shared_ptr<Domain> shared_ptr;
// Constructor on Discrete Key initializes an "all-allowed" domain
Domain(const DiscreteKey& dkey) :
Constraint(dkey.first), cardinality_(dkey.second) {
for (size_t v = 0; v < cardinality_; v++)
values_.insert(v);
Domain(const DiscreteKey& dkey)
: Constraint(dkey.first), cardinality_(dkey.second) {
for (size_t v = 0; v < cardinality_; v++) values_.insert(v);
}
// Constructor on Discrete Key with single allowed value
// Consider SingleValue constraint
Domain(const DiscreteKey& dkey, size_t v) :
Constraint(dkey.first), cardinality_(dkey.second) {
Domain(const DiscreteKey& dkey, size_t v)
: Constraint(dkey.first), cardinality_(dkey.second) {
values_.insert(v);
}
/// Constructor
Domain(const Domain& other) :
Constraint(other.keys_[0]), values_(other.values_) {
}
Domain(const Domain& other)
: Constraint(other.keys_[0]), values_(other.values_) {}
/// insert a value, non const :-(
void insert(size_t value) {
values_.insert(value);
}
void insert(size_t value) { values_.insert(value); }
/// erase a value, non const :-(
void erase(size_t value) {
values_.erase(value);
}
void erase(size_t value) { values_.erase(value); }
size_t nrValues() const {
return values_.size();
}
size_t nrValues() const { return values_.size(); }
bool isSingleton() const {
return nrValues() == 1;
}
bool isSingleton() const { return nrValues() == 1; }
size_t firstValue() const {
return *values_.begin();
}
size_t firstValue() const { return *values_.begin(); }
// print
void print(const std::string& s = "",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
void print(const std::string& s = "", const KeyFormatter& formatter =
DefaultKeyFormatter) const override;
/// equals
bool equals(const DiscreteFactor& other, double tol) const override {
if(!dynamic_cast<const Domain*>(&other))
if (!dynamic_cast<const Domain*>(&other))
return false;
else {
const Domain& f(static_cast<const Domain&>(other));
return (cardinality_==f.cardinality_) && (values_==f.values_);
return (cardinality_ == f.cardinality_) && (values_ == f.values_);
}
}
bool contains(size_t value) const {
return values_.count(value)>0;
}
bool contains(size_t value) const { return values_.count(value) > 0; }
/// Calculate value
double operator()(const Values& values) const override;
@ -97,11 +81,13 @@ namespace gtsam {
* @param j domain to be checked
* @param domains all other domains
*/
bool ensureArcConsistency(size_t j, std::vector<Domain>& domains) const override;
bool ensureArcConsistency(size_t j,
std::vector<Domain>& domains) const override;
/**
* Check for a value in domain that does not occur in any other connected domain.
* If found, we make this a singleton... Called in AllDiff::ensureArcConsistency
* Check for a value in domain that does not occur in any other connected
* domain. If found, we make this a singleton... Called in
* AllDiff::ensureArcConsistency
* @param keys connected domains through alldiff
*/
bool checkAllDiff(const KeyVector keys, std::vector<Domain>& domains);
@ -112,6 +98,6 @@ namespace gtsam {
/// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply(
const std::vector<Domain>& domains) const override;
};
};
} // namespace gtsam

View File

@ -5,24 +5,22 @@
* @author Frank Dellaert
*/
#include <gtsam_unstable/discrete/Scheduler.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/base/debug.h>
#include <gtsam/base/timing.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam_unstable/discrete/Scheduler.h>
#include <boost/tokenizer.hpp>
#include <cmath>
#include <fstream>
#include <iomanip>
#include <cmath>
namespace gtsam {
using namespace std;
using namespace std;
Scheduler::Scheduler(size_t maxNrStudents, const string& filename):
maxNrStudents_(maxNrStudents)
{
Scheduler::Scheduler(size_t maxNrStudents, const string& filename)
: maxNrStudents_(maxNrStudents) {
typedef boost::tokenizer<boost::escaped_list_separator<char> > Tokenizer;
// open file
@ -38,8 +36,7 @@ namespace gtsam {
if (getline(is, line, '\r')) {
Tokenizer tok(line);
Tokenizer::iterator it = tok.begin();
for (++it; it != tok.end(); ++it)
addFaculty(*it);
for (++it; it != tok.end(); ++it) addFaculty(*it);
}
// for all remaining lines
@ -50,17 +47,16 @@ namespace gtsam {
Tokenizer::iterator it = tok.begin();
addSlot(*it++); // add slot
// add availability
for (; it != tok.end(); ++it)
available_ += (it->empty()) ? "0 " : "1 ";
for (; it != tok.end(); ++it) available_ += (it->empty()) ? "0 " : "1 ";
available_ += '\n';
}
} // constructor
} // constructor
/** addStudent has to be called after adding slots and faculty */
void Scheduler::addStudent(const string& studentName,
const string& area1, const string& area2,
const string& area3, const string& advisor) {
assert(nrStudents()<maxNrStudents_);
/** addStudent has to be called after adding slots and faculty */
void Scheduler::addStudent(const string& studentName, const string& area1,
const string& area2, const string& area3,
const string& advisor) {
assert(nrStudents() < maxNrStudents_);
assert(facultyInArea_.count(area1));
assert(facultyInArea_.count(area2));
assert(facultyInArea_.count(area3));
@ -69,71 +65,73 @@ namespace gtsam {
student.name_ = studentName;
// We fix the ordering by assigning a higher index to the student
// and numbering the areas lower
Key j = 3*maxNrStudents_ + nrStudents();
Key j = 3 * maxNrStudents_ + nrStudents();
student.key_ = DiscreteKey(j, nrTimeSlots());
Key base = 3*nrStudents();
student.keys_[0] = DiscreteKey(base+0, nrFaculty());
student.keys_[1] = DiscreteKey(base+1, nrFaculty());
student.keys_[2] = DiscreteKey(base+2, nrFaculty());
Key base = 3 * nrStudents();
student.keys_[0] = DiscreteKey(base + 0, nrFaculty());
student.keys_[1] = DiscreteKey(base + 1, nrFaculty());
student.keys_[2] = DiscreteKey(base + 2, nrFaculty());
student.areaName_[0] = area1;
student.areaName_[1] = area2;
student.areaName_[2] = area3;
students_.push_back(student);
}
}
/** get key for student and area, 0 is time slot itself */
const DiscreteKey& Scheduler::key(size_t s, boost::optional<size_t> area) const {
/** get key for student and area, 0 is time slot itself */
const DiscreteKey& Scheduler::key(size_t s,
boost::optional<size_t> area) const {
return area ? students_[s].keys_[*area] : students_[s].key_;
}
}
const string& Scheduler::studentName(size_t i) const {
assert(i<nrStudents());
const string& Scheduler::studentName(size_t i) const {
assert(i < nrStudents());
return students_[i].name_;
}
}
const DiscreteKey& Scheduler::studentKey(size_t i) const {
assert(i<nrStudents());
const DiscreteKey& Scheduler::studentKey(size_t i) const {
assert(i < nrStudents());
return students_[i].key_;
}
}
const string& Scheduler::studentArea(size_t i, size_t area) const {
assert(i<nrStudents());
const string& Scheduler::studentArea(size_t i, size_t area) const {
assert(i < nrStudents());
return students_[i].areaName_[area];
}
}
/** Add student-specific constraints to the graph */
void Scheduler::addStudentSpecificConstraints(size_t i, boost::optional<size_t> slot) {
/** Add student-specific constraints to the graph */
void Scheduler::addStudentSpecificConstraints(size_t i,
boost::optional<size_t> slot) {
bool debug = ISDEBUG("Scheduler::buildGraph");
assert(i<nrStudents());
assert(i < nrStudents());
const Student& s = students_[i];
if (!slot && !slotsAvailable_.empty()) {
if (debug) cout << "Adding availability of slots" << endl;
assert(slotsAvailable_.size()==s.key_.second);
assert(slotsAvailable_.size() == s.key_.second);
CSP::add(s.key_, slotsAvailable_);
}
// For all areas
for (size_t area = 0; area < 3; area++) {
DiscreteKey areaKey = s.keys_[area];
const string& areaName = s.areaName_[area];
if (debug) cout << "Area constraints " << areaName << endl;
assert(facultyInArea_[areaName].size()==areaKey.second);
assert(facultyInArea_[areaName].size() == areaKey.second);
CSP::add(areaKey, facultyInArea_[areaName]);
if (debug) cout << "Advisor constraint " << areaName << endl;
assert(s.advisor_.size()==areaKey.second);
assert(s.advisor_.size() == areaKey.second);
CSP::add(areaKey, s.advisor_);
if (debug) cout << "Availability of faculty " << areaName << endl;
if (slot) {
// get all constraints then specialize to slot
size_t dummyIndex = maxNrStudents_*3+maxNrStudents_;
size_t dummyIndex = maxNrStudents_ * 3 + maxNrStudents_;
DiscreteKey dummy(dummyIndex, nrTimeSlots());
Potentials::ADT p(dummy & areaKey, available_); // available_ is Doodle string
Potentials::ADT p(dummy & areaKey,
available_); // available_ is Doodle string
Potentials::ADT q = p.choose(dummyIndex, *slot);
DiscreteFactor::shared_ptr f(new DecisionTreeFactor(areaKey, q));
CSP::push_back(f);
@ -145,25 +143,22 @@ namespace gtsam {
// add mutex
if (debug) cout << "Mutex for faculty" << endl;
addAllDiff(s.keys_[0] & s.keys_[1] & s.keys_[2]);
}
}
/** Main routine that builds factor graph */
void Scheduler::buildGraph(size_t mutexBound) {
/** Main routine that builds factor graph */
void Scheduler::buildGraph(size_t mutexBound) {
bool debug = ISDEBUG("Scheduler::buildGraph");
if (debug) cout << "Adding student-specific constraints" << endl;
for (size_t i = 0; i < nrStudents(); i++)
addStudentSpecificConstraints(i);
for (size_t i = 0; i < nrStudents(); i++) addStudentSpecificConstraints(i);
// special constraint for MN
if (studentName(0) == "Michael N") CSP::add(studentKey(0),
"0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1");
if (studentName(0) == "Michael N")
CSP::add(studentKey(0), "0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1");
if (!mutexBound) {
DiscreteKeys dkeys;
for(const Student& s: students_)
dkeys.push_back(s.key_);
for (const Student& s : students_) dkeys.push_back(s.key_);
addAllDiff(dkeys);
} else {
if (debug) cout << "Mutex for Students" << endl;
@ -175,103 +170,98 @@ namespace gtsam {
}
}
}
} // buildGraph
} // buildGraph
/** print */
void Scheduler::print(const string& s, const KeyFormatter& formatter) const {
/** print */
void Scheduler::print(const string& s, const KeyFormatter& formatter) const {
cout << s << " Faculty:" << endl;
for(const string& name: facultyName_)
cout << name << '\n';
for (const string& name : facultyName_) cout << name << '\n';
cout << endl;
cout << s << " Slots:\n";
size_t i = 0;
for(const string& name: slotName_)
cout << i++ << " " << name << endl;
for (const string& name : slotName_) cout << i++ << " " << name << endl;
cout << endl;
cout << "Availability:\n" << available_ << '\n';
cout << s << " Area constraints:\n";
for(const FacultyInArea::value_type& it: facultyInArea_)
{
for (const FacultyInArea::value_type& it : facultyInArea_) {
cout << setw(12) << it.first << ": ";
for(double v: it.second)
cout << v << " ";
for (double v : it.second) cout << v << " ";
cout << '\n';
}
cout << endl;
cout << s << " Students:\n";
for (const Student& student: students_)
student.print();
for (const Student& student : students_) student.print();
cout << endl;
CSP::print(s + " Factor graph");
cout << endl;
} // print
} // print
/** Print readable form of assignment */
void Scheduler::printAssignment(sharedValues assignment) const {
/** Print readable form of assignment */
void Scheduler::printAssignment(sharedValues assignment) const {
// Not intended to be general! Assumes very particular ordering !
cout << endl;
for (size_t s = 0; s < nrStudents(); s++) {
Key j = 3*maxNrStudents_ + s;
Key j = 3 * maxNrStudents_ + s;
size_t slot = assignment->at(j);
cout << studentName(s) << " slot: " << slotName_[slot] << endl;
Key base = 3*s;
Key base = 3 * s;
for (size_t area = 0; area < 3; area++) {
size_t faculty = assignment->at(base+area);
cout << setw(12) << studentArea(s,area) << ": " << facultyName_[faculty]
size_t faculty = assignment->at(base + area);
cout << setw(12) << studentArea(s, area) << ": " << facultyName_[faculty]
<< endl;
}
cout << endl;
}
}
}
/** Special print for single-student case */
void Scheduler::printSpecial(sharedValues assignment) const {
/** Special print for single-student case */
void Scheduler::printSpecial(sharedValues assignment) const {
Values::const_iterator it = assignment->begin();
for (size_t area = 0; area < 3; area++, it++) {
size_t f = it->second;
cout << setw(12) << studentArea(0,area) << ": " << facultyName_[f] << endl;
cout << setw(12) << studentArea(0, area) << ": " << facultyName_[f] << endl;
}
cout << endl;
}
}
/** Accumulate faculty stats */
void Scheduler::accumulateStats(sharedValues assignment, vector<
size_t>& stats) const {
/** Accumulate faculty stats */
void Scheduler::accumulateStats(sharedValues assignment,
vector<size_t>& stats) const {
for (size_t s = 0; s < nrStudents(); s++) {
Key base = 3*s;
Key base = 3 * s;
for (size_t area = 0; area < 3; area++) {
size_t f = assignment->at(base+area);
assert(f<stats.size());
size_t f = assignment->at(base + area);
assert(f < stats.size());
stats[f]++;
} // area
} // s
}
}
/** Eliminate, return a Bayes net */
DiscreteBayesNet::shared_ptr Scheduler::eliminate() const {
/** Eliminate, return a Bayes net */
DiscreteBayesNet::shared_ptr Scheduler::eliminate() const {
gttic(my_eliminate);
// TODO: fix this!!
size_t maxKey = keys().size();
Ordering defaultKeyOrdering;
for (size_t i = 0; i<maxKey; ++i)
defaultKeyOrdering += Key(i);
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(defaultKeyOrdering);
for (size_t i = 0; i < maxKey; ++i) defaultKeyOrdering += Key(i);
DiscreteBayesNet::shared_ptr chordal =
this->eliminateSequential(defaultKeyOrdering);
gttoc(my_eliminate);
return chordal;
}
}
/** Find the best total assignment - can be expensive */
Scheduler::sharedValues Scheduler::optimalAssignment() const {
/** Find the best total assignment - can be expensive */
Scheduler::sharedValues Scheduler::optimalAssignment() const {
DiscreteBayesNet::shared_ptr chordal = eliminate();
if (ISDEBUG("Scheduler::optimalAssignment")) {
DiscreteBayesNet::const_iterator it = chordal->end()-1;
const Student & student = students_.front();
DiscreteBayesNet::const_iterator it = chordal->end() - 1;
const Student& student = students_.front();
cout << endl;
(*it)->print(student.name_);
}
@ -280,23 +270,21 @@ namespace gtsam {
sharedValues mpe = chordal->optimize();
gttoc(my_optimize);
return mpe;
}
}
/** find the assignment of students to slots with most possible committees */
Scheduler::sharedValues Scheduler::bestSchedule() const {
/** find the assignment of students to slots with most possible committees */
Scheduler::sharedValues Scheduler::bestSchedule() const {
sharedValues best;
throw runtime_error("bestSchedule not implemented");
return best;
}
}
/** find the corresponding most desirable committee assignment */
Scheduler::sharedValues Scheduler::bestAssignment(
/** find the corresponding most desirable committee assignment */
Scheduler::sharedValues Scheduler::bestAssignment(
sharedValues bestSchedule) const {
sharedValues best;
throw runtime_error("bestAssignment not implemented");
return best;
}
} // gtsam
}
} // namespace gtsam

View File

@ -11,17 +11,15 @@
namespace gtsam {
/**
/**
* Scheduler class
* Creates one variable for each student, and three variables for each
* of the student's areas, for a total of 4*nrStudents variables.
* The "student" variable will determine when the student takes the qual.
* The "area" variables determine which faculty are on his/her committee.
*/
class GTSAM_UNSTABLE_EXPORT Scheduler : public CSP {
class GTSAM_UNSTABLE_EXPORT Scheduler : public CSP {
private:
/** Internal data structure for students */
struct Student {
std::string name_;
@ -29,15 +27,14 @@ namespace gtsam {
std::vector<DiscreteKey> keys_; // key for areas
std::vector<std::string> areaName_;
std::vector<double> advisor_;
Student(size_t nrFaculty, size_t advisorIndex) :
keys_(3), areaName_(3), advisor_(nrFaculty, 1.0) {
Student(size_t nrFaculty, size_t advisorIndex)
: keys_(3), areaName_(3), advisor_(nrFaculty, 1.0) {
advisor_[advisorIndex] = 0.0;
}
void print() const {
using std::cout;
cout << name_ << ": ";
for (size_t area = 0; area < 3; area++)
cout << areaName_[area] << " ";
for (size_t area = 0; area < 3; area++) cout << areaName_[area] << " ";
cout << std::endl;
}
};
@ -63,7 +60,6 @@ namespace gtsam {
std::vector<double> slotsAvailable_;
public:
/**
* Constructor
* We need to know the number of students in advance for ordering keys.
@ -79,26 +75,16 @@ namespace gtsam {
facultyName_.push_back(facultyName);
}
size_t nrFaculty() const {
return facultyName_.size();
}
size_t nrFaculty() const { return facultyName_.size(); }
/** boolean std::string of nrTimeSlots * nrFaculty */
void setAvailability(const std::string& available) {
available_ = available;
}
void setAvailability(const std::string& available) { available_ = available; }
void addSlot(const std::string& slotName) {
slotName_.push_back(slotName);
}
void addSlot(const std::string& slotName) { slotName_.push_back(slotName); }
size_t nrTimeSlots() const {
return slotName_.size();
}
size_t nrTimeSlots() const { return slotName_.size(); }
const std::string& slotName(size_t s) const {
return slotName_[s];
}
const std::string& slotName(size_t s) const { return slotName_[s]; }
/** slots available, boolean */
void setSlotsAvailable(const std::vector<double>& slotsAvailable) {
@ -107,7 +93,8 @@ namespace gtsam {
void addArea(const std::string& facultyName, const std::string& areaName) {
areaName_.push_back(areaName);
std::vector<double>& table = facultyInArea_[areaName]; // will create if needed
std::vector<double>& table =
facultyInArea_[areaName]; // will create if needed
if (table.empty()) table.resize(nrFaculty(), 0);
table[facultyIndex_[facultyName]] = 1;
}
@ -119,7 +106,8 @@ namespace gtsam {
Scheduler(size_t maxNrStudents, const std::string& filename);
/** get key for student and area, 0 is time slot itself */
const DiscreteKey& key(size_t s, boost::optional<size_t> area = boost::none) const;
const DiscreteKey& key(size_t s,
boost::optional<size_t> area = boost::none) const;
/** addStudent has to be called after adding slots and faculty */
void addStudent(const std::string& studentName, const std::string& area1,
@ -127,16 +115,15 @@ namespace gtsam {
const std::string& advisor);
/// current number of students
size_t nrStudents() const {
return students_.size();
}
size_t nrStudents() const { return students_.size(); }
const std::string& studentName(size_t i) const;
const DiscreteKey& studentKey(size_t i) const;
const std::string& studentArea(size_t i, size_t area) const;
/** Add student-specific constraints to the graph */
void addStudentSpecificConstraints(size_t i, boost::optional<size_t> slot = boost::none);
void addStudentSpecificConstraints(
size_t i, boost::optional<size_t> slot = boost::none);
/** Main routine that builds factor graph */
void buildGraph(size_t mutexBound = 7);
@ -168,8 +155,6 @@ namespace gtsam {
/** find the corresponding most desirable committee assignment */
sharedValues bestAssignment(sharedValues bestSchedule) const;
}; // Scheduler
} // gtsam
}; // Scheduler
} // namespace gtsam

View File

@ -5,75 +5,74 @@
* @author Frank Dellaert
*/
#include <gtsam_unstable/discrete/SingleValue.h>
#include <gtsam_unstable/discrete/Domain.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam_unstable/discrete/Domain.h>
#include <gtsam_unstable/discrete/SingleValue.h>
#include <boost/make_shared.hpp>
namespace gtsam {
using namespace std;
using namespace std;
/* ************************************************************************* */
void SingleValue::print(const string& s,
const KeyFormatter& formatter) const {
cout << s << "SingleValue on " << "j=" << formatter(keys_[0])
<< " with value " << value_ << endl;
}
/* ************************************************************************* */
void SingleValue::print(const string& s, const KeyFormatter& formatter) const {
cout << s << "SingleValue on "
<< "j=" << formatter(keys_[0]) << " with value " << value_ << endl;
}
/* ************************************************************************* */
double SingleValue::operator()(const Values& values) const {
return (double) (values.at(keys_[0]) == value_);
}
/* ************************************************************************* */
double SingleValue::operator()(const Values& values) const {
return (double)(values.at(keys_[0]) == value_);
}
/* ************************************************************************* */
DecisionTreeFactor SingleValue::toDecisionTreeFactor() const {
/* ************************************************************************* */
DecisionTreeFactor SingleValue::toDecisionTreeFactor() const {
DiscreteKeys keys;
keys += DiscreteKey(keys_[0],cardinality_);
keys += DiscreteKey(keys_[0], cardinality_);
vector<double> table;
for (size_t i1 = 0; i1 < cardinality_; i1++)
table.push_back(i1 == value_);
for (size_t i1 = 0; i1 < cardinality_; i1++) table.push_back(i1 == value_);
DecisionTreeFactor converted(keys, table);
return converted;
}
}
/* ************************************************************************* */
DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const {
/* ************************************************************************* */
DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const {
// TODO: can we do this more efficiently?
return toDecisionTreeFactor() * f;
}
}
/* ************************************************************************* */
bool SingleValue::ensureArcConsistency(size_t j,
/* ************************************************************************* */
bool SingleValue::ensureArcConsistency(size_t j,
vector<Domain>& domains) const {
if (j != keys_[0]) throw invalid_argument(
"SingleValue check on wrong domain");
if (j != keys_[0])
throw invalid_argument("SingleValue check on wrong domain");
Domain& D = domains[j];
if (D.isSingleton()) {
if (D.firstValue() != value_) throw runtime_error("Unsatisfiable");
return false;
}
D = Domain(discreteKey(),value_);
D = Domain(discreteKey(), value_);
return true;
}
}
/* ************************************************************************* */
Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const {
/* ************************************************************************* */
Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const {
Values::const_iterator it = values.find(keys_[0]);
if (it != values.end() && it->second != value_) throw runtime_error(
"SingleValue::partiallyApply: unsatisfiable");
return boost::make_shared < SingleValue > (keys_[0], cardinality_, value_);
}
if (it != values.end() && it->second != value_)
throw runtime_error("SingleValue::partiallyApply: unsatisfiable");
return boost::make_shared<SingleValue>(keys_[0], cardinality_, value_);
}
/* ************************************************************************* */
Constraint::shared_ptr SingleValue::partiallyApply(
/* ************************************************************************* */
Constraint::shared_ptr SingleValue::partiallyApply(
const vector<Domain>& domains) const {
const Domain& Dk = domains[keys_[0]];
if (Dk.isSingleton() && !Dk.contains(value_)) throw runtime_error(
"SingleValue::partiallyApply: unsatisfiable");
return boost::make_shared < SingleValue > (discreteKey(), value_);
}
if (Dk.isSingleton() && !Dk.contains(value_))
throw runtime_error("SingleValue::partiallyApply: unsatisfiable");
return boost::make_shared<SingleValue>(discreteKey(), value_);
}
/* ************************************************************************* */
} // namespace gtsam

View File

@ -7,16 +7,15 @@
#pragma once
#include <gtsam_unstable/discrete/Constraint.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam_unstable/discrete/Constraint.h>
namespace gtsam {
/**
/**
* SingleValue constraint
*/
class GTSAM_UNSTABLE_EXPORT SingleValue: public Constraint {
class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
/// Number of values
size_t cardinality_;
@ -24,34 +23,31 @@ namespace gtsam {
size_t value_;
DiscreteKey discreteKey() const {
return DiscreteKey(keys_[0],cardinality_);
return DiscreteKey(keys_[0], cardinality_);
}
public:
typedef boost::shared_ptr<SingleValue> shared_ptr;
/// Constructor
SingleValue(Key key, size_t n, size_t value) :
Constraint(key), cardinality_(n), value_(value) {
}
SingleValue(Key key, size_t n, size_t value)
: Constraint(key), cardinality_(n), value_(value) {}
/// Constructor
SingleValue(const DiscreteKey& dkey, size_t value) :
Constraint(dkey.first), cardinality_(dkey.second), value_(value) {
}
SingleValue(const DiscreteKey& dkey, size_t value)
: Constraint(dkey.first), cardinality_(dkey.second), value_(value) {}
// print
void print(const std::string& s = "",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
void print(const std::string& s = "", const KeyFormatter& formatter =
DefaultKeyFormatter) const override;
/// equals
bool equals(const DiscreteFactor& other, double tol) const override {
if(!dynamic_cast<const SingleValue*>(&other))
if (!dynamic_cast<const SingleValue*>(&other))
return false;
else {
const SingleValue& f(static_cast<const SingleValue&>(other));
return (cardinality_==f.cardinality_) && (value_==f.value_);
return (cardinality_ == f.cardinality_) && (value_ == f.value_);
}
}
@ -69,7 +65,8 @@ namespace gtsam {
* @param j domain to be checked
* @param domains all other domains
*/
bool ensureArcConsistency(size_t j, std::vector<Domain>& domains) const override;
bool ensureArcConsistency(size_t j,
std::vector<Domain>& domains) const override;
/// Partially apply known values
Constraint::shared_ptr partiallyApply(const Values& values) const override;
@ -77,6 +74,6 @@ namespace gtsam {
/// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply(
const std::vector<Domain>& domains) const override;
};
};
} // namespace gtsam

View File

@ -7,49 +7,50 @@
#include <gtsam_unstable/discrete/CSP.h>
#include <gtsam_unstable/discrete/Domain.h>
#include <boost/assign/std/map.hpp>
using boost::assign::insert;
#include <CppUnitLite/TestHarness.h>
#include <iostream>
#include <fstream>
#include <iostream>
using namespace std;
using namespace gtsam;
/* ************************************************************************* */
TEST_UNSAFE( BinaryAllDif, allInOne)
{
TEST_UNSAFE(BinaryAllDif, allInOne) {
// Create keys and ordering
size_t nrColors = 2;
// DiscreteKey ID("Idaho", nrColors), UT("Utah", nrColors), AZ("Arizona", nrColors);
// DiscreteKey ID("Idaho", nrColors), UT("Utah", nrColors), AZ("Arizona",
// nrColors);
DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors);
// Check construction and conversion
BinaryAllDiff c1(ID, UT);
DecisionTreeFactor f1(ID & UT, "0 1 1 0");
EXPECT(assert_equal(f1,c1.toDecisionTreeFactor()));
EXPECT(assert_equal(f1, c1.toDecisionTreeFactor()));
// Check construction and conversion
BinaryAllDiff c2(UT, AZ);
DecisionTreeFactor f2(UT & AZ, "0 1 1 0");
EXPECT(assert_equal(f2,c2.toDecisionTreeFactor()));
EXPECT(assert_equal(f2, c2.toDecisionTreeFactor()));
DecisionTreeFactor f3 = f1*f2;
EXPECT(assert_equal(f3,c1*f2));
EXPECT(assert_equal(f3,c2*f1));
DecisionTreeFactor f3 = f1 * f2;
EXPECT(assert_equal(f3, c1 * f2));
EXPECT(assert_equal(f3, c2 * f1));
}
/* ************************************************************************* */
TEST_UNSAFE( CSP, allInOne)
{
TEST_UNSAFE(CSP, allInOne) {
// Create keys and ordering
size_t nrColors = 2;
DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors);
// Create the CSP
CSP csp;
csp.addAllDiff(ID,UT);
csp.addAllDiff(UT,AZ);
csp.addAllDiff(ID, UT);
csp.addAllDiff(UT, AZ);
// Check an invalid combination, with ID==UT==AZ all same color
DiscreteFactor::Values invalid;
@ -69,67 +70,67 @@ TEST_UNSAFE( CSP, allInOne)
DecisionTreeFactor product = csp.product();
// product.dot("product");
DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0");
EXPECT(assert_equal(expectedProduct,product));
EXPECT(assert_equal(expectedProduct, product));
// Solve
CSP::sharedValues mpe = csp.optimalAssignment();
CSP::Values expected;
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1);
EXPECT(assert_equal(expected,*mpe));
EXPECT(assert_equal(expected, *mpe));
EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9);
}
/* ************************************************************************* */
TEST_UNSAFE( CSP, WesternUS)
{
TEST_UNSAFE(CSP, WesternUS) {
// Create keys
size_t nrColors = 4;
DiscreteKey
// Create ordering according to example in ND-CSP.lyx
WA(0, nrColors), OR(3, nrColors), CA(1, nrColors),NV(2, nrColors),
ID(8, nrColors), UT(9, nrColors), AZ(10, nrColors),
MT(4, nrColors), WY(5, nrColors), CO(7, nrColors), NM(6, nrColors);
WA(0, nrColors),
OR(3, nrColors), CA(1, nrColors), NV(2, nrColors), ID(8, nrColors),
UT(9, nrColors), AZ(10, nrColors), MT(4, nrColors), WY(5, nrColors),
CO(7, nrColors), NM(6, nrColors);
// Create the CSP
CSP csp;
csp.addAllDiff(WA,ID);
csp.addAllDiff(WA,OR);
csp.addAllDiff(OR,ID);
csp.addAllDiff(OR,CA);
csp.addAllDiff(OR,NV);
csp.addAllDiff(CA,NV);
csp.addAllDiff(CA,AZ);
csp.addAllDiff(ID,MT);
csp.addAllDiff(ID,WY);
csp.addAllDiff(ID,UT);
csp.addAllDiff(ID,NV);
csp.addAllDiff(NV,UT);
csp.addAllDiff(NV,AZ);
csp.addAllDiff(UT,WY);
csp.addAllDiff(UT,CO);
csp.addAllDiff(UT,NM);
csp.addAllDiff(UT,AZ);
csp.addAllDiff(AZ,CO);
csp.addAllDiff(AZ,NM);
csp.addAllDiff(MT,WY);
csp.addAllDiff(WY,CO);
csp.addAllDiff(CO,NM);
csp.addAllDiff(WA, ID);
csp.addAllDiff(WA, OR);
csp.addAllDiff(OR, ID);
csp.addAllDiff(OR, CA);
csp.addAllDiff(OR, NV);
csp.addAllDiff(CA, NV);
csp.addAllDiff(CA, AZ);
csp.addAllDiff(ID, MT);
csp.addAllDiff(ID, WY);
csp.addAllDiff(ID, UT);
csp.addAllDiff(ID, NV);
csp.addAllDiff(NV, UT);
csp.addAllDiff(NV, AZ);
csp.addAllDiff(UT, WY);
csp.addAllDiff(UT, CO);
csp.addAllDiff(UT, NM);
csp.addAllDiff(UT, AZ);
csp.addAllDiff(AZ, CO);
csp.addAllDiff(AZ, NM);
csp.addAllDiff(MT, WY);
csp.addAllDiff(WY, CO);
csp.addAllDiff(CO, NM);
// Solve
Ordering ordering;
ordering += Key(0),Key(1),Key(2),Key(3),Key(4),Key(5),Key(6),Key(7),Key(8),Key(9),Key(10);
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7),
Key(8), Key(9), Key(10);
CSP::sharedValues mpe = csp.optimalAssignment(ordering);
// GTSAM_PRINT(*mpe);
CSP::Values expected;
insert(expected)
(WA.first,1)(CA.first,1)(NV.first,3)(OR.first,0)
(MT.first,1)(WY.first,0)(NM.first,3)(CO.first,2)
(ID.first,2)(UT.first,1)(AZ.first,0);
insert(expected)(WA.first, 1)(CA.first, 1)(NV.first, 3)(OR.first, 0)(
MT.first, 1)(WY.first, 0)(NM.first, 3)(CO.first, 2)(ID.first, 2)(
UT.first, 1)(AZ.first, 0);
// TODO: Fix me! mpe result seems to be right. (See the printing)
// It has the same prob as the expected solution.
// Is mpe another solution, or the expected solution is unique???
EXPECT(assert_equal(expected,*mpe));
EXPECT(assert_equal(expected, *mpe));
EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9);
// Write out the dual graph for hmetis
@ -142,8 +143,7 @@ TEST_UNSAFE( CSP, WesternUS)
}
/* ************************************************************************* */
TEST_UNSAFE( CSP, AllDiff)
{
TEST_UNSAFE(CSP, AllDiff) {
// Create keys and ordering
size_t nrColors = 3;
DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors);
@ -151,24 +151,25 @@ TEST_UNSAFE( CSP, AllDiff)
// Create the CSP
CSP csp;
vector<DiscreteKey> dkeys;
dkeys += ID,UT,AZ;
dkeys += ID, UT, AZ;
csp.addAllDiff(dkeys);
csp.addSingleValue(AZ,2);
// GTSAM_PRINT(csp);
csp.addSingleValue(AZ, 2);
// GTSAM_PRINT(csp);
// Check construction and conversion
SingleValue s(AZ,2);
DecisionTreeFactor f1(AZ,"0 0 1");
EXPECT(assert_equal(f1,s.toDecisionTreeFactor()));
SingleValue s(AZ, 2);
DecisionTreeFactor f1(AZ, "0 0 1");
EXPECT(assert_equal(f1, s.toDecisionTreeFactor()));
// Check construction and conversion
AllDiff alldiff(dkeys);
DecisionTreeFactor actual = alldiff.toDecisionTreeFactor();
// GTSAM_PRINT(actual);
// actual.dot("actual");
DecisionTreeFactor f2(ID & AZ & UT,
// GTSAM_PRINT(actual);
// actual.dot("actual");
DecisionTreeFactor f2(
ID & AZ & UT,
"0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0");
EXPECT(assert_equal(f2,actual));
EXPECT(assert_equal(f2, actual));
// Check an invalid combination, with ID==UT==AZ all same color
DiscreteFactor::Values invalid;
@ -188,36 +189,36 @@ TEST_UNSAFE( CSP, AllDiff)
CSP::sharedValues mpe = csp.optimalAssignment();
CSP::Values expected;
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2);
EXPECT(assert_equal(expected,*mpe));
EXPECT(assert_equal(expected, *mpe));
EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9);
// Arc-consistency
vector<Domain> domains;
domains += Domain(ID), Domain(AZ), Domain(UT);
SingleValue singleValue(AZ,2);
EXPECT(singleValue.ensureArcConsistency(1,domains));
EXPECT(alldiff.ensureArcConsistency(0,domains));
EXPECT(!alldiff.ensureArcConsistency(1,domains));
EXPECT(alldiff.ensureArcConsistency(2,domains));
LONGS_EQUAL(2,domains[0].nrValues());
LONGS_EQUAL(1,domains[1].nrValues());
LONGS_EQUAL(2,domains[2].nrValues());
SingleValue singleValue(AZ, 2);
EXPECT(singleValue.ensureArcConsistency(1, domains));
EXPECT(alldiff.ensureArcConsistency(0, domains));
EXPECT(!alldiff.ensureArcConsistency(1, domains));
EXPECT(alldiff.ensureArcConsistency(2, domains));
LONGS_EQUAL(2, domains[0].nrValues());
LONGS_EQUAL(1, domains[1].nrValues());
LONGS_EQUAL(2, domains[2].nrValues());
// Parial application, version 1
DiscreteFactor::Values known;
known[AZ.first] = 2;
DiscreteFactor::shared_ptr reduced1 = alldiff.partiallyApply(known);
DecisionTreeFactor f3(ID & UT, "0 1 1 1 0 1 1 1 0");
EXPECT(assert_equal(f3,reduced1->toDecisionTreeFactor()));
EXPECT(assert_equal(f3, reduced1->toDecisionTreeFactor()));
DiscreteFactor::shared_ptr reduced2 = singleValue.partiallyApply(known);
DecisionTreeFactor f4(AZ, "0 0 1");
EXPECT(assert_equal(f4,reduced2->toDecisionTreeFactor()));
EXPECT(assert_equal(f4, reduced2->toDecisionTreeFactor()));
// Parial application, version 2
DiscreteFactor::shared_ptr reduced3 = alldiff.partiallyApply(domains);
EXPECT(assert_equal(f3,reduced3->toDecisionTreeFactor()));
EXPECT(assert_equal(f3, reduced3->toDecisionTreeFactor()));
DiscreteFactor::shared_ptr reduced4 = singleValue.partiallyApply(domains);
EXPECT(assert_equal(f4,reduced4->toDecisionTreeFactor()));
EXPECT(assert_equal(f4, reduced4->toDecisionTreeFactor()));
// full arc-consistency test
csp.runArcConsistency(nrColors);
@ -229,4 +230,3 @@ int main() {
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -5,14 +5,15 @@
* @date Oct 11, 2013
*/
#include <gtsam/inference/VariableIndex.h>
#include <CppUnitLite/TestHarness.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <CppUnitLite/TestHarness.h>
#include <boost/range/adaptor/map.hpp>
#include <gtsam/inference/VariableIndex.h>
#include <boost/assign/list_of.hpp>
#include <iostream>
#include <boost/range/adaptor/map.hpp>
#include <fstream>
#include <iostream>
using namespace std;
using namespace boost;
@ -23,11 +24,12 @@ using namespace gtsam;
* Loopy belief solver for graphs with only binary and unary factors
*/
class LoopyBelief {
/** Star graph struct for each node, containing
* - the star graph itself
* - the product of original unary factors so we don't have to recompute it later, and
* - the factor indices of the corrected belief factors of the neighboring nodes
* - the product of original unary factors so we don't have to recompute it
* later, and
* - the factor indices of the corrected belief factors of the neighboring
* nodes
*/
typedef std::map<Key, size_t> CorrectedBeliefIndices;
struct StarGraph {
@ -37,40 +39,40 @@ class LoopyBelief {
VariableIndex varIndex_;
StarGraph(const DiscreteFactorGraph::shared_ptr& _star,
const CorrectedBeliefIndices& _beliefIndices,
const DecisionTreeFactor::shared_ptr& _unary) :
star(_star), correctedBeliefIndices(_beliefIndices), unary(_unary), varIndex_(
*_star) {
}
const DecisionTreeFactor::shared_ptr& _unary)
: star(_star),
correctedBeliefIndices(_beliefIndices),
unary(_unary),
varIndex_(*_star) {}
void print(const std::string& s = "") const {
cout << s << ":" << endl;
star->print("Star graph: ");
for(Key key: correctedBeliefIndices | boost::adaptors::map_keys) {
for (Key key : correctedBeliefIndices | boost::adaptors::map_keys) {
cout << "Belief factor index for " << key << ": "
<< correctedBeliefIndices.at(key) << endl;
}
if (unary)
unary->print("Unary: ");
if (unary) unary->print("Unary: ");
}
};
typedef std::map<Key, StarGraph> StarGraphs;
StarGraphs starGraphs_; ///< star graph at each variable
public:
public:
/** Constructor
* Need all discrete keys to access node's cardinality for creating belief factors
* Need all discrete keys to access node's cardinality for creating belief
* factors
* TODO: so troublesome!!
*/
LoopyBelief(const DiscreteFactorGraph& graph,
const std::map<Key, DiscreteKey>& allDiscreteKeys) :
starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) {
}
const std::map<Key, DiscreteKey>& allDiscreteKeys)
: starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) {}
/// print
void print(const std::string& s = "") const {
cout << s << ":" << endl;
for(Key key: starGraphs_ | boost::adaptors::map_keys) {
for (Key key : starGraphs_ | boost::adaptors::map_keys) {
starGraphs_.at(key).print((boost::format("Node %d:") % key).str());
}
}
@ -79,12 +81,13 @@ public:
DiscreteFactorGraph::shared_ptr iterate(
const std::map<Key, DiscreteKey>& allDiscreteKeys) {
static const bool debug = false;
static DiscreteConditional::shared_ptr dummyCond; // unused by-product of elimination
static DiscreteConditional::shared_ptr
dummyCond; // unused by-product of elimination
DiscreteFactorGraph::shared_ptr beliefs(new DiscreteFactorGraph());
std::map<Key, std::map<Key, DiscreteFactor::shared_ptr> > allMessages;
// Eliminate each star graph
for(Key key: starGraphs_ | boost::adaptors::map_keys) {
// cout << "***** Node " << key << "*****" << endl;
for (Key key : starGraphs_ | boost::adaptors::map_keys) {
// cout << "***** Node " << key << "*****" << endl;
// initialize belief to the unary factor from the original graph
DecisionTreeFactor::shared_ptr beliefAtKey;
@ -92,15 +95,16 @@ public:
std::map<Key, DiscreteFactor::shared_ptr> messages;
// eliminate each neighbor in this star graph one by one
for(Key neighbor: starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) {
for (Key neighbor : starGraphs_.at(key).correctedBeliefIndices |
boost::adaptors::map_keys) {
DiscreteFactorGraph subGraph;
for(size_t factor: starGraphs_.at(key).varIndex_[neighbor]) {
for (size_t factor : starGraphs_.at(key).varIndex_[neighbor]) {
subGraph.push_back(starGraphs_.at(key).star->at(factor));
}
if (debug) subGraph.print("------- Subgraph:");
DiscreteFactor::shared_ptr message;
boost::tie(dummyCond, message) = EliminateDiscrete(subGraph,
Ordering(list_of(neighbor)));
boost::tie(dummyCond, message) =
EliminateDiscrete(subGraph, Ordering(list_of(neighbor)));
// store the new factor into messages
messages.insert(make_pair(neighbor, message));
if (debug) message->print("------- Message: ");
@ -108,14 +112,12 @@ public:
// Belief is the product of all messages and the unary factor
// Incorporate new the factor to belief
if (!beliefAtKey)
beliefAtKey = boost::dynamic_pointer_cast<DecisionTreeFactor>(
message);
else
beliefAtKey =
boost::make_shared<DecisionTreeFactor>(
(*beliefAtKey)
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
message)));
boost::dynamic_pointer_cast<DecisionTreeFactor>(message);
else
beliefAtKey = boost::make_shared<DecisionTreeFactor>(
(*beliefAtKey) *
(*boost::dynamic_pointer_cast<DecisionTreeFactor>(message)));
}
if (starGraphs_.at(key).unary)
beliefAtKey = boost::make_shared<DecisionTreeFactor>(
@ -133,7 +135,8 @@ public:
sumFactorTable = (boost::format("%s %f") % sumFactorTable % sum).str();
DecisionTreeFactor sumFactor(allDiscreteKeys.at(key), sumFactorTable);
if (debug) sumFactor.print("denomFactor: ");
beliefAtKey = boost::make_shared<DecisionTreeFactor>((*beliefAtKey) / sumFactor);
beliefAtKey =
boost::make_shared<DecisionTreeFactor>((*beliefAtKey) / sumFactor);
if (debug) beliefAtKey->print("New belief at key normalized: ");
beliefs->push_back(beliefAtKey);
allMessages[key] = messages;
@ -141,17 +144,20 @@ public:
// Update corrected beliefs
VariableIndex beliefFactors(*beliefs);
for(Key key: starGraphs_ | boost::adaptors::map_keys) {
for (Key key : starGraphs_ | boost::adaptors::map_keys) {
std::map<Key, DiscreteFactor::shared_ptr> messages = allMessages[key];
for(Key neighbor: starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) {
DecisionTreeFactor correctedBelief = (*boost::dynamic_pointer_cast<
DecisionTreeFactor>(beliefs->at(beliefFactors[key].front())))
/ (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
for (Key neighbor : starGraphs_.at(key).correctedBeliefIndices |
boost::adaptors::map_keys) {
DecisionTreeFactor correctedBelief =
(*boost::dynamic_pointer_cast<DecisionTreeFactor>(
beliefs->at(beliefFactors[key].front()))) /
(*boost::dynamic_pointer_cast<DecisionTreeFactor>(
messages.at(neighbor)));
if (debug) correctedBelief.print("correctedBelief: ");
size_t beliefIndex = starGraphs_.at(neighbor).correctedBeliefIndices.at(
key);
starGraphs_.at(neighbor).star->replace(beliefIndex,
size_t beliefIndex =
starGraphs_.at(neighbor).correctedBeliefIndices.at(key);
starGraphs_.at(neighbor).star->replace(
beliefIndex,
boost::make_shared<DecisionTreeFactor>(correctedBelief));
}
}
@ -161,21 +167,22 @@ public:
return beliefs;
}
private:
private:
/**
* Build star graphs for each node.
*/
StarGraphs buildStarGraphs(const DiscreteFactorGraph& graph,
StarGraphs buildStarGraphs(
const DiscreteFactorGraph& graph,
const std::map<Key, DiscreteKey>& allDiscreteKeys) const {
StarGraphs starGraphs;
VariableIndex varIndex(graph); ///< access to all factors of each node
for(Key key: varIndex | boost::adaptors::map_keys) {
for (Key key : varIndex | boost::adaptors::map_keys) {
// initialize to multiply with other unary factors later
DecisionTreeFactor::shared_ptr prodOfUnaries;
// collect all factors involving this key in the original graph
DiscreteFactorGraph::shared_ptr star(new DiscreteFactorGraph());
for(size_t factorIndex: varIndex[key]) {
for (size_t factorIndex : varIndex[key]) {
star->push_back(graph.at(factorIndex));
// accumulate unary factors
@ -185,8 +192,8 @@ private:
graph.at(factorIndex));
else
prodOfUnaries = boost::make_shared<DecisionTreeFactor>(
*prodOfUnaries
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
*prodOfUnaries *
(*boost::dynamic_pointer_cast<DecisionTreeFactor>(
graph.at(factorIndex))));
}
}
@ -196,7 +203,7 @@ private:
KeySet neighbors = star->keys();
neighbors.erase(key);
CorrectedBeliefIndices correctedBeliefIndices;
for(Key neighbor: neighbors) {
for (Key neighbor : neighbors) {
// TODO: default table for keys with more than 2 values?
string initialBelief;
for (size_t v = 0; v < allDiscreteKeys.at(neighbor).second - 1; ++v) {
@ -207,9 +214,8 @@ private:
DecisionTreeFactor(allDiscreteKeys.at(neighbor), initialBelief));
correctedBeliefIndices.insert(make_pair(neighbor, star->size() - 1));
}
starGraphs.insert(
make_pair(key,
StarGraph(star, correctedBeliefIndices, prodOfUnaries)));
starGraphs.insert(make_pair(
key, StarGraph(star, correctedBeliefIndices, prodOfUnaries)));
}
return starGraphs;
}
@ -249,7 +255,6 @@ TEST_UNSAFE(LoopyBelief, construction) {
DiscreteFactorGraph::shared_ptr beliefs = solver.iterate(allKeys);
beliefs->print();
}
}
/* ************************************************************************* */

View File

@ -5,14 +5,13 @@
*/
//#define ENABLE_TIMING
#include <gtsam_unstable/discrete/Scheduler.h>
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/timing.h>
#include <gtsam_unstable/discrete/Scheduler.h>
#include <CppUnitLite/TestHarness.h>
#include <boost/assign/std/vector.hpp>
#include <boost/assign/std/map.hpp>
#include <boost/assign/std/vector.hpp>
#include <boost/optional.hpp>
using namespace boost::assign;
@ -22,7 +21,6 @@ using namespace gtsam;
/* ************************************************************************* */
// Create the expected graph of constraints
DiscreteFactorGraph createExpected() {
// Start building
size_t nrFaculty = 4, nrTimeSlots = 3;
@ -79,8 +77,7 @@ DiscreteFactorGraph createExpected() {
}
/* ************************************************************************* */
TEST( schedulingExample, test)
{
TEST(schedulingExample, test) {
Scheduler s(2);
// add faculty
@ -121,7 +118,7 @@ TEST( schedulingExample, test)
// Do brute force product and output that to file
DecisionTreeFactor product = s.product();
//product.dot("scheduling", false);
// product.dot("scheduling", false);
// Do exact inference
gttic(small);
@ -129,25 +126,24 @@ TEST( schedulingExample, test)
gttoc(small);
// print MPE, commented out as unit tests don't print
// s.printAssignment(MPE);
// s.printAssignment(MPE);
// Commented out as does not work yet
// s.runArcConsistency(8,10,true);
// find the assignment of students to slots with most possible committees
// Commented out as not implemented yet
// sharedValues bestSchedule = s.bestSchedule();
// GTSAM_PRINT(*bestSchedule);
// sharedValues bestSchedule = s.bestSchedule();
// GTSAM_PRINT(*bestSchedule);
// find the corresponding most desirable committee assignment
// Commented out as not implemented yet
// sharedValues bestAssignment = s.bestAssignment(bestSchedule);
// GTSAM_PRINT(*bestAssignment);
// sharedValues bestAssignment = s.bestAssignment(bestSchedule);
// GTSAM_PRINT(*bestAssignment);
}
/* ************************************************************************* */
TEST( schedulingExample, smallFromFile)
{
TEST(schedulingExample, smallFromFile) {
string path(TOPSRCDIR "/gtsam_unstable/discrete/examples/");
Scheduler s(2, path + "small.csv");
@ -179,4 +175,3 @@ int main() {
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -5,21 +5,22 @@
* @author Frank Dellaert
*/
#include <gtsam_unstable/discrete/CSP.h>
#include <CppUnitLite/TestHarness.h>
#include <gtsam_unstable/discrete/CSP.h>
#include <boost/assign/std/map.hpp>
using boost::assign::insert;
#include <stdarg.h>
#include <iostream>
#include <sstream>
#include <stdarg.h>
using namespace std;
using namespace gtsam;
#define PRINT false
class Sudoku: public CSP {
class Sudoku : public CSP {
/// sudoku size
size_t n_;
@ -27,25 +28,21 @@ class Sudoku: public CSP {
typedef std::pair<size_t, size_t> IJ;
std::map<IJ, DiscreteKey> dkeys_;
public:
public:
/// return DiscreteKey for cell(i,j)
const DiscreteKey& dkey(size_t i, size_t j) const {
return dkeys_.at(IJ(i, j));
}
/// return Key for cell(i,j)
Key key(size_t i, size_t j) const {
return dkey(i, j).first;
}
Key key(size_t i, size_t j) const { return dkey(i, j).first; }
/// Constructor
Sudoku(size_t n, ...) :
n_(n) {
Sudoku(size_t n, ...) : n_(n) {
// Create variables, ordering, and unary constraints
va_list ap;
va_start(ap, n);
Key k=0;
Key k = 0;
for (size_t i = 0; i < n; ++i) {
for (size_t j = 0; j < n; ++j, ++k) {
// create the key
@ -56,23 +53,21 @@ public:
// cout << value << " ";
if (value != 0) addSingleValue(dkeys_[ij], value - 1);
}
//cout << endl;
// cout << endl;
}
va_end(ap);
// add row constraints
for (size_t i = 0; i < n; i++) {
DiscreteKeys dkeys;
for (size_t j = 0; j < n; j++)
dkeys += dkey(i, j);
for (size_t j = 0; j < n; j++) dkeys += dkey(i, j);
addAllDiff(dkeys);
}
// add col constraints
for (size_t j = 0; j < n; j++) {
DiscreteKeys dkeys;
for (size_t i = 0; i < n; i++)
dkeys += dkey(i, j);
for (size_t i = 0; i < n; i++) dkeys += dkey(i, j);
addAllDiff(dkeys);
}
@ -84,8 +79,7 @@ public:
// Box I,J
DiscreteKeys dkeys;
for (size_t i = i0; i < i0 + N; i++)
for (size_t j = j0; j < j0 + N; j++)
dkeys += dkey(i, j);
for (size_t j = j0; j < j0 + N; j++) dkeys += dkey(i, j);
addAllDiff(dkeys);
j0 += N;
}
@ -109,74 +103,59 @@ public:
DiscreteFactor::sharedValues MPE = optimalAssignment();
printAssignment(MPE);
}
};
/* ************************************************************************* */
TEST_UNSAFE( Sudoku, small)
{
Sudoku csp(4,
1,0, 0,4,
0,0, 0,0,
TEST_UNSAFE(Sudoku, small) {
Sudoku csp(4, 1, 0, 0, 4, 0, 0, 0, 0,
4,0, 2,0,
0,1, 0,0);
4, 0, 2, 0, 0, 1, 0, 0);
// Do BP
csp.runArcConsistency(4,10,PRINT);
csp.runArcConsistency(4, 10, PRINT);
// optimize and check
CSP::sharedValues solution = csp.optimalAssignment();
CSP::Values expected;
insert(expected)
(csp.key(0,0), 0)(csp.key(0,1), 1)(csp.key(0,2), 2)(csp.key(0,3), 3)
(csp.key(1,0), 2)(csp.key(1,1), 3)(csp.key(1,2), 0)(csp.key(1,3), 1)
(csp.key(2,0), 3)(csp.key(2,1), 2)(csp.key(2,2), 1)(csp.key(2,3), 0)
(csp.key(3,0), 1)(csp.key(3,1), 0)(csp.key(3,2), 3)(csp.key(3,3), 2);
EXPECT(assert_equal(expected,*solution));
//csp.printAssignment(solution);
insert(expected)(csp.key(0, 0), 0)(csp.key(0, 1), 1)(csp.key(0, 2), 2)(
csp.key(0, 3), 3)(csp.key(1, 0), 2)(csp.key(1, 1), 3)(csp.key(1, 2), 0)(
csp.key(1, 3), 1)(csp.key(2, 0), 3)(csp.key(2, 1), 2)(csp.key(2, 2), 1)(
csp.key(2, 3), 0)(csp.key(3, 0), 1)(csp.key(3, 1), 0)(csp.key(3, 2), 3)(
csp.key(3, 3), 2);
EXPECT(assert_equal(expected, *solution));
// csp.printAssignment(solution);
}
/* ************************************************************************* */
TEST_UNSAFE( Sudoku, easy)
{
Sudoku sudoku(9,
0,0,5, 0,9,0, 0,0,1,
0,0,0, 0,0,2, 0,7,3,
7,6,0, 0,0,8, 2,0,0,
TEST_UNSAFE(Sudoku, easy) {
Sudoku sudoku(9, 0, 0, 5, 0, 9, 0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 7, 3, 7, 6,
0, 0, 0, 8, 2, 0, 0,
0,1,2, 0,0,9, 0,0,4,
0,0,0, 2,0,3, 0,0,0,
3,0,0, 1,0,0, 9,6,0,
0, 1, 2, 0, 0, 9, 0, 0, 4, 0, 0, 0, 2, 0, 3, 0, 0, 0, 3, 0, 0,
1, 0, 0, 9, 6, 0,
0,0,1, 9,0,0, 0,5,8,
9,7,0, 5,0,0, 0,0,0,
5,0,0, 0,3,0, 7,0,0);
0, 0, 1, 9, 0, 0, 0, 5, 8, 9, 7, 0, 5, 0, 0, 0, 0, 0, 5, 0, 0,
0, 3, 0, 7, 0, 0);
// Do BP
sudoku.runArcConsistency(4,10,PRINT);
sudoku.runArcConsistency(4, 10, PRINT);
// sudoku.printSolution(); // don't do it
}
/* ************************************************************************* */
TEST_UNSAFE( Sudoku, extreme)
{
Sudoku sudoku(9,
0,0,9, 7,4,8, 0,0,0,
7,0,0, 0,0,0, 0,0,0,
0,2,0, 1,0,9, 0,0,0,
TEST_UNSAFE(Sudoku, extreme) {
Sudoku sudoku(9, 0, 0, 9, 7, 4, 8, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,
0, 1, 0, 9, 0, 0, 0,
0,0,7, 0,0,0, 2,4,0,
0,6,4, 0,1,0, 5,9,0,
0,9,8, 0,0,0, 3,0,0,
0, 0, 7, 0, 0, 0, 2, 4, 0, 0, 6, 4, 0, 1, 0, 5, 9, 0, 0, 9, 8,
0, 0, 0, 3, 0, 0,
0,0,0, 8,0,3, 0,2,0,
0,0,0, 0,0,0, 0,0,6,
0,0,0, 2,7,5, 9,0,0);
0, 0, 0, 8, 0, 3, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0,
2, 7, 5, 9, 0, 0);
// Do BP
sudoku.runArcConsistency(9,10,PRINT);
sudoku.runArcConsistency(9, 10, PRINT);
#ifdef METIS
VariableIndexOrdered index(sudoku);
@ -185,29 +164,24 @@ TEST_UNSAFE( Sudoku, extreme)
index.outputMetisFormat(os);
#endif
//sudoku.printSolution(); // don't do it
// sudoku.printSolution(); // don't do it
}
/* ************************************************************************* */
TEST_UNSAFE( Sudoku, AJC_3star_Feb8_2012)
{
Sudoku sudoku(9,
9,5,0, 0,0,6, 0,0,0,
0,8,4, 0,7,0, 0,0,0,
6,2,0, 5,0,0, 4,0,0,
TEST_UNSAFE(Sudoku, AJC_3star_Feb8_2012) {
Sudoku sudoku(9, 9, 5, 0, 0, 0, 6, 0, 0, 0, 0, 8, 4, 0, 7, 0, 0, 0, 0, 6, 2,
0, 5, 0, 0, 4, 0, 0,
0,0,0, 2,9,0, 6,0,0,
0,9,0, 0,0,0, 0,2,0,
0,0,2, 0,6,3, 0,0,0,
0, 0, 0, 2, 9, 0, 6, 0, 0, 0, 9, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2,
0, 6, 3, 0, 0, 0,
0,0,9, 0,0,7, 0,6,8,
0,0,0, 0,3,0, 2,9,0,
0,0,0, 1,0,0, 0,3,7);
0, 0, 9, 0, 0, 7, 0, 6, 8, 0, 0, 0, 0, 3, 0, 2, 9, 0, 0, 0, 0,
1, 0, 0, 0, 3, 7);
// Do BP
sudoku.runArcConsistency(9,10,PRINT);
sudoku.runArcConsistency(9, 10, PRINT);
//sudoku.printSolution(); // don't do it
// sudoku.printSolution(); // don't do it
}
/* ************************************************************************* */
@ -216,4 +190,3 @@ int main() {
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */