Made into class

release/4.3a0
Frank Dellaert 2019-06-15 18:42:54 -04:00 committed by Frank Dellaert
parent 7d86b073e6
commit f8af4a465d
2 changed files with 95 additions and 73 deletions

View File

@ -16,6 +16,9 @@
* @date May 2019 * @date May 2019
**/ **/
#pragma once
#include <cmath>
#include <queue> #include <queue>
#include <random> #include <random>
#include <stdexcept> #include <stdexcept>
@ -25,20 +28,36 @@
namespace gtsam { namespace gtsam {
/* /*
* Fast sampling without replacement. * Fast sampling without replacement.
* Example usage:
* std::mt19937 rng(42);
* WeightedSampler<std::mt19937> sampler(&rng);
* auto samples = sampler.sampleWithoutReplacement(5, weights);
*/ */
template <class Engine> template <class Engine = std::mt19937>
std::vector<size_t> sampleWithoutReplacement(Engine& rng, size_t s, class WeightedSampler {
private:
Engine* engine_; // random number generation engine
public:
/**
* Construct from random number generation engine
* We only store a pointer to it.
*/
explicit WeightedSampler(Engine* engine) : engine_(engine) {}
std::vector<size_t> sampleWithoutReplacement(size_t numSamples,
std::vector<double> weights) { std::vector<double> weights) {
// Implementation adapted from paper at // Implementation adapted from code accompanying paper at
// https://www.ethz.ch/content/dam/ethz/special-interest/baug/ivt/ivt-dam/vpl/reports/1101-1200/ab1141.pdf // https://www.ethz.ch/content/dam/ethz/special-interest/baug/ivt/ivt-dam/vpl/reports/1101-1200/ab1141.pdf
const size_t n = weights.size(); const size_t n = weights.size();
if (n < s) { if (n < numSamples) {
throw std::runtime_error("s must be smaller than weights.size()"); throw std::runtime_error(
"numSamples must be smaller than weights.size()");
} }
// Return empty array if s==0 // Return empty array if numSamples==0
std::vector<size_t> result(s); std::vector<size_t> result(numSamples);
if (s == 0) return result; if (numSamples == 0) return result;
// Step 1: The first m items of V are inserted into reservoir // Step 1: The first m items of V are inserted into reservoir
// Step 2: For each item v_i ∈ reservoir: Calculate a key k_i = u_i^(1/w), // Step 2: For each item v_i ∈ reservoir: Calculate a key k_i = u_i^(1/w),
@ -47,10 +66,10 @@ std::vector<size_t> sampleWithoutReplacement(Engine& rng, size_t s,
// reservoir is a priority queue that pops the *maximum* elements) // reservoir is a priority queue that pops the *maximum* elements)
std::priority_queue<std::pair<double, size_t> > reservoir; std::priority_queue<std::pair<double, size_t> > reservoir;
static const double kexp1 = exp(1.0); static const double kexp1 = std::exp(1.0);
for (auto iprob = weights.begin(); iprob != weights.begin() + s; ++iprob) { for (auto it = weights.begin(); it != weights.begin() + numSamples; ++it) {
double k_i = kexp1 / *iprob; const double k_i = kexp1 / *it;
reservoir.push(std::make_pair(k_i, iprob - weights.begin() + 1)); reservoir.push(std::make_pair(k_i, it - weights.begin() + 1));
} }
// Step 4: Repeat Steps 510 until the population is exhausted // Step 4: Repeat Steps 510 until the population is exhausted
@ -60,37 +79,38 @@ std::vector<size_t> sampleWithoutReplacement(Engine& rng, size_t s,
// Step 10: The new threshold T w is the new minimum key of reservoir // Step 10: The new threshold T w is the new minimum key of reservoir
const std::pair<double, size_t>& T_w = reservoir.top(); const std::pair<double, size_t>& T_w = reservoir.top();
// Incrementing iprob is part of Step 7 // Incrementing it is part of Step 7
for (auto iprob = weights.begin() + s; iprob != weights.end(); ++iprob) { for (auto it = weights.begin() + numSamples; it != weights.end(); ++it) {
// Step 5: Let r = random(0, 1) and X_w = log(r) / log(T_w) // Step 5: Let r = random(0, 1) and X_w = log(r) / log(T_w)
// (Modification: Use e = -exp(1) instead of log(r)) // (Modification: Use e = -exp(1) instead of log(r))
double X_w = kexp1 / T_w.first; const double X_w = kexp1 / T_w.first;
// Step 6: From the current item v_c skip items until item v_i, such that: // Step 6: From the current item v_c skip items until item v_i, such
// that:
double w = 0.0; double w = 0.0;
// Step 7: w_c + w_{c+1} + ··· + w_{i1} < X_w <= w_c + w_{c+1} + ··· + // Step 7: w_c + w_{c+1} + ··· + w_{i1} < X_w <= w_c + w_{c+1} + ··· +
// w_{i1} + w_i // w_{i1} + w_i
for (; iprob != weights.end(); ++iprob) { for (; it != weights.end(); ++it) {
w += *iprob; w += *it;
if (X_w <= w) break; if (X_w <= w) break;
} }
// Step 7: No such item, terminate // Step 7: No such item, terminate
if (iprob == weights.end()) break; if (it == weights.end()) break;
// Step 9: Let t_w = T_w^{w_i}, r_2 = random(t_w, 1) and v_is key: k_i = // Step 9: Let t_w = T_w^{w_i}, r_2 = random(t_w, 1) and v_is key: k_i
// (r_2)^{1/w_i} (Mod: Let t_w = log(T_w) * {w_i}, e_2 = // = (r_2)^{1/w_i} (Mod: Let t_w = log(T_w) * {w_i}, e_2 =
// log(random(e^{t_w}, 1)) and v_is key: k_i = -e_2 / w_i) // log(random(e^{t_w}, 1)) and v_is key: k_i = -e_2 / w_i)
double t_w = -T_w.first * *iprob; const double t_w = -T_w.first * *it;
std::uniform_real_distribution<double> randomAngle(exp(t_w), 1.0); std::uniform_real_distribution<double> randomAngle(std::exp(t_w), 1.0);
double e_2 = log(randomAngle(rng)); const double e_2 = std::log(randomAngle(*engine_));
double k_i = -e_2 / *iprob; const double k_i = -e_2 / *it;
// Step 8: The item in reservoir with the minimum key is replaced by item // Step 8: The item in reservoir with the minimum key is replaced by
// v_i // item v_i
reservoir.pop(); reservoir.pop();
reservoir.push(std::make_pair(k_i, iprob - weights.begin() + 1)); reservoir.push(std::make_pair(k_i, it - weights.begin() + 1));
} }
} }
@ -113,4 +133,5 @@ std::vector<size_t> sampleWithoutReplacement(Engine& rng, size_t s,
return result; return result;
} }
}; // namespace gtsam
} // namespace gtsam } // namespace gtsam

View File

@ -27,8 +27,9 @@ using namespace gtsam;
TEST(WeightedSampler, sampleWithoutReplacement) { TEST(WeightedSampler, sampleWithoutReplacement) {
vector<double> weights{1, 2, 3, 4, 3, 2, 1}; vector<double> weights{1, 2, 3, 4, 3, 2, 1};
mt19937 rng(42); std::mt19937 rng(42);
auto samples = sampleWithoutReplacement(rng, 5, weights); WeightedSampler<std::mt19937> sampler(&rng);
auto samples = sampler.sampleWithoutReplacement(5, weights);
EXPECT_LONGS_EQUAL(5, samples.size()); EXPECT_LONGS_EQUAL(5, samples.size());
} }