1 // Copyright 2016 Google Inc. All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef SRC_WEIGHTED_RESERVOIR_SAMPLER_H_
16 #define SRC_WEIGHTED_RESERVOIR_SAMPLER_H_
17 
18 #include <cassert>
19 #include <random>
20 
21 namespace protobuf_mutator {
22 
23 // Algorithm pick one item from the sequence of weighted items.
24 // https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_A-Chao
25 //
26 // Example:
27 //   WeightedReservoirSampler<int> sampler;
28 //   for(int i = 0; i < size; ++i)
29 //     sampler.Pick(weight[i], i);
30 //   return sampler.GetSelected();
31 template <class T, class RandomEngine = std::default_random_engine>
32 class WeightedReservoirSampler {
33  public:
WeightedReservoirSampler(RandomEngine * random)34   explicit WeightedReservoirSampler(RandomEngine* random) : random_(random) {}
35 
Try(uint64_t weight,const T & item)36   void Try(uint64_t weight, const T& item) {
37     if (Pick(weight)) selected_ = item;
38   }
39 
selected()40   const T& selected() const { return selected_; }
41 
IsEmpty()42   bool IsEmpty() const { return total_weight_ == 0; }
43 
44  private:
Pick(uint64_t weight)45   bool Pick(uint64_t weight) {
46     if (weight == 0) return false;
47     total_weight_ += weight;
48     return weight == total_weight_ || std::uniform_int_distribution<uint64_t>(
49                                           1, total_weight_)(*random_) <= weight;
50   }
51 
52   T selected_ = {};
53   uint64_t total_weight_ = 0;
54   RandomEngine* random_;
55 };
56 
57 }  // namespace protobuf_mutator
58 
59 #endif  // SRC_WEIGHTED_RESERVOIR_SAMPLER_H_
60