1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/primitive_util.h"
17 
18 #include "absl/strings/ascii.h"
19 #include "absl/strings/numbers.h"
20 #include "tensorflow/compiler/xla/util.h"
21 #include "tensorflow/compiler/xla/xla_data.pb.h"
22 #include "tensorflow/core/platform/logging.h"
23 
24 namespace xla {
25 namespace primitive_util {
26 
SignificandWidth(PrimitiveType type)27 int SignificandWidth(PrimitiveType type) {
28   switch (type) {
29     case F32:
30       return std::numeric_limits<float>::digits;
31     case F64:
32       return std::numeric_limits<double>::digits;
33     case BF16:
34       return kBFloat16MantissaBits + 1;
35     case F16:
36       return 11;
37     default:
38       LOG(FATAL) << "Not a floating data type " << type;
39   }
40 }
41 
IsFloatingPointType(PrimitiveType type)42 bool IsFloatingPointType(PrimitiveType type) {
43   return type == F16 || type == F32 || type == F64 || type == BF16;
44 }
45 
IsComplexType(PrimitiveType type)46 bool IsComplexType(PrimitiveType type) { return type == C64 || type == C128; }
47 
IsSignedIntegralType(PrimitiveType type)48 bool IsSignedIntegralType(PrimitiveType type) {
49   return type == S8 || type == S16 || type == S32 || type == S64;
50 }
51 
IsUnsignedIntegralType(PrimitiveType type)52 bool IsUnsignedIntegralType(PrimitiveType type) {
53   return type == U8 || type == U16 || type == U32 || type == U64;
54 }
55 
IsIntegralType(PrimitiveType type)56 bool IsIntegralType(PrimitiveType type) {
57   return IsUnsignedIntegralType(type) || IsSignedIntegralType(type);
58 }
59 
BitWidth(PrimitiveType type)60 int BitWidth(PrimitiveType type) {
61   switch (type) {
62     case PRED:
63       return 1;
64 
65     case S8:
66     case U8:
67       return 8;
68 
69     case S16:
70     case U16:
71     case F16:
72     case BF16:
73       return 16;
74 
75     case U32:
76     case S32:
77     case F32:
78       return 32;
79 
80     case U64:
81     case S64:
82     case F64:
83     case C64:
84       return 64;
85 
86     case C128:
87       return 128;
88 
89     case TUPLE:
90       LOG(FATAL) << "TUPLE is an invalid type for BitWidth";
91 
92     case OPAQUE:
93       LOG(FATAL) << "OPAQUE is an invalid type for BitWidth";
94 
95     default:
96       LOG(FATAL) << "Unhandled primitive type " << type;
97   }
98 }
99 
UnsignedIntegralTypeForBitWidth(int64 src_bitwidth)100 xla::PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth) {
101   switch (src_bitwidth) {
102     case 8:
103       return xla::U8;
104     case 16:
105       return xla::U16;
106     case 32:
107       return xla::U32;
108     case 64:
109       return xla::U64;
110     default:
111       return xla::PRIMITIVE_TYPE_INVALID;
112   }
113 }
114 
ComplexComponentType(PrimitiveType complex_type)115 PrimitiveType ComplexComponentType(PrimitiveType complex_type) {
116   switch (complex_type) {
117     case C64:
118       return F32;
119     case C128:
120       return F64;
121     default:
122       LOG(FATAL) << "Primitive type is not complex: "
123                  << PrimitiveType_Name(complex_type);
124   }
125 }
126 
IsArrayType(PrimitiveType primitive_type)127 bool IsArrayType(PrimitiveType primitive_type) {
128   return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE &&
129          primitive_type != OPAQUE && primitive_type != TOKEN;
130 }
131 
132 // Class to memoize the computation of
133 //   absl::AsciiStrToLower(PrimitiveType_Name(p))
134 // for all PrimitiveType values "p"
135 class PrimitiveTypeNameGenerator {
136  public:
PrimitiveTypeNameGenerator()137   PrimitiveTypeNameGenerator() {
138     for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
139       if (PrimitiveType_IsValid(i)) {
140         lowercase_name_[i] = absl::AsciiStrToLower(
141             PrimitiveType_Name(static_cast<PrimitiveType>(i)));
142       }
143     }
144   }
LowercaseName(PrimitiveType t)145   const string& LowercaseName(PrimitiveType t) {
146     return lowercase_name_[static_cast<int>(t)];
147   }
148 
149  private:
150   string lowercase_name_[PrimitiveType_ARRAYSIZE];
151 };
152 
LowercasePrimitiveTypeName(PrimitiveType s)153 const string& LowercasePrimitiveTypeName(PrimitiveType s) {
154   static auto* gen = new PrimitiveTypeNameGenerator();
155   return gen->LowercaseName(s);
156 }
157 
158 namespace {
159 
160 // Returns a map from lower-case primitive type name to primitive type.
GetPrimitiveTypeStringMap()161 const std::unordered_map<string, PrimitiveType>& GetPrimitiveTypeStringMap() {
162   static std::unordered_map<string, PrimitiveType>* name_to_type = [] {
163     static auto* map = new std::unordered_map<string, PrimitiveType>;
164     for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
165       if (PrimitiveType_IsValid(i) && i != PRIMITIVE_TYPE_INVALID) {
166         auto value = static_cast<PrimitiveType>(i);
167         (*map)[LowercasePrimitiveTypeName(value)] = value;
168       }
169     }
170     return map;
171   }();
172   return *name_to_type;
173 }
174 
175 }  // namespace
176 
StringToPrimitiveType(absl::string_view name)177 StatusOr<PrimitiveType> StringToPrimitiveType(absl::string_view name) {
178   const auto& map = GetPrimitiveTypeStringMap();
179   auto found = map.find(string(name));
180   if (found == map.end()) {
181     return InvalidArgument("Invalid element type string: \"%s\".", name);
182   }
183   return found->second;
184 }
185 
IsPrimitiveTypeName(absl::string_view name)186 bool IsPrimitiveTypeName(absl::string_view name) {
187   const auto& map = GetPrimitiveTypeStringMap();
188   auto found = map.find(string(name));
189   return found != map.end();
190 }
191 
192 }  // namespace primitive_util
193 }  // namespace xla
194