1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Utilities for dealing with Literal protobufs.
17
18 #ifndef TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
19 #define TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
20
21 #include <functional>
22 #include <initializer_list>
23 #include <iterator>
24 #include <memory>
25 #include <ostream>
26 #include <string>
27 #include <type_traits>
28 #include <vector>
29
30 #include "tensorflow/compiler/xla/array2d.h"
31 #include "tensorflow/compiler/xla/array3d.h"
32 #include "tensorflow/compiler/xla/array4d.h"
33 #include "tensorflow/compiler/xla/index_util.h"
34 #include "tensorflow/compiler/xla/layout_util.h"
35 #include "tensorflow/compiler/xla/primitive_util.h"
36 #include "tensorflow/compiler/xla/ptr_util.h"
37 #include "tensorflow/compiler/xla/shape_tree.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/sparse_index_array.h"
40 #include "tensorflow/compiler/xla/status_macros.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/compiler/xla/xla_data.pb.h"
44 #include "tensorflow/core/lib/core/bitmap.h"
45 #include "tensorflow/core/lib/core/status.h"
46 #include "tensorflow/core/lib/core/stringpiece.h"
47 #include "tensorflow/core/lib/gtl/array_slice.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/platform/macros.h"
50 #include "tensorflow/core/platform/protobuf.h"
51 #include "tensorflow/core/platform/types.h"
52
53 namespace xla {
54
55 // Class representing literal values in XLA.
56 //
57 // TODO(b/67651157): The methods in this class should be reduced to a minimal
58 // set of methods which construct Literals and accessors methods. Other methods
59 // which perform computation on Literals (Reshape, Slice, etc) should be moved
60 // elsewhere, and perhaps combined with evaluator code which operates on
61 // Literals.
62 class Literal {
63 public:
Literal()64 Literal() : Literal(ShapeUtil::MakeNil()) {}
65
66 // Create a literal of the given shape. The literal is allocated sufficient
67 // memory to hold the shape. Memory is uninitialized.
68 explicit Literal(const Shape& shape);
69 virtual ~Literal();
70
71 // Literals are moveable, but not copyable. To copy a literal use
72 // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
73 // of literals which can be expensive.
74 Literal(const Literal& other) = delete;
75 Literal& operator=(const Literal& other) = delete;
76 Literal(Literal&& other);
77 Literal& operator=(Literal&& other);
78
79 // Literals are equal if they have compatible shapes and the same data
80 // values. Layout is not compared.
81 bool operator==(const Literal& other) const;
82 bool operator!=(const Literal& other) const { return !(*this == other); }
83
84 // Serialize to and from a proto.
85 static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
86 const LiteralProto& proto);
87 LiteralProto ToProto() const;
88
89 // Return the shape of the literal.
shape()90 const Shape& shape() const { return shape_; }
91
92 // TODO(b/67651157): Remove this accessor. Literal users should not be able to
93 // mutate the shape as this can produce malformed Literals.
mutable_shape_do_not_use()94 Shape* mutable_shape_do_not_use() { return &shape_; }
95
96 // Returns a (Mutable)ArraySlice view of the array for this literal for the
97 // given NativeT (e.g., float). CHECKs if the subshape of the literal at the
98 // given ShapeIndex is not array. See primitive_util.h for the mapping from
99 // XLA type to native type.
100 template <typename NativeT>
101 tensorflow::gtl::ArraySlice<NativeT> data(
102 const ShapeIndex& shape_index = {}) const;
103 template <typename NativeT>
104 tensorflow::gtl::MutableArraySlice<NativeT> data(
105 const ShapeIndex& shape_index = {});
106
107 // Returns a pointer to the sparse index array. Returns nullptr if the literal
108 // is not a sparse array.
109 const SparseIndexArray* sparse_indices(
110 const ShapeIndex& shape_index = {}) const;
111 SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {});
112
113 // Returns a pointer to (or size of) the underlying buffer holding the array
114 // at the given shape index. CHECKs if the subshape of the literal at the
115 // given ShapeIndex is not array.
116 const void* untyped_data(const ShapeIndex& shape_index = {}) const;
117 void* untyped_data(const ShapeIndex& shape_index = {});
118 int64 size_bytes(const ShapeIndex& shape_index = {}) const;
119
120 // Creates a new literal of a given rank. To minimize ambiguity (for users
121 // and the compiler) these CreateR[0-2] methods should explicitly specify the
122 // native type. For example:
123 //
124 // CreateR1<float>({1.0, 42.0});
125 // CreateR2<uint32>({{1, 2}, {3, 4}});
126 //
127 // The variants not ending with WithLayout use the default XLA layout for the
128 // literal's linear representation in memory.
129 template <typename NativeT>
130 static std::unique_ptr<Literal> CreateR0(NativeT value);
131 template <typename NativeT>
132 static std::unique_ptr<Literal> CreateR1(
133 tensorflow::gtl::ArraySlice<NativeT> values);
134 static std::unique_ptr<Literal> CreateR1(
135 const tensorflow::core::Bitmap& values);
136 template <typename NativeT>
137 static std::unique_ptr<Literal> CreateR2(
138 std::initializer_list<std::initializer_list<NativeT>> values);
139 template <typename NativeT>
140 static std::unique_ptr<Literal> CreateR2WithLayout(
141 std::initializer_list<std::initializer_list<NativeT>> values,
142 const Layout& layout);
143 template <typename NativeT>
144 static std::unique_ptr<Literal> CreateR3(
145 std::initializer_list<
146 std::initializer_list<std::initializer_list<NativeT>>>
147 values);
148 template <typename NativeT>
149 static std::unique_ptr<Literal> CreateR3WithLayout(
150 std::initializer_list<
151 std::initializer_list<std::initializer_list<NativeT>>>
152 values,
153 const Layout& layout);
154 template <typename NativeT>
155 static std::unique_ptr<Literal> CreateR4(
156 std::initializer_list<std::initializer_list<
157 std::initializer_list<std::initializer_list<NativeT>>>>
158 values);
159 template <typename NativeT>
160 static std::unique_ptr<Literal> CreateR4WithLayout(
161 std::initializer_list<std::initializer_list<
162 std::initializer_list<std::initializer_list<NativeT>>>>
163 values,
164 const Layout& layout);
165
166 // Returns this literal's data as a string. This literal must be a rank-1 U8
167 // array.
168 string GetR1U8AsString() const;
169
170 // Creates a literal with a sparse layout and the given indices and values.
171 // The shape is initialized from the given dimensions. The minor dimension of
172 // the indices array must equal the rank of the shape (i.e. size of the
173 // dimensions array). The major dimension of the indices array must equal the
174 // number of elements in the values array. The maximum number of elements in
175 // the array is taken from the max_indices() value of the index array.
176 //
177 // XLA assumes that sparse literals are in sorted order for all operations. If
178 // the `sort` argument is true, then the indices and values will be sorted
179 // while copying them into the literal. If you have ensured that the indices
180 // and values are already sorted, then you may set the `sort` argument to
181 // false to skip the sorting step.
182 //
183 // For example:
184 //
185 // CreateSparse(
186 // {12, 12, 12},
187 // SparseIndexArray(10, 3,
188 // Array2D{
189 // {0, 1, 2},
190 // {3, 4, 5},
191 // {6, 7, 8},
192 // {9, 10, 11},
193 // }),
194 // {1.0, 2.0 3.0, 4.0})
195 //
196 // This creates an array with shape F64[12,12,12]sparse{10}, that has the
197 // following non-zero values:
198 //
199 // [0, 1, 2]: 1.0
200 // [3, 4, 5]: 2.0
201 // [6, 7, 8]: 3.0
202 // [9, 10, 11]: 4.0
203 //
204 template <typename NativeT>
205 static std::unique_ptr<Literal> CreateSparse(
206 tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices,
207 tensorflow::gtl::ArraySlice<NativeT> values, bool sort = true);
208
209 // Populates a literal with a sparse layout with the given indices and values.
210 // Each index in the indices array is CHECKed against the dimensions in the
211 // literal's shape. If sort is true, then the indices and values will be
212 // sorted. If sort is false, then the indices and values are assumed to
213 // already be in sorted order. See CreateSparse for an example of how data
214 // are populated.
215 template <typename NativeT>
216 void PopulateSparse(SparseIndexArray indices,
217 tensorflow::gtl::ArraySlice<NativeT> values,
218 bool sort = true);
219
220 // Creates a new Literal object with the shape specified as parameter.
221 // The content of the literal values is the default value of the primitive
222 // type of literal itself (0 for numeric types, and false for predicates).
223 static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
224
225 // Creates a new Literal object with its values havings the primitive_type
226 // type, and with dimensions defined by the dimensions parameter.
227 // The content of the literal values is the default value of the primitive
228 // type of literal itself (0 for numeric types, and false for predicates).
229 static std::unique_ptr<Literal> CreateFromDimensions(
230 PrimitiveType primitive_type,
231 tensorflow::gtl::ArraySlice<int64> dimensions);
232
233 // Copy values from 'src_literal' rooted at 'src_shape_index' into this
234 // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
235 // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
236 // rooted at 'src_shape_index', but need not be arrays.
237 Status CopyFrom(const Literal& src_literal,
238 const ShapeIndex& dest_shape_index = {},
239 const ShapeIndex& src_shape_index = {});
240
241 // Similar to CopyFrom, but with move semantincs. The subshape of this literal
242 // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
243 // (layouts and shapes must match), but need not be arrays. The memory
244 // allocated in this literal for the subshape at dest_shape_index is
245 // deallocated, and the respective buffers are replaced with those in
246 // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
247 Status MoveFrom(Literal&& src_literal,
248 const ShapeIndex& dest_shape_index = {});
249
250 // Copies the values from src_literal, starting at src_base shape indexes,
251 // to this literal, starting at dest_base, where the copy size in each
252 // dimension is specified by copy_size.
253 // The src_literal and this literal must have the same primitive type,
254 // src_base+copy_size must fit the source literal dimensions, as well as
255 // dest_base+copy_size must fit the destination literal dimensions.
256 // Note: if either src_literal or this literal contains dimensions with zero
257 // element, then copy_size must be 0 in these dimensions while the
258 // corresponding base indices being 0.
259 // This literal and 'src_literal' must be arrays.
260 Status CopySliceFrom(const Literal& src_literal,
261 tensorflow::gtl::ArraySlice<int64> src_base,
262 tensorflow::gtl::ArraySlice<int64> dest_base,
263 tensorflow::gtl::ArraySlice<int64> copy_size);
264
265 // Returns a vector containing the tuple elements of this Literal as separate
266 // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
267 // elements are moved into the new Literals; no data is copied. Upon return
268 // this Literal is set to a nil shape (empty tuple)
269 std::vector<Literal> DecomposeTuple();
270
271 // This operation is the inverse of DecomposeTuple. The given elements are
272 // moved into the tuple elements of a new tuple-shaped Literal which is
273 // returned. Upon return, each of the Literals in 'elements' is set to a nil
274 // shape (empty tuple).
275 static Literal MoveIntoTuple(
276 tensorflow::gtl::MutableArraySlice<Literal> elements);
277
278 // Creates a new value that has the equivalent value as this literal, but
279 // conforms to new_layout; e.g. a literal matrix that was in {0, 1}
280 // minor-to-major dimension layout can be re-layed-out as {1, 0}
281 // minor-to-major dimension layout and the value in the cell at any given
282 // logical index (i0, i1) will be the same.
283 //
284 // For tuple shaped literals, shape_index should be used to select the inner
285 // array that the new layout applies to.
286 //
287 // Note: this is useful when the client wants to ensure that a value placed in
288 // the XLA allocation tracker has a particular layout; for efficiency
289 // purposes or avoiding unimplemented operation/layout combinations.
290 std::unique_ptr<Literal> Relayout(const Layout& new_layout,
291 const ShapeIndex& shape_index = {}) const;
292
293 // An overload of Relayout which changes the layout of the entire shape rather
294 // than being limited to a single array within the shape.
295 std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const;
296
297 // Creates a new literal by reshaping this literal to have the given
298 // dimensions. The total number of elements must not change; The
299 // implementation currently only supports monotonic dim0-major layouts.
300 // This literal must be an array.
301 StatusOr<std::unique_ptr<Literal>> Reshape(
302 tensorflow::gtl::ArraySlice<int64> dimensions) const;
303
304 // Creates a new literal by reordering the dimensions of this literal.
305 // The given `permutation` must be a permutation of the dimension numbers
306 // in the original literal, and it specifies the order of the new dimensions
307 // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
308 // For example, a transpose call on a literal of shape [3 x 8 x 4] and
309 // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
310 // This literal must be an array.
311 std::unique_ptr<Literal> Transpose(
312 tensorflow::gtl::ArraySlice<int64> permutation) const;
313
314 // Creates a sub-array from this literal by extracting the indices
315 // [start_index, limit_index) of each dimension. The result literal has the
316 // same rank and layout as for the given literal. The number of indices in
317 // start_indices and limit_indices must be the rank of the literal, and the
318 // indices follow the order of the dimensions.
319 // This literal must be an array.
320 std::unique_ptr<Literal> Slice(
321 tensorflow::gtl::ArraySlice<int64> start_indices,
322 tensorflow::gtl::ArraySlice<int64> limit_indices) const;
323
324 // Creates a literal with a prepended dimension with bound "times"; e.g. a
325 // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
326 // literal replicated four times.
327 // This literal must be an array.
328 template <typename NativeT>
329 std::unique_ptr<Literal> Replicate(int64 times) const;
330
331 // Converts this literal to another primitive type. Returns an error if the
332 // conversion is not possible. This literal must be array-shaped.
333 StatusOr<std::unique_ptr<Literal>> Convert(
334 PrimitiveType primitive_dest_type) const;
335
336 // Creates a scalar literal value zero of the given primitive type.
337 static Literal Zero(PrimitiveType primitive_type);
338
339 // Creates a scalar literal value one of the given primitive type.
340 static Literal One(PrimitiveType primitive_type);
341
342 // Creates a scalar literal value containing the minimum value of the given
343 // primitive type. For floating-point types, returns -inf.
344 static Literal MinValue(PrimitiveType primitive_type);
345
346 // Creates a scalar literal value containing the maximum value of the given
347 // primitive type. For floating-point types, returns inf.
348 static Literal MaxValue(PrimitiveType primitive_type);
349
350 // Creates a literal of the given shape where each element is `value`.
351 template <typename NativeT>
352 static std::unique_ptr<Literal> CreateFullWithDescendingLayout(
353 tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value);
354
355 // Creates a new literal from an Array type. The variants not ending with
356 // WithLayout use the default XLA layout for the literal's linear
357 // representation in memory.
358 template <typename NativeT>
359 static std::unique_ptr<Literal> CreateFromArray(const Array<NativeT>& values);
360 template <typename NativeT>
361 static std::unique_ptr<Literal> CreateFromArrayWithLayout(
362 const Array<NativeT>& values, const Layout& layout);
363 template <typename NativeT>
364 static std::unique_ptr<Literal> CreateR2FromArray2D(
365 const Array2D<NativeT>& values);
366 template <typename NativeT>
367 static std::unique_ptr<Literal> CreateR2FromArray2DWithLayout(
368 const Array2D<NativeT>& values, const Layout& layout);
369 template <typename NativeT>
370 static std::unique_ptr<Literal> CreateR3FromArray3D(
371 const Array3D<NativeT>& values);
372 template <typename NativeT>
373 static std::unique_ptr<Literal> CreateR3FromArray3DWithLayout(
374 const Array3D<NativeT>& values, const Layout& layout);
375 template <typename NativeT>
376 static std::unique_ptr<Literal> CreateR4FromArray4D(
377 const Array4D<NativeT>& values);
378 template <typename NativeT>
379 static std::unique_ptr<Literal> CreateR4FromArray4DWithLayout(
380 const Array4D<NativeT>& values, const Layout& layout);
381
382 // Creates a new vector of U8s literal value from a string.
383 static std::unique_ptr<Literal> CreateR1U8(tensorflow::StringPiece value);
384
385 // Creates a linspace-populated literal with the given number of rows and
386 // columns.
387 static std::unique_ptr<Literal> CreateR2F32Linspace(float from, float to,
388 int64 rows, int64 cols);
389
390 // Creates a literal that projects the (x, y) dimensions given in values into
391 // the z dimension given by "projection".
392 template <typename NativeT>
393 static std::unique_ptr<Literal> CreateR3Projected(
394 std::initializer_list<std::initializer_list<NativeT>> values,
395 int64 projection);
396
397 // Creates a literal that projects the (x, y) dimensions given in values into
398 // the z and p dimensions given.
399 template <typename NativeT>
400 static std::unique_ptr<Literal> CreateR4Projected(
401 std::initializer_list<std::initializer_list<NativeT>> values,
402 int64 projection_p, int64 projection_z);
403
404 // Clones this literal into a new Literal, or new std::unique_ptr<Literal>.
405 Literal Clone() const;
406 std::unique_ptr<Literal> CloneToUnique() const;
407
408 // Gets or sets an element in the literal at the given index. The multi_index
409 // is CHECKed against the dimension sizes.
410 template <typename NativeT>
411 NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index,
412 const ShapeIndex& shape_index) const;
413 template <typename NativeT>
414 void Set(tensorflow::gtl::ArraySlice<int64> multi_index,
415 const ShapeIndex& shape_index, NativeT value);
416
417 // Overloads of Get and Set for array literals. CHECKs if the literal is not
418 // array-shaped and dense.
419 template <typename NativeT>
420 NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index) const;
421 template <typename NativeT>
422 void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value);
423
424 // Returns the multi-index of the element in a sparse literal at the given
425 // sparse element number. The sparse element number is the position with in
426 // the sparse array's list of (index, value) pairs, and is checked against the
427 // total number of (index, value) pairs in the sparse array.
428 tensorflow::gtl::ArraySlice<int64> GetSparseIndex(
429 int64 sparse_element_number, const ShapeIndex& shape_index = {}) const;
430
431 // Returns the value of the element in a sparse literal at the given sparse
432 // element number. The sparse element number is the position with in the
433 // sparse array's list of (index, value) pairs, and is checked against the
434 // total number of (index, value) pairs in the sparse array.
435 template <typename NativeT>
436 NativeT GetSparseElement(int64 sparse_element_number,
437 const ShapeIndex& shape_index = {}) const;
438
439 // Appends the given element to the literal. If the elements are not appended
440 // in sorted order, then SortSparseElements should be called before calling
441 // other methods. This literal must have a sparse layout.
442 template <typename NativeT>
443 void AppendSparseElement(tensorflow::gtl::ArraySlice<int64> multi_index,
444 NativeT value, const ShapeIndex& shape_index = {});
445
446 // Sorts the elements in a sparse array.
447 void SortSparseElements(const ShapeIndex& shape_index = {});
448
449 // Returns the element value at index (0, ..., 0), however many zeroes are
450 // required for that index.
451 template <typename NativeT>
452 NativeT GetFirstElement() const;
453
454 // As Get(), but determines the correct type and converts the value
455 // into text.
456 string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
457 const ShapeIndex& shape_index = {}) const;
458
459 // As GetSparseElement(), but determines the correct type and converts the
460 // value into text.
461 string GetSparseElementAsString(int64 sparse_element_number,
462 const ShapeIndex& shape_index = {}) const;
463
464 // As Get(), but determines the correct type and converts the value into
465 // int64. This literal must be an array.
466 StatusOr<int64> GetIntegralAsS64(
467 tensorflow::gtl::ArraySlice<int64> multi_index) const;
468
469 // Returns an identity matrix (rank 2) with the given row and column count.
470 template <typename NativeT>
471 static std::unique_ptr<Literal> MakeIdentityR2(int64 size);
472
473 // Returns a tuple literal composed of given literals. Data is copied from the
474 // given elements into the returned literal.
475 static std::unique_ptr<Literal> MakeTuple(
476 tensorflow::gtl::ArraySlice<const Literal*> elements);
477
478 // As above, but intended to be invoked with move semantics; i.e.
479 //
480 // std::vector<std::unique_ptr<Literal>> elements = ...;
481 // auto result = Literal::MakeTupleOwned(std::move(elements));
482 //
483 // This would have been declared as an overload, but there is ambiguity
484 // in invocation between the above signature and this one.
485 static std::unique_ptr<Literal> MakeTupleOwned(
486 std::vector<std::unique_ptr<Literal>> elements);
487
488 // This overload lets you pass a braced list of unique_ptr<Literal>s to
489 // MakeTupleOwned:
490 //
491 // Literal::MakeTupleOwned(Literal::CreateR1(...), ...).
492 //
493 // Simply relying on the MakeTupleOwned(std::vector<unique_ptr<Literal>>)
494 // overload doesn't work because std::initializer_list's elements are always
495 // const.
496 //
497 // The arguments to this function must all be unique_ptr<Literal>.
498 template <typename... Ts>
MakeTupleOwned(std::unique_ptr<Ts>...elements)499 static std::unique_ptr<Literal> MakeTupleOwned(
500 std::unique_ptr<Ts>... elements) {
501 std::array<std::unique_ptr<Literal>, sizeof...(Ts)> arr{
502 std::move(elements)...};
503 std::vector<std::unique_ptr<Literal>> v;
504 v.insert(v.begin(), std::make_move_iterator(arr.begin()),
505 std::make_move_iterator(arr.end()));
506 return MakeTupleOwned(std::move(v));
507 }
508
509 // Returns a string representation of the literal value.
510 // Warning: this function can take minutes for multi-million element Literals.
511 string ToString(bool print_layout = false) const;
512
513 // Invokes the "per cell" callback for each element in the provided
514 // literal with the element's indices and a string representation of
515 // the element's value.
516 //
517 // This function is useful if you want a polymorphic representation
518 // of the tensor's elements (turning it to a string for something
519 // like representation in a protobuf).
520 //
521 // This literal must have a dense layout.
522 void EachCellAsString(
523 const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
524 const string& value)>& per_cell) const;
525 template <typename NativeT>
526 void EachCell(std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
527 NativeT value)>
528 per_cell) const;
529
530 // Populate this literal with the given values. Examples:
531 //
532 // // Populate with floats.
533 // Array2D<float> float_values = ...
534 // literal.PopulateR2FromArray2D(values);
535 //
536 // // Populate with int32s.
537 // literal.PopulateR2<int32>({{1, 2}, {3, 4}});
538 //
539 // The shape and element type of this literal must match given values. For
540 // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
541 // array of S32.
542 template <typename NativeT>
543 void PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values);
544 void PopulateR1(const tensorflow::core::Bitmap& values);
545 template <typename NativeT>
546 void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
547 template <typename NativeT>
548 void PopulateFromArray(const Array<NativeT>& values);
549 template <typename NativeT>
550 void PopulateR2FromArray2D(const Array2D<NativeT>& values);
551 template <typename NativeT>
552 void PopulateR3FromArray3D(const Array3D<NativeT>& values);
553 template <typename NativeT>
554 void PopulateR4FromArray4D(const Array4D<NativeT>& values);
555
556 // Populates literal values by calling the generator function for every cell
557 // in this literal object.
558 //
559 // generator must be a callable of the type
560 // NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible.
561 //
562 // This literal must have a dense layout.
563 template <typename NativeT, typename FnType>
564 Status Populate(const FnType& generator);
565
566 // Fills this literal with the given value.
567 template <typename NativeT>
568 void PopulateWithValue(NativeT value);
569
570 // Returns whether every element in this literal is equal to value.
571 //
572 // value is an int8 because we expect this to be called with small
573 // compile-time constants (0, -1, etc.) and so that whatever value you pass
574 // can be represented exactly by floating-point types as small as 16 bits.
575 //
576 // If value doesn't fit in this literal's type, returns false. Values of 1/0
577 // are considered equal to true/false; other values are not considered equal
578 // to true. Also if this literal is not array-shaped false is returned.
579 bool IsAll(int8 value) const;
580
581 // Like IsAll(const Literal&, int8), except we check whether the literal is
582 // equal to a particular floating-point number.
583 //
584 // If the literal is not a floating-point value, this always returns false.
585 //
586 // This casts value to the type of literal, then compares using ==. The usual
587 // admonishments about floating-point equality checks apply. We expect you to
588 // use this to check for values that can be expressed precisely as a float,
589 // e.g. -0.5. Also if this literal is not array-shaped false is returned.
590 bool IsAllFloat(float value) const;
591
592 // Like IsAll(const Literal&, int8), except we check whether the literal is
593 // equal to a particular complex number.
594 //
595 // If the literal is not a complex value, this always returns false.
596 //
597 // This casts value to the type of literal, then compares using ==. The usual
598 // admonishments about floating-point equality checks apply. We expect you to
599 // use this to check for complex values that can be expressed precisely as
600 // float pairs e.g. (-0.5, 1.0).
601 //
602 // This literal must have a dense layout.
603 bool IsAllComplex(complex64 value) const;
604
605 // Returns whether this literal is zero at the specified index. This literal
606 // must be an array with a dense layout.
607 bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
608
609 // Return the count of the elements in the array at the given shape index in
610 // this literal.
611 int64 element_count(const ShapeIndex& index = {}) const {
612 return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
613 }
614
615 // Return the count of the elements in the sparse array at the given shape
616 // index in this literal, which will be no larger than
617 // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()).
618 int64 sparse_element_count() const;
619
620 protected:
621 // 'allocate_arrays' indicates whether to allocate memory for the arrays in
622 // the shape. If false, buffer pointers inside of the Literal::Pieces are set
623 // to nullptr.
624 Literal(const Shape& shape, bool allocate_arrays);
625
626 // Internal template helper for the Literal::CopySliceFrom(), matching its
627 // arguments one by one.
628 template <typename NativeT>
629 Status CopySliceFromInternal(const Literal& src_literal,
630 tensorflow::gtl::ArraySlice<int64> src_base,
631 tensorflow::gtl::ArraySlice<int64> dest_base,
632 tensorflow::gtl::ArraySlice<int64> copy_size);
633
634 // Utility structure which is used to create the optimal configuration for
635 // a ShapeUtil::ForEachIndex() scan across two literals.
636 struct StrideConfig {
637 StrideConfig(const Shape& source_shape, const Shape& dest_shape,
638 tensorflow::gtl::ArraySlice<int64> dimensions);
639
640 // The dimensions of the stride operation. Essentially every dimension
641 // will be iterated from base[i] to base[i]+dimensions[i], in step[i]
642 // steps.
643 tensorflow::gtl::ArraySlice<int64> dimensions;
644 DimensionVector base;
645 DimensionVector step;
646 int64 minor_dimension = 0;
647 // The size of the strides for source and destination. One of the two
648 // (the one looping through its most minor dimension) will be 1, while
649 // the other will be the stride size at the dimension matching the other
650 // shape most minor dimension being scanned.
651 int64 dest_stride = 1;
652 int64 source_stride = 1;
653 // The size of the inner loop on the most minor dimension.
654 int64 minor_loop_size = 1;
655 };
656
657 // A data structure representing a subshape at a particular ShapeIndex within
658 // the literal. For array-shaped ShapeIndexes, this data structure holds the
659 // pointer to the memory allocated for the array data.
660 class Piece {
661 public:
662 // Return the buffer holding the array data for this piece as an array
663 // slice. This piece must be array-shaped.
664 template <typename NativeT>
665 tensorflow::gtl::ArraySlice<NativeT> data() const;
666 template <typename NativeT>
667 tensorflow::gtl::MutableArraySlice<NativeT> data();
668
669 // Return the buffer holding the array data for this piece as a void*. This
670 // piece must be array-shaped.
671 void* untyped_data();
672 const void* untyped_data() const;
673
674 // Gets or sets an element in the array at the given index. The multi_index
675 // is CHECKed against the dimension sizes of the array. This piece must be
676 // array-shaped.
677 template <typename NativeT>
678 NativeT Get(tensorflow::gtl::ArraySlice<int64> index) const;
679 template <typename NativeT>
680 void Set(tensorflow::gtl::ArraySlice<int64> index, NativeT value);
681
682 // Gets/sets the buffer holding the array data.
buffer()683 char* buffer() const { return buffer_; }
set_buffer(char * buffer)684 void set_buffer(char* buffer) { buffer_ = buffer; }
685
686 // The array of multi-indices that provide the locations of non-zero
687 // elements in a sparse array. Only used if
688 // LayoutUtil::IsSparseArray(shape()) is true.
sparse_indices()689 SparseIndexArray* sparse_indices() const { return sparse_indices_; }
set_sparse_indices(SparseIndexArray * sparse_indices)690 void set_sparse_indices(SparseIndexArray* sparse_indices) {
691 sparse_indices_ = sparse_indices;
692 }
693
694 // Gets or sets the subshape of this piece. This reference points to a
695 // subshape within the shape in the containing Literal (Literal::shape_).
subshape()696 const Shape& subshape() const { return *subshape_; }
set_subshape(const Shape * subshape)697 void set_subshape(const Shape* subshape) { subshape_ = subshape; }
698
699 // Returns the size in bytes of the buffer holding the array data.
size_bytes()700 int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
701
702 // Returns the number of elements in this piece's array.
element_count()703 int64 element_count() const { return ShapeUtil::ElementsIn(subshape()); }
704
705 // Copy the data from 'src' into this piece's buffer. Shapes of this piece
706 // and src must be compatible.
707 Status CopyFrom(const Piece& src);
708
709 // Returns true if this piece and 'other' contain the same data. This piece
710 // and 'other' must be array-shaped and compatible.
711 bool EqualElements(const Piece& other) const;
712
713 // Writes the shape and data (if array-shaped) into the given proto.
714 void WriteToProto(LiteralProto* proto) const;
715
716 // Copies the data from the given proto into this piece. The shape of this
717 // piece must be equal (not just compatible) to the shape of the proto.
718 Status CopyFromProto(const LiteralProto& proto);
719
720 // Sorts the elements in a sparse array.
721 void SortSparseElements();
722
723 private:
724 // Recursive helper for EqualElements.
725 template <typename NativeT>
726 bool EqualElementsInternal(const Piece& other,
727 std::vector<int64>* multi_index) const;
728
729 // Helper for SortSparseElements that has the element type as a template
730 // parameter.
731 template <typename NativeT>
732 void SortSparseElementsInternal();
733
734 // For array-shaped pieces, this is the buffer holding the literal data.
735 char* buffer_ = nullptr;
736
737 // For sparse arrays, this is the array of indices.
738 SparseIndexArray* sparse_indices_ = nullptr;
739
740 // The shape of piece. This points into the shape of the containing Literal
741 // (Literal::shape_).
742 const Shape* subshape_ = nullptr;
743 };
744
745 // Returns the piece at the given ShapeIndex.
piece(const ShapeIndex & shape_index)746 Piece& piece(const ShapeIndex& shape_index) {
747 return *pieces_.mutable_element(shape_index);
748 }
piece(const ShapeIndex & shape_index)749 const Piece& piece(const ShapeIndex& shape_index) const {
750 return pieces_.element(shape_index);
751 }
752
753 // Returns the piece at the root of the shape (empty ShapeIndex).
root_piece()754 Piece& root_piece() { return piece({}); }
root_piece()755 const Piece& root_piece() const { return piece({}); }
756
757 // Deallocate the buffers held by this literal (if the literal owns the
758 // buffer).
759 void DeallocateBuffers();
760
761 Shape shape_;
762 ShapeTree<Piece> pieces_;
763
764 // Whether the buffers held in pieces_ are owned by this Literal.
765 bool owns_buffers_;
766
767 // LiteralView must access and manipulate Pieces of other Literals.
768 friend class LiteralView;
769 }; // namespace xla
770
771 std::ostream& operator<<(std::ostream& out, const Literal& literal);
772
773 // A read-only view of a Literal. A LiteralView contains pointers to buffers
774 // owned by the viewed Literal.
775 //
776 // TODO(b/71550060): Replace LiteralView with Literal slice classes (immutable
777 // and mutable) similar to (Mutable)ArraySlice.
778 class LiteralView : public Literal {
779 public:
780 // Create and return a view of the given literal rooted at the given shape
781 // index within the given literal. A factory is used rather than a public
782 // constructor because only const LiteralViews are supported. It's still
783 // possible to create non-const LiteralViews via the copy constructors, but
784 // the factory method makes it a bit less likely. Implementing literal slices
785 // will fix this undesirable situation (b/71550060).
786 static const LiteralView Create(const Literal& literal,
787 const ShapeIndex& view_root = {});
788
789 LiteralView(const LiteralView& other);
790 LiteralView& operator=(const LiteralView& other);
791
792 virtual ~LiteralView();
793
794 private:
795 LiteralView(const Literal& literal, const ShapeIndex& view_root);
796
797 // Helper for the copy constructor and copy assignment operator.
798 void CopyFrom(const LiteralView& other);
799 };
800
801 template <typename NativeT>
data()802 tensorflow::gtl::ArraySlice<NativeT> Literal::Piece::data() const {
803 CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
804 CHECK_EQ(subshape().element_type(),
805 primitive_util::NativeToPrimitiveType<NativeT>())
806 << "Attempting to access "
807 << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
808 << " type, but literal element type is "
809 << PrimitiveType_Name(subshape().element_type());
810 return tensorflow::gtl::ArraySlice<NativeT>(
811 reinterpret_cast<const NativeT*>(buffer()),
812 ShapeUtil::ElementsIn(subshape()));
813 }
814
815 template <typename NativeT>
data()816 tensorflow::gtl::MutableArraySlice<NativeT> Literal::Piece::data() {
817 CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
818 CHECK_EQ(subshape().element_type(),
819 primitive_util::NativeToPrimitiveType<NativeT>())
820 << "Attempting to access "
821 << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
822 << " type, but literal element type is "
823 << PrimitiveType_Name(subshape().element_type());
824 return tensorflow::gtl::MutableArraySlice<NativeT>(
825 reinterpret_cast<NativeT*>(buffer()), ShapeUtil::ElementsIn(subshape()));
826 }
827
828 template <typename NativeT>
Get(tensorflow::gtl::ArraySlice<int64> multi_index)829 NativeT Literal::Piece::Get(
830 tensorflow::gtl::ArraySlice<int64> multi_index) const {
831 CHECK(LayoutUtil::IsDenseArray(subshape()));
832 return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
833 subshape(), multi_index)];
834 }
835
836 template <typename NativeT>
Set(tensorflow::gtl::ArraySlice<int64> multi_index,NativeT value)837 void Literal::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
838 NativeT value) {
839 CHECK(LayoutUtil::IsDenseArray(subshape()));
840 data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
841 subshape(), multi_index)] = value;
842 }
843
844 template <typename NativeT>
data(const ShapeIndex & shape_index)845 tensorflow::gtl::ArraySlice<NativeT> Literal::data(
846 const ShapeIndex& shape_index) const {
847 return piece(shape_index).data<NativeT>();
848 }
849
850 template <typename NativeT>
data(const ShapeIndex & shape_index)851 tensorflow::gtl::MutableArraySlice<NativeT> Literal::data(
852 const ShapeIndex& shape_index) {
853 return piece(shape_index).data<NativeT>();
854 }
855
856 template <typename NativeT>
Get(tensorflow::gtl::ArraySlice<int64> multi_index,const ShapeIndex & shape_index)857 inline NativeT Literal::Get(tensorflow::gtl::ArraySlice<int64> multi_index,
858 const ShapeIndex& shape_index) const {
859 return piece(shape_index).Get<NativeT>(multi_index);
860 }
861
862 template <typename NativeT>
Get(tensorflow::gtl::ArraySlice<int64> multi_index)863 inline NativeT Literal::Get(
864 tensorflow::gtl::ArraySlice<int64> multi_index) const {
865 return root_piece().Get<NativeT>(multi_index);
866 }
867
868 template <typename NativeT>
Set(tensorflow::gtl::ArraySlice<int64> multi_index,const ShapeIndex & shape_index,NativeT value)869 inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
870 const ShapeIndex& shape_index, NativeT value) {
871 return piece(shape_index).Set<NativeT>(multi_index, value);
872 }
873
874 template <typename NativeT>
Set(tensorflow::gtl::ArraySlice<int64> multi_index,NativeT value)875 inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
876 NativeT value) {
877 return root_piece().Set<NativeT>(multi_index, value);
878 }
879
880 template <typename NativeT>
CreateR0(NativeT value)881 /* static */ std::unique_ptr<Literal> Literal::CreateR0(NativeT value) {
882 auto literal = MakeUnique<Literal>(ShapeUtil::MakeShape(
883 primitive_util::NativeToPrimitiveType<NativeT>(), {}));
884 literal->Set({}, value);
885 return literal;
886 }
887
888 template <typename NativeT>
CreateR1(tensorflow::gtl::ArraySlice<NativeT> values)889 /* static */ std::unique_ptr<Literal> Literal::CreateR1(
890 tensorflow::gtl::ArraySlice<NativeT> values) {
891 auto literal = MakeUnique<Literal>(
892 ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
893 {static_cast<int64>(values.size())}));
894 literal->PopulateR1(values);
895 return literal;
896 }
897
898 template <typename NativeT>
CreateR2WithLayout(std::initializer_list<std::initializer_list<NativeT>> values,const Layout & layout)899 /* static */ std::unique_ptr<Literal> Literal::CreateR2WithLayout(
900 std::initializer_list<std::initializer_list<NativeT>> values,
901 const Layout& layout) {
902 auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(
903 primitive_util::NativeToPrimitiveType<NativeT>(),
904 {static_cast<int64>(values.size()),
905 static_cast<int64>(values.begin()->size())},
906 AsInt64Slice(layout.minor_to_major())));
907 literal->PopulateR2(values);
908 return literal;
909 }
910
911 template <typename NativeT>
CreateR2(std::initializer_list<std::initializer_list<NativeT>> values)912 /* static */ std::unique_ptr<Literal> Literal::CreateR2(
913 std::initializer_list<std::initializer_list<NativeT>> values) {
914 return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
915 }
916
917 template <typename NativeT>
CreateR3WithLayout(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> values,const Layout & layout)918 /* static */ std::unique_ptr<Literal> Literal::CreateR3WithLayout(
919 std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
920 values,
921 const Layout& layout) {
922 const int64 d0 = values.size();
923 const int64 d1 = values.begin()->size();
924 const int64 d2 = values.begin()->begin()->size();
925 Array3D<NativeT> tmp(d0, d1, d2);
926 int64 i0 = 0;
927 for (auto d1_values : values) {
928 int64 i1 = 0;
929 for (auto d2_values : d1_values) {
930 int64 i2 = 0;
931 for (auto value : d2_values) {
932 tmp(i0, i1, i2) = value;
933 ++i2;
934 }
935 ++i1;
936 }
937 ++i0;
938 }
939 return CreateR3FromArray3DWithLayout(tmp, layout);
940 }
941
942 template <typename NativeT>
CreateR3(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> values)943 /* static */ std::unique_ptr<Literal> Literal::CreateR3(
944 std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
945 values) {
946 return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
947 }
948
949 template <typename NativeT>
CreateR4WithLayout(std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>> values,const Layout & layout)950 /* static */ std::unique_ptr<Literal> Literal::CreateR4WithLayout(
951 std::initializer_list<std::initializer_list<
952 std::initializer_list<std::initializer_list<NativeT>>>>
953 values,
954 const Layout& layout) {
955 const int64 d0 = values.size();
956 const int64 d1 = values.begin()->size();
957 const int64 d2 = values.begin()->begin()->size();
958 const int64 d3 = values.begin()->begin()->begin()->size();
959 Array4D<NativeT> tmp(d0, d1, d2, d3);
960 int64 i0 = 0;
961 for (auto d1_values : values) {
962 int64 i1 = 0;
963 for (auto d2_values : d1_values) {
964 int64 i2 = 0;
965 for (auto d3_values : d2_values) {
966 int64 i3 = 0;
967 for (auto value : d3_values) {
968 tmp(i0, i1, i2, i3) = value;
969 ++i3;
970 }
971 ++i2;
972 }
973 ++i1;
974 }
975 ++i0;
976 }
977 return CreateR4FromArray4DWithLayout(tmp, layout);
978 }
979
980 template <typename NativeT>
CreateSparse(tensorflow::gtl::ArraySlice<int64> dimensions,SparseIndexArray indices,tensorflow::gtl::ArraySlice<NativeT> values,bool sort)981 /* static */ std::unique_ptr<Literal> Literal::CreateSparse(
982 tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices,
983 tensorflow::gtl::ArraySlice<NativeT> values, bool sort) {
984 int64 num_elements = values.size();
985 int64 rank = dimensions.size();
986 CHECK_EQ(num_elements, indices.index_count());
987 CHECK_EQ(rank, indices.rank());
988 auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithSparseLayout(
989 primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
990 indices.max_indices()));
991 literal->PopulateSparse(indices, values, sort);
992 return literal;
993 }
994
995 template <typename NativeT>
CreateR4(std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>> values)996 /* static */ std::unique_ptr<Literal> Literal::CreateR4(
997 std::initializer_list<std::initializer_list<
998 std::initializer_list<std::initializer_list<NativeT>>>>
999 values) {
1000 return CreateR4WithLayout(values, LayoutUtil::GetDefaultLayoutForR4());
1001 }
1002
1003 template <typename NativeT>
CreateFromArrayWithLayout(const Array<NativeT> & values,const Layout & layout)1004 /* static */ std::unique_ptr<Literal> Literal::CreateFromArrayWithLayout(
1005 const Array<NativeT>& values, const Layout& layout) {
1006 auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(
1007 primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
1008 AsInt64Slice(layout.minor_to_major())));
1009 literal->PopulateFromArray(values);
1010 return literal;
1011 }
1012
1013 template <typename NativeT>
CreateFromArray(const Array<NativeT> & values)1014 /* static */ std::unique_ptr<Literal> Literal::CreateFromArray(
1015 const Array<NativeT>& values) {
1016 return CreateFromArrayWithLayout(
1017 values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
1018 }
1019
1020 template <typename NativeT>
CreateR2FromArray2DWithLayout(const Array2D<NativeT> & values,const Layout & layout)1021 /* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2DWithLayout(
1022 const Array2D<NativeT>& values, const Layout& layout) {
1023 return CreateFromArrayWithLayout(values, layout);
1024 }
1025
1026 template <typename NativeT>
CreateR2FromArray2D(const Array2D<NativeT> & values)1027 /* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2D(
1028 const Array2D<NativeT>& values) {
1029 return CreateFromArray(values);
1030 }
1031
1032 template <typename NativeT>
CreateR3FromArray3DWithLayout(const Array3D<NativeT> & values,const Layout & layout)1033 /* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3DWithLayout(
1034 const Array3D<NativeT>& values, const Layout& layout) {
1035 return CreateFromArrayWithLayout(values, layout);
1036 }
1037
1038 template <typename NativeT>
CreateR3FromArray3D(const Array3D<NativeT> & values)1039 /* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3D(
1040 const Array3D<NativeT>& values) {
1041 return CreateFromArray(values);
1042 }
1043
1044 template <typename NativeT>
CreateR3Projected(std::initializer_list<std::initializer_list<NativeT>> values,int64 projection)1045 /* static */ std::unique_ptr<Literal> Literal::CreateR3Projected(
1046 std::initializer_list<std::initializer_list<NativeT>> values,
1047 int64 projection) {
1048 int64 dim0_size = projection;
1049 int64 dim1_size = values.size();
1050 int64 dim2_size = values.begin()->size();
1051
1052 Array3D<NativeT> array(dim0_size, dim1_size, dim2_size);
1053 for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) {
1054 int64 dim1 = 0;
1055 for (auto inner_list : values) {
1056 int64 dim2 = 0;
1057 for (auto value : inner_list) {
1058 array(dim0, dim1, dim2) = value;
1059 ++dim2;
1060 }
1061 CHECK_EQ(dim2_size, dim2);
1062 ++dim1;
1063 }
1064 CHECK_EQ(dim1_size, dim1);
1065 }
1066 return CreateR3FromArray3D(array);
1067 }
1068
1069 template <typename NativeT>
CreateR4Projected(std::initializer_list<std::initializer_list<NativeT>> values,int64 projection_p,int64 projection_z)1070 /* static */ std::unique_ptr<Literal> Literal::CreateR4Projected(
1071 std::initializer_list<std::initializer_list<NativeT>> values,
1072 int64 projection_p, int64 projection_z) {
1073 int64 dim0_size = projection_p;
1074 int64 dim1_size = projection_z;
1075 int64 dim2_size = values.size();
1076 int64 dim3_size = values.begin()->size();
1077
1078 Array4D<NativeT> array(dim0_size, dim1_size, dim2_size, dim3_size);
1079 for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) {
1080 for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) {
1081 int64 dim2 = 0;
1082 for (auto inner_list : values) {
1083 int64 dim3 = 0;
1084 for (auto value : inner_list) {
1085 array(dim0, dim1, dim2, dim3) = value;
1086 ++dim3;
1087 }
1088 CHECK_EQ(dim3_size, dim3);
1089 ++dim2;
1090 }
1091 CHECK_EQ(dim2_size, dim2);
1092 }
1093 }
1094 return CreateR4FromArray4D(array);
1095 }
1096
1097 template <typename NativeT>
CreateR4FromArray4D(const Array4D<NativeT> & values)1098 /* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4D(
1099 const Array4D<NativeT>& values) {
1100 return CreateFromArray(values);
1101 }
1102
1103 template <typename NativeT>
CreateR4FromArray4DWithLayout(const Array4D<NativeT> & values,const Layout & layout)1104 /* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4DWithLayout(
1105 const Array4D<NativeT>& values, const Layout& layout) {
1106 return CreateFromArrayWithLayout(values, layout);
1107 }
1108
1109 template <typename NativeT>
GetFirstElement()1110 NativeT Literal::GetFirstElement() const {
1111 return data<NativeT>().at(0);
1112 }
1113
1114 template <typename NativeT>
GetSparseElement(int64 sparse_element_number,const ShapeIndex & shape_index)1115 NativeT Literal::GetSparseElement(int64 sparse_element_number,
1116 const ShapeIndex& shape_index) const {
1117 CHECK(
1118 LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index)));
1119 return data<NativeT>(shape_index)[sparse_element_number];
1120 }
1121
1122 template <typename NativeT>
AppendSparseElement(tensorflow::gtl::ArraySlice<int64> multi_index,NativeT value,const ShapeIndex & shape_index)1123 void Literal::AppendSparseElement(
1124 tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value,
1125 const ShapeIndex& shape_index) {
1126 Piece& p = piece(shape_index);
1127 const Shape& subshape = p.subshape();
1128 CHECK(LayoutUtil::IsSparseArray(subshape));
1129 int64 rank = ShapeUtil::Rank(subshape);
1130 CHECK_EQ(multi_index.size(), rank);
1131 int64 last_element = p.sparse_indices()->index_count();
1132 CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout()));
1133 p.sparse_indices()->Append(multi_index);
1134 CHECK_LT(last_element, p.data<NativeT>().size());
1135 p.data<NativeT>()[last_element] = value;
1136 }
1137
1138 // Returns an identity matrix (rank 2) with the given row and column count.
1139 template <typename NativeT>
MakeIdentityR2(int64 size)1140 /* static */ std::unique_ptr<Literal> Literal::MakeIdentityR2(int64 size) {
1141 Array2D<NativeT> array(size, size, 0);
1142 for (int64 i = 0; i < size; ++i) {
1143 array(i, i) = 1;
1144 }
1145 return CreateR2FromArray2D(array);
1146 }
1147
1148 template <typename NativeT>
EachCell(std::function<void (tensorflow::gtl::ArraySlice<int64> indices,NativeT value)> per_cell)1149 void Literal::EachCell(
1150 std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
1151 NativeT value)>
1152 per_cell) const {
1153 if (ShapeUtil::HasZeroElements(shape())) {
1154 return;
1155 }
1156 std::vector<int64> indices(ShapeUtil::Rank(shape()), 0);
1157 do {
1158 per_cell(indices, Get<NativeT>(indices));
1159 } while (IndexUtil::BumpIndices(shape(), &indices));
1160 }
1161
1162 template <typename NativeT>
PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values)1163 inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) {
1164 CHECK(ShapeUtil::IsArray(shape()));
1165 CHECK_EQ(ShapeUtil::Rank(shape()), 1);
1166 CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
1167 CHECK_EQ(shape().element_type(),
1168 primitive_util::NativeToPrimitiveType<NativeT>());
1169 for (int64 i = 0; i < values.size(); ++i) {
1170 Set({i}, values[i]);
1171 }
1172 }
1173
1174 template <typename NativeT>
PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values)1175 void Literal::PopulateR2(
1176 std::initializer_list<std::initializer_list<NativeT>> values) {
1177 CHECK(ShapeUtil::IsArray(shape()));
1178 CHECK_EQ(ShapeUtil::Rank(shape()), 2);
1179 CHECK_EQ(shape().element_type(),
1180 primitive_util::NativeToPrimitiveType<NativeT>());
1181
1182 const int64 dim0_size = values.size();
1183 const int64 dim1_size = values.begin()->size();
1184 CHECK_EQ(dim0_size, shape().dimensions(0));
1185 CHECK_EQ(dim1_size, shape().dimensions(1));
1186
1187 int64 dim0 = 0;
1188 for (auto inner_list : values) {
1189 int64 dim1 = 0;
1190 for (auto value : inner_list) {
1191 Set({dim0, dim1}, value);
1192 ++dim1;
1193 }
1194 CHECK_EQ(dim1_size, dim1);
1195 ++dim0;
1196 }
1197 }
1198
1199 template <typename NativeT>
PopulateFromArray(const Array<NativeT> & values)1200 void Literal::PopulateFromArray(const Array<NativeT>& values) {
1201 CHECK(ShapeUtil::IsArray(shape()));
1202 CHECK_EQ(shape().element_type(),
1203 primitive_util::NativeToPrimitiveType<NativeT>());
1204 CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions());
1205 for (int dim = 0; dim < values.num_dimensions(); ++dim) {
1206 CHECK_EQ(values.dim(dim), shape().dimensions(dim));
1207 }
1208 values.Each([this](tensorflow::gtl::ArraySlice<int64> indices,
1209 NativeT value) { this->Set(indices, value); });
1210 }
1211
1212 template <typename NativeT>
PopulateR2FromArray2D(const Array2D<NativeT> & values)1213 void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
1214 PopulateFromArray(values);
1215 }
1216
1217 template <typename NativeT>
PopulateR3FromArray3D(const Array3D<NativeT> & values)1218 void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
1219 PopulateFromArray(values);
1220 }
1221
1222 template <typename NativeT>
PopulateR4FromArray4D(const Array4D<NativeT> & values)1223 void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
1224 PopulateFromArray(values);
1225 }
1226
1227 template <typename NativeT>
PopulateSparse(SparseIndexArray indices,tensorflow::gtl::ArraySlice<NativeT> values,bool sort)1228 void Literal::PopulateSparse(SparseIndexArray indices,
1229 tensorflow::gtl::ArraySlice<NativeT> values,
1230 bool sort) {
1231 CHECK(LayoutUtil::IsSparseArray(shape()));
1232 int rank = ShapeUtil::Rank(shape());
1233 CHECK_EQ(indices.rank(), rank);
1234 int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout());
1235 CHECK_LE(indices.max_indices(), max_elements);
1236 int64 num_elements = values.size();
1237 CHECK_LE(num_elements, max_elements);
1238 CHECK_EQ(num_elements, indices.index_count());
1239 auto root_data = root_piece().data<NativeT>();
1240 root_data.remove_suffix(max_elements - values.size());
1241 std::copy(values.begin(), values.end(), root_data.begin());
1242 *this->root_piece().sparse_indices() = std::move(indices);
1243 if (sort) {
1244 auto root_data = this->root_piece().data<NativeT>();
1245 root_data.remove_suffix(root_data.size() - num_elements);
1246 this->root_piece().sparse_indices()->SortWithValues(root_data);
1247 }
1248 DCHECK(this->root_piece().sparse_indices()->Validate(shape()));
1249 }
1250
1251 template <typename NativeT, typename FnType>
Populate(const FnType & generator)1252 Status Literal::Populate(const FnType& generator) {
1253 const Shape& this_shape = shape();
1254 const int64 rank = ShapeUtil::Rank(this_shape);
1255 TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
1256 TF_RET_CHECK(this_shape.element_type() ==
1257 primitive_util::NativeToPrimitiveType<NativeT>());
1258 tensorflow::gtl::MutableArraySlice<NativeT> literal_data = data<NativeT>();
1259 if (rank > 0) {
1260 StrideConfig stride_config(this_shape, this_shape,
1261 AsInt64Slice(this_shape.dimensions()));
1262 DimensionVector minor_scan_indexes(rank, 0);
1263 int64 minor_dimension_size =
1264 ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
1265
1266 auto init_function = [&](const std::vector<int64>& indexes) {
1267 const int64 index =
1268 IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
1269 std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
1270 for (int64 i = 0; i < minor_dimension_size; ++i) {
1271 minor_scan_indexes[stride_config.minor_dimension] = i;
1272 literal_data.at(index + i) = generator(minor_scan_indexes);
1273 }
1274 return true;
1275 };
1276 ShapeUtil::ForEachIndex(this_shape, stride_config.base,
1277 stride_config.dimensions, stride_config.step,
1278 init_function);
1279 } else {
1280 // For scalars.
1281 literal_data.at(0) = generator({});
1282 }
1283 return Status::OK();
1284 }
1285
1286 template <typename NativeT>
PopulateWithValue(NativeT value)1287 void Literal::PopulateWithValue(NativeT value) {
1288 CHECK(ShapeUtil::IsArray(shape()));
1289 CHECK_EQ(shape().element_type(),
1290 primitive_util::NativeToPrimitiveType<NativeT>());
1291 for (NativeT& element : data<NativeT>()) {
1292 element = value;
1293 }
1294 }
1295
1296 template <typename NativeT>
CreateFullWithDescendingLayout(tensorflow::gtl::ArraySlice<int64> dimensions,NativeT value)1297 /* static */ std::unique_ptr<Literal> Literal::CreateFullWithDescendingLayout(
1298 tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value) {
1299 auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithDescendingLayout(
1300 primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
1301 literal->PopulateWithValue(value);
1302 return literal;
1303 }
1304
1305 template <typename NativeT>
Replicate(int64 times)1306 std::unique_ptr<Literal> Literal::Replicate(int64 times) const {
1307 DimensionVector bounds = {times};
1308 bounds.reserve(shape().dimensions_size() + 1);
1309 for (int64 bound : shape().dimensions()) {
1310 bounds.push_back(bound);
1311 }
1312 auto literal =
1313 MakeUnique<Literal>(ShapeUtil::MakeShape(shape().element_type(), bounds));
1314 int64 elements = ShapeUtil::ElementsIn(literal->shape());
1315 if (elements == 0) {
1316 return literal;
1317 }
1318
1319 DimensionVector output_indices(bounds.size(), 0);
1320 tensorflow::gtl::ArraySlice<int64> input_indices = output_indices;
1321 input_indices.remove_prefix(1);
1322
1323 bool done = false;
1324 while (!done) {
1325 const auto element = Get<NativeT>(input_indices);
1326 literal->Set<NativeT>(output_indices, element);
1327
1328 done = true;
1329 for (int n = 0; n < output_indices.size(); ++n) {
1330 ++output_indices[n];
1331 if (output_indices[n] < bounds[n]) {
1332 done = false;
1333 break;
1334 }
1335 output_indices[n] = 0;
1336 }
1337 }
1338 return literal;
1339 }
1340
1341 } // namespace xla
1342
1343 #endif // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
1344