1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // DistributionSampler allows generating a discrete random variable with a given
17 // distribution.
18 // The values taken by the variable are [0, N) and relative weights for each
19 // value are specified using a vector of size N.
20 //
21 // The Algorithm takes O(N) time to precompute data at construction time and
22 // takes O(1) time (2 random number generation, 2 lookups) for each sample.
23 // The data structure takes O(N) memory.
24 //
25 // In contrast, util/random/weighted-picker.h provides O(lg N) sampling.
26 // The advantage of that implementation is that weights can be adjusted
27 // dynamically, while DistributionSampler doesn't allow weight adjustment.
28 //
29 // The algorithm used is Walker's Aliasing algorithm, described in Knuth, Vol 2.
30 
31 #ifndef TENSORFLOW_CORE_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_
32 #define TENSORFLOW_CORE_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_
33 
34 #include <memory>
35 #include <utility>
36 
37 #include "tensorflow/core/lib/gtl/array_slice.h"
38 #include "tensorflow/core/lib/random/simple_philox.h"
39 #include "tensorflow/core/platform/logging.h"
40 #include "tensorflow/core/platform/macros.h"
41 #include "tensorflow/core/platform/types.h"
42 
43 namespace tensorflow {
44 namespace random {
45 
46 class DistributionSampler {
47  public:
48   explicit DistributionSampler(const gtl::ArraySlice<float>& weights);
49 
~DistributionSampler()50   ~DistributionSampler() {}
51 
Sample(SimplePhilox * rand)52   int Sample(SimplePhilox* rand) const {
53     float r = rand->RandFloat();
54     // Since n is typically low, we don't bother with UnbiasedUniform.
55     int idx = rand->Uniform(num_);
56     if (r < prob(idx)) return idx;
57     // else pick alt from that bucket.
58     DCHECK_NE(-1, alt(idx));
59     return alt(idx);
60   }
61 
num()62   int num() const { return num_; }
63 
64  private:
prob(int idx)65   float prob(int idx) const {
66     DCHECK_LT(idx, num_);
67     return data_[idx].first;
68   }
69 
alt(int idx)70   int alt(int idx) const {
71     DCHECK_LT(idx, num_);
72     return data_[idx].second;
73   }
74 
set_prob(int idx,float f)75   void set_prob(int idx, float f) {
76     DCHECK_LT(idx, num_);
77     data_[idx].first = f;
78   }
79 
set_alt(int idx,int val)80   void set_alt(int idx, int val) {
81     DCHECK_LT(idx, num_);
82     data_[idx].second = val;
83   }
84 
85   int num_;
86   std::unique_ptr<std::pair<float, int>[]> data_;
87 
88   TF_DISALLOW_COPY_AND_ASSIGN(DistributionSampler);
89 };
90 
91 }  // namespace random
92 }  // namespace tensorflow
93 
94 #endif  // TENSORFLOW_CORE_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_
95