Formatting only

release/4.3a0
Frank Dellaert 2021-11-18 15:08:01 -05:00
parent dd50975668
commit b7f43906bc
9 changed files with 487 additions and 509 deletions

View File

@ -5,61 +5,59 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam_unstable/discrete/Domain.h>
#include <gtsam_unstable/discrete/AllDiff.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam_unstable/discrete/AllDiff.h>
#include <gtsam_unstable/discrete/Domain.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
namespace gtsam { namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
AllDiff::AllDiff(const DiscreteKeys& dkeys) : AllDiff::AllDiff(const DiscreteKeys& dkeys) : Constraint(dkeys.indices()) {
Constraint(dkeys.indices()) { for (const DiscreteKey& dkey : dkeys) cardinalities_.insert(dkey);
for(const DiscreteKey& dkey: dkeys) }
cardinalities_.insert(dkey);
}
/* ************************************************************************* */ /* ************************************************************************* */
void AllDiff::print(const std::string& s, void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const {
const KeyFormatter& formatter) const {
std::cout << s << "AllDiff on "; std::cout << s << "AllDiff on ";
for (Key dkey: keys_) for (Key dkey : keys_) std::cout << formatter(dkey) << " ";
std::cout << formatter(dkey) << " ";
std::cout << std::endl; std::cout << std::endl;
} }
/* ************************************************************************* */ /* ************************************************************************* */
double AllDiff::operator()(const Values& values) const { double AllDiff::operator()(const Values& values) const {
std::set < size_t > taken; // record values taken by keys std::set<size_t> taken; // record values taken by keys
for(Key dkey: keys_) { for (Key dkey : keys_) {
size_t value = values.at(dkey); // get the value for that key size_t value = values.at(dkey); // get the value for that key
if (taken.count(value)) return 0.0;// check if value alreday taken if (taken.count(value)) return 0.0; // check if value alreday taken
taken.insert(value);// if not, record it as taken and keep checking taken.insert(value); // if not, record it as taken and keep checking
} }
return 1.0; return 1.0;
} }
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor AllDiff::toDecisionTreeFactor() const { DecisionTreeFactor AllDiff::toDecisionTreeFactor() const {
// We will do this by converting the allDif into many BinaryAllDiff constraints // We will do this by converting the allDif into many BinaryAllDiff
// constraints
DecisionTreeFactor converted; DecisionTreeFactor converted;
size_t nrKeys = keys_.size(); size_t nrKeys = keys_.size();
for (size_t i1 = 0; i1 < nrKeys; i1++) for (size_t i1 = 0; i1 < nrKeys; i1++)
for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) { for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) {
BinaryAllDiff binary12(discreteKey(i1),discreteKey(i2)); BinaryAllDiff binary12(discreteKey(i1), discreteKey(i2));
converted = converted * binary12.toDecisionTreeFactor(); converted = converted * binary12.toDecisionTreeFactor();
} }
return converted; return converted;
} }
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const { DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const {
// TODO: can we do this more efficiently? // TODO: can we do this more efficiently?
return toDecisionTreeFactor() * f; return toDecisionTreeFactor() * f;
} }
/* ************************************************************************* */ /* ************************************************************************* */
bool AllDiff::ensureArcConsistency(size_t j, bool AllDiff::ensureArcConsistency(size_t j,
std::vector<Domain>* domains) const { std::vector<Domain>* domains) const {
// We are changing the domain of variable j. // We are changing the domain of variable j.
// TODO(dellaert): confusing, I thought we were changing others... // TODO(dellaert): confusing, I thought we were changing others...
@ -90,30 +88,29 @@ namespace gtsam {
} }
} }
return changed; return changed;
} }
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const { Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const {
DiscreteKeys newKeys; DiscreteKeys newKeys;
// loop over keys and add them only if they do not appear in values // loop over keys and add them only if they do not appear in values
for(Key k: keys_) for (Key k : keys_)
if (values.find(k) == values.end()) { if (values.find(k) == values.end()) {
newKeys.push_back(DiscreteKey(k,cardinalities_.at(k))); newKeys.push_back(DiscreteKey(k, cardinalities_.at(k)));
} }
return boost::make_shared<AllDiff>(newKeys); return boost::make_shared<AllDiff>(newKeys);
} }
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr AllDiff::partiallyApply( Constraint::shared_ptr AllDiff::partiallyApply(
const std::vector<Domain>& domains) const { const std::vector<Domain>& domains) const {
DiscreteFactor::Values known; DiscreteFactor::Values known;
for(Key k: keys_) { for (Key k : keys_) {
const Domain& Dk = domains[k]; const Domain& Dk = domains[k];
if (Dk.isSingleton()) if (Dk.isSingleton()) known[k] = Dk.firstValue();
known[k] = Dk.firstValue();
} }
return partiallyApply(known); return partiallyApply(known);
} }
/* ************************************************************************* */ /* ************************************************************************* */
} // namespace gtsam } // namespace gtsam

