1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/conv_metal.h"
17 
18 #include <cmath>
19 #include <cstdint>
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/strings/substitute.h"
27 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
28 #include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
29 #include "tensorflow/lite/delegates/gpu/common/operations.h"
30 #include "tensorflow/lite/delegates/gpu/common/shape.h"
31 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
32 #include "tensorflow/lite/delegates/gpu/common/task/weights_conversion.h"
33 #include "tensorflow/lite/delegates/gpu/common/task/weights_layout.h"
34 #include "tensorflow/lite/delegates/gpu/common/types.h"
35 #include "tensorflow/lite/delegates/gpu/common/util.h"
36 #include "tensorflow/lite/delegates/gpu/common/winograd_util.h"
37 
38 namespace tflite {
39 namespace gpu {
40 
41 namespace {
42 
GetNumOutputSlices(int dst_channels)43 int GetNumOutputSlices(int dst_channels) {
44   const int dst_depth = DivideRoundUp(dst_channels, 4);
45   if (dst_depth % 4 == 0 || dst_depth >= 16) {
46     return 4;
47   } else if (dst_depth % 2 == 0 || dst_depth >= 4) {
48     return 2;
49   } else {
50     return 1;
51   }
52 }
53 
54 struct GlobalIdsParams {
55   std::vector<std::string> global_ids;
56   std::vector<std::string> group_ids;
57   std::vector<std::string> local_sizes;
58   std::vector<std::string> local_ids;
59   int3 block_size;
60   int3 launch_order;
61   bool linear_wh;
62   bool linear_whs;
63   std::string task_size_w;  // must be filled if linear_wh or linear_whs enabled
64   std::string task_size_wh;  // must be filled if linear_whs enabled
65 };
66 
GlobalIdsGen(const GlobalIdsParams & params)67 std::string GlobalIdsGen(const GlobalIdsParams& params) {
68   std::string c;
69   int3 launch_remap;
70   launch_remap[params.launch_order.x] = 0;
71   launch_remap[params.launch_order.y] = 1;
72   launch_remap[params.launch_order.z] = 2;
73   if (params.linear_whs) {
74     c += "  int linear_whs = " + params.global_ids[0] + ";\n";
75     c += "  int Z = (linear_whs / " + params.task_size_wh + ") * " +
76          std::to_string(params.block_size.z) + ";\n";
77     c += "  int linear_wh = linear_whs % " + params.task_size_wh + ";\n";
78     c += "  int Y = (linear_wh / " + params.task_size_w + ") * " +
79          std::to_string(params.block_size.y) + ";\n";
80     c += "  int X = (linear_wh % " + params.task_size_w + ") * " +
81          std::to_string(params.block_size.x) + ";\n";
82   } else if (params.linear_wh) {
83     if (params.launch_order.x == 0) {
84       c += "  int linear_wh = " + params.global_ids[0] + ";\n";
85     } else {
86       c += "  int linear_wh = " + params.group_ids[launch_remap.x] + " * " +
87            params.local_sizes[0] + " + " + params.local_ids[0] + ";\n";
88     }
89     c += "  int Y = (linear_wh / " + params.task_size_w + ") * " +
90          std::to_string(params.block_size.y) + ";\n";
91     c += "  int X = (linear_wh % " + params.task_size_w + ") * " +
92          std::to_string(params.block_size.x) + ";\n";
93     if (params.launch_order.y == 1) {
94       c += "  int Z = " + params.global_ids[1] + " * " +
95            std::to_string(params.block_size.z) + ";\n";
96     } else {
97       c += "  int Z = (" + params.group_ids[launch_remap.y] + " * " +
98            params.local_sizes[1] + " + " + params.local_ids[1] + ") * " +
99            std::to_string(params.block_size.z) + ";\n";
100     }
101   } else {
102     if (params.launch_order.x == 0) {
103       c += "  int X = " + params.global_ids[0] + " * " +
104            std::to_string(params.block_size.x) + ";\n";
105     } else {
106       c += "  int X = (" + params.group_ids[launch_remap.x] + " * " +
107            params.local_sizes[0] + " + " + params.local_ids[0] + ") * " +
108            std::to_string(params.block_size.x) + ";\n";
109     }
110     if (params.launch_order.y == 1) {
111       c += "  int Y = " + params.global_ids[1] + " * " +
112            std::to_string(params.block_size.y) + ";\n";
113     } else {
114       c += "  int Y = (" + params.group_ids[launch_remap.y] + " * " +
115            params.local_sizes[1] + " + " + params.local_ids[1] + ") * " +
116            std::to_string(params.block_size.y) + ";\n";
117     }
118     if (params.launch_order.z == 2) {
119       c += "  int Z = " + params.global_ids[2] + " * " +
120            std::to_string(params.block_size.z) + ";\n";
121     } else {
122       c += "  int Z = (" + params.group_ids[launch_remap.z] + " * " +
123            params.local_sizes[2] + " + " + params.local_ids[2] + ") * " +
124            std::to_string(params.block_size.z) + ";\n";
125     }
126   }
127   return c;
128 }
129 
GenerateUploadByThreads(const std::string & local_ptr_name,const std::string & global_ptr_name,const std::string & global_offset_name,const std::string & lid_name,int total_work_items,int elements_to_upload)130 std::string GenerateUploadByThreads(const std::string& local_ptr_name,
131                                     const std::string& global_ptr_name,
132                                     const std::string& global_offset_name,
133                                     const std::string& lid_name,
134                                     int total_work_items,
135                                     int elements_to_upload) {
136   std::string c;
137   std::string offset =
138       global_offset_name.empty() ? "" : global_offset_name + " + ";
139   const int groups = elements_to_upload / total_work_items;
140   const int reminder = elements_to_upload % total_work_items;
141   for (int i = 0; i < groups; ++i) {
142     c += "    " + local_ptr_name + "[" + lid_name + " + " +
143          std::to_string(total_work_items * i) + "] = " + global_ptr_name + "[" +
144          offset + lid_name + " + " + std::to_string(total_work_items * i) +
145          "];\n";
146   }
147   if (reminder != 0) {
148     c += "    if (" + lid_name + " < " + std::to_string(reminder) + ") {\n";
149     c += "      " + local_ptr_name + "[" + lid_name + " + " +
150          std::to_string(total_work_items * groups) + "] = " + global_ptr_name +
151          "[" + offset + lid_name + " + " +
152          std::to_string(total_work_items * groups) + "];\n";
153     c += "    }\n";
154   }
155   return c;
156 }
157 
GenerateConvolution(const ConvolutionMetal::ConvParams & params,const OperationDef & definition,bool stride_correction)158 std::string GenerateConvolution(const ConvolutionMetal::ConvParams& params,
159                                 const OperationDef& definition,
160                                 bool stride_correction) {
161   GlobalIdsParams ids_params;
162   ids_params.group_ids = {"group_id.x", "group_id.y", "group_id.z"};
163   ids_params.global_ids = {"ugid.x", "ugid.y", "ugid.z"};
164   ids_params.local_ids = {"tid3d.x", "tid3d.y", "tid3d.z"};
165   ids_params.local_sizes = {"lsize.x", "lsize.y", "lsize.z"};
166   ids_params.linear_wh = params.linear_wh;
167   ids_params.task_size_w = "args.task_size_x";
168   ids_params.task_size_wh = "args.task_size_y";
169   ids_params.linear_whs = params.linear_whs;
170   ids_params.block_size = params.block_size;
171   ids_params.launch_order = params.work_group_launch_order;
172 
173   std::string addr_space =
174       params.weights_upload_type ==
175               ConvolutionMetal::WeightsUploadType::CONSTANT_MEM
176           ? "constant"
177           : "device";
178   const bool use_local_mem =
179       params.weights_upload_type ==
180       ConvolutionMetal::WeightsUploadType::LOCAL_MEM_BY_THREADS;
181   const int local_mem_size =
182       params.block_size.z * 4 * params.src_depth_loop_size;
183 
184   const bool use_simd_broadcast =
185       params.weights_upload_type ==
186           ConvolutionMetal::WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST ||
187       params.weights_upload_type ==
188           ConvolutionMetal::WeightsUploadType::PRIVATE_MEM_SIMD16_BROADCAST ||
189       params.weights_upload_type ==
190           ConvolutionMetal::WeightsUploadType::PRIVATE_MEM_SIMD32_BROADCAST;
191   int simd_size = 1;
192   if (params.weights_upload_type ==
193       ConvolutionMetal::WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST) {
194     simd_size = 8;
195   } else if (params.weights_upload_type == ConvolutionMetal::WeightsUploadType::
196                                                PRIVATE_MEM_SIMD16_BROADCAST) {
197     simd_size = 16;
198   } else if (params.weights_upload_type == ConvolutionMetal::WeightsUploadType::
199                                                PRIVATE_MEM_SIMD32_BROADCAST) {
200     simd_size = 32;
201   }
202 
203   const bool use_filters_constants =
204       !params.need_dst_loop && !params.need_src_loop && params.x_kernel_is_1 &&
205       params.y_kernel_is_1;
206 
207   const auto src_storage_type = definition.src_tensors[0].storage_type;
208   const auto dst_storage_type = definition.dst_tensors[0].storage_type;
209   const bool src_is_linear =
210       src_storage_type == TensorStorageType::BUFFER ||
211       src_storage_type == TensorStorageType::IMAGE_BUFFER;
212   const bool dst_is_linear =
213       dst_storage_type == TensorStorageType::BUFFER ||
214       dst_storage_type == TensorStorageType::IMAGE_BUFFER;
215 
216   std::string channels[4] = {"x", "y", "z", "w"};
217   std::string c;
218   c.reserve(16 * 1024);  // Reserve large enough buffer.
219   c += R"(
220 kernel void ComputeFunction(
221     $0
222     uint tid[[thread_index_in_threadgroup]],
223     uint3 group_id[[threadgroup_position_in_grid]],
224     uint3 tid3d[[thread_position_in_threadgroup]],
225     uint3 lsize[[threads_per_threadgroup]],
226 )";
227   if (use_simd_broadcast) {
228     c += "    uint simd_id[[thread_index_in_simdgroup]],\n";
229   }
230   c += "    uint3 ugid[[thread_position_in_grid]]){\n";
231   c += GlobalIdsGen(ids_params);
232   c += "  if (Z >= args.dst_tensor.Slices()) return;\n";
233   bool late_xy_check = use_local_mem || use_simd_broadcast;
234   if (!late_xy_check && !params.linear_whs) {
235     c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height()) "
236          "return;\n";
237   }
238   for (int z = 0; z < params.block_size.z; ++z) {
239     for (int y = 0; y < params.block_size.y; ++y) {
240       for (int x = 0; x < params.block_size.x; ++x) {
241         const std::string s_i =
242             std::to_string(z) + std::to_string(y) + std::to_string(x);
243         c +=
244             "  ACCUM_FLT4 r" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n";
245       }
246     }
247   }
248   auto for_every_yx =
249       [&](std::function<std::string(const std::string&, const std::string&,
250                                     const std::string&, int, int)>
251               lambda) {
252         for (int y = 0; y < params.block_size.y; ++y) {
253           const std::string s_y = std::to_string(y);
254           for (int x = 0; x < params.block_size.x; ++x) {
255             const std::string s_x = std::to_string(x);
256             const std::string s_yx = s_y + s_x;
257             c += lambda(s_yx, s_x, s_y, x, y) + "\n";
258           }
259         }
260       };
261   if (!use_filters_constants) {
262     std::string kern_x = params.x_kernel_is_1 ? "" : " * args.kernel_size_x";
263     std::string kern_y = params.y_kernel_is_1 ? "" : " * args.kernel_size_y";
264     std::string dst_offset =
265         params.need_dst_loop ? " + Z * 4 * args.src_tensor.Slices()" : "";
266     if (!params.need_dst_loop) {
267       c += "  " + addr_space + " FLT4* tmp = args.weights.GetPtr();\n";
268     } else {
269       if (params.different_weights_for_height) {
270         c += "  " + addr_space +
271              " FLT4* tmp = args.weights.GetPtr() + (Z * "
272              "args.src_tensor.Height() + Y * " +
273              std::to_string(params.block_size.z) +
274              ") * 4 * args.src_tensor.Slices();\n";
275       } else {
276         c += "  " + addr_space +
277              " FLT4* tmp = args.weights.GetPtr() + Z * 4 * "
278              "args.src_tensor.Slices()" +
279              kern_x + kern_y + ";\n";
280       }
281     }
282   }
283   if (!params.x_kernel_is_1) {
284     for (int x = 0; x < params.block_size.x; ++x) {
285       const std::string s_x = std::to_string(x);
286       if (stride_correction) {
287         c += "  int x" + s_x + " = " +
288              GetXStrideCorrected("(X + " + s_x + ")", "args.src_tensor.Batch()",
289                                  "args.stride_x", "args.padding_x") +
290              ";\n";
291       } else {
292         c += "  int x" + s_x + " = (X + " + s_x +
293              ") * args.stride_x + args.padding_x;\n";
294       }
295     }
296   }
297   if (!params.y_kernel_is_1) {
298     for (int y = 0; y < params.block_size.y; ++y) {
299       const std::string s_y = std::to_string(y);
300       c += "  int y" + s_y + " = (Y + " + s_y +
301            ") * args.stride_y + args.padding_y;\n";
302     }
303   }
304   if (use_local_mem) {
305     c += "  threadgroup FLT4 weights_cache[" + std::to_string(local_mem_size) +
306          "];\n";
307   }
308   if (!params.y_kernel_is_1) {
309     c += "  int y = 0;\n";
310     c += "  do {\n";
311     for (int y = 0; y < params.block_size.y; ++y) {
312       const std::string s_y = std::to_string(y);
313       c += "  int c_y" + s_y + " = y * args.dilation_y + y" + s_y + ";\n";
314       if (src_is_linear) {
315         c += "  bool y" + s_y + "_out = c_y" + s_y + " < 0 || c_y" + s_y +
316              " >= args.src_tensor.Height();\n";
317         c += "  c_y" + s_y + " = clamp(c_y" + s_y +
318              ", 0, args.src_tensor.Height() - 1);\n";
319       }
320     }
321   } else {
322     for (int y = 0; y < params.block_size.y; ++y) {
323       const std::string s_y = std::to_string(y);
324       c += "  int c_y" + s_y + " = clamp(Y + " + s_y +
325            ", 0, args.src_tensor.Height() - 1);\n";
326     }
327   }
328   if (!params.x_kernel_is_1) {
329     c += "  int x = 0;\n";
330     c += "  do {\n";
331     for (int x = 0; x < params.block_size.x; ++x) {
332       const std::string s_x = std::to_string(x);
333       c += "  int c_x" + s_x + " = x * args.dilation_x + x" + s_x + ";\n";
334       if (src_is_linear) {
335         c += "  bool x" + s_x + "_out = c_x" + s_x + " < 0 || c_x" + s_x +
336              " >= args.src_tensor.Width();\n";
337         c += "  c_x" + s_x + " = clamp(c_x" + s_x +
338              ", 0, args.src_tensor.Width() - 1);\n";
339       }
340     }
341   } else {
342     for (int x = 0; x < params.block_size.x; ++x) {
343       const std::string s_x = std::to_string(x);
344       c += "  int c_x" + s_x + " = clamp(X + " + s_x +
345            ", 0, args.src_tensor.Width() - 1);\n";
346     }
347   }
348   if (src_is_linear) {
349     for (int y = 0; y < params.block_size.y; ++y) {
350       const std::string s_y = std::to_string(y);
351       for (int x = 0; x < params.block_size.x; ++x) {
352         const std::string s_x = std::to_string(x);
353         const std::string s_yx = s_y + s_x;
354         if (!params.y_kernel_is_1 && !params.x_kernel_is_1) {
355           c += "  FLT m" + s_yx + " = !(y" + s_y + "_out || x" + s_x +
356                "_out);\n";
357         } else if (!params.y_kernel_is_1) {
358           c += "  FLT m" + s_yx + " = !y" + s_y + "_out;\n";
359         } else if (!params.x_kernel_is_1) {
360           c += "  FLT m" + s_yx + " = !x" + s_x + "_out;\n";
361         }
362       }
363     }
364     for (int y = 0; y < params.block_size.y; ++y) {
365       const std::string s_y = std::to_string(y);
366       for (int x = 0; x < params.block_size.x; ++x) {
367         const std::string s_x = std::to_string(x);
368         const std::string s_yx = s_y + s_x;
369         if (definition.src_tensors[0].storage_type ==
370             TensorStorageType::BUFFER) {
371           c += "  device FLT4* src_loc_" + s_yx +
372                " = args.src_tensor.GetHandle() + "
373                "args.src_tensor.GetWHOffset(c_x" +
374                s_x + ", c_y" + s_y + ");\n";
375         } else if (definition.src_tensors[0].storage_type ==
376                    TensorStorageType::IMAGE_BUFFER) {
377           c += "  int src_loc_" + s_yx + " = args.src_tensor.GetWHOffset(c_x" +
378                s_x + ", c_y" + s_y + ");\n";
379         }
380       }
381     }
382   }
383   c += "  int s = 0;\n";
384   if (params.need_src_loop) {
385     c += "  do {\n";
386   }
387   if (use_local_mem) {
388     const int total_work_items = params.work_group_size.x *
389                                  params.work_group_size.y *
390                                  params.work_group_size.z;
391     c += "    SIMDGROUP_BARRIER(mem_flags::mem_none);\n";
392     c += GenerateUploadByThreads("weights_cache", "tmp",
393                                  /*global_offset_name*/ "", "tid",
394                                  total_work_items, local_mem_size);
395     c += "    SIMDGROUP_BARRIER(mem_flags::mem_threadgroup);\n";
396   } else if (use_simd_broadcast) {
397     int parts = local_mem_size / simd_size;
398     int reminder = local_mem_size % simd_size;
399     for (int i = 0; i < parts; ++i) {
400       c += "    FLT4 simd_w" + std::to_string(i) + " = tmp[simd_id + " +
401            std::to_string(i * simd_size) + "];\n";
402     }
403     if (reminder) {
404       c += "    FLT4 simd_w" + std::to_string(parts) + ";\n";
405       c += "    if (simd_id < " + std::to_string(reminder) + ") {\n";
406       c += "      simd_w" + std::to_string(parts) + " = tmp[simd_id + " +
407            std::to_string(parts * simd_size) + "];\n";
408       c += "    }\n";
409     }
410   }
411   auto declare_src = [&]() {
412     for (int y = 0; y < params.block_size.y; ++y) {
413       for (int x = 0; x < params.block_size.x; ++x) {
414         const std::string s_yx = std::to_string(y) + std::to_string(x);
415         c += "    FLT4 src" + s_yx + ";\n";
416       }
417     }
418   };
419   auto read_src = [&]() {
420     for (int y = 0; y < params.block_size.y; ++y) {
421       for (int x = 0; x < params.block_size.x; ++x) {
422         const std::string s_yx = std::to_string(y) + std::to_string(x);
423         if (src_is_linear) {
424           if (definition.src_tensors[0].storage_type ==
425               TensorStorageType::BUFFER) {
426             if (!params.y_kernel_is_1 || !params.x_kernel_is_1) {
427               c += "    src" + s_yx + " = *src_loc_" + s_yx + " * m" + s_yx +
428                    ";\n";
429             } else {
430               c += "    src" + s_yx + " = *src_loc_" + s_yx + ";\n";
431             }
432           } else if (definition.src_tensors[0].storage_type ==
433                      TensorStorageType::IMAGE_BUFFER) {
434             if (!params.y_kernel_is_1 || !params.x_kernel_is_1) {
435               c += "    src" + s_yx + " = args.src_tensor.Read(src_loc_" +
436                    s_yx + ") * m" + s_yx + ";\n";
437             } else {
438               c += "    src" + s_yx + " = args.src_tensor.Read(src_loc_" +
439                    s_yx + ");\n";
440             }
441           }
442         } else {
443           c += "    src" + s_yx + " = args.src_tensor.Read(c_x" +
444                std::to_string(x) + ", c_y" + std::to_string(y) + ", s);\n";
445         }
446       }
447     }
448     if (src_is_linear) {
449       for (int y = 0; y < params.block_size.y; ++y) {
450         for (int x = 0; x < params.block_size.x; ++x) {
451           const std::string s_yx = std::to_string(y) + std::to_string(x);
452           c += "    src_loc_" + s_yx + " += args.src_tensor.SliceStride();\n";
453         }
454       }
455     }
456   };
457   auto conv_core = [&](int offset) {
458     std::string name = use_local_mem ? "weights_cache" : "tmp";
459     if (use_filters_constants) {
460       name = "args.weights.GetPtr()";
461     }
462     for (int z = 0; z < params.block_size.z; ++z) {
463       for (int ch = 0; ch < 4; ++ch) {
464         for (int y = 0; y < params.block_size.y; ++y) {
465           for (int x = 0; x < params.block_size.x; ++x) {
466             std::string s_id = std::to_string(y) + std::to_string(x);
467             std::string r_id =
468                 std::to_string(z) + std::to_string(y) + std::to_string(x);
469             std::string f_val =
470                 name + "[" + std::to_string(z * 4 + ch + offset) + "]";
471             if (use_simd_broadcast) {
472               int simd_id = (z * 4 + ch + offset) / simd_size;
473               int thread_id = (z * 4 + ch + offset) % simd_size;
474               f_val = "simd_broadcast(simd_w" + std::to_string(simd_id) + ", " +
475                       std::to_string(thread_id) + "u)";
476             }
477             std::string s_val = "src" + s_id;
478             std::string r_val = "r" + r_id;
479             if (params.weights_layout == WeightsLayout::kOHWIOGroupO4I4) {
480               c += "    " + r_val + "." + channels[ch] + " += dot(" + f_val +
481                    ", " + s_val + ");\n";
482             } else {  // WeightsInnerBlockLayout::I404
483               std::string temp_sum = f_val + " * " + s_val + "." + channels[ch];
484               if (definition.precision == CalculationsPrecision::F32_F16) {
485                 temp_sum = "float4(" + temp_sum + ")";
486               }
487               c += "    " + r_val + " += " + temp_sum + ";\n";
488             }
489           }
490         }
491       }
492     }
493   };
494   declare_src();
495   read_src();
496   c += "    s += 1;\n";
497   conv_core(0);
498   for (int i = 1; i < params.src_depth_loop_size; ++i) {
499     read_src();
500     conv_core(i * params.block_size.z * 4);
501     c += "    s += 1;\n";
502   }
503   if (!use_filters_constants) {
504     c += "    tmp += " +
505          std::to_string(params.block_size.z * 4 * params.src_depth_loop_size) +
506          ";\n";
507   }
508   if (params.need_src_loop) {
509     c += "  } while (s < args.src_tensor.Slices());\n";
510   }
511   if (!params.x_kernel_is_1) {
512     c += "  x++;\n";
513     c += "  } while (x < args.kernel_size_x);\n";
514   }
515   if (!params.y_kernel_is_1) {
516     c += "  y++;\n";
517     c += "  } while (y < args.kernel_size_y);\n";
518   }
519 
520   if (late_xy_check && !params.linear_whs) {
521     c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height()) "
522          "return;\n";
523   }
524 
525   if (dst_is_linear) {
526     for_every_yx([](const std::string& s_yx, const std::string& s_x,
527                     const std::string& s_y, int x, int y) {
528       return "  args.dst_tensor.GetAddress(offset_" + s_yx + ", X + " + s_x +
529              ", Y + " + s_y + ", Z);";
530     });
531   }
532 
533   std::string bias_name = "args.biases.GetPtr()";
534   if (params.need_dst_loop) {
535     c += "  device FLT4* bias_loc = args.biases.GetPtr() + Z;\n";
536     bias_name = "bias_loc";
537   }
538   for (int y = 0; y < params.block_size.y; ++y) {
539     for (int x = 0; x < params.block_size.x; ++x) {
540       for (int z = 0; z < params.block_size.z; ++z) {
541         std::string r_id =
542             std::to_string(z) + std::to_string(y) + std::to_string(x);
543         c += "  r" + r_id + " += TO_ACCUM_TYPE(" + bias_name + "[" +
544              std::to_string(z) + "]);\n";
545       }
546     }
547   }
548   for (int z = 0; z < params.block_size.z; ++z) {
549     const std::string s_z = std::to_string(z);
550     c += "  if (Z + " + s_z + " < args.dst_tensor.Slices()) {\n";
551     for (int y = 0; y < params.block_size.y; ++y) {
552       const std::string s_y = std::to_string(y);
553       for (int x = 0; x < params.block_size.x; ++x) {
554         const std::string s_x = std::to_string(x);
555         const std::string s_yx = s_y + s_x;
556         const std::string s_zyx = s_z + s_yx;
557         bool need_check_x = x >= 1;
558         bool need_check_y = y >= 1;
559         std::string check;
560         if (need_check_x) {
561           check += "(X + " + s_x + ") < args.dst_tensor.Width()";
562         }
563         if (need_check_y) {
564           check += check.empty() ? "" : " && ";
565           check += "(Y + " + s_y + ") < args.dst_tensor.Height()";
566         }
567         if (!check.empty()) {
568           c += "    if (" + check + ") {\n";
569         } else {
570           c += "    {\n";
571         }
572         c += "      FLT4 value = FLT4(r" + s_zyx + ");\n";
573         if (dst_is_linear) {
574           c += "      int linear_index = offset_" + s_yx +
575                " + args.dst_tensor.SliceStride() * " + s_z + ";\n";
576           c += "      args.dst_tensor.Linking(value, X + " + s_x + ", Y + " +
577                s_y + ", Z + " + s_z + ");\n";
578           c += "      args.dst_tensor.WriteLinear(value, linear_index);\n";
579         } else {
580           c += "      args.dst_tensor.Write(value, X + " + s_x + ", Y + " +
581                s_y + ", Z + " + s_z + ");\n";
582         }
583         c += "    }\n";
584       }
585     }
586     c += "  }\n";
587   }
588   c += "}\n";
589   return c;
590 }
591 
ReorderWeightsForConv(const tflite::gpu::Tensor<OHWI,DataType::FLOAT32> & weights,const WeightsDescription & weights_desc,const DataType & weights_type)592 std::vector<uint8_t> ReorderWeightsForConv(
593     const tflite::gpu::Tensor<OHWI, DataType::FLOAT32>& weights,
594     const WeightsDescription& weights_desc, const DataType& weights_type) {
595   const int flt_count =
596       GetTotalElementsCountForLayout(weights_desc, weights.shape);
597   std::vector<uint8_t> result(flt_count * SizeOf(weights_type));
598   RearrangeWeights(weights, weights_desc, weights_type, absl::MakeSpan(result));
599   return result;
600 }
601 
ReorderBiasesForConv(const tflite::gpu::Tensor<Linear,DataType::FLOAT32> & biases,const DataType & biases_type,int output_size)602 std::vector<uint8_t> ReorderBiasesForConv(
603     const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& biases,
604     const DataType& biases_type, int output_size) {
605   std::vector<uint8_t> result(output_size * SizeOf(biases_type));
606   if (biases_type == DataType::FLOAT32) {
607     float* gpu_data = reinterpret_cast<float*>(result.data());
608     for (int i = 0; i < output_size; ++i) {
609       gpu_data[i] = i < biases.shape.v ? biases.data[i] : 0.0f;
610     }
611   } else {
612     half* gpu_data = reinterpret_cast<half*>(result.data());
613     for (int i = 0; i < output_size; ++i) {
614       gpu_data[i] = i < biases.shape.v ? biases.data[i] : 0.0f;
615     }
616   }
617   return result;
618 }
619 
GetGroupsCount(const BHWC & dst_shape,const int3 & wg_size,const int3 & block_size)620 int GetGroupsCount(const BHWC& dst_shape, const int3& wg_size,
621                    const int3& block_size) {
622   const int dst_slices = DivideRoundUp(dst_shape.c, 4);
623 
624   int grid_x = DivideRoundUp(dst_shape.w, block_size.x);
625   int grid_y = DivideRoundUp(dst_shape.h, block_size.y);
626   int grid_z = DivideRoundUp(dst_slices, block_size.z);
627 
628   return DivideRoundUp(grid_x, wg_size.x) * DivideRoundUp(grid_y, wg_size.y) *
629          DivideRoundUp(grid_z, wg_size.z);
630 }
631 
GetGroupsCountForLinearWH(const BHWC & dst_shape,const int3 & wg_size,const int3 & block_size)632 int GetGroupsCountForLinearWH(const BHWC& dst_shape, const int3& wg_size,
633                               const int3& block_size) {
634   const int dst_slices = DivideRoundUp(dst_shape.c, 4);
635 
636   int grid_x = DivideRoundUp(dst_shape.w, block_size.x);
637   int grid_y = DivideRoundUp(dst_shape.h, block_size.y);
638   int grid_z = DivideRoundUp(dst_slices, block_size.z);
639 
640   return DivideRoundUp(grid_x * grid_y, wg_size.x) *
641          DivideRoundUp(grid_z, wg_size.y);
642 }
643 
GetGroupsCountForLinearWHS(const BHWC & dst_shape,const int3 & wg_size,const int3 & block_size)644 int GetGroupsCountForLinearWHS(const BHWC& dst_shape, const int3& wg_size,
645                                const int3& block_size) {
646   const int dst_slices = DivideRoundUp(dst_shape.c, 4);
647 
648   int grid_x = DivideRoundUp(dst_shape.w, block_size.x);
649   int grid_y = DivideRoundUp(dst_shape.h, block_size.y);
650   int grid_z = DivideRoundUp(dst_slices, block_size.z);
651 
652   return DivideRoundUp(grid_x * grid_y * grid_z, wg_size.x);
653 }
654 
IsKernelXIs1(const Convolution2DAttributes & attr)655 bool IsKernelXIs1(const Convolution2DAttributes& attr) {
656   return attr.weights.shape.w == 1 && attr.strides.w == 1 &&
657          attr.dilations.w == 1 && attr.padding.prepended.w == 0 &&
658          attr.padding.appended.w == 0;
659 }
660 
IsKernelYIs1(const Convolution2DAttributes & attr)661 bool IsKernelYIs1(const Convolution2DAttributes& attr) {
662   return attr.weights.shape.h == 1 && attr.strides.h == 1 &&
663          attr.dilations.h == 1 && attr.padding.prepended.h == 0 &&
664          attr.padding.appended.h == 0;
665 }
666 
GetMaximumPossibleWavesCount(const AppleInfo & apple_info,const BHWC & dst_shape)667 int GetMaximumPossibleWavesCount(const AppleInfo& apple_info,
668                                  const BHWC& dst_shape) {
669   if (apple_info.IsLocalMemoryPreferredOverGlobal()) {
670     return GetGroupsCountForLinearWH(dst_shape, {32, 1, 1}, {1, 1, 1});
671   } else {
672     return GetGroupsCountForLinearWHS(dst_shape, {32, 1, 1}, {1, 1, 1});
673   }
674 }
675 
GetRecommendedBlockSize(const AppleInfo & apple_info,const BHWC & dst_shape)676 int GetRecommendedBlockSize(const AppleInfo& apple_info,
677                             const BHWC& dst_shape) {
678   const int max_waves = GetMaximumPossibleWavesCount(apple_info, dst_shape);
679   const int cu_count = apple_info.GetComputeUnitsCount();
680   if (max_waves >= cu_count * 64) {
681     return 8;
682   } else if (max_waves >= cu_count * 32) {
683     return 4;
684   } else if (max_waves >= cu_count * 16) {
685     return 2;
686   } else {
687     return 1;
688   }
689 }
690 
GetConvParamsForA7A8(const AppleInfo & apple_info,const Convolution2DAttributes & attr,const BHWC & dst_shape)691 ConvolutionMetal::ConvParams GetConvParamsForA7A8(
692     const AppleInfo& apple_info, const Convolution2DAttributes& attr,
693     const BHWC& dst_shape) {
694   const int dst_slices = DivideRoundUp(dst_shape.c, 4);
695   const int src_slices = DivideRoundUp(attr.weights.shape.i, 4);
696 
697   ConvolutionMetal::ConvParams params;
698   params.weights_upload_type =
699       ConvolutionMetal::WeightsUploadType::LOCAL_MEM_BY_THREADS;
700   params.x_kernel_is_1 = IsKernelXIs1(attr);
701   params.y_kernel_is_1 = IsKernelYIs1(attr);
702   params.src_depth_loop_size = 1;
703   params.block_size = int3(1, 1, 1);
704   params.linear_wh = false;
705   params.linear_whs = false;
706   params.work_group_launch_order = int3(0, 1, 2);
707   params.weights_layout = WeightsLayout::kOHWIOGroupO4I4;
708 
709   int blk_total_size = GetRecommendedBlockSize(apple_info, dst_shape);
710 
711   if (blk_total_size >= 4 && (dst_slices % 4 == 0 || dst_slices >= 16)) {
712     params.block_size.z = 4;
713     blk_total_size /= 4;
714   } else if (blk_total_size >= 2 && (dst_slices % 2 == 0 || dst_slices >= 4)) {
715     params.block_size.z = 2;
716     blk_total_size /= 2;
717   }
718   if (blk_total_size >= 4) {
719     params.block_size.x = 2;
720     params.block_size.y = 2;
721     blk_total_size /= 4;
722   } else if (blk_total_size >= 2) {
723     if (dst_shape.w % 2 != 0 && dst_shape.h % 2 == 0) {
724       params.block_size.y = 2;
725     } else {
726       params.block_size.x = 2;
727     }
728     blk_total_size /= 2;
729   }
730 
731   params.work_group_size = params.block_size.x <= params.block_size.y
732                                ? int3(8, 4, 1)
733                                : int3(4, 8, 1);
734 
735   int g1 = GetGroupsCount(dst_shape, params.work_group_size, params.block_size);
736   int g2 = GetGroupsCountForLinearWH(dst_shape, {32, 1, 1}, params.block_size);
737   int g3 = GetGroupsCountForLinearWHS(dst_shape, {32, 1, 1}, params.block_size);
738 
739   if (g2 < g1) {
740     params.linear_wh = true;
741     params.work_group_size = int3(32, 1, 1);
742     params.work_group_launch_order = int3(0, 1, 2);
743   }
744   float precise_threshold = 3.1f;
745   float precise_ratio = static_cast<float>(g2) / static_cast<float>(g3);
746   if (precise_ratio > precise_threshold) {
747     params.linear_wh = false;
748     params.linear_whs = true;
749     params.work_group_size = int3(32, 1, 1);
750     params.weights_upload_type =
751         ConvolutionMetal::WeightsUploadType::GLOBAL_MEM;
752   }
753 
754   if (params.src_depth_loop_size == src_slices) {
755     params.need_src_loop = false;
756   }
757   if (params.block_size.z == dst_slices) {
758     params.need_dst_loop = false;
759   }
760   const bool use_filters_constants =
761       !params.need_dst_loop && !params.need_src_loop && params.x_kernel_is_1 &&
762       params.y_kernel_is_1;
763   if (use_filters_constants) {
764     params.weights_upload_type =
765         ConvolutionMetal::WeightsUploadType::CONSTANT_MEM;
766   }
767 
768   return params;
769 }
770 
GetConvParamsForA9AndHigher(const AppleInfo & apple_info,const Convolution2DAttributes & attr,const BHWC & dst_shape)771 ConvolutionMetal::ConvParams GetConvParamsForA9AndHigher(
772     const AppleInfo& apple_info, const Convolution2DAttributes& attr,
773     const BHWC& dst_shape) {
774   const int dst_slices = DivideRoundUp(dst_shape.c, 4);
775   const int src_slices = DivideRoundUp(attr.weights.shape.i, 4);
776   int blk_total_size = GetRecommendedBlockSize(apple_info, dst_shape);
777   int3 block_size = int3(1, 1, 1);
778   if (blk_total_size >= 2 && apple_info.IsBionic()) {
779     if (dst_shape.h % 2 != 0 && dst_shape.w % 2 == 0) {
780       block_size.x = 2;
781     } else {
782       block_size.y = 2;
783     }
784     blk_total_size /= 2;
785   }
786   if (blk_total_size >= 4 && (dst_slices % 4 == 0 || dst_slices >= 16)) {
787     block_size.z = 4;
788     blk_total_size /= 4;
789   } else if (blk_total_size >= 2 && (dst_slices % 2 == 0 || dst_slices >= 4)) {
790     block_size.z = 2;
791     blk_total_size /= 2;
792   }
793   if (blk_total_size >= 4 && dst_slices == 3) {
794     block_size.z = 3;
795     blk_total_size /= 4;
796   }
797 
798   ConvolutionMetal::ConvParams params;
799   params.weights_upload_type = ConvolutionMetal::WeightsUploadType::GLOBAL_MEM;
800   params.x_kernel_is_1 = IsKernelXIs1(attr);
801   params.y_kernel_is_1 = IsKernelYIs1(attr);
802   params.src_depth_loop_size = 1;
803   params.block_size = block_size;
804   params.linear_wh = false;
805   params.linear_whs = false;
806   params.work_group_size = int3(8, 4, 1);
807   params.work_group_launch_order = int3(2, 0, 1);
808   params.weights_layout = WeightsLayout::kOHWIOGroupO4I4;
809   int g1 = GetGroupsCount(dst_shape, {8, 4, 1}, block_size);
810   int g2 = GetGroupsCountForLinearWH(dst_shape, {32, 1, 1}, block_size);
811   int g3 = GetGroupsCountForLinearWHS(dst_shape, {32, 1, 1}, block_size);
812   if (g2 < g1) {
813     params.linear_wh = true;
814     params.work_group_size = int3(32, 1, 1);
815     params.work_group_launch_order = int3(0, 1, 2);
816   }
817   float precise_threshold = apple_info.IsBionic() ? 1.0f : 1.04f;
818   float precise_ratio = static_cast<float>(g2) / static_cast<float>(g3);
819   if (precise_ratio > precise_threshold) {
820     params.linear_wh = false;
821     params.linear_whs = true;
822     params.work_group_size = int3(32, 1, 1);
823   }
824   int total_elements =
825       params.block_size.x * params.block_size.y * params.block_size.z;
826   if (total_elements == 1) {
827     if (src_slices % 4 == 0) {
828       params.src_depth_loop_size = 4;
829     } else if (src_slices % 2 == 0) {
830       params.src_depth_loop_size = 2;
831     }
832   } else if (total_elements == 2) {
833     if (src_slices % 2 == 0) {
834       params.src_depth_loop_size = 2;
835     }
836   }
837   if (params.src_depth_loop_size == src_slices) {
838     params.need_src_loop = false;
839   }
840   if (params.block_size.z == dst_slices) {
841     params.need_dst_loop = false;
842   }
843   const bool use_filters_constants =
844       !params.need_dst_loop && !params.need_src_loop && params.x_kernel_is_1 &&
845       params.y_kernel_is_1;
846   if (use_filters_constants) {
847     params.weights_upload_type =
848         ConvolutionMetal::WeightsUploadType::CONSTANT_MEM;
849   }
850 
851   return params;
852 }
853 
GetConvParamsForIntel(const Convolution2DAttributes & attr,CalculationsPrecision precision,const BHWC & dst_shape)854 ConvolutionMetal::ConvParams GetConvParamsForIntel(
855     const Convolution2DAttributes& attr, CalculationsPrecision precision,
856     const BHWC& dst_shape) {
857   const int dst_slices = DivideRoundUp(dst_shape.c, 4);
858   const int src_slices = DivideRoundUp(attr.weights.shape.i, 4);
859   ConvolutionMetal::ConvParams params;
860   params.weights_upload_type =
861       ConvolutionMetal::WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST;
862   params.x_kernel_is_1 = IsKernelXIs1(attr);
863   params.y_kernel_is_1 = IsKernelYIs1(attr);
864   params.src_depth_loop_size = 1;
865   params.linear_wh = false;
866   params.linear_whs = false;
867   params.work_group_launch_order = int3(2, 0, 1);
868   params.block_size = int3(1, 1, 1);
869   if (dst_slices % 4 == 0 || dst_slices >= 8) {
870     params.block_size.z = 4;
871   } else if (dst_slices % 2 == 0 || dst_slices >= 4) {
872     params.block_size.z = 2;
873   }
874   params.work_group_size = int3(8, 2, 1);
875   if (precision == CalculationsPrecision::F32_F16) {
876     params.weights_layout = WeightsLayout::kOHWIOGroupO4I4;
877   } else {
878     params.weights_layout = WeightsLayout::kOHWIOGroupI4O4;
879   }
880 
881   if (src_slices % 2 == 0) {
882     params.src_depth_loop_size = 2;
883   }
884 
885   int g1 = GetGroupsCount(dst_shape, params.work_group_size, params.block_size);
886   int g2 = GetGroupsCountForLinearWH(dst_shape, {16, 1, 1}, params.block_size);
887 
888   if (g2 < g1) {
889     params.linear_wh = true;
890     params.work_group_size = int3(16, 1, 1);
891     params.work_group_launch_order = int3(1, 0, 2);
892   }
893 
894   return params;
895 }
896 
GetConvParamsForAMD(const Convolution2DAttributes & attr,CalculationsPrecision precision,const BHWC & dst_shape)897 ConvolutionMetal::ConvParams GetConvParamsForAMD(
898     const Convolution2DAttributes& attr, CalculationsPrecision precision,
899     const BHWC& dst_shape) {
900   ConvolutionMetal::ConvParams params;
901   params.block_size = int3(1, 1, 4);
902   params.work_group_size = int3(8, 4, 1);
903   params.work_group_launch_order = int3(2, 0, 1);
904   params.src_depth_loop_size = 1;
905   params.need_src_loop = true;
906   params.need_dst_loop = true;
907   params.linear_wh = false;
908   params.linear_whs = false;
909   params.weights_upload_type = ConvolutionMetal::WeightsUploadType::GLOBAL_MEM;
910   params.different_weights_for_height = false;
911   params.x_kernel_is_1 = IsKernelXIs1(attr);
912   params.y_kernel_is_1 = IsKernelYIs1(attr);
913   if (precision == CalculationsPrecision::F32_F16) {
914     params.weights_layout = WeightsLayout::kOHWIOGroupO4I4;
915   } else {
916     params.weights_layout = WeightsLayout::kOHWIOGroupI4O4;
917   }
918   return params;
919 }
920 
GetConvParams(const GpuInfo & gpu_info,const Convolution2DAttributes & attr,CalculationsPrecision precision,const BHWC & dst_shape)921 ConvolutionMetal::ConvParams GetConvParams(const GpuInfo& gpu_info,
922                                            const Convolution2DAttributes& attr,
923                                            CalculationsPrecision precision,
924                                            const BHWC& dst_shape) {
925   if (gpu_info.IsApple()) {
926     if (gpu_info.apple_info.IsLocalMemoryPreferredOverGlobal()) {
927       return GetConvParamsForA7A8(gpu_info.apple_info, attr, dst_shape);
928     } else {
929       return GetConvParamsForA9AndHigher(gpu_info.apple_info, attr, dst_shape);
930     }
931   } else if (gpu_info.IsIntel()) {
932     return GetConvParamsForIntel(attr, precision, dst_shape);
933   } else if (gpu_info.IsAMD()) {
934     return GetConvParamsForAMD(attr, precision, dst_shape);
935   } else {
936     ConvolutionMetal::ConvParams params;
937     params.block_size = int3(1, 1, 4);
938     params.work_group_size = int3(8, 4, 1);
939     params.work_group_launch_order = int3(2, 0, 1);
940     params.src_depth_loop_size = 1;
941     params.need_src_loop = true;
942     params.need_dst_loop = true;
943     params.linear_wh = false;
944     params.linear_whs = false;
945     params.weights_upload_type =
946         ConvolutionMetal::WeightsUploadType::GLOBAL_MEM;
947     params.different_weights_for_height = false;
948     params.x_kernel_is_1 = IsKernelXIs1(attr);
949     params.y_kernel_is_1 = IsKernelYIs1(attr);
950     params.weights_layout = WeightsLayout::kOHWIOGroupO4I4;
951     return params;
952   }
953 }
954 
955 }  // namespace
956 
BindArguments(ArgumentsBinder * args)957 absl::Status ConvolutionMetal::BindArguments(ArgumentsBinder* args) {
958   RETURN_IF_ERROR(args->SetInt("padding_x", padding_.x * src_[0]->Batch()));
959   RETURN_IF_ERROR(args->SetInt("dilation_x", dilation_.x * src_[0]->Batch()));
960   const int grid_x =
961       DivideRoundUp(dst_[0]->Width() * dst_[0]->Batch(), params_.block_size.x);
962   const int grid_y = DivideRoundUp(dst_[0]->Height(), params_.block_size.y);
963   RETURN_IF_ERROR(args->SetInt("task_size_x", grid_x));
964   RETURN_IF_ERROR(args->SetInt("task_size_y", grid_x * grid_y));
965   return absl::OkStatus();
966 }
967 
GetGridSize() const968 int3 ConvolutionMetal::GetGridSize() const {
969   int grid_x =
970       DivideRoundUp(dst_[0]->Width() * dst_[0]->Batch(), params_.block_size.x);
971   int grid_y = DivideRoundUp(dst_[0]->Height(), params_.block_size.y);
972   int grid_z = DivideRoundUp(dst_[0]->Slices(), params_.block_size.z);
973 
974   int3 group_size(params_.work_group_size);
975   int3 wg;
976   uint3 groups_count;
977   if (params_.linear_whs) {
978     return int3(grid_x * grid_y * grid_z, 1, 1);
979   } else if (params_.linear_wh) {
980     return int3(grid_x * grid_y, grid_z, 1);
981   } else {
982     return int3(grid_x, grid_y, grid_z);
983   }
984 }
985 
CreateConvolutionMetal(const OperationDef & definition,const BHWC & dst_shape,const Convolution2DAttributes & attr,const GpuInfo & gpu_info)986 ConvolutionMetal CreateConvolutionMetal(const OperationDef& definition,
987                                         const BHWC& dst_shape,
988                                         const Convolution2DAttributes& attr,
989                                         const GpuInfo& gpu_info) {
990   BHWC new_shape = BHWC(1, dst_shape.h, dst_shape.w * dst_shape.b, dst_shape.c);
991   ConvolutionMetal::ConvParams params =
992       GetConvParams(gpu_info, attr, definition.precision, new_shape);
993 
994   ConvolutionMetal desc(definition);
995   desc.params_ = params;
996   const bool stride_correction =
997       definition.IsBatchSupported() && attr.strides.w != 1;
998   desc.code_ = GenerateConvolution(params, definition, stride_correction);
999 
1000   auto src_desc = definition.src_tensors[0];
1001   if (definition.IsBatchSupported()) {
1002     src_desc.SetStateVar("BatchedWidth", "true");
1003   }
1004   desc.AddSrcTensor("src_tensor", src_desc);
1005   auto dst_desc = definition.dst_tensors[0];
1006   if (definition.IsBatchSupported()) {
1007     dst_desc.SetStateVar("BatchedWidth", "true");
1008   }
1009   desc.AddDstTensor("dst_tensor", dst_desc);
1010 
1011   desc.args_.AddInt("kernel_size_x", attr.weights.shape.w);
1012   desc.args_.AddInt("kernel_size_y", attr.weights.shape.h);
1013   desc.args_.AddInt("dilation_x", attr.dilations.w);
1014   desc.args_.AddInt("dilation_y", attr.dilations.h);
1015   desc.args_.AddInt("stride_x", attr.strides.w);
1016   desc.args_.AddInt("stride_y", attr.strides.h);
1017   desc.args_.AddInt("padding_x", -attr.padding.prepended.w);
1018   desc.args_.AddInt("padding_y", -attr.padding.prepended.h);
1019   desc.padding_ = int2(-attr.padding.prepended.w, -attr.padding.prepended.h);
1020   desc.dilation_ = int2(attr.dilations.w, attr.dilations.h);
1021 
1022   auto weights_type = DeduceDataTypeFromPrecision(definition.precision);
1023 
1024   MemoryType mem_type =
1025       params.weights_upload_type ==
1026               ConvolutionMetal::WeightsUploadType::CONSTANT_MEM
1027           ? MemoryType::CONSTANT
1028           : MemoryType::GLOBAL;
1029 
1030   if (definition.src_tensors.size() == 2) {
1031     // dynamic weights
1032     BufferDescriptor weights_desc;
1033     weights_desc.element_type = definition.src_tensors[1].data_type;
1034     weights_desc.element_size = 4;
1035     weights_desc.memory_type = mem_type;
1036     desc.AddSrcBuffer("weights", weights_desc);
1037   } else {
1038     BufferDescriptor weights_desc;
1039     weights_desc.element_type = weights_type;
1040     weights_desc.element_size = 4;
1041     weights_desc.memory_type = mem_type;
1042     weights_desc.data = ReorderWeightsForConv(
1043         attr.weights, desc.GetWeightsDescription(), weights_type);
1044     weights_desc.size = weights_desc.data.size();
1045     desc.args_.AddObject("weights", absl::make_unique<BufferDescriptor>(
1046                                         std::move(weights_desc)));
1047   }
1048 
1049   BufferDescriptor bias_desc;
1050   bias_desc.element_type = weights_type;
1051   bias_desc.element_size = 4;
1052   bias_desc.memory_type = mem_type;
1053   bias_desc.data = ReorderBiasesForConv(
1054       attr.bias, weights_type,
1055       AlignByN(attr.weights.shape.o, params.block_size.z * 4));
1056   bias_desc.size = bias_desc.data.size();
1057   desc.args_.AddObject(
1058       "biases", absl::make_unique<BufferDescriptor>(std::move(bias_desc)));
1059 
1060   desc.args_.AddInt("task_size_x");
1061   desc.args_.AddInt("task_size_y");
1062 
1063   desc.work_group_size_ = params.work_group_size;
1064   desc.work_group_launch_order_ = params.work_group_launch_order;
1065   if (params.linear_whs) {
1066     desc.grid_dimension_ = 1;
1067   } else if (params.linear_wh) {
1068     desc.grid_dimension_ = 2;
1069   } else {
1070     desc.grid_dimension_ = 3;
1071   }
1072 
1073   return desc;
1074 }
1075 
CreateConvolutionMetalWino4x4To6x6(const OperationDef & definition,const BHWC & dst_shape,const Convolution2DAttributes & attr,const GpuInfo & gpu_info)1076 ConvolutionMetal CreateConvolutionMetalWino4x4To6x6(
1077     const OperationDef& definition, const BHWC& dst_shape,
1078     const Convolution2DAttributes& attr, const GpuInfo& gpu_info) {
1079   ConvolutionMetal::ConvParams params;
1080   params.work_group_launch_order = int3(2, 0, 1);
1081   params.src_depth_loop_size = 1;
1082   params.need_src_loop = true;
1083   params.need_dst_loop = true;
1084   params.linear_wh = false;
1085   params.linear_whs = false;
1086   params.different_weights_for_height = true;
1087   params.x_kernel_is_1 = true;
1088   params.y_kernel_is_1 = true;
1089   if (gpu_info.IsApple()) {
1090     params.weights_layout = WeightsLayout::kOHWIOGroupO4I4;
1091     if (gpu_info.apple_info.IsLocalMemoryPreferredOverGlobal()) {
1092       params.weights_upload_type =
1093           ConvolutionMetal::WeightsUploadType::LOCAL_MEM_BY_THREADS;
1094       params.work_group_size = int3(32, 1, 1);
1095       params.block_size = int3(4, 1, 4);
1096     } else {
1097       params.weights_upload_type =
1098           ConvolutionMetal::WeightsUploadType::GLOBAL_MEM;
1099       params.work_group_size = int3(8, 4, 1);
1100       params.block_size = int3(4, 1, 4);
1101     }
1102   } else if (gpu_info.IsIntel()) {
1103     params.weights_layout = WeightsLayout::kOHWIOGroupI4O4;
1104     params.weights_upload_type =
1105         ConvolutionMetal::WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST;
1106     params.work_group_size = int3(16, 1, 1);
1107     params.block_size = int3(1, 1, 4);
1108   } else if (gpu_info.IsAMD()) {
1109     params.weights_layout = WeightsLayout::kOHWIOGroupI4O4;
1110     params.weights_upload_type =
1111         ConvolutionMetal::WeightsUploadType::GLOBAL_MEM;
1112     params.work_group_size = int3(32, 1, 1);
1113     params.block_size = int3(2, 1, 4);
1114   } else {
1115     params.weights_layout = WeightsLayout::kOHWIOGroupI4O4;
1116     params.weights_upload_type =
1117         ConvolutionMetal::WeightsUploadType::GLOBAL_MEM;
1118     params.work_group_size = int3(32, 1, 1);
1119     params.block_size = int3(2, 1, 4);
1120   }
1121 
1122   ConvolutionMetal desc(definition);
1123   desc.params_ = params;
1124   desc.code_ = GenerateConvolution(params, definition, false);
1125   auto src_desc = definition.src_tensors[0];
1126   if (definition.IsBatchSupported()) {
1127     src_desc.SetStateVar("BatchedWidth", "true");
1128   }
1129   desc.AddSrcTensor("src_tensor", src_desc);
1130   auto dst_desc = definition.dst_tensors[0];
1131   if (definition.IsBatchSupported()) {
1132     dst_desc.SetStateVar("BatchedWidth", "true");
1133   }
1134   desc.AddDstTensor("dst_tensor", dst_desc);
1135 
1136   desc.args_.AddInt("kernel_size_x", 1);
1137   desc.args_.AddInt("kernel_size_y", 1);
1138   desc.args_.AddInt("dilation_x", 1);
1139   desc.args_.AddInt("dilation_y", 1);
1140   desc.args_.AddInt("stride_x", 1);
1141   desc.args_.AddInt("stride_y", 1);
1142   desc.args_.AddInt("padding_x", 0);
1143   desc.args_.AddInt("padding_y", 0);
1144   desc.padding_ = int2(0, 0);
1145   desc.dilation_ = int2(1, 1);
1146 
1147   auto weights_type = DeduceDataTypeFromPrecision(definition.precision);
1148 
1149   tflite::gpu::Tensor<OHWI, DataType::FLOAT32> wino_weights;
1150   tflite::gpu::Tensor<Linear, DataType::FLOAT32> wino_biases;
1151   RearrangeWeightsToWinograd4x4To6x6Weights(attr.weights, &wino_weights);
1152   wino_biases.shape = Linear(attr.weights.shape.o);
1153   wino_biases.data.resize(attr.weights.shape.o, 0.0f);
1154 
1155   BufferDescriptor weights_desc;
1156   weights_desc.element_type = weights_type;
1157   weights_desc.element_size = 4;
1158   weights_desc.data = ReorderWeightsForConv(
1159       wino_weights, desc.GetWeightsDescription(), weights_type);
1160   weights_desc.size = weights_desc.data.size();
1161   desc.args_.AddObject(
1162       "weights", absl::make_unique<BufferDescriptor>(std::move(weights_desc)));
1163 
1164   BufferDescriptor bias_desc;
1165   bias_desc.element_type = weights_type;
1166   bias_desc.element_size = 4;
1167   bias_desc.data = ReorderBiasesForConv(
1168       wino_biases, weights_type,
1169       AlignByN(attr.weights.shape.o, params.block_size.z * 4));
1170   bias_desc.size = bias_desc.data.size();
1171   desc.args_.AddObject(
1172       "biases", absl::make_unique<BufferDescriptor>(std::move(bias_desc)));
1173 
1174   desc.args_.AddInt("task_size_x");
1175   desc.args_.AddInt("task_size_y");
1176 
1177   desc.work_group_size_ = params.work_group_size;
1178   desc.work_group_launch_order_ = params.work_group_launch_order;
1179   if (params.linear_whs) {
1180     desc.grid_dimension_ = 1;
1181   } else if (params.linear_wh) {
1182     desc.grid_dimension_ = 2;
1183   } else {
1184     desc.grid_dimension_ = 3;
1185   }
1186 
1187   return desc;
1188 }
1189 
IsConvolutionMetalSupported(const OperationDef & definition)1190 bool IsConvolutionMetalSupported(const OperationDef& definition) {
1191   return !definition.src_tensors[0].HasAxis(Axis::DEPTH);
1192 }
1193 
1194 }  // namespace gpu
1195 }  // namespace tflite
1196