1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Utilities for dealing with Literal protobufs.
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
19 #define TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
20 
21 #include <functional>
22 #include <initializer_list>
23 #include <iterator>
24 #include <memory>
25 #include <ostream>
26 #include <string>
27 #include <type_traits>
28 #include <vector>
29 
30 #include "absl/memory/memory.h"
31 #include "absl/strings/string_view.h"
32 #include "absl/types/span.h"
33 #include "tensorflow/compiler/xla/array2d.h"
34 #include "tensorflow/compiler/xla/array3d.h"
35 #include "tensorflow/compiler/xla/array4d.h"
36 #include "tensorflow/compiler/xla/index_util.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/primitive_util.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/sparse_index_array.h"
42 #include "tensorflow/compiler/xla/status_macros.h"
43 #include "tensorflow/compiler/xla/types.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/compiler/xla/xla_data.pb.h"
46 #include "tensorflow/core/lib/core/bitmap.h"
47 #include "tensorflow/core/lib/core/status.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/platform/macros.h"
50 #include "tensorflow/core/platform/protobuf.h"
51 #include "tensorflow/core/platform/types.h"
52 
53 namespace xla {
54 
55 class LiteralUtil {
56  public:
57   LiteralUtil() = delete;
58 
59   // Returns a literal scalar representing the first element.
60   static Literal GetFirstScalarLiteral(const LiteralSlice& literal);
61 
62   // Creates a new literal of a given rank. To minimize ambiguity (for users
63   // and the compiler) these CreateR[0-2] methods should explicitly specify the
64   // native type. For example:
65   //
66   //  CreateR1<float>({1.0, 42.0});
67   //  CreateR2<uint32>({{1, 2}, {3, 4}});
68   //
69   // The variants not ending with WithLayout use the default XLA layout for the
70   // literal's linear representation in memory.
71   template <typename NativeT>
72   static Literal CreateR0(NativeT value);
73   template <typename NativeT>
74   static Literal CreateR1(absl::Span<const NativeT> values);
75   static Literal CreateR1(const tensorflow::core::Bitmap& values);
76   template <typename NativeT>
77   static Literal CreateR2(
78       std::initializer_list<std::initializer_list<NativeT>> values);
79   template <typename NativeT>
80   static Literal CreateR2WithLayout(
81       std::initializer_list<std::initializer_list<NativeT>> values,
82       const Layout& layout);
83   template <typename NativeT>
84   static Literal CreateR3(std::initializer_list<
85                           std::initializer_list<std::initializer_list<NativeT>>>
86                               values);
87   template <typename NativeT>
88   static Literal CreateR3WithLayout(
89       std::initializer_list<
90           std::initializer_list<std::initializer_list<NativeT>>>
91           values,
92       const Layout& layout);
93   template <typename NativeT>
94   static Literal CreateR4(
95       std::initializer_list<std::initializer_list<
96           std::initializer_list<std::initializer_list<NativeT>>>>
97           values);
98   template <typename NativeT>
99   static Literal CreateR4WithLayout(
100       std::initializer_list<std::initializer_list<
101           std::initializer_list<std::initializer_list<NativeT>>>>
102           values,
103       const Layout& layout);
104 
105   // Creates a literal with a sparse layout and the given indices and values.
106   // The shape is initialized from the given dimensions.  The minor dimension of
107   // the indices array must equal the rank of the shape (i.e. size of the
108   // dimensions array). The major dimension of the indices array must equal the
109   // number of elements in the values array. The maximum number of elements in
110   // the array is taken from the max_indices() value of the index array.
111   //
112   // XLA assumes that sparse literals are in sorted order for all operations. If
113   // the `sort` argument is true, then the indices and values will be sorted
114   // while copying them into the literal. If you have ensured that the indices
115   // and values are already sorted, then you may set the `sort` argument to
116   // false to skip the sorting step.
117   //
118   // For example:
119   //
120   //   CreateSparse(
121   //     {12, 12, 12},
122   //     SparseIndexArray(10, 3,
123   //                      Array2D{
124   //                        {0, 1, 2},
125   //                        {3, 4, 5},
126   //                        {6, 7, 8},
127   //                        {9, 10, 11},
128   //                      }),
129   //     {1.0, 2.0 3.0, 4.0})
130   //
131   // This creates an array with shape F64[12,12,12]sparse{10}, that has the
132   // following non-zero values:
133   //
134   //     [0,  1,  2]: 1.0
135   //     [3,  4,  5]: 2.0
136   //     [6,  7,  8]: 3.0
137   //     [9, 10, 11]: 4.0
138   //
139   template <typename NativeT>
140   static Literal CreateSparse(absl::Span<const int64> dimensions,
141                               SparseIndexArray indices,
142                               absl::Span<const NativeT> values,
143                               bool sort = true);
144 
145   // Creates a scalar literal value zero of the given primitive type.
146   static Literal Zero(PrimitiveType primitive_type);
147   // Creates a scalar literal value one of the given primitive type.
148   static Literal One(PrimitiveType primitive_type);
149   // Creates a scalar literal value containing the minimum value of the given
150   // primitive type. For floating-point types, returns -inf.
151   static Literal MinValue(PrimitiveType primitive_type);
152   // Creates a scalar literal value containing the maximum value of the given
153   // primitive type. For floating-point types, returns inf.
154   static Literal MaxValue(PrimitiveType primitive_type);
155   // Creates a literal of the given shape where each element is `value`.
156   template <typename NativeT>
157   static Literal CreateFullWithDescendingLayout(
158       absl::Span<const int64> dimensions, NativeT value);
159 
160   // Creates a new literal from an Array type. The variants not ending with
161   // WithLayout use the default XLA layout for the literal's linear
162   // representation in memory.
163   template <typename NativeT>
164   static Literal CreateFromArray(const Array<NativeT>& values);
165   template <typename NativeT>
166   static Literal CreateFromArrayWithLayout(const Array<NativeT>& values,
167                                            const Layout& layout);
168   template <typename NativeT>
169   static Literal CreateR2FromArray2D(const Array2D<NativeT>& values);
170   template <typename NativeT>
171   static Literal CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
172                                                const Layout& layout);
173   template <typename NativeT>
174   static Literal CreateR3FromArray3D(const Array3D<NativeT>& values);
175   template <typename NativeT>
176   static Literal CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
177                                                const Layout& layout);
178   template <typename NativeT>
179   static Literal CreateR4FromArray4D(const Array4D<NativeT>& values);
180   template <typename NativeT>
181   static Literal CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
182                                                const Layout& layout);
183 
184   // Creates a new vector of U8s literal value from a string.
185   static Literal CreateR1U8(absl::string_view value);
186 
187   // Creates a linspace-populated literal with the given number of rows and
188   // columns.
189   static Literal CreateR2F32Linspace(float from, float to, int64 rows,
190                                      int64 cols);
191 
192   // Creates a literal that projects the (x, y) dimensions given in values into
193   // the z dimension given by "projection".
194   template <typename NativeT>
195   static Literal CreateR3Projected(
196       std::initializer_list<std::initializer_list<NativeT>> values,
197       int64 projection);
198 
199   // Creates a literal that projects the (x, y) dimensions given in values into
200   // the z and p dimensions given.
201   template <typename NativeT>
202   static Literal CreateR4Projected(
203       std::initializer_list<std::initializer_list<NativeT>> values,
204       int64 projection_p, int64 projection_z);
205 
206   // Returns an identity matrix (rank 2) with the given row and column count.
207   template <typename NativeT>
208   static Literal MakeIdentityR2(int64 size);
209 
210   // Returns a tuple literal composed of given literals. Data is copied from the
211   // given elements into the returned literal.
212   static Literal MakeTuple(absl::Span<const Literal* const> elements);
213 
214   static Literal MakeTupleFromSlices(absl::Span<const LiteralSlice> elements);
215 
216   // As above, but intended to be invoked with move semantics; i.e.
217   //
218   //  std::vector<Literal> elements = ...;
219   //  auto result = LiteralUtil::MakeTupleOwned(std::move(elements));
220   //
221   // This would have been declared as an overload, but there is ambiguity
222   // in invocation between the above signature and this one.
223   static Literal MakeTupleOwned(std::vector<Literal> elements);
224 
225   // This overload lets you pass a braced list of Literals to
226   // MakeTupleOwned:
227   //
228   //   LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...).
229   //
230   // Simply relying on the MakeTupleOwned(std::vector<Literal>)
231   // overload doesn't work because std::initializer_list's elements are always
232   // const.
233   //
234   // The arguments to this function must all be Literal.
235   template <typename... Ts>
MakeTupleOwned(Ts...elements)236   static Literal MakeTupleOwned(Ts... elements) {
237     std::array<Literal, sizeof...(Ts)> arr{std::move(elements)...};
238     std::vector<Literal> v;
239     v.insert(v.begin(), std::make_move_iterator(arr.begin()),
240              std::make_move_iterator(arr.end()));
241     return MakeTupleOwned(std::move(v));
242   }
243 
244   // Create a constant token literal. Token types have no value.
245   static Literal CreateToken();
246 
247   // Creates a new Literal object with its values havings the primitive_type
248   // type, and with dimensions defined by the dimensions parameter.
249   // The content of the literal values is the default value of the primitive
250   // type of literal itself (0 for numeric types, and false for predicates).
251   static Literal CreateFromDimensions(PrimitiveType primitive_type,
252                                       absl::Span<const int64> dimensions);
253 
254   // If the given literal's data type is bfloat16, converts it to a float
255   // literal; otherwise, returns a copy of it. If the literal is a tuple,
256   // recursively converts its elements.
257   static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal);
258 
259   // If the given literal's data type is float, converts it to a bfloat16
260   // literal; otherwise, returns a copy of it. If the literal is a tuple,
261   // recursively converts its elements.
262   static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal);
263 
264   // Creates a literal with a new shape with the given new dimensions using the
265   // data in the given input literal. For reshaping purposes the (flat) data
266   // buffer of the input literal is assumed to have the given minor_to_major
267   // layout order.
268   static Literal ReshapeSlice(absl::Span<const int64> new_dimensions,
269                               absl::Span<const int64> minor_to_major,
270                               const LiteralSlice& literal);
271 
272   // Creates a literal with the supplied shape, and uses the provided value
273   // generator to populate the literal's values.
274   // Returns the new literal object, or an error Status if failed.
275   template <
276       PrimitiveType type,
277       typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
278   static StatusOr<Literal> CreateRandomLiteral(
279       const Shape& shape,
280       const std::function<T(absl::Span<const int64>)>& generator);
281 
282   // Creates a literal with the supplied shape, and initializes the literal
283   // values using a normal distribution with given mean and stddev standard
284   // deviation, and using the engine as entropy generator.
285   // Returns the new literal object, or an error Status if failed.
286   template <
287       PrimitiveType type, typename E,
288       typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
289   static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, E* engine,
290                                                T mean, T stddev);
291 
292   // Creates a literal with the supplied shape, and initializes the literal
293   // values using a normal distribution with given mean and stddev standard
294   // deviation.
295   // Returns the new literal object, or an error Status if failed.
296   template <
297       PrimitiveType type,
298       typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
299   static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, T mean,
300                                                T stddev);
301 
302   //
303   // End of factory methods.
304 
305   // Returns a multi-dimensional index as a string. For example: '{7, 8}' will
306   // be returned for a 2-dimensional index with dimension 0 index equal to 7,
307   // dimension 1 equal to 8.
308   static string MultiIndexAsString(absl::Span<const int64> multi_index);
309 };
310 
311 std::ostream& operator<<(std::ostream& out, const Literal& literal);
312 
313 template <typename NativeT>
CreateR0(NativeT value)314 /* static */ Literal LiteralUtil::CreateR0(NativeT value) {
315   Literal literal(ShapeUtil::MakeShape(
316       primitive_util::NativeToPrimitiveType<NativeT>(), {}));
317   literal.Set({}, value);
318   return literal;
319 }
320 
321 template <typename NativeT>
CreateR1(absl::Span<const NativeT> values)322 /* static */ Literal LiteralUtil::CreateR1(absl::Span<const NativeT> values) {
323   Literal literal(
324       ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
325                            {static_cast<int64>(values.size())}));
326   literal.PopulateR1(values);
327   return literal;
328 }
329 
330 template <typename NativeT>
CreateR2WithLayout(std::initializer_list<std::initializer_list<NativeT>> values,const Layout & layout)331 /* static */ Literal LiteralUtil::CreateR2WithLayout(
332     std::initializer_list<std::initializer_list<NativeT>> values,
333     const Layout& layout) {
334   Literal literal(ShapeUtil::MakeShapeWithLayout(
335       primitive_util::NativeToPrimitiveType<NativeT>(),
336       {static_cast<int64>(values.size()),
337        static_cast<int64>(values.begin()->size())},
338       AsInt64Slice(layout.minor_to_major())));
339   literal.PopulateR2(values);
340   return literal;
341 }
342 
343 template <typename NativeT>
CreateR2(std::initializer_list<std::initializer_list<NativeT>> values)344 /* static */ Literal LiteralUtil::CreateR2(
345     std::initializer_list<std::initializer_list<NativeT>> values) {
346   return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
347 }
348 
349 template <typename NativeT>
CreateR3WithLayout(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> values,const Layout & layout)350 /* static */ Literal LiteralUtil::CreateR3WithLayout(
351     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
352         values,
353     const Layout& layout) {
354   const int64 d0 = values.size();
355   const int64 d1 = values.begin()->size();
356   const int64 d2 = values.begin()->begin()->size();
357   Array3D<NativeT> tmp(d0, d1, d2);
358   int64 i0 = 0;
359   for (auto d1_values : values) {
360     int64 i1 = 0;
361     for (auto d2_values : d1_values) {
362       int64 i2 = 0;
363       for (auto value : d2_values) {
364         tmp(i0, i1, i2) = value;
365         ++i2;
366       }
367       ++i1;
368     }
369     ++i0;
370   }
371   return CreateR3FromArray3DWithLayout(tmp, layout);
372 }
373 
374 template <typename NativeT>
CreateR3(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> values)375 /* static */ Literal LiteralUtil::CreateR3(
376     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
377         values) {
378   return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
379 }
380 
381 template <typename NativeT>
CreateR4WithLayout(std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>> values,const Layout & layout)382 /* static */ Literal LiteralUtil::CreateR4WithLayout(
383     std::initializer_list<std::initializer_list<
384         std::initializer_list<std::initializer_list<NativeT>>>>
385         values,
386     const Layout& layout) {
387   const int64 d0 = values.size();
388   const int64 d1 = values.begin()->size();
389   const int64 d2 = values.begin()->begin()->size();
390   const int64 d3 = values.begin()->begin()->begin()->size();
391   Array4D<NativeT> tmp(d0, d1, d2, d3);
392   int64 i0 = 0;
393   for (auto d1_values : values) {
394     int64 i1 = 0;
395     for (auto d2_values : d1_values) {
396       int64 i2 = 0;
397       for (auto d3_values : d2_values) {
398         int64 i3 = 0;
399         for (auto value : d3_values) {
400           tmp(i0, i1, i2, i3) = value;
401           ++i3;
402         }
403         ++i2;
404       }
405       ++i1;
406     }
407     ++i0;
408   }
409   return CreateR4FromArray4DWithLayout(tmp, layout);
410 }
411 
412 template <typename NativeT>
CreateSparse(absl::Span<const int64> dimensions,SparseIndexArray indices,absl::Span<const NativeT> values,bool sort)413 /* static */ Literal LiteralUtil::CreateSparse(
414     absl::Span<const int64> dimensions, SparseIndexArray indices,
415     absl::Span<const NativeT> values, bool sort) {
416   int64 num_elements = values.size();
417   int64 rank = dimensions.size();
418   CHECK_EQ(num_elements, indices.index_count());
419   CHECK_EQ(rank, indices.rank());
420   Literal literal(ShapeUtil::MakeShapeWithSparseLayout(
421       primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
422       indices.max_indices()));
423   literal.PopulateSparse(indices, values, sort);
424   return literal;
425 }
426 
427 template <typename NativeT>
CreateR4(std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>> values)428 /* static */ Literal LiteralUtil::CreateR4(
429     std::initializer_list<std::initializer_list<
430         std::initializer_list<std::initializer_list<NativeT>>>>
431         values) {
432   return CreateR4WithLayout(values, LayoutUtil::GetDefaultLayoutForR4());
433 }
434 
435 template <typename NativeT>
CreateFromArrayWithLayout(const Array<NativeT> & values,const Layout & layout)436 /* static */ Literal LiteralUtil::CreateFromArrayWithLayout(
437     const Array<NativeT>& values, const Layout& layout) {
438   Literal literal(ShapeUtil::MakeShapeWithLayout(
439       primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
440       AsInt64Slice(layout.minor_to_major())));
441   literal.PopulateFromArray(values);
442   return literal;
443 }
444 
445 template <typename NativeT>
CreateFromArray(const Array<NativeT> & values)446 /* static */ Literal LiteralUtil::CreateFromArray(
447     const Array<NativeT>& values) {
448   return CreateFromArrayWithLayout(
449       values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
450 }
451 
452 template <typename NativeT>
CreateR2FromArray2DWithLayout(const Array2D<NativeT> & values,const Layout & layout)453 /* static */ Literal LiteralUtil::CreateR2FromArray2DWithLayout(
454     const Array2D<NativeT>& values, const Layout& layout) {
455   return CreateFromArrayWithLayout(values, layout);
456 }
457 
458 template <typename NativeT>
CreateR2FromArray2D(const Array2D<NativeT> & values)459 /* static */ Literal LiteralUtil::CreateR2FromArray2D(
460     const Array2D<NativeT>& values) {
461   return CreateFromArray(values);
462 }
463 
464 template <typename NativeT>
CreateR3FromArray3DWithLayout(const Array3D<NativeT> & values,const Layout & layout)465 /* static */ Literal LiteralUtil::CreateR3FromArray3DWithLayout(
466     const Array3D<NativeT>& values, const Layout& layout) {
467   return CreateFromArrayWithLayout(values, layout);
468 }
469 
470 template <typename NativeT>
CreateR3FromArray3D(const Array3D<NativeT> & values)471 /* static */ Literal LiteralUtil::CreateR3FromArray3D(
472     const Array3D<NativeT>& values) {
473   return CreateFromArray(values);
474 }
475 
476 template <typename NativeT>
CreateR3Projected(std::initializer_list<std::initializer_list<NativeT>> values,int64 projection)477 /* static */ Literal LiteralUtil::CreateR3Projected(
478     std::initializer_list<std::initializer_list<NativeT>> values,
479     int64 projection) {
480   int64 dim0_size = projection;
481   int64 dim1_size = values.size();
482   int64 dim2_size = values.begin()->size();
483 
484   Array3D<NativeT> array(dim0_size, dim1_size, dim2_size);
485   for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) {
486     int64 dim1 = 0;
487     for (auto inner_list : values) {
488       int64 dim2 = 0;
489       for (auto value : inner_list) {
490         array(dim0, dim1, dim2) = value;
491         ++dim2;
492       }
493       CHECK_EQ(dim2_size, dim2);
494       ++dim1;
495     }
496     CHECK_EQ(dim1_size, dim1);
497   }
498   return CreateR3FromArray3D(array);
499 }
500 
501 template <typename NativeT>
CreateR4Projected(std::initializer_list<std::initializer_list<NativeT>> values,int64 projection_p,int64 projection_z)502 /* static */ Literal LiteralUtil::CreateR4Projected(
503     std::initializer_list<std::initializer_list<NativeT>> values,
504     int64 projection_p, int64 projection_z) {
505   int64 dim0_size = projection_p;
506   int64 dim1_size = projection_z;
507   int64 dim2_size = values.size();
508   int64 dim3_size = values.begin()->size();
509 
510   Array4D<NativeT> array(dim0_size, dim1_size, dim2_size, dim3_size);
511   for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) {
512     for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) {
513       int64 dim2 = 0;
514       for (auto inner_list : values) {
515         int64 dim3 = 0;
516         for (auto value : inner_list) {
517           array(dim0, dim1, dim2, dim3) = value;
518           ++dim3;
519         }
520         CHECK_EQ(dim3_size, dim3);
521         ++dim2;
522       }
523       CHECK_EQ(dim2_size, dim2);
524     }
525   }
526   return CreateR4FromArray4D(array);
527 }
528 
529 template <typename NativeT>
CreateR4FromArray4D(const Array4D<NativeT> & values)530 /* static */ Literal LiteralUtil::CreateR4FromArray4D(
531     const Array4D<NativeT>& values) {
532   return CreateFromArray(values);
533 }
534 
535 template <typename NativeT>
CreateR4FromArray4DWithLayout(const Array4D<NativeT> & values,const Layout & layout)536 /* static */ Literal LiteralUtil::CreateR4FromArray4DWithLayout(
537     const Array4D<NativeT>& values, const Layout& layout) {
538   return CreateFromArrayWithLayout(values, layout);
539 }
540 
541 // Returns an identity matrix (rank 2) with the given row and column count.
542 template <typename NativeT>
MakeIdentityR2(int64 size)543 /* static */ Literal LiteralUtil::MakeIdentityR2(int64 size) {
544   Array2D<NativeT> array(size, size, 0);
545   for (int64 i = 0; i < size; ++i) {
546     array(i, i) = 1;
547   }
548   return CreateR2FromArray2D(array);
549 }
550 
551 template <typename NativeT>
CreateFullWithDescendingLayout(absl::Span<const int64> dimensions,NativeT value)552 /* static */ Literal LiteralUtil::CreateFullWithDescendingLayout(
553     absl::Span<const int64> dimensions, NativeT value) {
554   Literal literal(ShapeUtil::MakeShapeWithDescendingLayout(
555       primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
556   literal.PopulateWithValue(value);
557   return literal;
558 }
559 
560 template <PrimitiveType type, typename T>
CreateRandomLiteral(const Shape & shape,const std::function<T (absl::Span<const int64>)> & generator)561 /* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
562     const Shape& shape,
563     const std::function<T(absl::Span<const int64>)>& generator) {
564   using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
565   TF_RET_CHECK(shape.element_type() == type);
566   Literal literal(shape);
567   TF_RETURN_IF_ERROR(literal.Populate<NativeT>(
568       [&](absl::Span<const int64> indexes) { return generator(indexes); }));
569   return std::move(literal);
570 }
571 
572 template <PrimitiveType type, typename E, typename T>
CreateRandomLiteral(const Shape & shape,E * engine,T mean,T stddev)573 /* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
574     const Shape& shape, E* engine, T mean, T stddev) {
575   using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
576   std::normal_distribution<NativeT> generator(mean, stddev);
577   return CreateRandomLiteral<type, NativeT>(
578       shape,
579       [&](absl::Span<const int64> /*indexes*/) { return generator(*engine); });
580 }
581 
582 template <PrimitiveType type, typename T>
CreateRandomLiteral(const Shape & shape,T mean,T stddev)583 /* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
584     const Shape& shape, T mean, T stddev) {
585   std::minstd_rand0 engine;
586   return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
587 }
588 
589 }  // namespace xla
590 
591 #endif  // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
592