1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/literal_util.h"
17 
18 #include <algorithm>
19 #include <cstring>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <vector>
24 
25 #include "absl/memory/memory.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/compiler/xla/index_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/hash/hash.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/mem.h"
37 #include "tensorflow/core/platform/types.h"
38 
39 namespace xla {
40 namespace {
41 
42 using absl::StrCat;
43 
44 // Return a literal with all arrays of type FromNativeT converted to type
45 // ToNativeT in the given literal.
46 template <typename FromNativeT, typename ToNativeT>
ConvertType(LiteralSlice literal)47 Literal ConvertType(LiteralSlice literal) {
48   // First construct shape of the result.
49   Shape result_shape(literal.shape());
50   ShapeUtil::ForEachMutableSubshape(
51       &result_shape, [](Shape* subshape, const ShapeIndex&) {
52         if (subshape->element_type() ==
53             primitive_util::NativeToPrimitiveType<FromNativeT>()) {
54           subshape->set_element_type(
55               primitive_util::NativeToPrimitiveType<ToNativeT>());
56         }
57       });
58   Literal result(result_shape);
59 
60   // Then copy over the data from 'literal' converting FromNativeT values to
61   // ToNativeT values as necessary.
62   ShapeUtil::ForEachSubshape(
63       literal.shape(),
64       [&](const Shape& subshape, const ShapeIndex& shape_index) {
65         if (subshape.IsArray()) {
66           if (subshape.element_type() ==
67               primitive_util::NativeToPrimitiveType<FromNativeT>()) {
68             auto src = literal.data<FromNativeT>(shape_index);
69             auto dest = result.data<ToNativeT>(shape_index);
70             for (int64 i = 0, end = src.size(); i < end; ++i) {
71               dest[i] = static_cast<ToNativeT>(src[i]);
72             }
73           } else {
74             TF_CHECK_OK(result.CopyFrom(literal,
75                                         /*dest_shape_index=*/shape_index,
76                                         /*src_shape_index=*/shape_index));
77           }
78         }
79       });
80   return result;
81 }
82 
83 }  // namespace
84 
CreateFromDimensions(PrimitiveType primitive_type,absl::Span<const int64> dimensions)85 /* static */ Literal LiteralUtil::CreateFromDimensions(
86     PrimitiveType primitive_type, absl::Span<const int64> dimensions) {
87   return Literal::CreateFromShape(
88       ShapeUtil::MakeShape(primitive_type, dimensions));
89 }
90 
ConvertBF16ToF32(const LiteralSlice & bf16_literal)91 /* static */ Literal LiteralUtil::ConvertBF16ToF32(
92     const LiteralSlice& bf16_literal) {
93   return ConvertType<bfloat16, float>(bf16_literal);
94 }
95 
ConvertBF16ToF64(const LiteralSlice & bf16_literal)96 /* static */ Literal LiteralUtil::ConvertBF16ToF64(
97     const LiteralSlice& bf16_literal) {
98   return ConvertType<bfloat16, double>(bf16_literal);
99 }
100 
ConvertF32ToBF16(const LiteralSlice & f32_literal)101 /* static */ Literal LiteralUtil::ConvertF32ToBF16(
102     const LiteralSlice& f32_literal) {
103   return ConvertType<float, bfloat16>(f32_literal);
104 }
105 
ConvertF32ToF64(const LiteralSlice & f32_literal)106 /* static */ Literal LiteralUtil::ConvertF32ToF64(
107     const LiteralSlice& f32_literal) {
108   return ConvertType<float, double>(f32_literal);
109 }
110 
ConvertF64ToBF16(const LiteralSlice & f64_literal)111 /* static */ Literal LiteralUtil::ConvertF64ToBF16(
112     const LiteralSlice& f64_literal) {
113   return ConvertType<double, bfloat16>(f64_literal);
114 }
115 
ConvertF64ToF32(const LiteralSlice & f64_literal)116 /* static */ Literal LiteralUtil::ConvertF64ToF32(
117     const LiteralSlice& f64_literal) {
118   return ConvertType<double, float>(f64_literal);
119 }
120 
CreateToken()121 /* static */ Literal LiteralUtil::CreateToken() {
122   return Literal(ShapeUtil::MakeTokenShape());
123 }
124 
Zero(PrimitiveType primitive_type)125 /* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) {
126   switch (primitive_type) {
127     case U8:
128       return LiteralUtil::CreateR0<uint8>(0);
129     case U16:
130       return LiteralUtil::CreateR0<uint16>(0);
131     case U32:
132       return LiteralUtil::CreateR0<uint32>(0);
133     case U64:
134       return LiteralUtil::CreateR0<uint64>(0);
135     case S8:
136       return LiteralUtil::CreateR0<int8>(0);
137     case S16:
138       return LiteralUtil::CreateR0<int16>(0);
139     case S32:
140       return LiteralUtil::CreateR0<int32>(0);
141     case S64:
142       return LiteralUtil::CreateR0<int64>(0);
143     case F16:
144       return LiteralUtil::CreateR0<half>(static_cast<half>(0.0f));
145     case BF16:
146       return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f));
147     case F32:
148       return LiteralUtil::CreateR0<float>(0);
149     case F64:
150       return LiteralUtil::CreateR0<double>(0);
151     case C64:
152       return LiteralUtil::CreateR0<complex64>(0);
153     case C128:
154       return LiteralUtil::CreateR0<complex128>(0);
155     case PRED:
156       return LiteralUtil::CreateR0<bool>(false);
157     case TUPLE:
158       LOG(FATAL) << "tuple element type cannot take on value of 0";
159     case OPAQUE_TYPE:
160       LOG(FATAL) << "opaque element type cannot take on value of 0";
161     default:
162       LOG(FATAL) << "Unhandled primitive type " << primitive_type;
163   }
164 }
165 
One(PrimitiveType primitive_type)166 /* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) {
167   switch (primitive_type) {
168     case U8:
169       return LiteralUtil::CreateR0<uint8>(1);
170     case U16:
171       return LiteralUtil::CreateR0<uint16>(1);
172     case U32:
173       return LiteralUtil::CreateR0<uint32>(1);
174     case U64:
175       return LiteralUtil::CreateR0<uint64>(1);
176     case S8:
177       return LiteralUtil::CreateR0<int8>(1);
178     case S16:
179       return LiteralUtil::CreateR0<int16>(1);
180     case S32:
181       return LiteralUtil::CreateR0<int32>(1);
182     case S64:
183       return LiteralUtil::CreateR0<int64>(1);
184     case F16:
185       return LiteralUtil::CreateR0<half>(static_cast<half>(1.0f));
186     case BF16:
187       return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f));
188     case F32:
189       return LiteralUtil::CreateR0<float>(1);
190     case F64:
191       return LiteralUtil::CreateR0<double>(1);
192     case C64:
193       return LiteralUtil::CreateR0<complex64>(1);
194     case C128:
195       return LiteralUtil::CreateR0<complex128>(1);
196     case PRED:
197       return LiteralUtil::CreateR0<bool>(true);
198     case TUPLE:
199       LOG(FATAL) << "tuple element type cannot take on value of 1";
200     case OPAQUE_TYPE:
201       LOG(FATAL) << "opaque element type cannot take on value of 1";
202     default:
203       LOG(FATAL) << "Unhandled primitive type " << primitive_type;
204   }
205 }
206 
MinValue(PrimitiveType primitive_type)207 /* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) {
208   switch (primitive_type) {
209     case U8:
210       return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min());
211     case U16:
212       return LiteralUtil::CreateR0<uint16>(std::numeric_limits<uint16>::min());
213     case U32:
214       return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min());
215     case U64:
216       return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min());
217     case S8:
218       return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min());
219     case S16:
220       return LiteralUtil::CreateR0<int16>(std::numeric_limits<int16>::min());
221     case S32:
222       return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min());
223     case S64:
224       return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::min());
225     case F32:
226       return LiteralUtil::CreateR0<float>(
227           -std::numeric_limits<float>::infinity());
228     case F64:
229       return LiteralUtil::CreateR0<double>(
230           -std::numeric_limits<double>::infinity());
231     case C64:
232       LOG(FATAL) << "C64 element type has no minimum value";
233     case C128:
234       LOG(FATAL) << "C128 element type has no minimum value";
235     case PRED:
236       return LiteralUtil::CreateR0<bool>(false);
237     case F16:
238       return LiteralUtil::CreateR0<half>(
239           static_cast<half>(-std::numeric_limits<float>::infinity()));
240     case BF16:
241       return LiteralUtil::CreateR0<bfloat16>(
242           static_cast<bfloat16>(-std::numeric_limits<float>::infinity()));
243     case TUPLE:
244       LOG(FATAL) << "tuple element type has no minimum value";
245     case OPAQUE_TYPE:
246       LOG(FATAL) << "opaque element type has no minimum value";
247     default:
248       LOG(FATAL) << "Unhandled primitive type " << primitive_type;
249   }
250 }
251 
MaxValue(PrimitiveType primitive_type)252 /* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) {
253   switch (primitive_type) {
254     case U8:
255       return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max());
256     case U16:
257       return LiteralUtil::CreateR0<uint16>(std::numeric_limits<uint16>::max());
258     case U32:
259       return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max());
260     case U64:
261       return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max());
262     case S8:
263       return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max());
264     case S16:
265       return LiteralUtil::CreateR0<int16>(std::numeric_limits<int16>::max());
266     case S32:
267       return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max());
268     case S64:
269       return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::max());
270     case F32:
271       return LiteralUtil::CreateR0<float>(
272           std::numeric_limits<float>::infinity());
273     case F64:
274       return LiteralUtil::CreateR0<double>(
275           std::numeric_limits<double>::infinity());
276     case PRED:
277       return LiteralUtil::CreateR0<bool>(true);
278     case F16:
279       return LiteralUtil::CreateR0<half>(
280           static_cast<half>(std::numeric_limits<float>::infinity()));
281     case BF16:
282       return LiteralUtil::CreateR0<bfloat16>(
283           static_cast<bfloat16>(std::numeric_limits<float>::infinity()));
284     case TUPLE:
285       LOG(FATAL) << "tuple element type has no maximum value";
286     case OPAQUE_TYPE:
287       LOG(FATAL) << "opaque element type has no maximum value";
288     default:
289       LOG(FATAL) << "Unhandled primitive type " << primitive_type;
290   }
291 }
292 
NanValue(PrimitiveType primitive_type)293 /* static */ StatusOr<Literal> LiteralUtil::NanValue(
294     PrimitiveType primitive_type) {
295   switch (primitive_type) {
296     case F16:
297       return LiteralUtil::CreateR0<half>(
298           static_cast<half>(std::numeric_limits<float>::quiet_NaN()));
299     case BF16:
300       return LiteralUtil::CreateR0<bfloat16>(
301           static_cast<bfloat16>(std::numeric_limits<float>::quiet_NaN()));
302     case F32:
303       return LiteralUtil::CreateR0<float>(
304           std::numeric_limits<float>::quiet_NaN());
305     case F64:
306       return LiteralUtil::CreateR0<double>(
307           std::numeric_limits<double>::quiet_NaN());
308     case C64: {
309       float nan = std::numeric_limits<float>::quiet_NaN();
310       return LiteralUtil::CreateR0<complex64>(complex64(nan, nan));
311     }
312     case C128: {
313       double nan = std::numeric_limits<double>::quiet_NaN();
314       return LiteralUtil::CreateR0<complex128>(complex128(nan, nan));
315     }
316     default:
317       return InvalidArgument("Invalid type for NanValue: %s",
318                              PrimitiveType_Name(primitive_type));
319   }
320 }
321 
CreateR1(const tensorflow::core::Bitmap & values)322 /* static */ Literal LiteralUtil::CreateR1(
323     const tensorflow::core::Bitmap& values) {
324   Literal literal(
325       ShapeUtil::MakeShape(PRED, {static_cast<int64>(values.bits())}));
326   literal.PopulateR1(values);
327   return literal;
328 }
329 
CreateR1U8(absl::string_view value)330 /* static */ Literal LiteralUtil::CreateR1U8(absl::string_view value) {
331   Literal literal(ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
332   for (int i = 0, end = value.size(); i < end; ++i) {
333     literal.Set<uint8>({i}, value[i]);
334   }
335   return literal;
336 }
337 
CreateR2F32Linspace(float from,float to,int64 rows,int64 cols)338 /* static */ Literal LiteralUtil::CreateR2F32Linspace(float from, float to,
339                                                       int64 rows, int64 cols) {
340   auto value = MakeLinspaceArray2D(from, to, rows, cols);
341   return CreateR2FromArray2D(*value);
342 }
343 
ReshapeSlice(absl::Span<const int64> new_dimensions,absl::Span<const int64> minor_to_major,const LiteralSlice & literal)344 /* static */ Literal LiteralUtil::ReshapeSlice(
345     absl::Span<const int64> new_dimensions,
346     absl::Span<const int64> minor_to_major, const LiteralSlice& literal) {
347   int64 new_num_elements = 1;
348   for (int64 i = 0, end = new_dimensions.size(); i < end; ++i) {
349     new_num_elements *= new_dimensions[i];
350   }
351   CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
352   CHECK_EQ(new_dimensions.size(), minor_to_major.size());
353 
354   Literal new_literal(
355       ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
356 
357   // Create a new shape with the given minor-to-major layout. This shape is used
358   // solely for converting linear address to multi-dimensional addresses when
359   // writing elements to the new literal.
360   Shape shape_with_layout = new_literal.shape();
361   *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
362 
363   // Copy data into new literal, element-by-element.
364   for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
365     std::vector<int64> from_multi_index =
366         IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i);
367     std::vector<int64> to_multi_index =
368         IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
369     switch (literal.shape().element_type()) {
370       case PRED:
371         new_literal.Set<bool>(to_multi_index,
372                               literal.Get<bool>(from_multi_index));
373         break;
374       case U8:
375         new_literal.Set<uint8>(to_multi_index,
376                                literal.Get<uint8>(from_multi_index));
377         break;
378       case U32:
379         new_literal.Set<uint32>(to_multi_index,
380                                 literal.Get<uint32>(from_multi_index));
381         break;
382       case S32:
383         new_literal.Set<int32>(to_multi_index,
384                                literal.Get<int32>(from_multi_index));
385         break;
386       case U64:
387         new_literal.Set<uint64>(to_multi_index,
388                                 literal.Get<uint64>(from_multi_index));
389         break;
390       case S64:
391         new_literal.Set<int64>(to_multi_index,
392                                literal.Get<int64>(from_multi_index));
393         break;
394       case F32:
395         new_literal.Set<float>(to_multi_index,
396                                literal.Get<float>(from_multi_index));
397         break;
398       case F64:
399         new_literal.Set<double>(to_multi_index,
400                                 literal.Get<double>(from_multi_index));
401         break;
402       case C64:
403         new_literal.Set<complex64>(to_multi_index,
404                                    literal.Get<complex64>(from_multi_index));
405         break;
406       case C128:
407         new_literal.Set<complex128>(to_multi_index,
408                                     literal.Get<complex128>(from_multi_index));
409         break;
410       default:
411         LOG(FATAL) << "Unhandled primitive element type: "
412                    << PrimitiveType_Name(literal.shape().element_type());
413     }
414   }
415 
416   return new_literal;
417 }
418 
GetFirstScalarLiteral(const LiteralSlice & literal)419 /* static */ Literal LiteralUtil::GetFirstScalarLiteral(
420     const LiteralSlice& literal) {
421   CHECK(literal.shape().IsArray());
422   CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0);
423   switch (literal.shape().element_type()) {
424     case PRED:
425       return LiteralUtil::CreateR0<bool>(literal.GetFirstElement<bool>());
426     // 8 bit types.
427     case S8:
428       return LiteralUtil::CreateR0<int8>(literal.GetFirstElement<int8>());
429     case U8:
430       return LiteralUtil::CreateR0<uint8>(literal.GetFirstElement<uint8>());
431     // 16 bit types.
432     case BF16:
433       return LiteralUtil::CreateR0<bfloat16>(
434           literal.GetFirstElement<bfloat16>());
435     case F16:
436       return LiteralUtil::CreateR0<half>(literal.GetFirstElement<half>());
437     case S16:
438       return LiteralUtil::CreateR0<int16>(literal.GetFirstElement<int16>());
439     case U16:
440       return LiteralUtil::CreateR0<uint16>(literal.GetFirstElement<uint16>());
441     // 32 bit types.
442     case F32:
443       return LiteralUtil::CreateR0<float>(literal.GetFirstElement<float>());
444     case S32:
445       return LiteralUtil::CreateR0<int32>(literal.GetFirstElement<int32>());
446     case U32:
447       return LiteralUtil::CreateR0<uint32>(literal.GetFirstElement<uint32>());
448     // 64 bit types.
449     case C64:
450       return LiteralUtil::CreateR0<complex64>(
451           literal.GetFirstElement<complex64>());
452     case F64:
453       return LiteralUtil::CreateR0<double>(literal.GetFirstElement<double>());
454     case S64:
455       return LiteralUtil::CreateR0<int64>(literal.GetFirstElement<int64>());
456     case U64:
457       return LiteralUtil::CreateR0<uint64>(literal.GetFirstElement<uint64>());
458 
459     case C128:
460       return LiteralUtil::CreateR0<complex128>(
461           literal.GetFirstElement<complex128>());
462     default:
463       LOG(FATAL) << "Unhandled primitive type "
464                  << literal.shape().element_type();
465   }
466 }
467 
MakeTuple(absl::Span<const Literal * const> elements)468 /* static */ Literal LiteralUtil::MakeTuple(
469     absl::Span<const Literal* const> elements) {
470   std::vector<Shape> element_shapes;
471   for (const auto* element : elements) {
472     element_shapes.push_back(element->shape());
473   }
474   Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
475   for (int i = 0, end = elements.size(); i < end; ++i) {
476     TF_CHECK_OK(literal.CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
477   }
478   return literal;
479 }
480 
MakeTupleFromSlices(absl::Span<const LiteralSlice> elements)481 /* static */ Literal LiteralUtil::MakeTupleFromSlices(
482     absl::Span<const LiteralSlice> elements) {
483   std::vector<Shape> element_shapes;
484   for (const auto& element : elements) {
485     element_shapes.push_back(element.shape());
486   }
487   Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
488   for (int i = 0, end = elements.size(); i < end; ++i) {
489     TF_CHECK_OK(literal.CopyFrom(elements[i], /*dest_shape_index=*/{i}));
490   }
491   return literal;
492 }
493 
MakeTupleOwned(std::vector<Literal> elements)494 /* static */ Literal LiteralUtil::MakeTupleOwned(
495     std::vector<Literal> elements) {
496   std::vector<Shape> element_shapes;
497   element_shapes.reserve(elements.size());
498   for (const auto& element : elements) {
499     element_shapes.push_back(element.shape());
500   }
501   Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
502   for (int64 i = 0, end = elements.size(); i < end; ++i) {
503     TF_CHECK_OK(
504         literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
505   }
506   return literal;
507 }
508 
MultiIndexAsString(absl::Span<const int64> multi_index)509 /* static */ string LiteralUtil::MultiIndexAsString(
510     absl::Span<const int64> multi_index) {
511   return StrCat("{", absl::StrJoin(multi_index, ","), "}");
512 }
513 
514 }  // namespace xla
515