Use DiscreteValues everywhere
parent
c63c1167ba
commit
e89a294376
|
@ -80,7 +80,7 @@ namespace gtsam {
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Value is just look up in AlgebraicDecisonTree
|
/// Value is just look up in AlgebraicDecisonTree
|
||||||
double operator()(const Values& values) const override {
|
double operator()(const DiscreteValues& values) const override {
|
||||||
return Potentials::operator()(values);
|
return Potentials::operator()(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -45,7 +45,7 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double DiscreteBayesNet::evaluate(const DiscreteConditional::Values & values) const {
|
double DiscreteBayesNet::evaluate(const DiscreteValues & values) const {
|
||||||
// evaluate all conditionals and multiply
|
// evaluate all conditionals and multiply
|
||||||
double result = 1.0;
|
double result = 1.0;
|
||||||
for(DiscreteConditional::shared_ptr conditional: *this)
|
for(DiscreteConditional::shared_ptr conditional: *this)
|
||||||
|
@ -54,18 +54,18 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
DiscreteFactor::Values DiscreteBayesNet::optimize() const {
|
DiscreteValues DiscreteBayesNet::optimize() const {
|
||||||
// solve each node in turn in topological sort order (parents first)
|
// solve each node in turn in topological sort order (parents first)
|
||||||
DiscreteFactor::Values result;
|
DiscreteValues result;
|
||||||
for (auto conditional: boost::adaptors::reverse(*this))
|
for (auto conditional: boost::adaptors::reverse(*this))
|
||||||
conditional->solveInPlace(&result);
|
conditional->solveInPlace(&result);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
DiscreteFactor::Values DiscreteBayesNet::sample() const {
|
DiscreteValues DiscreteBayesNet::sample() const {
|
||||||
// sample each node in turn in topological sort order (parents first)
|
// sample each node in turn in topological sort order (parents first)
|
||||||
DiscreteFactor::Values result;
|
DiscreteValues result;
|
||||||
for (auto conditional: boost::adaptors::reverse(*this))
|
for (auto conditional: boost::adaptors::reverse(*this))
|
||||||
conditional->sampleInPlace(&result);
|
conditional->sampleInPlace(&result);
|
||||||
return result;
|
return result;
|
||||||
|
|
|
@ -77,16 +77,16 @@ namespace gtsam {
|
||||||
// /** Add a DiscreteCondtional in front, when listing parents first*/
|
// /** Add a DiscreteCondtional in front, when listing parents first*/
|
||||||
// GTSAM_EXPORT void add_front(const Signature& s);
|
// GTSAM_EXPORT void add_front(const Signature& s);
|
||||||
|
|
||||||
//** evaluate for given Values */
|
//** evaluate for given DiscreteValues */
|
||||||
double evaluate(const DiscreteConditional::Values & values) const;
|
double evaluate(const DiscreteValues & values) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Solve the DiscreteBayesNet by back-substitution
|
* Solve the DiscreteBayesNet by back-substitution
|
||||||
*/
|
*/
|
||||||
DiscreteFactor::Values optimize() const;
|
DiscreteValues optimize() const;
|
||||||
|
|
||||||
/** Do ancestral sampling */
|
/** Do ancestral sampling */
|
||||||
DiscreteFactor::Values sample() const;
|
DiscreteValues sample() const;
|
||||||
|
|
||||||
///@}
|
///@}
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double DiscreteBayesTreeClique::evaluate(
|
double DiscreteBayesTreeClique::evaluate(
|
||||||
const DiscreteConditional::Values& values) const {
|
const DiscreteValues& values) const {
|
||||||
// evaluate all conditionals and multiply
|
// evaluate all conditionals and multiply
|
||||||
double result = (*conditional_)(values);
|
double result = (*conditional_)(values);
|
||||||
for (const auto& child : children) {
|
for (const auto& child : children) {
|
||||||
|
@ -47,7 +47,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double DiscreteBayesTree::evaluate(
|
double DiscreteBayesTree::evaluate(
|
||||||
const DiscreteConditional::Values& values) const {
|
const DiscreteValues& values) const {
|
||||||
double result = 1.0;
|
double result = 1.0;
|
||||||
for (const auto& root : roots_) {
|
for (const auto& root : roots_) {
|
||||||
result *= root->evaluate(values);
|
result *= root->evaluate(values);
|
||||||
|
|
|
@ -57,8 +57,8 @@ class GTSAM_EXPORT DiscreteBayesTreeClique
|
||||||
conditional_->printSignature(s, formatter);
|
conditional_->printSignature(s, formatter);
|
||||||
}
|
}
|
||||||
|
|
||||||
//** evaluate conditional probability of subtree for given Values */
|
//** evaluate conditional probability of subtree for given DiscreteValues */
|
||||||
double evaluate(const DiscreteConditional::Values& values) const;
|
double evaluate(const DiscreteValues& values) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -78,8 +78,8 @@ class GTSAM_EXPORT DiscreteBayesTree
|
||||||
/** Check equality */
|
/** Check equality */
|
||||||
bool equals(const This& other, double tol = 1e-9) const;
|
bool equals(const This& other, double tol = 1e-9) const;
|
||||||
|
|
||||||
//** evaluate probability for given Values */
|
//** evaluate probability for given DiscreteValues */
|
||||||
double evaluate(const DiscreteConditional::Values& values) const;
|
double evaluate(const DiscreteValues& values) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -97,7 +97,7 @@ bool DiscreteConditional::equals(const DiscreteFactor& other,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
Potentials::ADT DiscreteConditional::choose(const Values& parentsValues) const {
|
Potentials::ADT DiscreteConditional::choose(const DiscreteValues& parentsValues) const {
|
||||||
ADT pFS(*this);
|
ADT pFS(*this);
|
||||||
Key j; size_t value;
|
Key j; size_t value;
|
||||||
for(Key key: parents()) {
|
for(Key key: parents()) {
|
||||||
|
@ -117,12 +117,12 @@ Potentials::ADT DiscreteConditional::choose(const Values& parentsValues) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
void DiscreteConditional::solveInPlace(Values* values) const {
|
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
||||||
// TODO: Abhijit asks: is this really the fastest way? He thinks it is.
|
// TODO: Abhijit asks: is this really the fastest way? He thinks it is.
|
||||||
ADT pFS = choose(*values); // P(F|S=parentsValues)
|
ADT pFS = choose(*values); // P(F|S=parentsValues)
|
||||||
|
|
||||||
// Initialize
|
// Initialize
|
||||||
Values mpe;
|
DiscreteValues mpe;
|
||||||
double maxP = 0;
|
double maxP = 0;
|
||||||
|
|
||||||
DiscreteKeys keys;
|
DiscreteKeys keys;
|
||||||
|
@ -131,10 +131,10 @@ void DiscreteConditional::solveInPlace(Values* values) const {
|
||||||
keys & dk;
|
keys & dk;
|
||||||
}
|
}
|
||||||
// Get all Possible Configurations
|
// Get all Possible Configurations
|
||||||
vector<Values> allPosbValues = cartesianProduct(keys);
|
vector<DiscreteValues> allPosbValues = cartesianProduct(keys);
|
||||||
|
|
||||||
// Find the MPE
|
// Find the MPE
|
||||||
for(Values& frontalVals: allPosbValues) {
|
for(DiscreteValues& frontalVals: allPosbValues) {
|
||||||
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
||||||
// Update MPE solution if better
|
// Update MPE solution if better
|
||||||
if (pValueS > maxP) {
|
if (pValueS > maxP) {
|
||||||
|
@ -150,7 +150,7 @@ void DiscreteConditional::solveInPlace(Values* values) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
void DiscreteConditional::sampleInPlace(Values* values) const {
|
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
||||||
assert(nrFrontals() == 1);
|
assert(nrFrontals() == 1);
|
||||||
Key j = (firstFrontalKey());
|
Key j = (firstFrontalKey());
|
||||||
size_t sampled = sample(*values); // Sample variable given parents
|
size_t sampled = sample(*values); // Sample variable given parents
|
||||||
|
@ -158,7 +158,7 @@ void DiscreteConditional::sampleInPlace(Values* values) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
size_t DiscreteConditional::solve(const Values& parentsValues) const {
|
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
|
||||||
|
|
||||||
// TODO: is this really the fastest way? I think it is.
|
// TODO: is this really the fastest way? I think it is.
|
||||||
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
|
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
|
||||||
|
@ -166,7 +166,7 @@ size_t DiscreteConditional::solve(const Values& parentsValues) const {
|
||||||
// Then, find the max over all remaining
|
// Then, find the max over all remaining
|
||||||
// TODO, only works for one key now, seems horribly slow this way
|
// TODO, only works for one key now, seems horribly slow this way
|
||||||
size_t mpe = 0;
|
size_t mpe = 0;
|
||||||
Values frontals;
|
DiscreteValues frontals;
|
||||||
double maxP = 0;
|
double maxP = 0;
|
||||||
assert(nrFrontals() == 1);
|
assert(nrFrontals() == 1);
|
||||||
Key j = (firstFrontalKey());
|
Key j = (firstFrontalKey());
|
||||||
|
@ -183,7 +183,7 @@ size_t DiscreteConditional::solve(const Values& parentsValues) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
size_t DiscreteConditional::sample(const Values& parentsValues) const {
|
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
||||||
static mt19937 rng(2); // random number generator
|
static mt19937 rng(2); // random number generator
|
||||||
|
|
||||||
// Get the correct conditional density
|
// Get the correct conditional density
|
||||||
|
@ -194,7 +194,7 @@ size_t DiscreteConditional::sample(const Values& parentsValues) const {
|
||||||
Key key = firstFrontalKey();
|
Key key = firstFrontalKey();
|
||||||
size_t nj = cardinality(key);
|
size_t nj = cardinality(key);
|
||||||
vector<double> p(nj);
|
vector<double> p(nj);
|
||||||
Values frontals;
|
DiscreteValues frontals;
|
||||||
for (size_t value = 0; value < nj; value++) {
|
for (size_t value = 0; value < nj; value++) {
|
||||||
frontals[key] = value;
|
frontals[key] = value;
|
||||||
p[value] = pFS(frontals); // P(F=value|S=parentsValues)
|
p[value] = pFS(frontals); // P(F=value|S=parentsValues)
|
||||||
|
|
|
@ -42,9 +42,7 @@ public:
|
||||||
typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class
|
typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class
|
||||||
typedef Conditional<BaseFactor, This> BaseConditional; ///< Typedef to our conditional base class
|
typedef Conditional<BaseFactor, This> BaseConditional; ///< Typedef to our conditional base class
|
||||||
|
|
||||||
/** A map from keys to values..
|
using Values = DiscreteValues; ///< backwards compatibility
|
||||||
* TODO: Again, do we need this??? */
|
|
||||||
typedef Assignment<Key> Values;
|
|
||||||
|
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
@ -101,7 +99,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Evaluate, just look up in AlgebraicDecisonTree
|
/// Evaluate, just look up in AlgebraicDecisonTree
|
||||||
double operator()(const Values& values) const override {
|
double operator()(const DiscreteValues& values) const override {
|
||||||
return Potentials::operator()(values);
|
return Potentials::operator()(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,31 +109,31 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Restrict to given parent values, returns AlgebraicDecisionDiagram */
|
/** Restrict to given parent values, returns AlgebraicDecisionDiagram */
|
||||||
ADT choose(const Assignment<Key>& parentsValues) const;
|
ADT choose(const DiscreteValues& parentsValues) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* solve a conditional
|
* solve a conditional
|
||||||
* @param parentsValues Known values of the parents
|
* @param parentsValues Known values of the parents
|
||||||
* @return MPE value of the child (1 frontal variable).
|
* @return MPE value of the child (1 frontal variable).
|
||||||
*/
|
*/
|
||||||
size_t solve(const Values& parentsValues) const;
|
size_t solve(const DiscreteValues& parentsValues) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* sample
|
* sample
|
||||||
* @param parentsValues Known values of the parents
|
* @param parentsValues Known values of the parents
|
||||||
* @return sample from conditional
|
* @return sample from conditional
|
||||||
*/
|
*/
|
||||||
size_t sample(const Values& parentsValues) const;
|
size_t sample(const DiscreteValues& parentsValues) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// solve a conditional, in place
|
/// solve a conditional, in place
|
||||||
void solveInPlace(Values* parentsValues) const;
|
void solveInPlace(DiscreteValues* parentsValues) const;
|
||||||
|
|
||||||
/// sample in place, stores result in partial solution
|
/// sample in place, stores result in partial solution
|
||||||
void sampleInPlace(Values* parentsValues) const;
|
void sampleInPlace(DiscreteValues* parentsValues) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/discrete/Assignment.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
#include <gtsam/inference/Factor.h>
|
#include <gtsam/inference/Factor.h>
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
|
|
||||||
|
@ -40,17 +40,7 @@ public:
|
||||||
typedef boost::shared_ptr<DiscreteFactor> shared_ptr; ///< shared_ptr to this class
|
typedef boost::shared_ptr<DiscreteFactor> shared_ptr; ///< shared_ptr to this class
|
||||||
typedef Factor Base; ///< Our base class
|
typedef Factor Base; ///< Our base class
|
||||||
|
|
||||||
/** A map from keys to values
|
using Values = DiscreteValues; ///< backwards compatibility
|
||||||
* TODO: Do we need this? Should we just use gtsam::Values?
|
|
||||||
* We just need another special DiscreteValue to represent labels,
|
|
||||||
* However, all other Lie's operators are undefined in this class.
|
|
||||||
* The good thing is we can have a Hybrid graph of discrete/continuous variables
|
|
||||||
* together..
|
|
||||||
* Another good thing is we don't need to have the special DiscreteKey which stores
|
|
||||||
* cardinality of a Discrete variable. It should be handled naturally in
|
|
||||||
* the new class DiscreteValue, as the varible's type (domain)
|
|
||||||
*/
|
|
||||||
typedef Assignment<Key> Values;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
@ -91,7 +81,7 @@ public:
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Find value for given assignment of values to variables
|
/// Find value for given assignment of values to variables
|
||||||
virtual double operator()(const Values&) const = 0;
|
virtual double operator()(const DiscreteValues&) const = 0;
|
||||||
|
|
||||||
/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
|
/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
|
||||||
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
|
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
|
||||||
|
@ -104,6 +94,6 @@ public:
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
|
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
|
||||||
template<> struct traits<DiscreteFactor::Values> : public Testable<DiscreteFactor::Values> {};
|
template<> struct traits<DiscreteValues> : public Testable<DiscreteValues> {};
|
||||||
|
|
||||||
}// namespace gtsam
|
}// namespace gtsam
|
||||||
|
|
|
@ -56,7 +56,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double DiscreteFactorGraph::operator()(
|
double DiscreteFactorGraph::operator()(
|
||||||
const DiscreteFactor::Values &values) const {
|
const DiscreteValues &values) const {
|
||||||
double product = 1.0;
|
double product = 1.0;
|
||||||
for( const sharedFactor& factor: factors_ )
|
for( const sharedFactor& factor: factors_ )
|
||||||
product *= (*factor)(values);
|
product *= (*factor)(values);
|
||||||
|
@ -94,7 +94,7 @@ namespace gtsam {
|
||||||
// }
|
// }
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
DiscreteFactorGraph::Values DiscreteFactorGraph::optimize() const
|
DiscreteValues DiscreteFactorGraph::optimize() const
|
||||||
{
|
{
|
||||||
gttic(DiscreteFactorGraph_optimize);
|
gttic(DiscreteFactorGraph_optimize);
|
||||||
return BaseEliminateable::eliminateSequential()->optimize();
|
return BaseEliminateable::eliminateSequential()->optimize();
|
||||||
|
|
|
@ -71,9 +71,10 @@ public:
|
||||||
typedef EliminateableFactorGraph<This> BaseEliminateable; ///< Typedef to base elimination class
|
typedef EliminateableFactorGraph<This> BaseEliminateable; ///< Typedef to base elimination class
|
||||||
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
||||||
|
|
||||||
|
using Values = DiscreteValues; ///< backwards compatibility
|
||||||
|
|
||||||
/** A map from keys to values */
|
/** A map from keys to values */
|
||||||
typedef KeyVector Indices;
|
typedef KeyVector Indices;
|
||||||
typedef Assignment<Key> Values;
|
|
||||||
|
|
||||||
/** Default constructor */
|
/** Default constructor */
|
||||||
DiscreteFactorGraph() {}
|
DiscreteFactorGraph() {}
|
||||||
|
@ -130,7 +131,7 @@ public:
|
||||||
DecisionTreeFactor product() const;
|
DecisionTreeFactor product() const;
|
||||||
|
|
||||||
/** Evaluates the factor graph given values, returns the joint probability of the factor graph given specific instantiation of values*/
|
/** Evaluates the factor graph given values, returns the joint probability of the factor graph given specific instantiation of values*/
|
||||||
double operator()(const DiscreteFactor::Values & values) const;
|
double operator()(const DiscreteValues & values) const;
|
||||||
|
|
||||||
/// print
|
/// print
|
||||||
void print(
|
void print(
|
||||||
|
@ -141,7 +142,7 @@ public:
|
||||||
* the dense elimination function specified in \c function,
|
* the dense elimination function specified in \c function,
|
||||||
* followed by back-substitution resulting from elimination. Is equivalent
|
* followed by back-substitution resulting from elimination. Is equivalent
|
||||||
* to calling graph.eliminateSequential()->optimize(). */
|
* to calling graph.eliminateSequential()->optimize(). */
|
||||||
Values optimize() const;
|
DiscreteValues optimize() const;
|
||||||
|
|
||||||
|
|
||||||
// /** Permute the variables in the factors */
|
// /** Permute the variables in the factors */
|
||||||
|
|
|
@ -64,7 +64,7 @@ namespace gtsam {
|
||||||
//Create result
|
//Create result
|
||||||
Vector vResult(key.second);
|
Vector vResult(key.second);
|
||||||
for (size_t state = 0; state < key.second ; ++ state) {
|
for (size_t state = 0; state < key.second ; ++ state) {
|
||||||
DiscreteFactor::Values values;
|
DiscreteValues values;
|
||||||
values[key.first] = state;
|
values[key.first] = state;
|
||||||
vResult(state) = (*marginalFactor)(values);
|
vResult(state) = (*marginalFactor)(values);
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
|
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
|
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
|
||||||
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
// headers first to make sure no missing headers
|
// headers first to make sure no missing headers
|
||||||
//#define DT_NO_PRUNING
|
//#define DT_NO_PRUNING
|
||||||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||||
|
@ -445,7 +446,7 @@ TEST(ADT, equality_parser)
|
||||||
TEST(ADT, constructor)
|
TEST(ADT, constructor)
|
||||||
{
|
{
|
||||||
DiscreteKey v0(0,2), v1(1,3);
|
DiscreteKey v0(0,2), v1(1,3);
|
||||||
Assignment<Key> x00, x01, x02, x10, x11, x12;
|
DiscreteValues x00, x01, x02, x10, x11, x12;
|
||||||
x00[0] = 0, x00[1] = 0;
|
x00[0] = 0, x00[1] = 0;
|
||||||
x01[0] = 0, x01[1] = 1;
|
x01[0] = 0, x01[1] = 1;
|
||||||
x02[0] = 0, x02[1] = 2;
|
x02[0] = 0, x02[1] = 2;
|
||||||
|
@ -475,7 +476,7 @@ TEST(ADT, constructor)
|
||||||
for(double& t: table)
|
for(double& t: table)
|
||||||
t = x++;
|
t = x++;
|
||||||
ADT f3(z0 & z1 & z2 & z3, table);
|
ADT f3(z0 & z1 & z2 & z3, table);
|
||||||
Assignment<Key> assignment;
|
DiscreteValues assignment;
|
||||||
assignment[0] = 0;
|
assignment[0] = 0;
|
||||||
assignment[1] = 0;
|
assignment[1] = 0;
|
||||||
assignment[2] = 0;
|
assignment[2] = 0;
|
||||||
|
@ -501,7 +502,7 @@ TEST(ADT, conversion)
|
||||||
// f2.print("f2");
|
// f2.print("f2");
|
||||||
dot(fIndexKey, "conversion-f2");
|
dot(fIndexKey, "conversion-f2");
|
||||||
|
|
||||||
Assignment<Key> x00, x01, x02, x10, x11, x12;
|
DiscreteValues x00, x01, x02, x10, x11, x12;
|
||||||
x00[5] = 0, x00[2] = 0;
|
x00[5] = 0, x00[2] = 0;
|
||||||
x01[5] = 0, x01[2] = 1;
|
x01[5] = 0, x01[2] = 1;
|
||||||
x10[5] = 1, x10[2] = 0;
|
x10[5] = 1, x10[2] = 0;
|
||||||
|
@ -577,7 +578,7 @@ TEST(ADT, zero)
|
||||||
ADT notb(B, 1, 0);
|
ADT notb(B, 1, 0);
|
||||||
ADT anotb = a * notb;
|
ADT anotb = a * notb;
|
||||||
// GTSAM_PRINT(anotb);
|
// GTSAM_PRINT(anotb);
|
||||||
Assignment<Key> x00, x01, x10, x11;
|
DiscreteValues x00, x01, x10, x11;
|
||||||
x00[0] = 0, x00[1] = 0;
|
x00[0] = 0, x00[1] = 0;
|
||||||
x01[0] = 0, x01[1] = 1;
|
x01[0] = 0, x01[1] = 1;
|
||||||
x10[0] = 1, x10[1] = 0;
|
x10[0] = 1, x10[1] = 0;
|
||||||
|
|
|
@ -43,7 +43,7 @@ TEST( DecisionTreeFactor, constructors)
|
||||||
// f2.print("f2:");
|
// f2.print("f2:");
|
||||||
// f3.print("f3:");
|
// f3.print("f3:");
|
||||||
|
|
||||||
DecisionTreeFactor::Values values;
|
DiscreteValues values;
|
||||||
values[0] = 1; // x
|
values[0] = 1; // x
|
||||||
values[1] = 2; // y
|
values[1] = 2; // y
|
||||||
values[2] = 1; // z
|
values[2] = 1; // z
|
||||||
|
|
|
@ -105,7 +105,7 @@ TEST(DiscreteBayesNet, Asia) {
|
||||||
|
|
||||||
// solve
|
// solve
|
||||||
auto actualMPE = chordal->optimize();
|
auto actualMPE = chordal->optimize();
|
||||||
DiscreteFactor::Values expectedMPE;
|
DiscreteValues expectedMPE;
|
||||||
insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)(
|
insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)(
|
||||||
Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)(
|
Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)(
|
||||||
LungCancer.first, 0)(Bronchitis.first, 0);
|
LungCancer.first, 0)(Bronchitis.first, 0);
|
||||||
|
@ -118,14 +118,14 @@ TEST(DiscreteBayesNet, Asia) {
|
||||||
// solve again, now with evidence
|
// solve again, now with evidence
|
||||||
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
|
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
|
||||||
auto actualMPE2 = chordal2->optimize();
|
auto actualMPE2 = chordal2->optimize();
|
||||||
DiscreteFactor::Values expectedMPE2;
|
DiscreteValues expectedMPE2;
|
||||||
insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)(
|
insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)(
|
||||||
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)(
|
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)(
|
||||||
LungCancer.first, 0)(Bronchitis.first, 1);
|
LungCancer.first, 0)(Bronchitis.first, 1);
|
||||||
EXPECT(assert_equal(expectedMPE2, actualMPE2));
|
EXPECT(assert_equal(expectedMPE2, actualMPE2));
|
||||||
|
|
||||||
// now sample from it
|
// now sample from it
|
||||||
DiscreteFactor::Values expectedSample;
|
DiscreteValues expectedSample;
|
||||||
SETDEBUG("DiscreteConditional::sample", false);
|
SETDEBUG("DiscreteConditional::sample", false);
|
||||||
insert(expectedSample)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 1)(
|
insert(expectedSample)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 1)(
|
||||||
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 1)(
|
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 1)(
|
||||||
|
|
|
@ -89,24 +89,24 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
|
||||||
auto R = bayesTree->roots().front();
|
auto R = bayesTree->roots().front();
|
||||||
|
|
||||||
// Check whether BN and BT give the same answer on all configurations
|
// Check whether BN and BT give the same answer on all configurations
|
||||||
vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
|
vector<DiscreteValues> allPosbValues = cartesianProduct(
|
||||||
key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] &
|
key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] &
|
||||||
key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
|
key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
|
||||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||||
DiscreteFactor::Values x = allPosbValues[i];
|
DiscreteValues x = allPosbValues[i];
|
||||||
double expected = bayesNet.evaluate(x);
|
double expected = bayesNet.evaluate(x);
|
||||||
double actual = bayesTree->evaluate(x);
|
double actual = bayesTree->evaluate(x);
|
||||||
DOUBLES_EQUAL(expected, actual, 1e-9);
|
DOUBLES_EQUAL(expected, actual, 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate all some marginals for Values==all1
|
// Calculate all some marginals for DiscreteValues==all1
|
||||||
Vector marginals = Vector::Zero(15);
|
Vector marginals = Vector::Zero(15);
|
||||||
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
|
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
|
||||||
joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
|
joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
|
||||||
joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0,
|
joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0,
|
||||||
joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
|
joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
|
||||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||||
DiscreteFactor::Values x = allPosbValues[i];
|
DiscreteValues x = allPosbValues[i];
|
||||||
double px = bayesTree->evaluate(x);
|
double px = bayesTree->evaluate(x);
|
||||||
for (size_t i = 0; i < 15; i++)
|
for (size_t i = 0; i < 15; i++)
|
||||||
if (x[i]) marginals[i] += px;
|
if (x[i]) marginals[i] += px;
|
||||||
|
@ -138,7 +138,7 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
DiscreteFactor::Values all1 = allPosbValues.back();
|
DiscreteValues all1 = allPosbValues.back();
|
||||||
|
|
||||||
// check separator marginal P(S0)
|
// check separator marginal P(S0)
|
||||||
auto clique = (*bayesTree)[0];
|
auto clique = (*bayesTree)[0];
|
||||||
|
|
|
@ -81,8 +81,8 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
|
||||||
graph.add(P2, "0.9 0.6");
|
graph.add(P2, "0.9 0.6");
|
||||||
graph.add(P1 & P2, "4 1 10 4");
|
graph.add(P1 & P2, "4 1 10 4");
|
||||||
|
|
||||||
// Instantiate Values
|
// Instantiate DiscreteValues
|
||||||
DiscreteFactor::Values values;
|
DiscreteValues values;
|
||||||
values[0] = 1;
|
values[0] = 1;
|
||||||
values[1] = 1;
|
values[1] = 1;
|
||||||
|
|
||||||
|
@ -167,7 +167,7 @@ TEST( DiscreteFactorGraph, test)
|
||||||
// EXPECT(assert_equal(expected, *actual2));
|
// EXPECT(assert_equal(expected, *actual2));
|
||||||
|
|
||||||
// Test optimization
|
// Test optimization
|
||||||
DiscreteFactor::Values expectedValues;
|
DiscreteValues expectedValues;
|
||||||
insert(expectedValues)(0, 0)(1, 0)(2, 0);
|
insert(expectedValues)(0, 0)(1, 0)(2, 0);
|
||||||
auto actualValues = graph.optimize();
|
auto actualValues = graph.optimize();
|
||||||
EXPECT(assert_equal(expectedValues, actualValues));
|
EXPECT(assert_equal(expectedValues, actualValues));
|
||||||
|
@ -188,7 +188,7 @@ TEST( DiscreteFactorGraph, testMPE)
|
||||||
|
|
||||||
auto actualMPE = graph.optimize();
|
auto actualMPE = graph.optimize();
|
||||||
|
|
||||||
DiscreteFactor::Values expectedMPE;
|
DiscreteValues expectedMPE;
|
||||||
insert(expectedMPE)(0, 0)(1, 1)(2, 1);
|
insert(expectedMPE)(0, 0)(1, 1)(2, 1);
|
||||||
EXPECT(assert_equal(expectedMPE, actualMPE));
|
EXPECT(assert_equal(expectedMPE, actualMPE));
|
||||||
}
|
}
|
||||||
|
@ -211,7 +211,7 @@ TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
|
||||||
// graph.product().potentials().dot("Darwiche-product");
|
// graph.product().potentials().dot("Darwiche-product");
|
||||||
// DiscreteSequentialSolver(graph).eliminate()->print();
|
// DiscreteSequentialSolver(graph).eliminate()->print();
|
||||||
|
|
||||||
DiscreteFactor::Values expectedMPE;
|
DiscreteValues expectedMPE;
|
||||||
insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1);
|
insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1);
|
||||||
|
|
||||||
// Use the solver machinery.
|
// Use the solver machinery.
|
||||||
|
|
|
@ -47,7 +47,7 @@ TEST_UNSAFE( DiscreteMarginals, UGM_small ) {
|
||||||
|
|
||||||
DiscreteMarginals marginals(graph);
|
DiscreteMarginals marginals(graph);
|
||||||
DiscreteFactor::shared_ptr actualC = marginals(Cathy.first);
|
DiscreteFactor::shared_ptr actualC = marginals(Cathy.first);
|
||||||
DiscreteFactor::Values values;
|
DiscreteValues values;
|
||||||
|
|
||||||
values[Cathy.first] = 0;
|
values[Cathy.first] = 0;
|
||||||
EXPECT_DOUBLES_EQUAL( 0.359631, (*actualC)(values), 1e-6);
|
EXPECT_DOUBLES_EQUAL( 0.359631, (*actualC)(values), 1e-6);
|
||||||
|
@ -94,7 +94,7 @@ TEST_UNSAFE( DiscreteMarginals, UGM_chain ) {
|
||||||
|
|
||||||
DiscreteMarginals marginals(graph);
|
DiscreteMarginals marginals(graph);
|
||||||
DiscreteFactor::shared_ptr actualC = marginals(key[2].first);
|
DiscreteFactor::shared_ptr actualC = marginals(key[2].first);
|
||||||
DiscreteFactor::Values values;
|
DiscreteValues values;
|
||||||
|
|
||||||
values[key[2].first] = 0;
|
values[key[2].first] = 0;
|
||||||
EXPECT_DOUBLES_EQUAL( 0.03426, (*actualC)(values), 1e-4);
|
EXPECT_DOUBLES_EQUAL( 0.03426, (*actualC)(values), 1e-4);
|
||||||
|
@ -164,11 +164,11 @@ TEST_UNSAFE(DiscreteMarginals, truss2) {
|
||||||
graph.add(key[2] & key[3] & key[4], "1 2 3 4 5 6 7 8");
|
graph.add(key[2] & key[3] & key[4], "1 2 3 4 5 6 7 8");
|
||||||
|
|
||||||
// Calculate the marginals by brute force
|
// Calculate the marginals by brute force
|
||||||
vector<DiscreteFactor::Values> allPosbValues =
|
vector<DiscreteValues> allPosbValues =
|
||||||
cartesianProduct(key[0] & key[1] & key[2] & key[3] & key[4]);
|
cartesianProduct(key[0] & key[1] & key[2] & key[3] & key[4]);
|
||||||
Vector T = Z_5x1, F = Z_5x1;
|
Vector T = Z_5x1, F = Z_5x1;
|
||||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||||
DiscreteFactor::Values x = allPosbValues[i];
|
DiscreteValues x = allPosbValues[i];
|
||||||
double px = graph(x);
|
double px = graph(x);
|
||||||
for (size_t j = 0; j < 5; j++)
|
for (size_t j = 0; j < 5; j++)
|
||||||
if (x[j])
|
if (x[j])
|
||||||
|
|
|
@ -26,7 +26,7 @@ void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double AllDiff::operator()(const Values& values) const {
|
double AllDiff::operator()(const DiscreteValues& values) const {
|
||||||
std::set<size_t> taken; // record values taken by keys
|
std::set<size_t> taken; // record values taken by keys
|
||||||
for (Key dkey : keys_) {
|
for (Key dkey : keys_) {
|
||||||
size_t value = values.at(dkey); // get the value for that key
|
size_t value = values.at(dkey); // get the value for that key
|
||||||
|
@ -88,7 +88,7 @@ bool AllDiff::ensureArcConsistency(Key j, Domains* domains) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const {
|
Constraint::shared_ptr AllDiff::partiallyApply(const DiscreteValues& values) const {
|
||||||
DiscreteKeys newKeys;
|
DiscreteKeys newKeys;
|
||||||
// loop over keys and add them only if they do not appear in values
|
// loop over keys and add them only if they do not appear in values
|
||||||
for (Key k : keys_)
|
for (Key k : keys_)
|
||||||
|
@ -101,7 +101,7 @@ Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const {
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
Constraint::shared_ptr AllDiff::partiallyApply(
|
Constraint::shared_ptr AllDiff::partiallyApply(
|
||||||
const Domains& domains) const {
|
const Domains& domains) const {
|
||||||
DiscreteFactor::Values known;
|
DiscreteValues known;
|
||||||
for (Key k : keys_) {
|
for (Key k : keys_) {
|
||||||
const Domain& Dk = domains.at(k);
|
const Domain& Dk = domains.at(k);
|
||||||
if (Dk.isSingleton()) known[k] = Dk.firstValue();
|
if (Dk.isSingleton()) known[k] = Dk.firstValue();
|
||||||
|
|
|
@ -45,7 +45,7 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Calculate value = expensive !
|
/// Calculate value = expensive !
|
||||||
double operator()(const Values& values) const override;
|
double operator()(const DiscreteValues& values) const override;
|
||||||
|
|
||||||
/// Convert into a decisiontree, can be *very* expensive !
|
/// Convert into a decisiontree, can be *very* expensive !
|
||||||
DecisionTreeFactor toDecisionTreeFactor() const override;
|
DecisionTreeFactor toDecisionTreeFactor() const override;
|
||||||
|
@ -62,7 +62,7 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
|
||||||
bool ensureArcConsistency(Key j, Domains* domains) const override;
|
bool ensureArcConsistency(Key j, Domains* domains) const override;
|
||||||
|
|
||||||
/// Partially apply known values
|
/// Partially apply known values
|
||||||
Constraint::shared_ptr partiallyApply(const Values&) const override;
|
Constraint::shared_ptr partiallyApply(const DiscreteValues&) const override;
|
||||||
|
|
||||||
/// Partially apply known values, domain version
|
/// Partially apply known values, domain version
|
||||||
Constraint::shared_ptr partiallyApply(
|
Constraint::shared_ptr partiallyApply(
|
||||||
|
|
|
@ -47,7 +47,7 @@ class BinaryAllDiff : public Constraint {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Calculate value
|
/// Calculate value
|
||||||
double operator()(const Values& values) const override {
|
double operator()(const DiscreteValues& values) const override {
|
||||||
return (double)(values.at(keys_[0]) != values.at(keys_[1]));
|
return (double)(values.at(keys_[0]) != values.at(keys_[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,7 +82,7 @@ class BinaryAllDiff : public Constraint {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Partially apply known values
|
/// Partially apply known values
|
||||||
Constraint::shared_ptr partiallyApply(const Values&) const override {
|
Constraint::shared_ptr partiallyApply(const DiscreteValues&) const override {
|
||||||
throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented");
|
throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,13 +14,13 @@ using namespace std;
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/// Find the best total assignment - can be expensive
|
/// Find the best total assignment - can be expensive
|
||||||
CSP::Values CSP::optimalAssignment() const {
|
DiscreteValues CSP::optimalAssignment() const {
|
||||||
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential();
|
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential();
|
||||||
return chordal->optimize();
|
return chordal->optimize();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Find the best total assignment - can be expensive
|
/// Find the best total assignment - can be expensive
|
||||||
CSP::Values CSP::optimalAssignment(const Ordering& ordering) const {
|
DiscreteValues CSP::optimalAssignment(const Ordering& ordering) const {
|
||||||
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(ordering);
|
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(ordering);
|
||||||
return chordal->optimize();
|
return chordal->optimize();
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,10 +20,8 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph {
|
class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph {
|
||||||
public:
|
public:
|
||||||
/** A map from keys to values */
|
using Values = DiscreteValues; ///< backwards compatibility
|
||||||
typedef Assignment<Key> Values;
|
|
||||||
|
|
||||||
public:
|
|
||||||
/// Add a unary constraint, allowing only a single value
|
/// Add a unary constraint, allowing only a single value
|
||||||
void addSingleValue(const DiscreteKey& dkey, size_t value) {
|
void addSingleValue(const DiscreteKey& dkey, size_t value) {
|
||||||
emplace_shared<SingleValue>(dkey, value);
|
emplace_shared<SingleValue>(dkey, value);
|
||||||
|
@ -46,10 +44,10 @@ class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph {
|
||||||
// }
|
// }
|
||||||
|
|
||||||
/// Find the best total assignment - can be expensive.
|
/// Find the best total assignment - can be expensive.
|
||||||
Values optimalAssignment() const;
|
DiscreteValues optimalAssignment() const;
|
||||||
|
|
||||||
/// Find the best total assignment, with given ordering - can be expensive.
|
/// Find the best total assignment, with given ordering - can be expensive.
|
||||||
Values optimalAssignment(const Ordering& ordering) const;
|
DiscreteValues optimalAssignment(const Ordering& ordering) const;
|
||||||
|
|
||||||
// /*
|
// /*
|
||||||
// * Perform loopy belief propagation
|
// * Perform loopy belief propagation
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteFactor.h>
|
#include <gtsam/discrete/DiscreteFactor.h>
|
||||||
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
#include <gtsam_unstable/dllexport.h>
|
#include <gtsam_unstable/dllexport.h>
|
||||||
|
|
||||||
#include <boost/assign.hpp>
|
#include <boost/assign.hpp>
|
||||||
|
@ -75,7 +76,7 @@ class GTSAM_EXPORT Constraint : public DiscreteFactor {
|
||||||
virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0;
|
virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0;
|
||||||
|
|
||||||
/// Partially apply known values
|
/// Partially apply known values
|
||||||
virtual shared_ptr partiallyApply(const Values&) const = 0;
|
virtual shared_ptr partiallyApply(const DiscreteValues&) const = 0;
|
||||||
|
|
||||||
/// Partially apply known values, domain version
|
/// Partially apply known values, domain version
|
||||||
virtual shared_ptr partiallyApply(const Domains&) const = 0;
|
virtual shared_ptr partiallyApply(const Domains&) const = 0;
|
||||||
|
|
|
@ -31,7 +31,7 @@ string Domain::base1Str() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double Domain::operator()(const Values& values) const {
|
double Domain::operator()(const DiscreteValues& values) const {
|
||||||
return contains(values.at(key()));
|
return contains(values.at(key()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,8 +79,8 @@ boost::optional<Domain> Domain::checkAllDiff(const KeyVector keys,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
Constraint::shared_ptr Domain::partiallyApply(const Values& values) const {
|
Constraint::shared_ptr Domain::partiallyApply(const DiscreteValues& values) const {
|
||||||
Values::const_iterator it = values.find(key());
|
DiscreteValues::const_iterator it = values.find(key());
|
||||||
if (it != values.end() && !contains(it->second))
|
if (it != values.end() && !contains(it->second))
|
||||||
throw runtime_error("Domain::partiallyApply: unsatisfiable");
|
throw runtime_error("Domain::partiallyApply: unsatisfiable");
|
||||||
return boost::make_shared<Domain>(*this);
|
return boost::make_shared<Domain>(*this);
|
||||||
|
|
|
@ -76,7 +76,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
|
||||||
bool contains(size_t value) const { return values_.count(value) > 0; }
|
bool contains(size_t value) const { return values_.count(value) > 0; }
|
||||||
|
|
||||||
/// Calculate value
|
/// Calculate value
|
||||||
double operator()(const Values& values) const override;
|
double operator()(const DiscreteValues& values) const override;
|
||||||
|
|
||||||
/// Convert into a decisiontree
|
/// Convert into a decisiontree
|
||||||
DecisionTreeFactor toDecisionTreeFactor() const override;
|
DecisionTreeFactor toDecisionTreeFactor() const override;
|
||||||
|
@ -104,7 +104,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
|
||||||
const Domains& domains) const;
|
const Domains& domains) const;
|
||||||
|
|
||||||
/// Partially apply known values
|
/// Partially apply known values
|
||||||
Constraint::shared_ptr partiallyApply(const Values& values) const override;
|
Constraint::shared_ptr partiallyApply(const DiscreteValues& values) const override;
|
||||||
|
|
||||||
/// Partially apply known values, domain version
|
/// Partially apply known values, domain version
|
||||||
Constraint::shared_ptr partiallyApply(const Domains& domains) const override;
|
Constraint::shared_ptr partiallyApply(const Domains& domains) const override;
|
||||||
|
|
|
@ -202,7 +202,7 @@ void Scheduler::print(const string& s, const KeyFormatter& formatter) const {
|
||||||
} // print
|
} // print
|
||||||
|
|
||||||
/** Print readable form of assignment */
|
/** Print readable form of assignment */
|
||||||
void Scheduler::printAssignment(const Values& assignment) const {
|
void Scheduler::printAssignment(const DiscreteValues& assignment) const {
|
||||||
// Not intended to be general! Assumes very particular ordering !
|
// Not intended to be general! Assumes very particular ordering !
|
||||||
cout << endl;
|
cout << endl;
|
||||||
for (size_t s = 0; s < nrStudents(); s++) {
|
for (size_t s = 0; s < nrStudents(); s++) {
|
||||||
|
@ -220,8 +220,8 @@ void Scheduler::printAssignment(const Values& assignment) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Special print for single-student case */
|
/** Special print for single-student case */
|
||||||
void Scheduler::printSpecial(const Values& assignment) const {
|
void Scheduler::printSpecial(const DiscreteValues& assignment) const {
|
||||||
Values::const_iterator it = assignment.begin();
|
DiscreteValues::const_iterator it = assignment.begin();
|
||||||
for (size_t area = 0; area < 3; area++, it++) {
|
for (size_t area = 0; area < 3; area++, it++) {
|
||||||
size_t f = it->second;
|
size_t f = it->second;
|
||||||
cout << setw(12) << studentArea(0, area) << ": " << facultyName_[f] << endl;
|
cout << setw(12) << studentArea(0, area) << ": " << facultyName_[f] << endl;
|
||||||
|
@ -230,7 +230,7 @@ void Scheduler::printSpecial(const Values& assignment) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Accumulate faculty stats */
|
/** Accumulate faculty stats */
|
||||||
void Scheduler::accumulateStats(const Values& assignment,
|
void Scheduler::accumulateStats(const DiscreteValues& assignment,
|
||||||
vector<size_t>& stats) const {
|
vector<size_t>& stats) const {
|
||||||
for (size_t s = 0; s < nrStudents(); s++) {
|
for (size_t s = 0; s < nrStudents(); s++) {
|
||||||
Key base = 3 * s;
|
Key base = 3 * s;
|
||||||
|
@ -256,7 +256,7 @@ DiscreteBayesNet::shared_ptr Scheduler::eliminate() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Find the best total assignment - can be expensive */
|
/** Find the best total assignment - can be expensive */
|
||||||
Scheduler::Values Scheduler::optimalAssignment() const {
|
DiscreteValues Scheduler::optimalAssignment() const {
|
||||||
DiscreteBayesNet::shared_ptr chordal = eliminate();
|
DiscreteBayesNet::shared_ptr chordal = eliminate();
|
||||||
|
|
||||||
if (ISDEBUG("Scheduler::optimalAssignment")) {
|
if (ISDEBUG("Scheduler::optimalAssignment")) {
|
||||||
|
@ -267,21 +267,21 @@ Scheduler::Values Scheduler::optimalAssignment() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
gttic(my_optimize);
|
gttic(my_optimize);
|
||||||
Values mpe = chordal->optimize();
|
DiscreteValues mpe = chordal->optimize();
|
||||||
gttoc(my_optimize);
|
gttoc(my_optimize);
|
||||||
return mpe;
|
return mpe;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** find the assignment of students to slots with most possible committees */
|
/** find the assignment of students to slots with most possible committees */
|
||||||
Scheduler::Values Scheduler::bestSchedule() const {
|
DiscreteValues Scheduler::bestSchedule() const {
|
||||||
Values best;
|
DiscreteValues best;
|
||||||
throw runtime_error("bestSchedule not implemented");
|
throw runtime_error("bestSchedule not implemented");
|
||||||
return best;
|
return best;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** find the corresponding most desirable committee assignment */
|
/** find the corresponding most desirable committee assignment */
|
||||||
Scheduler::Values Scheduler::bestAssignment(const Values& bestSchedule) const {
|
DiscreteValues Scheduler::bestAssignment(const DiscreteValues& bestSchedule) const {
|
||||||
Values best;
|
DiscreteValues best;
|
||||||
throw runtime_error("bestAssignment not implemented");
|
throw runtime_error("bestAssignment not implemented");
|
||||||
return best;
|
return best;
|
||||||
}
|
}
|
||||||
|
|
|
@ -134,26 +134,26 @@ class GTSAM_UNSTABLE_EXPORT Scheduler : public CSP {
|
||||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
/** Print readable form of assignment */
|
/** Print readable form of assignment */
|
||||||
void printAssignment(const Values& assignment) const;
|
void printAssignment(const DiscreteValues& assignment) const;
|
||||||
|
|
||||||
/** Special print for single-student case */
|
/** Special print for single-student case */
|
||||||
void printSpecial(const Values& assignment) const;
|
void printSpecial(const DiscreteValues& assignment) const;
|
||||||
|
|
||||||
/** Accumulate faculty stats */
|
/** Accumulate faculty stats */
|
||||||
void accumulateStats(const Values& assignment,
|
void accumulateStats(const DiscreteValues& assignment,
|
||||||
std::vector<size_t>& stats) const;
|
std::vector<size_t>& stats) const;
|
||||||
|
|
||||||
/** Eliminate, return a Bayes net */
|
/** Eliminate, return a Bayes net */
|
||||||
DiscreteBayesNet::shared_ptr eliminate() const;
|
DiscreteBayesNet::shared_ptr eliminate() const;
|
||||||
|
|
||||||
/** Find the best total assignment - can be expensive */
|
/** Find the best total assignment - can be expensive */
|
||||||
Values optimalAssignment() const;
|
DiscreteValues optimalAssignment() const;
|
||||||
|
|
||||||
/** find the assignment of students to slots with most possible committees */
|
/** find the assignment of students to slots with most possible committees */
|
||||||
Values bestSchedule() const;
|
DiscreteValues bestSchedule() const;
|
||||||
|
|
||||||
/** find the corresponding most desirable committee assignment */
|
/** find the corresponding most desirable committee assignment */
|
||||||
Values bestAssignment(const Values& bestSchedule) const;
|
DiscreteValues bestAssignment(const DiscreteValues& bestSchedule) const;
|
||||||
|
|
||||||
}; // Scheduler
|
}; // Scheduler
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@ void SingleValue::print(const string& s, const KeyFormatter& formatter) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double SingleValue::operator()(const Values& values) const {
|
double SingleValue::operator()(const DiscreteValues& values) const {
|
||||||
return (double)(values.at(keys_[0]) == value_);
|
return (double)(values.at(keys_[0]) == value_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,8 +57,8 @@ bool SingleValue::ensureArcConsistency(Key j, Domains* domains) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const {
|
Constraint::shared_ptr SingleValue::partiallyApply(const DiscreteValues& values) const {
|
||||||
Values::const_iterator it = values.find(keys_[0]);
|
DiscreteValues::const_iterator it = values.find(keys_[0]);
|
||||||
if (it != values.end() && it->second != value_)
|
if (it != values.end() && it->second != value_)
|
||||||
throw runtime_error("SingleValue::partiallyApply: unsatisfiable");
|
throw runtime_error("SingleValue::partiallyApply: unsatisfiable");
|
||||||
return boost::make_shared<SingleValue>(keys_[0], cardinality_, value_);
|
return boost::make_shared<SingleValue>(keys_[0], cardinality_, value_);
|
||||||
|
|
|
@ -50,7 +50,7 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Calculate value
|
/// Calculate value
|
||||||
double operator()(const Values& values) const override;
|
double operator()(const DiscreteValues& values) const override;
|
||||||
|
|
||||||
/// Convert into a decisiontree
|
/// Convert into a decisiontree
|
||||||
DecisionTreeFactor toDecisionTreeFactor() const override;
|
DecisionTreeFactor toDecisionTreeFactor() const override;
|
||||||
|
@ -67,7 +67,7 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
|
||||||
bool ensureArcConsistency(Key j, Domains* domains) const override;
|
bool ensureArcConsistency(Key j, Domains* domains) const override;
|
||||||
|
|
||||||
/// Partially apply known values
|
/// Partially apply known values
|
||||||
Constraint::shared_ptr partiallyApply(const Values& values) const override;
|
Constraint::shared_ptr partiallyApply(const DiscreteValues& values) const override;
|
||||||
|
|
||||||
/// Partially apply known values, domain version
|
/// Partially apply known values, domain version
|
||||||
Constraint::shared_ptr partiallyApply(
|
Constraint::shared_ptr partiallyApply(
|
||||||
|
|
|
@ -165,7 +165,7 @@ void solveStaged(size_t addMutex = 2) {
|
||||||
root->print(""/*scheduler.studentName(s)*/);
|
root->print(""/*scheduler.studentName(s)*/);
|
||||||
|
|
||||||
// solve root node only
|
// solve root node only
|
||||||
Scheduler::Values values;
|
DiscreteValues values;
|
||||||
size_t bestSlot = root->solve(values);
|
size_t bestSlot = root->solve(values);
|
||||||
|
|
||||||
// get corresponding count
|
// get corresponding count
|
||||||
|
@ -225,7 +225,7 @@ void sampleSolutions() {
|
||||||
// now, sample schedules
|
// now, sample schedules
|
||||||
for (size_t n = 0; n < 500; n++) {
|
for (size_t n = 0; n < 500; n++) {
|
||||||
vector<size_t> stats(19, 0);
|
vector<size_t> stats(19, 0);
|
||||||
vector<Scheduler::Values> samples;
|
vector<DiscreteValues> samples;
|
||||||
for (size_t i = 0; i < 7; i++) {
|
for (size_t i = 0; i < 7; i++) {
|
||||||
samples.push_back(samplers[i]->sample());
|
samples.push_back(samplers[i]->sample());
|
||||||
schedulers[i].accumulateStats(samples[i], stats);
|
schedulers[i].accumulateStats(samples[i], stats);
|
||||||
|
@ -319,7 +319,7 @@ void accomodateStudent() {
|
||||||
// GTSAM_PRINT(*chordal);
|
// GTSAM_PRINT(*chordal);
|
||||||
|
|
||||||
// solve root node only
|
// solve root node only
|
||||||
Scheduler::Values values;
|
DiscreteValues values;
|
||||||
size_t bestSlot = root->solve(values);
|
size_t bestSlot = root->solve(values);
|
||||||
|
|
||||||
// get corresponding count
|
// get corresponding count
|
||||||
|
|
|
@ -190,7 +190,7 @@ void solveStaged(size_t addMutex = 2) {
|
||||||
root->print(""/*scheduler.studentName(s)*/);
|
root->print(""/*scheduler.studentName(s)*/);
|
||||||
|
|
||||||
// solve root node only
|
// solve root node only
|
||||||
Scheduler::Values values;
|
DiscreteValues values;
|
||||||
size_t bestSlot = root->solve(values);
|
size_t bestSlot = root->solve(values);
|
||||||
|
|
||||||
// get corresponding count
|
// get corresponding count
|
||||||
|
@ -234,7 +234,7 @@ void sampleSolutions() {
|
||||||
// now, sample schedules
|
// now, sample schedules
|
||||||
for (size_t n = 0; n < 500; n++) {
|
for (size_t n = 0; n < 500; n++) {
|
||||||
vector<size_t> stats(19, 0);
|
vector<size_t> stats(19, 0);
|
||||||
vector<Scheduler::Values> samples;
|
vector<DiscreteValues> samples;
|
||||||
for (size_t i = 0; i < NRSTUDENTS; i++) {
|
for (size_t i = 0; i < NRSTUDENTS; i++) {
|
||||||
samples.push_back(samplers[i]->sample());
|
samples.push_back(samplers[i]->sample());
|
||||||
schedulers[i].accumulateStats(samples[i], stats);
|
schedulers[i].accumulateStats(samples[i], stats);
|
||||||
|
|
|
@ -212,7 +212,7 @@ void solveStaged(size_t addMutex = 2) {
|
||||||
root->print(""/*scheduler.studentName(s)*/);
|
root->print(""/*scheduler.studentName(s)*/);
|
||||||
|
|
||||||
// solve root node only
|
// solve root node only
|
||||||
Scheduler::Values values;
|
DiscreteValues values;
|
||||||
size_t bestSlot = root->solve(values);
|
size_t bestSlot = root->solve(values);
|
||||||
|
|
||||||
// get corresponding count
|
// get corresponding count
|
||||||
|
@ -259,7 +259,7 @@ void sampleSolutions() {
|
||||||
// now, sample schedules
|
// now, sample schedules
|
||||||
for (size_t n = 0; n < 10000; n++) {
|
for (size_t n = 0; n < 10000; n++) {
|
||||||
vector<size_t> stats(nrFaculty, 0);
|
vector<size_t> stats(nrFaculty, 0);
|
||||||
vector<Scheduler::Values> samples;
|
vector<DiscreteValues> samples;
|
||||||
for (size_t i = 0; i < NRSTUDENTS; i++) {
|
for (size_t i = 0; i < NRSTUDENTS; i++) {
|
||||||
samples.push_back(samplers[i]->sample());
|
samples.push_back(samplers[i]->sample());
|
||||||
schedulers[i].accumulateStats(samples[i], stats);
|
schedulers[i].accumulateStats(samples[i], stats);
|
||||||
|
|
|
@ -112,14 +112,14 @@ TEST(CSP, allInOne) {
|
||||||
csp.addAllDiff(UT, AZ);
|
csp.addAllDiff(UT, AZ);
|
||||||
|
|
||||||
// Check an invalid combination, with ID==UT==AZ all same color
|
// Check an invalid combination, with ID==UT==AZ all same color
|
||||||
DiscreteFactor::Values invalid;
|
DiscreteValues invalid;
|
||||||
invalid[ID.first] = 0;
|
invalid[ID.first] = 0;
|
||||||
invalid[UT.first] = 0;
|
invalid[UT.first] = 0;
|
||||||
invalid[AZ.first] = 0;
|
invalid[AZ.first] = 0;
|
||||||
EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9);
|
EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9);
|
||||||
|
|
||||||
// Check a valid combination
|
// Check a valid combination
|
||||||
DiscreteFactor::Values valid;
|
DiscreteValues valid;
|
||||||
valid[ID.first] = 0;
|
valid[ID.first] = 0;
|
||||||
valid[UT.first] = 1;
|
valid[UT.first] = 1;
|
||||||
valid[AZ.first] = 0;
|
valid[AZ.first] = 0;
|
||||||
|
@ -133,7 +133,7 @@ TEST(CSP, allInOne) {
|
||||||
|
|
||||||
// Solve
|
// Solve
|
||||||
auto mpe = csp.optimalAssignment();
|
auto mpe = csp.optimalAssignment();
|
||||||
CSP::Values expected;
|
DiscreteValues expected;
|
||||||
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1);
|
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1);
|
||||||
EXPECT(assert_equal(expected, mpe));
|
EXPECT(assert_equal(expected, mpe));
|
||||||
EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9);
|
EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9);
|
||||||
|
@ -179,7 +179,7 @@ TEST(CSP, WesternUS) {
|
||||||
// Solve using that ordering:
|
// Solve using that ordering:
|
||||||
auto mpe = csp.optimalAssignment(ordering);
|
auto mpe = csp.optimalAssignment(ordering);
|
||||||
// GTSAM_PRINT(mpe);
|
// GTSAM_PRINT(mpe);
|
||||||
CSP::Values expected;
|
DiscreteValues expected;
|
||||||
insert(expected)(WA.first, 1)(CA.first, 1)(NV.first, 3)(OR.first, 0)(
|
insert(expected)(WA.first, 1)(CA.first, 1)(NV.first, 3)(OR.first, 0)(
|
||||||
MT.first, 1)(WY.first, 0)(NM.first, 3)(CO.first, 2)(ID.first, 2)(
|
MT.first, 1)(WY.first, 0)(NM.first, 3)(CO.first, 2)(ID.first, 2)(
|
||||||
UT.first, 1)(AZ.first, 0);
|
UT.first, 1)(AZ.first, 0);
|
||||||
|
@ -213,14 +213,14 @@ TEST(CSP, ArcConsistency) {
|
||||||
// GTSAM_PRINT(csp);
|
// GTSAM_PRINT(csp);
|
||||||
|
|
||||||
// Check an invalid combination, with ID==UT==AZ all same color
|
// Check an invalid combination, with ID==UT==AZ all same color
|
||||||
DiscreteFactor::Values invalid;
|
DiscreteValues invalid;
|
||||||
invalid[ID.first] = 0;
|
invalid[ID.first] = 0;
|
||||||
invalid[UT.first] = 1;
|
invalid[UT.first] = 1;
|
||||||
invalid[AZ.first] = 0;
|
invalid[AZ.first] = 0;
|
||||||
EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9);
|
EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9);
|
||||||
|
|
||||||
// Check a valid combination
|
// Check a valid combination
|
||||||
DiscreteFactor::Values valid;
|
DiscreteValues valid;
|
||||||
valid[ID.first] = 0;
|
valid[ID.first] = 0;
|
||||||
valid[UT.first] = 1;
|
valid[UT.first] = 1;
|
||||||
valid[AZ.first] = 2;
|
valid[AZ.first] = 2;
|
||||||
|
@ -228,7 +228,7 @@ TEST(CSP, ArcConsistency) {
|
||||||
|
|
||||||
// Solve
|
// Solve
|
||||||
auto mpe = csp.optimalAssignment();
|
auto mpe = csp.optimalAssignment();
|
||||||
CSP::Values expected;
|
DiscreteValues expected;
|
||||||
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2);
|
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2);
|
||||||
EXPECT(assert_equal(expected, mpe));
|
EXPECT(assert_equal(expected, mpe));
|
||||||
EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9);
|
EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9);
|
||||||
|
@ -250,7 +250,7 @@ TEST(CSP, ArcConsistency) {
|
||||||
LONGS_EQUAL(2, domains.at(2).nrValues());
|
LONGS_EQUAL(2, domains.at(2).nrValues());
|
||||||
|
|
||||||
// Parial application, version 1
|
// Parial application, version 1
|
||||||
DiscreteFactor::Values known;
|
DiscreteValues known;
|
||||||
known[AZ.first] = 2;
|
known[AZ.first] = 2;
|
||||||
DiscreteFactor::shared_ptr reduced1 = alldiff.partiallyApply(known);
|
DiscreteFactor::shared_ptr reduced1 = alldiff.partiallyApply(known);
|
||||||
DecisionTreeFactor f3(ID & UT, "0 1 1 1 0 1 1 1 0");
|
DecisionTreeFactor f3(ID & UT, "0 1 1 1 0 1 1 1 0");
|
||||||
|
|
|
@ -126,7 +126,7 @@ class LoopyBelief {
|
||||||
// normalize belief
|
// normalize belief
|
||||||
double sum = 0.0;
|
double sum = 0.0;
|
||||||
for (size_t v = 0; v < allDiscreteKeys.at(key).second; ++v) {
|
for (size_t v = 0; v < allDiscreteKeys.at(key).second; ++v) {
|
||||||
DiscreteFactor::Values val;
|
DiscreteValues val;
|
||||||
val[key] = v;
|
val[key] = v;
|
||||||
sum += (*beliefAtKey)(val);
|
sum += (*beliefAtKey)(val);
|
||||||
}
|
}
|
||||||
|
|
|
@ -88,7 +88,7 @@ class Sudoku : public CSP {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Print readable form of assignment
|
/// Print readable form of assignment
|
||||||
void printAssignment(const DiscreteFactor::Values& assignment) const {
|
void printAssignment(const DiscreteValues& assignment) const {
|
||||||
for (size_t i = 0; i < n_; i++) {
|
for (size_t i = 0; i < n_; i++) {
|
||||||
for (size_t j = 0; j < n_; j++) {
|
for (size_t j = 0; j < n_; j++) {
|
||||||
Key k = key(i, j);
|
Key k = key(i, j);
|
||||||
|
@ -127,7 +127,7 @@ TEST(Sudoku, small) {
|
||||||
|
|
||||||
// optimize and check
|
// optimize and check
|
||||||
auto solution = csp.optimalAssignment();
|
auto solution = csp.optimalAssignment();
|
||||||
CSP::Values expected;
|
DiscreteValues expected;
|
||||||
insert(expected)(csp.key(0, 0), 0)(csp.key(0, 1), 1)(csp.key(0, 2), 2)(
|
insert(expected)(csp.key(0, 0), 0)(csp.key(0, 1), 1)(csp.key(0, 2), 2)(
|
||||||
csp.key(0, 3), 3)(csp.key(1, 0), 2)(csp.key(1, 1), 3)(csp.key(1, 2), 0)(
|
csp.key(0, 3), 3)(csp.key(1, 0), 2)(csp.key(1, 1), 3)(csp.key(1, 2), 0)(
|
||||||
csp.key(1, 3), 1)(csp.key(2, 0), 3)(csp.key(2, 1), 2)(csp.key(2, 2), 1)(
|
csp.key(1, 3), 1)(csp.key(2, 0), 3)(csp.key(2, 1), 2)(csp.key(2, 2), 1)(
|
||||||
|
|
Loading…
Reference in New Issue