1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/substitute.h"
20 #include "tensorflow/lite/delegates/gpu/common/shape.h"
21 #include "tensorflow/lite/delegates/gpu/common/util.h"
22 
23 namespace tflite {
24 namespace gpu {
25 namespace {
GetReadImageFromDataType(DataType data_type)26 std::string GetReadImageFromDataType(DataType data_type) {
27   if (data_type == DataType::FLOAT32) {
28     return "read_imagef";
29   } else if (data_type == DataType::FLOAT16) {
30     return "read_imageh";
31   } else {
32     return "error";
33   }
34 }
35 
GetWriteImageFromDataType(DataType data_type)36 std::string GetWriteImageFromDataType(DataType data_type) {
37   if (data_type == DataType::FLOAT32) {
38     return "write_imagef";
39   } else if (data_type == DataType::FLOAT16) {
40     return "write_imageh";
41   } else {
42     return "error";
43   }
44 }
45 
AddressModeToCLSampler(AddressMode address_mode)46 std::string AddressModeToCLSampler(AddressMode address_mode) {
47   switch (address_mode) {
48     case AddressMode::kDontCare:
49       return "smp_none";
50     case AddressMode::kZero:
51       return "smp_zero";
52   }
53 }
54 
55 }  // namespace
56 
ToString(TensorStorageType type)57 std::string ToString(TensorStorageType type) {
58   switch (type) {
59     case TensorStorageType::UNKNOWN:
60       return "TensorStorageType::UNKNOWN";
61     case TensorStorageType::BUFFER:
62       return "TensorStorageType::BUFFER";
63     case TensorStorageType::TEXTURE_ARRAY:
64       return "TensorStorageType::TEXTURE_ARRAY";
65     case TensorStorageType::TEXTURE_2D:
66       return "TensorStorageType::TEXTURE_2D";
67     case TensorStorageType::TEXTURE_3D:
68       return "TensorStorageType::TEXTURE_3D";
69     case TensorStorageType::SINGLE_TEXTURE_2D:
70       return "TensorStorageType::SINGLE_TEXTURE_2D";
71     case TensorStorageType::IMAGE_BUFFER:
72       return "TensorStorageType::IMAGE_BUFFER";
73   }
74 }
75 
TensorDescriptor(TensorDescriptor && desc)76 TensorDescriptor::TensorDescriptor(TensorDescriptor&& desc)
77     : GPUObjectDescriptor(std::move(desc)),
78       data_type(desc.data_type),
79       storage_type(desc.storage_type),
80       layout(desc.layout),
81       shape(desc.shape),
82       data(std::move(desc.data)) {}
operator =(TensorDescriptor && desc)83 TensorDescriptor& TensorDescriptor::operator=(TensorDescriptor&& desc) {
84   if (this != &desc) {
85     std::swap(data_type, desc.data_type);
86     std::swap(storage_type, desc.storage_type);
87     std::swap(layout, desc.layout);
88     std::swap(shape, desc.shape);
89     data = std::move(desc.data);
90     GPUObjectDescriptor::operator=(std::move(desc));
91   }
92   return *this;
93 }
94 
GetGPUResources() const95 GPUResources TensorDescriptor::GetGPUResources() const {
96   GPUResources resources;
97   resources.ints.push_back("slice_stride");
98   if (HasAxis(Axis::WIDTH)) {
99     resources.ints.push_back("width");
100     resources.ints.push_back("width_div2");
101     resources.ints.push_back("width_div4");
102     resources.ints.push_back("width_batched");
103     resources.ints.push_back("width_batched_div2");
104     resources.ints.push_back("width_batched_div4");
105   }
106   if (HasAxis(Axis::HEIGHT)) {
107     resources.ints.push_back("height");
108   }
109   if (HasAxis(Axis::CHANNELS)) {
110     resources.ints.push_back("slices");
111     resources.ints.push_back("channels");
112   }
113   if (HasAxis(Axis::BATCH)) {
114     resources.ints.push_back("batch");
115   }
116   if (HasAxis(Axis::DEPTH)) {
117     resources.ints.push_back("depth");
118   }
119   if (storage_type == TensorStorageType::BUFFER) {
120     GPUBufferDescriptor desc;
121     desc.data_type = data_type;
122     desc.access_type = access_type_;
123     desc.element_size = 4;
124     auto it1 = state_vars_.find("ElementsX2");
125     if (it1 != state_vars_.end() && it1->second == "true") {
126       desc.element_size = 8;
127     }
128     auto it2 = state_vars_.find("ElementsX4");
129     if (it2 != state_vars_.end() && it2->second == "true") {
130       desc.element_size = 16;
131     }
132     resources.buffers.push_back({"buffer", desc});
133   } else if (storage_type == TensorStorageType::SINGLE_TEXTURE_2D ||
134              storage_type == TensorStorageType::TEXTURE_2D) {
135     GPUImage2DDescriptor desc;
136     desc.data_type = data_type;
137     desc.access_type = access_type_;
138     resources.images2d.push_back({"image2d", desc});
139   } else if (storage_type == TensorStorageType::TEXTURE_ARRAY) {
140     GPUImage2DArrayDescriptor desc;
141     desc.data_type = data_type;
142     desc.access_type = access_type_;
143     resources.image2d_arrays.push_back({"image2d_array", desc});
144   } else if (storage_type == TensorStorageType::TEXTURE_3D) {
145     GPUImage3DDescriptor desc;
146     desc.data_type = data_type;
147     desc.access_type = access_type_;
148     resources.images3d.push_back({"image3d", desc});
149   } else if (storage_type == TensorStorageType::IMAGE_BUFFER) {
150     if (access_type_ == AccessType::READ) {
151       GPUImageBufferDescriptor desc;
152       desc.data_type = data_type;
153       desc.access_type = access_type_;
154       resources.image_buffers.push_back({"image_buffer", desc});
155     } else {
156       GPUBufferDescriptor desc;
157       desc.data_type = data_type;
158       desc.access_type = access_type_;
159       desc.element_size = 4;
160       resources.buffers.push_back({"buffer", desc});
161     }
162   }
163   return resources;
164 }
165 
PerformSelector(const GpuInfo & gpu_info,const std::string & selector,const std::vector<std::string> & args,const std::vector<std::string> & template_args,std::string * result) const166 absl::Status TensorDescriptor::PerformSelector(
167     const GpuInfo& gpu_info, const std::string& selector,
168     const std::vector<std::string>& args,
169     const std::vector<std::string>& template_args, std::string* result) const {
170   if (selector == "Width") {
171     *result = GetWidth();
172     return absl::OkStatus();
173   } else if (selector == "Height") {
174     *result = "height";
175     return absl::OkStatus();
176   } else if (selector == "Slices") {
177     *result = "slices";
178     return absl::OkStatus();
179   } else if (selector == "SliceStride") {
180     *result = "slice_stride";
181     return absl::OkStatus();
182   } else if (selector == "Channels") {
183     *result = "channels";
184     return absl::OkStatus();
185   } else if (selector == "Batch") {
186     if (HasAxis(Axis::BATCH)) {
187       *result = "batch";
188     } else {
189       *result = "1";
190     }
191     return absl::OkStatus();
192   } else if (selector == "Depth") {
193     *result = "depth";
194     return absl::OkStatus();
195   } else if (selector == "SetBatchRef") {
196     if (args.size() != 1) {
197       return absl::InvalidArgumentError(
198           "Unsupported arguments in SetBatchRef selector");
199     }
200     state_vars_["batch_id"] = args[0];
201     *result = "";
202     return absl::OkStatus();
203   } else if (selector == "Read") {
204     return PerformReadSelector(gpu_info, args, template_args, result);
205   } else if (selector == "Write") {
206     return PerformWriteSelector(gpu_info, args, result);
207   } else if (selector == "WriteLinear") {
208     return PerformWriteLinearSelector(gpu_info, args, result);
209   } else if (selector == "Write2D") {
210     return PerformWrite2DSelector(gpu_info, args, result);
211   } else if (selector == "GetAddress") {
212     return PerformGetAddressSelector(args, result);
213   } else if (selector == "GetPtrWithSliceOffset") {
214     return PerformGetPtrWithSliceOffsetSelector(args, result);
215   } else if (selector == "GetWHOffset") {
216     return PerformGetWHOffsetSelector(args, result);
217   } else if (selector == "GetHandle") {
218     return PerformGetHandleSelector(args, result);
219   } else {
220     return absl::NotFoundError(absl::StrCat(
221         "TensorDescriptor don't have selector with name - ", selector));
222   }
223 }
224 
PerformReadSelector(const GpuInfo & gpu_info,const std::vector<std::string> & args,const std::vector<std::string> & template_args,std::string * result) const225 absl::Status TensorDescriptor::PerformReadSelector(
226     const GpuInfo& gpu_info, const std::vector<std::string>& args,
227     const std::vector<std::string>& template_args, std::string* result) const {
228   DataType read_as_type = data_type;
229   if (!template_args.empty()) {
230     if (template_args.size() != 1) {
231       return absl::NotFoundError(
232           "Unrecognized Read selector template arguments.");
233     } else {
234       RETURN_IF_ERROR(
235           GetDataTypeFromTemplateArgs(template_args[0], &read_as_type));
236     }
237   }
238   if (args.size() == 1) {  // function overload for 1D linear types.
239     if (storage_type == TensorStorageType::BUFFER ||
240         storage_type == TensorStorageType::IMAGE_BUFFER) {
241       *result = Read(gpu_info, read_as_type, {args[0]});
242       return absl::OkStatus();
243     } else {
244       return absl::InvalidArgumentError(
245           "Read selector with single argument can be used only with linear "
246           "storage types(BUFFER or IMAGE_BUFFER)");
247     }
248   }
249   std::string xc;
250   std::string yc;
251   std::string zc;
252   std::string sc;
253   std::string bc;
254   bool parsed = ParseCoordsFromArgs(args, 0, &xc, &yc, &zc, &sc, &bc);
255   if (args.size() < 2 || !parsed) {
256     return absl::NotFoundError("Unrecognized Read selector");
257   }
258 
259   *result = Read(gpu_info, read_as_type, GetPhysicalCoords(xc, yc, zc, sc, bc));
260   return absl::OkStatus();
261 }
262 
GetLinkingContextFromWriteSelector(const std::vector<std::string> & args,std::string * value_name,std::string * x_coord,std::string * y_coord,std::string * s_coord) const263 absl::Status TensorDescriptor::GetLinkingContextFromWriteSelector(
264     const std::vector<std::string>& args, std::string* value_name,
265     std::string* x_coord, std::string* y_coord, std::string* s_coord) const {
266   std::string xc;
267   std::string yc;
268   std::string zc;
269   std::string sc;
270   std::string bc;
271   bool parsed = ParseCoordsFromArgs(args, 1, &xc, &yc, &zc, &sc, &bc);
272   if (args.size() < 2 || !parsed) {
273     return absl::NotFoundError("Unrecognized Write selector");
274   }
275   *value_name = args[0];
276   if (HasAxis(Axis::BATCH) && !IsBatchedWidth()) {
277     *x_coord = absl::StrCat("((", xc, ") * batch + (", bc, "))");
278   } else {
279     *x_coord = absl::StrCat("(", xc, ")");
280   }
281   *y_coord = absl::StrCat("(", yc, ")");
282   *s_coord = absl::StrCat("(", sc, ")");
283   return absl::OkStatus();
284 }
285 
PerformWriteSelector(const GpuInfo & gpu_info,const std::vector<std::string> & args,std::string * result) const286 absl::Status TensorDescriptor::PerformWriteSelector(
287     const GpuInfo& gpu_info, const std::vector<std::string>& args,
288     std::string* result) const {
289   std::string xc;
290   std::string yc;
291   std::string zc;
292   std::string sc;
293   std::string bc;
294   bool parsed = ParseCoordsFromArgs(args, 1, &xc, &yc, &zc, &sc, &bc);
295   if (args.size() < 2 || !parsed) {
296     return absl::NotFoundError("Unrecognized Write selector");
297   }
298   *result = Write(gpu_info, args[0], GetPhysicalCoords(xc, yc, zc, sc, bc));
299   return absl::OkStatus();
300 }
301 
PerformWriteLinearSelector(const GpuInfo & gpu_info,const std::vector<std::string> & args,std::string * result) const302 absl::Status TensorDescriptor::PerformWriteLinearSelector(
303     const GpuInfo& gpu_info, const std::vector<std::string>& args,
304     std::string* result) const {
305   if (storage_type != TensorStorageType::BUFFER &&
306       storage_type != TensorStorageType::IMAGE_BUFFER) {
307     return absl::InvalidArgumentError(
308         "WriteLinear selector can be used only with linear "
309         "storages(BUFFER/IMAGE_BUFFER)");
310   }
311   if (args.size() != 2) {
312     return absl::NotFoundError("Unrecognized WriteLinear selector");
313   }
314   *result = Write(gpu_info, args[0], {args[1]});
315   return absl::OkStatus();
316 }
317 
PerformWrite2DSelector(const GpuInfo & gpu_info,const std::vector<std::string> & args,std::string * result) const318 absl::Status TensorDescriptor::PerformWrite2DSelector(
319     const GpuInfo& gpu_info, const std::vector<std::string>& args,
320     std::string* result) const {
321   if (storage_type != TensorStorageType::TEXTURE_2D) {
322     return absl::InvalidArgumentError(
323         "Write2D selector can be used only with 2d "
324         "storages(TEXTURE_2D)");
325   }
326   if (args.size() != 3) {
327     return absl::NotFoundError("Unrecognized Write2D selector");
328   }
329   *result = Write(gpu_info, args[0], {args[1], args[2]});
330   return absl::OkStatus();
331 }
332 
Read(const GpuInfo & gpu_info,DataType read_as_type,const std::vector<std::string> & coords) const333 std::string TensorDescriptor::Read(
334     const GpuInfo& gpu_info, DataType read_as_type,
335     const std::vector<std::string>& coords) const {
336   const std::string read_as =
337       read_as_type == DataType::FLOAT16 ? "read_imageh" : "read_imagef";
338   const bool need_conversion = read_as_type != data_type;
339   const std::string metal_type =
340       read_as_type == DataType::FLOAT32 ? "float4" : "half4";
341   switch (storage_type) {
342     case TensorStorageType::BUFFER:
343       if (read_as_type == data_type) {
344         return absl::StrCat("buffer[", coords[0], "]");
345       } else {
346         std::string conversion;
347         if (gpu_info.IsApiMetal()) {
348           conversion = metal_type;
349         } else if (gpu_info.IsApiOpenCl()) {
350           if (read_as_type == DataType::FLOAT16) {
351             conversion = "convert_half4";
352           } else if (read_as_type == DataType::FLOAT32) {
353             conversion = "convert_float4";
354           }
355         }
356         return absl::StrCat(conversion, "(buffer[", coords[0], "])");
357       }
358     case TensorStorageType::TEXTURE_2D:
359     case TensorStorageType::SINGLE_TEXTURE_2D:
360       if (gpu_info.IsApiOpenCl()) {
361         return absl::Substitute("$0(image2d, $1, (int2)($2, $3))", read_as,
362                                 AddressModeToCLSampler(AddressModeFromState()),
363                                 coords[0], coords[1]);
364       } else if (gpu_info.IsApiMetal()) {
365         std::string result = absl::Substitute("image2d.read(ushort2($0, $1))",
366                                               coords[0], coords[1]);
367         if (need_conversion) {
368           result = metal_type + "(" + result + ")";
369         }
370         return result;
371       } else {
372         return "";
373       }
374     case TensorStorageType::TEXTURE_3D:
375       if (gpu_info.IsApiOpenCl()) {
376         return absl::Substitute("$0(image3d, $1, (int4)($2, $3, $4, 0))",
377                                 read_as,
378                                 AddressModeToCLSampler(AddressModeFromState()),
379                                 coords[0], coords[1], coords[2]);
380       } else if (gpu_info.IsApiMetal()) {
381         std::string result =
382             absl::Substitute("image3d.read(ushort3($0, $1, $2))", coords[0],
383                              coords[1], coords[2]);
384         if (need_conversion) {
385           result = metal_type + "(" + result + ")";
386         }
387         return result;
388       } else {
389         return "";
390       }
391     case TensorStorageType::TEXTURE_ARRAY:
392       if (gpu_info.IsApiOpenCl()) {
393         return absl::Substitute("$0(image2d_array, $1, (int4)($2, $3, $4, 0))",
394                                 read_as,
395                                 AddressModeToCLSampler(AddressModeFromState()),
396                                 coords[0], coords[1], coords[2]);
397       } else if (gpu_info.IsApiMetal()) {
398         std::string result =
399             absl::Substitute("image2d_array.read(ushort2($0, $1), $2)",
400                              coords[0], coords[1], coords[2]);
401         if (need_conversion) {
402           result = metal_type + "(" + result + ")";
403         }
404         return result;
405       } else {
406         return "";
407       }
408     case TensorStorageType::IMAGE_BUFFER:
409       if (gpu_info.IsApiOpenCl()) {
410         return absl::StrCat(read_as, "(image_buffer, ", coords[0], ")");
411       } else if (gpu_info.IsApiMetal()) {
412         std::string result =
413             absl::Substitute("image_buffer.read(uint($0))", coords[0]);
414         if (need_conversion) {
415           result = metal_type + "(" + result + ")";
416         }
417         return result;
418       } else {
419         return "";
420       }
421     case TensorStorageType::UNKNOWN:
422       return "";
423   }
424 }
425 
Write(const GpuInfo & gpu_info,const std::string & var_name,const std::vector<std::string> & coords) const426 std::string TensorDescriptor::Write(
427     const GpuInfo& gpu_info, const std::string& var_name,
428     const std::vector<std::string>& coords) const {
429   switch (storage_type) {
430     case TensorStorageType::BUFFER:
431     case TensorStorageType::IMAGE_BUFFER:
432       return absl::StrCat("buffer[", coords[0], "] = ", var_name);
433     case TensorStorageType::SINGLE_TEXTURE_2D:
434     case TensorStorageType::TEXTURE_2D:
435       if (gpu_info.IsApiOpenCl()) {
436         return absl::Substitute("$0(image2d, (int2)($1, $2), $3)",
437                                 GetWriteImageFromDataType(data_type), coords[0],
438                                 coords[1], var_name);
439       } else if (gpu_info.IsApiMetal()) {
440         return absl::Substitute("image2d.write($0, ushort2($1, $2))", var_name,
441                                 coords[0], coords[1]);
442       } else {
443         return "";
444       }
445     case TensorStorageType::TEXTURE_3D:
446       if (gpu_info.IsApiOpenCl()) {
447         return absl::Substitute("$0(image3d, (int4)($1, $2, $3, 0), $4)",
448                                 GetWriteImageFromDataType(data_type), coords[0],
449                                 coords[1], coords[2], var_name);
450       } else if (gpu_info.IsApiMetal()) {
451         return absl::Substitute("image3d.write($0, ushort3($1, $2, $3))",
452                                 var_name, coords[0], coords[1], coords[2]);
453       } else {
454         return "";
455       }
456     case TensorStorageType::TEXTURE_ARRAY:
457       if (gpu_info.IsApiOpenCl()) {
458         return absl::Substitute("$0(image2d_array, (int4)($1, $2, $3, 0), $4)",
459                                 GetWriteImageFromDataType(data_type), coords[0],
460                                 coords[1], coords[2], var_name);
461       } else if (gpu_info.IsApiMetal()) {
462         return absl::Substitute("image2d_array.write($0, ushort2($1, $2), $3)",
463                                 var_name, coords[0], coords[1], coords[2]);
464       } else {
465         return "";
466       }
467     case TensorStorageType::UNKNOWN:
468       return "";
469   }
470 }
471 
PerformGetAddressSelector(const std::vector<std::string> & args,std::string * result) const472 absl::Status TensorDescriptor::PerformGetAddressSelector(
473     const std::vector<std::string>& args, std::string* result) const {
474   std::string xc;
475   std::string yc;
476   std::string zc;
477   std::string sc;
478   std::string bc;
479   bool parsed = ParseCoordsFromArgs(args, 1, &xc, &yc, &zc, &sc, &bc);
480   if (args.size() < 3 || !parsed) {
481     return absl::NotFoundError("Unrecognized GetAddress selector");
482   }
483 
484   *result = DeclareAddress(args[0],
485                            GetGlobalAddressNoDeclaration(xc, yc, zc, sc, bc));
486   return absl::OkStatus();
487 }
488 
PerformGetPtrWithSliceOffsetSelector(const std::vector<std::string> & args,std::string * result) const489 absl::Status TensorDescriptor::PerformGetPtrWithSliceOffsetSelector(
490     const std::vector<std::string>& args, std::string* result) const {
491   if (storage_type != TensorStorageType::BUFFER) {
492     return absl::InvalidArgumentError(
493         "GetPtrWithSliceOffset selector can be used only with BUFFER");
494   }
495   if (args.size() != 1) {
496     return absl::NotFoundError(absl::StrCat(
497         "GetPtrWithSliceOffset require one argument(slice coordinate), but ",
498         args.size(), " was passed"));
499   }
500   *result = absl::StrCat("buffer + ", args[0], " * slice_stride");
501   return absl::OkStatus();
502 }
503 
PerformGetWHOffsetSelector(const std::vector<std::string> & args,std::string * result) const504 absl::Status TensorDescriptor::PerformGetWHOffsetSelector(
505     const std::vector<std::string>& args, std::string* result) const {
506   if (storage_type != TensorStorageType::BUFFER &&
507       storage_type != TensorStorageType::IMAGE_BUFFER) {
508     return absl::InvalidArgumentError(
509         "GetWHOffset selector can be used only with BUFFER/IMAGE_BUFFER");
510   }
511   if (args.size() != 2) {
512     return absl::NotFoundError(absl::StrCat(
513         "GetWHOffset require two arguments(X and Y coordinates), but ",
514         args.size(), " was passed"));
515   }
516   if (HasAxis(Axis::BATCH) && !IsBatchedWidth()) {
517     auto it = state_vars_.find("batch_id");
518     std::string batch_id;
519     if (it == state_vars_.end()) {
520       return absl::NotFoundError(
521           "Not found batch_id. Should be setted up by SetBatchRef(). method");
522     } else {
523       batch_id = it->second;
524     }
525     *result = absl::StrCat("((", args[1], ") * ", GetWidth(), " + (", args[0],
526                            ")) * batch + (", batch_id, ")");
527   } else {
528     *result =
529         absl::StrCat("(", args[1], ") * ", GetWidth(), " + (", args[0], ")");
530   }
531   return absl::OkStatus();
532 }
533 
PerformGetHandleSelector(const std::vector<std::string> & args,std::string * result) const534 absl::Status TensorDescriptor::PerformGetHandleSelector(
535     const std::vector<std::string>& args, std::string* result) const {
536   if (!args.empty()) {
537     return absl::NotFoundError(
538         absl::StrCat("GetHandle does not require arguments, but ", args.size(),
539                      " was passed"));
540   }
541   switch (storage_type) {
542     case TensorStorageType::BUFFER:
543       *result = "buffer";
544       return absl::OkStatus();
545     case TensorStorageType::IMAGE_BUFFER:
546       if (access_type_ == AccessType::READ) {
547         *result = "image_buffer";
548       } else {
549         *result = "buffer";
550       }
551       return absl::OkStatus();
552     case TensorStorageType::TEXTURE_2D:
553     case TensorStorageType::SINGLE_TEXTURE_2D:
554       *result = "image2d";
555       return absl::OkStatus();
556     case TensorStorageType::TEXTURE_ARRAY:
557       *result = "image2d_array";
558       return absl::OkStatus();
559     case TensorStorageType::TEXTURE_3D:
560       *result = "image3d";
561       return absl::OkStatus();
562     case TensorStorageType::UNKNOWN:
563       return absl::UnavailableError("Unknown type");
564   }
565 }
566 
DeclareAddress(const std::string & var_name,const std::string & address) const567 std::string TensorDescriptor::DeclareAddress(const std::string& var_name,
568                                              const std::string& address) const {
569   return absl::StrCat(StorageTypeToAddressType(), " ", var_name, " = ", address,
570                       ";");
571 }
572 
StorageTypeToAddressType() const573 std::string TensorDescriptor::StorageTypeToAddressType() const {
574   switch (storage_type) {
575     case TensorStorageType::BUFFER:
576     case TensorStorageType::IMAGE_BUFFER:
577       return "int";
578     case TensorStorageType::TEXTURE_2D:
579     case TensorStorageType::SINGLE_TEXTURE_2D:
580       return "int2";
581     case TensorStorageType::TEXTURE_ARRAY:
582     case TensorStorageType::TEXTURE_3D:
583       return "int4";
584     case TensorStorageType::UNKNOWN:
585       return "";
586   }
587 }
588 
GetPhysicalCoordsWHS(const std::string & x,const std::string & y,const std::string & s) const589 std::vector<std::string> TensorDescriptor::GetPhysicalCoordsWHS(
590     const std::string& x, const std::string& y, const std::string& s) const {
591   switch (storage_type) {
592     case TensorStorageType::BUFFER:
593     case TensorStorageType::IMAGE_BUFFER:
594       return {absl::Substitute("((($2) * height + ($1)) * $3 + ($0))", x, y, s,
595                                GetWidth())};
596     case TensorStorageType::TEXTURE_2D:
597       return {absl::Substitute("($0)", x),
598               absl::Substitute("(($0) * slices + ($1))", y, s)};
599     case TensorStorageType::SINGLE_TEXTURE_2D:
600       return {absl::Substitute("($0)", x), absl::Substitute("($0)", y)};
601     case TensorStorageType::TEXTURE_ARRAY:
602     case TensorStorageType::TEXTURE_3D:
603       return {absl::Substitute("($0)", x), absl::Substitute("($0)", y),
604               absl::Substitute("($0)", s)};
605     case TensorStorageType::UNKNOWN:
606       return {""};
607     default:
608       return {""};
609   }
610 }
611 
GetPhysicalCoordsWHSB(const std::string & x,const std::string & y,const std::string & s,const std::string & b) const612 std::vector<std::string> TensorDescriptor::GetPhysicalCoordsWHSB(
613     const std::string& x, const std::string& y, const std::string& s,
614     const std::string& b) const {
615   switch (storage_type) {
616     case TensorStorageType::BUFFER:
617     case TensorStorageType::IMAGE_BUFFER:
618       return {absl::Substitute(
619           "(((($3) * height + $2) * width + ($1)) * batch + ($0))", b, x, y,
620           s)};
621     case TensorStorageType::TEXTURE_2D:
622       return {absl::Substitute("(($0) * batch + ($1))", x, b),
623               absl::Substitute("(($0) * slices + ($1))", y, s)};
624     case TensorStorageType::SINGLE_TEXTURE_2D:
625       return {absl::Substitute("(($0) * batch + ($1))", x, b),
626               absl::Substitute("($0)", y)};
627     case TensorStorageType::TEXTURE_ARRAY:
628     case TensorStorageType::TEXTURE_3D:
629       return {absl::Substitute("(($0) * batch + ($1))", x, b),
630               absl::Substitute("($0)", y), absl::Substitute("($0)", s)};
631     case TensorStorageType::UNKNOWN:
632       return {""};
633     default:
634       return {""};
635   }
636 }
637 
GetPhysicalCoordsWHDS(const std::string & x,const std::string & y,const std::string & z,const std::string & s) const638 std::vector<std::string> TensorDescriptor::GetPhysicalCoordsWHDS(
639     const std::string& x, const std::string& y, const std::string& z,
640     const std::string& s) const {
641   switch (storage_type) {
642     case TensorStorageType::BUFFER:
643     case TensorStorageType::IMAGE_BUFFER:
644       return {absl::Substitute(
645           "(((($3) * slices + ($2)) * height + ($1)) * $4 + ($0))", x, y, s, z,
646           GetWidth())};
647     case TensorStorageType::TEXTURE_2D:
648       return {absl::Substitute("(($0) * depth + ($1))", x, z),
649               absl::Substitute("(($0) * slices + ($1))", y, s)};
650     case TensorStorageType::SINGLE_TEXTURE_2D:
651       return {absl::Substitute("(($0) * depth + ($1))", x, z),
652               absl::Substitute("($0)", y)};
653     case TensorStorageType::TEXTURE_ARRAY:
654     case TensorStorageType::TEXTURE_3D:
655       return {absl::Substitute("($0)", x), absl::Substitute("($0)", y),
656               absl::Substitute("(($0) * slices + ($1))", z, s)};
657     case TensorStorageType::UNKNOWN:
658       return {""};
659     default:
660       return {""};
661   }
662 }
663 
GetPhysicalCoordsWHDSB(const std::string & x,const std::string & y,const std::string & z,const std::string & s,const std::string & b) const664 std::vector<std::string> TensorDescriptor::GetPhysicalCoordsWHDSB(
665     const std::string& x, const std::string& y, const std::string& z,
666     const std::string& s, const std::string& b) const {
667   switch (storage_type) {
668     case TensorStorageType::BUFFER:
669     case TensorStorageType::IMAGE_BUFFER:
670       return {absl::Substitute(
671           "((((($4) * slices + ($3)) * height + $2) * width + ($1)) * batch + "
672           "($0))",
673           b, x, y, s, z)};
674     case TensorStorageType::TEXTURE_2D:
675       return {absl::Substitute("((($0)*batch + ($1))*depth + ($2))", x, b, z),
676               absl::Substitute("(($0) * slices + ($1))", y, s)};
677     case TensorStorageType::SINGLE_TEXTURE_2D:
678       return {absl::Substitute("((($0)*batch + ($1))*depth + ($2))", x, b, z),
679               absl::Substitute("($0)", y)};
680     case TensorStorageType::TEXTURE_ARRAY:
681     case TensorStorageType::TEXTURE_3D:
682       return {absl::Substitute("(($0) * batch + ($1))", x, b),
683               absl::Substitute("($0)", y),
684               absl::Substitute("(($0) * slices + ($1))", z, s)};
685     case TensorStorageType::UNKNOWN:
686       return {""};
687     default:
688       return {""};
689   }
690 }
691 
GetGlobalAddressNoDeclaration(const std::string & xc,const std::string & yc,const std::string & zc,const std::string & sc,const std::string & bc) const692 std::string TensorDescriptor::GetGlobalAddressNoDeclaration(
693     const std::string& xc, const std::string& yc, const std::string& zc,
694     const std::string& sc, const std::string& bc) const {
695   auto coords = GetPhysicalCoords(xc, yc, zc, sc, bc);
696   switch (storage_type) {
697     case TensorStorageType::BUFFER:
698     case TensorStorageType::IMAGE_BUFFER: {
699       return coords[0];
700     }
701     case TensorStorageType::TEXTURE_2D:
702     case TensorStorageType::SINGLE_TEXTURE_2D:
703       return absl::Substitute("(int2)($0, $1)", coords[0], coords[1]);
704     case TensorStorageType::TEXTURE_ARRAY:
705     case TensorStorageType::TEXTURE_3D:
706       return absl::Substitute("(int4)($0, $1, $2, 0)", coords[0], coords[1],
707                               coords[2]);
708     case TensorStorageType::UNKNOWN:
709       return "error";
710   }
711 }
712 
GetPhysicalCoords(const std::string & xc,const std::string & yc,const std::string & zc,const std::string & sc,const std::string & bc) const713 std::vector<std::string> TensorDescriptor::GetPhysicalCoords(
714     const std::string& xc, const std::string& yc, const std::string& zc,
715     const std::string& sc, const std::string& bc) const {
716   if (layout == Layout::HWC || (IsBatchedWidth() && layout == Layout::BHWC)) {
717     return GetPhysicalCoordsWHS(xc, yc, sc);
718   } else if (layout == Layout::BHWC) {
719     return GetPhysicalCoordsWHSB(xc, yc, sc, bc);
720   } else if (layout == Layout::HWDC ||
721              (IsBatchedWidth() && layout == Layout::BHWDC)) {
722     return GetPhysicalCoordsWHDS(xc, yc, zc, sc);
723   } else if (layout == Layout::BHWDC) {
724     return GetPhysicalCoordsWHDSB(xc, yc, zc, sc, bc);
725   } else {
726     return {""};
727   }
728 }
729 
GetDataTypeFromTemplateArgs(const std::string & template_arg,DataType * result) const730 absl::Status TensorDescriptor::GetDataTypeFromTemplateArgs(
731     const std::string& template_arg, DataType* result) const {
732   std::string read_type = template_arg;
733   if (read_type == "FLT" || read_type == "ACCUM_FLT") {
734     auto it = state_vars_.find(read_type);
735     if (it == state_vars_.end()) {
736       return absl::UnavailableError(absl::StrCat(
737           "Read selector template argument ", read_type, " uninitialized."));
738     } else {
739       read_type = it->second;
740     }
741   }
742 
743   if (read_type == "half") {
744     *result = DataType::FLOAT16;
745   } else if (read_type == "float") {
746     *result = DataType::FLOAT32;
747   } else {
748     return absl::NotFoundError(absl::StrCat(
749         "Unrecognized Read selector template argument - ", read_type));
750   }
751   return absl::OkStatus();
752 }
753 
HasAxis(Axis axis) const754 bool TensorDescriptor::HasAxis(Axis axis) const {
755   if (axis == Axis::WIDTH || axis == Axis::HEIGHT || axis == Axis::CHANNELS) {
756     return true;
757   }
758   if (axis == Axis::BATCH &&
759       (layout == Layout::BHWC || layout == Layout::BHWDC)) {
760     return true;
761   }
762   if (axis == Axis::DEPTH &&
763       (layout == Layout::HWDC || layout == Layout::BHWDC)) {
764     return true;
765   }
766   return false;
767 }
768 
GetWidthSize(BHWDC shape) const769 int TensorDescriptor::GetWidthSize(BHWDC shape) const {
770   int width = shape.w;
771   auto it1 = state_vars_.find("ElementsX2");
772   if (it1 != state_vars_.end() && it1->second == "true") {
773     width /= 2;
774   }
775   auto it2 = state_vars_.find("ElementsX4");
776   if (it2 != state_vars_.end() && it2->second == "true") {
777     width /= 4;
778   }
779   auto it = state_vars_.find("BatchedWidth");
780   if (it != state_vars_.end() && it->second == "true") {
781     width *= shape.b;
782   }
783   return width;
784 }
785 
GetSliceStrideSize(BHWDC shape) const786 int TensorDescriptor::GetSliceStrideSize(BHWDC shape) const {
787   if (IsBatchedWidth()) {
788     return GetWidthSize(shape) * shape.h;
789   } else {
790     if (HasAxis(Axis::BATCH)) {
791       return GetWidthSize(shape) * shape.h * shape.b;
792     } else {
793       return GetWidthSize(shape) * shape.h;
794     }
795   }
796 }
797 
SetAddressMode(AddressMode mode)798 void TensorDescriptor::SetAddressMode(AddressMode mode) {
799   if (mode == AddressMode::kZero) {
800     state_vars_["TextureMode"] = "ZERO";
801   } else {
802     state_vars_["TextureMode"] = "DONT_CARE";
803   }
804 }
805 
ParseCoordsFromArgs(const std::vector<std::string> & args,int offset,std::string * xc,std::string * yc,std::string * zc,std::string * sc,std::string * bc) const806 bool TensorDescriptor::ParseCoordsFromArgs(const std::vector<std::string>& args,
807                                            int offset, std::string* xc,
808                                            std::string* yc, std::string* zc,
809                                            std::string* sc,
810                                            std::string* bc) const {
811   if (HasAxis(Axis::WIDTH)) {
812     if (offset >= args.size()) return false;
813     *xc = args[offset++];
814   }
815   if (HasAxis(Axis::HEIGHT)) {
816     if (offset >= args.size()) return false;
817     *yc = args[offset++];
818   }
819   if (HasAxis(Axis::DEPTH)) {
820     if (offset >= args.size()) return false;
821     *zc = args[offset++];
822   }
823   if (HasAxis(Axis::CHANNELS)) {
824     if (offset >= args.size()) {
825       auto it = state_vars_.find("slice_id");
826       if (it == state_vars_.end()) {
827         return false;
828       } else {
829         *sc = it->second;
830       }
831     } else {
832       *sc = args[offset++];
833     }
834   }
835   if (HasAxis(Axis::BATCH) && !IsBatchedWidth()) {
836     if (offset >= args.size()) {
837       auto it = state_vars_.find("batch_id");
838       if (it == state_vars_.end()) {
839         return false;
840       } else {
841         *bc = it->second;
842       }
843     } else {
844       *bc = args[offset++];
845     }
846   }
847   return true;
848 }
849 
IsBatchedWidth() const850 bool TensorDescriptor::IsBatchedWidth() const {
851   auto it = state_vars_.find("BatchedWidth");
852   return it != state_vars_.end() && it->second == "true";
853 }
854 
GetWidth() const855 std::string TensorDescriptor::GetWidth() const {
856   std::string div;
857   auto it1 = state_vars_.find("ElementsX2");
858   if (it1 != state_vars_.end() && it1->second == "true") {
859     div = "_div2";
860   }
861   auto it2 = state_vars_.find("ElementsX4");
862   if (it2 != state_vars_.end() && it2->second == "true") {
863     div = "_div4";
864   }
865   auto it = state_vars_.find("BatchedWidth");
866   if (it != state_vars_.end() && it->second == "true") {
867     return "width_batched" + div;
868   } else {
869     return "width" + div;
870   }
871 }
872 
AddressModeFromState() const873 AddressMode TensorDescriptor::AddressModeFromState() const {
874   auto it = state_vars_.find("TextureMode");
875   if (it != state_vars_.end()) {
876     if (it->second == "ZERO") {
877       return AddressMode::kZero;
878     } else {
879       return AddressMode::kDontCare;
880     }
881   } else {
882     return AddressMode::kDontCare;
883   }
884 }
885 
UploadData(const tflite::gpu::Tensor<BHWC,DataType::FLOAT32> & src)886 void TensorDescriptor::UploadData(
887     const tflite::gpu::Tensor<BHWC, DataType::FLOAT32>& src) {
888   shape = BHWDC(src.shape.b, src.shape.h, src.shape.w, 1, src.shape.c);
889   UploadData(src.data.data());
890 }
891 
UploadData(const tflite::gpu::Tensor<HWC,DataType::FLOAT32> & src)892 void TensorDescriptor::UploadData(
893     const tflite::gpu::Tensor<HWC, DataType::FLOAT32>& src) {
894   shape = BHWDC(1, src.shape.h, src.shape.w, 1, src.shape.c);
895   UploadData(src.data.data());
896 }
897 
UploadData(const tflite::gpu::Tensor<Linear,DataType::FLOAT32> & src)898 void TensorDescriptor::UploadData(
899     const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& src) {
900   shape = BHWDC(1, 1, 1, 1, src.shape.v);
901   UploadData(src.data.data());
902 }
903 
UploadData(const float * src)904 void TensorDescriptor::UploadData(const float* src) {
905   int aligned_channels = storage_type == TensorStorageType::SINGLE_TEXTURE_2D
906                              ? shape.c
907                              : AlignByN(shape.c, 4);
908   int elements_count = shape.b * shape.w * shape.h * shape.d * aligned_channels;
909   data.resize(elements_count * SizeOf(data_type));
910   if (data_type == DataType::FLOAT32) {
911     float* gpu_data = reinterpret_cast<float*>(data.data());
912     DataFromBHWDC(src, shape, *this, gpu_data);
913   } else {
914     half* gpu_data = reinterpret_cast<half*>(data.data());
915     DataFromBHWDC(src, shape, *this, gpu_data);
916   }
917 }
918 
SupportsZeroClamp(const Axis & axis) const919 bool TensorDescriptor::SupportsZeroClamp(const Axis& axis) const {
920   switch (storage_type) {
921     case TensorStorageType::UNKNOWN:
922       return false;
923     case TensorStorageType::BUFFER:
924     case TensorStorageType::IMAGE_BUFFER:
925       return false;
926     case TensorStorageType::TEXTURE_ARRAY:
927     case TensorStorageType::TEXTURE_2D:
928     case TensorStorageType::SINGLE_TEXTURE_2D:
929       return axis == Axis::WIDTH || axis == Axis::HEIGHT;
930     case TensorStorageType::TEXTURE_3D:
931       return axis == Axis::WIDTH || axis == Axis::HEIGHT || axis == Axis::DEPTH;
932   }
933 }
934 
CanReadOutOfBorder(const Axis & axis) const935 bool TensorDescriptor::CanReadOutOfBorder(const Axis& axis) const {
936   switch (storage_type) {
937     case TensorStorageType::UNKNOWN:
938       return false;
939     case TensorStorageType::BUFFER:
940       return false;
941     case TensorStorageType::IMAGE_BUFFER:
942     case TensorStorageType::TEXTURE_2D:
943     case TensorStorageType::TEXTURE_3D:
944     case TensorStorageType::SINGLE_TEXTURE_2D:
945     case TensorStorageType::TEXTURE_ARRAY:
946       return true;
947   }
948 }
949 
IsLinear() const950 bool TensorDescriptor::IsLinear() const {
951   return storage_type == TensorStorageType::BUFFER ||
952          storage_type == TensorStorageType::IMAGE_BUFFER;
953 }
954 
ReturnsZeroForNegOneRead() const955 bool TensorDescriptor::ReturnsZeroForNegOneRead() const {
956   return storage_type == TensorStorageType::IMAGE_BUFFER;
957 }
958 
959 namespace {
GetLinearIndex(const TensorDescriptor & desc,const BHWDC & shape,int b,int x,int y,int d,int s,int sub_c)960 int GetLinearIndex(const TensorDescriptor& desc, const BHWDC& shape, int b,
961                    int x, int y, int d, int s, int sub_c) {
962   const int slices = DivideRoundUp(shape.c, 4);
963   switch (desc.storage_type) {
964     case TensorStorageType::BUFFER:
965     case TensorStorageType::IMAGE_BUFFER:
966     case TensorStorageType::TEXTURE_ARRAY:
967     case TensorStorageType::TEXTURE_3D:
968       return ((((d * slices + s) * shape.h + y) * shape.w + x) * shape.b + b) *
969                  4 +
970              sub_c;  // DSHWBC4
971     case TensorStorageType::TEXTURE_2D:
972       return ((((y * slices + s) * shape.w + x) * shape.b + b) * shape.d + d) *
973                  4 +
974              sub_c;  // HSWBDC4
975     case TensorStorageType::SINGLE_TEXTURE_2D:
976       return (((y * shape.w + x) * shape.b + b) * shape.d + d) * shape.c +
977              sub_c;  // HWBDC
978     case TensorStorageType::UNKNOWN:
979       return -1;
980   }
981 }
982 
GetChannelsAlignment(const TensorDescriptor & desc,const BHWDC & shape)983 int GetChannelsAlignment(const TensorDescriptor& desc, const BHWDC& shape) {
984   return desc.storage_type == TensorStorageType::SINGLE_TEXTURE_2D ? shape.c
985                                                                    : 4;
986 }
987 }  // namespace
988 
989 template <typename T>
DataFromBHWDC(const float * src,const BHWDC & shape,const TensorDescriptor & desc,T * dst)990 void DataFromBHWDC(const float* src, const BHWDC& shape,
991                    const TensorDescriptor& desc, T* dst) {
992   const int channels_alignment = GetChannelsAlignment(desc, shape);
993   const int slices = DivideRoundUp(shape.c, 4);
994   for (int b = 0; b < shape.b; ++b) {
995     for (int s = 0; s < slices; ++s) {
996       for (int y = 0; y < shape.h; ++y) {
997         for (int x = 0; x < shape.w; ++x) {
998           for (int d = 0; d < shape.d; ++d) {
999             for (int c = 0; c < channels_alignment; ++c) {
1000               float value;
1001               if (s * 4 + c < shape.c) {
1002                 const int cpu_index =
1003                     shape.LinearIndex({b, y, x, d, s * 4 + c});
1004                 value = src[cpu_index];
1005               } else {
1006                 value = 0.0f;
1007               }
1008               int gpu_index = GetLinearIndex(desc, shape, b, x, y, d, s, c);
1009               dst[gpu_index] = value;
1010             }
1011           }
1012         }
1013       }
1014     }
1015   }
1016 }
1017 
1018 template void DataFromBHWDC<float>(const float* src, const BHWDC& shape,
1019                                    const TensorDescriptor& desc, float* dst);
1020 template void DataFromBHWDC<half>(const float* src, const BHWDC& shape,
1021                                   const TensorDescriptor& desc, half* dst);
1022 
1023 template <typename T>
DataToBHWDC(const T * src,const BHWDC & shape,const TensorDescriptor & desc,float * dst)1024 void DataToBHWDC(const T* src, const BHWDC& shape, const TensorDescriptor& desc,
1025                  float* dst) {
1026   const int channels_alignment = GetChannelsAlignment(desc, shape);
1027   const int slices = DivideRoundUp(shape.c, 4);
1028   for (int b = 0; b < shape.b; ++b) {
1029     for (int s = 0; s < slices; ++s) {
1030       for (int y = 0; y < shape.h; ++y) {
1031         for (int x = 0; x < shape.w; ++x) {
1032           for (int d = 0; d < shape.d; ++d) {
1033             for (int c = 0; c < channels_alignment; ++c) {
1034               if (s * 4 + c >= shape.c) {
1035                 continue;
1036               }
1037               int cpu_index = shape.LinearIndex({b, y, x, d, s * 4 + c});
1038               int gpu_index = GetLinearIndex(desc, shape, b, x, y, d, s, c);
1039               dst[cpu_index] = src[gpu_index];
1040             }
1041           }
1042         }
1043       }
1044     }
1045   }
1046 }
1047 
1048 template void DataToBHWDC<float>(const float* src, const BHWDC& shape,
1049                                  const TensorDescriptor& desc, float* dst);
1050 template void DataToBHWDC<half>(const half* src, const BHWDC& shape,
1051                                 const TensorDescriptor& desc, float* dst);
1052 
1053 }  // namespace gpu
1054 }  // namespace tflite
1055