1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed_3x3.h"
17 
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
23 
24 namespace tflite {
25 namespace gpu {
26 
ConvolutionTransposed3x3(const OperationDef & definition,const GpuInfo & gpu_info,int2 padding)27 ConvolutionTransposed3x3::ConvolutionTransposed3x3(
28     const OperationDef& definition, const GpuInfo& gpu_info, int2 padding)
29     : GPUOperation(definition), padding_(padding) {
30   work_group_size_ = int3(8, 4, 1);
31   work_group_launch_order_ = int3(2, 0, 1);
32   if (gpu_info.IsPowerVR()) {
33     weights_upload_type_ = WeightsUploadType::LOCAL_MEM_ASYNC;
34   } else if (gpu_info.IsNvidia() || gpu_info.IsIntel()) {
35     weights_upload_type_ = WeightsUploadType::LOCAL_MEM_BY_THREADS;
36   } else if (gpu_info.IsAMD()) {
37     weights_upload_type_ = WeightsUploadType::CONSTANT_MEM;
38   } else {
39     weights_upload_type_ = WeightsUploadType::GLOBAL_MEM;
40   }
41   if (gpu_info.IsApple()) {
42     weights_layout_ = WeightsLayout::kOICustomSpatialO4I4;
43   } else {
44     weights_layout_ = WeightsLayout::kOICustomSpatialI4O4;
45   }
46   code_ = GenerateConvolutionTransposedCode(gpu_info, definition_,
47                                             weights_upload_type_, padding_,
48                                             work_group_launch_order_);
49   if (definition_.precision == CalculationsPrecision::F16 &&
50       gpu_info.IsPowerVR()) {
51     compiler_options_.push_back(CompilerOptions::kClPowervrFp16);
52   }
53 }
54 
GenerateConvolutionTransposedCode(const GpuInfo & gpu_info,const OperationDef & op_def,ConvolutionTransposed3x3::WeightsUploadType weights_upload_type,int2 padding,int3 work_group_launch_order)55 std::string ConvolutionTransposed3x3::GenerateConvolutionTransposedCode(
56     const GpuInfo& gpu_info, const OperationDef& op_def,
57     ConvolutionTransposed3x3::WeightsUploadType weights_upload_type,
58     int2 padding, int3 work_group_launch_order) {
59   auto src_desc = op_def.src_tensors[0];
60   src_desc.SetAddressMode(AddressMode::kZero);
61   if (op_def.IsBatchSupported()) {
62     src_desc.SetStateVar("BatchedWidth", "true");
63   }
64   AddSrcTensor("src_tensor", src_desc);
65 
66   auto dst_desc = op_def.dst_tensors[0];
67   if (op_def.IsBatchSupported()) {
68     dst_desc.SetStateVar("BatchedWidth", "true");
69   }
70   AddDstTensor("dst_tensor", dst_desc);
71 
72   if (op_def.src_tensors.size() == 2) {
73     // dynamic weights
74     BufferDescriptor desc;
75     desc.element_type = op_def.src_tensors[1].data_type;
76     desc.element_size = 4;
77     desc.memory_type =
78         weights_upload_type ==
79                 ConvolutionTransposed3x3::WeightsUploadType::CONSTANT_MEM
80             ? MemoryType::CONSTANT
81             : MemoryType::GLOBAL;
82     AddSrcBuffer("weights", desc);
83   }
84 
85   args_.AddInt("filter_offset");
86   args_.AddInt("padding_x");
87   args_.AddInt("padding_y");
88 
89   const bool need_local_mem =
90       weights_upload_type ==
91           ConvolutionTransposed3x3::WeightsUploadType::LOCAL_MEM_BY_THREADS ||
92       weights_upload_type ==
93           ConvolutionTransposed3x3::WeightsUploadType::LOCAL_MEM_ASYNC;
94 
95   std::string c;
96   if (GetWeightsDescription().IsI4O4()) {
97     switch (op_def.precision) {
98       case CalculationsPrecision::F32:
99       case CalculationsPrecision::F16:
100         c += "#define CONV(R, SRC, F) \\\n";
101         c += "  R += SRC.x * weights_cache[F]; \\\n";
102         c += "  R += SRC.y * weights_cache[F + 1]; \\\n";
103         c += "  R += SRC.z * weights_cache[F + 2]; \\\n";
104         c += "  R += SRC.w * weights_cache[F + 3];   \n";
105         break;
106       case CalculationsPrecision::F32_F16:
107         c += "#define CONV(R, SRC, F) \\\n";
108         c += "  R += TO_ACCUM_TYPE(SRC.x * weights_cache[F] + SRC.y * "
109              "weights_cache[F + 1] + SRC.z * weights_cache[F + 2] + SRC.w * "
110              "weights_cache[F + 3]);\n";
111         break;
112     }
113   } else {
114     // O4I4
115     c += "#define CONV(R, SRC, F) \\\n";
116     c += "  R.x += dot(SRC, weights_cache[F]); \\\n";
117     c += "  R.y += dot(SRC, weights_cache[F + 1]); \\\n";
118     c += "  R.z += dot(SRC, weights_cache[F + 2]); \\\n";
119     c += "  R.w += dot(SRC, weights_cache[F + 3]);   \n";
120   }
121 
122   const int wg_total_size =
123       work_group_size_.x * work_group_size_.y * work_group_size_.z;
124   const std::string barrier =
125       wg_total_size == 32 && gpu_info.IsWaveSizeEqualTo32()
126           ? "SIMD_LOCAL_MEM_BARRIER"
127           : "LOCAL_MEM_BARRIER";
128   const std::string weights_space =
129       weights_upload_type ==
130               ConvolutionTransposed3x3::WeightsUploadType::CONSTANT_MEM
131           ? "__constant"
132           : "__global";
133 
134   const std::string pixel_stride =
135       op_def.IsBatchSupported() ? "args.dst_tensor.Batch()" : "1";
136   if (gpu_info.IsApiOpenCl()) {
137     c += "__attribute__((reqd_work_group_size(8, 4, 1)))\n";
138   }
139   c += "MAIN_FUNCTION($0) {\n";
140   int3 launch_remap;
141   launch_remap[work_group_launch_order.x] = 0;
142   launch_remap[work_group_launch_order.y] = 1;
143   launch_remap[work_group_launch_order.z] = 2;
144   auto GetGlobalID = [&](int id) {
145     std::string result;
146     const std::string sid = std::to_string(id);
147     if (work_group_launch_order[id] == id) {
148       return "GLOBAL_ID_" + sid;
149     } else {
150       return "GROUP_ID_" + std::to_string(launch_remap[id]) + " * GROUP_SIZE_" +
151              sid + " + LOCAL_ID_" + sid;
152     }
153   };
154   if (op_def.IsBatchSupported()) {
155     c += "  int linear_id = " + GetGlobalID(0) + ";\n";
156     c += "  int X0 = linear_id / args.dst_tensor.Batch();\n";
157     c += "  int B = linear_id % args.dst_tensor.Batch();\n";
158     c += "  int DST_X = X0 * 2 * args.dst_tensor.Batch() + B;\n";
159     c += "  int SRC_X = linear_id + args.padding_x;\n";
160   } else {
161     c += "  int X = " + GetGlobalID(0) + ";\n";
162     c += "  int DST_X = X * 2;\n";
163     c += "  int SRC_X = X + args.padding_x;\n";
164   }
165   c += "  int Y = " + GetGlobalID(1) + ";\n";
166   c += "  int DST_Y = Y * 2;\n";
167   c += "  int SRC_Y = Y + args.padding_y;\n";
168   c += "  int Z = " + GetGlobalID(2) + ";\n";
169   if (!need_local_mem) {
170     c += "  if (DST_X >= args.dst_tensor.Width() || DST_Y >= "
171          "args.dst_tensor.Height() || Z >= args.dst_tensor.Slices()) return;\n";
172   }
173   c += "  ACCUM_FLT4 r0 = INIT_ACCUM_FLT4(0.0f);\n";
174   c += "  ACCUM_FLT4 r1 = INIT_ACCUM_FLT4(0.0f);\n";
175   c += "  ACCUM_FLT4 r2 = INIT_ACCUM_FLT4(0.0f);\n";
176   c += "  ACCUM_FLT4 r3 = INIT_ACCUM_FLT4(0.0f);\n";
177   c += "  int f_offset = Z * args.filter_offset;\n";
178   if (need_local_mem) {
179     c += "  __local FLT4 weights_cache[36];\n";
180   }
181   if (weights_upload_type ==
182       ConvolutionTransposed3x3::WeightsUploadType::LOCAL_MEM_BY_THREADS) {
183     c += "  int local_id = LOCAL_ID_1 * 8 + LOCAL_ID_0;\n";
184   }
185   const std::string next_x = "SRC_X + " + pixel_stride;
186   if (!src_desc.SupportsZeroClamp(Axis::WIDTH)) {
187     c += "  bool in_x0 = SRC_X >= 0 && SRC_X < args.src_tensor.Width();\n";
188     c += "  bool in_x1 = " + next_x + " >= 0 && " + next_x +
189          " < args.src_tensor.Width();\n";
190   }
191   if (!src_desc.SupportsZeroClamp(Axis::HEIGHT)) {
192     c += "  bool in_y0 = SRC_Y >= 0 && SRC_Y < args.src_tensor.Height();\n";
193     c += "  bool in_y1 = SRC_Y + 1 >= 0 && SRC_Y + 1 < "
194          "args.src_tensor.Height();\n";
195   }
196   auto generate_check = [&](int x, int y) {
197     std::string check;
198     const std::vector<Axis> axes{Axis::WIDTH, Axis::HEIGHT};
199     const std::vector<std::string> names{"in_x" + std::to_string(x),
200                                          "in_y" + std::to_string(y)};
201     for (int i = 0; i < axes.size(); ++i) {
202       const auto& axis = axes[i];
203       if (src_desc.HasAxis(axis) && !src_desc.SupportsZeroClamp(axis)) {
204         if (!check.empty()) {
205           check += " && ";
206         }
207         check += names[i];
208       }
209     }
210     return check;
211   };
212   if (src_desc.IsLinear()) {
213     if (src_desc.ReturnsZeroForNegOneRead()) {
214       c += "  args.src_tensor.GetAddress(addr_0, SRC_X, SRC_Y, 0);\n";
215       c += "  args.src_tensor.GetAddress(addr_1," + next_x + ", SRC_Y, 0);\n";
216       c += "  args.src_tensor.GetAddress(addr_2, SRC_X, SRC_Y + 1, 0);\n";
217       c += "  args.src_tensor.GetAddress(addr_3," + next_x + ", SRC_Y+1, 0);\n";
218       c += "  addr_0 = select(-1, addr_0, (in_x0 && in_y0));\n";
219       c += "  addr_1 = select(-1, addr_1, (in_x1 && in_y0));\n";
220       c += "  addr_2 = select(-1, addr_2, (in_x0 && in_y1));\n";
221       c += "  addr_3 = select(-1, addr_3, (in_x1 && in_y1));\n";
222       c += "  int dz_0 = select(0, args.src_tensor.SliceStride(), (in_x0 && "
223            "in_y0));\n";
224       c += "  int dz_1 = select(0, args.src_tensor.SliceStride(), (in_x1 && "
225            "in_y0));\n";
226       c += "  int dz_2 = select(0, args.src_tensor.SliceStride(), (in_x0 && "
227            "in_y1));\n";
228       c += "  int dz_3 = select(0, args.src_tensor.SliceStride(), (in_x1 && "
229            "in_y1));\n";
230     } else {
231       c += "  int xc0 = clamp(SRC_X, 0, args.src_tensor.Width() - 1);\n";
232       c += "  int xc1 = clamp(" + next_x +
233            ", 0, args.src_tensor.Width() - 1);\n";
234       c += "  int yc0 = clamp(SRC_Y, 0, args.src_tensor.Height() - 1);\n";
235       c += "  int yc1 = clamp(SRC_Y + 1, 0, args.src_tensor.Height() - 1);\n";
236       c += "  args.src_tensor.GetAddress(addr_0, xc0, yc0, 0);\n";
237       c += "  args.src_tensor.GetAddress(addr_1, xc1, yc0, 0);\n";
238       c += "  args.src_tensor.GetAddress(addr_2, xc0, yc1, 0);\n";
239       c += "  args.src_tensor.GetAddress(addr_3, xc1, yc1, 0);\n";
240       c += "  int dz = args.src_tensor.SliceStride();\n";
241     }
242   }
243   auto read_src = [&](int x, int y) {
244     if (src_desc.IsLinear()) {
245       const std::string id = std::to_string(y * 2 + x);
246       const std::string addr = "addr_" + std::to_string(y * 2 + x);
247       if (src_desc.ReturnsZeroForNegOneRead()) {
248         return "args.src_tensor.Read(" + addr + "); " + addr + " += dz_" + id +
249                ";\n";
250       } else {
251         return "args.src_tensor.Read(" + addr + ") * INIT_FLT(in_x" +
252                std::to_string(x) + " && in_y" + std::to_string(y) + "); " +
253                addr + " += dz;\n";
254       }
255     } else {
256       std::string check = generate_check(x, y);
257       if (!check.empty()) {
258         check = " * INIT_FLT(" + check + ")";
259       }
260       return "args.src_tensor.Read(SRC_X + " + std::to_string(x) + "*" +
261              pixel_stride + ", SRC_Y + " + std::to_string(y) + ", s)" + check +
262              ";\n";
263     }
264   };
265   const int padding_x_rem = abs(padding.x) % 2;
266   const int padding_y_rem = abs(padding.y) % 2;
267   std::vector<std::pair<int, int>> permutation;
268   if (padding_x_rem == 1 && padding_y_rem == 1) {
269     permutation = {{0, 0}, {1, 0}, {1, 1}, {2, 0}, {2, 2},
270                    {3, 0}, {3, 1}, {3, 2}, {3, 3}};
271   } else if (padding_x_rem == 0 && padding_y_rem == 1) {
272     permutation = {{0, 0}, {0, 1}, {1, 1}, {2, 0}, {2, 1},
273                    {2, 2}, {2, 3}, {3, 1}, {3, 3}};
274   } else if (padding_x_rem == 1 && padding_y_rem == 0) {
275     permutation = {{0, 0}, {0, 2}, {1, 0}, {1, 1}, {1, 2},
276                    {1, 3}, {2, 2}, {3, 2}, {3, 3}};
277   } else {  // padding_x_rem == 0 && padding_y_rem == 0
278     permutation = {{0, 0}, {0, 1}, {0, 2}, {0, 3}, {1, 1},
279                    {1, 3}, {2, 2}, {2, 3}, {3, 3}};
280   }
281   c += "  for (int s = 0; s < args.src_tensor.Slices(); ++s) {\n";
282   if (need_local_mem) {
283     c += "    " + barrier + ";\n";
284   }
285   if (weights_upload_type ==
286       ConvolutionTransposed3x3::WeightsUploadType::LOCAL_MEM_ASYNC) {
287     c += "    async_work_group_copy(weights_cache, "
288          "args.weights.GetPtr(f_offset), 36, "
289          "0);\n";
290   } else if (weights_upload_type ==
291              ConvolutionTransposed3x3::WeightsUploadType::
292                  LOCAL_MEM_BY_THREADS) {
293     c += "    weights_cache[local_id] = args.weights.Read(f_offset + "
294          "local_id);\n";
295     c += "    if (local_id < 4) {\n";
296     c += "      weights_cache[local_id + 32] = args.weights.Read(f_offset + "
297          "local_id + "
298          "32);\n";
299     c += "    };\n";
300   } else {  // GLOBAL_MEM/CONSTANT_MEM
301     c += "    " + weights_space +
302          " FLT4* weights_cache = args.weights.GetPtr(f_offset);\n";
303   }
304   c += "    FLT4 src0 = " + read_src(0, 0);
305   c += "    FLT4 src1 = " + read_src(1, 0);
306   c += "    FLT4 src2 = " + read_src(0, 1);
307   c += "    FLT4 src3 = " + read_src(1, 1);
308   c += "    f_offset += 36;\n";
309   if (need_local_mem) {
310     c += "    " + barrier + ";\n";
311   }
312   for (int i = 0; i < 9; ++i) {
313     const std::string r_name = "r" + std::to_string(permutation[i].first);
314     const std::string s_name = "src" + std::to_string(permutation[i].second);
315     const std::string w_name = std::to_string(i * 4);
316     c += "    CONV(" + r_name + ", " + s_name + ", " + w_name + ");\n";
317   }
318   c += "  }\n";
319   if (need_local_mem) {
320     c += "  if (DST_X >= args.dst_tensor.Width() || DST_Y >= "
321          "args.dst_tensor.Height() || Z >= args.dst_tensor.Slices()) return;\n";
322   }
323   c += "  FLT4 bias_val = args.biases.Read(Z);\n";
324   for (int y = 0; y < 2; ++y) {
325     for (int x = 0; x < 2; ++x) {
326       const std::string s_x = std::to_string(x);
327       const std::string s_y = std::to_string(y);
328       const std::string id = std::to_string(y * 2 + x);
329       const std::string x_c = "DST_X + " + s_x + " * " + pixel_stride;
330       const std::string y_c = "DST_Y + " + s_y;
331       c += "  if (" + x_c + " < args.dst_tensor.Width() && " + y_c +
332            " < args.dst_tensor.Height()) {\n";
333       c += "    FLT4 res0 = TO_FLT4(r" + id + ") + bias_val;\n";
334       c += "    args.dst_tensor.Write(res0, " + x_c + ", " + y_c + ", Z);\n";
335       c += "  }\n";
336     }
337   }
338   c += "}\n";
339   return c;
340 }
341 
BindArguments(ArgumentsBinder * args)342 absl::Status ConvolutionTransposed3x3::BindArguments(ArgumentsBinder* args) {
343   RETURN_IF_ERROR(args->SetInt("filter_offset", 4 * 9 * src_[0]->Slices()));
344   const int padding_x =
345       padding_.x >= 1 ? (padding_.x - 1) / 2 : (padding_.x - 2) / 2;
346   const int padding_y =
347       padding_.y >= 1 ? (padding_.y - 1) / 2 : (padding_.y - 2) / 2;
348   RETURN_IF_ERROR(args->SetInt("padding_x", padding_x * src_[0]->Batch()));
349   return args->SetInt("padding_y", padding_y);
350 }
351 
GetPossibleKernelWorkGroups(TuningType tuning_type,const GpuInfo & gpu_info,const KernelInfo & kernel_info,std::vector<int3> * work_groups) const352 void ConvolutionTransposed3x3::GetPossibleKernelWorkGroups(
353     TuningType tuning_type, const GpuInfo& gpu_info,
354     const KernelInfo& kernel_info, std::vector<int3>* work_groups) const {
355   if (weights_upload_type_ == WeightsUploadType::LOCAL_MEM_ASYNC ||
356       weights_upload_type_ == WeightsUploadType::LOCAL_MEM_BY_THREADS) {
357     work_groups->push_back(work_group_size_);
358     return;
359   }
360   GetPossibleWorkGroupsConv(tuning_type, gpu_info, kernel_info, grid_size_,
361                             work_groups);
362 }
363 
GetGridSize() const364 int3 ConvolutionTransposed3x3::GetGridSize() const {
365   const int grid_x = DivideRoundUp(dst_[0]->Width(), 2) * dst_[0]->Batch();
366   const int grid_y = DivideRoundUp(dst_[0]->Height(), 2);
367   const int grid_z = dst_[0]->Slices();
368   return int3(grid_x, grid_y, grid_z);
369 }
370 
GetSpatialWeightsRemap() const371 std::vector<int> ConvolutionTransposed3x3::GetSpatialWeightsRemap() const {
372   const int padding_x_rem = abs(padding_.x) % 2;
373   const int padding_y_rem = abs(padding_.y) % 2;
374 
375   std::vector<int> remap;
376   if (padding_x_rem == 1 && padding_y_rem == 1) {
377     return std::vector<int>{4, 5, 3, 7, 1, 8, 6, 2, 0};
378   } else if (padding_x_rem == 0 && padding_y_rem == 1) {
379     return std::vector<int>{5, 3, 4, 8, 6, 2, 0, 7, 1};
380   } else if (padding_x_rem == 1 && padding_y_rem == 0) {
381     return std::vector<int>{7, 1, 8, 6, 2, 0, 4, 5, 3};
382   } else {  // padding_x_rem == 0 && padding_y_rem == 0
383     return std::vector<int>{8, 6, 2, 0, 7, 1, 5, 3, 4};
384   }
385 }
386 
UploadWeights(const tflite::gpu::Tensor<OHWI,DataType::FLOAT32> & weights)387 void ConvolutionTransposed3x3::UploadWeights(
388     const tflite::gpu::Tensor<OHWI, DataType::FLOAT32>& weights) {
389   const int flt_count =
390       GetTotalElementsCountForLayout(GetWeightsDescription(), weights.shape);
391 
392   DataType weights_type = definition_.precision == CalculationsPrecision::F32
393                               ? DataType::FLOAT32
394                               : DataType::FLOAT16;
395 
396   BufferDescriptor desc;
397   desc.element_type = weights_type;
398   desc.element_size = 4;
399   desc.memory_type =
400       weights_upload_type_ ==
401               ConvolutionTransposed3x3::WeightsUploadType::CONSTANT_MEM
402           ? MemoryType::CONSTANT
403           : MemoryType::GLOBAL;
404   desc.size = flt_count * SizeOf(desc.element_type);
405   desc.data.resize(desc.size);
406 
407   RearrangeWeights(weights, GetWeightsDescription(), weights_type,
408                    absl::MakeSpan(desc.data));
409 
410   args_.AddObject("weights",
411                   absl::make_unique<BufferDescriptor>(std::move(desc)));
412 }
413 
IsConvolutionTransposed3x3Supported(const OperationDef & definition,const ConvolutionTransposedAttributes & attr)414 bool IsConvolutionTransposed3x3Supported(
415     const OperationDef& definition,
416     const ConvolutionTransposedAttributes& attr) {
417   return attr.weights.shape.w == 3 && attr.weights.shape.h == 3 &&
418          attr.stride.w == 2 && attr.stride.h == 2;
419 }
420 
CreateConvolutionTransposed3x3(const GpuInfo & gpu_info,const OperationDef & definition,const ConvolutionTransposedAttributes & attr)421 ConvolutionTransposed3x3 CreateConvolutionTransposed3x3(
422     const GpuInfo& gpu_info, const OperationDef& definition,
423     const ConvolutionTransposedAttributes& attr) {
424   const int2 padding = int2(attr.padding.prepended.w, attr.padding.prepended.h);
425   ConvolutionTransposed3x3 result(definition, gpu_info, padding);
426   result.UploadWeights(attr.weights);
427 
428   TensorLinearDescriptor desc;
429   desc.storage_type = LinearStorageType::TEXTURE_2D;
430   desc.element_type = definition.GetDataType();
431   desc.UploadLinearData(attr.bias);
432   result.args_.AddObject(
433       "biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
434   return result;
435 }
436 
CreateConvolutionTransposed3x3DynamicWeights(const GpuInfo & gpu_info,const OperationDef & definition,const ConvolutionTransposedAttributes & attr)437 ConvolutionTransposed3x3 CreateConvolutionTransposed3x3DynamicWeights(
438     const GpuInfo& gpu_info, const OperationDef& definition,
439     const ConvolutionTransposedAttributes& attr) {
440   OperationDef new_def = definition;
441   new_def.src_tensors = {
442       definition.src_tensors[0]};  // leaving only src_tensor def, weights defs
443                                    // will be added later
444   const DataType weights_type = definition.GetDataType();
445   // add 1 src_tensor(buffer) for weights
446   new_def.src_tensors.push_back(
447       {weights_type, TensorStorageType::BUFFER, Layout::HWC});
448 
449   const int2 padding = int2(attr.padding.prepended.w, attr.padding.prepended.h);
450   ConvolutionTransposed3x3 result(new_def, gpu_info, padding);
451 
452   TensorLinearDescriptor desc;
453   desc.storage_type = LinearStorageType::TEXTURE_2D;
454   desc.element_type = new_def.GetDataType();
455   desc.UploadLinearData(attr.bias);
456   result.args_.AddObject(
457       "biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
458   return result;
459 }
460 
461 }  // namespace gpu
462 }  // namespace tflite
463