1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.h"
17
18 #include <algorithm>
19
20 #include "absl/status/status.h"
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
23 #include "tensorflow/lite/delegates/gpu/common/status.h"
24 #include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h"
25 #include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h"
26 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
27
28 namespace tflite {
29 namespace gpu {
30 namespace gl {
31
ShaderCodegen(const CompilationOptions & options,const GpuInfo & gpu_info)32 ShaderCodegen::ShaderCodegen(const CompilationOptions& options,
33 const GpuInfo& gpu_info)
34 : options_(options), gpu_type_(gpu_info.vendor) {}
35
Build(CompiledNodeAttributes attr,ShaderCode * shader_code) const36 absl::Status ShaderCodegen::Build(CompiledNodeAttributes attr,
37 ShaderCode* shader_code) const {
38 VariableAccessor variable_accessor(options_.inline_parameters,
39 options_.vulkan_support);
40 ObjectAccessor object_accessor(gpu_type_ == GpuVendor::kMali,
41 options_.sampler_textures, &variable_accessor);
42
43 const auto add_object = [&](const std::string& name, Object&& object) {
44 if (!object_accessor.AddObject(name, std::forward<Object>(object))) {
45 return absl::AlreadyExistsError(absl::StrCat("Object \"", name, "\""));
46 }
47 return absl::OkStatus();
48 };
49
50 const auto add_uniform_parameter = [&](Variable&& variable) {
51 const std::string name = variable.name;
52 const Variable& const_ref = variable;
53 if (variable_accessor.IsEmptyVariableLength(const_ref)) {
54 return absl::InvalidArgumentError(
55 absl::StrCat("Empty uniform vector value \"", name, "\""));
56 }
57 if (!variable_accessor.AddUniformParameter(std::move(variable))) {
58 return absl::AlreadyExistsError(
59 absl::StrCat("Uniform parameter \"", name, "\""));
60 }
61 return absl::OkStatus();
62 };
63
64 for (auto&& object : attr.code.objects) {
65 RETURN_IF_ERROR(add_object(object.first, std::move(object.second)));
66 }
67
68 for (auto&& variable : attr.code.shared_variables) {
69 const std::string name = variable.name;
70 if (!variable_accessor.AddSharedVariable(std::move(variable))) {
71 return absl::AlreadyExistsError(
72 absl::StrCat("Shared variable \"", name, "\""));
73 }
74 }
75
76 for (auto&& variable : attr.code.parameters) {
77 RETURN_IF_ERROR(add_uniform_parameter(std::move(variable)));
78 }
79
80 int index = 0;
81 for (auto&& input : attr.inputs) {
82 RETURN_IF_ERROR(
83 add_object(absl::StrCat("input_data_", index++), std::move(input)));
84 }
85 index = 0;
86 for (auto&& output : attr.outputs) {
87 RETURN_IF_ERROR(
88 add_object(absl::StrCat("output_data_", index++), std::move(output)));
89 }
90
91 // TODO(akulik): workload params need to go away and be replaced with
92 // output_data_0_w
93 RETURN_IF_ERROR(add_uniform_parameter(
94 {"workload_x", static_cast<int32_t>(attr.code.workload.x)}));
95 RETURN_IF_ERROR(add_uniform_parameter(
96 {"workload_y", static_cast<int32_t>(attr.code.workload.y)}));
97 RETURN_IF_ERROR(add_uniform_parameter(
98 {"workload_z", static_cast<int32_t>(attr.code.workload.z)}));
99
100 // NOTE: If the shader has shared variables it will have to use barriers,
101 // which will conflict with a return at this stage.
102 // Let the user deal with the geometry constraints.
103 const bool has_shared_variables = !attr.code.shared_variables.empty();
104 std::string main_source_code = has_shared_variables ? R"(
105 ivec3 gid = ivec3(gl_GlobalInvocationID.xyz);
106 )"
107 : R"(
108 ivec3 gid = ivec3(gl_GlobalInvocationID.xyz);
109 if (gid.x >= $workload_x$ || gid.y >= $workload_y$ || gid.z >= $workload_z$) {
110 return;
111 }
112 )";
113
114 switch (attr.code.input) {
115 case IOStructure::ONLY_DEFINITIONS:
116 for (int i = 0; i < attr.inputs.size(); ++i) {
117 absl::StrAppend(&main_source_code, " highp vec4 value_", i,
118 " = vec4(0);\n");
119 }
120 break;
121 case IOStructure::AUTO: {
122 for (int i = 0; i < attr.inputs.size(); ++i) {
123 absl::StrAppend(&main_source_code, " highp vec4 value_", i,
124 " = $input_data_", i, "[gid.x, gid.y, gid.z]$;\n");
125 }
126 break;
127 }
128 }
129
130 main_source_code.append(attr.code.source_code);
131
132 if (attr.code.output == IOStructure::AUTO) {
133 for (int i = 0; i < attr.outputs.size(); ++i) {
134 absl::StrAppend(&main_source_code, " $output_data_", i,
135 "[gid.x, gid.y, gid.z] = value_", i, "$;\n");
136 }
137 }
138
139 // At this point main function is already generated. Now we need to process
140 // object and variable accessors.
141
142 // process objects first. Object accessor may introduce new uniform
143 // parameters that need to be rewritten in the subsequent pass.
144 {
145 TextPreprocessor preprocessor('$', /*keep_unknown_rewrites=*/true);
146 preprocessor.AddRewrite(&object_accessor);
147 RETURN_IF_ERROR(preprocessor.Rewrite(main_source_code, &main_source_code));
148 }
149
150 {
151 TextPreprocessor preprocessor('$', /*keep_unknown_rewrites=*/false);
152 preprocessor.AddRewrite(&variable_accessor);
153 RETURN_IF_ERROR(preprocessor.Rewrite(main_source_code, &main_source_code));
154 }
155
156 if (options_.inline_parameters) {
157 main_source_code = absl::StrCat(variable_accessor.GetConstDeclarations(),
158 main_source_code);
159 }
160
161 // partial_source_code is only missing the following which is added later:
162 // #version 310 es
163 // layout(local_size_x = ..., local_size_y = ..., local_size_z = ...) in;
164 const char* precision = options_.allow_precision_loss ? "mediump" : "highp";
165 const std::string partial_source_code = absl::StrCat(
166 "layout(std430) buffer;\n", //
167 "precision ", precision, " float;\n", //
168 object_accessor.GetFunctionsDeclarations(), "\n", //
169 object_accessor.GetObjectDeclarations(), "\n", //
170 variable_accessor.GetUniformParameterDeclarations(), "\n", //
171 variable_accessor.GetSharedVariableDeclarations(), "\n", //
172 "void main() {\n", //
173 main_source_code, //
174 "}");
175 *shader_code =
176 ShaderCode(variable_accessor.GetUniformParameters(),
177 object_accessor.GetObjects(), attr.code.workload,
178 attr.code.workgroup, partial_source_code, attr.node_indices);
179 return absl::OkStatus();
180 }
181
182 } // namespace gl
183 } // namespace gpu
184 } // namespace tflite
185