1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h"
17
18 #include "absl/strings/ascii.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_format.h"
21 #include "absl/strings/str_join.h"
22 #include "absl/strings/str_split.h"
23 #include "absl/types/variant.h"
24 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
25 #include "tensorflow/lite/delegates/gpu/common/types.h"
26
27 namespace tflite {
28 namespace gpu {
29 namespace gl {
30 namespace object_accessor_internal {
31
32 // Splits name[index1, index2...] into 'name' and {'index1', 'index2'...}.
ParseElement(absl::string_view input)33 IndexedElement ParseElement(absl::string_view input) {
34 auto i = input.find('[');
35 if (i == std::string::npos || input.back() != ']') {
36 return {};
37 }
38 return {input.substr(0, i),
39 absl::StrSplit(input.substr(i + 1, input.size() - i - 2), ',',
40 absl::SkipWhitespace())};
41 }
42
43 } // namespace object_accessor_internal
44
45 namespace {
46
MaybeConvertToHalf(DataType data_type,absl::string_view value,std::string * output)47 void MaybeConvertToHalf(DataType data_type, absl::string_view value,
48 std::string* output) {
49 if (data_type == DataType::FLOAT16) {
50 absl::StrAppend(output, "Vec4ToHalf(", value, ")");
51 } else {
52 absl::StrAppend(output, value);
53 }
54 }
55
MaybeConvertFromHalf(DataType data_type,absl::string_view value,std::string * output)56 void MaybeConvertFromHalf(DataType data_type, absl::string_view value,
57 std::string* output) {
58 if (data_type == DataType::FLOAT16) {
59 absl::StrAppend(output, "Vec4FromHalf(", value, ")");
60 } else {
61 absl::StrAppend(output, value);
62 }
63 }
64
65 struct ReadFromTextureGenerator {
operator ()tflite::gpu::gl::__anon12a37c930111::ReadFromTextureGenerator66 RewriteStatus operator()(size_t) const {
67 if (element.indices.size() != 1) {
68 result->append("WRONG_NUMBER_OF_INDICES");
69 return RewriteStatus::ERROR;
70 }
71 // 1D textures are emulated as 2D textures
72 if (sampler_textures) {
73 absl::StrAppend(result, "texelFetch(", element.object_name, ", ivec2(",
74 element.indices[0], ", 0), 0)");
75 } else {
76 absl::StrAppend(result, "imageLoad(", element.object_name, ", ivec2(",
77 element.indices[0], ", 0))");
78 }
79 return RewriteStatus::SUCCESS;
80 }
81
82 template <typename Shape>
operator ()tflite::gpu::gl::__anon12a37c930111::ReadFromTextureGenerator83 RewriteStatus operator()(const Shape&) const {
84 if (element.indices.size() != Shape::size()) {
85 result->append("WRONG_NUMBER_OF_INDICES");
86 return RewriteStatus::ERROR;
87 }
88 if (sampler_textures) {
89 absl::StrAppend(result, "texelFetch(", element.object_name, ", ivec",
90 Shape::size(), "(", absl::StrJoin(element.indices, ", "),
91 "), 0)");
92 } else {
93 absl::StrAppend(result, "imageLoad(", element.object_name, ", ivec",
94 Shape::size(), "(", absl::StrJoin(element.indices, ", "),
95 "))");
96 }
97 return RewriteStatus::SUCCESS;
98 }
99
100 const object_accessor_internal::IndexedElement& element;
101 const bool sampler_textures;
102 std::string* result;
103 };
104
105 struct ReadFromBufferGenerator {
operator ()tflite::gpu::gl::__anon12a37c930111::ReadFromBufferGenerator106 RewriteStatus operator()(size_t) const {
107 if (element.indices.size() != 1) {
108 result->append("WRONG_NUMBER_OF_INDICES");
109 return RewriteStatus::ERROR;
110 }
111 MaybeConvertFromHalf(
112 data_type,
113 absl::StrCat(element.object_name, ".data[", element.indices[0], "]"),
114 result);
115 return RewriteStatus::SUCCESS;
116 }
117
operator ()tflite::gpu::gl::__anon12a37c930111::ReadFromBufferGenerator118 RewriteStatus operator()(const uint2& size) const {
119 if (element.indices.size() == 1) {
120 // access by linear index. Use method above to generate accessor.
121 return (*this)(1U);
122 }
123 if (element.indices.size() != 2) {
124 result->append("WRONG_NUMBER_OF_INDICES");
125 return RewriteStatus::ERROR;
126 }
127 MaybeConvertFromHalf(
128 data_type,
129 absl::StrCat(element.object_name, ".data[", element.indices[0], " + $",
130 element.object_name, "_w$ * (", element.indices[1], ")]"),
131 result);
132 *requires_sizes = true;
133 return RewriteStatus::SUCCESS;
134 }
135
operator ()tflite::gpu::gl::__anon12a37c930111::ReadFromBufferGenerator136 RewriteStatus operator()(const uint3& size) const {
137 if (element.indices.size() == 1) {
138 // access by linear index. Use method above to generate accessor.
139 return (*this)(1U);
140 }
141 if (element.indices.size() != 3) {
142 result->append("WRONG_NUMBER_OF_INDICES");
143 return RewriteStatus::ERROR;
144 }
145 MaybeConvertFromHalf(
146 data_type,
147 absl::StrCat(element.object_name, ".data[", element.indices[0], " + $",
148 element.object_name, "_w$ * (", element.indices[1], " + $",
149 element.object_name, "_h$ * (", element.indices[2], "))]"),
150 result);
151 *requires_sizes = true;
152 return RewriteStatus::SUCCESS;
153 }
154
155 DataType data_type;
156 const object_accessor_internal::IndexedElement& element;
157 std::string* result;
158
159 // indicates that generated code accessed _w and/or _h index variables.
160 bool* requires_sizes;
161 };
162
163 // Generates code for reading an element from an object.
GenerateReadAccessor(const Object & object,const object_accessor_internal::IndexedElement & element,bool sampler_textures,std::string * result,bool * requires_sizes)164 RewriteStatus GenerateReadAccessor(
165 const Object& object,
166 const object_accessor_internal::IndexedElement& element,
167 bool sampler_textures, std::string* result, bool* requires_sizes) {
168 switch (object.object_type) {
169 case ObjectType::BUFFER:
170 return absl::visit(ReadFromBufferGenerator{object.data_type, element,
171 result, requires_sizes},
172 object.size);
173 case ObjectType::TEXTURE:
174 return absl::visit(
175 ReadFromTextureGenerator{element, sampler_textures, result},
176 object.size);
177 case ObjectType::UNKNOWN:
178 return RewriteStatus::ERROR;
179 }
180 }
181
182 struct WriteToBufferGenerator {
operator ()tflite::gpu::gl::__anon12a37c930111::WriteToBufferGenerator183 RewriteStatus operator()(size_t) const {
184 if (element.indices.size() != 1) {
185 result->append("WRONG_NUMBER_OF_INDICES");
186 return RewriteStatus::ERROR;
187 }
188 absl::StrAppend(result, element.object_name, ".data[", element.indices[0],
189 "] = ");
190 MaybeConvertToHalf(data_type, value, result);
191 return RewriteStatus::SUCCESS;
192 }
193
operator ()tflite::gpu::gl::__anon12a37c930111::WriteToBufferGenerator194 RewriteStatus operator()(const uint2& size) const {
195 if (element.indices.size() == 1) {
196 // access by linear index. Use method above to generate accessor.
197 return (*this)(1U);
198 }
199 if (element.indices.size() != 2) {
200 result->append("WRONG_NUMBER_OF_INDICES");
201 return RewriteStatus::ERROR;
202 }
203 absl::StrAppend(result, element.object_name, ".data[", element.indices[0],
204 " + $", element.object_name, "_w$ * (", element.indices[1],
205 ")] = ");
206 MaybeConvertToHalf(data_type, value, result);
207 *requires_sizes = true;
208 return RewriteStatus::SUCCESS;
209 }
210
operator ()tflite::gpu::gl::__anon12a37c930111::WriteToBufferGenerator211 RewriteStatus operator()(const uint3& size) const {
212 if (element.indices.size() == 1) {
213 // access by linear index. Use method above to generate accessor.
214 return (*this)(1U);
215 }
216 if (element.indices.size() != 3) {
217 result->append("WRONG_NUMBER_OF_INDICES");
218 return RewriteStatus::ERROR;
219 }
220 absl::StrAppend(result, element.object_name, ".data[", element.indices[0],
221 " + $", element.object_name, "_w$ * (", element.indices[1],
222 " + $", element.object_name, "_h$ * (", element.indices[2],
223 "))] = ");
224 MaybeConvertToHalf(data_type, value, result);
225 *requires_sizes = true;
226 return RewriteStatus::SUCCESS;
227 }
228
229 DataType data_type;
230 const object_accessor_internal::IndexedElement& element;
231 absl::string_view value;
232 std::string* result;
233
234 // indicates that generated code accessed _w and/or _h index variables.
235 bool* requires_sizes;
236 };
237
238 struct WriteToTextureGenerator {
operator ()tflite::gpu::gl::__anon12a37c930111::WriteToTextureGenerator239 RewriteStatus operator()(size_t) const {
240 if (element.indices.size() != 1) {
241 result->append("WRONG_NUMBER_OF_INDICES");
242 return RewriteStatus::ERROR;
243 }
244 // 1D textures are emulated as 2D textures
245 absl::StrAppend(result, "imageStore(", element.object_name, ", ivec2(",
246 element.indices[0], ", 0), ", value, ")");
247 return RewriteStatus::SUCCESS;
248 }
249
250 template <typename Shape>
operator ()tflite::gpu::gl::__anon12a37c930111::WriteToTextureGenerator251 RewriteStatus operator()(const Shape&) const {
252 if (element.indices.size() != Shape::size()) {
253 result->append("WRONG_NUMBER_OF_INDICES");
254 return RewriteStatus::ERROR;
255 }
256 absl::StrAppend(result, "imageStore(", element.object_name, ", ivec",
257 Shape::size(), "(", absl::StrJoin(element.indices, ", "),
258 "), ", value, ")");
259 return RewriteStatus::SUCCESS;
260 }
261
262 const object_accessor_internal::IndexedElement& element;
263 absl::string_view value;
264 std::string* result;
265 };
266
267 // Generates code for writing value an element in an object.
GenerateWriteAccessor(const Object & object,const object_accessor_internal::IndexedElement & element,absl::string_view value,std::string * result,bool * requires_sizes)268 RewriteStatus GenerateWriteAccessor(
269 const Object& object,
270 const object_accessor_internal::IndexedElement& element,
271 absl::string_view value, std::string* result, bool* requires_sizes) {
272 switch (object.object_type) {
273 case ObjectType::BUFFER:
274 return absl::visit(WriteToBufferGenerator{object.data_type, element,
275 value, result, requires_sizes},
276 object.size);
277 case ObjectType::TEXTURE:
278 return absl::visit(WriteToTextureGenerator{element, value, result},
279 object.size);
280 case ObjectType::UNKNOWN:
281 return RewriteStatus::ERROR;
282 }
283 }
284
ToAccessModifier(AccessType access,bool use_readonly_modifier)285 std::string ToAccessModifier(AccessType access, bool use_readonly_modifier) {
286 switch (access) {
287 case AccessType::READ:
288 return use_readonly_modifier ? " readonly" : "";
289 case AccessType::WRITE:
290 return " writeonly";
291 case AccessType::READ_WRITE:
292 return " restrict";
293 }
294 return " unknown_access";
295 }
296
ToBufferType(DataType data_type)297 std::string ToBufferType(DataType data_type) {
298 switch (data_type) {
299 case DataType::UINT8:
300 case DataType::UINT16:
301 case DataType::UINT32:
302 return "uvec4";
303 case DataType::UINT64:
304 return "u64vec4_not_available_in_glsl";
305 case DataType::INT8:
306 case DataType::INT16:
307 case DataType::INT32:
308 return "ivec4";
309 case DataType::INT64:
310 return "i64vec4_not_available_in_glsl";
311 case DataType::FLOAT16:
312 return "uvec2";
313 case DataType::FLOAT32:
314 return "vec4";
315 case DataType::FLOAT64:
316 return "dvec4";
317 case DataType::UNKNOWN:
318 return "unknown_buffer_type";
319 // Do NOT add `default:'; we want build failure for new enum values.
320 }
321 }
322
323 struct TextureImageTypeGetter {
operator ()tflite::gpu::gl::__anon12a37c930111::TextureImageTypeGetter324 std::string operator()(size_t) const {
325 // 1D textures are emulated as 2D textures
326 return (*this)(uint2());
327 }
328
operator ()tflite::gpu::gl::__anon12a37c930111::TextureImageTypeGetter329 std::string operator()(const uint2&) const {
330 switch (type) {
331 case DataType::UINT16:
332 case DataType::UINT32:
333 return "uimage2D";
334 case DataType::INT16:
335 case DataType::INT32:
336 return "iimage2D";
337 case DataType::FLOAT16:
338 case DataType::FLOAT32:
339 return "image2D";
340 default:
341 return "unknown_image_2d";
342 }
343 }
344
operator ()tflite::gpu::gl::__anon12a37c930111::TextureImageTypeGetter345 std::string operator()(const uint3&) const {
346 switch (type) {
347 case DataType::UINT16:
348 case DataType::UINT32:
349 return "uimage2DArray";
350 case DataType::INT16:
351 case DataType::INT32:
352 return "iimage2DArray";
353 case DataType::FLOAT16:
354 case DataType::FLOAT32:
355 return "image2DArray";
356 default:
357 return "unknown_image_2d_array";
358 }
359 }
360
361 DataType type;
362 };
363
364 struct TextureSamplerTypeGetter {
operator ()tflite::gpu::gl::__anon12a37c930111::TextureSamplerTypeGetter365 std::string operator()(size_t) const {
366 // 1D textures are emulated as 2D textures
367 return (*this)(uint2());
368 }
369
operator ()tflite::gpu::gl::__anon12a37c930111::TextureSamplerTypeGetter370 std::string operator()(const uint2&) const {
371 switch (type) {
372 case DataType::FLOAT16:
373 case DataType::FLOAT32:
374 return "sampler2D";
375 case DataType::INT32:
376 case DataType::INT16:
377 return "isampler2D";
378 case DataType::UINT32:
379 case DataType::UINT16:
380 return "usampler2D";
381 default:
382 return "unknown_sampler2D";
383 }
384 }
385
operator ()tflite::gpu::gl::__anon12a37c930111::TextureSamplerTypeGetter386 std::string operator()(const uint3&) const {
387 switch (type) {
388 case DataType::FLOAT16:
389 case DataType::FLOAT32:
390 return "sampler2DArray";
391 case DataType::INT32:
392 case DataType::INT16:
393 return "isampler2DArray";
394 case DataType::UINT32:
395 case DataType::UINT16:
396 return "usampler2DArray";
397 default:
398 return "unknown_sampler2DArray";
399 }
400 }
401
402 DataType type;
403 };
404
ToImageType(const Object & object,bool sampler_textures)405 std::string ToImageType(const Object& object, bool sampler_textures) {
406 if (sampler_textures && (object.access == AccessType::READ)) {
407 return absl::visit(TextureSamplerTypeGetter{object.data_type}, object.size);
408 } else {
409 return absl::visit(TextureImageTypeGetter{object.data_type}, object.size);
410 }
411 }
412
ToImageLayoutQualifier(DataType type)413 std::string ToImageLayoutQualifier(DataType type) {
414 switch (type) {
415 case DataType::UINT16:
416 return "rgba16ui";
417 case DataType::UINT32:
418 return "rgba32ui";
419 case DataType::INT16:
420 return "rgba16i";
421 case DataType::INT32:
422 return "rgba32i";
423 case DataType::FLOAT16:
424 return "rgba16f";
425 case DataType::FLOAT32:
426 return "rgba32f";
427 default:
428 return "unknown_image_layout";
429 }
430 }
431
ToImagePrecision(DataType type)432 std::string ToImagePrecision(DataType type) {
433 switch (type) {
434 case DataType::UINT16:
435 case DataType::INT16:
436 case DataType::FLOAT16:
437 return "mediump";
438 case DataType::UINT32:
439 case DataType::INT32:
440 case DataType::FLOAT32:
441 return "highp";
442 default:
443 return "unknown_image_precision";
444 }
445 }
446
447 struct SizeParametersAdder {
operator ()tflite::gpu::gl::__anon12a37c930111::SizeParametersAdder448 void operator()(size_t) const {}
449
operator ()tflite::gpu::gl::__anon12a37c930111::SizeParametersAdder450 void operator()(const uint2& size) const {
451 variable_accessor->AddUniformParameter(
452 {absl::StrCat(object_name, "_w"), static_cast<int32_t>(size.x)});
453 }
454
455 // p1 and p2 are padding. For some reason buffer does not map correctly
456 // without it.
operator ()tflite::gpu::gl::__anon12a37c930111::SizeParametersAdder457 void operator()(const uint3& size) const {
458 variable_accessor->AddUniformParameter(
459 {absl::StrCat(object_name, "_w"), static_cast<int32_t>(size.x)});
460 variable_accessor->AddUniformParameter(
461 {absl::StrCat(object_name, "_h"), static_cast<int32_t>(size.y)});
462 }
463
464 absl::string_view object_name;
465 VariableAccessor* variable_accessor;
466 };
467
468 // Adds necessary parameters to parameter accessor that represent object size
469 // needed for indexed access.
470 // - 1D : empty
471 // - 2D : 'int object_name_w'
472 // - 3D : 'int object_name_w' + 'int object_name_h'
AddSizeParameters(absl::string_view object_name,const Object & object,VariableAccessor * parameters)473 void AddSizeParameters(absl::string_view object_name, const Object& object,
474 VariableAccessor* parameters) {
475 absl::visit(SizeParametersAdder{object_name, parameters}, object.size);
476 }
477
GenerateObjectDeclaration(absl::string_view name,const Object & object,std::string * declaration,bool is_mali,bool sampler_textures)478 void GenerateObjectDeclaration(absl::string_view name, const Object& object,
479 std::string* declaration, bool is_mali,
480 bool sampler_textures) {
481 switch (object.object_type) {
482 case ObjectType::BUFFER:
483 // readonly modifier used to fix shader compilation for Mali on Android 8,
484 // see b/111601761
485 absl::StrAppend(declaration, "layout(binding = ", object.binding, ")",
486 ToAccessModifier(object.access, !is_mali), " buffer B",
487 object.binding, " { ", ToBufferType(object.data_type),
488 " data[]; } ", name, ";\n");
489 break;
490 case ObjectType::TEXTURE:
491 if (sampler_textures && (object.access == AccessType::READ)) {
492 absl::StrAppend(declaration, "layout(binding = ", object.binding,
493 ") uniform ", ToImagePrecision(object.data_type), " ",
494 ToImageType(object, sampler_textures), " ", name,
495 ";\n");
496 } else {
497 absl::StrAppend(
498 declaration, "layout(", ToImageLayoutQualifier(object.data_type),
499 ", binding = ", object.binding, ")",
500 ToAccessModifier(object.access, true), " uniform ",
501 ToImagePrecision(object.data_type), " ",
502 ToImageType(object, sampler_textures), " ", name, ";\n");
503 }
504 break;
505 case ObjectType::UNKNOWN:
506 // do nothing.
507 break;
508 }
509 }
510
511 } // namespace
512
Rewrite(absl::string_view input,std::string * output)513 RewriteStatus ObjectAccessor::Rewrite(absl::string_view input,
514 std::string* output) {
515 // Splits 'a =b' into {'a','b'}.
516 std::pair<absl::string_view, absl::string_view> n =
517 absl::StrSplit(input, absl::MaxSplits('=', 1), absl::SkipWhitespace());
518 if (n.first.empty()) {
519 return RewriteStatus::NOT_RECOGNIZED;
520 }
521 if (n.second.empty()) {
522 return RewriteRead(absl::StripAsciiWhitespace(n.first), output);
523 }
524 return RewriteWrite(absl::StripAsciiWhitespace(n.first),
525 absl::StripAsciiWhitespace(n.second), output);
526 }
527
RewriteRead(absl::string_view location,std::string * output)528 RewriteStatus ObjectAccessor::RewriteRead(absl::string_view location,
529 std::string* output) {
530 auto element = object_accessor_internal::ParseElement(location);
531 if (element.object_name.empty()) {
532 return RewriteStatus::NOT_RECOGNIZED;
533 }
534 auto it = name_to_object_.find(
535 std::string(element.object_name.data(), element.object_name.size()));
536 if (it == name_to_object_.end()) {
537 return RewriteStatus::NOT_RECOGNIZED;
538 }
539 bool requires_sizes = false;
540 auto status = GenerateReadAccessor(it->second, element, sampler_textures_,
541 output, &requires_sizes);
542 if (requires_sizes) {
543 AddSizeParameters(it->first, it->second, variable_accessor_);
544 }
545 return status;
546 }
547
RewriteWrite(absl::string_view location,absl::string_view value,std::string * output)548 RewriteStatus ObjectAccessor::RewriteWrite(absl::string_view location,
549 absl::string_view value,
550 std::string* output) {
551 // name[index1, index2...] = value
552 auto element = object_accessor_internal::ParseElement(location);
553 if (element.object_name.empty()) {
554 return RewriteStatus::NOT_RECOGNIZED;
555 }
556 auto it = name_to_object_.find(
557 std::string(element.object_name.data(), element.object_name.size()));
558 if (it == name_to_object_.end()) {
559 return RewriteStatus::NOT_RECOGNIZED;
560 }
561 bool requires_sizes = false;
562 auto status = GenerateWriteAccessor(it->second, element, value, output,
563 &requires_sizes);
564 if (requires_sizes) {
565 AddSizeParameters(it->first, it->second, variable_accessor_);
566 }
567 return status;
568 }
569
AddObject(const std::string & name,Object object)570 bool ObjectAccessor::AddObject(const std::string& name, Object object) {
571 if (object.object_type == ObjectType::UNKNOWN) {
572 return false;
573 }
574 return name_to_object_.insert({name, std::move(object)}).second;
575 }
576
GetObjectDeclarations() const577 std::string ObjectAccessor::GetObjectDeclarations() const {
578 std::string declarations;
579 for (auto& o : name_to_object_) {
580 GenerateObjectDeclaration(o.first, o.second, &declarations, is_mali_,
581 sampler_textures_);
582 }
583 return declarations;
584 }
585
GetFunctionsDeclarations() const586 std::string ObjectAccessor::GetFunctionsDeclarations() const {
587 // If there is a single object SSBO with F16, then we need to output macros
588 // as well.
589 for (const auto& o : name_to_object_) {
590 if (o.second.data_type == DataType::FLOAT16 &&
591 o.second.object_type == ObjectType::BUFFER) {
592 return absl::StrCat(
593 "#define Vec4FromHalf(v) vec4(unpackHalf2x16(v.x), "
594 "unpackHalf2x16(v.y))\n",
595 "#define Vec4ToHalf(v) uvec2(packHalf2x16(v.xy), "
596 "packHalf2x16(v.zw))");
597 }
598 }
599 return "";
600 }
601
GetObjects() const602 std::vector<Object> ObjectAccessor::GetObjects() const {
603 std::vector<Object> objects;
604 for (auto& o : name_to_object_) {
605 objects.push_back(o.second);
606 }
607 return objects;
608 }
609
610 } // namespace gl
611 } // namespace gpu
612 } // namespace tflite
613