Modernized sample function

release/4.3a0
Frank dellaert 2020-07-12 23:25:07 -04:00
parent 621e79f06c
commit 947d7377b4
1 changed files with 12 additions and 39 deletions

View File

@ -184,55 +184,28 @@ size_t DiscreteConditional::solve(const Values& parentsValues) const {
/* ******************************************************************************** */
size_t DiscreteConditional::sample(const Values& parentsValues) const {
static mt19937 rng(2); // random number generator
bool debug = ISDEBUG("DiscreteConditional::sample");
static mt19937 rng(2); // random number generator
// Get the correct conditional density
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
if (debug)
GTSAM_PRINT(pFS);
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
// get cumulative distribution function (cdf)
// TODO, only works for one key now, seems horribly slow this way
// TODO(Duy): only works for one key now, seems horribly slow this way
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
size_t nj = cardinality(j);
vector<double> cdf(nj);
Key key = firstFrontalKey();
size_t nj = cardinality(key);
vector<double> p(nj);
Values frontals;
double sum = 0;
for (size_t value = 0; value < nj; value++) {
frontals[j] = value;
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
sum += pValueS; // accumulate
if (debug)
cout << sum << " ";
if (pValueS == 1) {
if (debug)
cout << "--> " << value << endl;
return value; // shortcut exit
frontals[key] = value;
p[value] = pFS(frontals); // P(F=value|S=parentsValues)
if (p[value] == 1.0) {
return value; // shortcut exit
}
cdf[value] = sum;
}
// inspired by http://www.boost.org/doc/libs/1_46_1/doc/html/boost_random/tutorial.html
uniform_real_distribution<double> dist(0, cdf.back());
size_t sampled = lower_bound(cdf.begin(), cdf.end(), dist(rng)) - cdf.begin();
if (debug)
cout << "-> " << sampled << endl;
return sampled;
return 0;
std::discrete_distribution<size_t> distribution(p.begin(), p.end());
return distribution(rng);
}
/* ******************************************************************************** */
//void DiscreteConditional::permuteWithInverse(
// const Permutation& inversePermutation) {
// IndexConditionalOrdered::permuteWithInverse(inversePermutation);
// Potentials::permuteWithInverse(inversePermutation);
//}
/* ******************************************************************************** */
}// namespace