1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Utilities for dealing with Literal protobufs.
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
19 #define TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
20 
21 #include <functional>
22 #include <initializer_list>
23 #include <iterator>
24 #include <memory>
25 #include <ostream>
26 #include <string>
27 #include <type_traits>
28 #include <vector>
29 
30 #include "absl/memory/memory.h"
31 #include "absl/strings/string_view.h"
32 #include "absl/types/span.h"
33 #include "tensorflow/compiler/xla/array2d.h"
34 #include "tensorflow/compiler/xla/array3d.h"
35 #include "tensorflow/compiler/xla/array4d.h"
36 #include "tensorflow/compiler/xla/index_util.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/primitive_util.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/status_macros.h"
42 #include "tensorflow/compiler/xla/types.h"
43 #include "tensorflow/compiler/xla/util.h"
44 #include "tensorflow/compiler/xla/xla_data.pb.h"
45 #include "tensorflow/core/lib/core/bitmap.h"
46 #include "tensorflow/core/lib/core/status.h"
47 #include "tensorflow/core/platform/logging.h"
48 #include "tensorflow/core/platform/macros.h"
49 #include "tensorflow/core/platform/protobuf.h"
50 #include "tensorflow/core/platform/types.h"
51 
52 namespace xla {
53 
54 class LiteralUtil {
55  public:
56   LiteralUtil() = delete;
57 
58   // Returns a literal scalar representing the first element.
59   static Literal GetFirstScalarLiteral(const LiteralSlice& literal);
60 
61   // Creates a new literal of a given rank. To minimize ambiguity (for users
62   // and the compiler) these CreateR[0-2] methods should explicitly specify the
63   // native type. For example:
64   //
65   //  CreateR1<float>({1.0, 42.0});
66   //  CreateR2<uint32>({{1, 2}, {3, 4}});
67   //
68   // The variants not ending with WithLayout use the default XLA layout for the
69   // literal's linear representation in memory.
70   template <typename NativeT>
71   static Literal CreateR0(NativeT value);
72   template <typename NativeT>
73   static Literal CreateR1(absl::Span<const NativeT> values);
74   static Literal CreateR1(const tensorflow::core::Bitmap& values);
75   template <typename NativeT>
76   static Literal CreateR2(
77       std::initializer_list<std::initializer_list<NativeT>> values);
78   template <typename NativeT>
79   static Literal CreateR2WithLayout(
80       std::initializer_list<std::initializer_list<NativeT>> values,
81       const Layout& layout);
82   template <typename NativeT>
83   static Literal CreateR3(std::initializer_list<
84                           std::initializer_list<std::initializer_list<NativeT>>>
85                               values);
86   template <typename NativeT>
87   static Literal CreateR3WithLayout(
88       std::initializer_list<
89           std::initializer_list<std::initializer_list<NativeT>>>
90           values,
91       const Layout& layout);
92   template <typename NativeT>
93   static Literal CreateR4(
94       std::initializer_list<std::initializer_list<
95           std::initializer_list<std::initializer_list<NativeT>>>>
96           values);
97   template <typename NativeT>
98   static Literal CreateR4WithLayout(
99       std::initializer_list<std::initializer_list<
100           std::initializer_list<std::initializer_list<NativeT>>>>
101           values,
102       const Layout& layout);
103 
104   // Creates a scalar literal value zero of the given primitive type.
105   static Literal Zero(PrimitiveType primitive_type);
106   // Creates a scalar literal value one of the given primitive type.
107   static Literal One(PrimitiveType primitive_type);
108   // Creates a scalar literal value containing the minimum value of the given
109   // primitive type. For floating-point types, returns -inf.
110   static Literal MinValue(PrimitiveType primitive_type);
111   // Creates a scalar literal value containing the maximum value of the given
112   // primitive type. For floating-point types, returns inf.
113   static Literal MaxValue(PrimitiveType primitive_type);
114   // Creates a scalar literal value containing the NaN value of the given
115   // primitive type. Fail for non-inexact types. For complex types, returns a
116   // nan + nan * j value.
117   static StatusOr<Literal> NanValue(PrimitiveType primitive_type);
118   // Creates a literal of the given shape where each element is `value`.
119   template <typename NativeT>
120   static Literal CreateFullWithDescendingLayout(
121       absl::Span<const int64> dimensions, NativeT value);
122 
123   // Creates a new literal from an Array type. The variants not ending with
124   // WithLayout use the default XLA layout for the literal's linear
125   // representation in memory.
126   template <typename NativeT>
127   static Literal CreateFromArray(const Array<NativeT>& values);
128   template <typename NativeT>
129   static Literal CreateFromArrayWithLayout(const Array<NativeT>& values,
130                                            const Layout& layout);
131   template <typename NativeT>
132   static Literal CreateR2FromArray2D(const Array2D<NativeT>& values);
133   template <typename NativeT>
134   static Literal CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
135                                                const Layout& layout);
136   template <typename NativeT>
137   static Literal CreateR3FromArray3D(const Array3D<NativeT>& values);
138   template <typename NativeT>
139   static Literal CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
140                                                const Layout& layout);
141   template <typename NativeT>
142   static Literal CreateR4FromArray4D(const Array4D<NativeT>& values);
143   template <typename NativeT>
144   static Literal CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
145                                                const Layout& layout);
146 
147   // Creates a new vector of U8s literal value from a string.
148   static Literal CreateR1U8(absl::string_view value);
149 
150   // Creates a linspace-populated literal with the given number of rows and
151   // columns.
152   static Literal CreateR2F32Linspace(float from, float to, int64 rows,
153                                      int64 cols);
154 
155   // Creates a literal that projects the (x, y) dimensions given in values into
156   // the z dimension given by "projection".
157   template <typename NativeT>
158   static Literal CreateR3Projected(
159       std::initializer_list<std::initializer_list<NativeT>> values,
160       int64 projection);
161 
162   // Creates a literal that projects the (x, y) dimensions given in values into
163   // the z and p dimensions given.
164   template <typename NativeT>
165   static Literal CreateR4Projected(
166       std::initializer_list<std::initializer_list<NativeT>> values,
167       int64 projection_p, int64 projection_z);
168 
169   // Returns an identity matrix (rank 2) with the given row and column count.
170   template <typename NativeT>
171   static Literal MakeIdentityR2(int64 size);
172 
173   // Returns a tuple literal composed of given literals. Data is copied from the
174   // given elements into the returned literal.
175   static Literal MakeTuple(absl::Span<const Literal* const> elements);
176 
177   static Literal MakeTupleFromSlices(absl::Span<const LiteralSlice> elements);
178 
179   // As above, but intended to be invoked with move semantics; i.e.
180   //
181   //  std::vector<Literal> elements = ...;
182   //  auto result = LiteralUtil::MakeTupleOwned(std::move(elements));
183   //
184   // This would have been declared as an overload, but there is ambiguity
185   // in invocation between the above signature and this one.
186   static Literal MakeTupleOwned(std::vector<Literal> elements);
187 
188   // This overload lets you pass a list of Literals to MakeTupleOwned:
189   //
190   //   LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...).
191   //
192   // Simply relying on the MakeTupleOwned(std::vector<Literal>)
193   // overload doesn't work because std::initializer_list's elements are always
194   // const.
195   //
196   // The arguments to this function must all be Literal.
197   template <typename... Ts>
MakeTupleOwned(Ts...elements)198   static Literal MakeTupleOwned(Ts... elements) {
199     std::array<Literal, sizeof...(Ts)> arr{std::move(elements)...};
200     std::vector<Literal> v;
201     v.insert(v.begin(), std::make_move_iterator(arr.begin()),
202              std::make_move_iterator(arr.end()));
203     return MakeTupleOwned(std::move(v));
204   }
205 
206   // Create a constant token literal. Token types have no value.
207   static Literal CreateToken();
208 
209   // Creates a new Literal object with its values havings the primitive_type
210   // type, and with dimensions defined by the dimensions parameter.
211   // The content of the literal values is the default value of the primitive
212   // type of literal itself (0 for numeric types, and false for predicates).
213   static Literal CreateFromDimensions(PrimitiveType primitive_type,
214                                       absl::Span<const int64> dimensions);
215 
216   // If the given literal's data type is bfloat16, converts it to a float
217   // literal; otherwise, returns a copy of it. If the literal is a tuple,
218   // recursively converts its elements.
219   static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal);
220 
221   // If the given literal's data type is bfloat16, converts it to a double
222   // literal; otherwise, returns a copy of it. If the literal is a tuple,
223   // recursively converts its elements.
224   static Literal ConvertBF16ToF64(const LiteralSlice& bf16_literal);
225 
226   // If the given literal's data type is float, converts it to a bfloat16
227   // literal; otherwise, returns a copy of it. If the literal is a tuple,
228   // recursively converts its elements.
229   static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal);
230 
231   // If the given literal's data type is float, converts it to a double
232   // literal; otherwise, returns a copy of it. If the literal is a tuple,
233   // recursively converts its elements.
234   static Literal ConvertF32ToF64(const LiteralSlice& f32_literal);
235 
236   // If the given literal's data type is double, converts it to a bfloat16
237   // literal; otherwise, returns a copy of it. If the literal is a tuple,
238   // recursively converts its elements.
239   static Literal ConvertF64ToBF16(const LiteralSlice& f64_literal);
240 
241   // If the given literal's data type is double, converts it to a bfloat16
242   // literal; otherwise, returns a copy of it. If the literal is a tuple,
243   // recursively converts its elements.
244   static Literal ConvertF64ToF32(const LiteralSlice& f64_literal);
245 
246   // Creates a literal with a new shape with the given new dimensions using the
247   // data in the given input literal. For reshaping purposes the (flat) data
248   // buffer of the input literal is assumed to have the given minor_to_major
249   // layout order.
250   static Literal ReshapeSlice(absl::Span<const int64> new_dimensions,
251                               absl::Span<const int64> minor_to_major,
252                               const LiteralSlice& literal);
253 
254   // Creates a literal with the supplied shape, and uses the provided value
255   // generator to populate the literal's values.
256   // Returns the new literal object, or an error Status if failed.
257   template <
258       PrimitiveType type,
259       typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
260   static StatusOr<Literal> CreateLiteralWithGenerator(
261       const Shape& shape,
262       const std::function<T(absl::Span<const int64>)>& generator);
263 
264   // Creates a literal with the supplied shape, and initializes the literal
265   // values using a normal distribution with given mean and stddev standard
266   // deviation, and using the engine as entropy generator.
267   // Returns the new literal object, or an error Status if failed.
268   template <
269       PrimitiveType type, typename E,
270       typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
271   static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, E* engine,
272                                                T mean, T stddev);
273 
274   // Creates a literal with the supplied shape, and initializes the literal
275   // values using a normal distribution with given mean and stddev standard
276   // deviation.
277   // Returns the new literal object, or an error Status if failed.
278   template <
279       PrimitiveType type,
280       typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
281   static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, T mean,
282                                                T stddev);
283 
284   //
285   // End of factory methods.
286 
287   // Returns a multi-dimensional index as a string. For example: '{7, 8}' will
288   // be returned for a 2-dimensional index with dimension 0 index equal to 7,
289   // dimension 1 equal to 8.
290   static string MultiIndexAsString(absl::Span<const int64> multi_index);
291 };
292 
293 std::ostream& operator<<(std::ostream& out, const Literal& literal);
294 
295 template <typename NativeT>
CreateR0(NativeT value)296 /* static */ Literal LiteralUtil::CreateR0(NativeT value) {
297   Literal literal(ShapeUtil::MakeShape(
298       primitive_util::NativeToPrimitiveType<NativeT>(), {}));
299   literal.Set({}, value);
300   return literal;
301 }
302 
303 template <typename NativeT>
CreateR1(absl::Span<const NativeT> values)304 /* static */ Literal LiteralUtil::CreateR1(absl::Span<const NativeT> values) {
305   Literal literal(
306       ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
307                            {static_cast<int64>(values.size())}));
308   literal.PopulateR1(values);
309   return literal;
310 }
311 
312 template <typename NativeT>
CreateR2WithLayout(std::initializer_list<std::initializer_list<NativeT>> values,const Layout & layout)313 /* static */ Literal LiteralUtil::CreateR2WithLayout(
314     std::initializer_list<std::initializer_list<NativeT>> values,
315     const Layout& layout) {
316   Literal literal(ShapeUtil::MakeShapeWithLayout(
317       primitive_util::NativeToPrimitiveType<NativeT>(),
318       {static_cast<int64>(values.size()),
319        static_cast<int64>(values.begin()->size())},
320       AsInt64Slice(layout.minor_to_major())));
321   literal.PopulateR2(values);
322   return literal;
323 }
324 
325 template <typename NativeT>
CreateR2(std::initializer_list<std::initializer_list<NativeT>> values)326 /* static */ Literal LiteralUtil::CreateR2(
327     std::initializer_list<std::initializer_list<NativeT>> values) {
328   return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
329 }
330 
331 template <typename NativeT>
CreateR3WithLayout(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> values,const Layout & layout)332 /* static */ Literal LiteralUtil::CreateR3WithLayout(
333     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
334         values,
335     const Layout& layout) {
336   const int64 d0 = values.size();
337   const int64 d1 = values.begin()->size();
338   const int64 d2 = values.begin()->begin()->size();
339   Array3D<NativeT> tmp(d0, d1, d2);
340   int64 i0 = 0;
341   for (auto d1_values : values) {
342     int64 i1 = 0;
343     for (auto d2_values : d1_values) {
344       int64 i2 = 0;
345       for (auto value : d2_values) {
346         tmp(i0, i1, i2) = value;
347         ++i2;
348       }
349       ++i1;
350     }
351     ++i0;
352   }
353   return CreateR3FromArray3DWithLayout(tmp, layout);
354 }
355 
356 template <typename NativeT>
CreateR3(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> values)357 /* static */ Literal LiteralUtil::CreateR3(
358     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
359         values) {
360   return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
361 }
362 
363 template <typename NativeT>
CreateR4WithLayout(std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>> values,const Layout & layout)364 /* static */ Literal LiteralUtil::CreateR4WithLayout(
365     std::initializer_list<std::initializer_list<
366         std::initializer_list<std::initializer_list<NativeT>>>>
367         values,
368     const Layout& layout) {
369   const int64 d0 = values.size();
370   const int64 d1 = values.begin()->size();
371   const int64 d2 = values.begin()->begin()->size();
372   const int64 d3 = values.begin()->begin()->begin()->size();
373   Array4D<NativeT> tmp(d0, d1, d2, d3);
374   int64 i0 = 0;
375   for (auto d1_values : values) {
376     int64 i1 = 0;
377     for (auto d2_values : d1_values) {
378       int64 i2 = 0;
379       for (auto d3_values : d2_values) {
380         int64 i3 = 0;
381         for (auto value : d3_values) {
382           tmp(i0, i1, i2, i3) = value;
383           ++i3;
384         }
385         ++i2;
386       }
387       ++i1;
388     }
389     ++i0;
390   }
391   return CreateR4FromArray4DWithLayout(tmp, layout);
392 }
393 
394 template <typename NativeT>
CreateR4(std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>> values)395 /* static */ Literal LiteralUtil::CreateR4(
396     std::initializer_list<std::initializer_list<
397         std::initializer_list<std::initializer_list<NativeT>>>>
398         values) {
399   return CreateR4WithLayout(values, LayoutUtil::GetDefaultLayoutForR4());
400 }
401 
402 template <typename NativeT>
CreateFromArrayWithLayout(const Array<NativeT> & values,const Layout & layout)403 /* static */ Literal LiteralUtil::CreateFromArrayWithLayout(
404     const Array<NativeT>& values, const Layout& layout) {
405   Literal literal(ShapeUtil::MakeShapeWithLayout(
406       primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
407       AsInt64Slice(layout.minor_to_major())));
408   literal.PopulateFromArray(values);
409   return literal;
410 }
411 
412 template <typename NativeT>
CreateFromArray(const Array<NativeT> & values)413 /* static */ Literal LiteralUtil::CreateFromArray(
414     const Array<NativeT>& values) {
415   return CreateFromArrayWithLayout(
416       values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
417 }
418 
419 template <typename NativeT>
CreateR2FromArray2DWithLayout(const Array2D<NativeT> & values,const Layout & layout)420 /* static */ Literal LiteralUtil::CreateR2FromArray2DWithLayout(
421     const Array2D<NativeT>& values, const Layout& layout) {
422   return CreateFromArrayWithLayout(values, layout);
423 }
424 
425 template <typename NativeT>
CreateR2FromArray2D(const Array2D<NativeT> & values)426 /* static */ Literal LiteralUtil::CreateR2FromArray2D(
427     const Array2D<NativeT>& values) {
428   return CreateFromArray(values);
429 }
430 
431 template <typename NativeT>
CreateR3FromArray3DWithLayout(const Array3D<NativeT> & values,const Layout & layout)432 /* static */ Literal LiteralUtil::CreateR3FromArray3DWithLayout(
433     const Array3D<NativeT>& values, const Layout& layout) {
434   return CreateFromArrayWithLayout(values, layout);
435 }
436 
437 template <typename NativeT>
CreateR3FromArray3D(const Array3D<NativeT> & values)438 /* static */ Literal LiteralUtil::CreateR3FromArray3D(
439     const Array3D<NativeT>& values) {
440   return CreateFromArray(values);
441 }
442 
443 template <typename NativeT>
CreateR3Projected(std::initializer_list<std::initializer_list<NativeT>> values,int64 projection)444 /* static */ Literal LiteralUtil::CreateR3Projected(
445     std::initializer_list<std::initializer_list<NativeT>> values,
446     int64 projection) {
447   int64 dim0_size = projection;
448   int64 dim1_size = values.size();
449   int64 dim2_size = values.begin()->size();
450 
451   Array3D<NativeT> array(dim0_size, dim1_size, dim2_size);
452   for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) {
453     int64 dim1 = 0;
454     for (auto inner_list : values) {
455       int64 dim2 = 0;
456       for (auto value : inner_list) {
457         array(dim0, dim1, dim2) = value;
458         ++dim2;
459       }
460       CHECK_EQ(dim2_size, dim2);
461       ++dim1;
462     }
463     CHECK_EQ(dim1_size, dim1);
464   }
465   return CreateR3FromArray3D(array);
466 }
467 
468 template <typename NativeT>
CreateR4Projected(std::initializer_list<std::initializer_list<NativeT>> values,int64 projection_p,int64 projection_z)469 /* static */ Literal LiteralUtil::CreateR4Projected(
470     std::initializer_list<std::initializer_list<NativeT>> values,
471     int64 projection_p, int64 projection_z) {
472   int64 dim0_size = projection_p;
473   int64 dim1_size = projection_z;
474   int64 dim2_size = values.size();
475   int64 dim3_size = values.begin()->size();
476 
477   Array4D<NativeT> array(dim0_size, dim1_size, dim2_size, dim3_size);
478   for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) {
479     for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) {
480       int64 dim2 = 0;
481       for (auto inner_list : values) {
482         int64 dim3 = 0;
483         for (auto value : inner_list) {
484           array(dim0, dim1, dim2, dim3) = value;
485           ++dim3;
486         }
487         CHECK_EQ(dim3_size, dim3);
488         ++dim2;
489       }
490       CHECK_EQ(dim2_size, dim2);
491     }
492   }
493   return CreateR4FromArray4D(array);
494 }
495 
496 template <typename NativeT>
CreateR4FromArray4D(const Array4D<NativeT> & values)497 /* static */ Literal LiteralUtil::CreateR4FromArray4D(
498     const Array4D<NativeT>& values) {
499   return CreateFromArray(values);
500 }
501 
502 template <typename NativeT>
CreateR4FromArray4DWithLayout(const Array4D<NativeT> & values,const Layout & layout)503 /* static */ Literal LiteralUtil::CreateR4FromArray4DWithLayout(
504     const Array4D<NativeT>& values, const Layout& layout) {
505   return CreateFromArrayWithLayout(values, layout);
506 }
507 
508 // Returns an identity matrix (rank 2) with the given row and column count.
509 template <typename NativeT>
MakeIdentityR2(int64 size)510 /* static */ Literal LiteralUtil::MakeIdentityR2(int64 size) {
511   Array2D<NativeT> array(size, size, 0);
512   for (int64 i = 0; i < size; ++i) {
513     array(i, i) = 1;
514   }
515   return CreateR2FromArray2D(array);
516 }
517 
518 template <typename NativeT>
CreateFullWithDescendingLayout(absl::Span<const int64> dimensions,NativeT value)519 /* static */ Literal LiteralUtil::CreateFullWithDescendingLayout(
520     absl::Span<const int64> dimensions, NativeT value) {
521   Literal literal(ShapeUtil::MakeShapeWithDescendingLayout(
522       primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
523   literal.PopulateWithValue(value);
524   return literal;
525 }
526 
527 template <PrimitiveType type, typename T>
CreateLiteralWithGenerator(const Shape & shape,const std::function<T (absl::Span<const int64>)> & generator)528 /* static */ StatusOr<Literal> LiteralUtil::CreateLiteralWithGenerator(
529     const Shape& shape,
530     const std::function<T(absl::Span<const int64>)>& generator) {
531   using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
532   TF_RET_CHECK(shape.element_type() == type);
533   Literal literal(shape);
534   TF_RETURN_IF_ERROR(literal.Populate<NativeT>(
535       [&](absl::Span<const int64> indexes) { return generator(indexes); }));
536   return std::move(literal);
537 }
538 
539 template <PrimitiveType type, typename E, typename T>
CreateRandomLiteral(const Shape & shape,E * engine,T mean,T stddev)540 /* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
541     const Shape& shape, E* engine, T mean, T stddev) {
542   using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
543   std::normal_distribution<NativeT> generator(mean, stddev);
544   return CreateLiteralWithGenerator<type, NativeT>(
545       shape,
546       [&](absl::Span<const int64> /*indexes*/) { return generator(*engine); });
547 }
548 
549 template <PrimitiveType type, typename T>
CreateRandomLiteral(const Shape & shape,T mean,T stddev)550 /* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
551     const Shape& shape, T mean, T stddev) {
552   std::minstd_rand0 engine;
553   return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
554 }
555 
556 }  // namespace xla
557 
558 #endif  // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
559