Revamped arc consistency
							parent
							
								
									770fda9a26
								
							
						
					
					
						commit
						dd50975668
					
				| 
						 | 
				
			
			@ -5,105 +5,115 @@
 | 
			
		|||
 * @author Frank Dellaert
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
#include <gtsam/base/Testable.h>
 | 
			
		||||
#include <gtsam_unstable/discrete/AllDiff.h>
 | 
			
		||||
#include <gtsam_unstable/discrete/Domain.h>
 | 
			
		||||
 | 
			
		||||
#include <gtsam_unstable/discrete/AllDiff.h>
 | 
			
		||||
#include <gtsam/base/Testable.h>
 | 
			
		||||
#include <boost/make_shared.hpp>
 | 
			
		||||
 | 
			
		||||
namespace gtsam {
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
AllDiff::AllDiff(const DiscreteKeys& dkeys) : Constraint(dkeys.indices()) {
 | 
			
		||||
  for (const DiscreteKey& dkey : dkeys) cardinalities_.insert(dkey);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const {
 | 
			
		||||
  std::cout << s << "AllDiff on ";
 | 
			
		||||
  for (Key dkey : keys_) std::cout << formatter(dkey) << " ";
 | 
			
		||||
  std::cout << std::endl;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
double AllDiff::operator()(const Values& values) const {
 | 
			
		||||
  std::set<size_t> taken;  // record values taken by keys
 | 
			
		||||
  for (Key dkey : keys_) {
 | 
			
		||||
    size_t value = values.at(dkey);      // get the value for that key
 | 
			
		||||
    if (taken.count(value)) return 0.0;  // check if value alreday taken
 | 
			
		||||
    taken.insert(value);  // if not, record it as taken and keep checking
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  AllDiff::AllDiff(const DiscreteKeys& dkeys) :
 | 
			
		||||
    Constraint(dkeys.indices()) {
 | 
			
		||||
    for(const DiscreteKey& dkey: dkeys)
 | 
			
		||||
        cardinalities_.insert(dkey);
 | 
			
		||||
  }
 | 
			
		||||
  return 1.0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
DecisionTreeFactor AllDiff::toDecisionTreeFactor() const {
 | 
			
		||||
  // We will do this by converting the allDif into many BinaryAllDiff
 | 
			
		||||
  // constraints
 | 
			
		||||
  DecisionTreeFactor converted;
 | 
			
		||||
  size_t nrKeys = keys_.size();
 | 
			
		||||
  for (size_t i1 = 0; i1 < nrKeys; i1++)
 | 
			
		||||
    for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) {
 | 
			
		||||
      BinaryAllDiff binary12(discreteKey(i1), discreteKey(i2));
 | 
			
		||||
      converted = converted * binary12.toDecisionTreeFactor();
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  void AllDiff::print(const std::string& s,
 | 
			
		||||
      const KeyFormatter& formatter) const {
 | 
			
		||||
    std::cout << s << "AllDiff on ";
 | 
			
		||||
    for (Key dkey: keys_)
 | 
			
		||||
      std::cout << formatter(dkey) << " ";
 | 
			
		||||
    std::cout << std::endl;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  double AllDiff::operator()(const Values& values) const {
 | 
			
		||||
    std::set < size_t > taken; // record values taken by keys
 | 
			
		||||
    for(Key dkey: keys_) {
 | 
			
		||||
      size_t value = values.at(dkey); // get the value for that key
 | 
			
		||||
      if (taken.count(value)) return 0.0;// check if value alreday taken
 | 
			
		||||
      taken.insert(value);// if not, record it as taken and keep checking
 | 
			
		||||
    }
 | 
			
		||||
  return converted;
 | 
			
		||||
}
 | 
			
		||||
    return 1.0;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const {
 | 
			
		||||
  // TODO: can we do this more efficiently?
 | 
			
		||||
  return toDecisionTreeFactor() * f;
 | 
			
		||||
}
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  DecisionTreeFactor AllDiff::toDecisionTreeFactor() const {
 | 
			
		||||
    // We will do this by converting the allDif into many BinaryAllDiff constraints
 | 
			
		||||
    DecisionTreeFactor converted;
 | 
			
		||||
    size_t nrKeys = keys_.size();
 | 
			
		||||
    for (size_t i1 = 0; i1 < nrKeys; i1++)
 | 
			
		||||
      for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) {
 | 
			
		||||
        BinaryAllDiff binary12(discreteKey(i1),discreteKey(i2));
 | 
			
		||||
        converted = converted * binary12.toDecisionTreeFactor();
 | 
			
		||||
      }
 | 
			
		||||
    return converted;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
bool AllDiff::ensureArcConsistency(size_t j,
 | 
			
		||||
                                   std::vector<Domain>& domains) const {
 | 
			
		||||
  // Though strictly not part of allDiff, we check for
 | 
			
		||||
  // a value in domains[j] that does not occur in any other connected domain.
 | 
			
		||||
  // If found, we make this a singleton...
 | 
			
		||||
  // TODO: make a new constraint where this really is true
 | 
			
		||||
  Domain& Dj = domains[j];
 | 
			
		||||
  if (Dj.checkAllDiff(keys_, domains)) return true;
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const {
 | 
			
		||||
    // TODO: can we do this more efficiently?
 | 
			
		||||
    return toDecisionTreeFactor() * f;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Check all other domains for singletons and erase corresponding values
 | 
			
		||||
  // This is the same as arc-consistency on the equivalent binary constraints
 | 
			
		||||
  bool changed = false;
 | 
			
		||||
  for (Key k : keys_)
 | 
			
		||||
    if (k != j) {
 | 
			
		||||
      const Domain& Dk = domains[k];
 | 
			
		||||
      if (Dk.isSingleton()) {  // check if singleton
 | 
			
		||||
        size_t value = Dk.firstValue();
 | 
			
		||||
        if (Dj.contains(value)) {
 | 
			
		||||
          Dj.erase(value);  // erase value if true
 | 
			
		||||
          changed = true;
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  bool AllDiff::ensureArcConsistency(size_t j,
 | 
			
		||||
                                     std::vector<Domain>* domains) const {
 | 
			
		||||
    // We are changing the domain of variable j. 
 | 
			
		||||
    // TODO(dellaert): confusing, I thought we were changing others...
 | 
			
		||||
    Domain& Dj = domains->at(j);
 | 
			
		||||
 | 
			
		||||
    // Though strictly not part of allDiff, we check for
 | 
			
		||||
    // a value in domains[j] that does not occur in any other connected domain.
 | 
			
		||||
    // If found, we make this a singleton...
 | 
			
		||||
    // TODO: make a new constraint where this really is true
 | 
			
		||||
    boost::optional<Domain> maybeChanged = Dj.checkAllDiff(keys_, *domains);
 | 
			
		||||
    if (maybeChanged) {
 | 
			
		||||
      Dj = *maybeChanged;
 | 
			
		||||
      return true;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Check all other domains for singletons and erase corresponding values.
 | 
			
		||||
    // This is the same as arc-consistency on the equivalent binary constraints
 | 
			
		||||
    bool changed = false;
 | 
			
		||||
    for (Key k : keys_)
 | 
			
		||||
      if (k != j) {
 | 
			
		||||
        const Domain& Dk = domains->at(k);
 | 
			
		||||
        if (Dk.isSingleton()) {  // check if singleton
 | 
			
		||||
          size_t value = Dk.firstValue();
 | 
			
		||||
          if (Dj.contains(value)) {
 | 
			
		||||
            Dj.erase(value);  // erase value if true
 | 
			
		||||
            changed = true;
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  return changed;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const {
 | 
			
		||||
  DiscreteKeys newKeys;
 | 
			
		||||
  // loop over keys and add them only if they do not appear in values
 | 
			
		||||
  for (Key k : keys_)
 | 
			
		||||
    if (values.find(k) == values.end()) {
 | 
			
		||||
      newKeys.push_back(DiscreteKey(k, cardinalities_.at(k)));
 | 
			
		||||
    }
 | 
			
		||||
  return boost::make_shared<AllDiff>(newKeys);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
Constraint::shared_ptr AllDiff::partiallyApply(
 | 
			
		||||
    const std::vector<Domain>& domains) const {
 | 
			
		||||
  DiscreteFactor::Values known;
 | 
			
		||||
  for (Key k : keys_) {
 | 
			
		||||
    const Domain& Dk = domains[k];
 | 
			
		||||
    if (Dk.isSingleton()) known[k] = Dk.firstValue();
 | 
			
		||||
    return changed;
 | 
			
		||||
  }
 | 
			
		||||
  return partiallyApply(known);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
}  // namespace gtsam
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const {
 | 
			
		||||
    DiscreteKeys newKeys;
 | 
			
		||||
    // loop over keys and add them only if they do not appear in values
 | 
			
		||||
    for(Key k: keys_)
 | 
			
		||||
      if (values.find(k) == values.end()) {
 | 
			
		||||
        newKeys.push_back(DiscreteKey(k,cardinalities_.at(k)));
 | 
			
		||||
      }
 | 
			
		||||
    return boost::make_shared<AllDiff>(newKeys);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  Constraint::shared_ptr AllDiff::partiallyApply(
 | 
			
		||||
      const std::vector<Domain>& domains) const {
 | 
			
		||||
    DiscreteFactor::Values known;
 | 
			
		||||
    for(Key k: keys_) {
 | 
			
		||||
        const Domain& Dk = domains[k];
 | 
			
		||||
        if (Dk.isSingleton())
 | 
			
		||||
          known[k] = Dk.firstValue();
 | 
			
		||||
      }
 | 
			
		||||
    return partiallyApply(known);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
} // namespace gtsam
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,71 +7,70 @@
 | 
			
		|||
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <gtsam/discrete/DiscreteKey.h>
 | 
			
		||||
#include <gtsam_unstable/discrete/BinaryAllDiff.h>
 | 
			
		||||
#include <gtsam/discrete/DiscreteKey.h>
 | 
			
		||||
 | 
			
		||||
namespace gtsam {
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * General AllDiff constraint
 | 
			
		||||
 * Returns 1 if values for all keys are different, 0 otherwise
 | 
			
		||||
 * DiscreteFactors are all awkward in that they have to store two types of keys:
 | 
			
		||||
 * for each variable we have a Key and an Key. In this factor, we
 | 
			
		||||
 * keep the Indices locally, and the Indices are stored in IndexFactor.
 | 
			
		||||
 */
 | 
			
		||||
class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
 | 
			
		||||
  std::map<Key, size_t> cardinalities_;
 | 
			
		||||
 | 
			
		||||
  DiscreteKey discreteKey(size_t i) const {
 | 
			
		||||
    Key j = keys_[i];
 | 
			
		||||
    return DiscreteKey(j, cardinalities_.at(j));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 public:
 | 
			
		||||
  /// Constructor
 | 
			
		||||
  AllDiff(const DiscreteKeys& dkeys);
 | 
			
		||||
 | 
			
		||||
  // print
 | 
			
		||||
  void print(const std::string& s = "", const KeyFormatter& formatter =
 | 
			
		||||
                                            DefaultKeyFormatter) const override;
 | 
			
		||||
 | 
			
		||||
  /// equals
 | 
			
		||||
  bool equals(const DiscreteFactor& other, double tol) const override {
 | 
			
		||||
    if (!dynamic_cast<const AllDiff*>(&other))
 | 
			
		||||
      return false;
 | 
			
		||||
    else {
 | 
			
		||||
      const AllDiff& f(static_cast<const AllDiff&>(other));
 | 
			
		||||
      return cardinalities_.size() == f.cardinalities_.size() &&
 | 
			
		||||
             std::equal(cardinalities_.begin(), cardinalities_.end(),
 | 
			
		||||
                        f.cardinalities_.begin());
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Calculate value = expensive !
 | 
			
		||||
  double operator()(const Values& values) const override;
 | 
			
		||||
 | 
			
		||||
  /// Convert into a decisiontree, can be *very* expensive !
 | 
			
		||||
  DecisionTreeFactor toDecisionTreeFactor() const override;
 | 
			
		||||
 | 
			
		||||
  /// Multiply into a decisiontree
 | 
			
		||||
  DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
 | 
			
		||||
 | 
			
		||||
  /*
 | 
			
		||||
   * Ensure Arc-consistency
 | 
			
		||||
   * Arc-consistency involves creating binaryAllDiff constraints
 | 
			
		||||
   * In which case the combinatorial hyper-arc explosion disappears.
 | 
			
		||||
   * @param j domain to be checked
 | 
			
		||||
   * @param domains all other domains
 | 
			
		||||
  /**
 | 
			
		||||
   * General AllDiff constraint.
 | 
			
		||||
   * Returns 1 if values for all keys are different, 0 otherwise.
 | 
			
		||||
   */
 | 
			
		||||
  bool ensureArcConsistency(size_t j,
 | 
			
		||||
                            std::vector<Domain>& domains) const override;
 | 
			
		||||
  class GTSAM_UNSTABLE_EXPORT AllDiff: public Constraint {
 | 
			
		||||
 | 
			
		||||
  /// Partially apply known values
 | 
			
		||||
  Constraint::shared_ptr partiallyApply(const Values&) const override;
 | 
			
		||||
    std::map<Key,size_t> cardinalities_;
 | 
			
		||||
 | 
			
		||||
  /// Partially apply known values, domain version
 | 
			
		||||
  Constraint::shared_ptr partiallyApply(
 | 
			
		||||
      const std::vector<Domain>&) const override;
 | 
			
		||||
};
 | 
			
		||||
    DiscreteKey discreteKey(size_t i) const {
 | 
			
		||||
      Key j = keys_[i];
 | 
			
		||||
      return DiscreteKey(j,cardinalities_.at(j));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}  // namespace gtsam
 | 
			
		||||
  public:
 | 
			
		||||
 | 
			
		||||
    /// Construct from keys.
 | 
			
		||||
    AllDiff(const DiscreteKeys& dkeys);
 | 
			
		||||
 | 
			
		||||
    // print
 | 
			
		||||
    void print(const std::string& s = "",
 | 
			
		||||
        const KeyFormatter& formatter = DefaultKeyFormatter) const override;
 | 
			
		||||
 | 
			
		||||
    /// equals
 | 
			
		||||
    bool equals(const DiscreteFactor& other, double tol) const override {
 | 
			
		||||
      if(!dynamic_cast<const AllDiff*>(&other))
 | 
			
		||||
        return false;
 | 
			
		||||
      else {
 | 
			
		||||
        const AllDiff& f(static_cast<const AllDiff&>(other));
 | 
			
		||||
        return cardinalities_.size() == f.cardinalities_.size()
 | 
			
		||||
            && std::equal(cardinalities_.begin(), cardinalities_.end(),
 | 
			
		||||
                          f.cardinalities_.begin());
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Calculate value = expensive !
 | 
			
		||||
    double operator()(const Values& values) const override;
 | 
			
		||||
 | 
			
		||||
    /// Convert into a decisiontree, can be *very* expensive !
 | 
			
		||||
    DecisionTreeFactor toDecisionTreeFactor() const override;
 | 
			
		||||
 | 
			
		||||
    /// Multiply into a decisiontree
 | 
			
		||||
    DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
 | 
			
		||||
 | 
			
		||||
    /*
 | 
			
		||||
     * Ensure Arc-consistency
 | 
			
		||||
     * Arc-consistency involves creating binaryAllDiff constraints
 | 
			
		||||
     * In which case the combinatorial hyper-arc explosion disappears.
 | 
			
		||||
     * @param j domain to be checked
 | 
			
		||||
     * @param (in/out) domains all other domains
 | 
			
		||||
     */
 | 
			
		||||
    bool ensureArcConsistency(size_t j,
 | 
			
		||||
                              std::vector<Domain>* domains) const override;
 | 
			
		||||
 | 
			
		||||
    /// Partially apply known values
 | 
			
		||||
    Constraint::shared_ptr partiallyApply(const Values&) const override;
 | 
			
		||||
 | 
			
		||||
    /// Partially apply known values, domain version
 | 
			
		||||
    Constraint::shared_ptr partiallyApply(
 | 
			
		||||
        const std::vector<Domain>&) const override;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
} // namespace gtsam
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,93 +7,92 @@
 | 
			
		|||
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <gtsam/discrete/DecisionTreeFactor.h>
 | 
			
		||||
#include <gtsam_unstable/discrete/Constraint.h>
 | 
			
		||||
#include <gtsam_unstable/discrete/Domain.h>
 | 
			
		||||
#include <gtsam_unstable/discrete/Constraint.h>
 | 
			
		||||
#include <gtsam/discrete/DecisionTreeFactor.h>
 | 
			
		||||
 | 
			
		||||
namespace gtsam {
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Binary AllDiff constraint
 | 
			
		||||
 * Returns 1 if values for two keys are different, 0 otherwise
 | 
			
		||||
 * DiscreteFactors are all awkward in that they have to store two types of keys:
 | 
			
		||||
 * for each variable we have a Index and an Index. In this factor, we
 | 
			
		||||
 * keep the Indices locally, and the Indices are stored in IndexFactor.
 | 
			
		||||
 */
 | 
			
		||||
class BinaryAllDiff : public Constraint {
 | 
			
		||||
  size_t cardinality0_, cardinality1_;  /// cardinality
 | 
			
		||||
 | 
			
		||||
 public:
 | 
			
		||||
  /// Constructor
 | 
			
		||||
  BinaryAllDiff(const DiscreteKey& key1, const DiscreteKey& key2)
 | 
			
		||||
      : Constraint(key1.first, key2.first),
 | 
			
		||||
        cardinality0_(key1.second),
 | 
			
		||||
        cardinality1_(key2.second) {}
 | 
			
		||||
 | 
			
		||||
  // print
 | 
			
		||||
  void print(
 | 
			
		||||
      const std::string& s = "",
 | 
			
		||||
      const KeyFormatter& formatter = DefaultKeyFormatter) const override {
 | 
			
		||||
    std::cout << s << "BinaryAllDiff on " << formatter(keys_[0]) << " and "
 | 
			
		||||
              << formatter(keys_[1]) << std::endl;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// equals
 | 
			
		||||
  bool equals(const DiscreteFactor& other, double tol) const override {
 | 
			
		||||
    if (!dynamic_cast<const BinaryAllDiff*>(&other))
 | 
			
		||||
      return false;
 | 
			
		||||
    else {
 | 
			
		||||
      const BinaryAllDiff& f(static_cast<const BinaryAllDiff&>(other));
 | 
			
		||||
      return (cardinality0_ == f.cardinality0_) &&
 | 
			
		||||
             (cardinality1_ == f.cardinality1_);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Calculate value
 | 
			
		||||
  double operator()(const Values& values) const override {
 | 
			
		||||
    return (double)(values.at(keys_[0]) != values.at(keys_[1]));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Convert into a decisiontree
 | 
			
		||||
  DecisionTreeFactor toDecisionTreeFactor() const override {
 | 
			
		||||
    DiscreteKeys keys;
 | 
			
		||||
    keys.push_back(DiscreteKey(keys_[0], cardinality0_));
 | 
			
		||||
    keys.push_back(DiscreteKey(keys_[1], cardinality1_));
 | 
			
		||||
    std::vector<double> table;
 | 
			
		||||
    for (size_t i1 = 0; i1 < cardinality0_; i1++)
 | 
			
		||||
      for (size_t i2 = 0; i2 < cardinality1_; i2++) table.push_back(i1 != i2);
 | 
			
		||||
    DecisionTreeFactor converted(keys, table);
 | 
			
		||||
    return converted;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Multiply into a decisiontree
 | 
			
		||||
  DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
 | 
			
		||||
    // TODO: can we do this more efficiently?
 | 
			
		||||
    return toDecisionTreeFactor() * f;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /*
 | 
			
		||||
   * Ensure Arc-consistency
 | 
			
		||||
   * @param j domain to be checked
 | 
			
		||||
   * @param domains all other domains
 | 
			
		||||
  /**
 | 
			
		||||
   * Binary AllDiff constraint
 | 
			
		||||
   * Returns 1 if values for two keys are different, 0 otherwise.
 | 
			
		||||
   */
 | 
			
		||||
  bool ensureArcConsistency(size_t j,
 | 
			
		||||
                            std::vector<Domain>& domains) const override {
 | 
			
		||||
    //      throw std::runtime_error(
 | 
			
		||||
    //          "BinaryAllDiff::ensureArcConsistency not implemented");
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  class BinaryAllDiff: public Constraint {
 | 
			
		||||
 | 
			
		||||
  /// Partially apply known values
 | 
			
		||||
  Constraint::shared_ptr partiallyApply(const Values&) const override {
 | 
			
		||||
    throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented");
 | 
			
		||||
  }
 | 
			
		||||
    size_t cardinality0_, cardinality1_; /// cardinality
 | 
			
		||||
 | 
			
		||||
  /// Partially apply known values, domain version
 | 
			
		||||
  Constraint::shared_ptr partiallyApply(
 | 
			
		||||
      const std::vector<Domain>&) const override {
 | 
			
		||||
    throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented");
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
  public:
 | 
			
		||||
 | 
			
		||||
}  // namespace gtsam
 | 
			
		||||
    /// Constructor
 | 
			
		||||
    BinaryAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) :
 | 
			
		||||
      Constraint(key1.first, key2.first),
 | 
			
		||||
        cardinality0_(key1.second), cardinality1_(key2.second) {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // print
 | 
			
		||||
    void print(const std::string& s = "",
 | 
			
		||||
        const KeyFormatter& formatter = DefaultKeyFormatter) const override {
 | 
			
		||||
      std::cout << s << "BinaryAllDiff on " << formatter(keys_[0]) << " and "
 | 
			
		||||
          << formatter(keys_[1]) << std::endl;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// equals
 | 
			
		||||
    bool equals(const DiscreteFactor& other, double tol) const override {
 | 
			
		||||
      if(!dynamic_cast<const BinaryAllDiff*>(&other))
 | 
			
		||||
        return false;
 | 
			
		||||
      else {
 | 
			
		||||
        const BinaryAllDiff& f(static_cast<const BinaryAllDiff&>(other));
 | 
			
		||||
        return (cardinality0_==f.cardinality0_) && (cardinality1_==f.cardinality1_);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Calculate value
 | 
			
		||||
    double operator()(const Values& values) const override {
 | 
			
		||||
      return (double) (values.at(keys_[0]) != values.at(keys_[1]));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Convert into a decisiontree
 | 
			
		||||
    DecisionTreeFactor toDecisionTreeFactor() const override {
 | 
			
		||||
      DiscreteKeys keys;
 | 
			
		||||
      keys.push_back(DiscreteKey(keys_[0],cardinality0_));
 | 
			
		||||
      keys.push_back(DiscreteKey(keys_[1],cardinality1_));
 | 
			
		||||
      std::vector<double> table;
 | 
			
		||||
      for (size_t i1 = 0; i1 < cardinality0_; i1++)
 | 
			
		||||
        for (size_t i2 = 0; i2 < cardinality1_; i2++)
 | 
			
		||||
          table.push_back(i1 != i2);
 | 
			
		||||
      DecisionTreeFactor converted(keys, table);
 | 
			
		||||
      return converted;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Multiply into a decisiontree
 | 
			
		||||
    DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
 | 
			
		||||
      // TODO: can we do this more efficiently?
 | 
			
		||||
      return toDecisionTreeFactor() * f;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /*
 | 
			
		||||
     * Ensure Arc-consistency
 | 
			
		||||
     * @param j domain to be checked
 | 
			
		||||
     * @param domains all other domains
 | 
			
		||||
     */
 | 
			
		||||
    ///
 | 
			
		||||
    bool ensureArcConsistency(size_t j,
 | 
			
		||||
                              std::vector<Domain>* domains) const override {
 | 
			
		||||
      throw std::runtime_error(
 | 
			
		||||
          "BinaryAllDiff::ensureArcConsistency not implemented");
 | 
			
		||||
      return false;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Partially apply known values
 | 
			
		||||
    Constraint::shared_ptr partiallyApply(const Values&) const override {
 | 
			
		||||
      throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Partially apply known values, domain version
 | 
			
		||||
    Constraint::shared_ptr partiallyApply(
 | 
			
		||||
        const std::vector<Domain>&) const override {
 | 
			
		||||
      throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented");
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
} // namespace gtsam
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -56,12 +56,11 @@ void CSP::runArcConsistency(size_t cardinality, size_t nrIterations,
 | 
			
		|||
        // if not already a singleton
 | 
			
		||||
        if (!domains[v].isSingleton()) {
 | 
			
		||||
          // get the constraint and call its ensureArcConsistency method
 | 
			
		||||
          Constraint::shared_ptr constraint =
 | 
			
		||||
              boost::dynamic_pointer_cast<Constraint>((*this)[f]);
 | 
			
		||||
          auto constraint = boost::dynamic_pointer_cast<Constraint>((*this)[f]);
 | 
			
		||||
          if (!constraint)
 | 
			
		||||
            throw runtime_error("CSP:runArcConsistency: non-constraint factor");
 | 
			
		||||
          changed[v] =
 | 
			
		||||
              constraint->ensureArcConsistency(v, domains) || changed[v];
 | 
			
		||||
              constraint->ensureArcConsistency(v, &domains) || changed[v];
 | 
			
		||||
        }
 | 
			
		||||
      }  // f
 | 
			
		||||
      if (changed[v]) anyChange = true;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,68 +17,79 @@
 | 
			
		|||
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <gtsam/discrete/DiscreteFactor.h>
 | 
			
		||||
#include <gtsam_unstable/dllexport.h>
 | 
			
		||||
 | 
			
		||||
#include <gtsam/discrete/DiscreteFactor.h>
 | 
			
		||||
#include <boost/assign.hpp>
 | 
			
		||||
 | 
			
		||||
namespace gtsam {
 | 
			
		||||
 | 
			
		||||
class Domain;
 | 
			
		||||
  class Domain;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Base class for discrete probabilistic factors
 | 
			
		||||
 * The most general one is the derived DecisionTreeFactor
 | 
			
		||||
 */
 | 
			
		||||
class Constraint : public DiscreteFactor {
 | 
			
		||||
 public:
 | 
			
		||||
  typedef boost::shared_ptr<Constraint> shared_ptr;
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  /// Construct n-way factor
 | 
			
		||||
  Constraint(const KeyVector& js) : DiscreteFactor(js) {}
 | 
			
		||||
 | 
			
		||||
  /// Construct unary factor
 | 
			
		||||
  Constraint(Key j) : DiscreteFactor(boost::assign::cref_list_of<1>(j)) {}
 | 
			
		||||
 | 
			
		||||
  /// Construct binary factor
 | 
			
		||||
  Constraint(Key j1, Key j2)
 | 
			
		||||
      : DiscreteFactor(boost::assign::cref_list_of<2>(j1)(j2)) {}
 | 
			
		||||
 | 
			
		||||
  /// construct from container
 | 
			
		||||
  template <class KeyIterator>
 | 
			
		||||
  Constraint(KeyIterator beginKey, KeyIterator endKey)
 | 
			
		||||
      : DiscreteFactor(beginKey, endKey) {}
 | 
			
		||||
 | 
			
		||||
 public:
 | 
			
		||||
  /// @name Standard Constructors
 | 
			
		||||
  /// @{
 | 
			
		||||
 | 
			
		||||
  /// Default constructor for I/O
 | 
			
		||||
  Constraint();
 | 
			
		||||
 | 
			
		||||
  /// Virtual destructor
 | 
			
		||||
  ~Constraint() override {}
 | 
			
		||||
 | 
			
		||||
  /// @}
 | 
			
		||||
  /// @name Standard Interface
 | 
			
		||||
  /// @{
 | 
			
		||||
 | 
			
		||||
  /*
 | 
			
		||||
   * Ensure Arc-consistency
 | 
			
		||||
   * @param j domain to be checked
 | 
			
		||||
   * @param domains all other domains
 | 
			
		||||
  /**
 | 
			
		||||
   * Base class for constraint factors
 | 
			
		||||
   * Derived classes include SingleValue, BinaryAllDiff, and AllDiff.
 | 
			
		||||
   */
 | 
			
		||||
  virtual bool ensureArcConsistency(size_t j,
 | 
			
		||||
                                    std::vector<Domain>& domains) const = 0;
 | 
			
		||||
  class GTSAM_EXPORT Constraint : public DiscreteFactor {
 | 
			
		||||
 | 
			
		||||
  /// Partially apply known values
 | 
			
		||||
  virtual shared_ptr partiallyApply(const Values&) const = 0;
 | 
			
		||||
  public:
 | 
			
		||||
 | 
			
		||||
  /// Partially apply known values, domain version
 | 
			
		||||
  virtual shared_ptr partiallyApply(const std::vector<Domain>&) const = 0;
 | 
			
		||||
  /// @}
 | 
			
		||||
};
 | 
			
		||||
    typedef boost::shared_ptr<Constraint> shared_ptr;
 | 
			
		||||
 | 
			
		||||
  protected:
 | 
			
		||||
 | 
			
		||||
    /// Construct unary constraint factor.
 | 
			
		||||
    Constraint(Key j) :
 | 
			
		||||
      DiscreteFactor(boost::assign::cref_list_of<1>(j)) {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Construct binary constraint factor.
 | 
			
		||||
    Constraint(Key j1, Key j2) :
 | 
			
		||||
      DiscreteFactor(boost::assign::cref_list_of<2>(j1)(j2)) {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Construct n-way constraint factor.
 | 
			
		||||
    Constraint(const KeyVector& js) :
 | 
			
		||||
      DiscreteFactor(js) {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// construct from container
 | 
			
		||||
    template<class KeyIterator>
 | 
			
		||||
    Constraint(KeyIterator beginKey, KeyIterator endKey) :
 | 
			
		||||
      DiscreteFactor(beginKey, endKey) {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
  public:
 | 
			
		||||
 | 
			
		||||
    /// @name Standard Constructors
 | 
			
		||||
    /// @{
 | 
			
		||||
 | 
			
		||||
    /// Default constructor for I/O
 | 
			
		||||
    Constraint();
 | 
			
		||||
 | 
			
		||||
    /// Virtual destructor
 | 
			
		||||
    ~Constraint() override {}
 | 
			
		||||
 | 
			
		||||
    /// @}
 | 
			
		||||
    /// @name Standard Interface
 | 
			
		||||
    /// @{
 | 
			
		||||
 | 
			
		||||
    /*
 | 
			
		||||
     * Ensure Arc-consistency, possibly changing domains of connected variables.
 | 
			
		||||
     * @param j domain to be checked
 | 
			
		||||
     * @param (in/out) domains all other domains
 | 
			
		||||
     * @return true if domains were changed, false otherwise.
 | 
			
		||||
     */
 | 
			
		||||
    virtual bool ensureArcConsistency(size_t j,
 | 
			
		||||
                                      std::vector<Domain>* domains) const = 0;
 | 
			
		||||
 | 
			
		||||
    /// Partially apply known values
 | 
			
		||||
    virtual shared_ptr partiallyApply(const Values&) const = 0;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    /// Partially apply known values, domain version
 | 
			
		||||
    virtual shared_ptr partiallyApply(const std::vector<Domain>&) const = 0;
 | 
			
		||||
    /// @}
 | 
			
		||||
  };
 | 
			
		||||
// DiscreteFactor
 | 
			
		||||
 | 
			
		||||
}  // namespace gtsam
 | 
			
		||||
}// namespace gtsam
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -5,89 +5,90 @@
 | 
			
		|||
 * @author Frank Dellaert
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
#include <gtsam/base/Testable.h>
 | 
			
		||||
#include <gtsam/discrete/DecisionTreeFactor.h>
 | 
			
		||||
#include <gtsam_unstable/discrete/Domain.h>
 | 
			
		||||
 | 
			
		||||
#include <gtsam/discrete/DecisionTreeFactor.h>
 | 
			
		||||
#include <gtsam/base/Testable.h>
 | 
			
		||||
#include <boost/make_shared.hpp>
 | 
			
		||||
 | 
			
		||||
namespace gtsam {
 | 
			
		||||
 | 
			
		||||
using namespace std;
 | 
			
		||||
  using namespace std;
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
void Domain::print(const string& s, const KeyFormatter& formatter) const {
 | 
			
		||||
  //    cout << s << ": Domain on " << formatter(keys_[0]) << " (j=" <<
 | 
			
		||||
  //    formatter(keys_[0]) << ") with values";
 | 
			
		||||
  //    for (size_t v: values_) cout << " " << v;
 | 
			
		||||
  //    cout << endl;
 | 
			
		||||
  for (size_t v : values_) cout << v;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
double Domain::operator()(const Values& values) const {
 | 
			
		||||
  return contains(values.at(keys_[0]));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
DecisionTreeFactor Domain::toDecisionTreeFactor() const {
 | 
			
		||||
  DiscreteKeys keys;
 | 
			
		||||
  keys += DiscreteKey(keys_[0], cardinality_);
 | 
			
		||||
  vector<double> table;
 | 
			
		||||
  for (size_t i1 = 0; i1 < cardinality_; ++i1) table.push_back(contains(i1));
 | 
			
		||||
  DecisionTreeFactor converted(keys, table);
 | 
			
		||||
  return converted;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const {
 | 
			
		||||
  // TODO: can we do this more efficiently?
 | 
			
		||||
  return toDecisionTreeFactor() * f;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
bool Domain::ensureArcConsistency(size_t j, vector<Domain>& domains) const {
 | 
			
		||||
  if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain");
 | 
			
		||||
  Domain& D = domains[j];
 | 
			
		||||
  for (size_t value : values_)
 | 
			
		||||
    if (!D.contains(value)) throw runtime_error("Unsatisfiable");
 | 
			
		||||
  D = *this;
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
bool Domain::checkAllDiff(const KeyVector keys, vector<Domain>& domains) {
 | 
			
		||||
  Key j = keys_[0];
 | 
			
		||||
  // for all values in this domain
 | 
			
		||||
  for (size_t value : values_) {
 | 
			
		||||
    // for all connected domains
 | 
			
		||||
    for (Key k : keys)
 | 
			
		||||
      // if any domain contains the value we cannot make this domain singleton
 | 
			
		||||
      if (k != j && domains[k].contains(value)) goto found;
 | 
			
		||||
    values_.clear();
 | 
			
		||||
    values_.insert(value);
 | 
			
		||||
    return true;  // we changed it
 | 
			
		||||
  found:;
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  void Domain::print(const string& s,
 | 
			
		||||
      const KeyFormatter& formatter) const {
 | 
			
		||||
   cout << s << ": Domain on " << formatter(keys_[0]) << " (j=" <<
 | 
			
		||||
   formatter(keys_[0]) << ") with values";
 | 
			
		||||
   for (size_t v: values_) cout << " " << v;
 | 
			
		||||
   cout << endl;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  double Domain::operator()(const Values& values) const {
 | 
			
		||||
    return contains(values.at(keys_[0]));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  DecisionTreeFactor Domain::toDecisionTreeFactor() const {
 | 
			
		||||
    DiscreteKeys keys;
 | 
			
		||||
    keys += DiscreteKey(keys_[0],cardinality_);
 | 
			
		||||
    vector<double> table;
 | 
			
		||||
    for (size_t i1 = 0; i1 < cardinality_; ++i1)
 | 
			
		||||
      table.push_back(contains(i1));
 | 
			
		||||
    DecisionTreeFactor converted(keys, table);
 | 
			
		||||
    return converted;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const {
 | 
			
		||||
    // TODO: can we do this more efficiently?
 | 
			
		||||
    return toDecisionTreeFactor() * f;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  bool Domain::ensureArcConsistency(size_t j, vector<Domain>* domains) const {
 | 
			
		||||
    if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain");
 | 
			
		||||
    Domain& D = domains->at(j);
 | 
			
		||||
    for(size_t value: values_)
 | 
			
		||||
      if (!D.contains(value)) throw runtime_error("Unsatisfiable");
 | 
			
		||||
    D = *this;
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  boost::optional<Domain> Domain::checkAllDiff(
 | 
			
		||||
      const KeyVector keys, const vector<Domain>& domains) const {
 | 
			
		||||
    Key j = keys_[0];
 | 
			
		||||
    // for all values in this domain
 | 
			
		||||
    for (const size_t value : values_) {
 | 
			
		||||
      // for all connected domains
 | 
			
		||||
      for (const Key k : keys)
 | 
			
		||||
        // if any domain contains the value we cannot make this domain singleton
 | 
			
		||||
        if (k != j && domains[k].contains(value)) goto found;
 | 
			
		||||
      // Otherwise: return a singleton:
 | 
			
		||||
      return Domain(this->discreteKey(), value);
 | 
			
		||||
    found:;
 | 
			
		||||
    }
 | 
			
		||||
    return boost::none;  // we did not change it
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  Constraint::shared_ptr Domain::partiallyApply(
 | 
			
		||||
      const Values& values) const {
 | 
			
		||||
    Values::const_iterator it = values.find(keys_[0]);
 | 
			
		||||
    if (it != values.end() && !contains(it->second)) throw runtime_error(
 | 
			
		||||
        "Domain::partiallyApply: unsatisfiable");
 | 
			
		||||
    return boost::make_shared < Domain > (*this);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  Constraint::shared_ptr Domain::partiallyApply(
 | 
			
		||||
      const vector<Domain>& domains) const {
 | 
			
		||||
    const Domain& Dk = domains[keys_[0]];
 | 
			
		||||
    if (Dk.isSingleton() && !contains(*Dk.begin())) throw runtime_error(
 | 
			
		||||
        "Domain::partiallyApply: unsatisfiable");
 | 
			
		||||
    return boost::make_shared < Domain > (Dk);
 | 
			
		||||
  }
 | 
			
		||||
  return false;  // we did not change it
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
Constraint::shared_ptr Domain::partiallyApply(const Values& values) const {
 | 
			
		||||
  Values::const_iterator it = values.find(keys_[0]);
 | 
			
		||||
  if (it != values.end() && !contains(it->second))
 | 
			
		||||
    throw runtime_error("Domain::partiallyApply: unsatisfiable");
 | 
			
		||||
  return boost::make_shared<Domain>(*this);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
Constraint::shared_ptr Domain::partiallyApply(
 | 
			
		||||
    const vector<Domain>& domains) const {
 | 
			
		||||
  const Domain& Dk = domains[keys_[0]];
 | 
			
		||||
  if (Dk.isSingleton() && !contains(*Dk.begin()))
 | 
			
		||||
    throw runtime_error("Domain::partiallyApply: unsatisfiable");
 | 
			
		||||
  return boost::make_shared<Domain>(Dk);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
}  // namespace gtsam
 | 
			
		||||
} // namespace gtsam
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,18 +7,23 @@
 | 
			
		|||
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <gtsam/discrete/DiscreteKey.h>
 | 
			
		||||
#include <gtsam_unstable/discrete/Constraint.h>
 | 
			
		||||
#include <gtsam/discrete/DiscreteKey.h>
 | 
			
		||||
 | 
			
		||||
namespace gtsam {
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Domain restriction constraint
 | 
			
		||||
 * The Domain class represents a constraint that restricts the possible values a
 | 
			
		||||
 * particular variable, with given key, can take on.
 | 
			
		||||
 */
 | 
			
		||||
class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
 | 
			
		||||
  size_t cardinality_;       /// Cardinality
 | 
			
		||||
  std::set<size_t> values_;  /// allowed values
 | 
			
		||||
 | 
			
		||||
  DiscreteKey discreteKey() const {
 | 
			
		||||
    return DiscreteKey(keys_[0], cardinality_);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 public:
 | 
			
		||||
  typedef boost::shared_ptr<Domain> shared_ptr;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -35,14 +40,10 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
 | 
			
		|||
    values_.insert(v);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Constructor
 | 
			
		||||
  Domain(const Domain& other)
 | 
			
		||||
      : Constraint(other.keys_[0]), values_(other.values_) {}
 | 
			
		||||
 | 
			
		||||
  /// insert a value, non const :-(
 | 
			
		||||
  /// Insert a value, non const :-(
 | 
			
		||||
  void insert(size_t value) { values_.insert(value); }
 | 
			
		||||
 | 
			
		||||
  /// erase a value, non const :-(
 | 
			
		||||
  /// Erase a value, non const :-(
 | 
			
		||||
  void erase(size_t value) { values_.erase(value); }
 | 
			
		||||
 | 
			
		||||
  size_t nrValues() const { return values_.size(); }
 | 
			
		||||
| 
						 | 
				
			
			@ -82,15 +83,17 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
 | 
			
		|||
   * @param domains all other domains
 | 
			
		||||
   */
 | 
			
		||||
  bool ensureArcConsistency(size_t j,
 | 
			
		||||
                            std::vector<Domain>& domains) const override;
 | 
			
		||||
                            std::vector<Domain>* domains) const override;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   *  Check for a value in domain that does not occur in any other connected
 | 
			
		||||
   * domain. If found, we make this a singleton... Called in
 | 
			
		||||
   * AllDiff::ensureArcConsistency
 | 
			
		||||
   *  @param keys connected domains through alldiff
 | 
			
		||||
   * Check for a value in domain that does not occur in any other connected
 | 
			
		||||
   * domain. If found, return a a new singleton domain...
 | 
			
		||||
   * Called in AllDiff::ensureArcConsistency
 | 
			
		||||
   * @param keys connected domains through alldiff
 | 
			
		||||
   * @param keys other domains
 | 
			
		||||
   */
 | 
			
		||||
  bool checkAllDiff(const KeyVector keys, std::vector<Domain>& domains);
 | 
			
		||||
  boost::optional<Domain> checkAllDiff(
 | 
			
		||||
      const KeyVector keys, const std::vector<Domain>& domains) const;
 | 
			
		||||
 | 
			
		||||
  /// Partially apply known values
 | 
			
		||||
  Constraint::shared_ptr partiallyApply(const Values& values) const override;
 | 
			
		||||
| 
						 | 
				
			
			@ -98,6 +101,6 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
 | 
			
		|||
  /// Partially apply known values, domain version
 | 
			
		||||
  Constraint::shared_ptr partiallyApply(
 | 
			
		||||
      const std::vector<Domain>& domains) const override;
 | 
			
		||||
};
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
}  // namespace gtsam
 | 
			
		||||
} // namespace gtsam
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -5,74 +5,75 @@
 | 
			
		|||
 * @author Frank Dellaert
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
#include <gtsam/base/Testable.h>
 | 
			
		||||
#include <gtsam/discrete/DecisionTreeFactor.h>
 | 
			
		||||
#include <gtsam_unstable/discrete/Domain.h>
 | 
			
		||||
#include <gtsam_unstable/discrete/SingleValue.h>
 | 
			
		||||
 | 
			
		||||
#include <gtsam_unstable/discrete/Domain.h>
 | 
			
		||||
#include <gtsam/discrete/DecisionTreeFactor.h>
 | 
			
		||||
#include <gtsam/base/Testable.h>
 | 
			
		||||
#include <boost/make_shared.hpp>
 | 
			
		||||
 | 
			
		||||
namespace gtsam {
 | 
			
		||||
 | 
			
		||||
using namespace std;
 | 
			
		||||
  using namespace std;
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
void SingleValue::print(const string& s, const KeyFormatter& formatter) const {
 | 
			
		||||
  cout << s << "SingleValue on "
 | 
			
		||||
       << "j=" << formatter(keys_[0]) << " with value " << value_ << endl;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
double SingleValue::operator()(const Values& values) const {
 | 
			
		||||
  return (double)(values.at(keys_[0]) == value_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
DecisionTreeFactor SingleValue::toDecisionTreeFactor() const {
 | 
			
		||||
  DiscreteKeys keys;
 | 
			
		||||
  keys += DiscreteKey(keys_[0], cardinality_);
 | 
			
		||||
  vector<double> table;
 | 
			
		||||
  for (size_t i1 = 0; i1 < cardinality_; i1++) table.push_back(i1 == value_);
 | 
			
		||||
  DecisionTreeFactor converted(keys, table);
 | 
			
		||||
  return converted;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const {
 | 
			
		||||
  // TODO: can we do this more efficiently?
 | 
			
		||||
  return toDecisionTreeFactor() * f;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
bool SingleValue::ensureArcConsistency(size_t j,
 | 
			
		||||
                                       vector<Domain>& domains) const {
 | 
			
		||||
  if (j != keys_[0])
 | 
			
		||||
    throw invalid_argument("SingleValue check on wrong domain");
 | 
			
		||||
  Domain& D = domains[j];
 | 
			
		||||
  if (D.isSingleton()) {
 | 
			
		||||
    if (D.firstValue() != value_) throw runtime_error("Unsatisfiable");
 | 
			
		||||
    return false;
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  void SingleValue::print(const string& s,
 | 
			
		||||
      const KeyFormatter& formatter) const {
 | 
			
		||||
    cout << s << "SingleValue on " << "j=" << formatter(keys_[0])
 | 
			
		||||
        << " with value " << value_ << endl;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  double SingleValue::operator()(const Values& values) const {
 | 
			
		||||
    return (double) (values.at(keys_[0]) == value_);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  DecisionTreeFactor SingleValue::toDecisionTreeFactor() const {
 | 
			
		||||
    DiscreteKeys keys;
 | 
			
		||||
    keys += DiscreteKey(keys_[0],cardinality_);
 | 
			
		||||
    vector<double> table;
 | 
			
		||||
    for (size_t i1 = 0; i1 < cardinality_; i1++)
 | 
			
		||||
      table.push_back(i1 == value_);
 | 
			
		||||
    DecisionTreeFactor converted(keys, table);
 | 
			
		||||
    return converted;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const {
 | 
			
		||||
    // TODO: can we do this more efficiently?
 | 
			
		||||
    return toDecisionTreeFactor() * f;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  bool SingleValue::ensureArcConsistency(size_t j,
 | 
			
		||||
                                         vector<Domain>* domains) const {
 | 
			
		||||
    if (j != keys_[0])
 | 
			
		||||
      throw invalid_argument("SingleValue check on wrong domain");
 | 
			
		||||
    Domain& D = domains->at(j);
 | 
			
		||||
    if (D.isSingleton()) {
 | 
			
		||||
      if (D.firstValue() != value_) throw runtime_error("Unsatisfiable");
 | 
			
		||||
      return false;
 | 
			
		||||
    }
 | 
			
		||||
    D = Domain(discreteKey(), value_);
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const {
 | 
			
		||||
    Values::const_iterator it = values.find(keys_[0]);
 | 
			
		||||
    if (it != values.end() && it->second != value_) throw runtime_error(
 | 
			
		||||
        "SingleValue::partiallyApply: unsatisfiable");
 | 
			
		||||
    return boost::make_shared<SingleValue>(keys_[0], cardinality_, value_);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  Constraint::shared_ptr SingleValue::partiallyApply(
 | 
			
		||||
      const vector<Domain>& domains) const {
 | 
			
		||||
    const Domain& Dk = domains[keys_[0]];
 | 
			
		||||
    if (Dk.isSingleton() && !Dk.contains(value_)) throw runtime_error(
 | 
			
		||||
        "SingleValue::partiallyApply: unsatisfiable");
 | 
			
		||||
    return boost::make_shared<SingleValue>(discreteKey(), value_);
 | 
			
		||||
  }
 | 
			
		||||
  D = Domain(discreteKey(), value_);
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const {
 | 
			
		||||
  Values::const_iterator it = values.find(keys_[0]);
 | 
			
		||||
  if (it != values.end() && it->second != value_)
 | 
			
		||||
    throw runtime_error("SingleValue::partiallyApply: unsatisfiable");
 | 
			
		||||
  return boost::make_shared<SingleValue>(keys_[0], cardinality_, value_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
Constraint::shared_ptr SingleValue::partiallyApply(
 | 
			
		||||
    const vector<Domain>& domains) const {
 | 
			
		||||
  const Domain& Dk = domains[keys_[0]];
 | 
			
		||||
  if (Dk.isSingleton() && !Dk.contains(value_))
 | 
			
		||||
    throw runtime_error("SingleValue::partiallyApply: unsatisfiable");
 | 
			
		||||
  return boost::make_shared<SingleValue>(discreteKey(), value_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
}  // namespace gtsam
 | 
			
		||||
} // namespace gtsam
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,73 +7,74 @@
 | 
			
		|||
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <gtsam/discrete/DiscreteKey.h>
 | 
			
		||||
#include <gtsam_unstable/discrete/Constraint.h>
 | 
			
		||||
 | 
			
		||||
namespace gtsam {
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * SingleValue constraint
 | 
			
		||||
 */
 | 
			
		||||
class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
 | 
			
		||||
  /// Number of values
 | 
			
		||||
  size_t cardinality_;
 | 
			
		||||
 | 
			
		||||
  /// allowed value
 | 
			
		||||
  size_t value_;
 | 
			
		||||
 | 
			
		||||
  DiscreteKey discreteKey() const {
 | 
			
		||||
    return DiscreteKey(keys_[0], cardinality_);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 public:
 | 
			
		||||
  typedef boost::shared_ptr<SingleValue> shared_ptr;
 | 
			
		||||
 | 
			
		||||
  /// Constructor
 | 
			
		||||
  SingleValue(Key key, size_t n, size_t value)
 | 
			
		||||
      : Constraint(key), cardinality_(n), value_(value) {}
 | 
			
		||||
 | 
			
		||||
  /// Constructor
 | 
			
		||||
  SingleValue(const DiscreteKey& dkey, size_t value)
 | 
			
		||||
      : Constraint(dkey.first), cardinality_(dkey.second), value_(value) {}
 | 
			
		||||
 | 
			
		||||
  // print
 | 
			
		||||
  void print(const std::string& s = "", const KeyFormatter& formatter =
 | 
			
		||||
                                            DefaultKeyFormatter) const override;
 | 
			
		||||
 | 
			
		||||
  /// equals
 | 
			
		||||
  bool equals(const DiscreteFactor& other, double tol) const override {
 | 
			
		||||
    if (!dynamic_cast<const SingleValue*>(&other))
 | 
			
		||||
      return false;
 | 
			
		||||
    else {
 | 
			
		||||
      const SingleValue& f(static_cast<const SingleValue&>(other));
 | 
			
		||||
      return (cardinality_ == f.cardinality_) && (value_ == f.value_);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Calculate value
 | 
			
		||||
  double operator()(const Values& values) const override;
 | 
			
		||||
 | 
			
		||||
  /// Convert into a decisiontree
 | 
			
		||||
  DecisionTreeFactor toDecisionTreeFactor() const override;
 | 
			
		||||
 | 
			
		||||
  /// Multiply into a decisiontree
 | 
			
		||||
  DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
 | 
			
		||||
 | 
			
		||||
  /*
 | 
			
		||||
   * Ensure Arc-consistency
 | 
			
		||||
   * @param j domain to be checked
 | 
			
		||||
   * @param domains all other domains
 | 
			
		||||
  /**
 | 
			
		||||
   * SingleValue constraint: ensures a variable takes on a certain value.
 | 
			
		||||
   * This could of course also be implemented by changing its `Domain`.
 | 
			
		||||
   */
 | 
			
		||||
  bool ensureArcConsistency(size_t j,
 | 
			
		||||
                            std::vector<Domain>& domains) const override;
 | 
			
		||||
  class GTSAM_UNSTABLE_EXPORT SingleValue: public Constraint {
 | 
			
		||||
    
 | 
			
		||||
    size_t cardinality_; /// < Number of values
 | 
			
		||||
    size_t value_;       ///<  allowed value
 | 
			
		||||
 | 
			
		||||
  /// Partially apply known values
 | 
			
		||||
  Constraint::shared_ptr partiallyApply(const Values& values) const override;
 | 
			
		||||
    DiscreteKey discreteKey() const {
 | 
			
		||||
      return DiscreteKey(keys_[0],cardinality_);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
  /// Partially apply known values, domain version
 | 
			
		||||
  Constraint::shared_ptr partiallyApply(
 | 
			
		||||
      const std::vector<Domain>& domains) const override;
 | 
			
		||||
};
 | 
			
		||||
  public:
 | 
			
		||||
 | 
			
		||||
}  // namespace gtsam
 | 
			
		||||
    typedef boost::shared_ptr<SingleValue> shared_ptr;
 | 
			
		||||
 | 
			
		||||
    /// Construct from key, cardinality, and given value.
 | 
			
		||||
    SingleValue(Key key, size_t n, size_t value) :
 | 
			
		||||
      Constraint(key), cardinality_(n), value_(value) {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Construct from DiscreteKey and given value.
 | 
			
		||||
    SingleValue(const DiscreteKey& dkey, size_t value) :
 | 
			
		||||
      Constraint(dkey.first), cardinality_(dkey.second), value_(value) {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // print
 | 
			
		||||
    void print(const std::string& s = "",
 | 
			
		||||
        const KeyFormatter& formatter = DefaultKeyFormatter) const override;
 | 
			
		||||
 | 
			
		||||
    /// equals
 | 
			
		||||
    bool equals(const DiscreteFactor& other, double tol) const override {
 | 
			
		||||
      if(!dynamic_cast<const SingleValue*>(&other))
 | 
			
		||||
        return false;
 | 
			
		||||
      else {
 | 
			
		||||
        const SingleValue& f(static_cast<const SingleValue&>(other));
 | 
			
		||||
        return (cardinality_==f.cardinality_) && (value_==f.value_);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Calculate value
 | 
			
		||||
    double operator()(const Values& values) const override;
 | 
			
		||||
 | 
			
		||||
    /// Convert into a decisiontree
 | 
			
		||||
    DecisionTreeFactor toDecisionTreeFactor() const override;
 | 
			
		||||
 | 
			
		||||
    /// Multiply into a decisiontree
 | 
			
		||||
    DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
 | 
			
		||||
 | 
			
		||||
    /*
 | 
			
		||||
     * Ensure Arc-consistency: just sets domain[j] to {value_}
 | 
			
		||||
     * @param j domain to be checked
 | 
			
		||||
     * @param domains all other domains
 | 
			
		||||
     */
 | 
			
		||||
    bool ensureArcConsistency(size_t j,
 | 
			
		||||
                              std::vector<Domain>* domains) const override;
 | 
			
		||||
 | 
			
		||||
    /// Partially apply known values
 | 
			
		||||
    Constraint::shared_ptr partiallyApply(const Values& values) const override;
 | 
			
		||||
 | 
			
		||||
    /// Partially apply known values, domain version
 | 
			
		||||
    Constraint::shared_ptr partiallyApply(
 | 
			
		||||
        const std::vector<Domain>& domains) const override;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
} // namespace gtsam
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,12 +19,33 @@ using namespace std;
 | 
			
		|||
using namespace gtsam;
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
TEST_UNSAFE(BinaryAllDif, allInOne) {
 | 
			
		||||
  // Create keys and ordering
 | 
			
		||||
TEST(CSP, SingleValue) {
 | 
			
		||||
  // Create keys for Idaho, Arizona, and Utah, allowing two colors for each:
 | 
			
		||||
  size_t nrColors = 3;
 | 
			
		||||
  DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors);
 | 
			
		||||
 | 
			
		||||
  // Check that a single value is equal to a decision stump with only one "1":
 | 
			
		||||
  SingleValue singleValue(AZ, 2);
 | 
			
		||||
  DecisionTreeFactor f1(AZ, "0 0 1");
 | 
			
		||||
  EXPECT(assert_equal(f1, singleValue.toDecisionTreeFactor()));
 | 
			
		||||
 | 
			
		||||
  // Create domains, laid out as a vector.
 | 
			
		||||
  // TODO(dellaert): should be map??  
 | 
			
		||||
  vector<Domain> domains;
 | 
			
		||||
  domains += Domain(ID), Domain(AZ), Domain(UT);
 | 
			
		||||
 | 
			
		||||
  // Ensure arc-consistency: just wipes out values in AZ domain:
 | 
			
		||||
  EXPECT(singleValue.ensureArcConsistency(1, &domains));
 | 
			
		||||
  LONGS_EQUAL(3, domains[0].nrValues());
 | 
			
		||||
  LONGS_EQUAL(1, domains[1].nrValues());
 | 
			
		||||
  LONGS_EQUAL(3, domains[2].nrValues());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
TEST(CSP, BinaryAllDif) {
 | 
			
		||||
  // Create keys for Idaho, Arizona, and Utah, allowing 2 colors for each:
 | 
			
		||||
  size_t nrColors = 2;
 | 
			
		||||
  //  DiscreteKey ID("Idaho", nrColors), UT("Utah", nrColors), AZ("Arizona",
 | 
			
		||||
  //  nrColors);
 | 
			
		||||
  DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors);
 | 
			
		||||
  DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors);
 | 
			
		||||
 | 
			
		||||
  // Check construction and conversion
 | 
			
		||||
  BinaryAllDiff c1(ID, UT);
 | 
			
		||||
| 
						 | 
				
			
			@ -36,16 +57,51 @@ TEST_UNSAFE(BinaryAllDif, allInOne) {
 | 
			
		|||
  DecisionTreeFactor f2(UT & AZ, "0 1 1 0");
 | 
			
		||||
  EXPECT(assert_equal(f2, c2.toDecisionTreeFactor()));
 | 
			
		||||
 | 
			
		||||
  // Check multiplication of factors with constraint:
 | 
			
		||||
  DecisionTreeFactor f3 = f1 * f2;
 | 
			
		||||
  EXPECT(assert_equal(f3, c1 * f2));
 | 
			
		||||
  EXPECT(assert_equal(f3, c2 * f1));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
TEST_UNSAFE(CSP, allInOne) {
 | 
			
		||||
  // Create keys and ordering
 | 
			
		||||
TEST(CSP, AllDiff) {
 | 
			
		||||
  // Create keys for Idaho, Arizona, and Utah, allowing two colors for each:
 | 
			
		||||
  size_t nrColors = 3;
 | 
			
		||||
  DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors);
 | 
			
		||||
 | 
			
		||||
  // Check construction and conversion
 | 
			
		||||
  vector<DiscreteKey> dkeys{ID, UT, AZ};
 | 
			
		||||
  AllDiff alldiff(dkeys);
 | 
			
		||||
  DecisionTreeFactor actual = alldiff.toDecisionTreeFactor();
 | 
			
		||||
  // GTSAM_PRINT(actual);
 | 
			
		||||
  actual.dot("actual");
 | 
			
		||||
  DecisionTreeFactor f2(
 | 
			
		||||
      ID & AZ & UT,
 | 
			
		||||
      "0 0 0  0 0 1  0 1 0   0 0 1  0 0 0  1 0 0   0 1 0  1 0 0  0 0 0");
 | 
			
		||||
  EXPECT(assert_equal(f2, actual));
 | 
			
		||||
 | 
			
		||||
  // Create domains.
 | 
			
		||||
  vector<Domain> domains;
 | 
			
		||||
  domains += Domain(ID), Domain(AZ), Domain(UT);
 | 
			
		||||
 | 
			
		||||
  // First constrict AZ domain:
 | 
			
		||||
  SingleValue singleValue(AZ, 2);
 | 
			
		||||
  EXPECT(singleValue.ensureArcConsistency(1, &domains));
 | 
			
		||||
 | 
			
		||||
  // Arc-consistency
 | 
			
		||||
  EXPECT(alldiff.ensureArcConsistency(0, &domains));
 | 
			
		||||
  EXPECT(!alldiff.ensureArcConsistency(1, &domains));
 | 
			
		||||
  EXPECT(alldiff.ensureArcConsistency(2, &domains));
 | 
			
		||||
  LONGS_EQUAL(2, domains[0].nrValues());
 | 
			
		||||
  LONGS_EQUAL(1, domains[1].nrValues());
 | 
			
		||||
  LONGS_EQUAL(2, domains[2].nrValues());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
TEST(CSP, allInOne) {
 | 
			
		||||
  // Create keys for Idaho, Arizona, and Utah, allowing 3 colors for each:
 | 
			
		||||
  size_t nrColors = 2;
 | 
			
		||||
  DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors);
 | 
			
		||||
  DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors);
 | 
			
		||||
 | 
			
		||||
  // Create the CSP
 | 
			
		||||
  CSP csp;
 | 
			
		||||
| 
						 | 
				
			
			@ -81,15 +137,12 @@ TEST_UNSAFE(CSP, allInOne) {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
TEST_UNSAFE(CSP, WesternUS) {
 | 
			
		||||
  // Create keys
 | 
			
		||||
TEST(CSP, WesternUS) {
 | 
			
		||||
  // Create keys for all states in Western US, with 4 color possibilities.
 | 
			
		||||
  size_t nrColors = 4;
 | 
			
		||||
  DiscreteKey
 | 
			
		||||
      // Create ordering according to example in ND-CSP.lyx
 | 
			
		||||
      WA(0, nrColors),
 | 
			
		||||
      OR(3, nrColors), CA(1, nrColors), NV(2, nrColors), ID(8, nrColors),
 | 
			
		||||
      UT(9, nrColors), AZ(10, nrColors), MT(4, nrColors), WY(5, nrColors),
 | 
			
		||||
      CO(7, nrColors), NM(6, nrColors);
 | 
			
		||||
  DiscreteKey WA(0, nrColors), OR(3, nrColors), CA(1, nrColors),
 | 
			
		||||
      NV(2, nrColors), ID(8, nrColors), UT(9, nrColors), AZ(10, nrColors),
 | 
			
		||||
      MT(4, nrColors), WY(5, nrColors), CO(7, nrColors), NM(6, nrColors);
 | 
			
		||||
 | 
			
		||||
  // Create the CSP
 | 
			
		||||
  CSP csp;
 | 
			
		||||
| 
						 | 
				
			
			@ -116,10 +169,12 @@ TEST_UNSAFE(CSP, WesternUS) {
 | 
			
		|||
  csp.addAllDiff(WY, CO);
 | 
			
		||||
  csp.addAllDiff(CO, NM);
 | 
			
		||||
 | 
			
		||||
  // Solve
 | 
			
		||||
  // Create ordering according to example in ND-CSP.lyx
 | 
			
		||||
  Ordering ordering;
 | 
			
		||||
  ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7),
 | 
			
		||||
      Key(8), Key(9), Key(10);
 | 
			
		||||
 | 
			
		||||
  // Solve using that ordering:
 | 
			
		||||
  CSP::sharedValues mpe = csp.optimalAssignment(ordering);
 | 
			
		||||
  // GTSAM_PRINT(*mpe);
 | 
			
		||||
  CSP::Values expected;
 | 
			
		||||
| 
						 | 
				
			
			@ -143,33 +198,17 @@ TEST_UNSAFE(CSP, WesternUS) {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
TEST_UNSAFE(CSP, AllDiff) {
 | 
			
		||||
  // Create keys and ordering
 | 
			
		||||
TEST(CSP, ArcConsistency) {
 | 
			
		||||
  // Create keys for Idaho, Arizona, and Utah, allowing three colors for each:
 | 
			
		||||
  size_t nrColors = 3;
 | 
			
		||||
  DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors);
 | 
			
		||||
  DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors);
 | 
			
		||||
 | 
			
		||||
  // Create the CSP
 | 
			
		||||
  // Create the CSP using just one all-diff constraint, plus constrain Arizona.
 | 
			
		||||
  CSP csp;
 | 
			
		||||
  vector<DiscreteKey> dkeys;
 | 
			
		||||
  dkeys += ID, UT, AZ;
 | 
			
		||||
  vector<DiscreteKey> dkeys{ID, UT, AZ};
 | 
			
		||||
  csp.addAllDiff(dkeys);
 | 
			
		||||
  csp.addSingleValue(AZ, 2);
 | 
			
		||||
  //  GTSAM_PRINT(csp);
 | 
			
		||||
 | 
			
		||||
  // Check construction and conversion
 | 
			
		||||
  SingleValue s(AZ, 2);
 | 
			
		||||
  DecisionTreeFactor f1(AZ, "0 0 1");
 | 
			
		||||
  EXPECT(assert_equal(f1, s.toDecisionTreeFactor()));
 | 
			
		||||
 | 
			
		||||
  // Check construction and conversion
 | 
			
		||||
  AllDiff alldiff(dkeys);
 | 
			
		||||
  DecisionTreeFactor actual = alldiff.toDecisionTreeFactor();
 | 
			
		||||
  //  GTSAM_PRINT(actual);
 | 
			
		||||
  //  actual.dot("actual");
 | 
			
		||||
  DecisionTreeFactor f2(
 | 
			
		||||
      ID & AZ & UT,
 | 
			
		||||
      "0 0 0  0 0 1  0 1 0   0 0 1  0 0 0  1 0 0   0 1 0  1 0 0  0 0 0");
 | 
			
		||||
  EXPECT(assert_equal(f2, actual));
 | 
			
		||||
  // GTSAM_PRINT(csp);
 | 
			
		||||
 | 
			
		||||
  // Check an invalid combination, with ID==UT==AZ all same color
 | 
			
		||||
  DiscreteFactor::Values invalid;
 | 
			
		||||
| 
						 | 
				
			
			@ -192,14 +231,15 @@ TEST_UNSAFE(CSP, AllDiff) {
 | 
			
		|||
  EXPECT(assert_equal(expected, *mpe));
 | 
			
		||||
  EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9);
 | 
			
		||||
 | 
			
		||||
  // Arc-consistency
 | 
			
		||||
  // ensure arc-consistency, i.e., narrow domains...
 | 
			
		||||
  vector<Domain> domains;
 | 
			
		||||
  domains += Domain(ID), Domain(AZ), Domain(UT);
 | 
			
		||||
  SingleValue singleValue(AZ, 2);
 | 
			
		||||
  EXPECT(singleValue.ensureArcConsistency(1, domains));
 | 
			
		||||
  EXPECT(alldiff.ensureArcConsistency(0, domains));
 | 
			
		||||
  EXPECT(!alldiff.ensureArcConsistency(1, domains));
 | 
			
		||||
  EXPECT(alldiff.ensureArcConsistency(2, domains));
 | 
			
		||||
  AllDiff alldiff(dkeys);
 | 
			
		||||
  EXPECT(singleValue.ensureArcConsistency(1, &domains));
 | 
			
		||||
  EXPECT(alldiff.ensureArcConsistency(0, &domains));
 | 
			
		||||
  EXPECT(!alldiff.ensureArcConsistency(1, &domains));
 | 
			
		||||
  EXPECT(alldiff.ensureArcConsistency(2, &domains));
 | 
			
		||||
  LONGS_EQUAL(2, domains[0].nrValues());
 | 
			
		||||
  LONGS_EQUAL(1, domains[1].nrValues());
 | 
			
		||||
  LONGS_EQUAL(2, domains[2].nrValues());
 | 
			
		||||
| 
						 | 
				
			
			@ -222,6 +262,7 @@ TEST_UNSAFE(CSP, AllDiff) {
 | 
			
		|||
 | 
			
		||||
  // full arc-consistency test
 | 
			
		||||
  csp.runArcConsistency(nrColors);
 | 
			
		||||
  // GTSAM_PRINT(csp);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue