1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/prng.h"
17 
18 #include <cmath>
19 #include <vector>
20 
21 #include "absl/base/casts.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/lib/math.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/util.h"
26 
27 namespace xla {
28 
ConcatScalars(xla::XlaBuilder * builder,absl::Span<const xla::XlaOp> scalars)29 xla::XlaOp ConcatScalars(xla::XlaBuilder* builder,
30                          absl::Span<const xla::XlaOp> scalars) {
31   std::vector<xla::XlaOp> vectors;
32   absl::c_transform(scalars, std::back_inserter(vectors),
33                     [](xla::XlaOp x) { return xla::Reshape(x, {1}); });
34   return ConcatInDim(builder, vectors, 0);
35 }
36 
37 namespace {
38 
39 // Rotates a 32-bit integer 'v' left by 'distance' bits.
RotateLeftU32(XlaOp v,int distance)40 XlaOp RotateLeftU32(XlaOp v, int distance) {
41   return (v << ConstantR0<uint32>(v.builder(), distance)) |
42          ShiftRightLogical(v, ConstantR0<uint32>(v.builder(), 32 - distance));
43 }
44 
45 // The internal state of the Three Fry implementation.
46 using ThreeFry2x32State = std::array<XlaOp, 2>;
47 
48 // Implements the ThreeFry counter-based PRNG algorithm.
49 // Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
50 // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
ThreeFry2x32(ThreeFry2x32State input,ThreeFry2x32State key)51 ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
52   XlaBuilder* builder = input[0].builder();
53   key[0] = BitcastConvertType(key[0], U32);
54   key[1] = BitcastConvertType(key[1], U32);
55 
56   // Rotation distances specified by the Threefry2x32 algorithm.
57   constexpr std::array<int, 8> rotations = {13, 15, 26, 6, 17, 29, 16, 24};
58   ThreeFry2x32State x;
59 
60   std::array<XlaOp, 3> ks;
61   // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm.
62   ks[2] = ConstantR0<uint32>(builder, 0x1BD11BDA);
63   for (int i = 0; i < 2; ++i) {
64     ks[i] = key[i];
65     x[i] = input[i];
66     ks[2] = ks[2] ^ key[i];
67   }
68 
69   x[0] = x[0] + ks[0];
70   x[1] = x[1] + ks[1];
71 
72   // Performs a single round of the Threefry2x32 algorithm, with a rotation
73   // amount 'rotation'.
74   auto round = [](ThreeFry2x32State v, int rotation) {
75     v[0] = v[0] + v[1];
76     v[1] = RotateLeftU32(v[1], rotation);
77     v[1] = v[0] ^ v[1];
78     return v;
79   };
80 
81   // There are no known statistical flaws with 13 rounds of Threefry2x32.
82   // We are conservative and use 20 rounds.
83   x = round(x, rotations[0]);
84   x = round(x, rotations[1]);
85   x = round(x, rotations[2]);
86   x = round(x, rotations[3]);
87   x[0] = x[0] + ks[1];
88   x[1] = x[1] + ks[2] + ConstantR0<uint32>(builder, 1);
89 
90   x = round(x, rotations[4]);
91   x = round(x, rotations[5]);
92   x = round(x, rotations[6]);
93   x = round(x, rotations[7]);
94   x[0] = x[0] + ks[2];
95   x[1] = x[1] + ks[0] + ConstantR0<uint32>(builder, 2);
96 
97   x = round(x, rotations[0]);
98   x = round(x, rotations[1]);
99   x = round(x, rotations[2]);
100   x = round(x, rotations[3]);
101   x[0] = x[0] + ks[0];
102   x[1] = x[1] + ks[1] + ConstantR0<uint32>(builder, 3);
103 
104   x = round(x, rotations[4]);
105   x = round(x, rotations[5]);
106   x = round(x, rotations[6]);
107   x = round(x, rotations[7]);
108   x[0] = x[0] + ks[1];
109   x[1] = x[1] + ks[2] + ConstantR0<uint32>(builder, 4);
110 
111   x = round(x, rotations[0]);
112   x = round(x, rotations[1]);
113   x = round(x, rotations[2]);
114   x = round(x, rotations[3]);
115   x[0] = x[0] + ks[2];
116   x[1] = x[1] + ks[0] + ConstantR0<uint32>(builder, 5);
117 
118   return x;
119 }
120 
121 // Converts a uint64 to two uint32s.
Uint64ToUint32s(XlaOp u64)122 std::array<XlaOp, 2> Uint64ToUint32s(XlaOp u64) {
123   XlaBuilder* builder = u64.builder();
124   XlaOp const32 = ConstantR0WithType(builder, U64, 32);
125   XlaOp fst = ConvertElementType(u64, U32);
126   XlaOp snd = ConvertElementType(ShiftRightLogical(u64, const32), U32);
127   return {fst, snd};
128 }
129 
130 // Converts two uint32s to a uint64.
Uint32sToUint64(std::array<XlaOp,2> u32s)131 XlaOp Uint32sToUint64(std::array<XlaOp, 2> u32s) {
132   XlaBuilder* builder = u32s[0].builder();
133   return ConvertElementType(u32s[0], U64) |
134          ShiftLeft(ConvertElementType(u32s[1], U64),
135                    ConstantR0WithType(builder, U64, 32));
136 }
137 
138 // Given the initial state and the request shape of random numbers to be
139 // generated, returns the input for the random number generator and a new state.
GetThreeFryInputsAndUpdatedState(XlaOp initial_state,const Shape & shape)140 std::pair<ThreeFry2x32State, XlaOp> GetThreeFryInputsAndUpdatedState(
141     XlaOp initial_state, const Shape& shape) {
142   XlaBuilder* builder = initial_state.builder();
143   auto u64_shape = ShapeUtil::MakeShape(U64, shape.dimensions());
144   // initial_state is an R1, so reshape it to a scalar.
145   auto input_u64 = Broadcast(Reshape(initial_state, {}), shape.dimensions());
146   int64 trailing_dims_product = 1;
147   for (int64 i = shape.rank() - 1; i >= 0; --i) {
148     if (shape.dimensions(i) < 2) {
149       continue;
150     }
151     input_u64 =
152         input_u64 + (Iota(builder, u64_shape, i) *
153                      ConstantR0<uint64>(builder, trailing_dims_product));
154     trailing_dims_product *= shape.dimensions(i);
155   }
156   XlaOp new_state =
157       initial_state + ConstantR0<uint64>(builder, ShapeUtil::ElementsIn(shape));
158   return std::make_pair(Uint64ToUint32s(input_u64), new_state);
159 }
160 
161 // Result for SplitShapeIntoHalves().
162 struct SplitShapePair {
163   Shape half_shape;
164   Shape concat_shape;
165   int64 split_dim;
166   int64 new_concat_dim;
167 };
168 
169 // Split the shape on a dimension > 1 into two halves.
SplitShapeIntoHalves(const Shape & shape)170 SplitShapePair SplitShapeIntoHalves(const Shape& shape) {
171   SplitShapePair pair;
172   if (shape.rank() == 0) {
173     pair.half_shape = ShapeUtil::MakeShape(shape.element_type(), {1});
174     pair.concat_shape = ShapeUtil::MakeShape(shape.element_type(), {2});
175     pair.split_dim = 0;
176     pair.new_concat_dim = 0;
177     return pair;
178   }
179   pair.split_dim = -1;
180   for (int64 i = 0; i < shape.rank(); ++i) {
181     if (shape.dimensions(i) % 2 == 0) {
182       pair.split_dim = i;
183       break;
184     }
185   }
186   if (pair.split_dim == -1) {
187     // No even dims. Find a dimension with maximum size.
188     for (int64 i = 0; i < shape.rank(); ++i) {
189       if (pair.split_dim == -1 ||
190           shape.dimensions(i) > shape.dimensions(pair.split_dim)) {
191         pair.split_dim = i;
192       }
193     }
194   }
195   CHECK_GE(pair.split_dim, 0);
196   std::vector<int64> half_shape_dims;
197   std::vector<int64> concat_shape_dims;
198   for (int64 i = 0; i < shape.rank(); ++i) {
199     if (i == pair.split_dim) {
200       // Create a new trivial dim for the later concat, which is more friendly
201       // to sharding propagation.
202       half_shape_dims.push_back(CeilOfRatio<int64>(shape.dimensions(i), 2));
203       half_shape_dims.push_back(1);
204       concat_shape_dims.push_back(half_shape_dims[i]);
205       concat_shape_dims.push_back(2);
206     } else {
207       half_shape_dims.push_back(shape.dimensions(i));
208       concat_shape_dims.push_back(shape.dimensions(i));
209     }
210   }
211   pair.new_concat_dim = pair.split_dim + 1;
212   pair.half_shape = ShapeUtil::MakeShape(shape.element_type(), half_shape_dims);
213   pair.concat_shape =
214       ShapeUtil::MakeShape(shape.element_type(), concat_shape_dims);
215   return pair;
216 }
217 
218 // Combines a pair of split shapes. It works with scalar and non-scalar shapes.
CombineShapePair(absl::Span<const XlaOp> pair,const SplitShapePair & shape_pair,const Shape & original_shape)219 XlaOp CombineShapePair(absl::Span<const XlaOp> pair,
220                        const SplitShapePair& shape_pair,
221                        const Shape& original_shape) {
222   if (original_shape.rank() == 0) {
223     return Reshape(pair[0], {});
224   }
225   XlaBuilder* builder = pair[0].builder();
226   XlaOp result = ConcatInDim(builder, pair, shape_pair.new_concat_dim);
227   const int64 pre_split_size = original_shape.dimensions(shape_pair.split_dim);
228   std::vector<int64> reshape_dims(original_shape.dimensions().begin(),
229                                   original_shape.dimensions().end());
230   reshape_dims[shape_pair.split_dim] =
231       RoundUpToNearest<int64>(pre_split_size, 2);
232   result = Reshape(result, reshape_dims);
233   if (reshape_dims[shape_pair.split_dim] != pre_split_size) {
234     result = Slice(result, std::vector<int64>(original_shape.rank(), 0),
235                    original_shape.dimensions(),
236                    std::vector<int64>(original_shape.rank(), 1));
237   }
238   return result;
239 }
240 
241 // Generates random 32bits with the given shape using the Three Fry
242 // implementation. Returns the random bits and the new state.
ThreeFryRngBit32(XlaOp key,XlaOp initial_state,const Shape & shape)243 RngOutput ThreeFryRngBit32(XlaOp key, XlaOp initial_state, const Shape& shape) {
244   auto shape_pair = SplitShapeIntoHalves(shape);
245   std::pair<ThreeFry2x32State, XlaOp> inputs_state =
246       GetThreeFryInputsAndUpdatedState(initial_state, shape_pair.half_shape);
247   ThreeFry2x32State inputs = inputs_state.first;
248   ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
249   XlaOp result = CombineShapePair(outputs, shape_pair, shape);
250   return {result, inputs_state.second};
251 }
252 
253 // Generates random 64bits with the given shape using the Three Fry
254 // implementation. Returns the random bits and the new state.
ThreeFryRngBit64(XlaOp key,XlaOp initial_state,const Shape & shape)255 RngOutput ThreeFryRngBit64(XlaOp key, XlaOp initial_state, const Shape& shape) {
256   std::pair<ThreeFry2x32State, XlaOp> inputs_state =
257       GetThreeFryInputsAndUpdatedState(initial_state, shape);
258   ThreeFry2x32State inputs = inputs_state.first;
259   ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
260   XlaOp result = Uint32sToUint64(outputs);
261   return {result, inputs_state.second};
262 }
263 
264 // The key of the Philox random number generator.
265 using Philox4x32Key = std::array<XlaOp, 2>;
266 // The internal state of the Philox random number generator.
267 using Philox4x32State = std::array<XlaOp, 4>;
268 
269 // Computes the Philox4x32 algorithm using 10 rounds.
Philox4x32(Philox4x32State state,Philox4x32Key key)270 Philox4x32State Philox4x32(Philox4x32State state, Philox4x32Key key) {
271   // Constants specified by the Philox algorithm.
272   static const uint32 kPhiloxW32A = 0x9E3779B9;
273   static const uint32 kPhiloxW32B = 0xBB67AE85;
274   static const uint32 kPhiloxM4x32A = 0xD2511F53;
275   static const uint32 kPhiloxM4x32B = 0xCD9E8D57;
276 
277   struct HighLowPair {
278     XlaOp high;
279     XlaOp low;
280   };
281 
282   // Compute the high and low words from multiplying two 32-bit integers.
283   auto mul_hi_low = [](XlaOp x, uint32 k) {
284     auto product =
285         ConvertElementType(x, U64) * ConstantR0<uint64>(x.builder(), k);
286     auto low = ConvertElementType(product, U32);
287     auto high =
288         ConvertElementType(product >> ConstantR0<uint64>(x.builder(), 32), U32);
289     return HighLowPair{high, low};
290   };
291 
292   // Perform a single round of the Philox algorithm.
293   auto philox_round = [&](Philox4x32State x, Philox4x32Key key) {
294     auto product0 = mul_hi_low(x[0], kPhiloxM4x32A);
295     auto product1 = mul_hi_low(x[2], kPhiloxM4x32B);
296     return Philox4x32State{product1.high ^ x[1] ^ key[0], product1.low,
297                            product0.high ^ x[3] ^ key[1], product0.low};
298   };
299 
300   // Update the key after a round of Philox algorithm.
301   auto raise_key = [](Philox4x32Key key) {
302     XlaBuilder* builder = key[0].builder();
303     return Philox4x32Key{key[0] + ConstantR0<uint32>(builder, kPhiloxW32A),
304                          key[1] + ConstantR0<uint32>(builder, kPhiloxW32B)};
305   };
306 
307   static const int kNumRounds = 10;
308   for (int round = 0; round < kNumRounds; ++round, key = raise_key(key)) {
309     state = philox_round(state, key);
310   }
311   return state;
312 }
313 
314 // Scrambles the input key so that users don't need to worry about which part
315 // of the key needs to be strong.
ScramblePhiloxKey(Philox4x32Key key)316 std::pair<Philox4x32State, Philox4x32Key> ScramblePhiloxKey(Philox4x32Key key) {
317   XlaBuilder* builder = key[0].builder();
318   XlaOp key0 = ConvertElementType(key[0], U64);
319   XlaOp key1 = ConvertElementType(key[1], U64);
320 
321   Philox4x32State state = {
322       ConvertElementType(key0, U32),
323       ConvertElementType(key0 >> ScalarLike(key0, 32), U32),
324       ConvertElementType(key1, U32),
325       ConvertElementType(key1 >> ScalarLike(key1, 32), U32),
326   };
327   key = {ConstantR0<uint32>(builder, 0x3ec8f720),
328          ConstantR0<uint32>(builder, 0x02461e29)};
329   state = Philox4x32(state, key);
330   XlaOp zero = ConstantR0<uint32>(builder, 0);
331   return {Philox4x32State{zero, zero, state[2], state[3]},
332           Philox4x32Key{state[0], state[1]}};
333 }
334 
335 // Adds an U128 tensor with an U64 tensor. The U128 tensor is represented as two
336 // U64s with the low 64bits in the front. This routine supports explicit
337 // broadcasting of the U128 tensor, with `broadcast_sizes` representing the
338 // dimensions prepended to its shape.
Uint128AddUint64(const std::array<XlaOp,2> & u128,XlaOp u64,absl::Span<const int64> broadcast_sizes={})339 std::array<XlaOp, 2> Uint128AddUint64(
340     const std::array<XlaOp, 2>& u128, XlaOp u64,
341     absl::Span<const int64> broadcast_sizes = {}) {
342   auto u128_low = u128[0];
343   auto u128_high = u128[1];
344   XlaOp new_u128_low = u128_low + u64;
345   XlaOp one = ConstantR0<uint64>(u128[0].builder(), 1);
346   XlaOp new_u128_high = Select(Lt(new_u128_low, u128_low),
347                                Broadcast(u128_high + one, broadcast_sizes),
348                                Broadcast(u128_high, broadcast_sizes));
349   return {new_u128_low, new_u128_high};
350 }
351 
Uint32sToUint128(const std::array<XlaOp,4> & u32s)352 std::array<XlaOp, 2> Uint32sToUint128(const std::array<XlaOp, 4>& u32s) {
353   return {Uint32sToUint64({u32s[0], u32s[1]}),
354           Uint32sToUint64({u32s[2], u32s[3]})};
355 }
356 
Uint128ToUint32s(const std::array<XlaOp,2> & u128)357 std::array<XlaOp, 4> Uint128ToUint32s(const std::array<XlaOp, 2>& u128) {
358   std::array<XlaOp, 2> u128_low_32s = Uint64ToUint32s(u128[0]);
359   std::array<XlaOp, 2> u128_high_32s = Uint64ToUint32s(u128[1]);
360   return {u128_low_32s[0], u128_low_32s[1], u128_high_32s[0], u128_high_32s[1]};
361 }
362 
Uint128FromOp(XlaOp op)363 std::array<XlaOp, 2> Uint128FromOp(XlaOp op) {
364   auto u128_low = xla::Reshape(xla::Slice(op, {0}, {1}, {1}), {});
365   auto u128_high = xla::Reshape(xla::Slice(op, {1}, {2}, {1}), {});
366   return {u128_low, u128_high};
367 }
368 
Uint128ToOp(std::array<XlaOp,2> u128)369 XlaOp Uint128ToOp(std::array<XlaOp, 2> u128) {
370   return ConcatScalars(u128[0].builder(), {u128[0], u128[1]});
371 }
372 
373 // Returns the pair (state + [0, 1, ..., n-1], state + n), which should be used
374 // as the inputs fed to `Philox4x32` and the updated state. `state` is an U128
375 // represented as 4 U32s in the order from the least significant one to the most
376 // significant one.
GetPhiloxInputsAndUpdatedState(const Philox4x32State & state,int64 n)377 std::pair<Philox4x32State, XlaOp> GetPhiloxInputsAndUpdatedState(
378     const Philox4x32State& state, int64 n) {
379   XlaBuilder* builder = state[0].builder();
380   XlaOp iota = Iota(builder, U64, n);
381   auto state_u128 = Uint32sToUint128(state);
382   auto inputs = Uint128ToUint32s(Uint128AddUint64(state_u128, iota, {n}));
383   XlaOp new_state =
384       Uint128ToOp(Uint128AddUint64(state_u128, ConstantR0<uint64>(builder, n)));
385   return std::make_pair(inputs, new_state);
386 }
387 
388 // Generates CeilOfRatio(num_elems, 4)*4 32bit Philox random numbers, as Philox
389 // numbers are generated in the unit of 128bits.
GeneratePhiloxBits(int64 num_elems,XlaOp initial_state,Philox4x32Key key)390 std::pair<Philox4x32State, XlaOp> GeneratePhiloxBits(int64 num_elems,
391                                                      XlaOp initial_state,
392                                                      Philox4x32Key key) {
393   Philox4x32State state;
394   state = Uint128ToUint32s(Uint128FromOp(initial_state));
395   const int64 num_vector4 = CeilOfRatio<int64>(num_elems, 4);
396   Philox4x32State inputs;
397   XlaOp new_state;
398   std::tie(inputs, new_state) =
399       GetPhiloxInputsAndUpdatedState(state, num_vector4);
400   auto outputs = Philox4x32(inputs, key);
401   return std::make_pair(outputs, new_state);
402 }
403 
404 // Generates an array of primitive type U32 with the given shape containing
405 // random bits generated by the Philox algorithm. Returns the array and the new
406 // state of the random number generator.
PhiloxRngBit32(XlaOp op_key,XlaOp initial_state,const Shape & shape)407 RngOutput PhiloxRngBit32(XlaOp op_key, XlaOp initial_state,
408                          const Shape& shape) {
409   XlaBuilder* builder = op_key.builder();
410   const int64 num_elems = ShapeUtil::ElementsIn(shape);
411 
412   Philox4x32Key key = Uint64ToUint32s(op_key);
413   Philox4x32State bits;
414   XlaOp new_state;
415   std::tie(bits, new_state) = GeneratePhiloxBits(num_elems, initial_state, key);
416   // Combining bits[i] in a round-robin fashion, to align with non-XLA
417   // implementations
418   int64 bits_len = (num_elems + 3) / 4;
419   for (auto i = 0; i < 4; ++i) {
420     bits[i] = Reshape(bits[i], {bits_len, 1});
421   }
422   XlaOp numbers = ConcatInDim(builder, {bits[0], bits[1], bits[2], bits[3]},
423                               /*dimension=*/1);
424   numbers = Reshape(numbers, {bits_len * 4});
425   numbers = Slice(numbers, /*start_indices=*/{0},
426                   /*limit_indices=*/{num_elems},
427                   /*strides=*/{1});
428   return {Reshape(numbers, AsInt64Slice(shape.dimensions())), new_state};
429 }
430 
431 // Generates an array of primitive type U64 with the given shape containing
432 // random bits generated by the Philox algorithm. Returns the array and the new
433 // state of the random number generator.
PhiloxRngBit64(XlaOp op_key,XlaOp initial_state,const Shape & shape)434 RngOutput PhiloxRngBit64(XlaOp op_key, XlaOp initial_state,
435                          const Shape& shape) {
436   XlaBuilder* builder = op_key.builder();
437   const int64 num_elems = ShapeUtil::ElementsIn(shape);
438 
439   Philox4x32Key key = Uint64ToUint32s(op_key);
440   Philox4x32State bits32;
441   XlaOp new_state;
442   std::tie(bits32, new_state) =
443       GeneratePhiloxBits(num_elems * 2, initial_state, key);
444 
445   std::array<XlaOp, 2> bits64;
446   bits64[0] = Uint32sToUint64({bits32[0], bits32[1]});
447   bits64[1] = Uint32sToUint64({bits32[2], bits32[3]});
448 
449   // Combining bits64[i] in a round-robin fashion, to align with non-XLA
450   // implementations
451   int64 bits64_len = (num_elems + 1) / 2;
452   for (auto i = 0; i < 2; ++i) {
453     bits64[i] = Reshape(bits64[i], {bits64_len, 1});
454   }
455   XlaOp numbers = ConcatInDim(builder, {bits64[0], bits64[1]},
456                               /*dimension=*/1);
457   numbers = Reshape(numbers, {bits64_len * 2});
458   numbers = Slice(numbers, /*start_indices=*/{0},
459                   /*limit_indices=*/{num_elems},
460                   /*strides=*/{1});
461   return {Reshape(numbers, AsInt64Slice(shape.dimensions())), new_state};
462 }
463 
ConvertRandomBitsToUniformFloatingPoint(XlaOp bits,XlaOp minval,XlaOp maxval)464 XlaOp ConvertRandomBitsToUniformFloatingPoint(XlaOp bits, XlaOp minval,
465                                               XlaOp maxval) {
466   XlaBuilder* builder = bits.builder();
467   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
468     TF_ASSIGN_OR_RETURN(const Shape* minval_shape,
469                         builder->GetShapePtr(minval));
470     TF_ASSIGN_OR_RETURN(const Shape* bits_shape, builder->GetShapePtr(bits));
471     PrimitiveType value_type = minval_shape->element_type();
472     PrimitiveType bit_type = bits_shape->element_type();
473     CHECK((value_type == F32 && bit_type == U32) ||
474           (value_type == F64 && bit_type == U64));
475 
476     // Form random mantissa bits for float/double, with a leading 1 bit.
477     int num_float_bits = primitive_util::BitWidth(value_type);
478     // Subtract one as SignificandWidth includes the leading 1 bit.
479     int num_mantissa_bits = primitive_util::SignificandWidth(value_type) - 1;
480 
481     // Ignore the exponent bits and convert the mantissa bits to the floating
482     // point type.
483     bits = ShiftRightLogical(
484         bits, ScalarLike(bits, num_float_bits - num_mantissa_bits));
485 
486     // We have an integer-valued floating point number in the range
487     // [0, 2**{num_mantissa_bits}).
488     XlaOp values = ConvertElementType(bits, value_type);
489 
490     // Divide by 2**{-num_mantissa_bits} to get a number in the range
491     // [0.0, 1.0).
492     values = values * ScalarLike(values, std::ldexp(1., -num_mantissa_bits));
493 
494     // Multiply and add to shift to the range [minval, maxval).
495     return values * (maxval - minval) + minval;
496   });
497 }
498 
ConvertRandomBitsToUniformInt(XlaOp bits,XlaOp minval,XlaOp maxval,PrimitiveType type,PrimitiveType unsigned_type)499 XlaOp ConvertRandomBitsToUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
500                                     PrimitiveType type,
501                                     PrimitiveType unsigned_type) {
502   XlaBuilder* builder = bits.builder();
503   XlaOp range = BitcastConvertType(maxval, unsigned_type) -
504                 BitcastConvertType(minval, unsigned_type);
505   XlaOp dist = Rem(bits, range);
506   XlaOp dist_div_2 =
507       ShiftRightLogical(dist, ConstantR0WithType(builder, unsigned_type, 1));
508 
509   return minval + BitcastConvertType(dist_div_2, type) +
510          BitcastConvertType(dist - dist_div_2, type);
511 }
512 
513 // Implements the Box-Muller transform, which converts random floats in the
514 // range of [0, 1] from uniform distribution to normal distribution with mean 0
515 // and variance 1. For more detail on the Box-Muller transform, see
516 // http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform#Basic_form
BoxMullerTransform(XlaOp x0,XlaOp x1)517 std::pair<XlaOp, XlaOp> BoxMullerTransform(XlaOp x0, XlaOp x1) {
518   // Do not send a really small number to log().
519   XlaOp u1 = Max(x0, ScalarLike(x0, 1.0e-7f));
520 
521   XlaOp v1 = ScalarLike(x1, 2.0f * M_PI) * x1;
522   XlaOp u2 = Sqrt(ScalarLike(u1, -2.0f) * Log(u1));
523   return {Sin(v1) * u2, Cos(v1) * u2};
524 }
525 
526 }  // namespace
527 
PhiloxIncreaseCounter(XlaOp counter,XlaOp delta)528 XlaOp PhiloxIncreaseCounter(XlaOp counter, XlaOp delta) {
529   return Uint128ToOp(Uint128AddUint64(Uint128FromOp(counter), delta));
530 }
531 
ThreeFryBitGenerator(XlaOp key,XlaOp initial_state,const Shape & shape)532 RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
533                                const Shape& shape) {
534   PrimitiveType type = shape.element_type();
535   switch (type) {
536     case F32:
537     case U32:
538     case S32:
539       return ThreeFryRngBit32(key, initial_state, shape);
540     case F64:
541     case U64:
542     case S64:
543       return ThreeFryRngBit64(key, initial_state, shape);
544     default:
545       return {key.builder()->ReportError(Unimplemented(
546                   "Types other than F32, F64, U32, S32, U64 and S64 "
547                   "are not implemented by ThreeFryBitGenerator; got %s",
548                   primitive_util::LowercasePrimitiveTypeName(type))),
549               initial_state};
550   }
551 }
552 
PhiloxBitGenerator(XlaOp key,XlaOp initial_state,const Shape & shape)553 RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state,
554                              const Shape& shape) {
555   PrimitiveType type = shape.element_type();
556   switch (type) {
557     case F32:
558     case U32:
559     case S32:
560       return PhiloxRngBit32(key, initial_state, shape);
561     case F64:
562     case U64:
563     case S64:
564       return PhiloxRngBit64(key, initial_state, shape);
565     default:
566       return {key.builder()->ReportError(Unimplemented(
567                   "Types other than F32, F64, U32, S32, U64 and S64 "
568                   "are not implemented by PhiloxFryBitGenerator; got %s",
569                   primitive_util::LowercasePrimitiveTypeName(type))),
570               initial_state};
571   }
572 }
573 
ScramblePhiloxKey(XlaOp key)574 std::pair<XlaOp, XlaOp> ScramblePhiloxKey(XlaOp key) {
575   Philox4x32Key pkey = Uint64ToUint32s(key);
576   auto state_key = ScramblePhiloxKey(pkey);
577   return std::make_pair(Uint128ToOp(Uint32sToUint128(state_key.first)),
578                         Uint32sToUint64(state_key.second));
579 }
580 
UniformFloatingPointDistribution(XlaOp key,XlaOp initial_state,BitGeneratorTy bit_generator,XlaOp minval,XlaOp maxval,const Shape & shape)581 RngOutput UniformFloatingPointDistribution(XlaOp key, XlaOp initial_state,
582                                            BitGeneratorTy bit_generator,
583                                            XlaOp minval, XlaOp maxval,
584                                            const Shape& shape) {
585   RngOutput bits_state = bit_generator(key, initial_state, shape);
586   XlaOp bits = bits_state.value;
587   XlaOp new_state = bits_state.state;
588   return {ConvertRandomBitsToUniformFloatingPoint(bits, minval, maxval),
589           new_state};
590 }
591 
UniformIntDistribution(XlaOp key,XlaOp initial_state,BitGeneratorTy bit_generator,XlaOp minval,XlaOp maxval,const Shape & shape)592 RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state,
593                                  BitGeneratorTy bit_generator, XlaOp minval,
594                                  XlaOp maxval, const Shape& shape) {
595   RngOutput bits_state = bit_generator(key, initial_state, shape);
596   XlaOp bits = bits_state.value;
597   XlaOp new_state = bits_state.state;
598   PrimitiveType type = shape.element_type();
599   PrimitiveType unsigned_type;
600   if (type == U32 || type == S32) {
601     unsigned_type = U32;
602   } else {
603     DCHECK(type == U64 || type == S64);
604     unsigned_type = U64;
605   }
606   return {
607       ConvertRandomBitsToUniformInt(bits, minval, maxval, type, unsigned_type),
608       new_state};
609 }
610 
NormalFloatingPointDistribution(XlaOp key,XlaOp initial_state,BitGeneratorTy bit_generator,const Shape & shape)611 RngOutput NormalFloatingPointDistribution(XlaOp key, XlaOp initial_state,
612                                           BitGeneratorTy bit_generator,
613                                           const Shape& shape) {
614   PrimitiveType primitive_type = shape.element_type();
615   DCHECK(primitive_type == F32 || primitive_type == F64);
616 
617   XlaBuilder* builder = key.builder();
618   auto shape_pair = SplitShapeIntoHalves(shape);
619   RngOutput bits_state = UniformFloatingPointDistribution(
620       key, initial_state, bit_generator,
621       xla::ConstantR0WithType(builder, primitive_type, 0.0),
622       xla::ConstantR0WithType(builder, primitive_type, 1.0),
623       shape_pair.concat_shape);
624 
625   // Separate the bits into two groups to perform the Box-Muller transform.
626   XlaOp bits_0 = Slice(bits_state.value,
627                        std::vector<int64>(shape_pair.half_shape.rank(), 0),
628                        shape_pair.half_shape.dimensions(),
629                        std::vector<int64>(shape_pair.half_shape.rank(), 1));
630   std::vector<int64> bits_1_starts(shape_pair.half_shape.rank(), 0);
631   bits_1_starts[shape_pair.new_concat_dim] = 1;
632   XlaOp bits_1 = Slice(bits_state.value, bits_1_starts,
633                        shape_pair.concat_shape.dimensions(),
634                        std::vector<int64>(shape_pair.half_shape.rank(), 1));
635   std::tie(bits_0, bits_1) = BoxMullerTransform(bits_0, bits_1);
636 
637   // Put the numbers in the two groups back to form the requested shape.
638   XlaOp normal = CombineShapePair({bits_0, bits_1}, shape_pair, shape);
639   return {normal, bits_state.state};
640 }
641 
642 }  // namespace xla
643