1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/str_format.h"
20 #include "absl/strings/str_join.h"
21 #include "absl/types/variant.h"
22 #include "tensorflow/lite/delegates/gpu/common/types.h"
23 
24 namespace tflite {
25 namespace gpu {
26 namespace gl {
27 namespace variable_accessor_internal {
28 
29 // Parse the following regex manually
30 // name(\[index\])?(\.field)?
Parse(absl::string_view input)31 VariableReference Parse(absl::string_view input) {
32   VariableReference ref;
33   auto start_index = input.find('[');
34   if (start_index != std::string::npos) {
35     auto end_index = input.rfind(']');
36     if (end_index == std::string::npos) {
37       return ref;
38     }
39     ref.index = input.substr(start_index + 1, end_index - start_index - 1);
40     ref.name = input.substr(0, start_index);
41     ref.field = input.substr(end_index + 1);
42   } else {
43     auto dot = input.find('.');
44     if (dot != std::string::npos) {
45       ref.name = input.substr(0, dot);
46       ref.field = input.substr(dot);
47     } else {
48       ref.name = input;
49     }
50   }
51   return ref;
52 }
53 
54 }  // namespace variable_accessor_internal
55 
56 namespace {
57 
58 struct VariableTypeGetter {
operator ()tflite::gpu::gl::__anonc6f577c20111::VariableTypeGetter59   std::string operator()(int) const { return "int"; }
operator ()tflite::gpu::gl::__anonc6f577c20111::VariableTypeGetter60   std::string operator()(const int2&) const { return "ivec2"; }
operator ()tflite::gpu::gl::__anonc6f577c20111::VariableTypeGetter61   std::string operator()(const std::vector<int2>&) const { return "ivec2"; }
operator ()tflite::gpu::gl::__anonc6f577c20111::VariableTypeGetter62   std::string operator()(const int4&) const { return "ivec4"; }
operator ()tflite::gpu::gl::__anonc6f577c20111::VariableTypeGetter63   std::string operator()(unsigned int) const { return "uint"; }
operator ()tflite::gpu::gl::__anonc6f577c20111::VariableTypeGetter64   std::string operator()(const uint4&) const { return "uvec4"; }
operator ()tflite::gpu::gl::__anonc6f577c20111::VariableTypeGetter65   std::string operator()(float) const { return "float"; }
operator ()tflite::gpu::gl::__anonc6f577c20111::VariableTypeGetter66   std::string operator()(const float2&) const { return "vec2"; }
operator ()tflite::gpu::gl::__anonc6f577c20111::VariableTypeGetter67   std::string operator()(const float4&) const { return "vec4"; }
operator ()tflite::gpu::gl::__anonc6f577c20111::VariableTypeGetter68   std::string operator()(const std::vector<float4>&) const { return "vec4"; }
69 };
70 
71 // Returns GLSL uniform type of the given variable.
GetVariableType(const Variable::ValueType & value)72 std::string GetVariableType(const Variable::ValueType& value) {
73   return absl::visit(VariableTypeGetter(), value);
74 }
75 
76 struct LengthGetter {
77   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::LengthGetter78   int operator()(const T& param) const {
79     return 1;
80   }
81   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::LengthGetter82   int operator()(const std::vector<T>& param) const {
83     return param.size();
84   }
85 };
86 
GetLength(const Variable::ValueType & value)87 int GetLength(const Variable::ValueType& value) {
88   return absl::visit(LengthGetter(), value);
89 }
90 
91 template <typename T>
FormatValue(std::string * result,T t)92 void FormatValue(std::string* result, T t) {
93   absl::StrAppend(result, t);
94 }
95 
96 template <>
FormatValue(std::string * result,float t)97 void FormatValue(std::string* result, float t) {
98   absl::StrAppend(result, absl::StrFormat("%.9ff", t));
99 }
100 
101 // Unfortunately absl::StrJoin with custom formatter requires formatter to use
102 // string, not std::string. Therefore, due to this compatibility issue data
103 // needs to be converted to string representation first and then joined.
104 template <typename T, int N>
ToString(const std::array<T,N> & data)105 std::vector<std::string> ToString(const std::array<T, N>& data) {
106   std::vector<std::string> result(N);
107   for (int i = 0; i < N; ++i) {
108     FormatValue(&result[i], data[i]);
109   }
110   return result;
111 }
112 
113 struct ConstGenerator {
114   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::ConstGenerator115   void operator()(T t) const {
116     FormatValue(result, t);
117   }
118 
119   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::ConstGenerator120   void operator()(const Vec2<T>& v) const {
121     absl::StrAppend(result, VariableTypeGetter()(v), "(",
122                     absl::StrJoin(ToString<T, 2>(v.data_), ","), ")");
123   }
124 
125   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::ConstGenerator126   void operator()(const Vec3<T>& v) const {
127     absl::StrAppend(result, VariableTypeGetter()(v), "(",
128                     absl::StrJoin(ToString<T, 3>(v.data_), ","), ")");
129   }
130 
131   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::ConstGenerator132   void operator()(const Vec4<T>& v) const {
133     absl::StrAppend(result, VariableTypeGetter()(v), "(",
134                     absl::StrJoin(ToString<T, 4>(v.data_), ","), ")");
135   }
136 
137   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::ConstGenerator138   void operator()(const std::vector<T>& v) const {
139     std::string type = VariableTypeGetter()(v);
140     absl::StrAppend(result, type, "[", v.size(), "](");
141     bool first = true;
142     for (const auto& i : v) {
143       if (first) {
144         first = false;
145       } else {
146         absl::StrAppend(result, ",");
147       }
148       (*this)(i);
149     }
150     absl::StrAppend(result, ")");
151   }
152 
153   std::string* result;
154 };
155 
156 // Appends string representation of a variable value.
GetValue(const Variable::ValueType & value,std::string * result)157 void GetValue(const Variable::ValueType& value, std::string* result) {
158   absl::visit(ConstGenerator{result}, value);
159 }
160 
161 struct SharedVariableDeclarationGenerator {
162   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::SharedVariableDeclarationGenerator163   void operator()(const T&) const {
164     absl::StrAppend(result, "shared highp ", GetVariableType(variable.value),
165                     " ", variable.name, ";\n");
166   }
167 
168   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::SharedVariableDeclarationGenerator169   void operator()(const std::vector<T>& v) const {
170     absl::StrAppend(result, "shared highp ", GetVariableType(variable.value),
171                     " ", variable.name);
172     if (v.empty()) {
173       // Normalize the size of the shared array to that of the WorkGroupSize
174       absl::StrAppend(
175           result,
176           "[gl_WorkGroupSize.z * gl_WorkGroupSize.y * gl_WorkGroupSize.x];\n");
177     } else {
178       // Use the specified size
179       absl::StrAppend(result, "[", v.size(), "];\n");
180     }
181   }
182 
183   const Variable& variable;
184   std::string* result;
185 };
186 
GenerateSharedVariableDeclaration(const Variable & variable,std::string * result)187 void GenerateSharedVariableDeclaration(const Variable& variable,
188                                        std::string* result) {
189   absl::visit(SharedVariableDeclarationGenerator{variable, result},
190               variable.value);
191 }
192 
193 struct UniformParameterDeclarationGenerator {
194   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::UniformParameterDeclarationGenerator195   void operator()(const T&) const {
196     absl::StrAppend(result, "uniform ", GetVariableType(variable.value), " ",
197                     variable.name, ";\n");
198   }
199 
200   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::UniformParameterDeclarationGenerator201   void operator()(const std::vector<T>& v) const {
202     absl::StrAppend(result, "uniform ", GetVariableType(variable.value), " ",
203                     variable.name, "[", v.size(), "];\n");
204   }
205 
206   const Variable& variable;
207   std::string* result;
208 };
209 
GenerateUniformParameterDeclaration(const Variable & variable,std::string * result)210 void GenerateUniformParameterDeclaration(const Variable& variable,
211                                          std::string* result) {
212   absl::visit(UniformParameterDeclarationGenerator{variable, result},
213               variable.value);
214 }
215 
216 struct VulkanPushConstantGenerator {
217   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::VulkanPushConstantGenerator218   void operator()(const T&) const {
219     absl::StrAppend(result, "  ", GetVariableType(variable.value), " ",
220                     variable.name, ";\n");
221   }
222 
223   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::VulkanPushConstantGenerator224   void operator()(const std::vector<T>& v) const {
225     absl::StrAppend(result, "  ", GetVariableType(variable.value), " ",
226                     variable.name, "[", v.size(), "];\n");
227   }
228 
229   const Variable& variable;
230   std::string* result;
231 };
232 
GenerateVulkanPushConstant(const Variable & variable,std::string * result)233 void GenerateVulkanPushConstant(const Variable& variable, std::string* result) {
234   absl::visit(VulkanPushConstantGenerator{variable, result}, variable.value);
235 }
236 
237 struct VariableLengthGetter {
238   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::VariableLengthGetter239   bool operator()(const T&) const {
240     return false;
241   }
242   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::VariableLengthGetter243   bool operator()(const std::vector<T>&) const {
244     return true;
245   }
246 };
247 
248 struct VulkanConstantGenerator {
249   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::VulkanConstantGenerator250   void operator()(const T&) const {
251     const std::string variable_type = GetVariableType(variable.value);
252 
253     // Vulkan specialization constants are used for scalar types, all other
254     // types go in push (uniform) constants.
255     if (variable_type == "int" || variable_type == "uint" ||
256         variable_type == "float") {
257       absl::StrAppend(result, "layout(constant_id = ", *constant_id, ") const ",
258                       variable_type, " ", variable.name, " = ");
259       // Always set the default values to zero to generate generic cacheable
260       // shaders.
261       absl::StrAppend(result, (variable_type == "float" ? "0.0" : "0"), ";\n");
262       (*constant_id)++;
263     } else {
264       non_scalar_variables->push_back(variable);
265     }
266   }
267 
268   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::VulkanConstantGenerator269   void operator()(const std::vector<T>& v) const {
270     non_scalar_variables->push_back(variable);
271   }
272 
273   const Variable& variable;
274   int* const constant_id;
275   std::vector<Variable>* non_scalar_variables;
276   std::string* result;
277 };
278 
GenerateVulkanConstant(const Variable & variable,int * constant_id,std::vector<Variable> * non_scalar_variables,std::string * result)279 void GenerateVulkanConstant(const Variable& variable, int* constant_id,
280                             std::vector<Variable>* non_scalar_variables,
281                             std::string* result) {
282   absl::visit(VulkanConstantGenerator{variable, constant_id,
283                                       non_scalar_variables, result},
284               variable.value);
285 }
286 
287 class VulkanConstantsProcessor {
288  public:
ProcessVulkanConstant(const Variable & variable,std::string * result)289   void ProcessVulkanConstant(const Variable& variable, std::string* result) {
290     GenerateVulkanConstant(variable, &constant_id_, &non_scalar_variables_,
291                            result);
292   }
293 
GeneratePushConstantsDeclarations(std::string * result)294   void GeneratePushConstantsDeclarations(std::string* result) {
295     if (!non_scalar_variables_.empty()) {
296       *result += "\nlayout(push_constant) uniform pushConstants {\n";
297       for (const auto& variable : non_scalar_variables_) {
298         GenerateVulkanPushConstant(variable, result);
299       }
300       *result += "};\n";
301     }
302   }
303 
304  protected:
305   // Reserve the first three specialization constants slots for the
306   // workgroup size.
307   int constant_id_ = 3;
308   std::vector<Variable> non_scalar_variables_;
309 };
310 
311 // Returns true if value is a vector
IsVariableLength(const Variable::ValueType & value)312 bool IsVariableLength(const Variable::ValueType& value) {
313   return absl::visit(VariableLengthGetter(), value);
314 }
315 
316 enum Field : uint8_t { UNKNOWN = 4, X = 0, Y = 1, Z = 2, W = 3 };
317 
ToField(absl::string_view field_name)318 Field ToField(absl::string_view field_name) {
319   if (field_name.size() == 2 && field_name[0] == '.') {
320     switch (field_name[1]) {
321       case 'x':
322         return Field::X;
323       case 'y':
324         return Field::Y;
325       case 'z':
326         return Field::Z;
327       case 'w':
328         return Field::W;
329     }
330   }
331   return Field::UNKNOWN;
332 }
333 
334 struct FieldAccessor {
335   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::FieldAccessor336   void operator()(const T&) const {}
337 
338   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::FieldAccessor339   void operator()(const Vec2<T>& v) const {
340     FormatValue(result, v[field]);
341   }
342 
343   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::FieldAccessor344   void operator()(const Vec3<T>& v) const {
345     FormatValue(result, v[field]);
346   }
347 
348   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::FieldAccessor349   void operator()(const Vec4<T>& v) const {
350     FormatValue(result, v[field]);
351   }
352 
353   Field field;
354   std::string* result;
355 };
356 
357 // Appends formatted value of the given field.
GetValue(const Variable::ValueType & value,Field field,std::string * result)358 void GetValue(const Variable::ValueType& value, Field field,
359               std::string* result) {
360   absl::visit(FieldAccessor{field, result}, value);
361 }
362 
363 struct FieldChecker {
364   // For trivial as well as variable-length types indexed access is not allowed.
365   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::FieldChecker366   bool operator()(const T&) const {
367     return false;
368   }
369 
370   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::FieldChecker371   bool operator()(const Vec2<T>& v) const {
372     return field < v.size();
373   }
374 
375   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::FieldChecker376   bool operator()(const Vec3<T>& v) const {
377     return field < v.size();
378   }
379 
380   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::FieldChecker381   bool operator()(const Vec4<T>& v) const {
382     return field < v.size();
383   }
384 
385   template <typename T>
operator ()tflite::gpu::gl::__anonc6f577c20111::FieldChecker386   bool operator()(const std::vector<T>&) const {
387     // technically accessing [0] element of an empty vector is UB, but we need
388     // only type information for this check. Therefore, construct default T and
389     // use it instead.
390     T t;
391     return (*this)(t);
392   }
393 
394   Field field;
395 };
396 
397 // Returns true if field has field access and field is not out of bounds.
HasField(const Variable::ValueType & value,Field field)398 bool HasField(const Variable::ValueType& value, Field field) {
399   return absl::visit(FieldChecker{field}, value);
400 }
401 
AssembleAccessor(absl::string_view name,absl::string_view index,absl::string_view field,std::string * result)402 void AssembleAccessor(absl::string_view name, absl::string_view index,
403                       absl::string_view field, std::string* result) {
404   if (index.empty()) {
405     absl::StrAppend(result, name, field);
406   } else {
407     absl::StrAppend(result, name, "[", index, "]", field);
408   }
409 }
410 
411 }  // namespace
412 
Rewrite(absl::string_view input,std::string * output)413 RewriteStatus VariableAccessor::Rewrite(absl::string_view input,
414                                         std::string* output) {
415   auto ref = variable_accessor_internal::Parse(input);
416   if (ref.name.empty()) {
417     absl::StrAppend(output, "INVALID_SYNTAX");
418     return RewriteStatus::ERROR;
419   }
420 
421   auto it =
422       name_to_variable_.find(std::string(ref.name.data(), ref.name.size()));
423   if (it == name_to_variable_.end()) {
424     // Uniform with this name is not registered.
425     return RewriteStatus::NOT_RECOGNIZED;
426   }
427   const auto& value = it->second.value;
428 
429   if (!ref.index.empty() && !IsVariableLength(value)) {
430     // Trying to access variable by index, but it is not variable-length.
431     absl::StrAppend(output, "INVALID_ACCESS_BY_INDEX");
432     return RewriteStatus::ERROR;
433   }
434 
435   Field f = ToField(ref.field);
436   if (!ref.field.empty() && !HasField(value, f)) {
437     // Trying to access a variable by field, but it does not have it.
438     absl::StrAppend(output, "INVALID_ACCESS_BY_FIELD");
439     return RewriteStatus::ERROR;
440   }
441 
442   // Error checks are complete now.
443 
444   // All variable-length variables are encoded as-is without inlining.
445   if (!inline_values_ || IsVariableLength(value)) {
446     AssembleAccessor(it->second.name, ref.index, ref.field, output);
447   } else {
448     // Parameter + field is replaced with field value.
449     if (f != Field::UNKNOWN) {
450       GetValue(value, f, output);
451     } else {
452       // Parameter is accessed directly.
453       GetValue(value, output);
454     }
455   }
456   return RewriteStatus::SUCCESS;
457 }
458 
AddSharedVariable(Variable && variable)459 bool VariableAccessor::AddSharedVariable(Variable&& variable) {
460   const std::string name = variable.name;
461   if (!name_to_variable_.insert({name, std::move(variable)}).second) {
462     return false;
463   }
464   shared_variables_.insert(name);
465   return true;
466 }
467 
AddUniformParameter(Variable && variable)468 bool VariableAccessor::AddUniformParameter(Variable&& variable) {
469   const std::string name = variable.name;
470   if (!name_to_variable_.insert({name, std::move(variable)}).second) {
471     return false;
472   }
473   uniform_parameters_.insert(name);
474   return true;
475 }
476 
IsEmptyVariableLength(const Variable & variable) const477 bool VariableAccessor::IsEmptyVariableLength(const Variable& variable) const {
478   const auto& value = variable.value;
479   return IsVariableLength(value) && GetLength(value) == 0;
480 }
481 
GetConstDeclarations() const482 std::string VariableAccessor::GetConstDeclarations() const {
483   // Variable length variables are declared as const and accessed via variable
484   // with index.
485   std::string declarations;
486   for (const auto& variable : name_to_variable_) {
487     // Skip shared variables.
488     const std::string& variable_name = variable.second.name;
489     if (shared_variables_.find(variable_name) != shared_variables_.end()) {
490       continue;
491     }
492 
493     const auto& value = variable.second.value;
494     if (IsVariableLength(value)) {
495       absl::StrAppend(&declarations, "const ", GetVariableType(value), " ",
496                       variable_name, "[] = ");
497       GetValue(value, &declarations);
498       absl::StrAppend(&declarations, ";\n");
499     }
500   }
501   return declarations;
502 }
503 
GetSharedVariableDeclarations() const504 std::string VariableAccessor::GetSharedVariableDeclarations() const {
505   std::string declarations;
506   for (const auto& name : shared_variables_) {
507     const auto& variable = name_to_variable_.at(name);
508     GenerateSharedVariableDeclaration(variable, &declarations);
509   }
510   return declarations;
511 }
512 
GetUniformParameterDeclarations() const513 std::string VariableAccessor::GetUniformParameterDeclarations() const {
514   std::string declarations;
515   if (!inline_values_) {
516     if (vulkan_support_) {
517       VulkanConstantsProcessor processor;
518       for (const auto& name : uniform_parameters_) {
519         const auto& variable = name_to_variable_.at(name);
520         processor.ProcessVulkanConstant(variable, &declarations);
521       }
522       processor.GeneratePushConstantsDeclarations(&declarations);
523     } else {
524       for (const auto& name : uniform_parameters_) {
525         const auto& variable = name_to_variable_.at(name);
526         GenerateUniformParameterDeclaration(variable, &declarations);
527       }
528     }
529   }
530   return declarations;
531 }
532 
GetUniformParameters() const533 std::vector<Variable> VariableAccessor::GetUniformParameters() const {
534   std::vector<Variable> variables;
535   if (!inline_values_) {
536     variables.reserve(name_to_variable_.size());
537     // Keep the order of the variables consistent with that of the declarations
538     for (const auto& name : uniform_parameters_) {
539       const auto& variable = name_to_variable_.at(name);
540       variables.push_back(variable);
541     }
542   }
543   return variables;
544 }
545 
546 }  // namespace gl
547 }  // namespace gpu
548 }  // namespace tflite
549