1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
17 
18 #include <stddef.h>
19 #include <string>
20 
21 namespace tflite {
22 namespace gpu {
23 
SizeOf(DataType data_type)24 size_t SizeOf(DataType data_type) {
25   switch (data_type) {
26     case DataType::UINT8:
27     case DataType::INT8:
28       return 1;
29     case DataType::FLOAT16:
30     case DataType::INT16:
31     case DataType::UINT16:
32       return 2;
33     case DataType::FLOAT32:
34     case DataType::INT32:
35     case DataType::UINT32:
36       return 4;
37     case DataType::FLOAT64:
38     case DataType::INT64:
39     case DataType::UINT64:
40       return 8;
41     case DataType::UNKNOWN:
42       return 0;
43   }
44   return 0;
45 }
46 
ToString(DataType data_type)47 std::string ToString(DataType data_type) {
48   switch (data_type) {
49     case DataType::FLOAT16:
50       return "float16";
51     case DataType::FLOAT32:
52       return "float32";
53     case DataType::FLOAT64:
54       return "float64";
55     case DataType::INT16:
56       return "int16";
57     case DataType::INT32:
58       return "int32";
59     case DataType::INT64:
60       return "int64";
61     case DataType::INT8:
62       return "int8";
63     case DataType::UINT16:
64       return "uint16";
65     case DataType::UINT32:
66       return "uint32";
67     case DataType::UINT64:
68       return "uint64";
69     case DataType::UINT8:
70       return "uint8";
71     case DataType::UNKNOWN:
72       return "unknown";
73   }
74   return "undefined";
75 }
76 
ToCLDataType(DataType data_type,int vec_size)77 std::string ToCLDataType(DataType data_type, int vec_size) {
78   const std::string postfix = vec_size == 1 ? "" : std::to_string(vec_size);
79   switch (data_type) {
80     case DataType::FLOAT16:
81       return "half" + postfix;
82     case DataType::FLOAT32:
83       return "float" + postfix;
84     case DataType::FLOAT64:
85       return "double" + postfix;
86     case DataType::INT16:
87       return "short" + postfix;
88     case DataType::INT32:
89       return "int" + postfix;
90     case DataType::INT64:
91       return "long" + postfix;
92     case DataType::INT8:
93       return "char" + postfix;
94     case DataType::UINT16:
95       return "ushort" + postfix;
96     case DataType::UINT32:
97       return "uint" + postfix;
98     case DataType::UINT64:
99       return "ulong" + postfix;
100     case DataType::UINT8:
101       return "uchar" + postfix;
102     case DataType::UNKNOWN:
103       return "unknown";
104   }
105   return "undefined";
106 }
107 
ToMetalDataType(DataType data_type,int vec_size)108 std::string ToMetalDataType(DataType data_type, int vec_size) {
109   const std::string postfix = vec_size == 1 ? "" : std::to_string(vec_size);
110   switch (data_type) {
111     case DataType::FLOAT16:
112       return "half" + postfix;
113     case DataType::FLOAT32:
114       return "float" + postfix;
115     case DataType::FLOAT64:
116       return "double" + postfix;
117     case DataType::INT16:
118       return "short" + postfix;
119     case DataType::INT32:
120       return "int" + postfix;
121     case DataType::INT64:
122       return "long" + postfix;
123     case DataType::INT8:
124       return "char" + postfix;
125     case DataType::UINT16:
126       return "ushort" + postfix;
127     case DataType::UINT32:
128       return "uint" + postfix;
129     case DataType::UINT64:
130       return "ulong" + postfix;
131     case DataType::UINT8:
132       return "uchar" + postfix;
133     case DataType::UNKNOWN:
134       return "unknown";
135   }
136   return "undefined";
137 }
138 
139 }  // namespace gpu
140 }  // namespace tflite
141