1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/stateful_random_ops.h"
17
18 #include <cmath>
19
20 #include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
21 #include "tensorflow/compiler/tf2xla/lib/random.h"
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/tf2xla/type_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
27 #include "tensorflow/compiler/xla/client/lib/constants.h"
28 #include "tensorflow/compiler/xla/client/lib/math.h"
29 #include "tensorflow/compiler/xla/client/lib/prng.h"
30 #include "tensorflow/compiler/xla/client/xla_builder.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/rng_alg.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/lib/math/math_util.h"
37
38 namespace tensorflow {
39 namespace {
40
BitGen(Algorithm alg)41 xla::BitGeneratorTy BitGen(Algorithm alg) {
42 if (alg == RNG_ALG_PHILOX) {
43 return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
44 state =
45 xla::ConcatInDim(key.builder(), {xla::Reshape(key, {1}), state}, 0);
46 xla::XlaOp result =
47 xla::RngBitGenerator(xla::RandomAlgorithm::RNG_PHILOX, state, shape);
48 xla::XlaOp data = xla::GetTupleElement(result, 1);
49 xla::XlaOp new_state =
50 xla::Slice(xla::GetTupleElement(result, 0), {1}, {3}, {1});
51 return xla::RngOutput{data, new_state};
52 };
53 } else {
54 return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
55 state = xla::ConcatScalars(key.builder(), {key, state});
56 xla::XlaOp result = xla::RngBitGenerator(
57 xla::RandomAlgorithm::RNG_THREE_FRY, state, shape);
58 xla::XlaOp data = xla::GetTupleElement(result, 1);
59 xla::XlaOp new_state = xla::Reshape(
60 xla::Slice(xla::GetTupleElement(result, 0), {1}, {2}, {1}), {});
61 return xla::RngOutput{data, new_state};
62 };
63 }
64 }
65
StatefulRngUniform(Algorithm alg,xla::XlaOp key,xla::XlaOp initial_state,const xla::Shape & shape,xla::XlaOp minval,xla::XlaOp maxval)66 xla::RngOutput StatefulRngUniform(Algorithm alg, xla::XlaOp key,
67 xla::XlaOp initial_state,
68 const xla::Shape& shape, xla::XlaOp minval,
69 xla::XlaOp maxval) {
70 xla::PrimitiveType type = shape.element_type();
71 switch (type) {
72 case xla::F32:
73 case xla::F64:
74 return xla::UniformFloatingPointDistribution(
75 key, initial_state, BitGen(alg), minval, maxval, shape);
76 case xla::U32:
77 case xla::S32:
78 case xla::U64:
79 case xla::S64:
80 return UniformIntDistribution(key, initial_state, BitGen(alg), minval,
81 maxval, shape);
82 default:
83 return {key.builder()->ReportError(xla::Unimplemented(
84 "Types other than F32, U32, S32, U64 and S64 "
85 "are not implemented by "
86 "StatefulRngUniform; got %s",
87 xla::primitive_util::LowercasePrimitiveTypeName(type))),
88 initial_state};
89 }
90 }
91
StatefulRngUniformFullInt(Algorithm alg,xla::XlaOp key,xla::XlaOp initial_state,const xla::Shape & shape)92 xla::RngOutput StatefulRngUniformFullInt(Algorithm alg, xla::XlaOp key,
93 xla::XlaOp initial_state,
94 const xla::Shape& shape) {
95 xla::PrimitiveType type = shape.element_type();
96 xla::RngOutput output = BitGen(alg)(key, initial_state, shape);
97 switch (type) {
98 case xla::U32:
99 case xla::U64:
100 return output;
101 case xla::S32:
102 case xla::S64:
103 output.value = BitcastConvertType(output.value, type);
104 return output;
105 default:
106 return {
107 key.builder()->ReportError(xla::Unimplemented(
108 "Types other than U32, S32, U64 and S64 are not implemented by "
109 "StatefulRngUniformFullInt; got: %s",
110 xla::primitive_util::LowercasePrimitiveTypeName(type))),
111 initial_state};
112 }
113 }
114
115 using SamplerReturnType = xla::StatusOr<xla::RngOutput>;
116
GetMinStateSize(Algorithm alg)117 int64 GetMinStateSize(Algorithm alg) {
118 if (alg == RNG_ALG_PHILOX) {
119 return PHILOX_MIN_STATE_SIZE;
120 }
121 return THREEFRY_MIN_STATE_SIZE;
122 }
123
CheckStateShape(Algorithm alg,const TensorShape & shape)124 Status CheckStateShape(Algorithm alg, const TensorShape& shape) {
125 if (shape.dims() != 1) {
126 return errors::InvalidArgument(
127 "RNG state must have one and only one dimension, not ", shape.dims());
128 }
129 auto state_size = shape.dim_size(0);
130 auto min_state_size = GetMinStateSize(alg);
131 if (state_size < min_state_size) {
132 return errors::InvalidArgument("The size of the state must be at least ",
133 min_state_size, "; got ", state_size);
134 }
135 return Status::OK();
136 }
137
StateAndKeyFromVariable(Algorithm alg,xla::XlaOp var)138 std::pair<xla::XlaOp, xla::XlaOp> StateAndKeyFromVariable(Algorithm alg,
139 xla::XlaOp var) {
140 if (alg == RNG_ALG_THREEFRY) {
141 static constexpr int kStateSize = 1;
142 auto state = BitcastConvertType(
143 xla::Reshape(xla::Slice(var, {0}, {kStateSize}, {1}), {}), xla::U64);
144 auto key = BitcastConvertType(
145 xla::Reshape(xla::Slice(var, {kStateSize}, {kStateSize + 1}, {1}), {}),
146 xla::U64);
147 return std::make_pair(state, key);
148 } else {
149 static constexpr int kStateSize = 2;
150 auto state =
151 BitcastConvertType(xla::Slice(var, {0}, {kStateSize}, {1}), xla::U64);
152 auto key = xla::Reshape(
153 BitcastConvertType(xla::Slice(var, {kStateSize}, {kStateSize + 1}, {1}),
154 xla::U64),
155 {});
156 return std::make_pair(state, key);
157 }
158 }
159
StateAndKeyToVariable(Algorithm alg,xla::XlaOp state,xla::XlaOp key)160 xla::XlaOp StateAndKeyToVariable(Algorithm alg, xla::XlaOp state,
161 xla::XlaOp key) {
162 auto builder = state.builder();
163 if (alg == RNG_ALG_THREEFRY) {
164 return ConcatScalars(builder, {state, key});
165 } else {
166 return ConcatInDim(builder, {state, xla::Reshape(key, {1})}, 0);
167 }
168 }
169
170 // A helper function containing the common part of several kernels below.
171 // Precondition: 'algorithm' and 'shape' are compile-time constants.
CompileImpl(XlaOpKernelContext * ctx,int state_input_idx,int alg_input_idx,int shape_input_idx,std::function<SamplerReturnType (Algorithm,xla::XlaOp,xla::XlaOp,TensorShape)> const & sampler)172 Status CompileImpl(
173 XlaOpKernelContext* ctx, int state_input_idx, int alg_input_idx,
174 int shape_input_idx,
175 std::function<SamplerReturnType(Algorithm, xla::XlaOp, xla::XlaOp,
176 TensorShape)> const& sampler) {
177 auto alg_shape = ctx->InputShape(alg_input_idx);
178 if (alg_shape.dims() != 0) {
179 return errors::InvalidArgument("algorithm must be of shape [], not ",
180 alg_shape.DebugString());
181 }
182 xla::Literal alg_literal;
183 TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal));
184 Algorithm alg = Algorithm(alg_literal.Get<int64>({}));
185 if (!(alg == RNG_ALG_THREEFRY || alg == RNG_ALG_PHILOX)) {
186 return errors::InvalidArgument("Unsupported algorithm id: ", alg);
187 }
188
189 xla::XlaOp var;
190 TensorShape var_shape;
191 TF_RETURN_IF_ERROR(ctx->ReadVariableInput(
192 state_input_idx, STATE_ELEMENT_DTYPE, &var_shape, &var));
193 TF_RETURN_IF_ERROR(CheckStateShape(alg, var_shape));
194 TensorShape shape;
195 TF_RETURN_IF_ERROR(ctx->ConstantInputAsShape(shape_input_idx, &shape));
196 xla::XlaOp state;
197 xla::XlaOp key;
198 std::tie(state, key) = StateAndKeyFromVariable(alg, var);
199 auto status_or_value = sampler(alg, state, key, shape);
200 if (!status_or_value.ok()) {
201 return status_or_value.status();
202 }
203 xla::RngOutput value_state = status_or_value.ConsumeValueOrDie();
204 state = value_state.state;
205 ctx->SetOutput(0, value_state.value);
206 var = StateAndKeyToVariable(alg, state, key);
207 xla::PrimitiveType state_element_type;
208 TF_RETURN_IF_ERROR(
209 DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type));
210 var = BitcastConvertType(var, state_element_type);
211 TF_RETURN_IF_ERROR(
212 ctx->AssignVariable(state_input_idx, STATE_ELEMENT_DTYPE, var));
213 return Status::OK();
214 }
215
216 class StatefulUniformOp : public XlaOpKernel {
217 public:
StatefulUniformOp(OpKernelConstruction * ctx)218 explicit StatefulUniformOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
219 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
220 }
221
Compile(XlaOpKernelContext * ctx)222 void Compile(XlaOpKernelContext* ctx) override {
223 xla::XlaBuilder* builder = ctx->builder();
224 auto sampler = [builder, this](Algorithm alg, xla::XlaOp state,
225 xla::XlaOp key,
226 TensorShape shape) -> SamplerReturnType {
227 xla::Shape xla_shape;
228 DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
229 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
230 xla::PrimitiveType rng_primitive_type = xla_shape.element_type();
231 xla::RngOutput uniform_state = StatefulRngUniform(
232 alg, key, state, xla_shape,
233 xla::ConstantR0WithType(builder, rng_primitive_type, 0.0),
234 xla::ConstantR0WithType(builder, rng_primitive_type, 1.0));
235 xla::XlaOp uniform = uniform_state.value;
236 state = uniform_state.state;
237 uniform = MaybeConvertF32ToBF16(uniform, dtype_);
238 return {{uniform, state}};
239 };
240 OP_REQUIRES_OK(ctx,
241 CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
242 /*shape_input_idx=*/2, sampler));
243 }
244
245 private:
246 DataType dtype_;
247
248 TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformOp);
249 };
250
251 // TODO(wangpeng): Support plain float16 to get rid of the `TypeConstraint`.
252 REGISTER_XLA_OP(Name("StatefulUniform")
253 .CompileTimeConstantInput("algorithm")
254 .CompileTimeConstantInput("shape")
255 .TypeConstraint("dtype",
256 {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
257 StatefulUniformOp);
258
259 class StatefulStandardNormalOp : public XlaOpKernel {
260 public:
StatefulStandardNormalOp(OpKernelConstruction * ctx)261 explicit StatefulStandardNormalOp(OpKernelConstruction* ctx)
262 : XlaOpKernel(ctx) {
263 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
264 }
265
Compile(XlaOpKernelContext * ctx)266 void Compile(XlaOpKernelContext* ctx) override {
267 auto sampler =
268 // Needs explicit lambda return type because it fails to be inferred.
269 [this](Algorithm alg, xla::XlaOp state, xla::XlaOp key,
270 TensorShape shape) -> SamplerReturnType {
271 xla::Shape xla_shape;
272 DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
273 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
274 xla::RngOutput value_state = xla::NormalFloatingPointDistribution(
275 key, state, BitGen(alg), xla_shape);
276 xla::XlaOp normal = MaybeConvertF32ToBF16(value_state.value, dtype_);
277 return {{normal, value_state.state}};
278 };
279 OP_REQUIRES_OK(ctx,
280 CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
281 /*shape_input_idx=*/2, sampler));
282 }
283
284 private:
285 DataType dtype_;
286
287 TF_DISALLOW_COPY_AND_ASSIGN(StatefulStandardNormalOp);
288 };
289
290 // TODO(wangpeng): Support plain float16 to get rid of the `TypeConstraint`.
291 REGISTER_XLA_OP(Name("StatefulStandardNormalV2")
292 .CompileTimeConstantInput("algorithm")
293 .CompileTimeConstantInput("shape")
294 .TypeConstraint("dtype",
295 {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
296 StatefulStandardNormalOp);
297
298 class StatefulTruncatedNormalOp : public XlaOpKernel {
299 public:
StatefulTruncatedNormalOp(OpKernelConstruction * ctx)300 explicit StatefulTruncatedNormalOp(OpKernelConstruction* ctx)
301 : XlaOpKernel(ctx) {
302 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
303 }
304
Compile(XlaOpKernelContext * ctx)305 void Compile(XlaOpKernelContext* ctx) override {
306 xla::XlaBuilder* builder = ctx->builder();
307 auto sampler =
308 // Needs explicit lambda return type because it fails to be inferred.
309 [builder, this](Algorithm alg, xla::XlaOp state, xla::XlaOp key,
310 TensorShape shape) -> SamplerReturnType {
311 xla::Shape xla_shape;
312 DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
313 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
314
315 xla::RngOutput uniform_result = StatefulRngUniform(
316 alg, key, state, xla_shape,
317 xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
318 xla::One(builder, xla_shape.element_type()));
319 xla::XlaOp uniform = uniform_result.value;
320 state = uniform_result.state;
321 xla::XlaOp truncated_normal = TruncatedNormal(uniform);
322 truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
323 return {{truncated_normal, state}};
324 };
325 OP_REQUIRES_OK(ctx,
326 CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
327 /*shape_input_idx=*/2, sampler));
328 }
329
330 private:
331 DataType dtype_;
332
333 TF_DISALLOW_COPY_AND_ASSIGN(StatefulTruncatedNormalOp);
334 };
335
336 // TODO(wangpeng): Support plain float16 to get rid of the `TypeConstraint`.
337 REGISTER_XLA_OP(Name("StatefulTruncatedNormal")
338 .CompileTimeConstantInput("algorithm")
339 .CompileTimeConstantInput("shape")
340 .TypeConstraint("dtype",
341 {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
342 StatefulTruncatedNormalOp);
343
344 class StatefulUniformIntOp : public XlaOpKernel {
345 public:
StatefulUniformIntOp(OpKernelConstruction * ctx)346 explicit StatefulUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
347 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
348 }
349
Compile(XlaOpKernelContext * ctx)350 void Compile(XlaOpKernelContext* ctx) override {
351 xla::XlaOp minval = ctx->Input(3);
352 xla::XlaOp maxval = ctx->Input(4);
353 auto sample_with_threefry =
354 [minval, maxval, this](Algorithm alg, xla::XlaOp state, xla::XlaOp key,
355 TensorShape shape) -> SamplerReturnType {
356 xla::Shape xla_shape;
357 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
358 return StatefulRngUniform(alg, key, state, xla_shape, minval, maxval);
359 };
360 OP_REQUIRES_OK(ctx,
361 CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
362 /*shape_input_idx=*/2, sample_with_threefry));
363 }
364
365 private:
366 DataType dtype_;
367
368 TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformIntOp);
369 };
370
371 REGISTER_XLA_OP(Name("StatefulUniformInt")
372 .CompileTimeConstantInput("algorithm")
373 .CompileTimeConstantInput("shape")
374 .TypeConstraint("dtype",
375 {DT_INT32, DT_UINT32, DT_INT64, DT_UINT64}),
376 StatefulUniformIntOp);
377
378 class StatefulUniformFullIntOp : public XlaOpKernel {
379 public:
StatefulUniformFullIntOp(OpKernelConstruction * ctx)380 explicit StatefulUniformFullIntOp(OpKernelConstruction* ctx)
381 : XlaOpKernel(ctx) {
382 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
383 }
384
Compile(XlaOpKernelContext * ctx)385 void Compile(XlaOpKernelContext* ctx) override {
386 auto sample_with_threefry = [this](Algorithm alg, xla::XlaOp state,
387 xla::XlaOp key,
388 TensorShape shape) -> SamplerReturnType {
389 xla::Shape xla_shape;
390 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
391 return StatefulRngUniformFullInt(alg, key, state, xla_shape);
392 };
393 OP_REQUIRES_OK(ctx,
394 CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
395 /*shape_input_idx=*/2, sample_with_threefry));
396 }
397
398 private:
399 DataType dtype_;
400
401 TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformFullIntOp);
402 };
403
404 REGISTER_XLA_OP(Name("StatefulUniformFullInt")
405 .CompileTimeConstantInput("algorithm")
406 .CompileTimeConstantInput("shape")
407 .TypeConstraint("dtype",
408 {DT_INT32, DT_UINT32, DT_INT64, DT_UINT64}),
409 StatefulUniformFullIntOp);
410
IncreaseCounter(Algorithm const & alg,xla::XlaOp counter,xla::XlaOp delta)411 xla::XlaOp IncreaseCounter(Algorithm const& alg, xla::XlaOp counter,
412 xla::XlaOp delta) {
413 // Multiplying 256 to be consistent with the CPU/GPU kernels
414 delta = delta * ConstantR0WithType(delta.builder(), xla::U64, 256);
415 if (alg == RNG_ALG_PHILOX) {
416 return xla::PhiloxIncreaseCounter(counter, delta);
417 } else {
418 return counter + delta;
419 }
420 }
421
PadRight(xla::XlaOp a,int n)422 xla::XlaOp PadRight(xla::XlaOp a, int n) {
423 return xla::Pad(a, xla::ScalarLike(a, 0),
424 xla::MakeEdgePaddingConfig({{0, n}}));
425 }
426
427 template <typename AlgEnumType = int64, bool read_old_value = false>
428 class RngSkipOp : public XlaOpKernel {
429 public:
RngSkipOp(OpKernelConstruction * ctx)430 explicit RngSkipOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
431
Compile(XlaOpKernelContext * ctx)432 void Compile(XlaOpKernelContext* ctx) override {
433 const int state_input_idx = 0;
434 const int alg_input_idx = 1;
435 const int delta_input_idx = 2;
436 xla::XlaOp var;
437 TensorShape var_shape;
438 OP_REQUIRES_OK(ctx,
439 ctx->ReadVariableInput(state_input_idx, STATE_ELEMENT_DTYPE,
440 &var_shape, &var));
441 xla::Literal alg_literal;
442 OP_REQUIRES_OK(ctx, ctx->ConstantInput(alg_input_idx, &alg_literal));
443 Algorithm alg = Algorithm(alg_literal.Get<AlgEnumType>({}));
444 OP_REQUIRES(ctx, alg == RNG_ALG_THREEFRY || alg == RNG_ALG_PHILOX,
445 errors::InvalidArgument("Unsupported algorithm id: ", alg));
446 OP_REQUIRES_OK(ctx, CheckStateShape(alg, var_shape));
447 if (read_old_value) {
448 auto counter_size = GetCounterSize(alg);
449 xla::XlaOp output = var;
450 if (RNG_MAX_COUNTER_SIZE > counter_size) {
451 // Because the size of `var` depends on the algorithm while we want the
452 // output to have a fixed size (to help shape inference), we fix the
453 // output size to be the maximal state size among algorithms, and right-
454 // pad it with zeros if var's size is smaller than that.
455 output = PadRight(output, RNG_MAX_COUNTER_SIZE - counter_size);
456 }
457 ctx->SetOutput(0, output);
458 }
459 xla::XlaOp counter;
460 xla::XlaOp key;
461 std::tie(counter, key) = StateAndKeyFromVariable(alg, var);
462 xla::XlaOp delta = ctx->Input(delta_input_idx);
463 delta = BitcastConvertType(delta, xla::U64);
464 auto new_counter = IncreaseCounter(alg, counter, delta);
465 var = StateAndKeyToVariable(alg, new_counter, key);
466 xla::PrimitiveType state_element_type;
467 OP_REQUIRES_OK(
468 ctx, DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type));
469 var = BitcastConvertType(var, state_element_type);
470 OP_REQUIRES_OK(
471 ctx, ctx->AssignVariable(state_input_idx, STATE_ELEMENT_DTYPE, var));
472 }
473
474 private:
475 TF_DISALLOW_COPY_AND_ASSIGN(RngSkipOp);
476 };
477
478 REGISTER_XLA_OP(Name("RngSkip").CompileTimeConstantInput("algorithm"),
479 RngSkipOp<>);
480
481 using RngReadAndSkipOp = RngSkipOp<int32, true>;
482
483 REGISTER_XLA_OP(Name("RngReadAndSkip").CompileTimeConstantInput("alg"),
484 RngReadAndSkipOp);
485
486 } // namespace
487 } // namespace tensorflow
488