1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/stateful_random_ops.h"
17 
18 #include <cmath>
19 
20 #include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
21 #include "tensorflow/compiler/tf2xla/lib/random.h"
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/tf2xla/type_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
27 #include "tensorflow/compiler/xla/client/lib/constants.h"
28 #include "tensorflow/compiler/xla/client/lib/math.h"
29 #include "tensorflow/compiler/xla/client/lib/prng.h"
30 #include "tensorflow/compiler/xla/client/xla_builder.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/rng_alg.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/lib/math/math_util.h"
37 
38 namespace tensorflow {
39 namespace {
40 
BitGen(Algorithm alg)41 xla::BitGeneratorTy BitGen(Algorithm alg) {
42   if (alg == RNG_ALG_PHILOX) {
43     return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
44       state =
45           xla::ConcatInDim(key.builder(), {xla::Reshape(key, {1}), state}, 0);
46       xla::XlaOp result =
47           xla::RngBitGenerator(xla::RandomAlgorithm::RNG_PHILOX, state, shape);
48       xla::XlaOp data = xla::GetTupleElement(result, 1);
49       xla::XlaOp new_state =
50           xla::Slice(xla::GetTupleElement(result, 0), {1}, {3}, {1});
51       return xla::RngOutput{data, new_state};
52     };
53   } else {
54     return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
55       state = xla::ConcatScalars(key.builder(), {key, state});
56       xla::XlaOp result = xla::RngBitGenerator(
57           xla::RandomAlgorithm::RNG_THREE_FRY, state, shape);
58       xla::XlaOp data = xla::GetTupleElement(result, 1);
59       xla::XlaOp new_state = xla::Reshape(
60           xla::Slice(xla::GetTupleElement(result, 0), {1}, {2}, {1}), {});
61       return xla::RngOutput{data, new_state};
62     };
63   }
64 }
65 
StatefulRngUniform(Algorithm alg,xla::XlaOp key,xla::XlaOp initial_state,const xla::Shape & shape,xla::XlaOp minval,xla::XlaOp maxval)66 xla::RngOutput StatefulRngUniform(Algorithm alg, xla::XlaOp key,
67                                   xla::XlaOp initial_state,
68                                   const xla::Shape& shape, xla::XlaOp minval,
69                                   xla::XlaOp maxval) {
70   xla::PrimitiveType type = shape.element_type();
71   switch (type) {
72     case xla::F32:
73     case xla::F64:
74       return xla::UniformFloatingPointDistribution(
75           key, initial_state, BitGen(alg), minval, maxval, shape);
76     case xla::U32:
77     case xla::S32:
78     case xla::U64:
79     case xla::S64:
80       return UniformIntDistribution(key, initial_state, BitGen(alg), minval,
81                                     maxval, shape);
82     default:
83       return {key.builder()->ReportError(xla::Unimplemented(
84                   "Types other than F32, U32, S32, U64 and S64 "
85                   "are not implemented by "
86                   "StatefulRngUniform; got %s",
87                   xla::primitive_util::LowercasePrimitiveTypeName(type))),
88               initial_state};
89   }
90 }
91 
StatefulRngUniformFullInt(Algorithm alg,xla::XlaOp key,xla::XlaOp initial_state,const xla::Shape & shape)92 xla::RngOutput StatefulRngUniformFullInt(Algorithm alg, xla::XlaOp key,
93                                          xla::XlaOp initial_state,
94                                          const xla::Shape& shape) {
95   xla::PrimitiveType type = shape.element_type();
96   xla::RngOutput output = BitGen(alg)(key, initial_state, shape);
97   switch (type) {
98     case xla::U32:
99     case xla::U64:
100       return output;
101     case xla::S32:
102     case xla::S64:
103       output.value = BitcastConvertType(output.value, type);
104       return output;
105     default:
106       return {
107           key.builder()->ReportError(xla::Unimplemented(
108               "Types other than U32, S32, U64 and S64 are not implemented by "
109               "StatefulRngUniformFullInt; got: %s",
110               xla::primitive_util::LowercasePrimitiveTypeName(type))),
111           initial_state};
112   }
113 }
114 
115 using SamplerReturnType = xla::StatusOr<xla::RngOutput>;
116 
GetMinStateSize(Algorithm alg)117 int64 GetMinStateSize(Algorithm alg) {
118   if (alg == RNG_ALG_PHILOX) {
119     return PHILOX_MIN_STATE_SIZE;
120   }
121   return THREEFRY_MIN_STATE_SIZE;
122 }
123 
CheckStateShape(Algorithm alg,const TensorShape & shape)124 Status CheckStateShape(Algorithm alg, const TensorShape& shape) {
125   if (shape.dims() != 1) {
126     return errors::InvalidArgument(
127         "RNG state must have one and only one dimension, not ", shape.dims());
128   }
129   auto state_size = shape.dim_size(0);
130   auto min_state_size = GetMinStateSize(alg);
131   if (state_size < min_state_size) {
132     return errors::InvalidArgument("The size of the state must be at least ",
133                                    min_state_size, "; got ", state_size);
134   }
135   return Status::OK();
136 }
137 
StateAndKeyFromVariable(Algorithm alg,xla::XlaOp var)138 std::pair<xla::XlaOp, xla::XlaOp> StateAndKeyFromVariable(Algorithm alg,
139                                                           xla::XlaOp var) {
140   if (alg == RNG_ALG_THREEFRY) {
141     static constexpr int kStateSize = 1;
142     auto state = BitcastConvertType(
143         xla::Reshape(xla::Slice(var, {0}, {kStateSize}, {1}), {}), xla::U64);
144     auto key = BitcastConvertType(
145         xla::Reshape(xla::Slice(var, {kStateSize}, {kStateSize + 1}, {1}), {}),
146         xla::U64);
147     return std::make_pair(state, key);
148   } else {
149     static constexpr int kStateSize = 2;
150     auto state =
151         BitcastConvertType(xla::Slice(var, {0}, {kStateSize}, {1}), xla::U64);
152     auto key = xla::Reshape(
153         BitcastConvertType(xla::Slice(var, {kStateSize}, {kStateSize + 1}, {1}),
154                            xla::U64),
155         {});
156     return std::make_pair(state, key);
157   }
158 }
159 
StateAndKeyToVariable(Algorithm alg,xla::XlaOp state,xla::XlaOp key)160 xla::XlaOp StateAndKeyToVariable(Algorithm alg, xla::XlaOp state,
161                                  xla::XlaOp key) {
162   auto builder = state.builder();
163   if (alg == RNG_ALG_THREEFRY) {
164     return ConcatScalars(builder, {state, key});
165   } else {
166     return ConcatInDim(builder, {state, xla::Reshape(key, {1})}, 0);
167   }
168 }
169 
170 // A helper function containing the common part of several kernels below.
171 // Precondition: 'algorithm' and 'shape' are compile-time constants.
CompileImpl(XlaOpKernelContext * ctx,int state_input_idx,int alg_input_idx,int shape_input_idx,std::function<SamplerReturnType (Algorithm,xla::XlaOp,xla::XlaOp,TensorShape)> const & sampler)172 Status CompileImpl(
173     XlaOpKernelContext* ctx, int state_input_idx, int alg_input_idx,
174     int shape_input_idx,
175     std::function<SamplerReturnType(Algorithm, xla::XlaOp, xla::XlaOp,
176                                     TensorShape)> const& sampler) {
177   auto alg_shape = ctx->InputShape(alg_input_idx);
178   if (alg_shape.dims() != 0) {
179     return errors::InvalidArgument("algorithm must be of shape [], not ",
180                                    alg_shape.DebugString());
181   }
182   xla::Literal alg_literal;
183   TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal));
184   Algorithm alg = Algorithm(alg_literal.Get<int64>({}));
185   if (!(alg == RNG_ALG_THREEFRY || alg == RNG_ALG_PHILOX)) {
186     return errors::InvalidArgument("Unsupported algorithm id: ", alg);
187   }
188 
189   xla::XlaOp var;
190   TensorShape var_shape;
191   TF_RETURN_IF_ERROR(ctx->ReadVariableInput(
192       state_input_idx, STATE_ELEMENT_DTYPE, &var_shape, &var));
193   TF_RETURN_IF_ERROR(CheckStateShape(alg, var_shape));
194   TensorShape shape;
195   TF_RETURN_IF_ERROR(ctx->ConstantInputAsShape(shape_input_idx, &shape));
196   xla::XlaOp state;
197   xla::XlaOp key;
198   std::tie(state, key) = StateAndKeyFromVariable(alg, var);
199   auto status_or_value = sampler(alg, state, key, shape);
200   if (!status_or_value.ok()) {
201     return status_or_value.status();
202   }
203   xla::RngOutput value_state = status_or_value.ConsumeValueOrDie();
204   state = value_state.state;
205   ctx->SetOutput(0, value_state.value);
206   var = StateAndKeyToVariable(alg, state, key);
207   xla::PrimitiveType state_element_type;
208   TF_RETURN_IF_ERROR(
209       DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type));
210   var = BitcastConvertType(var, state_element_type);
211   TF_RETURN_IF_ERROR(
212       ctx->AssignVariable(state_input_idx, STATE_ELEMENT_DTYPE, var));
213   return Status::OK();
214 }
215 
216 class StatefulUniformOp : public XlaOpKernel {
217  public:
StatefulUniformOp(OpKernelConstruction * ctx)218   explicit StatefulUniformOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
219     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
220   }
221 
Compile(XlaOpKernelContext * ctx)222   void Compile(XlaOpKernelContext* ctx) override {
223     xla::XlaBuilder* builder = ctx->builder();
224     auto sampler = [builder, this](Algorithm alg, xla::XlaOp state,
225                                    xla::XlaOp key,
226                                    TensorShape shape) -> SamplerReturnType {
227       xla::Shape xla_shape;
228       DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
229       TF_RETURN_IF_ERROR(TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
230       xla::PrimitiveType rng_primitive_type = xla_shape.element_type();
231       xla::RngOutput uniform_state = StatefulRngUniform(
232           alg, key, state, xla_shape,
233           xla::ConstantR0WithType(builder, rng_primitive_type, 0.0),
234           xla::ConstantR0WithType(builder, rng_primitive_type, 1.0));
235       xla::XlaOp uniform = uniform_state.value;
236       state = uniform_state.state;
237       uniform = MaybeConvertF32ToBF16(uniform, dtype_);
238       return {{uniform, state}};
239     };
240     OP_REQUIRES_OK(ctx,
241                    CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
242                                /*shape_input_idx=*/2, sampler));
243   }
244 
245  private:
246   DataType dtype_;
247 
248   TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformOp);
249 };
250 
251 // TODO(wangpeng): Support plain float16 to get rid of the `TypeConstraint`.
252 REGISTER_XLA_OP(Name("StatefulUniform")
253                     .CompileTimeConstantInput("algorithm")
254                     .CompileTimeConstantInput("shape")
255                     .TypeConstraint("dtype",
256                                     {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
257                 StatefulUniformOp);
258 
259 class StatefulStandardNormalOp : public XlaOpKernel {
260  public:
StatefulStandardNormalOp(OpKernelConstruction * ctx)261   explicit StatefulStandardNormalOp(OpKernelConstruction* ctx)
262       : XlaOpKernel(ctx) {
263     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
264   }
265 
Compile(XlaOpKernelContext * ctx)266   void Compile(XlaOpKernelContext* ctx) override {
267     auto sampler =
268         // Needs explicit lambda return type because it fails to be inferred.
269         [this](Algorithm alg, xla::XlaOp state, xla::XlaOp key,
270                TensorShape shape) -> SamplerReturnType {
271       xla::Shape xla_shape;
272       DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
273       TF_RETURN_IF_ERROR(TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
274       xla::RngOutput value_state = xla::NormalFloatingPointDistribution(
275           key, state, BitGen(alg), xla_shape);
276       xla::XlaOp normal = MaybeConvertF32ToBF16(value_state.value, dtype_);
277       return {{normal, value_state.state}};
278     };
279     OP_REQUIRES_OK(ctx,
280                    CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
281                                /*shape_input_idx=*/2, sampler));
282   }
283 
284  private:
285   DataType dtype_;
286 
287   TF_DISALLOW_COPY_AND_ASSIGN(StatefulStandardNormalOp);
288 };
289 
290 // TODO(wangpeng): Support plain float16 to get rid of the `TypeConstraint`.
291 REGISTER_XLA_OP(Name("StatefulStandardNormalV2")
292                     .CompileTimeConstantInput("algorithm")
293                     .CompileTimeConstantInput("shape")
294                     .TypeConstraint("dtype",
295                                     {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
296                 StatefulStandardNormalOp);
297 
298 class StatefulTruncatedNormalOp : public XlaOpKernel {
299  public:
StatefulTruncatedNormalOp(OpKernelConstruction * ctx)300   explicit StatefulTruncatedNormalOp(OpKernelConstruction* ctx)
301       : XlaOpKernel(ctx) {
302     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
303   }
304 
Compile(XlaOpKernelContext * ctx)305   void Compile(XlaOpKernelContext* ctx) override {
306     xla::XlaBuilder* builder = ctx->builder();
307     auto sampler =
308         // Needs explicit lambda return type because it fails to be inferred.
309         [builder, this](Algorithm alg, xla::XlaOp state, xla::XlaOp key,
310                         TensorShape shape) -> SamplerReturnType {
311       xla::Shape xla_shape;
312       DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
313       TF_RETURN_IF_ERROR(TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
314 
315       xla::RngOutput uniform_result = StatefulRngUniform(
316           alg, key, state, xla_shape,
317           xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
318           xla::One(builder, xla_shape.element_type()));
319       xla::XlaOp uniform = uniform_result.value;
320       state = uniform_result.state;
321       xla::XlaOp truncated_normal = TruncatedNormal(uniform);
322       truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
323       return {{truncated_normal, state}};
324     };
325     OP_REQUIRES_OK(ctx,
326                    CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
327                                /*shape_input_idx=*/2, sampler));
328   }
329 
330  private:
331   DataType dtype_;
332 
333   TF_DISALLOW_COPY_AND_ASSIGN(StatefulTruncatedNormalOp);
334 };
335 
336 // TODO(wangpeng): Support plain float16 to get rid of the `TypeConstraint`.
337 REGISTER_XLA_OP(Name("StatefulTruncatedNormal")
338                     .CompileTimeConstantInput("algorithm")
339                     .CompileTimeConstantInput("shape")
340                     .TypeConstraint("dtype",
341                                     {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
342                 StatefulTruncatedNormalOp);
343 
344 class StatefulUniformIntOp : public XlaOpKernel {
345  public:
StatefulUniformIntOp(OpKernelConstruction * ctx)346   explicit StatefulUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
347     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
348   }
349 
Compile(XlaOpKernelContext * ctx)350   void Compile(XlaOpKernelContext* ctx) override {
351     xla::XlaOp minval = ctx->Input(3);
352     xla::XlaOp maxval = ctx->Input(4);
353     auto sample_with_threefry =
354         [minval, maxval, this](Algorithm alg, xla::XlaOp state, xla::XlaOp key,
355                                TensorShape shape) -> SamplerReturnType {
356       xla::Shape xla_shape;
357       TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
358       return StatefulRngUniform(alg, key, state, xla_shape, minval, maxval);
359     };
360     OP_REQUIRES_OK(ctx,
361                    CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
362                                /*shape_input_idx=*/2, sample_with_threefry));
363   }
364 
365  private:
366   DataType dtype_;
367 
368   TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformIntOp);
369 };
370 
371 REGISTER_XLA_OP(Name("StatefulUniformInt")
372                     .CompileTimeConstantInput("algorithm")
373                     .CompileTimeConstantInput("shape")
374                     .TypeConstraint("dtype",
375                                     {DT_INT32, DT_UINT32, DT_INT64, DT_UINT64}),
376                 StatefulUniformIntOp);
377 
378 class StatefulUniformFullIntOp : public XlaOpKernel {
379  public:
StatefulUniformFullIntOp(OpKernelConstruction * ctx)380   explicit StatefulUniformFullIntOp(OpKernelConstruction* ctx)
381       : XlaOpKernel(ctx) {
382     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
383   }
384 
Compile(XlaOpKernelContext * ctx)385   void Compile(XlaOpKernelContext* ctx) override {
386     auto sample_with_threefry = [this](Algorithm alg, xla::XlaOp state,
387                                        xla::XlaOp key,
388                                        TensorShape shape) -> SamplerReturnType {
389       xla::Shape xla_shape;
390       TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
391       return StatefulRngUniformFullInt(alg, key, state, xla_shape);
392     };
393     OP_REQUIRES_OK(ctx,
394                    CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
395                                /*shape_input_idx=*/2, sample_with_threefry));
396   }
397 
398  private:
399   DataType dtype_;
400 
401   TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformFullIntOp);
402 };
403 
404 REGISTER_XLA_OP(Name("StatefulUniformFullInt")
405                     .CompileTimeConstantInput("algorithm")
406                     .CompileTimeConstantInput("shape")
407                     .TypeConstraint("dtype",
408                                     {DT_INT32, DT_UINT32, DT_INT64, DT_UINT64}),
409                 StatefulUniformFullIntOp);
410 
IncreaseCounter(Algorithm const & alg,xla::XlaOp counter,xla::XlaOp delta)411 xla::XlaOp IncreaseCounter(Algorithm const& alg, xla::XlaOp counter,
412                            xla::XlaOp delta) {
413   // Multiplying 256 to be consistent with the CPU/GPU kernels
414   delta = delta * ConstantR0WithType(delta.builder(), xla::U64, 256);
415   if (alg == RNG_ALG_PHILOX) {
416     return xla::PhiloxIncreaseCounter(counter, delta);
417   } else {
418     return counter + delta;
419   }
420 }
421 
PadRight(xla::XlaOp a,int n)422 xla::XlaOp PadRight(xla::XlaOp a, int n) {
423   return xla::Pad(a, xla::ScalarLike(a, 0),
424                   xla::MakeEdgePaddingConfig({{0, n}}));
425 }
426 
427 template <typename AlgEnumType = int64, bool read_old_value = false>
428 class RngSkipOp : public XlaOpKernel {
429  public:
RngSkipOp(OpKernelConstruction * ctx)430   explicit RngSkipOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
431 
Compile(XlaOpKernelContext * ctx)432   void Compile(XlaOpKernelContext* ctx) override {
433     const int state_input_idx = 0;
434     const int alg_input_idx = 1;
435     const int delta_input_idx = 2;
436     xla::XlaOp var;
437     TensorShape var_shape;
438     OP_REQUIRES_OK(ctx,
439                    ctx->ReadVariableInput(state_input_idx, STATE_ELEMENT_DTYPE,
440                                           &var_shape, &var));
441     xla::Literal alg_literal;
442     OP_REQUIRES_OK(ctx, ctx->ConstantInput(alg_input_idx, &alg_literal));
443     Algorithm alg = Algorithm(alg_literal.Get<AlgEnumType>({}));
444     OP_REQUIRES(ctx, alg == RNG_ALG_THREEFRY || alg == RNG_ALG_PHILOX,
445                 errors::InvalidArgument("Unsupported algorithm id: ", alg));
446     OP_REQUIRES_OK(ctx, CheckStateShape(alg, var_shape));
447     if (read_old_value) {
448       auto counter_size = GetCounterSize(alg);
449       xla::XlaOp output = var;
450       if (RNG_MAX_COUNTER_SIZE > counter_size) {
451         // Because the size of `var` depends on the algorithm while we want the
452         // output to have a fixed size (to help shape inference), we fix the
453         // output size to be the maximal state size among algorithms, and right-
454         // pad it with zeros if var's size is smaller than that.
455         output = PadRight(output, RNG_MAX_COUNTER_SIZE - counter_size);
456       }
457       ctx->SetOutput(0, output);
458     }
459     xla::XlaOp counter;
460     xla::XlaOp key;
461     std::tie(counter, key) = StateAndKeyFromVariable(alg, var);
462     xla::XlaOp delta = ctx->Input(delta_input_idx);
463     delta = BitcastConvertType(delta, xla::U64);
464     auto new_counter = IncreaseCounter(alg, counter, delta);
465     var = StateAndKeyToVariable(alg, new_counter, key);
466     xla::PrimitiveType state_element_type;
467     OP_REQUIRES_OK(
468         ctx, DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type));
469     var = BitcastConvertType(var, state_element_type);
470     OP_REQUIRES_OK(
471         ctx, ctx->AssignVariable(state_input_idx, STATE_ELEMENT_DTYPE, var));
472   }
473 
474  private:
475   TF_DISALLOW_COPY_AND_ASSIGN(RngSkipOp);
476 };
477 
478 REGISTER_XLA_OP(Name("RngSkip").CompileTimeConstantInput("algorithm"),
479                 RngSkipOp<>);
480 
481 using RngReadAndSkipOp = RngSkipOp<int32, true>;
482 
483 REGISTER_XLA_OP(Name("RngReadAndSkip").CompileTimeConstantInput("alg"),
484                 RngReadAndSkipOp);
485 
486 }  // namespace
487 }  // namespace tensorflow
488