Made into class
							parent
							
								
									7d86b073e6
								
							
						
					
					
						commit
						f8af4a465d
					
				|  | @ -16,6 +16,9 @@ | |||
|  * @date    May 2019 | ||||
|  **/ | ||||
| 
 | ||||
| #pragma once | ||||
| 
 | ||||
| #include <cmath> | ||||
| #include <queue> | ||||
| #include <random> | ||||
| #include <stdexcept> | ||||
|  | @ -25,92 +28,110 @@ | |||
| namespace gtsam { | ||||
| /*
 | ||||
|  * Fast sampling without replacement. | ||||
|  * Example usage: | ||||
|  *   std::mt19937 rng(42); | ||||
|  *   WeightedSampler<std::mt19937> sampler(&rng); | ||||
|  *   auto samples = sampler.sampleWithoutReplacement(5, weights); | ||||
|  */ | ||||
| template <class Engine> | ||||
| std::vector<size_t> sampleWithoutReplacement(Engine& rng, size_t s, | ||||
|                                              std::vector<double> weights) { | ||||
|   // Implementation adapted from paper at
 | ||||
|   // https://www.ethz.ch/content/dam/ethz/special-interest/baug/ivt/ivt-dam/vpl/reports/1101-1200/ab1141.pdf
 | ||||
|   const size_t n = weights.size(); | ||||
|   if (n < s) { | ||||
|     throw std::runtime_error("s must be smaller than weights.size()"); | ||||
|   } | ||||
| template <class Engine = std::mt19937> | ||||
| class WeightedSampler { | ||||
|  private: | ||||
|   Engine* engine_;  // random number generation engine
 | ||||
| 
 | ||||
|   // Return empty array if s==0
 | ||||
|   std::vector<size_t> result(s); | ||||
|   if (s == 0) return result; | ||||
|  public: | ||||
|   /**
 | ||||
|    * Construct from random number generation engine | ||||
|    * We only store a pointer to it. | ||||
|    */ | ||||
|   explicit WeightedSampler(Engine* engine) : engine_(engine) {} | ||||
| 
 | ||||
|   // Step 1: The first m items of V are inserted into reservoir
 | ||||
|   // Step 2: For each item v_i ∈ reservoir: Calculate a key k_i = u_i^(1/w),
 | ||||
|   // where u_i = random(0, 1)
 | ||||
|   // (Modification: Calculate and store -log k_i = e_i / w where e_i = exp(1),
 | ||||
|   //  reservoir is a priority queue that pops the *maximum* elements)
 | ||||
|   std::priority_queue<std::pair<double, size_t> > reservoir; | ||||
|   std::vector<size_t> sampleWithoutReplacement(size_t numSamples, | ||||
|                                                std::vector<double> weights) { | ||||
|     // Implementation adapted from code accompanying paper at
 | ||||
|     // https://www.ethz.ch/content/dam/ethz/special-interest/baug/ivt/ivt-dam/vpl/reports/1101-1200/ab1141.pdf
 | ||||
|     const size_t n = weights.size(); | ||||
|     if (n < numSamples) { | ||||
|       throw std::runtime_error( | ||||
|           "numSamples must be smaller than weights.size()"); | ||||
|     } | ||||
| 
 | ||||
|   static const double kexp1 = exp(1.0); | ||||
|   for (auto iprob = weights.begin(); iprob != weights.begin() + s; ++iprob) { | ||||
|     double k_i = kexp1 / *iprob; | ||||
|     reservoir.push(std::make_pair(k_i, iprob - weights.begin() + 1)); | ||||
|   } | ||||
|     // Return empty array if numSamples==0
 | ||||
|     std::vector<size_t> result(numSamples); | ||||
|     if (numSamples == 0) return result; | ||||
| 
 | ||||
|   // Step 4: Repeat Steps 5–10 until the population is exhausted
 | ||||
|   { | ||||
|     // Step 3: The threshold T_w is the minimum key of reservoir
 | ||||
|     // (Modification: This is now the logarithm)
 | ||||
|     // Step 10: The new threshold T w is the new minimum key of reservoir
 | ||||
|     const std::pair<double, size_t>& T_w = reservoir.top(); | ||||
|     // Step 1: The first m items of V are inserted into reservoir
 | ||||
|     // Step 2: For each item v_i ∈ reservoir: Calculate a key k_i = u_i^(1/w),
 | ||||
|     // where u_i = random(0, 1)
 | ||||
|     // (Modification: Calculate and store -log k_i = e_i / w where e_i = exp(1),
 | ||||
|     //  reservoir is a priority queue that pops the *maximum* elements)
 | ||||
|     std::priority_queue<std::pair<double, size_t> > reservoir; | ||||
| 
 | ||||
|     // Incrementing iprob is part of Step 7
 | ||||
|     for (auto iprob = weights.begin() + s; iprob != weights.end(); ++iprob) { | ||||
|       // Step 5: Let r = random(0, 1) and X_w = log(r) / log(T_w)
 | ||||
|       // (Modification: Use e = -exp(1) instead of log(r))
 | ||||
|       double X_w = kexp1 / T_w.first; | ||||
|     static const double kexp1 = std::exp(1.0); | ||||
|     for (auto it = weights.begin(); it != weights.begin() + numSamples; ++it) { | ||||
|       const double k_i = kexp1 / *it; | ||||
|       reservoir.push(std::make_pair(k_i, it - weights.begin() + 1)); | ||||
|     } | ||||
| 
 | ||||
|       // Step 6: From the current item v_c skip items until item v_i, such that:
 | ||||
|       double w = 0.0; | ||||
|     // Step 4: Repeat Steps 5–10 until the population is exhausted
 | ||||
|     { | ||||
|       // Step 3: The threshold T_w is the minimum key of reservoir
 | ||||
|       // (Modification: This is now the logarithm)
 | ||||
|       // Step 10: The new threshold T w is the new minimum key of reservoir
 | ||||
|       const std::pair<double, size_t>& T_w = reservoir.top(); | ||||
| 
 | ||||
|       // Step 7: w_c + w_{c+1} + ··· + w_{i−1} < X_w <= w_c + w_{c+1} + ··· +
 | ||||
|       // w_{i−1} + w_i
 | ||||
|       for (; iprob != weights.end(); ++iprob) { | ||||
|         w += *iprob; | ||||
|         if (X_w <= w) break; | ||||
|       // Incrementing it is part of Step 7
 | ||||
|       for (auto it = weights.begin() + numSamples; it != weights.end(); ++it) { | ||||
|         // Step 5: Let r = random(0, 1) and X_w = log(r) / log(T_w)
 | ||||
|         // (Modification: Use e = -exp(1) instead of log(r))
 | ||||
|         const double X_w = kexp1 / T_w.first; | ||||
| 
 | ||||
|         // Step 6: From the current item v_c skip items until item v_i, such
 | ||||
|         // that:
 | ||||
|         double w = 0.0; | ||||
| 
 | ||||
|         // Step 7: w_c + w_{c+1} + ··· + w_{i−1} < X_w <= w_c + w_{c+1} + ··· +
 | ||||
|         // w_{i−1} + w_i
 | ||||
|         for (; it != weights.end(); ++it) { | ||||
|           w += *it; | ||||
|           if (X_w <= w) break; | ||||
|         } | ||||
| 
 | ||||
|         // Step 7: No such item, terminate
 | ||||
|         if (it == weights.end()) break; | ||||
| 
 | ||||
|         // Step 9: Let t_w = T_w^{w_i}, r_2 = random(t_w, 1) and v_i’s key: k_i
 | ||||
|         // = (r_2)^{1/w_i} (Mod: Let t_w = log(T_w) * {w_i}, e_2 =
 | ||||
|         // log(random(e^{t_w}, 1)) and v_i’s key: k_i = -e_2 / w_i)
 | ||||
|         const double t_w = -T_w.first * *it; | ||||
|         std::uniform_real_distribution<double> randomAngle(std::exp(t_w), 1.0); | ||||
|         const double e_2 = std::log(randomAngle(*engine_)); | ||||
|         const double k_i = -e_2 / *it; | ||||
| 
 | ||||
|         // Step 8: The item in reservoir with the minimum key is replaced by
 | ||||
|         // item v_i
 | ||||
|         reservoir.pop(); | ||||
|         reservoir.push(std::make_pair(k_i, it - weights.begin() + 1)); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     for (auto iret = result.end(); iret != result.begin();) { | ||||
|       --iret; | ||||
| 
 | ||||
|       if (reservoir.empty()) { | ||||
|         throw std::runtime_error( | ||||
|             "Reservoir empty before all elements have been filled"); | ||||
|       } | ||||
| 
 | ||||
|       // Step 7: No such item, terminate
 | ||||
|       if (iprob == weights.end()) break; | ||||
| 
 | ||||
|       // Step 9: Let t_w = T_w^{w_i}, r_2 = random(t_w, 1) and v_i’s key: k_i =
 | ||||
|       // (r_2)^{1/w_i} (Mod: Let t_w = log(T_w) * {w_i}, e_2 =
 | ||||
|       // log(random(e^{t_w}, 1)) and v_i’s key: k_i = -e_2 / w_i)
 | ||||
|       double t_w = -T_w.first * *iprob; | ||||
|       std::uniform_real_distribution<double> randomAngle(exp(t_w), 1.0); | ||||
|       double e_2 = log(randomAngle(rng)); | ||||
|       double k_i = -e_2 / *iprob; | ||||
| 
 | ||||
|       // Step 8: The item in reservoir with the minimum key is replaced by item
 | ||||
|       // v_i
 | ||||
|       *iret = reservoir.top().second; | ||||
|       reservoir.pop(); | ||||
|       reservoir.push(std::make_pair(k_i, iprob - weights.begin() + 1)); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   for (auto iret = result.end(); iret != result.begin();) { | ||||
|     --iret; | ||||
| 
 | ||||
|     if (reservoir.empty()) { | ||||
|     if (!reservoir.empty()) { | ||||
|       throw std::runtime_error( | ||||
|           "Reservoir empty before all elements have been filled"); | ||||
|           "Reservoir not empty after all elements have been filled"); | ||||
|     } | ||||
| 
 | ||||
|     *iret = reservoir.top().second; | ||||
|     reservoir.pop(); | ||||
|     return result; | ||||
|   } | ||||
| 
 | ||||
|   if (!reservoir.empty()) { | ||||
|     throw std::runtime_error( | ||||
|         "Reservoir not empty after all elements have been filled"); | ||||
|   } | ||||
| 
 | ||||
|   return result; | ||||
| } | ||||
| };  // namespace gtsam
 | ||||
| }  // namespace gtsam
 | ||||
|  |  | |||
|  | @ -27,8 +27,9 @@ using namespace gtsam; | |||
| 
 | ||||
| TEST(WeightedSampler, sampleWithoutReplacement) { | ||||
|   vector<double> weights{1, 2, 3, 4, 3, 2, 1}; | ||||
|   mt19937 rng(42); | ||||
|   auto samples = sampleWithoutReplacement(rng, 5, weights); | ||||
|   std::mt19937 rng(42); | ||||
|   WeightedSampler<std::mt19937> sampler(&rng); | ||||
|   auto samples = sampler.sampleWithoutReplacement(5, weights); | ||||
|   EXPECT_LONGS_EQUAL(5, samples.size()); | ||||
| } | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue