1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/lib/random/weighted_picker.h"
17 
18 #include <string.h>
19 #include <algorithm>
20 
21 #include "tensorflow/core/lib/random/simple_philox.h"
22 
23 namespace tensorflow {
24 namespace random {
25 
WeightedPicker(int N)26 WeightedPicker::WeightedPicker(int N) {
27   CHECK_GE(N, 0);
28   N_ = N;
29 
30   // Find the number of levels
31   num_levels_ = 1;
32   while (LevelSize(num_levels_ - 1) < N) {
33     num_levels_++;
34   }
35 
36   // Initialize the levels
37   level_ = new int32*[num_levels_];
38   for (int l = 0; l < num_levels_; l++) {
39     level_[l] = new int32[LevelSize(l)];
40   }
41 
42   SetAllWeights(1);
43 }
44 
~WeightedPicker()45 WeightedPicker::~WeightedPicker() {
46   for (int l = 0; l < num_levels_; l++) {
47     delete[] level_[l];
48   }
49   delete[] level_;
50 }
51 
UnbiasedUniform(SimplePhilox * r,int32 n)52 static int32 UnbiasedUniform(SimplePhilox* r, int32 n) {
53   CHECK_LE(0, n);
54   const uint32 range = ~static_cast<uint32>(0);
55   if (n == 0) {
56     return r->Rand32() * n;
57   } else if (0 == (n & (n - 1))) {
58     // N is a power of two, so just mask off the lower bits.
59     return r->Rand32() & (n - 1);
60   } else {
61     // Reject all numbers that skew the distribution towards 0.
62 
63     // Rand32's output is uniform in the half-open interval [0, 2^{32}).
64     // For any interval [m,n), the number of elements in it is n-m.
65 
66     uint32 rem = (range % n) + 1;
67     uint32 rnd;
68 
69     // rem = ((2^{32}-1) \bmod n) + 1
70     // 1 <= rem <= n
71 
72     // NB: rem == n is impossible, since n is not a power of 2 (from
73     // earlier check).
74 
75     do {
76       rnd = r->Rand32();  // rnd uniform over [0, 2^{32})
77     } while (rnd < rem);  // reject [0, rem)
78     // rnd is uniform over [rem, 2^{32})
79     //
80     // The number of elements in the half-open interval is
81     //
82     //  2^{32} - rem = 2^{32} - ((2^{32}-1) \bmod n) - 1
83     //               = 2^{32}-1 - ((2^{32}-1) \bmod n)
84     //               = n \cdot \lfloor (2^{32}-1)/n \rfloor
85     //
86     // therefore n evenly divides the number of integers in the
87     // interval.
88     //
89     // The function v \rightarrow v % n takes values from [bias,
90     // 2^{32}) to [0, n).  Each integer in the range interval [0, n)
91     // will have exactly \lfloor (2^{32}-1)/n \rfloor preimages from
92     // the domain interval.
93     //
94     // Therefore, v % n is uniform over [0, n).  QED.
95 
96     return rnd % n;
97   }
98 }
99 
Pick(SimplePhilox * rnd) const100 int WeightedPicker::Pick(SimplePhilox* rnd) const {
101   if (total_weight() == 0) return -1;
102 
103   // using unbiased uniform distribution to avoid bias
104   // toward low elements resulting from a possible use
105   // of big weights.
106   return PickAt(UnbiasedUniform(rnd, total_weight()));
107 }
108 
PickAt(int32 weight_index) const109 int WeightedPicker::PickAt(int32 weight_index) const {
110   if (weight_index < 0 || weight_index >= total_weight()) return -1;
111 
112   int32 position = weight_index;
113   int index = 0;
114 
115   for (int l = 1; l < num_levels_; l++) {
116     // Pick left or right child of "level_[l-1][index]"
117     const int32 left_weight = level_[l][2 * index];
118     if (position < left_weight) {
119       // Descend to left child
120       index = 2 * index;
121     } else {
122       // Descend to right child
123       index = 2 * index + 1;
124       position -= left_weight;
125     }
126   }
127   CHECK_GE(index, 0);
128   CHECK_LT(index, N_);
129   CHECK_LE(position, level_[num_levels_ - 1][index]);
130   return index;
131 }
132 
set_weight(int index,int32 weight)133 void WeightedPicker::set_weight(int index, int32 weight) {
134   assert(index >= 0);
135   assert(index < N_);
136 
137   // Adjust the sums all the way up to the root
138   const int32 delta = weight - get_weight(index);
139   for (int l = num_levels_ - 1; l >= 0; l--) {
140     level_[l][index] += delta;
141     index >>= 1;
142   }
143 }
144 
SetAllWeights(int32 weight)145 void WeightedPicker::SetAllWeights(int32 weight) {
146   // Initialize leaves
147   int32* leaves = level_[num_levels_ - 1];
148   for (int i = 0; i < N_; i++) leaves[i] = weight;
149   for (int i = N_; i < LevelSize(num_levels_ - 1); i++) leaves[i] = 0;
150 
151   // Now sum up towards the root
152   RebuildTreeWeights();
153 }
154 
SetWeightsFromArray(int N,const int32 * weights)155 void WeightedPicker::SetWeightsFromArray(int N, const int32* weights) {
156   Resize(N);
157 
158   // Initialize leaves
159   int32* leaves = level_[num_levels_ - 1];
160   for (int i = 0; i < N_; i++) leaves[i] = weights[i];
161   for (int i = N_; i < LevelSize(num_levels_ - 1); i++) leaves[i] = 0;
162 
163   // Now sum up towards the root
164   RebuildTreeWeights();
165 }
166 
RebuildTreeWeights()167 void WeightedPicker::RebuildTreeWeights() {
168   for (int l = num_levels_ - 2; l >= 0; l--) {
169     int32* level = level_[l];
170     int32* children = level_[l + 1];
171     for (int i = 0; i < LevelSize(l); i++) {
172       level[i] = children[2 * i] + children[2 * i + 1];
173     }
174   }
175 }
176 
Append(int32 weight)177 void WeightedPicker::Append(int32 weight) {
178   Resize(num_elements() + 1);
179   set_weight(num_elements() - 1, weight);
180 }
181 
Resize(int new_size)182 void WeightedPicker::Resize(int new_size) {
183   CHECK_GE(new_size, 0);
184   if (new_size <= LevelSize(num_levels_ - 1)) {
185     // The new picker fits in the existing levels.
186 
187     // First zero out any of the weights that are being dropped so
188     // that the levels are correct (only needed when shrinking)
189     for (int i = new_size; i < N_; i++) {
190       set_weight(i, 0);
191     }
192 
193     // We do not need to set any new weights when enlarging because
194     // the unneeded entries always have weight zero.
195     N_ = new_size;
196     return;
197   }
198 
199   // We follow the simple strategy of just copying the old
200   // WeightedPicker into a new WeightedPicker.  The cost is
201   // O(N) regardless.
202   assert(new_size > N_);
203   WeightedPicker new_picker(new_size);
204   int32* dst = new_picker.level_[new_picker.num_levels_ - 1];
205   int32* src = this->level_[this->num_levels_ - 1];
206   memcpy(dst, src, sizeof(dst[0]) * N_);
207   memset(dst + N_, 0, sizeof(dst[0]) * (new_size - N_));
208   new_picker.RebuildTreeWeights();
209 
210   // Now swap the two pickers
211   std::swap(new_picker.N_, this->N_);
212   std::swap(new_picker.num_levels_, this->num_levels_);
213   std::swap(new_picker.level_, this->level_);
214   assert(this->N_ == new_size);
215 }
216 
217 }  // namespace random
218 }  // namespace tensorflow
219