1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h"
17
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/str_format.h"
20 #include "absl/strings/str_join.h"
21 #include "absl/types/variant.h"
22 #include "tensorflow/lite/delegates/gpu/common/types.h"
23
24 namespace tflite {
25 namespace gpu {
26 namespace gl {
27 namespace variable_accessor_internal {
28
29 // Parse the following regex manually
30 // name(\[index\])?(\.field)?
Parse(absl::string_view input)31 VariableReference Parse(absl::string_view input) {
32 VariableReference ref;
33 auto start_index = input.find('[');
34 if (start_index != std::string::npos) {
35 auto end_index = input.rfind(']');
36 if (end_index == std::string::npos) {
37 return ref;
38 }
39 ref.index = input.substr(start_index + 1, end_index - start_index - 1);
40 ref.name = input.substr(0, start_index);
41 ref.field = input.substr(end_index + 1);
42 } else {
43 auto dot = input.find('.');
44 if (dot != std::string::npos) {
45 ref.name = input.substr(0, dot);
46 ref.field = input.substr(dot);
47 } else {
48 ref.name = input;
49 }
50 }
51 return ref;
52 }
53
54 } // namespace variable_accessor_internal
55
56 namespace {
57
58 struct VariableTypeGetter {
operator ()tflite::gpu::gl::__anonc6f577c20111::VariableTypeGetter59 std::string operator()(int) const { return "int"; }
operator ()tflite::gpu::gl::__anonc6f577c20111::VariableTypeGetter60 std::string operator()(const int2&) const { return "ivec2"; }
operator ()tflite::gpu::gl::__anonc6f577c20111::VariableTypeGetter61 std::string operator()(const std::vector<int2>&) const { return "ivec2"; }
operator ()tflite::gpu::gl::__anonc6f577c20111::VariableTypeGetter62 std::string operator()(const int4&) const { return "ivec4"; }
operator ()tflite::gpu::gl::__anonc6f577c20111::VariableTypeGetter63 std::string operator()(unsigned int) const { return "uint"; }
operator ()tflite::gpu::gl::__anonc6f577c20111::VariableTypeGetter64 std::string operator()(const uint4&) const { return "uvec4"; }
operator ()tflite::gpu::gl::__anonc6f577c20111::VariableTypeGetter65 std::string operator()(float) const { return "float"; }
operator ()tflite::gpu::gl::__anonc6f577c20111::VariableTypeGetter66 std::string operator()(const float2&) const { return "vec2"; }
operator ()tflite::gpu::gl::__anonc6f577c20111::VariableTypeGetter67 std::string operator()(const float4&) const { return "vec4"; }
operator ()tflite::gpu::gl::__anonc6f577c20111::VariableTypeGetter68 std::string operator()(const std::vector<float4>&) const { return "vec4"; }
69 };
70
71 // Returns GLSL uniform type of the given variable.
GetVariableType(const Variable::ValueType & value)72 std::string GetVariableType(const Variable::ValueType& value) {
73 return absl::visit(VariableTypeGetter(), value);
74 }
75
76 struct LengthGetter {
77 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::LengthGetter78 int operator()(const T& param) const {
79 return 1;
80 }
81 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::LengthGetter82 int operator()(const std::vector<T>& param) const {
83 return param.size();
84 }
85 };
86
GetLength(const Variable::ValueType & value)87 int GetLength(const Variable::ValueType& value) {
88 return absl::visit(LengthGetter(), value);
89 }
90
91 template <typename T>
FormatValue(std::string * result,T t)92 void FormatValue(std::string* result, T t) {
93 absl::StrAppend(result, t);
94 }
95
96 template <>
FormatValue(std::string * result,float t)97 void FormatValue(std::string* result, float t) {
98 absl::StrAppend(result, absl::StrFormat("%.9ff", t));
99 }
100
101 // Unfortunately absl::StrJoin with custom formatter requires formatter to use
102 // string, not std::string. Therefore, due to this compatibility issue data
103 // needs to be converted to string representation first and then joined.
104 template <typename T, int N>
ToString(const std::array<T,N> & data)105 std::vector<std::string> ToString(const std::array<T, N>& data) {
106 std::vector<std::string> result(N);
107 for (int i = 0; i < N; ++i) {
108 FormatValue(&result[i], data[i]);
109 }
110 return result;
111 }
112
113 struct ConstGenerator {
114 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::ConstGenerator115 void operator()(T t) const {
116 FormatValue(result, t);
117 }
118
119 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::ConstGenerator120 void operator()(const Vec2<T>& v) const {
121 absl::StrAppend(result, VariableTypeGetter()(v), "(",
122 absl::StrJoin(ToString<T, 2>(v.data_), ","), ")");
123 }
124
125 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::ConstGenerator126 void operator()(const Vec3<T>& v) const {
127 absl::StrAppend(result, VariableTypeGetter()(v), "(",
128 absl::StrJoin(ToString<T, 3>(v.data_), ","), ")");
129 }
130
131 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::ConstGenerator132 void operator()(const Vec4<T>& v) const {
133 absl::StrAppend(result, VariableTypeGetter()(v), "(",
134 absl::StrJoin(ToString<T, 4>(v.data_), ","), ")");
135 }
136
137 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::ConstGenerator138 void operator()(const std::vector<T>& v) const {
139 std::string type = VariableTypeGetter()(v);
140 absl::StrAppend(result, type, "[", v.size(), "](");
141 bool first = true;
142 for (const auto& i : v) {
143 if (first) {
144 first = false;
145 } else {
146 absl::StrAppend(result, ",");
147 }
148 (*this)(i);
149 }
150 absl::StrAppend(result, ")");
151 }
152
153 std::string* result;
154 };
155
156 // Appends string representation of a variable value.
GetValue(const Variable::ValueType & value,std::string * result)157 void GetValue(const Variable::ValueType& value, std::string* result) {
158 absl::visit(ConstGenerator{result}, value);
159 }
160
161 struct SharedVariableDeclarationGenerator {
162 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::SharedVariableDeclarationGenerator163 void operator()(const T&) const {
164 absl::StrAppend(result, "shared highp ", GetVariableType(variable.value),
165 " ", variable.name, ";\n");
166 }
167
168 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::SharedVariableDeclarationGenerator169 void operator()(const std::vector<T>& v) const {
170 absl::StrAppend(result, "shared highp ", GetVariableType(variable.value),
171 " ", variable.name);
172 if (v.empty()) {
173 // Normalize the size of the shared array to that of the WorkGroupSize
174 absl::StrAppend(
175 result,
176 "[gl_WorkGroupSize.z * gl_WorkGroupSize.y * gl_WorkGroupSize.x];\n");
177 } else {
178 // Use the specified size
179 absl::StrAppend(result, "[", v.size(), "];\n");
180 }
181 }
182
183 const Variable& variable;
184 std::string* result;
185 };
186
GenerateSharedVariableDeclaration(const Variable & variable,std::string * result)187 void GenerateSharedVariableDeclaration(const Variable& variable,
188 std::string* result) {
189 absl::visit(SharedVariableDeclarationGenerator{variable, result},
190 variable.value);
191 }
192
193 struct UniformParameterDeclarationGenerator {
194 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::UniformParameterDeclarationGenerator195 void operator()(const T&) const {
196 absl::StrAppend(result, "uniform ", GetVariableType(variable.value), " ",
197 variable.name, ";\n");
198 }
199
200 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::UniformParameterDeclarationGenerator201 void operator()(const std::vector<T>& v) const {
202 absl::StrAppend(result, "uniform ", GetVariableType(variable.value), " ",
203 variable.name, "[", v.size(), "];\n");
204 }
205
206 const Variable& variable;
207 std::string* result;
208 };
209
GenerateUniformParameterDeclaration(const Variable & variable,std::string * result)210 void GenerateUniformParameterDeclaration(const Variable& variable,
211 std::string* result) {
212 absl::visit(UniformParameterDeclarationGenerator{variable, result},
213 variable.value);
214 }
215
216 struct VulkanPushConstantGenerator {
217 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::VulkanPushConstantGenerator218 void operator()(const T&) const {
219 absl::StrAppend(result, " ", GetVariableType(variable.value), " ",
220 variable.name, ";\n");
221 }
222
223 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::VulkanPushConstantGenerator224 void operator()(const std::vector<T>& v) const {
225 absl::StrAppend(result, " ", GetVariableType(variable.value), " ",
226 variable.name, "[", v.size(), "];\n");
227 }
228
229 const Variable& variable;
230 std::string* result;
231 };
232
GenerateVulkanPushConstant(const Variable & variable,std::string * result)233 void GenerateVulkanPushConstant(const Variable& variable, std::string* result) {
234 absl::visit(VulkanPushConstantGenerator{variable, result}, variable.value);
235 }
236
237 struct VariableLengthGetter {
238 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::VariableLengthGetter239 bool operator()(const T&) const {
240 return false;
241 }
242 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::VariableLengthGetter243 bool operator()(const std::vector<T>&) const {
244 return true;
245 }
246 };
247
248 struct VulkanConstantGenerator {
249 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::VulkanConstantGenerator250 void operator()(const T&) const {
251 const std::string variable_type = GetVariableType(variable.value);
252
253 // Vulkan specialization constants are used for scalar types, all other
254 // types go in push (uniform) constants.
255 if (variable_type == "int" || variable_type == "uint" ||
256 variable_type == "float") {
257 absl::StrAppend(result, "layout(constant_id = ", *constant_id, ") const ",
258 variable_type, " ", variable.name, " = ");
259 // Always set the default values to zero to generate generic cacheable
260 // shaders.
261 absl::StrAppend(result, (variable_type == "float" ? "0.0" : "0"), ";\n");
262 (*constant_id)++;
263 } else {
264 non_scalar_variables->push_back(variable);
265 }
266 }
267
268 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::VulkanConstantGenerator269 void operator()(const std::vector<T>& v) const {
270 non_scalar_variables->push_back(variable);
271 }
272
273 const Variable& variable;
274 int* const constant_id;
275 std::vector<Variable>* non_scalar_variables;
276 std::string* result;
277 };
278
GenerateVulkanConstant(const Variable & variable,int * constant_id,std::vector<Variable> * non_scalar_variables,std::string * result)279 void GenerateVulkanConstant(const Variable& variable, int* constant_id,
280 std::vector<Variable>* non_scalar_variables,
281 std::string* result) {
282 absl::visit(VulkanConstantGenerator{variable, constant_id,
283 non_scalar_variables, result},
284 variable.value);
285 }
286
287 class VulkanConstantsProcessor {
288 public:
ProcessVulkanConstant(const Variable & variable,std::string * result)289 void ProcessVulkanConstant(const Variable& variable, std::string* result) {
290 GenerateVulkanConstant(variable, &constant_id_, &non_scalar_variables_,
291 result);
292 }
293
GeneratePushConstantsDeclarations(std::string * result)294 void GeneratePushConstantsDeclarations(std::string* result) {
295 if (!non_scalar_variables_.empty()) {
296 *result += "\nlayout(push_constant) uniform pushConstants {\n";
297 for (const auto& variable : non_scalar_variables_) {
298 GenerateVulkanPushConstant(variable, result);
299 }
300 *result += "};\n";
301 }
302 }
303
304 protected:
305 // Reserve the first three specialization constants slots for the
306 // workgroup size.
307 int constant_id_ = 3;
308 std::vector<Variable> non_scalar_variables_;
309 };
310
311 // Returns true if value is a vector
IsVariableLength(const Variable::ValueType & value)312 bool IsVariableLength(const Variable::ValueType& value) {
313 return absl::visit(VariableLengthGetter(), value);
314 }
315
316 enum Field : uint8_t { UNKNOWN = 4, X = 0, Y = 1, Z = 2, W = 3 };
317
ToField(absl::string_view field_name)318 Field ToField(absl::string_view field_name) {
319 if (field_name.size() == 2 && field_name[0] == '.') {
320 switch (field_name[1]) {
321 case 'x':
322 return Field::X;
323 case 'y':
324 return Field::Y;
325 case 'z':
326 return Field::Z;
327 case 'w':
328 return Field::W;
329 }
330 }
331 return Field::UNKNOWN;
332 }
333
334 struct FieldAccessor {
335 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::FieldAccessor336 void operator()(const T&) const {}
337
338 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::FieldAccessor339 void operator()(const Vec2<T>& v) const {
340 FormatValue(result, v[field]);
341 }
342
343 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::FieldAccessor344 void operator()(const Vec3<T>& v) const {
345 FormatValue(result, v[field]);
346 }
347
348 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::FieldAccessor349 void operator()(const Vec4<T>& v) const {
350 FormatValue(result, v[field]);
351 }
352
353 Field field;
354 std::string* result;
355 };
356
357 // Appends formatted value of the given field.
GetValue(const Variable::ValueType & value,Field field,std::string * result)358 void GetValue(const Variable::ValueType& value, Field field,
359 std::string* result) {
360 absl::visit(FieldAccessor{field, result}, value);
361 }
362
363 struct FieldChecker {
364 // For trivial as well as variable-length types indexed access is not allowed.
365 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::FieldChecker366 bool operator()(const T&) const {
367 return false;
368 }
369
370 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::FieldChecker371 bool operator()(const Vec2<T>& v) const {
372 return field < v.size();
373 }
374
375 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::FieldChecker376 bool operator()(const Vec3<T>& v) const {
377 return field < v.size();
378 }
379
380 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::FieldChecker381 bool operator()(const Vec4<T>& v) const {
382 return field < v.size();
383 }
384
385 template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::FieldChecker386 bool operator()(const std::vector<T>&) const {
387 // technically accessing [0] element of an empty vector is UB, but we need
388 // only type information for this check. Therefore, construct default T and
389 // use it instead.
390 T t;
391 return (*this)(t);
392 }
393
394 Field field;
395 };
396
397 // Returns true if field has field access and field is not out of bounds.
HasField(const Variable::ValueType & value,Field field)398 bool HasField(const Variable::ValueType& value, Field field) {
399 return absl::visit(FieldChecker{field}, value);
400 }
401
AssembleAccessor(absl::string_view name,absl::string_view index,absl::string_view field,std::string * result)402 void AssembleAccessor(absl::string_view name, absl::string_view index,
403 absl::string_view field, std::string* result) {
404 if (index.empty()) {
405 absl::StrAppend(result, name, field);
406 } else {
407 absl::StrAppend(result, name, "[", index, "]", field);
408 }
409 }
410
411 } // namespace
412
Rewrite(absl::string_view input,std::string * output)413 RewriteStatus VariableAccessor::Rewrite(absl::string_view input,
414 std::string* output) {
415 auto ref = variable_accessor_internal::Parse(input);
416 if (ref.name.empty()) {
417 absl::StrAppend(output, "INVALID_SYNTAX");
418 return RewriteStatus::ERROR;
419 }
420
421 auto it =
422 name_to_variable_.find(std::string(ref.name.data(), ref.name.size()));
423 if (it == name_to_variable_.end()) {
424 // Uniform with this name is not registered.
425 return RewriteStatus::NOT_RECOGNIZED;
426 }
427 const auto& value = it->second.value;
428
429 if (!ref.index.empty() && !IsVariableLength(value)) {
430 // Trying to access variable by index, but it is not variable-length.
431 absl::StrAppend(output, "INVALID_ACCESS_BY_INDEX");
432 return RewriteStatus::ERROR;
433 }
434
435 Field f = ToField(ref.field);
436 if (!ref.field.empty() && !HasField(value, f)) {
437 // Trying to access a variable by field, but it does not have it.
438 absl::StrAppend(output, "INVALID_ACCESS_BY_FIELD");
439 return RewriteStatus::ERROR;
440 }
441
442 // Error checks are complete now.
443
444 // All variable-length variables are encoded as-is without inlining.
445 if (!inline_values_ || IsVariableLength(value)) {
446 AssembleAccessor(it->second.name, ref.index, ref.field, output);
447 } else {
448 // Parameter + field is replaced with field value.
449 if (f != Field::UNKNOWN) {
450 GetValue(value, f, output);
451 } else {
452 // Parameter is accessed directly.
453 GetValue(value, output);
454 }
455 }
456 return RewriteStatus::SUCCESS;
457 }
458
AddSharedVariable(Variable && variable)459 bool VariableAccessor::AddSharedVariable(Variable&& variable) {
460 const std::string name = variable.name;
461 if (!name_to_variable_.insert({name, std::move(variable)}).second) {
462 return false;
463 }
464 shared_variables_.insert(name);
465 return true;
466 }
467
AddUniformParameter(Variable && variable)468 bool VariableAccessor::AddUniformParameter(Variable&& variable) {
469 const std::string name = variable.name;
470 if (!name_to_variable_.insert({name, std::move(variable)}).second) {
471 return false;
472 }
473 uniform_parameters_.insert(name);
474 return true;
475 }
476
IsEmptyVariableLength(const Variable & variable) const477 bool VariableAccessor::IsEmptyVariableLength(const Variable& variable) const {
478 const auto& value = variable.value;
479 return IsVariableLength(value) && GetLength(value) == 0;
480 }
481
GetConstDeclarations() const482 std::string VariableAccessor::GetConstDeclarations() const {
483 // Variable length variables are declared as const and accessed via variable
484 // with index.
485 std::string declarations;
486 for (const auto& variable : name_to_variable_) {
487 // Skip shared variables.
488 const std::string& variable_name = variable.second.name;
489 if (shared_variables_.find(variable_name) != shared_variables_.end()) {
490 continue;
491 }
492
493 const auto& value = variable.second.value;
494 if (IsVariableLength(value)) {
495 absl::StrAppend(&declarations, "const ", GetVariableType(value), " ",
496 variable_name, "[] = ");
497 GetValue(value, &declarations);
498 absl::StrAppend(&declarations, ";\n");
499 }
500 }
501 return declarations;
502 }
503
GetSharedVariableDeclarations() const504 std::string VariableAccessor::GetSharedVariableDeclarations() const {
505 std::string declarations;
506 for (const auto& name : shared_variables_) {
507 const auto& variable = name_to_variable_.at(name);
508 GenerateSharedVariableDeclaration(variable, &declarations);
509 }
510 return declarations;
511 }
512
GetUniformParameterDeclarations() const513 std::string VariableAccessor::GetUniformParameterDeclarations() const {
514 std::string declarations;
515 if (!inline_values_) {
516 if (vulkan_support_) {
517 VulkanConstantsProcessor processor;
518 for (const auto& name : uniform_parameters_) {
519 const auto& variable = name_to_variable_.at(name);
520 processor.ProcessVulkanConstant(variable, &declarations);
521 }
522 processor.GeneratePushConstantsDeclarations(&declarations);
523 } else {
524 for (const auto& name : uniform_parameters_) {
525 const auto& variable = name_to_variable_.at(name);
526 GenerateUniformParameterDeclaration(variable, &declarations);
527 }
528 }
529 }
530 return declarations;
531 }
532
GetUniformParameters() const533 std::vector<Variable> VariableAccessor::GetUniformParameters() const {
534 std::vector<Variable> variables;
535 if (!inline_values_) {
536 variables.reserve(name_to_variable_.size());
537 // Keep the order of the variables consistent with that of the declarations
538 for (const auto& name : uniform_parameters_) {
539 const auto& variable = name_to_variable_.at(name);
540 variables.push_back(variable);
541 }
542 }
543 return variables;
544 }
545
546 } // namespace gl
547 } // namespace gpu
548 } // namespace tflite
549