1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/primitive_util.h"
17 
18 #include <limits>
19 
20 #include "absl/strings/ascii.h"
21 #include "absl/strings/numbers.h"
22 #include "tensorflow/compiler/xla/types.h"
23 #include "tensorflow/compiler/xla/util.h"
24 #include "tensorflow/compiler/xla/xla_data.pb.h"
25 #include "tensorflow/core/platform/logging.h"
26 
27 namespace xla {
28 namespace primitive_util {
29 
SignificandWidth(PrimitiveType type)30 int SignificandWidth(PrimitiveType type) {
31   switch (type) {
32     case F32:
33       return std::numeric_limits<float>::digits;
34     case F64:
35       return std::numeric_limits<double>::digits;
36     case BF16:
37       return std::numeric_limits<bfloat16>::digits;
38     case F16:
39       return std::numeric_limits<half>::digits;
40     default:
41       LOG(FATAL) << "Not a floating data type " << type;
42   }
43 }
44 
ExponentWidth(PrimitiveType type)45 int ExponentWidth(PrimitiveType type) {
46   // Per the IEEE-754 standard: a floating point type is stored as a sign bit, a
47   // biased exponent and a trailing significand field.
48   int total_bit_width = BitWidth(type);
49   // This field contains all bits in the significand other than the leading
50   // digit which is implied by the exponent.
51   int trailing_significand_field_width = SignificandWidth(type) - 1;
52   // The sign is encoded with a single bit.
53   int kSignBitWidth = 1;
54   // The remaining bits are used for encoding the biased exponent.
55   return total_bit_width - (trailing_significand_field_width + kSignBitWidth);
56 }
57 
OverflowExponent(PrimitiveType type)58 int OverflowExponent(PrimitiveType type) {
59   // |std::numeric_limits<float>::max_exponent| is defined as: "Maximum positive
60   // integer such that radix raised to the power one less than that integer is a
61   // representable finite floating-point number." as such it does not actually
62   // yield the maximum exponent but the exponent of the first integer which
63   // overflows.
64   switch (type) {
65     case F32:
66       return std::numeric_limits<float>::max_exponent;
67     case F64:
68       return std::numeric_limits<double>::max_exponent;
69     case BF16:
70       return std::numeric_limits<bfloat16>::max_exponent;
71     case F16:
72       return std::numeric_limits<half>::max_exponent;
73     default:
74       LOG(FATAL) << "Not a floating data type " << type;
75   }
76 }
77 
IsFloatingPointType(PrimitiveType type)78 bool IsFloatingPointType(PrimitiveType type) {
79   return type == F16 || type == F32 || type == F64 || type == BF16;
80 }
81 
IsComplexType(PrimitiveType type)82 bool IsComplexType(PrimitiveType type) { return type == C64 || type == C128; }
83 
IsSignedIntegralType(PrimitiveType type)84 bool IsSignedIntegralType(PrimitiveType type) {
85   return type == S8 || type == S16 || type == S32 || type == S64;
86 }
87 
IsUnsignedIntegralType(PrimitiveType type)88 bool IsUnsignedIntegralType(PrimitiveType type) {
89   return type == U8 || type == U16 || type == U32 || type == U64;
90 }
91 
IsIntegralType(PrimitiveType type)92 bool IsIntegralType(PrimitiveType type) {
93   return IsUnsignedIntegralType(type) || IsSignedIntegralType(type);
94 }
95 
BitWidth(PrimitiveType type)96 int BitWidth(PrimitiveType type) {
97   switch (type) {
98     case PRED:
99       return 1;
100 
101     case S8:
102     case U8:
103       return 8;
104 
105     case S16:
106     case U16:
107     case F16:
108     case BF16:
109       return 16;
110 
111     case U32:
112     case S32:
113     case F32:
114       return 32;
115 
116     case U64:
117     case S64:
118     case F64:
119     case C64:
120       return 64;
121 
122     case C128:
123       return 128;
124 
125     case TUPLE:
126       LOG(FATAL) << "TUPLE is an invalid type for BitWidth";
127 
128     case OPAQUE_TYPE:
129       LOG(FATAL) << "OPAQUE_TYPE is an invalid type for BitWidth";
130 
131     default:
132       LOG(FATAL) << "Unhandled primitive type " << type;
133   }
134 }
135 
UnsignedIntegralTypeForBitWidth(int64 src_bitwidth)136 xla::PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth) {
137   switch (src_bitwidth) {
138     case 8:
139       return xla::U8;
140     case 16:
141       return xla::U16;
142     case 32:
143       return xla::U32;
144     case 64:
145       return xla::U64;
146     default:
147       return xla::PRIMITIVE_TYPE_INVALID;
148   }
149 }
150 
SignedIntegralTypeForBitWidth(int64 src_bitwidth)151 xla::PrimitiveType SignedIntegralTypeForBitWidth(int64 src_bitwidth) {
152   switch (src_bitwidth) {
153     case 8:
154       return xla::S8;
155     case 16:
156       return xla::S16;
157     case 32:
158       return xla::S32;
159     case 64:
160       return xla::S64;
161     default:
162       return xla::PRIMITIVE_TYPE_INVALID;
163   }
164 }
165 
ComplexComponentType(PrimitiveType complex_type)166 PrimitiveType ComplexComponentType(PrimitiveType complex_type) {
167   switch (complex_type) {
168     case C64:
169       return F32;
170     case C128:
171       return F64;
172     default:
173       LOG(FATAL) << "Primitive type is not complex: "
174                  << PrimitiveType_Name(complex_type);
175   }
176 }
177 
IsArrayType(PrimitiveType primitive_type)178 bool IsArrayType(PrimitiveType primitive_type) {
179   return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE &&
180          primitive_type != OPAQUE_TYPE && primitive_type != TOKEN;
181 }
182 
183 // Class to memoize the computation of
184 //   absl::AsciiStrToLower(PrimitiveType_Name(p))
185 // for all PrimitiveType values "p"
186 //
187 // xla::OPAQUE_TYPE canonically maps to the string "opaque" -- the only reason
188 // it's called OPAQUE_TYPE is to avoid clashing with a windows.h macro.
189 class PrimitiveTypeNameGenerator {
190  public:
PrimitiveTypeNameGenerator()191   PrimitiveTypeNameGenerator() {
192     for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
193       if (i == static_cast<int>(OPAQUE_TYPE)) {
194         lowercase_name_[i] = "opaque";
195       } else if (PrimitiveType_IsValid(i)) {
196         lowercase_name_[i] = absl::AsciiStrToLower(
197             PrimitiveType_Name(static_cast<PrimitiveType>(i)));
198       }
199     }
200   }
LowercaseName(PrimitiveType t)201   const string& LowercaseName(PrimitiveType t) {
202     return lowercase_name_[static_cast<int>(t)];
203   }
204 
205  private:
206   string lowercase_name_[PrimitiveType_ARRAYSIZE];
207 };
208 
LowercasePrimitiveTypeName(PrimitiveType s)209 const string& LowercasePrimitiveTypeName(PrimitiveType s) {
210   static auto* gen = new PrimitiveTypeNameGenerator();
211   return gen->LowercaseName(s);
212 }
213 
214 namespace {
215 
216 // Returns a map from lower-case primitive type name to primitive type.
217 //
218 // Due to Postel's Law considerations, both "opaque" and "opaque_type" map to
219 // the xla::OPAQUE_TYPE enumerator.
GetPrimitiveTypeStringMap()220 const std::unordered_map<string, PrimitiveType>& GetPrimitiveTypeStringMap() {
221   static std::unordered_map<string, PrimitiveType>* name_to_type = [] {
222     static auto* map = new std::unordered_map<string, PrimitiveType>;
223     for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
224       if (PrimitiveType_IsValid(i) && i != PRIMITIVE_TYPE_INVALID) {
225         auto value = static_cast<PrimitiveType>(i);
226         (*map)[LowercasePrimitiveTypeName(value)] = value;
227       }
228     }
229     (*map)["opaque"] = OPAQUE_TYPE;
230     return map;
231   }();
232   return *name_to_type;
233 }
234 
235 }  // namespace
236 
StringToPrimitiveType(absl::string_view name)237 StatusOr<PrimitiveType> StringToPrimitiveType(absl::string_view name) {
238   const auto& map = GetPrimitiveTypeStringMap();
239   auto found = map.find(string(name));
240   if (found == map.end()) {
241     return InvalidArgument("Invalid element type string: \"%s\".", name);
242   }
243   return found->second;
244 }
245 
IsPrimitiveTypeName(absl::string_view name)246 bool IsPrimitiveTypeName(absl::string_view name) {
247   const auto& map = GetPrimitiveTypeStringMap();
248   auto found = map.find(string(name));
249   return found != map.end();
250 }
251 
252 }  // namespace primitive_util
253 }  // namespace xla
254