1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_
18 
19 #include <array>
20 
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/xla_data.pb.h"
23 
24 namespace xla {
25 
26 // Records the bits and state generated by a random number generator.
27 struct RngOutput {
28   XlaOp value;
29   XlaOp state;
30 };
31 
32 // A BitGenerator returns random bits and updated random bit generator state.
33 //
34 // key: is a value input to a random number generator that can affect the
35 //   sequence of number it will generate. A random number generator constructs
36 //   its seed using the key and the initial state. The tf2xla bridge passes the
37 //   seed operand of a tensorflow random operation as a key to the random bit
38 //   generator, for example.
39 // initial_state: initial_state is the initial state of the current random
40 //   number generation. It could be 0 for a stateless random operation, and
41 //   the returned state from a previous execution for a stateful random
42 //   operation.
43 // shape: the shape of the random bits.
44 using BitGeneratorTy = std::function<RngOutput(XlaOp key, XlaOp initial_state,
45                                                const xla::Shape& shape)>;
46 
47 // Implements the ThreeFry counter-based PRNG algorithm.
48 // Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
49 // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
50 RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
51                                const xla::Shape& shape);
52 
53 // Implements the Philox algorithm to generate random numbers in parallel.
54 // Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
55 //   http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
56 //
57 // The paper presents a few variants of the Philox algorithm, we picked the
58 // 4x32_10 version of the algorithm for the following reasons:
59 //   . 4x32 uses 32-bit multiplication which is fast on GPUs.
60 //   . The authors recommend the 10-round variant, and TensorFlow also uses it.
61 RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state,
62                              const Shape& shape);
63 // Returns a scrambled pair of (state, key) from a single key.
64 std::pair<XlaOp, XlaOp> ScramblePhiloxKey(XlaOp key);
65 
66 // Uses the given bit generator to generate random bits and then converts the
67 // random bits to random numbers of uniform distribution in the given range.
68 // Returns the random numbers and the state of the random number generator.
69 // This function is for shape with floating point element types.
70 RngOutput UniformFloatingPointDistribution(XlaOp key, XlaOp initial_state,
71                                            BitGeneratorTy bit_generator,
72                                            XlaOp minval, XlaOp maxval,
73                                            const xla::Shape& shape);
74 
75 // Similar to UniformFloatingPointDistribution but for shape with integer
76 // element types.
77 RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state,
78                                  BitGeneratorTy bit_generator, XlaOp minval,
79                                  XlaOp maxval, const xla::Shape& shape);
80 
81 // Uses the given bit generator to generate random bits and then converts the
82 // random bits to random numbers of normal distribution.
83 // Returns the random numbers and the state of the random number generator.
84 RngOutput NormalFloatingPointDistribution(XlaOp key, XlaOp initial_state,
85                                           BitGeneratorTy bit_generator,
86                                           const xla::Shape& shape);
87 
88 // Concatenates scalars into a vector.
89 xla::XlaOp ConcatScalars(xla::XlaBuilder* builder,
90                          absl::Span<const xla::XlaOp> scalars);
91 
92 // Increases Philox counter (an uint128) by a delta (an uint64).
93 xla::XlaOp PhiloxIncreaseCounter(xla::XlaOp counter, xla::XlaOp delta);
94 
95 }  // namespace xla
96 
97 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_
98