1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
17 
18 #include <map>
19 #include <string>
20 
21 #include "absl/strings/ascii.h"
22 
23 namespace tflite {
24 namespace gpu {
25 namespace {
26 
GetGpuVendor(const std::string & gpu_description)27 GpuVendor GetGpuVendor(const std::string& gpu_description) {
28   const std::map<std::string, GpuVendor> kMapping = {
29       {"adreno", GpuVendor::kQualcomm},
30       {"apple", GpuVendor::kApple},
31       {"qualcomm", GpuVendor::kQualcomm},
32       {"mali", GpuVendor::kMali},
33       {"powervr", GpuVendor::kPowerVR},
34       {"advanced micro devices", GpuVendor::kAMD},
35       {"intel", GpuVendor::kIntel},
36       {"nvidia", GpuVendor::kNvidia},
37       {"amd", GpuVendor::kAMD},
38       {"power", GpuVendor::kPowerVR},
39   };
40   for (const auto& v : kMapping) {
41     if (gpu_description.find(v.first) != std::string::npos) {
42       return v.second;
43     }
44   }
45   return GpuVendor::kUnknown;
46 }
47 
GetAdrenoGpuVersion(const std::string & gpu_description)48 AdrenoGpu GetAdrenoGpuVersion(const std::string& gpu_description) {
49   const std::map<std::string, AdrenoGpu> kMapping = {
50       // Adreno 6xx series
51       {"685", AdrenoGpu::kAdreno685},
52       {"680", AdrenoGpu::kAdreno680},
53       {"675", AdrenoGpu::kAdreno675},
54       {"650", AdrenoGpu::kAdreno650},
55       {"640", AdrenoGpu::kAdreno640},
56       {"630", AdrenoGpu::kAdreno630},
57       {"620", AdrenoGpu::kAdreno620},
58       {"618", AdrenoGpu::kAdreno618},
59       {"616", AdrenoGpu::kAdreno616},
60       {"615", AdrenoGpu::kAdreno615},
61       {"612", AdrenoGpu::kAdreno612},
62       {"610", AdrenoGpu::kAdreno610},
63       {"605", AdrenoGpu::kAdreno605},
64       // Adreno 5xx series
65       {"540", AdrenoGpu::kAdreno540},
66       {"530", AdrenoGpu::kAdreno530},
67       {"512", AdrenoGpu::kAdreno512},
68       {"510", AdrenoGpu::kAdreno510},
69       {"509", AdrenoGpu::kAdreno509},
70       {"508", AdrenoGpu::kAdreno508},
71       {"506", AdrenoGpu::kAdreno506},
72       {"505", AdrenoGpu::kAdreno505},
73       {"504", AdrenoGpu::kAdreno504},
74       // Adreno 4xx series
75       {"430", AdrenoGpu::kAdreno430},
76       {"420", AdrenoGpu::kAdreno420},
77       {"418", AdrenoGpu::kAdreno418},
78       {"405", AdrenoGpu::kAdreno405},
79       // Adreno 3xx series
80       {"330", AdrenoGpu::kAdreno330},
81       {"320", AdrenoGpu::kAdreno320},
82       {"308", AdrenoGpu::kAdreno308},
83       {"306", AdrenoGpu::kAdreno306},
84       {"305", AdrenoGpu::kAdreno305},
85       {"304", AdrenoGpu::kAdreno304},
86       // Adreno 2xx series
87       {"225", AdrenoGpu::kAdreno225},
88       {"220", AdrenoGpu::kAdreno220},
89       {"205", AdrenoGpu::kAdreno205},
90       {"203", AdrenoGpu::kAdreno203},
91       {"200", AdrenoGpu::kAdreno200},
92       // Adreno 1xx series
93       {"130", AdrenoGpu::kAdreno130},
94       {"120", AdrenoGpu::kAdreno120},
95   };
96 
97   for (const auto& v : kMapping) {
98     if (gpu_description.find(v.first) != std::string::npos) {
99       return v.second;
100     }
101   }
102   return AdrenoGpu::kUnknown;
103 }
104 
GetMaliGpuVersion(const std::string & gpu_description)105 MaliGpu GetMaliGpuVersion(const std::string& gpu_description) {
106   const std::map<std::string, MaliGpu> kMapping = {
107       {"t604", MaliGpu::kT604}, {"t622", MaliGpu::kT622},
108       {"t624", MaliGpu::kT624}, {"t628", MaliGpu::kT628},
109       {"t658", MaliGpu::kT658}, {"t678", MaliGpu::kT678},
110       {"t720", MaliGpu::kT720}, {"t760", MaliGpu::kT760},
111       {"t820", MaliGpu::kT820}, {"t830", MaliGpu::kT830},
112       {"t860", MaliGpu::kT860}, {"t880", MaliGpu::kT880},
113       {"g31", MaliGpu::kG31},   {"g51", MaliGpu::kG51},
114       {"g71", MaliGpu::kG71},   {"g52", MaliGpu::kG52},
115       {"g72", MaliGpu::kG72},   {"g76", MaliGpu::kG76},
116       {"g57", MaliGpu::kG57},   {"g77", MaliGpu::kG77},
117       {"g68", MaliGpu::kG68},   {"g78", MaliGpu::kG78},
118   };
119   for (const auto& v : kMapping) {
120     if (gpu_description.find(v.first) != std::string::npos) {
121       return v.second;
122     }
123   }
124   return MaliGpu::kUnknown;
125 }
126 
127 }  // namespace
128 
AdrenoInfo(const std::string & device_version)129 AdrenoInfo::AdrenoInfo(const std::string& device_version)
130     : adreno_gpu(GetAdrenoGpuVersion(device_version)) {}
131 
IsAdreno1xx() const132 bool AdrenoInfo::IsAdreno1xx() const {
133   return adreno_gpu == AdrenoGpu::kAdreno120 ||
134          adreno_gpu == AdrenoGpu::kAdreno130;
135 }
136 
IsAdreno2xx() const137 bool AdrenoInfo::IsAdreno2xx() const {
138   return adreno_gpu == AdrenoGpu::kAdreno200 ||
139          adreno_gpu == AdrenoGpu::kAdreno203 ||
140          adreno_gpu == AdrenoGpu::kAdreno205 ||
141          adreno_gpu == AdrenoGpu::kAdreno220 ||
142          adreno_gpu == AdrenoGpu::kAdreno225;
143 }
144 
IsAdreno3xx() const145 bool AdrenoInfo::IsAdreno3xx() const {
146   return adreno_gpu == AdrenoGpu::kAdreno304 ||
147          adreno_gpu == AdrenoGpu::kAdreno305 ||
148          adreno_gpu == AdrenoGpu::kAdreno306 ||
149          adreno_gpu == AdrenoGpu::kAdreno308 ||
150          adreno_gpu == AdrenoGpu::kAdreno320 ||
151          adreno_gpu == AdrenoGpu::kAdreno330;
152 }
153 
IsAdreno4xx() const154 bool AdrenoInfo::IsAdreno4xx() const {
155   return adreno_gpu == AdrenoGpu::kAdreno405 ||
156          adreno_gpu == AdrenoGpu::kAdreno418 ||
157          adreno_gpu == AdrenoGpu::kAdreno420 ||
158          adreno_gpu == AdrenoGpu::kAdreno430;
159 }
160 
IsAdreno5xx() const161 bool AdrenoInfo::IsAdreno5xx() const {
162   return adreno_gpu == AdrenoGpu::kAdreno504 ||
163          adreno_gpu == AdrenoGpu::kAdreno505 ||
164          adreno_gpu == AdrenoGpu::kAdreno506 ||
165          adreno_gpu == AdrenoGpu::kAdreno508 ||
166          adreno_gpu == AdrenoGpu::kAdreno509 ||
167          adreno_gpu == AdrenoGpu::kAdreno510 ||
168          adreno_gpu == AdrenoGpu::kAdreno512 ||
169          adreno_gpu == AdrenoGpu::kAdreno530 ||
170          adreno_gpu == AdrenoGpu::kAdreno540;
171 }
172 
IsAdreno6xx() const173 bool AdrenoInfo::IsAdreno6xx() const {
174   return adreno_gpu == AdrenoGpu::kAdreno605 ||
175          adreno_gpu == AdrenoGpu::kAdreno610 ||
176          adreno_gpu == AdrenoGpu::kAdreno612 ||
177          adreno_gpu == AdrenoGpu::kAdreno615 ||
178          adreno_gpu == AdrenoGpu::kAdreno616 ||
179          adreno_gpu == AdrenoGpu::kAdreno618 ||
180          adreno_gpu == AdrenoGpu::kAdreno620 ||
181          adreno_gpu == AdrenoGpu::kAdreno630 ||
182          adreno_gpu == AdrenoGpu::kAdreno640 ||
183          adreno_gpu == AdrenoGpu::kAdreno650 ||
184          adreno_gpu == AdrenoGpu::kAdreno675 ||
185          adreno_gpu == AdrenoGpu::kAdreno680 ||
186          adreno_gpu == AdrenoGpu::kAdreno685;
187 }
188 
IsAdreno6xxOrHigher() const189 bool AdrenoInfo::IsAdreno6xxOrHigher() const {
190   return !compiler_bugs_in_a6xx && IsAdreno6xx();
191 }
192 
GetMaximumWavesCount() const193 int AdrenoInfo::GetMaximumWavesCount() const {
194   if (IsAdreno6xx()) {
195     if (adreno_gpu == AdrenoGpu::kAdreno640) {
196       return 30;
197     } else {
198       return 16;
199     }
200   } else {
201     // all other versions not supported
202     return 1;
203   }
204 }
205 
GetRegisterMemorySizePerComputeUnit() const206 int AdrenoInfo::GetRegisterMemorySizePerComputeUnit() const {
207   if (IsAdreno6xx()) {
208     if (adreno_gpu == AdrenoGpu::kAdreno640) {
209       return 128 * 144 * 16;
210     } else if (adreno_gpu == AdrenoGpu::kAdreno650 ||
211                adreno_gpu == AdrenoGpu::kAdreno620) {
212       return 128 * 64 * 16;
213     } else {
214       return 128 * 96 * 16;
215     }
216   } else {
217     // all other versions not supported
218     return 1;
219   }
220 }
221 
GetMaximumWavesCount(int register_footprint_per_tread,bool full_wave) const222 int AdrenoInfo::GetMaximumWavesCount(int register_footprint_per_tread,
223                                      bool full_wave) const {
224   const int register_usage_per_wave =
225       GetWaveSize(full_wave) * register_footprint_per_tread;
226   const int possible_waves_count =
227       GetRegisterMemorySizePerComputeUnit() / register_usage_per_wave;
228   return std::min(possible_waves_count, GetMaximumWavesCount());
229 }
230 
GetWaveSize(bool full_wave) const231 int AdrenoInfo::GetWaveSize(bool full_wave) const {
232   if (IsAdreno6xx()) {
233     return full_wave ? 128 : 64;
234   } else if (IsAdreno5xx() || IsAdreno4xx()) {
235     return full_wave ? 64 : 32;
236   } else {
237     // all other versions not supported
238     return 1;
239   }
240 }
241 
AppleInfo(const std::string & gpu_description)242 AppleInfo::AppleInfo(const std::string& gpu_description) {
243   const std::map<std::string, AppleGpu> kMapping = {
244       {"apple a7 gpu", AppleGpu::kA7},     {"apple a8 gpu", AppleGpu::kA8},
245       {"apple a8x gpu", AppleGpu::kA8X},   {"apple a9 gpu", AppleGpu::kA9},
246       {"apple a9x gpu", AppleGpu::kA9X},   {"apple a10 gpu", AppleGpu::kA10},
247       {"apple a10x gpu", AppleGpu::kA10X}, {"apple a11 gpu", AppleGpu::kA11},
248       {"apple a12 gpu", AppleGpu::kA12},   {"apple a12x gpu", AppleGpu::kA12X},
249       {"apple a12z gpu", AppleGpu::kA12Z}, {"apple a13 gpu", AppleGpu::kA13},
250       {"apple a14 gpu", AppleGpu::kA14},
251   };
252   auto it = kMapping.find(gpu_description);
253   if (it != kMapping.end()) {
254     gpu_type = it->second;
255   } else {
256     gpu_type = AppleGpu::kUnknown;
257   }
258 }
259 
IsLocalMemoryPreferredOverGlobal() const260 bool AppleInfo::IsLocalMemoryPreferredOverGlobal() const {
261   return gpu_type == AppleGpu::kA7 || gpu_type == AppleGpu::kA8 ||
262          gpu_type == AppleGpu::kA8X;
263 }
264 
IsBionic() const265 bool AppleInfo::IsBionic() const {
266   return gpu_type == AppleGpu::kA11 || gpu_type == AppleGpu::kA12 ||
267          gpu_type == AppleGpu::kA12X || gpu_type == AppleGpu::kA12Z ||
268          gpu_type == AppleGpu::kA13 || gpu_type == AppleGpu::kA14;
269 }
270 
IsRoundToNearestSupported() const271 bool AppleInfo::IsRoundToNearestSupported() const { return IsBionic(); }
272 
GetComputeUnitsCount() const273 int AppleInfo::GetComputeUnitsCount() const {
274   switch (gpu_type) {
275     case AppleGpu::kA7:
276       return 4;
277     case AppleGpu::kA8:
278       return 4;
279     case AppleGpu::kA8X:
280       return 8;
281     case AppleGpu::kA9:
282       return 6;
283     case AppleGpu::kA9X:
284       return 12;
285     case AppleGpu::kA10:
286       return 6;
287     case AppleGpu::kA10X:
288       return 12;
289     case AppleGpu::kA11:
290       return 3;
291     case AppleGpu::kA12:
292       return 4;
293     case AppleGpu::kA12X:
294       return 7;
295     case AppleGpu::kA12Z:
296       return 8;
297     case AppleGpu::kA13:
298       return 4;
299     case AppleGpu::kA14:
300       return 4;
301     case AppleGpu::kUnknown:
302       return 1;
303   }
304 }
305 
MaliInfo(const std::string & gpu_description)306 MaliInfo::MaliInfo(const std::string& gpu_description)
307     : gpu_version(GetMaliGpuVersion(gpu_description)) {}
308 
IsMaliT6xx() const309 bool MaliInfo::IsMaliT6xx() const {
310   return gpu_version == MaliGpu::kT604 || gpu_version == MaliGpu::kT622 ||
311          gpu_version == MaliGpu::kT624 || gpu_version == MaliGpu::kT628 ||
312          gpu_version == MaliGpu::kT658 || gpu_version == MaliGpu::kT678;
313 }
314 
IsMaliT7xx() const315 bool MaliInfo::IsMaliT7xx() const {
316   return gpu_version == MaliGpu::kT720 || gpu_version == MaliGpu::kT760;
317 }
318 
IsMaliT8xx() const319 bool MaliInfo::IsMaliT8xx() const {
320   return gpu_version == MaliGpu::kT820 || gpu_version == MaliGpu::kT830 ||
321          gpu_version == MaliGpu::kT860 || gpu_version == MaliGpu::kT880;
322 }
323 
IsMidgard() const324 bool MaliInfo::IsMidgard() const {
325   return IsMaliT6xx() || IsMaliT7xx() || IsMaliT8xx();
326 }
327 
IsBifrostGen1() const328 bool MaliInfo::IsBifrostGen1() const {
329   return gpu_version == MaliGpu::kG31 || gpu_version == MaliGpu::kG51 ||
330          gpu_version == MaliGpu::kG71;
331 }
332 
IsBifrostGen2() const333 bool MaliInfo::IsBifrostGen2() const {
334   return gpu_version == MaliGpu::kG52 || gpu_version == MaliGpu::kG72;
335 }
336 
IsBifrostGen3() const337 bool MaliInfo::IsBifrostGen3() const { return gpu_version == MaliGpu::kG76; }
338 
IsBifrost() const339 bool MaliInfo::IsBifrost() const {
340   return IsBifrostGen1() || IsBifrostGen2() || IsBifrostGen3();
341 }
342 
IsValhall() const343 bool MaliInfo::IsValhall() const {
344   return gpu_version == MaliGpu::kG57 || gpu_version == MaliGpu::kG77 ||
345          gpu_version == MaliGpu::kG68 || gpu_version == MaliGpu::kG78;
346 }
347 
GetGpuInfoFromDeviceDescription(const std::string & gpu_description,GpuApi gpu_api,GpuInfo * gpu_info)348 void GetGpuInfoFromDeviceDescription(const std::string& gpu_description,
349                                      GpuApi gpu_api, GpuInfo* gpu_info) {
350   gpu_info->gpu_api = gpu_api;
351   std::string lowered = gpu_description;
352   absl::AsciiStrToLower(&lowered);
353   gpu_info->vendor = GetGpuVendor(lowered);
354   if (gpu_info->IsAdreno()) {
355     gpu_info->adreno_info = AdrenoInfo(lowered);
356   } else if (gpu_info->IsApple()) {
357     gpu_info->apple_info = AppleInfo(lowered);
358     gpu_info->supported_subgroup_sizes = {32};
359   } else if (gpu_info->IsMali()) {
360     gpu_info->mali_info = MaliInfo(lowered);
361   }
362 }
363 
OpenClVersionToString(OpenClVersion version)364 std::string OpenClVersionToString(OpenClVersion version) {
365   switch (version) {
366     case OpenClVersion::kCl1_0:
367       return "1.0";
368     case OpenClVersion::kCl1_1:
369       return "1.1";
370     case OpenClVersion::kCl1_2:
371       return "1.2";
372     case OpenClVersion::kCl2_0:
373       return "2.0";
374     case OpenClVersion::kCl2_1:
375       return "2.1";
376     case OpenClVersion::kCl2_2:
377       return "2.2";
378     case OpenClVersion::kCl3_0:
379       return "3.0";
380     default:
381       return "Unknown OpenCL version";
382   }
383 }
384 
IsAdreno() const385 bool GpuInfo::IsAdreno() const { return vendor == GpuVendor::kQualcomm; }
386 
IsApple() const387 bool GpuInfo::IsApple() const { return vendor == GpuVendor::kApple; }
388 
IsMali() const389 bool GpuInfo::IsMali() const { return vendor == GpuVendor::kMali; }
390 
IsPowerVR() const391 bool GpuInfo::IsPowerVR() const { return vendor == GpuVendor::kPowerVR; }
392 
IsNvidia() const393 bool GpuInfo::IsNvidia() const { return vendor == GpuVendor::kNvidia; }
394 
IsAMD() const395 bool GpuInfo::IsAMD() const { return vendor == GpuVendor::kAMD; }
396 
IsIntel() const397 bool GpuInfo::IsIntel() const { return vendor == GpuVendor::kIntel; }
398 
IsRoundToNearestSupported() const399 bool GpuInfo::IsRoundToNearestSupported() const {
400   if (IsApiOpenCl()) {
401     return opencl_info.supports_fp16_rtn || opencl_info.supports_fp32_rtn;
402   }
403   if (IsApple()) {
404     return apple_info.IsRoundToNearestSupported();
405   }
406   return true;
407 }
408 
SupportsFP16() const409 bool GpuInfo::SupportsFP16() const {
410   if (IsApiOpenCl()) {
411     return opencl_info.supports_fp16;
412   }
413   return true;
414 }
415 
SupportsTextureArray() const416 bool GpuInfo::SupportsTextureArray() const {
417   if (!SupportsImages()) {
418     return false;
419   }
420   if (IsApiOpenCl()) {
421     return opencl_info.cl_version >= OpenClVersion::kCl1_2;
422   }
423   return true;
424 }
425 
SupportsImageBuffer() const426 bool GpuInfo::SupportsImageBuffer() const {
427   if (!SupportsImages()) {
428     return false;
429   }
430   if (IsApiOpenCl()) {
431     return opencl_info.cl_version >= OpenClVersion::kCl1_2;
432   }
433   return true;
434 }
435 
SupportsImage3D() const436 bool GpuInfo::SupportsImage3D() const {
437   if (!SupportsImages()) {
438     return false;
439   }
440   if (IsApiOpenCl()) {
441     if (IsMali() && mali_info.IsMidgard()) {
442       // On Mali T880 read_imageh doesn't compile with image3d_t
443       return false;
444     }
445     return opencl_info.supports_image3d_writes;
446   }
447   return true;
448 }
449 
SupportsImages() const450 bool GpuInfo::SupportsImages() const {
451   if (IsApiOpenCl()) {
452     return opencl_info.supports_images;
453   }
454   return true;
455 }
456 
IsWaveSizeEqualTo32() const457 bool GpuInfo::IsWaveSizeEqualTo32() const {
458   return supported_subgroup_sizes.size() == 1 &&
459          supported_subgroup_sizes[0] == 32;
460 }
461 
SupportsExtension(const std::string & extension) const462 bool GpuInfo::SupportsExtension(const std::string& extension) const {
463   const std::vector<std::string>* extensions = nullptr;
464   if (IsApiOpenGl()) {
465     extensions = &opengl_info.extensions;
466   } else if (IsApiVulkan()) {
467     extensions = &vulkan_info.extensions;
468   } else if (IsApiOpenCl()) {
469     extensions = &opencl_info.extensions;
470   }
471   if (!extensions) {
472     return false;
473   }
474   for (const auto& ext : *extensions) {
475     if (ext == extension) {
476       return true;
477     }
478   }
479   return false;
480 }
481 
SupportsSubGroupWithSize(int sub_group_size) const482 bool GpuInfo::SupportsSubGroupWithSize(int sub_group_size) const {
483   for (auto subgroup_size : supported_subgroup_sizes) {
484     if (sub_group_size == subgroup_size) {
485       return true;
486     }
487   }
488   return false;
489 }
490 
SupportsFloatImage2D(DataType data_type,int channels) const491 bool GpuInfo::SupportsFloatImage2D(DataType data_type, int channels) const {
492   if (IsApiOpenCl()) {
493     if (channels == 1) {
494       return data_type == DataType::FLOAT32 ? opencl_info.supports_r_f32_tex2d
495                                             : opencl_info.supports_r_f16_tex2d;
496     } else if (channels == 2) {
497       return data_type == DataType::FLOAT32 ? opencl_info.supports_rg_f32_tex2d
498                                             : opencl_info.supports_rg_f16_tex2d;
499     } else if (channels == 3) {
500       return data_type == DataType::FLOAT32
501                  ? opencl_info.supports_rgb_f32_tex2d
502                  : opencl_info.supports_rgb_f16_tex2d;
503     } else if (channels == 4) {
504       return data_type == DataType::FLOAT32
505                  ? opencl_info.supports_rgba_f32_tex2d
506                  : opencl_info.supports_rgba_f16_tex2d;
507     } else {
508       return false;
509     }
510   }
511   return false;
512 }
513 
GetComputeUnitsCount() const514 int GpuInfo::GetComputeUnitsCount() const {
515   if (IsApiOpenCl()) {
516     return opencl_info.compute_units_count;
517   }
518   if (IsApple()) {
519     return apple_info.GetComputeUnitsCount();
520   }
521   return 1;
522 }
523 
GetMaxWorkGroupSizeForX() const524 int GpuInfo::GetMaxWorkGroupSizeForX() const {
525   if (IsApiOpenGl()) {
526     return opengl_info.max_compute_work_group_size_x;
527   }
528   if (IsApiVulkan()) {
529     return vulkan_info.max_compute_work_group_size_x;
530   }
531   if (IsApiOpenCl()) {
532     return opencl_info.max_work_group_size_x;
533   }
534   if (IsApiMetal()) {
535     return metal_info.max_work_group_size_x;
536   }
537   return 256;
538 }
539 
GetMaxWorkGroupSizeForY() const540 int GpuInfo::GetMaxWorkGroupSizeForY() const {
541   if (IsApiOpenGl()) {
542     return opengl_info.max_compute_work_group_size_y;
543   }
544   if (IsApiVulkan()) {
545     return vulkan_info.max_compute_work_group_size_y;
546   }
547   if (IsApiOpenCl()) {
548     return opencl_info.max_work_group_size_y;
549   }
550   if (IsApiMetal()) {
551     return metal_info.max_work_group_size_y;
552   }
553   return 256;
554 }
555 
GetMaxWorkGroupSizeForZ() const556 int GpuInfo::GetMaxWorkGroupSizeForZ() const {
557   if (IsApiOpenGl()) {
558     return opengl_info.max_compute_work_group_size_z;
559   }
560   if (IsApiVulkan()) {
561     return vulkan_info.max_compute_work_group_size_z;
562   }
563   if (IsApiOpenCl()) {
564     return opencl_info.max_work_group_size_z;
565   }
566   if (IsApiMetal()) {
567     return metal_info.max_work_group_size_z;
568   }
569   return 64;
570 }
571 
GetMaxWorkGroupTotalSize() const572 int GpuInfo::GetMaxWorkGroupTotalSize() const {
573   if (IsApiOpenGl()) {
574     return opengl_info.max_work_group_invocations;
575   }
576   if (IsApiVulkan()) {
577     return vulkan_info.max_compute_work_group_invocations;
578   }
579   if (IsApiOpenCl()) {
580     return opencl_info.max_work_group_total_size;
581   }
582   if (IsApiMetal()) {
583     int max_size = metal_info.max_work_group_size_x;
584     max_size = std::max(max_size, metal_info.max_work_group_size_y);
585     max_size = std::max(max_size, metal_info.max_work_group_size_z);
586     return max_size;
587   }
588   return 256;
589 }
590 
GetMaxImage2DWidth() const591 uint64_t GpuInfo::GetMaxImage2DWidth() const {
592   if (IsApiOpenGl()) {
593     return opengl_info.max_texture_size;
594   }
595   if (IsApiVulkan()) {
596     return vulkan_info.max_image_dimension_2d;
597   }
598   if (IsApiOpenCl()) {
599     return opencl_info.image2d_max_width;
600   }
601   return 2048;
602 }
603 
GetMaxImage2DHeight() const604 uint64_t GpuInfo::GetMaxImage2DHeight() const {
605   if (IsApiOpenGl()) {
606     return opengl_info.max_texture_size;
607   }
608   if (IsApiVulkan()) {
609     return vulkan_info.max_image_dimension_2d;
610   }
611   if (IsApiOpenCl()) {
612     return opencl_info.image2d_max_height;
613   }
614   return 2048;
615 }
616 
GetMaxImage2DArrayLayers() const617 uint64_t GpuInfo::GetMaxImage2DArrayLayers() const {
618   if (IsApiOpenGl()) {
619     return opengl_info.max_array_texture_layers;
620   }
621   if (IsApiVulkan()) {
622     return vulkan_info.max_image_array_layers;
623   }
624   if (IsApiOpenCl()) {
625     return opencl_info.image_array_max_layers;
626   }
627   return 256;
628 }
629 
GetMaxImage3DWidth() const630 uint64_t GpuInfo::GetMaxImage3DWidth() const {
631   if (IsApiOpenCl()) {
632     return opencl_info.image3d_max_width;
633   }
634   return 256;
635 }
636 
GetMaxImage3DHeight() const637 uint64_t GpuInfo::GetMaxImage3DHeight() const {
638   if (IsApiOpenCl()) {
639     return opencl_info.image3d_max_height;
640   }
641   return 256;
642 }
643 
GetMaxImage3DDepth() const644 uint64_t GpuInfo::GetMaxImage3DDepth() const {
645   if (IsApiOpenCl()) {
646     return opencl_info.image3d_max_depth;
647   }
648   return 256;
649 }
650 
GetMaxBufferSize() const651 uint64_t GpuInfo::GetMaxBufferSize() const {
652   if (IsApiOpenCl()) {
653     return opencl_info.buffer_max_size;
654   } else if (IsApiMetal()) {
655     return metal_info.buffer_max_size;
656   }
657   return 128 * 1024 * 1024;
658 }
659 
GetMaxImageBufferWidth() const660 uint64_t GpuInfo::GetMaxImageBufferWidth() const {
661   if (IsApiOpenCl()) {
662     return opencl_info.image_buffer_max_size;
663   }
664   return 64 * 1024;
665 }
666 
GetMaxImageArguments() const667 int GpuInfo::GetMaxImageArguments() const {
668   if (IsApiOpenGl()) {
669     return opengl_info.max_image_units;
670   }
671   if (IsApiVulkan()) {
672     return vulkan_info.max_per_stage_descriptor_sampled_images;
673   }
674   if (IsApiMetal()) {
675     return 32;
676   }
677   if (IsApiOpenCl()) {
678     return 128;
679   }
680   return 1;
681 }
682 
IsApiOpenGl() const683 bool GpuInfo::IsApiOpenGl() const { return gpu_api == GpuApi::kOpenGl; }
684 
IsApiOpenGl31OrAbove() const685 bool GpuInfo::IsApiOpenGl31OrAbove() const {
686   if (!IsApiOpenGl()) {
687     return false;
688   }
689   return (opengl_info.major_version == 3 && opengl_info.minor_version >= 1) ||
690          opengl_info.major_version > 3;
691 }
692 
IsApiVulkan() const693 bool GpuInfo::IsApiVulkan() const { return gpu_api == GpuApi::kVulkan; }
694 
IsApiMetal() const695 bool GpuInfo::IsApiMetal() const { return gpu_api == GpuApi::kMetal; }
696 
IsApiOpenCl() const697 bool GpuInfo::IsApiOpenCl() const { return gpu_api == GpuApi::kOpenCl; }
698 
IsCL20OrHigher() const699 bool GpuInfo::IsCL20OrHigher() const {
700   if (!IsApiOpenCl()) {
701     return false;
702   }
703   return opencl_info.cl_version != OpenClVersion::kCl1_0 &&
704          opencl_info.cl_version != OpenClVersion::kCl1_1 &&
705          opencl_info.cl_version != OpenClVersion::kCl1_2;
706 }
707 
IsCL30OrHigher() const708 bool GpuInfo::IsCL30OrHigher() const {
709   if (!IsApiOpenCl()) {
710     return false;
711   }
712   return IsCL20OrHigher() && opencl_info.cl_version != OpenClVersion::kCl2_0 &&
713          opencl_info.cl_version != OpenClVersion::kCl2_1 &&
714          opencl_info.cl_version != OpenClVersion::kCl2_2;
715 }
716 
717 }  // namespace gpu
718 }  // namespace tflite
719