View File

@ -7,41 +7,39 @@
#pragma once #pragma once
#include <gtsam_unstable/discrete/BinaryAllDiff.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam_unstable/discrete/BinaryAllDiff.h>
namespace gtsam { namespace gtsam {
/** /**
* General AllDiff constraint. * General AllDiff constraint.
* Returns 1 if values for all keys are different, 0 otherwise. * Returns 1 if values for all keys are different, 0 otherwise.
*/ */
class GTSAM_UNSTABLE_EXPORT AllDiff: public Constraint { class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
std::map<Key, size_t> cardinalities_;
std::map<Key,size_t> cardinalities_;
DiscreteKey discreteKey(size_t i) const { DiscreteKey discreteKey(size_t i) const {
Key j = keys_[i]; Key j = keys_[i];
return DiscreteKey(j,cardinalities_.at(j)); return DiscreteKey(j, cardinalities_.at(j));
} }
public: public:
/// Construct from keys. /// Construct from keys.
AllDiff(const DiscreteKeys& dkeys); AllDiff(const DiscreteKeys& dkeys);
// print // print
void print(const std::string& s = "", void print(const std::string& s = "", const KeyFormatter& formatter =
const KeyFormatter& formatter = DefaultKeyFormatter) const override; DefaultKeyFormatter) const override;
/// equals /// equals
bool equals(const DiscreteFactor& other, double tol) const override { bool equals(const DiscreteFactor& other, double tol) const override {
if(!dynamic_cast<const AllDiff*>(&other)) if (!dynamic_cast<const AllDiff*>(&other))
return false; return false;
else { else {
const AllDiff& f(static_cast<const AllDiff&>(other)); const AllDiff& f(static_cast<const AllDiff&>(other));
return cardinalities_.size() == f.cardinalities_.size() return cardinalities_.size() == f.cardinalities_.size() &&
&& std::equal(cardinalities_.begin(), cardinalities_.end(), std::equal(cardinalities_.begin(), cardinalities_.end(),
f.cardinalities_.begin()); f.cardinalities_.begin());
} }
} }
@ -71,6 +69,6 @@ namespace gtsam {
/// Partially apply known values, domain version /// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply( Constraint::shared_ptr partiallyApply(
const std::vector<Domain>&) const override; const std::vector<Domain>&) const override;
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -7,30 +7,29 @@
#pragma once #pragma once
#include <gtsam_unstable/discrete/Domain.h>
#include <gtsam_unstable/discrete/Constraint.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam_unstable/discrete/Constraint.h>
#include <gtsam_unstable/discrete/Domain.h>
namespace gtsam { namespace gtsam {
/** /**
* Binary AllDiff constraint * Binary AllDiff constraint
* Returns 1 if values for two keys are different, 0 otherwise. * Returns 1 if values for two keys are different, 0 otherwise.
*/ */
class BinaryAllDiff: public Constraint { class BinaryAllDiff : public Constraint {
size_t cardinality0_, cardinality1_; /// cardinality size_t cardinality0_, cardinality1_; /// cardinality
public: public:
/// Constructor /// Constructor
BinaryAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) : BinaryAllDiff(const DiscreteKey& key1, const DiscreteKey& key2)
Constraint(key1.first, key2.first), : Constraint(key1.first, key2.first),
cardinality0_(key1.second), cardinality1_(key2.second) { cardinality0_(key1.second),
} cardinality1_(key2.second) {}
// print // print
void print(const std::string& s = "", void print(
const std::string& s = "",
const KeyFormatter& formatter = DefaultKeyFormatter) const override { const KeyFormatter& formatter = DefaultKeyFormatter) const override {
std::cout << s << "BinaryAllDiff on " << formatter(keys_[0]) << " and " std::cout << s << "BinaryAllDiff on " << formatter(keys_[0]) << " and "
<< formatter(keys_[1]) << std::endl; << formatter(keys_[1]) << std::endl;
@ -38,28 +37,28 @@ namespace gtsam {
/// equals /// equals
bool equals(const DiscreteFactor& other, double tol) const override { bool equals(const DiscreteFactor& other, double tol) const override {
if(!dynamic_cast<const BinaryAllDiff*>(&other)) if (!dynamic_cast<const BinaryAllDiff*>(&other))
return false; return false;
else { else {
const BinaryAllDiff& f(static_cast<const BinaryAllDiff&>(other)); const BinaryAllDiff& f(static_cast<const BinaryAllDiff&>(other));
return (cardinality0_==f.cardinality0_) && (cardinality1_==f.cardinality1_); return (cardinality0_ == f.cardinality0_) &&
(cardinality1_ == f.cardinality1_);
} }
} }
/// Calculate value /// Calculate value
double operator()(const Values& values) const override { double operator()(const Values& values) const override {
return (double) (values.at(keys_[0]) != values.at(keys_[1])); return (double)(values.at(keys_[0]) != values.at(keys_[1]));
} }
/// Convert into a decisiontree /// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override { DecisionTreeFactor toDecisionTreeFactor() const override {
DiscreteKeys keys; DiscreteKeys keys;
keys.push_back(DiscreteKey(keys_[0],cardinality0_)); keys.push_back(DiscreteKey(keys_[0], cardinality0_));
keys.push_back(DiscreteKey(keys_[1],cardinality1_)); keys.push_back(DiscreteKey(keys_[1], cardinality1_));
std::vector<double> table; std::vector<double> table;
for (size_t i1 = 0; i1 < cardinality0_; i1++) for (size_t i1 = 0; i1 < cardinality0_; i1++)
for (size_t i2 = 0; i2 < cardinality1_; i2++) for (size_t i2 = 0; i2 < cardinality1_; i2++) table.push_back(i1 != i2);
table.push_back(i1 != i2);
DecisionTreeFactor converted(keys, table); DecisionTreeFactor converted(keys, table);
return converted; return converted;
} }
@ -93,6 +92,6 @@ namespace gtsam {
const std::vector<Domain>&) const override { const std::vector<Domain>&) const override {
throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented");
} }
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -17,49 +17,40 @@
#pragma once #pragma once
#include <gtsam_unstable/dllexport.h>
#include <gtsam/discrete/DiscreteFactor.h> #include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam_unstable/dllexport.h>
#include <boost/assign.hpp> #include <boost/assign.hpp>
namespace gtsam { namespace gtsam {
class Domain; class Domain;
/** /**
* Base class for constraint factors * Base class for constraint factors
* Derived classes include SingleValue, BinaryAllDiff, and AllDiff. * Derived classes include SingleValue, BinaryAllDiff, and AllDiff.
*/ */
class GTSAM_EXPORT Constraint : public DiscreteFactor { class GTSAM_EXPORT Constraint : public DiscreteFactor {
public: public:
typedef boost::shared_ptr<Constraint> shared_ptr; typedef boost::shared_ptr<Constraint> shared_ptr;
protected: protected:
/// Construct unary constraint factor. /// Construct unary constraint factor.
Constraint(Key j) : Constraint(Key j) : DiscreteFactor(boost::assign::cref_list_of<1>(j)) {}
DiscreteFactor(boost::assign::cref_list_of<1>(j)) {
}
/// Construct binary constraint factor. /// Construct binary constraint factor.
Constraint(Key j1, Key j2) : Constraint(Key j1, Key j2)
DiscreteFactor(boost::assign::cref_list_of<2>(j1)(j2)) { : DiscreteFactor(boost::assign::cref_list_of<2>(j1)(j2)) {}
}
/// Construct n-way constraint factor. /// Construct n-way constraint factor.
Constraint(const KeyVector& js) : Constraint(const KeyVector& js) : DiscreteFactor(js) {}
DiscreteFactor(js) {
}
/// construct from container /// construct from container
template<class KeyIterator> template <class KeyIterator>
Constraint(KeyIterator beginKey, KeyIterator endKey) : Constraint(KeyIterator beginKey, KeyIterator endKey)
DiscreteFactor(beginKey, endKey) { : DiscreteFactor(beginKey, endKey) {}
}
public: public:
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
@ -85,11 +76,10 @@ namespace gtsam {
/// Partially apply known values /// Partially apply known values
virtual shared_ptr partiallyApply(const Values&) const = 0; virtual shared_ptr partiallyApply(const Values&) const = 0;
/// Partially apply known values, domain version /// Partially apply known values, domain version
virtual shared_ptr partiallyApply(const std::vector<Domain>&) const = 0; virtual shared_ptr partiallyApply(const std::vector<Domain>&) const = 0;
/// @} /// @}
}; };
// DiscreteFactor // DiscreteFactor
}// namespace gtsam } // namespace gtsam

View File

@ -5,58 +5,57 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam_unstable/discrete/Domain.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam_unstable/discrete/Domain.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
namespace gtsam { namespace gtsam {
using namespace std; using namespace std;
/* ************************************************************************* */ /* ************************************************************************* */
void Domain::print(const string& s, void Domain::print(const string& s, const KeyFormatter& formatter) const {
const KeyFormatter& formatter) const { cout << s << ": Domain on " << formatter(keys_[0])
cout << s << ": Domain on " << formatter(keys_[0]) << " (j=" << << " (j=" << formatter(keys_[0]) << ") with values";
formatter(keys_[0]) << ") with values"; for (size_t v : values_) cout << " " << v;
for (size_t v: values_) cout << " " << v;
cout << endl; cout << endl;
} }
/* ************************************************************************* */ /* ************************************************************************* */
double Domain::operator()(const Values& values) const { double Domain::operator()(const Values& values) const {
return contains(values.at(keys_[0])); return contains(values.at(keys_[0]));
} }
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor Domain::toDecisionTreeFactor() const { DecisionTreeFactor Domain::toDecisionTreeFactor() const {
DiscreteKeys keys; DiscreteKeys keys;
keys += DiscreteKey(keys_[0],cardinality_); keys += DiscreteKey(keys_[0], cardinality_);
vector<double> table; vector<double> table;
for (size_t i1 = 0; i1 < cardinality_; ++i1) for (size_t i1 = 0; i1 < cardinality_; ++i1) table.push_back(contains(i1));
table.push_back(contains(i1));
DecisionTreeFactor converted(keys, table); DecisionTreeFactor converted(keys, table);
return converted; return converted;
} }
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const { DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const {
// TODO: can we do this more efficiently? // TODO: can we do this more efficiently?
return toDecisionTreeFactor() * f; return toDecisionTreeFactor() * f;
} }
/* ************************************************************************* */ /* ************************************************************************* */
bool Domain::ensureArcConsistency(size_t j, vector<Domain>* domains) const { bool Domain::ensureArcConsistency(size_t j, vector<Domain>* domains) const {
if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain"); if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain");
Domain& D = domains->at(j); Domain& D = domains->at(j);
for(size_t value: values_) for (size_t value : values_)
if (!D.contains(value)) throw runtime_error("Unsatisfiable"); if (!D.contains(value)) throw runtime_error("Unsatisfiable");
D = *this; D = *this;
return true; return true;
} }
/* ************************************************************************* */ /* ************************************************************************* */
boost::optional<Domain> Domain::checkAllDiff( boost::optional<Domain> Domain::checkAllDiff(
const KeyVector keys, const vector<Domain>& domains) const { const KeyVector keys, const vector<Domain>& domains) const {
Key j = keys_[0]; Key j = keys_[0];
// for all values in this domain // for all values in this domain
@ -70,25 +69,24 @@ namespace gtsam {
found:; found:;
} }
return boost::none; // we did not change it return boost::none; // we did not change it
} }
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr Domain::partiallyApply( Constraint::shared_ptr Domain::partiallyApply(const Values& values) const {
const Values& values) const {
Values::const_iterator it = values.find(keys_[0]); Values::const_iterator it = values.find(keys_[0]);
if (it != values.end() && !contains(it->second)) throw runtime_error( if (it != values.end() && !contains(it->second))
"Domain::partiallyApply: unsatisfiable"); throw runtime_error("Domain::partiallyApply: unsatisfiable");
return boost::make_shared < Domain > (*this); return boost::make_shared<Domain>(*this);
} }
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr Domain::partiallyApply( Constraint::shared_ptr Domain::partiallyApply(
const vector<Domain>& domains) const { const vector<Domain>& domains) const {
const Domain& Dk = domains[keys_[0]]; const Domain& Dk = domains[keys_[0]];
if (Dk.isSingleton() && !contains(*Dk.begin())) throw runtime_error( if (Dk.isSingleton() && !contains(*Dk.begin()))
"Domain::partiallyApply: unsatisfiable"); throw runtime_error("Domain::partiallyApply: unsatisfiable");
return boost::make_shared < Domain > (Dk); return boost::make_shared<Domain>(Dk);
} }
/* ************************************************************************* */ /* ************************************************************************* */
} // namespace gtsam } // namespace gtsam

View File

@ -7,8 +7,8 @@
#pragma once #pragma once
#include <gtsam_unstable/discrete/Constraint.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam_unstable/discrete/Constraint.h>
namespace gtsam { namespace gtsam {
@ -101,6 +101,6 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
/// Partially apply known values, domain version /// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply( Constraint::shared_ptr partiallyApply(
const std::vector<Domain>& domains) const override; const std::vector<Domain>& domains) const override;
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -5,47 +5,46 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam_unstable/discrete/SingleValue.h>
#include <gtsam_unstable/discrete/Domain.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam_unstable/discrete/Domain.h>
#include <gtsam_unstable/discrete/SingleValue.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
namespace gtsam { namespace gtsam {
using namespace std; using namespace std;
/* ************************************************************************* */ /* ************************************************************************* */
void SingleValue::print(const string& s, void SingleValue::print(const string& s, const KeyFormatter& formatter) const {
const KeyFormatter& formatter) const { cout << s << "SingleValue on "
cout << s << "SingleValue on " << "j=" << formatter(keys_[0]) << "j=" << formatter(keys_[0]) << " with value " << value_ << endl;
<< " with value " << value_ << endl; }
}
/* ************************************************************************* */ /* ************************************************************************* */
double SingleValue::operator()(const Values& values) const { double SingleValue::operator()(const Values& values) const {
return (double) (values.at(keys_[0]) == value_); return (double)(values.at(keys_[0]) == value_);
} }
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor SingleValue::toDecisionTreeFactor() const { DecisionTreeFactor SingleValue::toDecisionTreeFactor() const {
DiscreteKeys keys; DiscreteKeys keys;
keys += DiscreteKey(keys_[0],cardinality_); keys += DiscreteKey(keys_[0], cardinality_);
vector<double> table; vector<double> table;
for (size_t i1 = 0; i1 < cardinality_; i1++) for (size_t i1 = 0; i1 < cardinality_; i1++) table.push_back(i1 == value_);
table.push_back(i1 == value_);
DecisionTreeFactor converted(keys, table); DecisionTreeFactor converted(keys, table);
return converted; return converted;
} }
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const { DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const {
// TODO: can we do this more efficiently? // TODO: can we do this more efficiently?
return toDecisionTreeFactor() * f; return toDecisionTreeFactor() * f;
} }
/* ************************************************************************* */ /* ************************************************************************* */
bool SingleValue::ensureArcConsistency(size_t j, bool SingleValue::ensureArcConsistency(size_t j,
vector<Domain>* domains) const { vector<Domain>* domains) const {
if (j != keys_[0]) if (j != keys_[0])
throw invalid_argument("SingleValue check on wrong domain"); throw invalid_argument("SingleValue check on wrong domain");
@ -56,24 +55,24 @@ namespace gtsam {
} }
D = Domain(discreteKey(), value_); D = Domain(discreteKey(), value_);
return true; return true;
} }
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const { Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const {
Values::const_iterator it = values.find(keys_[0]); Values::const_iterator it = values.find(keys_[0]);
if (it != values.end() && it->second != value_) throw runtime_error( if (it != values.end() && it->second != value_)
"SingleValue::partiallyApply: unsatisfiable"); throw runtime_error("SingleValue::partiallyApply: unsatisfiable");
return boost::make_shared<SingleValue>(keys_[0], cardinality_, value_); return boost::make_shared<SingleValue>(keys_[0], cardinality_, value_);
} }
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr SingleValue::partiallyApply( Constraint::shared_ptr SingleValue::partiallyApply(
const vector<Domain>& domains) const { const vector<Domain>& domains) const {
const Domain& Dk = domains[keys_[0]]; const Domain& Dk = domains[keys_[0]];
if (Dk.isSingleton() && !Dk.contains(value_)) throw runtime_error( if (Dk.isSingleton() && !Dk.contains(value_))
"SingleValue::partiallyApply: unsatisfiable"); throw runtime_error("SingleValue::partiallyApply: unsatisfiable");
return boost::make_shared<SingleValue>(discreteKey(), value_); return boost::make_shared<SingleValue>(discreteKey(), value_);
} }
/* ************************************************************************* */ /* ************************************************************************* */
} // namespace gtsam } // namespace gtsam

View File

@ -7,48 +7,45 @@
#pragma once #pragma once
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam_unstable/discrete/Constraint.h> #include <gtsam_unstable/discrete/Constraint.h>
namespace gtsam { namespace gtsam {
/** /**
* SingleValue constraint: ensures a variable takes on a certain value. * SingleValue constraint: ensures a variable takes on a certain value.
* This could of course also be implemented by changing its `Domain`. * This could of course also be implemented by changing its `Domain`.
*/ */
class GTSAM_UNSTABLE_EXPORT SingleValue: public Constraint { class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
size_t cardinality_; /// < Number of values size_t cardinality_; /// < Number of values
size_t value_; ///< allowed value size_t value_; ///< allowed value
DiscreteKey discreteKey() const { DiscreteKey discreteKey() const {
return DiscreteKey(keys_[0],cardinality_); return DiscreteKey(keys_[0], cardinality_);
} }
public: public:
typedef boost::shared_ptr<SingleValue> shared_ptr; typedef boost::shared_ptr<SingleValue> shared_ptr;
/// Construct from key, cardinality, and given value. /// Construct from key, cardinality, and given value.
SingleValue(Key key, size_t n, size_t value) : SingleValue(Key key, size_t n, size_t value)
Constraint(key), cardinality_(n), value_(value) { : Constraint(key), cardinality_(n), value_(value) {}
}
/// Construct from DiscreteKey and given value. /// Construct from DiscreteKey and given value.
SingleValue(const DiscreteKey& dkey, size_t value) : SingleValue(const DiscreteKey& dkey, size_t value)
Constraint(dkey.first), cardinality_(dkey.second), value_(value) { : Constraint(dkey.first), cardinality_(dkey.second), value_(value) {}
}
// print // print
void print(const std::string& s = "", void print(const std::string& s = "", const KeyFormatter& formatter =
const KeyFormatter& formatter = DefaultKeyFormatter) const override; DefaultKeyFormatter) const override;
/// equals /// equals
bool equals(const DiscreteFactor& other, double tol) const override { bool equals(const DiscreteFactor& other, double tol) const override {
if(!dynamic_cast<const SingleValue*>(&other)) if (!dynamic_cast<const SingleValue*>(&other))
return false; return false;
else { else {
const SingleValue& f(static_cast<const SingleValue&>(other)); const SingleValue& f(static_cast<const SingleValue&>(other));
return (cardinality_==f.cardinality_) && (value_==f.value_); return (cardinality_ == f.cardinality_) && (value_ == f.value_);
} }
} }
@ -75,6 +72,6 @@ namespace gtsam {
/// Partially apply known values, domain version /// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply( Constraint::shared_ptr partiallyApply(
const std::vector<Domain>& domains) const override; const std::vector<Domain>& domains) const override;
}; };
} // namespace gtsam } // namespace gtsam