1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Utilities for dealing with Literal protobufs.
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
19 #define TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
20 
21 #include <functional>
22 #include <initializer_list>
23 #include <iterator>
24 #include <memory>
25 #include <ostream>
26 #include <string>
27 #include <type_traits>
28 #include <vector>
29 
30 #include "tensorflow/compiler/xla/array2d.h"
31 #include "tensorflow/compiler/xla/array3d.h"
32 #include "tensorflow/compiler/xla/array4d.h"
33 #include "tensorflow/compiler/xla/index_util.h"
34 #include "tensorflow/compiler/xla/layout_util.h"
35 #include "tensorflow/compiler/xla/primitive_util.h"
36 #include "tensorflow/compiler/xla/ptr_util.h"
37 #include "tensorflow/compiler/xla/shape_tree.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/sparse_index_array.h"
40 #include "tensorflow/compiler/xla/status_macros.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/compiler/xla/xla_data.pb.h"
44 #include "tensorflow/core/lib/core/bitmap.h"
45 #include "tensorflow/core/lib/core/status.h"
46 #include "tensorflow/core/lib/core/stringpiece.h"
47 #include "tensorflow/core/lib/gtl/array_slice.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/platform/macros.h"
50 #include "tensorflow/core/platform/protobuf.h"
51 #include "tensorflow/core/platform/types.h"
52 
53 namespace xla {
54 
55 // Class representing literal values in XLA.
56 //
57 // TODO(b/67651157): The methods in this class should be reduced to a minimal
58 // set of methods which construct Literals and accessors methods. Other methods
59 // which perform computation on Literals (Reshape, Slice, etc) should be moved
60 // elsewhere, and perhaps combined with evaluator code which operates on
61 // Literals.
62 class Literal {
63  public:
Literal()64   Literal() : Literal(ShapeUtil::MakeNil()) {}
65 
66   // Create a literal of the given shape. The literal is allocated sufficient
67   // memory to hold the shape. Memory is uninitialized.
68   explicit Literal(const Shape& shape);
69   virtual ~Literal();
70 
71   // Literals are moveable, but not copyable. To copy a literal use
72   // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
73   // of literals which can be expensive.
74   Literal(const Literal& other) = delete;
75   Literal& operator=(const Literal& other) = delete;
76   Literal(Literal&& other);
77   Literal& operator=(Literal&& other);
78 
79   // Literals are equal if they have compatible shapes and the same data
80   // values. Layout is not compared.
81   bool operator==(const Literal& other) const;
82   bool operator!=(const Literal& other) const { return !(*this == other); }
83 
84   // Serialize to and from a proto.
85   static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
86       const LiteralProto& proto);
87   LiteralProto ToProto() const;
88 
89   // Return the shape of the literal.
shape()90   const Shape& shape() const { return shape_; }
91 
92   // TODO(b/67651157): Remove this accessor. Literal users should not be able to
93   // mutate the shape as this can produce malformed Literals.
mutable_shape_do_not_use()94   Shape* mutable_shape_do_not_use() { return &shape_; }
95 
96   // Returns a (Mutable)ArraySlice view of the array for this literal for the
97   // given NativeT (e.g., float). CHECKs if the subshape of the literal at the
98   // given ShapeIndex is not array. See primitive_util.h for the mapping from
99   // XLA type to native type.
100   template <typename NativeT>
101   tensorflow::gtl::ArraySlice<NativeT> data(
102       const ShapeIndex& shape_index = {}) const;
103   template <typename NativeT>
104   tensorflow::gtl::MutableArraySlice<NativeT> data(
105       const ShapeIndex& shape_index = {});
106 
107   // Returns a pointer to the sparse index array. Returns nullptr if the literal
108   // is not a sparse array.
109   const SparseIndexArray* sparse_indices(
110       const ShapeIndex& shape_index = {}) const;
111   SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {});
112 
113   // Returns a pointer to (or size of) the underlying buffer holding the array
114   // at the given shape index. CHECKs if the subshape of the literal at the
115   // given ShapeIndex is not array.
116   const void* untyped_data(const ShapeIndex& shape_index = {}) const;
117   void* untyped_data(const ShapeIndex& shape_index = {});
118   int64 size_bytes(const ShapeIndex& shape_index = {}) const;
119 
120   // Creates a new literal of a given rank. To minimize ambiguity (for users
121   // and the compiler) these CreateR[0-2] methods should explicitly specify the
122   // native type. For example:
123   //
124   //  CreateR1<float>({1.0, 42.0});
125   //  CreateR2<uint32>({{1, 2}, {3, 4}});
126   //
127   // The variants not ending with WithLayout use the default XLA layout for the
128   // literal's linear representation in memory.
129   template <typename NativeT>
130   static std::unique_ptr<Literal> CreateR0(NativeT value);
131   template <typename NativeT>
132   static std::unique_ptr<Literal> CreateR1(
133       tensorflow::gtl::ArraySlice<NativeT> values);
134   static std::unique_ptr<Literal> CreateR1(
135       const tensorflow::core::Bitmap& values);
136   template <typename NativeT>
137   static std::unique_ptr<Literal> CreateR2(
138       std::initializer_list<std::initializer_list<NativeT>> values);
139   template <typename NativeT>
140   static std::unique_ptr<Literal> CreateR2WithLayout(
141       std::initializer_list<std::initializer_list<NativeT>> values,
142       const Layout& layout);
143   template <typename NativeT>
144   static std::unique_ptr<Literal> CreateR3(
145       std::initializer_list<
146           std::initializer_list<std::initializer_list<NativeT>>>
147           values);
148   template <typename NativeT>
149   static std::unique_ptr<Literal> CreateR3WithLayout(
150       std::initializer_list<
151           std::initializer_list<std::initializer_list<NativeT>>>
152           values,
153       const Layout& layout);
154   template <typename NativeT>
155   static std::unique_ptr<Literal> CreateR4(
156       std::initializer_list<std::initializer_list<
157           std::initializer_list<std::initializer_list<NativeT>>>>
158           values);
159   template <typename NativeT>
160   static std::unique_ptr<Literal> CreateR4WithLayout(
161       std::initializer_list<std::initializer_list<
162           std::initializer_list<std::initializer_list<NativeT>>>>
163           values,
164       const Layout& layout);
165 
166   // Returns this literal's data as a string. This literal must be a rank-1 U8
167   // array.
168   string GetR1U8AsString() const;
169 
170   // Creates a literal with a sparse layout and the given indices and values.
171   // The shape is initialized from the given dimensions.  The minor dimension of
172   // the indices array must equal the rank of the shape (i.e. size of the
173   // dimensions array). The major dimension of the indices array must equal the
174   // number of elements in the values array. The maximum number of elements in
175   // the array is taken from the max_indices() value of the index array.
176   //
177   // XLA assumes that sparse literals are in sorted order for all operations. If
178   // the `sort` argument is true, then the indices and values will be sorted
179   // while copying them into the literal. If you have ensured that the indices
180   // and values are already sorted, then you may set the `sort` argument to
181   // false to skip the sorting step.
182   //
183   // For example:
184   //
185   //   CreateSparse(
186   //     {12, 12, 12},
187   //     SparseIndexArray(10, 3,
188   //                      Array2D{
189   //                        {0, 1, 2},
190   //                        {3, 4, 5},
191   //                        {6, 7, 8},
192   //                        {9, 10, 11},
193   //                      }),
194   //     {1.0, 2.0 3.0, 4.0})
195   //
196   // This creates an array with shape F64[12,12,12]sparse{10}, that has the
197   // following non-zero values:
198   //
199   //     [0,  1,  2]: 1.0
200   //     [3,  4,  5]: 2.0
201   //     [6,  7,  8]: 3.0
202   //     [9, 10, 11]: 4.0
203   //
204   template <typename NativeT>
205   static std::unique_ptr<Literal> CreateSparse(
206       tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices,
207       tensorflow::gtl::ArraySlice<NativeT> values, bool sort = true);
208 
209   // Populates a literal with a sparse layout with the given indices and values.
210   // Each index in the indices array is CHECKed against the dimensions in the
211   // literal's shape.  If sort is true, then the indices and values will be
212   // sorted.  If sort is false, then the indices and values are assumed to
213   // already be in sorted order.  See CreateSparse for an example of how data
214   // are populated.
215   template <typename NativeT>
216   void PopulateSparse(SparseIndexArray indices,
217                       tensorflow::gtl::ArraySlice<NativeT> values,
218                       bool sort = true);
219 
220   // Creates a new Literal object with the shape specified as parameter.
221   // The content of the literal values is the default value of the primitive
222   // type of literal itself (0 for numeric types, and false for predicates).
223   static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
224 
225   // Creates a new Literal object with its values havings the primitive_type
226   // type, and with dimensions defined by the dimensions parameter.
227   // The content of the literal values is the default value of the primitive
228   // type of literal itself (0 for numeric types, and false for predicates).
229   static std::unique_ptr<Literal> CreateFromDimensions(
230       PrimitiveType primitive_type,
231       tensorflow::gtl::ArraySlice<int64> dimensions);
232 
233   // Copy values from 'src_literal' rooted at 'src_shape_index' into this
234   // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
235   // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
236   // rooted at 'src_shape_index', but need not be arrays.
237   Status CopyFrom(const Literal& src_literal,
238                   const ShapeIndex& dest_shape_index = {},
239                   const ShapeIndex& src_shape_index = {});
240 
241   // Similar to CopyFrom, but with move semantincs. The subshape of this literal
242   // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
243   // (layouts and shapes must match), but need not be arrays. The memory
244   // allocated in this literal for the subshape at dest_shape_index is
245   // deallocated, and the respective buffers are replaced with those in
246   // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
247   Status MoveFrom(Literal&& src_literal,
248                   const ShapeIndex& dest_shape_index = {});
249 
250   // Copies the values from src_literal, starting at src_base shape indexes,
251   // to this literal, starting at dest_base, where the copy size in each
252   // dimension is specified by copy_size.
253   // The src_literal and this literal must have the same primitive type,
254   // src_base+copy_size must fit the source literal dimensions, as well as
255   // dest_base+copy_size must fit the destination literal dimensions.
256   // Note: if either src_literal or this literal contains dimensions with zero
257   // element, then copy_size must be 0 in these dimensions while the
258   // corresponding base indices being 0.
259   // This literal and 'src_literal' must be arrays.
260   Status CopySliceFrom(const Literal& src_literal,
261                        tensorflow::gtl::ArraySlice<int64> src_base,
262                        tensorflow::gtl::ArraySlice<int64> dest_base,
263                        tensorflow::gtl::ArraySlice<int64> copy_size);
264 
265   // Returns a vector containing the tuple elements of this Literal as separate
266   // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
267   // elements are moved into the new Literals; no data is copied. Upon return
268   // this Literal is set to a nil shape (empty tuple)
269   std::vector<Literal> DecomposeTuple();
270 
271   // This operation is the inverse of DecomposeTuple. The given elements are
272   // moved into the tuple elements of a new tuple-shaped Literal which is
273   // returned. Upon return, each of the Literals in 'elements' is set to a nil
274   // shape (empty tuple).
275   static Literal MoveIntoTuple(
276       tensorflow::gtl::MutableArraySlice<Literal> elements);
277 
278   // Creates a new value that has the equivalent value as this literal, but
279   // conforms to new_layout; e.g. a literal matrix that was in {0, 1}
280   // minor-to-major dimension layout can be re-layed-out as {1, 0}
281   // minor-to-major dimension layout and the value in the cell at any given
282   // logical index (i0, i1) will be the same.
283   //
284   // For tuple shaped literals, shape_index should be used to select the inner
285   // array that the new layout applies to.
286   //
287   // Note: this is useful when the client wants to ensure that a value placed in
288   // the XLA allocation tracker has a particular layout; for efficiency
289   // purposes or avoiding unimplemented operation/layout combinations.
290   std::unique_ptr<Literal> Relayout(const Layout& new_layout,
291                                     const ShapeIndex& shape_index = {}) const;
292 
293   // An overload of Relayout which changes the layout of the entire shape rather
294   // than being limited to a single array within the shape.
295   std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const;
296 
297   // Creates a new literal by reshaping this literal to have the given
298   // dimensions. The total number of elements must not change; The
299   // implementation currently only supports monotonic dim0-major layouts.
300   // This literal must be an array.
301   StatusOr<std::unique_ptr<Literal>> Reshape(
302       tensorflow::gtl::ArraySlice<int64> dimensions) const;
303 
304   // Creates a new literal by reordering the dimensions of this literal.
305   // The given `permutation` must be a permutation of the dimension numbers
306   // in the original literal, and it specifies the order of the new dimensions
307   // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
308   // For example, a transpose call on a literal of shape [3 x 8 x 4] and
309   // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
310   // This literal must be an array.
311   std::unique_ptr<Literal> Transpose(
312       tensorflow::gtl::ArraySlice<int64> permutation) const;
313 
314   // Creates a sub-array from this literal by extracting the indices
315   // [start_index, limit_index) of each dimension. The result literal has the
316   // same rank and layout as for the given literal. The number of indices in
317   // start_indices and limit_indices must be the rank of the literal, and the
318   // indices follow the order of the dimensions.
319   // This literal must be an array.
320   std::unique_ptr<Literal> Slice(
321       tensorflow::gtl::ArraySlice<int64> start_indices,
322       tensorflow::gtl::ArraySlice<int64> limit_indices) const;
323 
324   // Creates a literal with a prepended dimension with bound "times"; e.g. a
325   // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
326   // literal replicated four times.
327   // This literal must be an array.
328   template <typename NativeT>
329   std::unique_ptr<Literal> Replicate(int64 times) const;
330 
331   // Converts this literal to another primitive type. Returns an error if the
332   // conversion is not possible. This literal must be array-shaped.
333   StatusOr<std::unique_ptr<Literal>> Convert(
334       PrimitiveType primitive_dest_type) const;
335 
336   // Creates a scalar literal value zero of the given primitive type.
337   static Literal Zero(PrimitiveType primitive_type);
338 
339   // Creates a scalar literal value one of the given primitive type.
340   static Literal One(PrimitiveType primitive_type);
341 
342   // Creates a scalar literal value containing the minimum value of the given
343   // primitive type. For floating-point types, returns -inf.
344   static Literal MinValue(PrimitiveType primitive_type);
345 
346   // Creates a scalar literal value containing the maximum value of the given
347   // primitive type. For floating-point types, returns inf.
348   static Literal MaxValue(PrimitiveType primitive_type);
349 
350   // Creates a literal of the given shape where each element is `value`.
351   template <typename NativeT>
352   static std::unique_ptr<Literal> CreateFullWithDescendingLayout(
353       tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value);
354 
355   // Creates a new literal from an Array type. The variants not ending with
356   // WithLayout use the default XLA layout for the literal's linear
357   // representation in memory.
358   template <typename NativeT>
359   static std::unique_ptr<Literal> CreateFromArray(const Array<NativeT>& values);
360   template <typename NativeT>
361   static std::unique_ptr<Literal> CreateFromArrayWithLayout(
362       const Array<NativeT>& values, const Layout& layout);
363   template <typename NativeT>
364   static std::unique_ptr<Literal> CreateR2FromArray2D(
365       const Array2D<NativeT>& values);
366   template <typename NativeT>
367   static std::unique_ptr<Literal> CreateR2FromArray2DWithLayout(
368       const Array2D<NativeT>& values, const Layout& layout);
369   template <typename NativeT>
370   static std::unique_ptr<Literal> CreateR3FromArray3D(
371       const Array3D<NativeT>& values);
372   template <typename NativeT>
373   static std::unique_ptr<Literal> CreateR3FromArray3DWithLayout(
374       const Array3D<NativeT>& values, const Layout& layout);
375   template <typename NativeT>
376   static std::unique_ptr<Literal> CreateR4FromArray4D(
377       const Array4D<NativeT>& values);
378   template <typename NativeT>
379   static std::unique_ptr<Literal> CreateR4FromArray4DWithLayout(
380       const Array4D<NativeT>& values, const Layout& layout);
381 
382   // Creates a new vector of U8s literal value from a string.
383   static std::unique_ptr<Literal> CreateR1U8(tensorflow::StringPiece value);
384 
385   // Creates a linspace-populated literal with the given number of rows and
386   // columns.
387   static std::unique_ptr<Literal> CreateR2F32Linspace(float from, float to,
388                                                       int64 rows, int64 cols);
389 
390   // Creates a literal that projects the (x, y) dimensions given in values into
391   // the z dimension given by "projection".
392   template <typename NativeT>
393   static std::unique_ptr<Literal> CreateR3Projected(
394       std::initializer_list<std::initializer_list<NativeT>> values,
395       int64 projection);
396 
397   // Creates a literal that projects the (x, y) dimensions given in values into
398   // the z and p dimensions given.
399   template <typename NativeT>
400   static std::unique_ptr<Literal> CreateR4Projected(
401       std::initializer_list<std::initializer_list<NativeT>> values,
402       int64 projection_p, int64 projection_z);
403 
404   // Clones this literal into a new Literal, or new std::unique_ptr<Literal>.
405   Literal Clone() const;
406   std::unique_ptr<Literal> CloneToUnique() const;
407 
408   // Gets or sets an element in the literal at the given index. The multi_index
409   // is CHECKed against the dimension sizes.
410   template <typename NativeT>
411   NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index,
412               const ShapeIndex& shape_index) const;
413   template <typename NativeT>
414   void Set(tensorflow::gtl::ArraySlice<int64> multi_index,
415            const ShapeIndex& shape_index, NativeT value);
416 
417   // Overloads of Get and Set for array literals. CHECKs if the literal is not
418   // array-shaped and dense.
419   template <typename NativeT>
420   NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index) const;
421   template <typename NativeT>
422   void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value);
423 
424   // Returns the multi-index of the element in a sparse literal at the given
425   // sparse element number.  The sparse element number is the position with in
426   // the sparse array's list of (index, value) pairs, and is checked against the
427   // total number of (index, value) pairs in the sparse array.
428   tensorflow::gtl::ArraySlice<int64> GetSparseIndex(
429       int64 sparse_element_number, const ShapeIndex& shape_index = {}) const;
430 
431   // Returns the value of the element in a sparse literal at the given sparse
432   // element number.  The sparse element number is the position with in the
433   // sparse array's list of (index, value) pairs, and is checked against the
434   // total number of (index, value) pairs in the sparse array.
435   template <typename NativeT>
436   NativeT GetSparseElement(int64 sparse_element_number,
437                            const ShapeIndex& shape_index = {}) const;
438 
439   // Appends the given element to the literal.  If the elements are not appended
440   // in sorted order, then SortSparseElements should be called before calling
441   // other methods.  This literal must have a sparse layout.
442   template <typename NativeT>
443   void AppendSparseElement(tensorflow::gtl::ArraySlice<int64> multi_index,
444                            NativeT value, const ShapeIndex& shape_index = {});
445 
446   // Sorts the elements in a sparse array.
447   void SortSparseElements(const ShapeIndex& shape_index = {});
448 
449   // Returns the element value at index (0, ..., 0), however many zeroes are
450   // required for that index.
451   template <typename NativeT>
452   NativeT GetFirstElement() const;
453 
454   // As Get(), but determines the correct type and converts the value
455   // into text.
456   string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
457                      const ShapeIndex& shape_index = {}) const;
458 
459   // As GetSparseElement(), but determines the correct type and converts the
460   // value into text.
461   string GetSparseElementAsString(int64 sparse_element_number,
462                                   const ShapeIndex& shape_index = {}) const;
463 
464   // As Get(), but determines the correct type and converts the value into
465   // int64.  This literal must be an array.
466   StatusOr<int64> GetIntegralAsS64(
467       tensorflow::gtl::ArraySlice<int64> multi_index) const;
468 
469   // Returns an identity matrix (rank 2) with the given row and column count.
470   template <typename NativeT>
471   static std::unique_ptr<Literal> MakeIdentityR2(int64 size);
472 
473   // Returns a tuple literal composed of given literals. Data is copied from the
474   // given elements into the returned literal.
475   static std::unique_ptr<Literal> MakeTuple(
476       tensorflow::gtl::ArraySlice<const Literal*> elements);
477 
478   // As above, but intended to be invoked with move semantics; i.e.
479   //
480   //  std::vector<std::unique_ptr<Literal>> elements = ...;
481   //  auto result = Literal::MakeTupleOwned(std::move(elements));
482   //
483   // This would have been declared as an overload, but there is ambiguity
484   // in invocation between the above signature and this one.
485   static std::unique_ptr<Literal> MakeTupleOwned(
486       std::vector<std::unique_ptr<Literal>> elements);
487 
488   // This overload lets you pass a braced list of unique_ptr<Literal>s to
489   // MakeTupleOwned:
490   //
491   //   Literal::MakeTupleOwned(Literal::CreateR1(...), ...).
492   //
493   // Simply relying on the MakeTupleOwned(std::vector<unique_ptr<Literal>>)
494   // overload doesn't work because std::initializer_list's elements are always
495   // const.
496   //
497   // The arguments to this function must all be unique_ptr<Literal>.
498   template <typename... Ts>
MakeTupleOwned(std::unique_ptr<Ts>...elements)499   static std::unique_ptr<Literal> MakeTupleOwned(
500       std::unique_ptr<Ts>... elements) {
501     std::array<std::unique_ptr<Literal>, sizeof...(Ts)> arr{
502         std::move(elements)...};
503     std::vector<std::unique_ptr<Literal>> v;
504     v.insert(v.begin(), std::make_move_iterator(arr.begin()),
505              std::make_move_iterator(arr.end()));
506     return MakeTupleOwned(std::move(v));
507   }
508 
509   // Returns a string representation of the literal value.
510   // Warning: this function can take minutes for multi-million element Literals.
511   string ToString(bool print_layout = false) const;
512 
513   // Invokes the "per cell" callback for each element in the provided
514   // literal with the element's indices and a string representation of
515   // the element's value.
516   //
517   // This function is useful if you want a polymorphic representation
518   // of the tensor's elements (turning it to a string for something
519   // like representation in a protobuf).
520   //
521   // This literal must have a dense layout.
522   void EachCellAsString(
523       const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
524                                const string& value)>& per_cell) const;
525   template <typename NativeT>
526   void EachCell(std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
527                                    NativeT value)>
528                     per_cell) const;
529 
530   // Populate this literal with the given values. Examples:
531   //
532   //   // Populate with floats.
533   //   Array2D<float> float_values = ...
534   //   literal.PopulateR2FromArray2D(values);
535   //
536   //   // Populate with int32s.
537   //   literal.PopulateR2<int32>({{1, 2}, {3, 4}});
538   //
539   // The shape and element type of this literal must match given values. For
540   // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
541   // array of S32.
542   template <typename NativeT>
543   void PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values);
544   void PopulateR1(const tensorflow::core::Bitmap& values);
545   template <typename NativeT>
546   void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
547   template <typename NativeT>
548   void PopulateFromArray(const Array<NativeT>& values);
549   template <typename NativeT>
550   void PopulateR2FromArray2D(const Array2D<NativeT>& values);
551   template <typename NativeT>
552   void PopulateR3FromArray3D(const Array3D<NativeT>& values);
553   template <typename NativeT>
554   void PopulateR4FromArray4D(const Array4D<NativeT>& values);
555 
556   // Populates literal values by calling the generator function for every cell
557   // in this literal object.
558   //
559   // generator must be a callable of the type
560   // NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible.
561   //
562   // This literal must have a dense layout.
563   template <typename NativeT, typename FnType>
564   Status Populate(const FnType& generator);
565 
566   // Fills this literal with the given value.
567   template <typename NativeT>
568   void PopulateWithValue(NativeT value);
569 
570   // Returns whether every element in this literal is equal to value.
571   //
572   // value is an int8 because we expect this to be called with small
573   // compile-time constants (0, -1, etc.) and so that whatever value you pass
574   // can be represented exactly by floating-point types as small as 16 bits.
575   //
576   // If value doesn't fit in this literal's type, returns false.  Values of 1/0
577   // are considered equal to true/false; other values are not considered equal
578   // to true. Also if this literal is not array-shaped false is returned.
579   bool IsAll(int8 value) const;
580 
581   // Like IsAll(const Literal&, int8), except we check whether the literal is
582   // equal to a particular floating-point number.
583   //
584   // If the literal is not a floating-point value, this always returns false.
585   //
586   // This casts value to the type of literal, then compares using ==.  The usual
587   // admonishments about floating-point equality checks apply.  We expect you to
588   // use this to check for values that can be expressed precisely as a float,
589   // e.g. -0.5.  Also if this literal is not array-shaped false is returned.
590   bool IsAllFloat(float value) const;
591 
592   // Like IsAll(const Literal&, int8), except we check whether the literal is
593   // equal to a particular complex number.
594   //
595   // If the literal is not a complex value, this always returns false.
596   //
597   // This casts value to the type of literal, then compares using ==.  The usual
598   // admonishments about floating-point equality checks apply.  We expect you to
599   // use this to check for complex values that can be expressed precisely as
600   // float pairs e.g. (-0.5, 1.0).
601   //
602   // This literal must have a dense layout.
603   bool IsAllComplex(complex64 value) const;
604 
605   // Returns whether this literal is zero at the specified index. This literal
606   // must be an array with a dense layout.
607   bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
608 
609   // Return the count of the elements in the array at the given shape index in
610   // this literal.
611   int64 element_count(const ShapeIndex& index = {}) const {
612     return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
613   }
614 
615   // Return the count of the elements in the sparse array at the given shape
616   // index in this literal, which will be no larger than
617   // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()).
618   int64 sparse_element_count() const;
619 
620  protected:
621   // 'allocate_arrays' indicates whether to allocate memory for the arrays in
622   // the shape. If false, buffer pointers inside of the Literal::Pieces are set
623   // to nullptr.
624   Literal(const Shape& shape, bool allocate_arrays);
625 
626   // Internal template helper for the Literal::CopySliceFrom(), matching its
627   // arguments one by one.
628   template <typename NativeT>
629   Status CopySliceFromInternal(const Literal& src_literal,
630                                tensorflow::gtl::ArraySlice<int64> src_base,
631                                tensorflow::gtl::ArraySlice<int64> dest_base,
632                                tensorflow::gtl::ArraySlice<int64> copy_size);
633 
634   // Utility structure which is used to create the optimal configuration for
635   // a ShapeUtil::ForEachIndex() scan across two literals.
636   struct StrideConfig {
637     StrideConfig(const Shape& source_shape, const Shape& dest_shape,
638                  tensorflow::gtl::ArraySlice<int64> dimensions);
639 
640     // The dimensions of the stride operation. Essentially every dimension
641     // will be iterated from base[i] to base[i]+dimensions[i], in step[i]
642     // steps.
643     tensorflow::gtl::ArraySlice<int64> dimensions;
644     DimensionVector base;
645     DimensionVector step;
646     int64 minor_dimension = 0;
647     // The size of the strides for source and destination. One of the two
648     // (the one looping through its most minor dimension) will be 1, while
649     // the other will be the stride size at the dimension matching the other
650     // shape most minor dimension being scanned.
651     int64 dest_stride = 1;
652     int64 source_stride = 1;
653     // The size of the inner loop on the most minor dimension.
654     int64 minor_loop_size = 1;
655   };
656 
657   // A data structure representing a subshape at a particular ShapeIndex within
658   // the literal. For array-shaped ShapeIndexes, this data structure holds the
659   // pointer to the memory allocated for the array data.
660   class Piece {
661    public:
662     // Return the buffer holding the array data for this piece as an array
663     // slice. This piece must be array-shaped.
664     template <typename NativeT>
665     tensorflow::gtl::ArraySlice<NativeT> data() const;
666     template <typename NativeT>
667     tensorflow::gtl::MutableArraySlice<NativeT> data();
668 
669     // Return the buffer holding the array data for this piece as a void*. This
670     // piece must be array-shaped.
671     void* untyped_data();
672     const void* untyped_data() const;
673 
674     // Gets or sets an element in the array at the given index. The multi_index
675     // is CHECKed against the dimension sizes of the array.  This piece must be
676     // array-shaped.
677     template <typename NativeT>
678     NativeT Get(tensorflow::gtl::ArraySlice<int64> index) const;
679     template <typename NativeT>
680     void Set(tensorflow::gtl::ArraySlice<int64> index, NativeT value);
681 
682     // Gets/sets the buffer holding the array data.
buffer()683     char* buffer() const { return buffer_; }
set_buffer(char * buffer)684     void set_buffer(char* buffer) { buffer_ = buffer; }
685 
686     // The array of multi-indices that provide the locations of non-zero
687     // elements in a sparse array.  Only used if
688     // LayoutUtil::IsSparseArray(shape()) is true.
sparse_indices()689     SparseIndexArray* sparse_indices() const { return sparse_indices_; }
set_sparse_indices(SparseIndexArray * sparse_indices)690     void set_sparse_indices(SparseIndexArray* sparse_indices) {
691       sparse_indices_ = sparse_indices;
692     }
693 
694     // Gets or sets the subshape of this piece. This reference points to a
695     // subshape within the shape in the containing Literal (Literal::shape_).
subshape()696     const Shape& subshape() const { return *subshape_; }
set_subshape(const Shape * subshape)697     void set_subshape(const Shape* subshape) { subshape_ = subshape; }
698 
699     // Returns the size in bytes of the buffer holding the array data.
size_bytes()700     int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
701 
702     // Returns the number of elements in this piece's array.
element_count()703     int64 element_count() const { return ShapeUtil::ElementsIn(subshape()); }
704 
705     // Copy the data from 'src' into this piece's buffer. Shapes of this piece
706     // and src must be compatible.
707     Status CopyFrom(const Piece& src);
708 
709     // Returns true if this piece and 'other' contain the same data. This piece
710     // and 'other' must be array-shaped and compatible.
711     bool EqualElements(const Piece& other) const;
712 
713     // Writes the shape and data (if array-shaped) into the given proto.
714     void WriteToProto(LiteralProto* proto) const;
715 
716     // Copies the data from the given proto into this piece. The shape of this
717     // piece must be equal (not just compatible) to the shape of the proto.
718     Status CopyFromProto(const LiteralProto& proto);
719 
720     // Sorts the elements in a sparse array.
721     void SortSparseElements();
722 
723    private:
724     // Recursive helper for EqualElements.
725     template <typename NativeT>
726     bool EqualElementsInternal(const Piece& other,
727                                std::vector<int64>* multi_index) const;
728 
729     // Helper for SortSparseElements that has the element type as a template
730     // parameter.
731     template <typename NativeT>
732     void SortSparseElementsInternal();
733 
734     // For array-shaped pieces, this is the buffer holding the literal data.
735     char* buffer_ = nullptr;
736 
737     // For sparse arrays, this is the array of indices.
738     SparseIndexArray* sparse_indices_ = nullptr;
739 
740     // The shape of piece. This points into the shape of the containing Literal
741     // (Literal::shape_).
742     const Shape* subshape_ = nullptr;
743   };
744 
745   // Returns the piece at the given ShapeIndex.
piece(const ShapeIndex & shape_index)746   Piece& piece(const ShapeIndex& shape_index) {
747     return *pieces_.mutable_element(shape_index);
748   }
piece(const ShapeIndex & shape_index)749   const Piece& piece(const ShapeIndex& shape_index) const {
750     return pieces_.element(shape_index);
751   }
752 
753   // Returns the piece at the root of the shape (empty ShapeIndex).
root_piece()754   Piece& root_piece() { return piece({}); }
root_piece()755   const Piece& root_piece() const { return piece({}); }
756 
757   // Deallocate the buffers held by this literal (if the literal owns the
758   // buffer).
759   void DeallocateBuffers();
760 
761   Shape shape_;
762   ShapeTree<Piece> pieces_;
763 
764   // Whether the buffers held in pieces_ are owned by this Literal.
765   bool owns_buffers_;
766 
767   // LiteralView must access and manipulate Pieces of other Literals.
768   friend class LiteralView;
769 };  // namespace xla
770 
771 std::ostream& operator<<(std::ostream& out, const Literal& literal);
772 
773 // A read-only view of a Literal. A LiteralView contains pointers to buffers
774 // owned by the viewed Literal.
775 //
776 // TODO(b/71550060): Replace LiteralView with Literal slice classes (immutable
777 // and mutable) similar to (Mutable)ArraySlice.
778 class LiteralView : public Literal {
779  public:
780   // Create and return a view of the given literal rooted at the given shape
781   // index within the given literal. A factory is used rather than a public
782   // constructor because only const LiteralViews are supported. It's still
783   // possible to create non-const LiteralViews via the copy constructors, but
784   // the factory method makes it a bit less likely. Implementing literal slices
785   // will fix this undesirable situation (b/71550060).
786   static const LiteralView Create(const Literal& literal,
787                                   const ShapeIndex& view_root = {});
788 
789   LiteralView(const LiteralView& other);
790   LiteralView& operator=(const LiteralView& other);
791 
792   virtual ~LiteralView();
793 
794  private:
795   LiteralView(const Literal& literal, const ShapeIndex& view_root);
796 
797   // Helper for the copy constructor and copy assignment operator.
798   void CopyFrom(const LiteralView& other);
799 };
800 
801 template <typename NativeT>
data()802 tensorflow::gtl::ArraySlice<NativeT> Literal::Piece::data() const {
803   CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
804   CHECK_EQ(subshape().element_type(),
805            primitive_util::NativeToPrimitiveType<NativeT>())
806       << "Attempting to access "
807       << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
808       << " type, but literal element type is "
809       << PrimitiveType_Name(subshape().element_type());
810   return tensorflow::gtl::ArraySlice<NativeT>(
811       reinterpret_cast<const NativeT*>(buffer()),
812       ShapeUtil::ElementsIn(subshape()));
813 }
814 
815 template <typename NativeT>
data()816 tensorflow::gtl::MutableArraySlice<NativeT> Literal::Piece::data() {
817   CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
818   CHECK_EQ(subshape().element_type(),
819            primitive_util::NativeToPrimitiveType<NativeT>())
820       << "Attempting to access "
821       << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
822       << " type, but literal element type is "
823       << PrimitiveType_Name(subshape().element_type());
824   return tensorflow::gtl::MutableArraySlice<NativeT>(
825       reinterpret_cast<NativeT*>(buffer()), ShapeUtil::ElementsIn(subshape()));
826 }
827 
828 template <typename NativeT>
Get(tensorflow::gtl::ArraySlice<int64> multi_index)829 NativeT Literal::Piece::Get(
830     tensorflow::gtl::ArraySlice<int64> multi_index) const {
831   CHECK(LayoutUtil::IsDenseArray(subshape()));
832   return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
833       subshape(), multi_index)];
834 }
835 
836 template <typename NativeT>
Set(tensorflow::gtl::ArraySlice<int64> multi_index,NativeT value)837 void Literal::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
838                          NativeT value) {
839   CHECK(LayoutUtil::IsDenseArray(subshape()));
840   data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
841       subshape(), multi_index)] = value;
842 }
843 
844 template <typename NativeT>
data(const ShapeIndex & shape_index)845 tensorflow::gtl::ArraySlice<NativeT> Literal::data(
846     const ShapeIndex& shape_index) const {
847   return piece(shape_index).data<NativeT>();
848 }
849 
850 template <typename NativeT>
data(const ShapeIndex & shape_index)851 tensorflow::gtl::MutableArraySlice<NativeT> Literal::data(
852     const ShapeIndex& shape_index) {
853   return piece(shape_index).data<NativeT>();
854 }
855 
856 template <typename NativeT>
Get(tensorflow::gtl::ArraySlice<int64> multi_index,const ShapeIndex & shape_index)857 inline NativeT Literal::Get(tensorflow::gtl::ArraySlice<int64> multi_index,
858                             const ShapeIndex& shape_index) const {
859   return piece(shape_index).Get<NativeT>(multi_index);
860 }
861 
862 template <typename NativeT>
Get(tensorflow::gtl::ArraySlice<int64> multi_index)863 inline NativeT Literal::Get(
864     tensorflow::gtl::ArraySlice<int64> multi_index) const {
865   return root_piece().Get<NativeT>(multi_index);
866 }
867 
868 template <typename NativeT>
Set(tensorflow::gtl::ArraySlice<int64> multi_index,const ShapeIndex & shape_index,NativeT value)869 inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
870                          const ShapeIndex& shape_index, NativeT value) {
871   return piece(shape_index).Set<NativeT>(multi_index, value);
872 }
873 
874 template <typename NativeT>
Set(tensorflow::gtl::ArraySlice<int64> multi_index,NativeT value)875 inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
876                          NativeT value) {
877   return root_piece().Set<NativeT>(multi_index, value);
878 }
879 
880 template <typename NativeT>
CreateR0(NativeT value)881 /* static */ std::unique_ptr<Literal> Literal::CreateR0(NativeT value) {
882   auto literal = MakeUnique<Literal>(ShapeUtil::MakeShape(
883       primitive_util::NativeToPrimitiveType<NativeT>(), {}));
884   literal->Set({}, value);
885   return literal;
886 }
887 
888 template <typename NativeT>
CreateR1(tensorflow::gtl::ArraySlice<NativeT> values)889 /* static */ std::unique_ptr<Literal> Literal::CreateR1(
890     tensorflow::gtl::ArraySlice<NativeT> values) {
891   auto literal = MakeUnique<Literal>(
892       ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
893                            {static_cast<int64>(values.size())}));
894   literal->PopulateR1(values);
895   return literal;
896 }
897 
898 template <typename NativeT>
CreateR2WithLayout(std::initializer_list<std::initializer_list<NativeT>> values,const Layout & layout)899 /* static */ std::unique_ptr<Literal> Literal::CreateR2WithLayout(
900     std::initializer_list<std::initializer_list<NativeT>> values,
901     const Layout& layout) {
902   auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(
903       primitive_util::NativeToPrimitiveType<NativeT>(),
904       {static_cast<int64>(values.size()),
905        static_cast<int64>(values.begin()->size())},
906       AsInt64Slice(layout.minor_to_major())));
907   literal->PopulateR2(values);
908   return literal;
909 }
910 
911 template <typename NativeT>
CreateR2(std::initializer_list<std::initializer_list<NativeT>> values)912 /* static */ std::unique_ptr<Literal> Literal::CreateR2(
913     std::initializer_list<std::initializer_list<NativeT>> values) {
914   return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
915 }
916 
917 template <typename NativeT>
CreateR3WithLayout(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> values,const Layout & layout)918 /* static */ std::unique_ptr<Literal> Literal::CreateR3WithLayout(
919     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
920         values,
921     const Layout& layout) {
922   const int64 d0 = values.size();
923   const int64 d1 = values.begin()->size();
924   const int64 d2 = values.begin()->begin()->size();
925   Array3D<NativeT> tmp(d0, d1, d2);
926   int64 i0 = 0;
927   for (auto d1_values : values) {
928     int64 i1 = 0;
929     for (auto d2_values : d1_values) {
930       int64 i2 = 0;
931       for (auto value : d2_values) {
932         tmp(i0, i1, i2) = value;
933         ++i2;
934       }
935       ++i1;
936     }
937     ++i0;
938   }
939   return CreateR3FromArray3DWithLayout(tmp, layout);
940 }
941 
942 template <typename NativeT>
CreateR3(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> values)943 /* static */ std::unique_ptr<Literal> Literal::CreateR3(
944     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
945         values) {
946   return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
947 }
948 
949 template <typename NativeT>
CreateR4WithLayout(std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>> values,const Layout & layout)950 /* static */ std::unique_ptr<Literal> Literal::CreateR4WithLayout(
951     std::initializer_list<std::initializer_list<
952         std::initializer_list<std::initializer_list<NativeT>>>>
953         values,
954     const Layout& layout) {
955   const int64 d0 = values.size();
956   const int64 d1 = values.begin()->size();
957   const int64 d2 = values.begin()->begin()->size();
958   const int64 d3 = values.begin()->begin()->begin()->size();
959   Array4D<NativeT> tmp(d0, d1, d2, d3);
960   int64 i0 = 0;
961   for (auto d1_values : values) {
962     int64 i1 = 0;
963     for (auto d2_values : d1_values) {
964       int64 i2 = 0;
965       for (auto d3_values : d2_values) {
966         int64 i3 = 0;
967         for (auto value : d3_values) {
968           tmp(i0, i1, i2, i3) = value;
969           ++i3;
970         }
971         ++i2;
972       }
973       ++i1;
974     }
975     ++i0;
976   }
977   return CreateR4FromArray4DWithLayout(tmp, layout);
978 }
979 
980 template <typename NativeT>
CreateSparse(tensorflow::gtl::ArraySlice<int64> dimensions,SparseIndexArray indices,tensorflow::gtl::ArraySlice<NativeT> values,bool sort)981 /* static */ std::unique_ptr<Literal> Literal::CreateSparse(
982     tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices,
983     tensorflow::gtl::ArraySlice<NativeT> values, bool sort) {
984   int64 num_elements = values.size();
985   int64 rank = dimensions.size();
986   CHECK_EQ(num_elements, indices.index_count());
987   CHECK_EQ(rank, indices.rank());
988   auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithSparseLayout(
989       primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
990       indices.max_indices()));
991   literal->PopulateSparse(indices, values, sort);
992   return literal;
993 }
994 
995 template <typename NativeT>
CreateR4(std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>> values)996 /* static */ std::unique_ptr<Literal> Literal::CreateR4(
997     std::initializer_list<std::initializer_list<
998         std::initializer_list<std::initializer_list<NativeT>>>>
999         values) {
1000   return CreateR4WithLayout(values, LayoutUtil::GetDefaultLayoutForR4());
1001 }
1002 
1003 template <typename NativeT>
CreateFromArrayWithLayout(const Array<NativeT> & values,const Layout & layout)1004 /* static */ std::unique_ptr<Literal> Literal::CreateFromArrayWithLayout(
1005     const Array<NativeT>& values, const Layout& layout) {
1006   auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(
1007       primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
1008       AsInt64Slice(layout.minor_to_major())));
1009   literal->PopulateFromArray(values);
1010   return literal;
1011 }
1012 
1013 template <typename NativeT>
CreateFromArray(const Array<NativeT> & values)1014 /* static */ std::unique_ptr<Literal> Literal::CreateFromArray(
1015     const Array<NativeT>& values) {
1016   return CreateFromArrayWithLayout(
1017       values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
1018 }
1019 
1020 template <typename NativeT>
CreateR2FromArray2DWithLayout(const Array2D<NativeT> & values,const Layout & layout)1021 /* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2DWithLayout(
1022     const Array2D<NativeT>& values, const Layout& layout) {
1023   return CreateFromArrayWithLayout(values, layout);
1024 }
1025 
1026 template <typename NativeT>
CreateR2FromArray2D(const Array2D<NativeT> & values)1027 /* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2D(
1028     const Array2D<NativeT>& values) {
1029   return CreateFromArray(values);
1030 }
1031 
1032 template <typename NativeT>
CreateR3FromArray3DWithLayout(const Array3D<NativeT> & values,const Layout & layout)1033 /* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3DWithLayout(
1034     const Array3D<NativeT>& values, const Layout& layout) {
1035   return CreateFromArrayWithLayout(values, layout);
1036 }
1037 
1038 template <typename NativeT>
CreateR3FromArray3D(const Array3D<NativeT> & values)1039 /* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3D(
1040     const Array3D<NativeT>& values) {
1041   return CreateFromArray(values);
1042 }
1043 
1044 template <typename NativeT>
CreateR3Projected(std::initializer_list<std::initializer_list<NativeT>> values,int64 projection)1045 /* static */ std::unique_ptr<Literal> Literal::CreateR3Projected(
1046     std::initializer_list<std::initializer_list<NativeT>> values,
1047     int64 projection) {
1048   int64 dim0_size = projection;
1049   int64 dim1_size = values.size();
1050   int64 dim2_size = values.begin()->size();
1051 
1052   Array3D<NativeT> array(dim0_size, dim1_size, dim2_size);
1053   for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) {
1054     int64 dim1 = 0;
1055     for (auto inner_list : values) {
1056       int64 dim2 = 0;
1057       for (auto value : inner_list) {
1058         array(dim0, dim1, dim2) = value;
1059         ++dim2;
1060       }
1061       CHECK_EQ(dim2_size, dim2);
1062       ++dim1;
1063     }
1064     CHECK_EQ(dim1_size, dim1);
1065   }
1066   return CreateR3FromArray3D(array);
1067 }
1068 
1069 template <typename NativeT>
CreateR4Projected(std::initializer_list<std::initializer_list<NativeT>> values,int64 projection_p,int64 projection_z)1070 /* static */ std::unique_ptr<Literal> Literal::CreateR4Projected(
1071     std::initializer_list<std::initializer_list<NativeT>> values,
1072     int64 projection_p, int64 projection_z) {
1073   int64 dim0_size = projection_p;
1074   int64 dim1_size = projection_z;
1075   int64 dim2_size = values.size();
1076   int64 dim3_size = values.begin()->size();
1077 
1078   Array4D<NativeT> array(dim0_size, dim1_size, dim2_size, dim3_size);
1079   for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) {
1080     for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) {
1081       int64 dim2 = 0;
1082       for (auto inner_list : values) {
1083         int64 dim3 = 0;
1084         for (auto value : inner_list) {
1085           array(dim0, dim1, dim2, dim3) = value;
1086           ++dim3;
1087         }
1088         CHECK_EQ(dim3_size, dim3);
1089         ++dim2;
1090       }
1091       CHECK_EQ(dim2_size, dim2);
1092     }
1093   }
1094   return CreateR4FromArray4D(array);
1095 }
1096 
1097 template <typename NativeT>
CreateR4FromArray4D(const Array4D<NativeT> & values)1098 /* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4D(
1099     const Array4D<NativeT>& values) {
1100   return CreateFromArray(values);
1101 }
1102 
1103 template <typename NativeT>
CreateR4FromArray4DWithLayout(const Array4D<NativeT> & values,const Layout & layout)1104 /* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4DWithLayout(
1105     const Array4D<NativeT>& values, const Layout& layout) {
1106   return CreateFromArrayWithLayout(values, layout);
1107 }
1108 
1109 template <typename NativeT>
GetFirstElement()1110 NativeT Literal::GetFirstElement() const {
1111   return data<NativeT>().at(0);
1112 }
1113 
1114 template <typename NativeT>
GetSparseElement(int64 sparse_element_number,const ShapeIndex & shape_index)1115 NativeT Literal::GetSparseElement(int64 sparse_element_number,
1116                                   const ShapeIndex& shape_index) const {
1117   CHECK(
1118       LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index)));
1119   return data<NativeT>(shape_index)[sparse_element_number];
1120 }
1121 
1122 template <typename NativeT>
AppendSparseElement(tensorflow::gtl::ArraySlice<int64> multi_index,NativeT value,const ShapeIndex & shape_index)1123 void Literal::AppendSparseElement(
1124     tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value,
1125     const ShapeIndex& shape_index) {
1126   Piece& p = piece(shape_index);
1127   const Shape& subshape = p.subshape();
1128   CHECK(LayoutUtil::IsSparseArray(subshape));
1129   int64 rank = ShapeUtil::Rank(subshape);
1130   CHECK_EQ(multi_index.size(), rank);
1131   int64 last_element = p.sparse_indices()->index_count();
1132   CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout()));
1133   p.sparse_indices()->Append(multi_index);
1134   CHECK_LT(last_element, p.data<NativeT>().size());
1135   p.data<NativeT>()[last_element] = value;
1136 }
1137 
1138 // Returns an identity matrix (rank 2) with the given row and column count.
1139 template <typename NativeT>
MakeIdentityR2(int64 size)1140 /* static */ std::unique_ptr<Literal> Literal::MakeIdentityR2(int64 size) {
1141   Array2D<NativeT> array(size, size, 0);
1142   for (int64 i = 0; i < size; ++i) {
1143     array(i, i) = 1;
1144   }
1145   return CreateR2FromArray2D(array);
1146 }
1147 
1148 template <typename NativeT>
EachCell(std::function<void (tensorflow::gtl::ArraySlice<int64> indices,NativeT value)> per_cell)1149 void Literal::EachCell(
1150     std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
1151                        NativeT value)>
1152         per_cell) const {
1153   if (ShapeUtil::HasZeroElements(shape())) {
1154     return;
1155   }
1156   std::vector<int64> indices(ShapeUtil::Rank(shape()), 0);
1157   do {
1158     per_cell(indices, Get<NativeT>(indices));
1159   } while (IndexUtil::BumpIndices(shape(), &indices));
1160 }
1161 
1162 template <typename NativeT>
PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values)1163 inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) {
1164   CHECK(ShapeUtil::IsArray(shape()));
1165   CHECK_EQ(ShapeUtil::Rank(shape()), 1);
1166   CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
1167   CHECK_EQ(shape().element_type(),
1168            primitive_util::NativeToPrimitiveType<NativeT>());
1169   for (int64 i = 0; i < values.size(); ++i) {
1170     Set({i}, values[i]);
1171   }
1172 }
1173 
1174 template <typename NativeT>
PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values)1175 void Literal::PopulateR2(
1176     std::initializer_list<std::initializer_list<NativeT>> values) {
1177   CHECK(ShapeUtil::IsArray(shape()));
1178   CHECK_EQ(ShapeUtil::Rank(shape()), 2);
1179   CHECK_EQ(shape().element_type(),
1180            primitive_util::NativeToPrimitiveType<NativeT>());
1181 
1182   const int64 dim0_size = values.size();
1183   const int64 dim1_size = values.begin()->size();
1184   CHECK_EQ(dim0_size, shape().dimensions(0));
1185   CHECK_EQ(dim1_size, shape().dimensions(1));
1186 
1187   int64 dim0 = 0;
1188   for (auto inner_list : values) {
1189     int64 dim1 = 0;
1190     for (auto value : inner_list) {
1191       Set({dim0, dim1}, value);
1192       ++dim1;
1193     }
1194     CHECK_EQ(dim1_size, dim1);
1195     ++dim0;
1196   }
1197 }
1198 
1199 template <typename NativeT>
PopulateFromArray(const Array<NativeT> & values)1200 void Literal::PopulateFromArray(const Array<NativeT>& values) {
1201   CHECK(ShapeUtil::IsArray(shape()));
1202   CHECK_EQ(shape().element_type(),
1203            primitive_util::NativeToPrimitiveType<NativeT>());
1204   CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions());
1205   for (int dim = 0; dim < values.num_dimensions(); ++dim) {
1206     CHECK_EQ(values.dim(dim), shape().dimensions(dim));
1207   }
1208   values.Each([this](tensorflow::gtl::ArraySlice<int64> indices,
1209                      NativeT value) { this->Set(indices, value); });
1210 }
1211 
1212 template <typename NativeT>
PopulateR2FromArray2D(const Array2D<NativeT> & values)1213 void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
1214   PopulateFromArray(values);
1215 }
1216 
1217 template <typename NativeT>
PopulateR3FromArray3D(const Array3D<NativeT> & values)1218 void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
1219   PopulateFromArray(values);
1220 }
1221 
1222 template <typename NativeT>
PopulateR4FromArray4D(const Array4D<NativeT> & values)1223 void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
1224   PopulateFromArray(values);
1225 }
1226 
1227 template <typename NativeT>
PopulateSparse(SparseIndexArray indices,tensorflow::gtl::ArraySlice<NativeT> values,bool sort)1228 void Literal::PopulateSparse(SparseIndexArray indices,
1229                              tensorflow::gtl::ArraySlice<NativeT> values,
1230                              bool sort) {
1231   CHECK(LayoutUtil::IsSparseArray(shape()));
1232   int rank = ShapeUtil::Rank(shape());
1233   CHECK_EQ(indices.rank(), rank);
1234   int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout());
1235   CHECK_LE(indices.max_indices(), max_elements);
1236   int64 num_elements = values.size();
1237   CHECK_LE(num_elements, max_elements);
1238   CHECK_EQ(num_elements, indices.index_count());
1239   auto root_data = root_piece().data<NativeT>();
1240   root_data.remove_suffix(max_elements - values.size());
1241   std::copy(values.begin(), values.end(), root_data.begin());
1242   *this->root_piece().sparse_indices() = std::move(indices);
1243   if (sort) {
1244     auto root_data = this->root_piece().data<NativeT>();
1245     root_data.remove_suffix(root_data.size() - num_elements);
1246     this->root_piece().sparse_indices()->SortWithValues(root_data);
1247   }
1248   DCHECK(this->root_piece().sparse_indices()->Validate(shape()));
1249 }
1250 
1251 template <typename NativeT, typename FnType>
Populate(const FnType & generator)1252 Status Literal::Populate(const FnType& generator) {
1253   const Shape& this_shape = shape();
1254   const int64 rank = ShapeUtil::Rank(this_shape);
1255   TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
1256   TF_RET_CHECK(this_shape.element_type() ==
1257                primitive_util::NativeToPrimitiveType<NativeT>());
1258   tensorflow::gtl::MutableArraySlice<NativeT> literal_data = data<NativeT>();
1259   if (rank > 0) {
1260     StrideConfig stride_config(this_shape, this_shape,
1261                                AsInt64Slice(this_shape.dimensions()));
1262     DimensionVector minor_scan_indexes(rank, 0);
1263     int64 minor_dimension_size =
1264         ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
1265 
1266     auto init_function = [&](const std::vector<int64>& indexes) {
1267       const int64 index =
1268           IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
1269       std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
1270       for (int64 i = 0; i < minor_dimension_size; ++i) {
1271         minor_scan_indexes[stride_config.minor_dimension] = i;
1272         literal_data.at(index + i) = generator(minor_scan_indexes);
1273       }
1274       return true;
1275     };
1276     ShapeUtil::ForEachIndex(this_shape, stride_config.base,
1277                             stride_config.dimensions, stride_config.step,
1278                             init_function);
1279   } else {
1280     // For scalars.
1281     literal_data.at(0) = generator({});
1282   }
1283   return Status::OK();
1284 }
1285 
1286 template <typename NativeT>
PopulateWithValue(NativeT value)1287 void Literal::PopulateWithValue(NativeT value) {
1288   CHECK(ShapeUtil::IsArray(shape()));
1289   CHECK_EQ(shape().element_type(),
1290            primitive_util::NativeToPrimitiveType<NativeT>());
1291   for (NativeT& element : data<NativeT>()) {
1292     element = value;
1293   }
1294 }
1295 
1296 template <typename NativeT>
CreateFullWithDescendingLayout(tensorflow::gtl::ArraySlice<int64> dimensions,NativeT value)1297 /* static */ std::unique_ptr<Literal> Literal::CreateFullWithDescendingLayout(
1298     tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value) {
1299   auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithDescendingLayout(
1300       primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
1301   literal->PopulateWithValue(value);
1302   return literal;
1303 }
1304 
1305 template <typename NativeT>
Replicate(int64 times)1306 std::unique_ptr<Literal> Literal::Replicate(int64 times) const {
1307   DimensionVector bounds = {times};
1308   bounds.reserve(shape().dimensions_size() + 1);
1309   for (int64 bound : shape().dimensions()) {
1310     bounds.push_back(bound);
1311   }
1312   auto literal =
1313       MakeUnique<Literal>(ShapeUtil::MakeShape(shape().element_type(), bounds));
1314   int64 elements = ShapeUtil::ElementsIn(literal->shape());
1315   if (elements == 0) {
1316     return literal;
1317   }
1318 
1319   DimensionVector output_indices(bounds.size(), 0);
1320   tensorflow::gtl::ArraySlice<int64> input_indices = output_indices;
1321   input_indices.remove_prefix(1);
1322 
1323   bool done = false;
1324   while (!done) {
1325     const auto element = Get<NativeT>(input_indices);
1326     literal->Set<NativeT>(output_indices, element);
1327 
1328     done = true;
1329     for (int n = 0; n < output_indices.size(); ++n) {
1330       ++output_indices[n];
1331       if (output_indices[n] < bounds[n]) {
1332         done = false;
1333         break;
1334       }
1335       output_indices[n] = 0;
1336     }
1337   }
1338   return literal;
1339 }
1340 
1341 }  // namespace xla
1342 
1343 #endif  // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
1344