Use DiscreteValues everywhere

release/4.3a0
Frank Dellaert 2021-12-13 13:46:53 -05:00
parent c63c1167ba
commit e89a294376
35 changed files with 125 additions and 136 deletions

View File

@ -80,7 +80,7 @@ namespace gtsam {
/// @{ /// @{
/// Value is just look up in AlgebraicDecisonTree /// Value is just look up in AlgebraicDecisonTree
double operator()(const Values& values) const override { double operator()(const DiscreteValues& values) const override {
return Potentials::operator()(values); return Potentials::operator()(values);
} }

View File

@ -45,7 +45,7 @@ namespace gtsam {
} }
/* ************************************************************************* */ /* ************************************************************************* */
double DiscreteBayesNet::evaluate(const DiscreteConditional::Values & values) const { double DiscreteBayesNet::evaluate(const DiscreteValues & values) const {
// evaluate all conditionals and multiply // evaluate all conditionals and multiply
double result = 1.0; double result = 1.0;
for(DiscreteConditional::shared_ptr conditional: *this) for(DiscreteConditional::shared_ptr conditional: *this)
@ -54,18 +54,18 @@ namespace gtsam {
} }
/* ************************************************************************* */ /* ************************************************************************* */
DiscreteFactor::Values DiscreteBayesNet::optimize() const { DiscreteValues DiscreteBayesNet::optimize() const {
// solve each node in turn in topological sort order (parents first) // solve each node in turn in topological sort order (parents first)
DiscreteFactor::Values result; DiscreteValues result;
for (auto conditional: boost::adaptors::reverse(*this)) for (auto conditional: boost::adaptors::reverse(*this))
conditional->solveInPlace(&result); conditional->solveInPlace(&result);
return result; return result;
} }
/* ************************************************************************* */ /* ************************************************************************* */
DiscreteFactor::Values DiscreteBayesNet::sample() const { DiscreteValues DiscreteBayesNet::sample() const {
// sample each node in turn in topological sort order (parents first) // sample each node in turn in topological sort order (parents first)
DiscreteFactor::Values result; DiscreteValues result;
for (auto conditional: boost::adaptors::reverse(*this)) for (auto conditional: boost::adaptors::reverse(*this))
conditional->sampleInPlace(&result); conditional->sampleInPlace(&result);
return result; return result;

View File

@ -77,16 +77,16 @@ namespace gtsam {
// /** Add a DiscreteCondtional in front, when listing parents first*/ // /** Add a DiscreteCondtional in front, when listing parents first*/
// GTSAM_EXPORT void add_front(const Signature& s); // GTSAM_EXPORT void add_front(const Signature& s);
//** evaluate for given Values */ //** evaluate for given DiscreteValues */
double evaluate(const DiscreteConditional::Values & values) const; double evaluate(const DiscreteValues & values) const;
/** /**
* Solve the DiscreteBayesNet by back-substitution * Solve the DiscreteBayesNet by back-substitution
*/ */
DiscreteFactor::Values optimize() const; DiscreteValues optimize() const;
/** Do ancestral sampling */ /** Do ancestral sampling */
DiscreteFactor::Values sample() const; DiscreteValues sample() const;
///@} ///@}

View File

@ -31,7 +31,7 @@ namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
double DiscreteBayesTreeClique::evaluate( double DiscreteBayesTreeClique::evaluate(
const DiscreteConditional::Values& values) const { const DiscreteValues& values) const {
// evaluate all conditionals and multiply // evaluate all conditionals and multiply
double result = (*conditional_)(values); double result = (*conditional_)(values);
for (const auto& child : children) { for (const auto& child : children) {
@ -47,7 +47,7 @@ namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
double DiscreteBayesTree::evaluate( double DiscreteBayesTree::evaluate(
const DiscreteConditional::Values& values) const { const DiscreteValues& values) const {
double result = 1.0; double result = 1.0;
for (const auto& root : roots_) { for (const auto& root : roots_) {
result *= root->evaluate(values); result *= root->evaluate(values);

View File

@ -57,8 +57,8 @@ class GTSAM_EXPORT DiscreteBayesTreeClique
conditional_->printSignature(s, formatter); conditional_->printSignature(s, formatter);
} }
//** evaluate conditional probability of subtree for given Values */ //** evaluate conditional probability of subtree for given DiscreteValues */
double evaluate(const DiscreteConditional::Values& values) const; double evaluate(const DiscreteValues& values) const;
}; };
/* ************************************************************************* */ /* ************************************************************************* */
@ -78,8 +78,8 @@ class GTSAM_EXPORT DiscreteBayesTree
/** Check equality */ /** Check equality */
bool equals(const This& other, double tol = 1e-9) const; bool equals(const This& other, double tol = 1e-9) const;
//** evaluate probability for given Values */ //** evaluate probability for given DiscreteValues */
double evaluate(const DiscreteConditional::Values& values) const; double evaluate(const DiscreteValues& values) const;
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -97,7 +97,7 @@ bool DiscreteConditional::equals(const DiscreteFactor& other,
} }
/* ******************************************************************************** */ /* ******************************************************************************** */
Potentials::ADT DiscreteConditional::choose(const Values& parentsValues) const { Potentials::ADT DiscreteConditional::choose(const DiscreteValues& parentsValues) const {
ADT pFS(*this); ADT pFS(*this);
Key j; size_t value; Key j; size_t value;
for(Key key: parents()) { for(Key key: parents()) {
@ -117,12 +117,12 @@ Potentials::ADT DiscreteConditional::choose(const Values& parentsValues) const {
} }
/* ******************************************************************************** */ /* ******************************************************************************** */
void DiscreteConditional::solveInPlace(Values* values) const { void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
// TODO: Abhijit asks: is this really the fastest way? He thinks it is. // TODO: Abhijit asks: is this really the fastest way? He thinks it is.
ADT pFS = choose(*values); // P(F|S=parentsValues) ADT pFS = choose(*values); // P(F|S=parentsValues)
// Initialize // Initialize
Values mpe; DiscreteValues mpe;
double maxP = 0; double maxP = 0;
DiscreteKeys keys; DiscreteKeys keys;
@ -131,10 +131,10 @@ void DiscreteConditional::solveInPlace(Values* values) const {
keys & dk; keys & dk;
} }
// Get all Possible Configurations // Get all Possible Configurations
vector<Values> allPosbValues = cartesianProduct(keys); vector<DiscreteValues> allPosbValues = cartesianProduct(keys);
// Find the MPE // Find the MPE
for(Values& frontalVals: allPosbValues) { for(DiscreteValues& frontalVals: allPosbValues) {
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
// Update MPE solution if better // Update MPE solution if better
if (pValueS > maxP) { if (pValueS > maxP) {
@ -150,7 +150,7 @@ void DiscreteConditional::solveInPlace(Values* values) const {
} }
/* ******************************************************************************** */ /* ******************************************************************************** */
void DiscreteConditional::sampleInPlace(Values* values) const { void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
assert(nrFrontals() == 1); assert(nrFrontals() == 1);
Key j = (firstFrontalKey()); Key j = (firstFrontalKey());
size_t sampled = sample(*values); // Sample variable given parents size_t sampled = sample(*values); // Sample variable given parents
@ -158,7 +158,7 @@ void DiscreteConditional::sampleInPlace(Values* values) const {
} }
/* ******************************************************************************** */ /* ******************************************************************************** */
size_t DiscreteConditional::solve(const Values& parentsValues) const { size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
// TODO: is this really the fastest way? I think it is. // TODO: is this really the fastest way? I think it is.
ADT pFS = choose(parentsValues); // P(F|S=parentsValues) ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
@ -166,7 +166,7 @@ size_t DiscreteConditional::solve(const Values& parentsValues) const {
// Then, find the max over all remaining // Then, find the max over all remaining
// TODO, only works for one key now, seems horribly slow this way // TODO, only works for one key now, seems horribly slow this way
size_t mpe = 0; size_t mpe = 0;
Values frontals; DiscreteValues frontals;
double maxP = 0; double maxP = 0;
assert(nrFrontals() == 1); assert(nrFrontals() == 1);
Key j = (firstFrontalKey()); Key j = (firstFrontalKey());
@ -183,7 +183,7 @@ size_t DiscreteConditional::solve(const Values& parentsValues) const {
} }
/* ******************************************************************************** */ /* ******************************************************************************** */
size_t DiscreteConditional::sample(const Values& parentsValues) const { size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
static mt19937 rng(2); // random number generator static mt19937 rng(2); // random number generator
// Get the correct conditional density // Get the correct conditional density
@ -194,7 +194,7 @@ size_t DiscreteConditional::sample(const Values& parentsValues) const {
Key key = firstFrontalKey(); Key key = firstFrontalKey();
size_t nj = cardinality(key); size_t nj = cardinality(key);
vector<double> p(nj); vector<double> p(nj);
Values frontals; DiscreteValues frontals;
for (size_t value = 0; value < nj; value++) { for (size_t value = 0; value < nj; value++) {
frontals[key] = value; frontals[key] = value;
p[value] = pFS(frontals); // P(F=value|S=parentsValues) p[value] = pFS(frontals); // P(F=value|S=parentsValues)

View File

@ -42,9 +42,7 @@ public:
typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class
typedef Conditional<BaseFactor, This> BaseConditional; ///< Typedef to our conditional base class typedef Conditional<BaseFactor, This> BaseConditional; ///< Typedef to our conditional base class
/** A map from keys to values.. using Values = DiscreteValues; ///< backwards compatibility
* TODO: Again, do we need this??? */
typedef Assignment<Key> Values;
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
@ -101,7 +99,7 @@ public:
} }
/// Evaluate, just look up in AlgebraicDecisonTree /// Evaluate, just look up in AlgebraicDecisonTree
double operator()(const Values& values) const override { double operator()(const DiscreteValues& values) const override {
return Potentials::operator()(values); return Potentials::operator()(values);
} }
@ -111,31 +109,31 @@ public:
} }
/** Restrict to given parent values, returns AlgebraicDecisionDiagram */ /** Restrict to given parent values, returns AlgebraicDecisionDiagram */
ADT choose(const Assignment<Key>& parentsValues) const; ADT choose(const DiscreteValues& parentsValues) const;
/** /**
* solve a conditional * solve a conditional
* @param parentsValues Known values of the parents * @param parentsValues Known values of the parents
* @return MPE value of the child (1 frontal variable). * @return MPE value of the child (1 frontal variable).
*/ */
size_t solve(const Values& parentsValues) const; size_t solve(const DiscreteValues& parentsValues) const;
/** /**
* sample * sample
* @param parentsValues Known values of the parents * @param parentsValues Known values of the parents
* @return sample from conditional * @return sample from conditional
*/ */
size_t sample(const Values& parentsValues) const; size_t sample(const DiscreteValues& parentsValues) const;
/// @} /// @}
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{
/// solve a conditional, in place /// solve a conditional, in place
void solveInPlace(Values* parentsValues) const; void solveInPlace(DiscreteValues* parentsValues) const;
/// sample in place, stores result in partial solution /// sample in place, stores result in partial solution
void sampleInPlace(Values* parentsValues) const; void sampleInPlace(DiscreteValues* parentsValues) const;
/// @} /// @}

View File

@ -18,7 +18,7 @@
#pragma once #pragma once
#include <gtsam/discrete/Assignment.h> #include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/inference/Factor.h> #include <gtsam/inference/Factor.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
@ -40,17 +40,7 @@ public:
typedef boost::shared_ptr<DiscreteFactor> shared_ptr; ///< shared_ptr to this class typedef boost::shared_ptr<DiscreteFactor> shared_ptr; ///< shared_ptr to this class
typedef Factor Base; ///< Our base class typedef Factor Base; ///< Our base class
/** A map from keys to values using Values = DiscreteValues; ///< backwards compatibility
* TODO: Do we need this? Should we just use gtsam::Values?
* We just need another special DiscreteValue to represent labels,
* However, all other Lie's operators are undefined in this class.
* The good thing is we can have a Hybrid graph of discrete/continuous variables
* together..
* Another good thing is we don't need to have the special DiscreteKey which stores
* cardinality of a Discrete variable. It should be handled naturally in
* the new class DiscreteValue, as the varible's type (domain)
*/
typedef Assignment<Key> Values;
public: public:
@ -91,7 +81,7 @@ public:
/// @{ /// @{
/// Find value for given assignment of values to variables /// Find value for given assignment of values to variables
virtual double operator()(const Values&) const = 0; virtual double operator()(const DiscreteValues&) const = 0;
/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor /// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
@ -104,6 +94,6 @@ public:
// traits // traits
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {}; template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
template<> struct traits<DiscreteFactor::Values> : public Testable<DiscreteFactor::Values> {}; template<> struct traits<DiscreteValues> : public Testable<DiscreteValues> {};
}// namespace gtsam }// namespace gtsam

View File

@ -56,7 +56,7 @@ namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
double DiscreteFactorGraph::operator()( double DiscreteFactorGraph::operator()(
const DiscreteFactor::Values &values) const { const DiscreteValues &values) const {
double product = 1.0; double product = 1.0;
for( const sharedFactor& factor: factors_ ) for( const sharedFactor& factor: factors_ )
product *= (*factor)(values); product *= (*factor)(values);
@ -94,7 +94,7 @@ namespace gtsam {
// } // }
/* ************************************************************************* */ /* ************************************************************************* */
DiscreteFactorGraph::Values DiscreteFactorGraph::optimize() const DiscreteValues DiscreteFactorGraph::optimize() const
{ {
gttic(DiscreteFactorGraph_optimize); gttic(DiscreteFactorGraph_optimize);
return BaseEliminateable::eliminateSequential()->optimize(); return BaseEliminateable::eliminateSequential()->optimize();

View File

@ -71,9 +71,10 @@ public:
typedef EliminateableFactorGraph<This> BaseEliminateable; ///< Typedef to base elimination class typedef EliminateableFactorGraph<This> BaseEliminateable; ///< Typedef to base elimination class
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
using Values = DiscreteValues; ///< backwards compatibility
/** A map from keys to values */ /** A map from keys to values */
typedef KeyVector Indices; typedef KeyVector Indices;
typedef Assignment<Key> Values;
/** Default constructor */ /** Default constructor */
DiscreteFactorGraph() {} DiscreteFactorGraph() {}
@ -130,7 +131,7 @@ public:
DecisionTreeFactor product() const; DecisionTreeFactor product() const;
/** Evaluates the factor graph given values, returns the joint probability of the factor graph given specific instantiation of values*/ /** Evaluates the factor graph given values, returns the joint probability of the factor graph given specific instantiation of values*/
double operator()(const DiscreteFactor::Values & values) const; double operator()(const DiscreteValues & values) const;
/// print /// print
void print( void print(
@ -141,7 +142,7 @@ public:
* the dense elimination function specified in \c function, * the dense elimination function specified in \c function,
* followed by back-substitution resulting from elimination. Is equivalent * followed by back-substitution resulting from elimination. Is equivalent
* to calling graph.eliminateSequential()->optimize(). */ * to calling graph.eliminateSequential()->optimize(). */
Values optimize() const; DiscreteValues optimize() const;
// /** Permute the variables in the factors */ // /** Permute the variables in the factors */

View File

@ -64,7 +64,7 @@ namespace gtsam {
//Create result //Create result
Vector vResult(key.second); Vector vResult(key.second);
for (size_t state = 0; state < key.second ; ++ state) { for (size_t state = 0; state < key.second ; ++ state) {
DiscreteFactor::Values values; DiscreteValues values;
values[key.first] = state; values[key.first] = state;
vResult(state) = (*marginalFactor)(values); vResult(state) = (*marginalFactor)(values);
} }

View File

@ -18,6 +18,7 @@
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits #include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
#include <gtsam/discrete/DiscreteValues.h>
// headers first to make sure no missing headers // headers first to make sure no missing headers
//#define DT_NO_PRUNING //#define DT_NO_PRUNING
#include <gtsam/discrete/AlgebraicDecisionTree.h> #include <gtsam/discrete/AlgebraicDecisionTree.h>
@ -445,7 +446,7 @@ TEST(ADT, equality_parser)
TEST(ADT, constructor) TEST(ADT, constructor)
{ {
DiscreteKey v0(0,2), v1(1,3); DiscreteKey v0(0,2), v1(1,3);
Assignment<Key> x00, x01, x02, x10, x11, x12; DiscreteValues x00, x01, x02, x10, x11, x12;
x00[0] = 0, x00[1] = 0; x00[0] = 0, x00[1] = 0;
x01[0] = 0, x01[1] = 1; x01[0] = 0, x01[1] = 1;
x02[0] = 0, x02[1] = 2; x02[0] = 0, x02[1] = 2;
@ -475,7 +476,7 @@ TEST(ADT, constructor)
for(double& t: table) for(double& t: table)
t = x++; t = x++;
ADT f3(z0 & z1 & z2 & z3, table); ADT f3(z0 & z1 & z2 & z3, table);
Assignment<Key> assignment; DiscreteValues assignment;
assignment[0] = 0; assignment[0] = 0;
assignment[1] = 0; assignment[1] = 0;
assignment[2] = 0; assignment[2] = 0;
@ -501,7 +502,7 @@ TEST(ADT, conversion)
// f2.print("f2"); // f2.print("f2");
dot(fIndexKey, "conversion-f2"); dot(fIndexKey, "conversion-f2");
Assignment<Key> x00, x01, x02, x10, x11, x12; DiscreteValues x00, x01, x02, x10, x11, x12;
x00[5] = 0, x00[2] = 0; x00[5] = 0, x00[2] = 0;
x01[5] = 0, x01[2] = 1; x01[5] = 0, x01[2] = 1;
x10[5] = 1, x10[2] = 0; x10[5] = 1, x10[2] = 0;
@ -577,7 +578,7 @@ TEST(ADT, zero)
ADT notb(B, 1, 0); ADT notb(B, 1, 0);
ADT anotb = a * notb; ADT anotb = a * notb;
// GTSAM_PRINT(anotb); // GTSAM_PRINT(anotb);
Assignment<Key> x00, x01, x10, x11; DiscreteValues x00, x01, x10, x11;
x00[0] = 0, x00[1] = 0; x00[0] = 0, x00[1] = 0;
x01[0] = 0, x01[1] = 1; x01[0] = 0, x01[1] = 1;
x10[0] = 1, x10[1] = 0; x10[0] = 1, x10[1] = 0;

View File

@ -43,7 +43,7 @@ TEST( DecisionTreeFactor, constructors)
// f2.print("f2:"); // f2.print("f2:");
// f3.print("f3:"); // f3.print("f3:");
DecisionTreeFactor::Values values; DiscreteValues values;
values[0] = 1; // x values[0] = 1; // x
values[1] = 2; // y values[1] = 2; // y
values[2] = 1; // z values[2] = 1; // z

View File

@ -105,7 +105,7 @@ TEST(DiscreteBayesNet, Asia) {
// solve // solve
auto actualMPE = chordal->optimize(); auto actualMPE = chordal->optimize();
DiscreteFactor::Values expectedMPE; DiscreteValues expectedMPE;
insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)( insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)(
Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)( Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)(
LungCancer.first, 0)(Bronchitis.first, 0); LungCancer.first, 0)(Bronchitis.first, 0);
@ -118,14 +118,14 @@ TEST(DiscreteBayesNet, Asia) {
// solve again, now with evidence // solve again, now with evidence
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering); DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
auto actualMPE2 = chordal2->optimize(); auto actualMPE2 = chordal2->optimize();
DiscreteFactor::Values expectedMPE2; DiscreteValues expectedMPE2;
insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)( insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)(
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)( Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)(
LungCancer.first, 0)(Bronchitis.first, 1); LungCancer.first, 0)(Bronchitis.first, 1);
EXPECT(assert_equal(expectedMPE2, actualMPE2)); EXPECT(assert_equal(expectedMPE2, actualMPE2));
// now sample from it // now sample from it
DiscreteFactor::Values expectedSample; DiscreteValues expectedSample;
SETDEBUG("DiscreteConditional::sample", false); SETDEBUG("DiscreteConditional::sample", false);
insert(expectedSample)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 1)( insert(expectedSample)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 1)(
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 1)( Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 1)(

View File

@ -89,24 +89,24 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
auto R = bayesTree->roots().front(); auto R = bayesTree->roots().front();
// Check whether BN and BT give the same answer on all configurations // Check whether BN and BT give the same answer on all configurations
vector<DiscreteFactor::Values> allPosbValues = cartesianProduct( vector<DiscreteValues> allPosbValues = cartesianProduct(
key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] & key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] &
key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]); key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
for (size_t i = 0; i < allPosbValues.size(); ++i) { for (size_t i = 0; i < allPosbValues.size(); ++i) {
DiscreteFactor::Values x = allPosbValues[i]; DiscreteValues x = allPosbValues[i];
double expected = bayesNet.evaluate(x); double expected = bayesNet.evaluate(x);
double actual = bayesTree->evaluate(x); double actual = bayesTree->evaluate(x);
DOUBLES_EQUAL(expected, actual, 1e-9); DOUBLES_EQUAL(expected, actual, 1e-9);
} }
// Calculate all some marginals for Values==all1 // Calculate all some marginals for DiscreteValues==all1
Vector marginals = Vector::Zero(15); Vector marginals = Vector::Zero(15);
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0, double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0, joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0,
joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0; joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
for (size_t i = 0; i < allPosbValues.size(); ++i) { for (size_t i = 0; i < allPosbValues.size(); ++i) {
DiscreteFactor::Values x = allPosbValues[i]; DiscreteValues x = allPosbValues[i];
double px = bayesTree->evaluate(x); double px = bayesTree->evaluate(x);
for (size_t i = 0; i < 15; i++) for (size_t i = 0; i < 15; i++)
if (x[i]) marginals[i] += px; if (x[i]) marginals[i] += px;
@ -138,7 +138,7 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
} }
} }
} }
DiscreteFactor::Values all1 = allPosbValues.back(); DiscreteValues all1 = allPosbValues.back();
// check separator marginal P(S0) // check separator marginal P(S0)
auto clique = (*bayesTree)[0]; auto clique = (*bayesTree)[0];

View File

@ -81,8 +81,8 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
graph.add(P2, "0.9 0.6"); graph.add(P2, "0.9 0.6");
graph.add(P1 & P2, "4 1 10 4"); graph.add(P1 & P2, "4 1 10 4");
// Instantiate Values // Instantiate DiscreteValues
DiscreteFactor::Values values; DiscreteValues values;
values[0] = 1; values[0] = 1;
values[1] = 1; values[1] = 1;
@ -167,7 +167,7 @@ TEST( DiscreteFactorGraph, test)
// EXPECT(assert_equal(expected, *actual2)); // EXPECT(assert_equal(expected, *actual2));
// Test optimization // Test optimization
DiscreteFactor::Values expectedValues; DiscreteValues expectedValues;
insert(expectedValues)(0, 0)(1, 0)(2, 0); insert(expectedValues)(0, 0)(1, 0)(2, 0);
auto actualValues = graph.optimize(); auto actualValues = graph.optimize();
EXPECT(assert_equal(expectedValues, actualValues)); EXPECT(assert_equal(expectedValues, actualValues));
@ -188,7 +188,7 @@ TEST( DiscreteFactorGraph, testMPE)
auto actualMPE = graph.optimize(); auto actualMPE = graph.optimize();
DiscreteFactor::Values expectedMPE; DiscreteValues expectedMPE;
insert(expectedMPE)(0, 0)(1, 1)(2, 1); insert(expectedMPE)(0, 0)(1, 1)(2, 1);
EXPECT(assert_equal(expectedMPE, actualMPE)); EXPECT(assert_equal(expectedMPE, actualMPE));
} }
@ -211,7 +211,7 @@ TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
// graph.product().potentials().dot("Darwiche-product"); // graph.product().potentials().dot("Darwiche-product");
// DiscreteSequentialSolver(graph).eliminate()->print(); // DiscreteSequentialSolver(graph).eliminate()->print();
DiscreteFactor::Values expectedMPE; DiscreteValues expectedMPE;
insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1); insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1);
// Use the solver machinery. // Use the solver machinery.

View File

@ -47,7 +47,7 @@ TEST_UNSAFE( DiscreteMarginals, UGM_small ) {
DiscreteMarginals marginals(graph); DiscreteMarginals marginals(graph);
DiscreteFactor::shared_ptr actualC = marginals(Cathy.first); DiscreteFactor::shared_ptr actualC = marginals(Cathy.first);
DiscreteFactor::Values values; DiscreteValues values;
values[Cathy.first] = 0; values[Cathy.first] = 0;
EXPECT_DOUBLES_EQUAL( 0.359631, (*actualC)(values), 1e-6); EXPECT_DOUBLES_EQUAL( 0.359631, (*actualC)(values), 1e-6);
@ -94,7 +94,7 @@ TEST_UNSAFE( DiscreteMarginals, UGM_chain ) {
DiscreteMarginals marginals(graph); DiscreteMarginals marginals(graph);
DiscreteFactor::shared_ptr actualC = marginals(key[2].first); DiscreteFactor::shared_ptr actualC = marginals(key[2].first);
DiscreteFactor::Values values; DiscreteValues values;
values[key[2].first] = 0; values[key[2].first] = 0;
EXPECT_DOUBLES_EQUAL( 0.03426, (*actualC)(values), 1e-4); EXPECT_DOUBLES_EQUAL( 0.03426, (*actualC)(values), 1e-4);
@ -164,11 +164,11 @@ TEST_UNSAFE(DiscreteMarginals, truss2) {
graph.add(key[2] & key[3] & key[4], "1 2 3 4 5 6 7 8"); graph.add(key[2] & key[3] & key[4], "1 2 3 4 5 6 7 8");
// Calculate the marginals by brute force // Calculate the marginals by brute force
vector<DiscreteFactor::Values> allPosbValues = vector<DiscreteValues> allPosbValues =
cartesianProduct(key[0] & key[1] & key[2] & key[3] & key[4]); cartesianProduct(key[0] & key[1] & key[2] & key[3] & key[4]);
Vector T = Z_5x1, F = Z_5x1; Vector T = Z_5x1, F = Z_5x1;
for (size_t i = 0; i < allPosbValues.size(); ++i) { for (size_t i = 0; i < allPosbValues.size(); ++i) {
DiscreteFactor::Values x = allPosbValues[i]; DiscreteValues x = allPosbValues[i];
double px = graph(x); double px = graph(x);
for (size_t j = 0; j < 5; j++) for (size_t j = 0; j < 5; j++)
if (x[j]) if (x[j])

View File

@ -26,7 +26,7 @@ void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
double AllDiff::operator()(const Values& values) const { double AllDiff::operator()(const DiscreteValues& values) const {
std::set<size_t> taken; // record values taken by keys std::set<size_t> taken; // record values taken by keys
for (Key dkey : keys_) { for (Key dkey : keys_) {
size_t value = values.at(dkey); // get the value for that key size_t value = values.at(dkey); // get the value for that key
@ -88,7 +88,7 @@ bool AllDiff::ensureArcConsistency(Key j, Domains* domains) const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const { Constraint::shared_ptr AllDiff::partiallyApply(const DiscreteValues& values) const {
DiscreteKeys newKeys; DiscreteKeys newKeys;
// loop over keys and add them only if they do not appear in values // loop over keys and add them only if they do not appear in values
for (Key k : keys_) for (Key k : keys_)
@ -101,7 +101,7 @@ Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const {
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr AllDiff::partiallyApply( Constraint::shared_ptr AllDiff::partiallyApply(
const Domains& domains) const { const Domains& domains) const {
DiscreteFactor::Values known; DiscreteValues known;
for (Key k : keys_) { for (Key k : keys_) {
const Domain& Dk = domains.at(k); const Domain& Dk = domains.at(k);
if (Dk.isSingleton()) known[k] = Dk.firstValue(); if (Dk.isSingleton()) known[k] = Dk.firstValue();

View File

@ -45,7 +45,7 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
} }
/// Calculate value = expensive ! /// Calculate value = expensive !
double operator()(const Values& values) const override; double operator()(const DiscreteValues& values) const override;
/// Convert into a decisiontree, can be *very* expensive ! /// Convert into a decisiontree, can be *very* expensive !
DecisionTreeFactor toDecisionTreeFactor() const override; DecisionTreeFactor toDecisionTreeFactor() const override;
@ -62,7 +62,7 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
bool ensureArcConsistency(Key j, Domains* domains) const override; bool ensureArcConsistency(Key j, Domains* domains) const override;
/// Partially apply known values /// Partially apply known values
Constraint::shared_ptr partiallyApply(const Values&) const override; Constraint::shared_ptr partiallyApply(const DiscreteValues&) const override;
/// Partially apply known values, domain version /// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply( Constraint::shared_ptr partiallyApply(

View File

@ -47,7 +47,7 @@ class BinaryAllDiff : public Constraint {
} }
/// Calculate value /// Calculate value
double operator()(const Values& values) const override { double operator()(const DiscreteValues& values) const override {
return (double)(values.at(keys_[0]) != values.at(keys_[1])); return (double)(values.at(keys_[0]) != values.at(keys_[1]));
} }
@ -82,7 +82,7 @@ class BinaryAllDiff : public Constraint {
} }
/// Partially apply known values /// Partially apply known values
Constraint::shared_ptr partiallyApply(const Values&) const override { Constraint::shared_ptr partiallyApply(const DiscreteValues&) const override {
throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented");
} }

View File

@ -14,13 +14,13 @@ using namespace std;
namespace gtsam { namespace gtsam {
/// Find the best total assignment - can be expensive /// Find the best total assignment - can be expensive
CSP::Values CSP::optimalAssignment() const { DiscreteValues CSP::optimalAssignment() const {
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(); DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential();
return chordal->optimize(); return chordal->optimize();
} }
/// Find the best total assignment - can be expensive /// Find the best total assignment - can be expensive
CSP::Values CSP::optimalAssignment(const Ordering& ordering) const { DiscreteValues CSP::optimalAssignment(const Ordering& ordering) const {
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(ordering); DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(ordering);
return chordal->optimize(); return chordal->optimize();
} }

View File

@ -20,10 +20,8 @@ namespace gtsam {
*/ */
class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph { class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph {
public: public:
/** A map from keys to values */ using Values = DiscreteValues; ///< backwards compatibility
typedef Assignment<Key> Values;
public:
/// Add a unary constraint, allowing only a single value /// Add a unary constraint, allowing only a single value
void addSingleValue(const DiscreteKey& dkey, size_t value) { void addSingleValue(const DiscreteKey& dkey, size_t value) {
emplace_shared<SingleValue>(dkey, value); emplace_shared<SingleValue>(dkey, value);
@ -46,10 +44,10 @@ class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph {
// } // }
/// Find the best total assignment - can be expensive. /// Find the best total assignment - can be expensive.
Values optimalAssignment() const; DiscreteValues optimalAssignment() const;
/// Find the best total assignment, with given ordering - can be expensive. /// Find the best total assignment, with given ordering - can be expensive.
Values optimalAssignment(const Ordering& ordering) const; DiscreteValues optimalAssignment(const Ordering& ordering) const;
// /* // /*
// * Perform loopy belief propagation // * Perform loopy belief propagation

View File

@ -18,6 +18,7 @@
#pragma once #pragma once
#include <gtsam/discrete/DiscreteFactor.h> #include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam_unstable/dllexport.h> #include <gtsam_unstable/dllexport.h>
#include <boost/assign.hpp> #include <boost/assign.hpp>
@ -75,7 +76,7 @@ class GTSAM_EXPORT Constraint : public DiscreteFactor {
virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0; virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0;
/// Partially apply known values /// Partially apply known values
virtual shared_ptr partiallyApply(const Values&) const = 0; virtual shared_ptr partiallyApply(const DiscreteValues&) const = 0;
/// Partially apply known values, domain version /// Partially apply known values, domain version
virtual shared_ptr partiallyApply(const Domains&) const = 0; virtual shared_ptr partiallyApply(const Domains&) const = 0;

View File

@ -31,7 +31,7 @@ string Domain::base1Str() const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
double Domain::operator()(const Values& values) const { double Domain::operator()(const DiscreteValues& values) const {
return contains(values.at(key())); return contains(values.at(key()));
} }
@ -79,8 +79,8 @@ boost::optional<Domain> Domain::checkAllDiff(const KeyVector keys,
} }
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr Domain::partiallyApply(const Values& values) const { Constraint::shared_ptr Domain::partiallyApply(const DiscreteValues& values) const {
Values::const_iterator it = values.find(key()); DiscreteValues::const_iterator it = values.find(key());
if (it != values.end() && !contains(it->second)) if (it != values.end() && !contains(it->second))
throw runtime_error("Domain::partiallyApply: unsatisfiable"); throw runtime_error("Domain::partiallyApply: unsatisfiable");
return boost::make_shared<Domain>(*this); return boost::make_shared<Domain>(*this);

View File

@ -76,7 +76,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
bool contains(size_t value) const { return values_.count(value) > 0; } bool contains(size_t value) const { return values_.count(value) > 0; }
/// Calculate value /// Calculate value
double operator()(const Values& values) const override; double operator()(const DiscreteValues& values) const override;
/// Convert into a decisiontree /// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override; DecisionTreeFactor toDecisionTreeFactor() const override;
@ -104,7 +104,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
const Domains& domains) const; const Domains& domains) const;
/// Partially apply known values /// Partially apply known values
Constraint::shared_ptr partiallyApply(const Values& values) const override; Constraint::shared_ptr partiallyApply(const DiscreteValues& values) const override;
/// Partially apply known values, domain version /// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply(const Domains& domains) const override; Constraint::shared_ptr partiallyApply(const Domains& domains) const override;

View File

@ -202,7 +202,7 @@ void Scheduler::print(const string& s, const KeyFormatter& formatter) const {
} // print } // print
/** Print readable form of assignment */ /** Print readable form of assignment */
void Scheduler::printAssignment(const Values& assignment) const { void Scheduler::printAssignment(const DiscreteValues& assignment) const {
// Not intended to be general! Assumes very particular ordering ! // Not intended to be general! Assumes very particular ordering !
cout << endl; cout << endl;
for (size_t s = 0; s < nrStudents(); s++) { for (size_t s = 0; s < nrStudents(); s++) {
@ -220,8 +220,8 @@ void Scheduler::printAssignment(const Values& assignment) const {
} }
/** Special print for single-student case */ /** Special print for single-student case */
void Scheduler::printSpecial(const Values& assignment) const { void Scheduler::printSpecial(const DiscreteValues& assignment) const {
Values::const_iterator it = assignment.begin(); DiscreteValues::const_iterator it = assignment.begin();
for (size_t area = 0; area < 3; area++, it++) { for (size_t area = 0; area < 3; area++, it++) {
size_t f = it->second; size_t f = it->second;
cout << setw(12) << studentArea(0, area) << ": " << facultyName_[f] << endl; cout << setw(12) << studentArea(0, area) << ": " << facultyName_[f] << endl;
@ -230,7 +230,7 @@ void Scheduler::printSpecial(const Values& assignment) const {
} }
/** Accumulate faculty stats */ /** Accumulate faculty stats */
void Scheduler::accumulateStats(const Values& assignment, void Scheduler::accumulateStats(const DiscreteValues& assignment,
vector<size_t>& stats) const { vector<size_t>& stats) const {
for (size_t s = 0; s < nrStudents(); s++) { for (size_t s = 0; s < nrStudents(); s++) {
Key base = 3 * s; Key base = 3 * s;
@ -256,7 +256,7 @@ DiscreteBayesNet::shared_ptr Scheduler::eliminate() const {
} }
/** Find the best total assignment - can be expensive */ /** Find the best total assignment - can be expensive */
Scheduler::Values Scheduler::optimalAssignment() const { DiscreteValues Scheduler::optimalAssignment() const {
DiscreteBayesNet::shared_ptr chordal = eliminate(); DiscreteBayesNet::shared_ptr chordal = eliminate();
if (ISDEBUG("Scheduler::optimalAssignment")) { if (ISDEBUG("Scheduler::optimalAssignment")) {
@ -267,21 +267,21 @@ Scheduler::Values Scheduler::optimalAssignment() const {
} }
gttic(my_optimize); gttic(my_optimize);
Values mpe = chordal->optimize(); DiscreteValues mpe = chordal->optimize();
gttoc(my_optimize); gttoc(my_optimize);
return mpe; return mpe;
} }
/** find the assignment of students to slots with most possible committees */ /** find the assignment of students to slots with most possible committees */
Scheduler::Values Scheduler::bestSchedule() const { DiscreteValues Scheduler::bestSchedule() const {
Values best; DiscreteValues best;
throw runtime_error("bestSchedule not implemented"); throw runtime_error("bestSchedule not implemented");
return best; return best;
} }
/** find the corresponding most desirable committee assignment */ /** find the corresponding most desirable committee assignment */
Scheduler::Values Scheduler::bestAssignment(const Values& bestSchedule) const { DiscreteValues Scheduler::bestAssignment(const DiscreteValues& bestSchedule) const {
Values best; DiscreteValues best;
throw runtime_error("bestAssignment not implemented"); throw runtime_error("bestAssignment not implemented");
return best; return best;
} }

View File

@ -134,26 +134,26 @@ class GTSAM_UNSTABLE_EXPORT Scheduler : public CSP {
const KeyFormatter& formatter = DefaultKeyFormatter) const override; const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/** Print readable form of assignment */ /** Print readable form of assignment */
void printAssignment(const Values& assignment) const; void printAssignment(const DiscreteValues& assignment) const;
/** Special print for single-student case */ /** Special print for single-student case */
void printSpecial(const Values& assignment) const; void printSpecial(const DiscreteValues& assignment) const;
/** Accumulate faculty stats */ /** Accumulate faculty stats */
void accumulateStats(const Values& assignment, void accumulateStats(const DiscreteValues& assignment,
std::vector<size_t>& stats) const; std::vector<size_t>& stats) const;
/** Eliminate, return a Bayes net */ /** Eliminate, return a Bayes net */
DiscreteBayesNet::shared_ptr eliminate() const; DiscreteBayesNet::shared_ptr eliminate() const;
/** Find the best total assignment - can be expensive */ /** Find the best total assignment - can be expensive */
Values optimalAssignment() const; DiscreteValues optimalAssignment() const;
/** find the assignment of students to slots with most possible committees */ /** find the assignment of students to slots with most possible committees */
Values bestSchedule() const; DiscreteValues bestSchedule() const;
/** find the corresponding most desirable committee assignment */ /** find the corresponding most desirable committee assignment */
Values bestAssignment(const Values& bestSchedule) const; DiscreteValues bestAssignment(const DiscreteValues& bestSchedule) const;
}; // Scheduler }; // Scheduler

View File

@ -23,7 +23,7 @@ void SingleValue::print(const string& s, const KeyFormatter& formatter) const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
double SingleValue::operator()(const Values& values) const { double SingleValue::operator()(const DiscreteValues& values) const {
return (double)(values.at(keys_[0]) == value_); return (double)(values.at(keys_[0]) == value_);
} }
@ -57,8 +57,8 @@ bool SingleValue::ensureArcConsistency(Key j, Domains* domains) const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const { Constraint::shared_ptr SingleValue::partiallyApply(const DiscreteValues& values) const {
Values::const_iterator it = values.find(keys_[0]); DiscreteValues::const_iterator it = values.find(keys_[0]);
if (it != values.end() && it->second != value_) if (it != values.end() && it->second != value_)
throw runtime_error("SingleValue::partiallyApply: unsatisfiable"); throw runtime_error("SingleValue::partiallyApply: unsatisfiable");
return boost::make_shared<SingleValue>(keys_[0], cardinality_, value_); return boost::make_shared<SingleValue>(keys_[0], cardinality_, value_);

View File

@ -50,7 +50,7 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
} }
/// Calculate value /// Calculate value
double operator()(const Values& values) const override; double operator()(const DiscreteValues& values) const override;
/// Convert into a decisiontree /// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override; DecisionTreeFactor toDecisionTreeFactor() const override;
@ -67,7 +67,7 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
bool ensureArcConsistency(Key j, Domains* domains) const override; bool ensureArcConsistency(Key j, Domains* domains) const override;
/// Partially apply known values /// Partially apply known values
Constraint::shared_ptr partiallyApply(const Values& values) const override; Constraint::shared_ptr partiallyApply(const DiscreteValues& values) const override;
/// Partially apply known values, domain version /// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply( Constraint::shared_ptr partiallyApply(

View File

@ -165,7 +165,7 @@ void solveStaged(size_t addMutex = 2) {
root->print(""/*scheduler.studentName(s)*/); root->print(""/*scheduler.studentName(s)*/);
// solve root node only // solve root node only
Scheduler::Values values; DiscreteValues values;
size_t bestSlot = root->solve(values); size_t bestSlot = root->solve(values);
// get corresponding count // get corresponding count
@ -225,7 +225,7 @@ void sampleSolutions() {
// now, sample schedules // now, sample schedules
for (size_t n = 0; n < 500; n++) { for (size_t n = 0; n < 500; n++) {
vector<size_t> stats(19, 0); vector<size_t> stats(19, 0);
vector<Scheduler::Values> samples; vector<DiscreteValues> samples;
for (size_t i = 0; i < 7; i++) { for (size_t i = 0; i < 7; i++) {
samples.push_back(samplers[i]->sample()); samples.push_back(samplers[i]->sample());
schedulers[i].accumulateStats(samples[i], stats); schedulers[i].accumulateStats(samples[i], stats);
@ -319,7 +319,7 @@ void accomodateStudent() {
// GTSAM_PRINT(*chordal); // GTSAM_PRINT(*chordal);
// solve root node only // solve root node only
Scheduler::Values values; DiscreteValues values;
size_t bestSlot = root->solve(values); size_t bestSlot = root->solve(values);
// get corresponding count // get corresponding count

View File

@ -190,7 +190,7 @@ void solveStaged(size_t addMutex = 2) {
root->print(""/*scheduler.studentName(s)*/); root->print(""/*scheduler.studentName(s)*/);
// solve root node only // solve root node only
Scheduler::Values values; DiscreteValues values;
size_t bestSlot = root->solve(values); size_t bestSlot = root->solve(values);
// get corresponding count // get corresponding count
@ -234,7 +234,7 @@ void sampleSolutions() {
// now, sample schedules // now, sample schedules
for (size_t n = 0; n < 500; n++) { for (size_t n = 0; n < 500; n++) {
vector<size_t> stats(19, 0); vector<size_t> stats(19, 0);
vector<Scheduler::Values> samples; vector<DiscreteValues> samples;
for (size_t i = 0; i < NRSTUDENTS; i++) { for (size_t i = 0; i < NRSTUDENTS; i++) {
samples.push_back(samplers[i]->sample()); samples.push_back(samplers[i]->sample());
schedulers[i].accumulateStats(samples[i], stats); schedulers[i].accumulateStats(samples[i], stats);

View File

@ -212,7 +212,7 @@ void solveStaged(size_t addMutex = 2) {
root->print(""/*scheduler.studentName(s)*/); root->print(""/*scheduler.studentName(s)*/);
// solve root node only // solve root node only
Scheduler::Values values; DiscreteValues values;
size_t bestSlot = root->solve(values); size_t bestSlot = root->solve(values);
// get corresponding count // get corresponding count
@ -259,7 +259,7 @@ void sampleSolutions() {
// now, sample schedules // now, sample schedules
for (size_t n = 0; n < 10000; n++) { for (size_t n = 0; n < 10000; n++) {
vector<size_t> stats(nrFaculty, 0); vector<size_t> stats(nrFaculty, 0);
vector<Scheduler::Values> samples; vector<DiscreteValues> samples;
for (size_t i = 0; i < NRSTUDENTS; i++) { for (size_t i = 0; i < NRSTUDENTS; i++) {
samples.push_back(samplers[i]->sample()); samples.push_back(samplers[i]->sample());
schedulers[i].accumulateStats(samples[i], stats); schedulers[i].accumulateStats(samples[i], stats);

View File

@ -112,14 +112,14 @@ TEST(CSP, allInOne) {
csp.addAllDiff(UT, AZ); csp.addAllDiff(UT, AZ);
// Check an invalid combination, with ID==UT==AZ all same color // Check an invalid combination, with ID==UT==AZ all same color
DiscreteFactor::Values invalid; DiscreteValues invalid;
invalid[ID.first] = 0; invalid[ID.first] = 0;
invalid[UT.first] = 0; invalid[UT.first] = 0;
invalid[AZ.first] = 0; invalid[AZ.first] = 0;
EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9); EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9);
// Check a valid combination // Check a valid combination
DiscreteFactor::Values valid; DiscreteValues valid;
valid[ID.first] = 0; valid[ID.first] = 0;
valid[UT.first] = 1; valid[UT.first] = 1;
valid[AZ.first] = 0; valid[AZ.first] = 0;
@ -133,7 +133,7 @@ TEST(CSP, allInOne) {
// Solve // Solve
auto mpe = csp.optimalAssignment(); auto mpe = csp.optimalAssignment();
CSP::Values expected; DiscreteValues expected;
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1); insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1);
EXPECT(assert_equal(expected, mpe)); EXPECT(assert_equal(expected, mpe));
EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9); EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9);
@ -179,7 +179,7 @@ TEST(CSP, WesternUS) {
// Solve using that ordering: // Solve using that ordering:
auto mpe = csp.optimalAssignment(ordering); auto mpe = csp.optimalAssignment(ordering);
// GTSAM_PRINT(mpe); // GTSAM_PRINT(mpe);
CSP::Values expected; DiscreteValues expected;
insert(expected)(WA.first, 1)(CA.first, 1)(NV.first, 3)(OR.first, 0)( insert(expected)(WA.first, 1)(CA.first, 1)(NV.first, 3)(OR.first, 0)(
MT.first, 1)(WY.first, 0)(NM.first, 3)(CO.first, 2)(ID.first, 2)( MT.first, 1)(WY.first, 0)(NM.first, 3)(CO.first, 2)(ID.first, 2)(
UT.first, 1)(AZ.first, 0); UT.first, 1)(AZ.first, 0);
@ -213,14 +213,14 @@ TEST(CSP, ArcConsistency) {
// GTSAM_PRINT(csp); // GTSAM_PRINT(csp);
// Check an invalid combination, with ID==UT==AZ all same color // Check an invalid combination, with ID==UT==AZ all same color
DiscreteFactor::Values invalid; DiscreteValues invalid;
invalid[ID.first] = 0; invalid[ID.first] = 0;
invalid[UT.first] = 1; invalid[UT.first] = 1;
invalid[AZ.first] = 0; invalid[AZ.first] = 0;
EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9); EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9);
// Check a valid combination // Check a valid combination
DiscreteFactor::Values valid; DiscreteValues valid;
valid[ID.first] = 0; valid[ID.first] = 0;
valid[UT.first] = 1; valid[UT.first] = 1;
valid[AZ.first] = 2; valid[AZ.first] = 2;
@ -228,7 +228,7 @@ TEST(CSP, ArcConsistency) {
// Solve // Solve
auto mpe = csp.optimalAssignment(); auto mpe = csp.optimalAssignment();
CSP::Values expected; DiscreteValues expected;
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2); insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2);
EXPECT(assert_equal(expected, mpe)); EXPECT(assert_equal(expected, mpe));
EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9); EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9);
@ -250,7 +250,7 @@ TEST(CSP, ArcConsistency) {
LONGS_EQUAL(2, domains.at(2).nrValues()); LONGS_EQUAL(2, domains.at(2).nrValues());
// Parial application, version 1 // Parial application, version 1
DiscreteFactor::Values known; DiscreteValues known;
known[AZ.first] = 2; known[AZ.first] = 2;
DiscreteFactor::shared_ptr reduced1 = alldiff.partiallyApply(known); DiscreteFactor::shared_ptr reduced1 = alldiff.partiallyApply(known);
DecisionTreeFactor f3(ID & UT, "0 1 1 1 0 1 1 1 0"); DecisionTreeFactor f3(ID & UT, "0 1 1 1 0 1 1 1 0");

View File

@ -126,7 +126,7 @@ class LoopyBelief {
// normalize belief // normalize belief
double sum = 0.0; double sum = 0.0;
for (size_t v = 0; v < allDiscreteKeys.at(key).second; ++v) { for (size_t v = 0; v < allDiscreteKeys.at(key).second; ++v) {
DiscreteFactor::Values val; DiscreteValues val;
val[key] = v; val[key] = v;
sum += (*beliefAtKey)(val); sum += (*beliefAtKey)(val);
} }

View File

@ -88,7 +88,7 @@ class Sudoku : public CSP {
} }
/// Print readable form of assignment /// Print readable form of assignment
void printAssignment(const DiscreteFactor::Values& assignment) const { void printAssignment(const DiscreteValues& assignment) const {
for (size_t i = 0; i < n_; i++) { for (size_t i = 0; i < n_; i++) {
for (size_t j = 0; j < n_; j++) { for (size_t j = 0; j < n_; j++) {
Key k = key(i, j); Key k = key(i, j);
@ -127,7 +127,7 @@ TEST(Sudoku, small) {
// optimize and check // optimize and check
auto solution = csp.optimalAssignment(); auto solution = csp.optimalAssignment();
CSP::Values expected; DiscreteValues expected;
insert(expected)(csp.key(0, 0), 0)(csp.key(0, 1), 1)(csp.key(0, 2), 2)( insert(expected)(csp.key(0, 0), 0)(csp.key(0, 1), 1)(csp.key(0, 2), 2)(
csp.key(0, 3), 3)(csp.key(1, 0), 2)(csp.key(1, 1), 3)(csp.key(1, 2), 0)( csp.key(0, 3), 3)(csp.key(1, 0), 2)(csp.key(1, 1), 3)(csp.key(1, 2), 0)(
csp.key(1, 3), 1)(csp.key(2, 0), 3)(csp.key(2, 1), 2)(csp.key(2, 2), 1)( csp.key(1, 3), 1)(csp.key(2, 0), 3)(csp.key(2, 1), 2)(csp.key(2, 2), 1)(