1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/lib/random/random_distributions.h"
17
18 #include <algorithm>
19 #include <cmath>
20 #include <functional>
21 #include <numeric>
22 #include <unordered_map>
23 #include <vector>
24
25 #include "tensorflow/core/lib/math/math_util.h"
26 #include "tensorflow/core/lib/random/philox_random.h"
27 #include "tensorflow/core/lib/random/philox_random_test_utils.h"
28 #include "tensorflow/core/lib/random/random.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/test.h"
31
32 namespace tensorflow {
33 namespace random {
34 namespace {
35
36 // The largest z-value we want to tolerate. Since the z-test approximates a
37 // unit normal distribution, it should almost definitely never exceed 6.
38 static constexpr float kZLimit = 6.0;
39
40 // As bfloat16 has much less precision, the largest z-value will should be
41 // larger than float32.
42 static constexpr float kZLimitBfloat16 = 20.0;
43
44 // A utility function to fill the given array with samples from the given
45 // distribution, using the single adapter of the underlying generator
46 template <class Distribution>
FillRandomsWithSingles(PhiloxRandom gen,typename Distribution::ResultElementType * p,int64 size)47 void FillRandomsWithSingles(PhiloxRandom gen,
48 typename Distribution::ResultElementType* p,
49 int64 size) {
50 int granularity = Distribution::kResultElementCount;
51
52 CHECK(size % granularity == 0)
53 << " size: " << size << " granularity: " << granularity;
54
55 SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
56
57 Distribution dist;
58 for (int i = 0; i < size; i += granularity) {
59 auto sample = dist(&single_samples);
60 std::copy(&sample[0], &sample[0] + granularity, &p[i]);
61 }
62 }
63
64 // Check the given array of samples matches the given theoretical moment
65 // function at different orders. The test is considered passing if the z-tests
66 // of all statistical moments are all below z_limit.
67 // typename T in the template argument could be either float or double.
68 // Arguments:
69 // samples: an array of samples to be tested for their statistical properties;
70 // theoretical_moments: a functor that can calculate arbitrary order of
71 // of the given distribution;
72 // max_moments: the largest moments of the uniform distribution to be tested;
73 // stride: the distance between samples to check for statistical properties
74 // 0 means the n-th moment of each sample
75 // any other strides tests for spatial correlation between samples;
76 // z_limit: the maximum z-test we would consider the test to pass;
77 template <typename T>
CheckSamplesMoments(const std::vector<T> & samples,const std::function<double (int)> & theoretical_moments,int max_moments,int stride,T z_limit)78 bool CheckSamplesMoments(const std::vector<T>& samples,
79 const std::function<double(int)>& theoretical_moments,
80 int max_moments, int stride, T z_limit) {
81 const T* const samples_data = &samples[0];
82 const int samples_size = samples.size();
83 std::vector<double> moments(max_moments + 1);
84 double* const moments_data = &moments[0];
85 std::vector<int> moments_sample_count(max_moments + 1);
86 int* const moments_sample_count_data = &moments_sample_count[0];
87
88 for (int k = 0; k < samples_size; ++k) {
89 double moment = 1.;
90 for (int i = 0; i <= max_moments; ++i) {
91 int index = k + i * stride;
92 if (index >= samples_size) {
93 break;
94 }
95 // moments[i] store the i-th order measured moments.
96 // bypass std::vector::operator[] because they are too slow in the debug
97 // mode, given the large number of samples.
98 moments_data[i] += moment;
99 ++moments_sample_count_data[i];
100 moment *= static_cast<double>(samples_data[index]);
101 }
102 }
103
104 // normalize the moments
105 for (int i = 0; i <= max_moments; ++i) {
106 moments[i] /= moments_sample_count[i];
107 }
108
109 bool status = true;
110
111 for (int i = 1; i <= max_moments; ++i) {
112 // Calculate the theoretical mean and variance
113 const double moments_i_mean =
114 (stride == 0) ? theoretical_moments(i)
115 : MathUtil::IPow(theoretical_moments(1), i);
116 const double moments_i_squared =
117 (stride == 0) ? theoretical_moments(2 * i)
118 : MathUtil::IPow(theoretical_moments(2), i);
119 const double moments_i_var =
120 moments_i_squared - moments_i_mean * moments_i_mean;
121
122 // assume every operation has a small numerical error.
123 static const double kNumericalError = 1e-6;
124 // it takes i multiplications to calculate one i-th moment.
125 const double error_per_moment = i * kNumericalError;
126 const double total_variance =
127 moments_i_var / moments_sample_count[i] + error_per_moment;
128 // z_test is approximately a unit normal distribution.
129 const double z_test =
130 fabs((moments[i] - moments_i_mean) / sqrt(total_variance));
131
132 if (z_test > static_cast<double>(z_limit)) {
133 LOG(ERROR) << "failing z_test:"
134 << " moment: " << i << " stride: " << stride
135 << " z_test: " << z_test << " z_limit: " << z_limit
136 << " measured moments: " << moments[i]
137 << " theoretical mean of the moments: " << moments_i_mean
138 << " theoretical var of the moments: " << moments_i_var
139 << " sample count: " << moments_sample_count[i];
140 status = false;
141 }
142 }
143
144 return status;
145 }
146
147 // This tests checks that the generated samples match the theoretical moments
148 // of the uniform distribution.
149 template <typename T>
UniformMomentsTest(int count,int max_moments,const std::vector<int> & strides,T z_limit)150 void UniformMomentsTest(int count, int max_moments,
151 const std::vector<int>& strides, T z_limit) {
152 auto uniform_moments = [](int n) -> double { return 1. / (n + 1); };
153
154 std::vector<T> v1(count);
155 uint64 seed = GetTestSeed();
156 PhiloxRandom gen(seed);
157 FillRandoms<UniformDistribution<PhiloxRandom, T> >(gen, &v1[0], v1.size());
158 for (int stride : strides) {
159 bool status =
160 CheckSamplesMoments(v1, uniform_moments, max_moments, stride, z_limit);
161 ASSERT_TRUE(status) << " UniformMomentsTest failing. seed: " << seed;
162 }
163 }
164
165 // This test checks that the generated samples match the theoretical moments
166 // of the unit normal distribution.
167 template <typename T>
NormalMomentsTest(int count,int max_moments,const std::vector<int> & strides,T z_limit)168 void NormalMomentsTest(int count, int max_moments,
169 const std::vector<int>& strides, T z_limit) {
170 auto normal_moments = [](int n) -> double {
171 if (n % 2 == 1) {
172 // For an odd order, the moment of a unit normal distribution is zero.
173 return 0.;
174 } else {
175 // For an even order, the moment of a unit normal distribution is.
176 // (n-1)!!
177 double v = 1.;
178 for (int i = n - 1; i >= 1; i -= 2) {
179 v *= i;
180 }
181 return v;
182 }
183 };
184
185 std::vector<T> v1(count);
186 uint64 seed = GetTestSeed();
187 PhiloxRandom gen(seed);
188 FillRandoms<NormalDistribution<PhiloxRandom, T> >(gen, &v1[0], v1.size());
189
190 for (int stride : strides) {
191 bool status =
192 CheckSamplesMoments(v1, normal_moments, max_moments, stride, z_limit);
193 ASSERT_TRUE(status) << " NormalMomentsTest failing. seed: " << seed;
194 }
195 }
196
197 // A functor to calculate the moments for the truncated normal distribution.
198 // For any odd order, the moment is zero. But for any other n, it can be proven
199 // that the following recursive relationship for the moments of the truncated
200 // standard normal:
201 // m(n) = (n - 1) * m(n - 2) - 2 * v ^ (n - 1) * f(v) / (2 * Phi(v) - 1)
202 // where v is the cut-off value, f(v) is the p.d.f of the standard
203 // normal, and Phi(v) is the c.d.f of the standard normal.
204 class TruncatedNormalMoments {
205 public:
operator ()(int n)206 double operator()(int n) {
207 if (n == 0) {
208 return 1;
209 }
210 if (n % 2 == 1) {
211 // For an odd order, the moment is always zero
212 return 0.;
213 }
214
215 // Memoization and check the cached results.
216 auto iter = cached_results_.find(n);
217 if (iter != cached_results_.end()) {
218 return iter->second;
219 }
220
221 // The real computation of the moment.
222 double bias = 2.0 * MathUtil::IPow(kV, n - 1) * kFV / (2.0 * kPhiV - 1.0);
223 double moment_n_minus_2 = (*this)(n - 2);
224 double moment_n = (n - 1) * moment_n_minus_2 - bias;
225
226 cached_results_[n] = moment_n;
227 return moment_n;
228 }
229
230 private:
231 const double kV = 2.0;
232 // f(v), where f is the p.d.f of the normal distribution and v=2.
233 const double kFV = 1.0 / sqrt(2.0 * M_PI) * exp(-kV * kV / 2.0);
234 // The numerical evaluation of Phi(v), where v is the truncate value.
235 // v = 2 in the current implementation.
236 const double kPhiV = 0.977249868051821;
237 std::unordered_map<int, double> cached_results_;
238 };
239
240 // This test checks that the generated samples matche the theoretical moments
241 // of the truncated normal distribution.
242 template <typename T>
RandomParametersMomentsTest(int count,int max_moments,const std::vector<int> & strides,T z_limit)243 void RandomParametersMomentsTest(int count, int max_moments,
244 const std::vector<int>& strides, T z_limit) {
245 std::vector<T> v1(count);
246 uint64 seed = GetTestSeed();
247 PhiloxRandom gen(seed);
248 FillRandomsWithSingles<
249 TruncatedNormalDistribution<SingleSampleAdapter<PhiloxRandom>, T> >(
250 gen, &v1[0], v1.size());
251
252 for (int stride : strides) {
253 bool status = CheckSamplesMoments(v1, TruncatedNormalMoments(), max_moments,
254 stride, z_limit);
255 ASSERT_TRUE(status) << " NormalMomentsTest failing. seed: " << seed;
256 }
257 }
258
TEST(PhiloxRandomTest,UniformBfloat16MomentsTest)259 TEST(PhiloxRandomTest, UniformBfloat16MomentsTest) {
260 const std::vector<int> strides = {0, 1, 4, 17};
261 UniformMomentsTest<bfloat16>(1 << 20, 40, strides, bfloat16(kZLimitBfloat16));
262 }
263
TEST(PhiloxRandomTest,NormalBfloat16MomentsTest)264 TEST(PhiloxRandomTest, NormalBfloat16MomentsTest) {
265 const std::vector<int> strides = {0, 1, 4, 17};
266 NormalMomentsTest<bfloat16>(8 << 20, 25, strides, bfloat16(kZLimitBfloat16));
267 }
268
TEST(PhiloxRandomTest,RandomParametersBfloat16MomentsTest)269 TEST(PhiloxRandomTest, RandomParametersBfloat16MomentsTest) {
270 const std::vector<int> strides = {0, 1, 4, 17};
271 RandomParametersMomentsTest<bfloat16>(1 << 20, 40, strides,
272 bfloat16(kZLimitBfloat16));
273 }
274
TEST(PhiloxRandomTest,UniformFloatMomentsTest)275 TEST(PhiloxRandomTest, UniformFloatMomentsTest) {
276 const std::vector<int> strides = {0, 1, 4, 17};
277 UniformMomentsTest<float>(1 << 20, 40, strides, kZLimit);
278 }
279
TEST(PhiloxRandomTest,NormalFloatMomentsTest)280 TEST(PhiloxRandomTest, NormalFloatMomentsTest) {
281 const std::vector<int> strides = {0, 1, 4, 17};
282 NormalMomentsTest<float>(8 << 20, 25, strides, kZLimit);
283 }
284
TEST(PhiloxRandomTest,RandomParametersFloatMomentsTest)285 TEST(PhiloxRandomTest, RandomParametersFloatMomentsTest) {
286 const std::vector<int> strides = {0, 1, 4, 17};
287 RandomParametersMomentsTest<float>(1 << 20, 40, strides, kZLimit);
288 }
289
TEST(PhiloxRandomTest,UniformDoubleMomentsTest)290 TEST(PhiloxRandomTest, UniformDoubleMomentsTest) {
291 const std::vector<int> strides = {0, 1, 4, 17};
292 UniformMomentsTest<double>(1 << 20, 40, strides, kZLimit);
293 }
294
TEST(PhiloxRandomTest,NormalDoubleMomentsTest)295 TEST(PhiloxRandomTest, NormalDoubleMomentsTest) {
296 const std::vector<int> strides = {0, 1, 4, 17};
297 NormalMomentsTest<double>(8 << 20, 25, strides, kZLimit);
298 }
299
TEST(PhiloxRandomTest,RandomParametersDoubleMomentsTest)300 TEST(PhiloxRandomTest, RandomParametersDoubleMomentsTest) {
301 const std::vector<int> strides = {0, 1, 4, 17};
302 RandomParametersMomentsTest<double>(1 << 20, 40, strides, kZLimit);
303 }
304
305 class MockGenerator {
306 public:
MockGenerator(uint64 seed)307 explicit MockGenerator(uint64 seed) : counter_(seed) {}
308 using ResultType = std::vector<uint32>;
309 using ResultElementType = uint32;
310 static constexpr int kResultElementCount = 1;
operator ()()311 ResultType operator()() {
312 ResultType result;
313 result.push_back(counter_++);
314 return result;
315 }
316
317 private:
318 uint32 counter_;
319 };
320
321 template <typename T>
SingleSampleAdapterSkipTest()322 void SingleSampleAdapterSkipTest() {
323 std::vector<uint64> skips(10);
324 std::vector<uint64> skip_afters(10);
325 std::iota(skips.begin(), skips.end(), 0);
326 std::iota(skip_afters.begin(), skip_afters.end(), 0);
327 uint64 total_samples = 100;
328 uint64 seed = GetTestSeed();
329
330 for (uint64 skip : skips) {
331 for (uint64 skip_after : skip_afters) {
332 // Baseline rngs.
333 T parent_gen(seed);
334 SingleSampleAdapter<T> gen(&parent_gen);
335
336 // Rng on which Skip() is performed.
337 T parent_gen_to_skip(seed);
338 SingleSampleAdapter<T> gen_to_skip(&parent_gen_to_skip);
339
340 // Skip over `skip_after` samples from both `gen` and `gen_to_skip`.
341 int cur = 0;
342 for (; cur < skip_after; cur++) {
343 gen();
344 gen_to_skip();
345 }
346
347 // Skip over `skip_` samples from `gen` iteratively.
348 for (; cur < skip_after + skip; cur++) {
349 gen();
350 }
351
352 // Skip over `skip_` samples from `gen_to_skip` by calling `Skip()`.
353 gen_to_skip.Skip(skip);
354
355 // Assert that they produce same outputs afterwards.
356 for (; cur < total_samples; cur++) {
357 ASSERT_EQ(gen(), gen_to_skip());
358 }
359 }
360 }
361 }
362
TEST(SingleSampleAdapterTest,PhiloxRandomSkip)363 TEST(SingleSampleAdapterTest, PhiloxRandomSkip) {
364 SingleSampleAdapterSkipTest<PhiloxRandom>();
365 }
366
TEST(SingleSampleAdapterTest,MockGeneratorSkip)367 TEST(SingleSampleAdapterTest, MockGeneratorSkip) {
368 SingleSampleAdapterSkipTest<MockGenerator>();
369 }
370
371 } // namespace
372 } // namespace random
373 } // namespace tensorflow
374