1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.h"
17
18 #include <string>
19 #include <vector>
20
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_replace.h"
24 #include "absl/types/any.h"
25 #include "absl/types/variant.h"
26 #include "tensorflow/lite/delegates/gpu/common/model.h"
27 #include "tensorflow/lite/delegates/gpu/common/operations.h"
28 #include "tensorflow/lite/delegates/gpu/common/types.h"
29 #include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h"
30
31 namespace tflite {
32 namespace gpu {
33 namespace gl {
34 namespace {
35
MakeValueReplacement(int n,int k)36 std::pair<std::string, std::string> MakeValueReplacement(int n, int k) {
37 return {absl::StrCat("value_", n), absl::StrCat("value_", k)};
38 }
39
MakeDataReplacement(int n,int k)40 std::pair<std::string, std::string> MakeDataReplacement(int n, int k) {
41 return {absl::StrCat("input_data_", n), absl::StrCat("input_data_", k)};
42 }
43
44 } // namespace
45
ApplyToNode(Node * node,GraphFloat32 * graph)46 TransformResult FuseAutoInput::ApplyToNode(Node* node, GraphFloat32* graph) {
47 auto& node_attr =
48 absl::any_cast<CompiledNodeAttributes&>(node->operation.attributes);
49 auto& node_code = node_attr.code;
50
51 if (node_code.input != IOStructure::AUTO) {
52 return {TransformStatus::SKIPPED, ""};
53 }
54 uint3 workgroup = node_code.workgroup;
55
56 auto node_outputs = graph->FindOutputs(node->id);
57
58 // Check which inputs could be fused into the current node.
59 std::vector<std::pair<Node*, int>> nodes_to_fuse;
60 std::vector<std::pair<ValueId, int>> input_values;
61 int input_num = -1;
62 for (auto input_value : graph->FindInputs(node->id)) {
63 input_num++;
64 const ValueId input_id = input_value->id;
65 input_values.push_back({input_id, input_num});
66
67 if (graph->FindConsumers(input_id).size() > 1) {
68 continue; // input is consumed by >1 nodes
69 }
70 Node* input_producer = graph->FindProducer(input_id);
71 if (input_producer == nullptr) {
72 continue; // graph's input
73 }
74 if (graph->FindOutputs(input_producer->id).size() != 1) {
75 continue; // input node has more than one output
76 }
77 auto& input_producer_attr = absl::any_cast<const CompiledNodeAttributes&>(
78 input_producer->operation.attributes);
79 if (input_producer_attr.code.output != IOStructure::AUTO) {
80 continue;
81 }
82 if (input_producer_attr.code.workload != node_code.workload &&
83 uint3() != input_producer_attr.code.workload) {
84 continue;
85 }
86 if (input_producer_attr.code.workgroup != uint3()) {
87 // New fused node should fuse only a single shader that has pre-defined
88 // workgroup. Such shader is considered "heavy". Do not fuse two heavy
89 // shaders into one.
90 // TODO(eignasheva): make sure it still works.
91 if (workgroup != uint3()) {
92 continue;
93 }
94 workgroup = input_producer_attr.code.workgroup;
95 }
96 nodes_to_fuse.push_back({input_producer, input_num});
97 input_values.pop_back(); // this value will not be used as input.
98 }
99 if (nodes_to_fuse.empty()) {
100 return {TransformStatus::SKIPPED, ""};
101 }
102
103 // Skip fusions which will result in duplicate inputs, e.g. diamond shapes.
104 {
105 absl::flat_hash_set<ValueId> all_inputs;
106 for (const auto& node_to_fuse : nodes_to_fuse) {
107 for (const auto& input : graph->FindInputs(node_to_fuse.first->id)) {
108 if (all_inputs.find(input->id) != all_inputs.end()) {
109 return {TransformStatus::SKIPPED, ""};
110 }
111 all_inputs.insert(input->id);
112 }
113 }
114 for (const auto& input : graph->FindInputs(node->id)) {
115 if (all_inputs.find(input->id) != all_inputs.end()) {
116 return {TransformStatus::SKIPPED, ""};
117 }
118 all_inputs.insert(input->id);
119 }
120 }
121
122 // Break connections between current node and its inputs.
123 for (auto value : graph->FindInputs(node->id)) {
124 if (!graph->RemoveConsumer(node->id, value->id).ok()) {
125 return {TransformStatus::INVALID, ""};
126 }
127 }
128
129 std::string operation_type;
130 std::string source_code;
131 std::string values;
132
133 // Node source code need to be appended later to the end.
134 std::swap(source_code, node_code.source_code);
135
136 // Indicates value_k that is beyond originally declared [0..n] values,
137 // therefore, it can be used by newly added dependencies.
138 int extra_input_num = input_num;
139 input_num = 0;
140
141 // Fuse all nodes into one.
142 for (auto input_and_num : nodes_to_fuse) {
143 auto& input = input_and_num.first;
144 auto& attr =
145 absl::any_cast<CompiledNodeAttributes&>(input->operation.attributes);
146 auto super_inputs = graph->FindInputs(input->id);
147
148 // Replace all internal references in the input source code. For example:
149 // source code "value_0 = max(0, value_0);" will be rewritten into
150 // "value_2 = max(0, value_2);"
151 std::vector<std::pair<std::string, std::string>> replacements;
152 for (int i = 0; i < super_inputs.size(); ++i) {
153 // Node source code uses value_N to access output value from the fused
154 // node. Use correct reference.
155 //
156 // Here value_N does not correspond to input_N anymore. Instead it tracks
157 // value_n and input_m independently. Value_index uses an index needed
158 // for the "final" shader, while input_num preserves the order of inputs.
159 // For example:
160 // Shader A: input_0, input_1
161 // value_0 = value_0 > value_1 ? value_0 : value_1;
162 //
163 // Shader B: input_0
164 // value_0 = max(0, value_0);
165 //
166 // AddShader: input_0, input_1
167 // value_0 = value_0 + value_1;
168 //
169 // Fused shader is going to have 3 inputs: input_0 (A), input_1 (A),
170 // input_2 (B). But Shader B need to store result in value_1, because
171 // AddShader refers to it as 'value_1'. So, fused shader will look as
172 // follows:
173 //
174 // // Shader A
175 // vec4 value_0 = input_data_0.data[gid.x, gid.y, gid.z];
176 // vec4 value_2 = input_data_1.data[gid.x, gid.y, gid.z];
177 // value_0 = value_0 > value_2 ? value_0 : value_2;
178 //
179 // // Shader B
180 // vec4 value_1 = input_data_2.data[gid.x, gid.y, gid.z];
181 // value_1 = max(0, value_1);
182 //
183 // // AddShader
184 // value_0 = value_0 + value_1;
185 //
186 // output_data_0.data[gid.x, gid.y, gid.z] = value_0;
187 int value_index = i == 0 ? input_and_num.second : ++extra_input_num;
188 replacements.push_back(MakeValueReplacement(i, value_index));
189 replacements.push_back(MakeDataReplacement(i, input_num));
190
191 // Declare input values based on the input structure of the merged node.
192 // This code copies what shader_codegen would do automatically.
193 if (attr.code.input == IOStructure::AUTO) {
194 absl::StrAppend(&values, " value_", value_index, " = $input_data_",
195 input_num, "[gid.x, gid.y, gid.z]$;\n");
196 }
197
198 if (!graph->AddConsumer(node->id, super_inputs[i]->id).ok()) {
199 return {TransformStatus::INVALID, ""};
200 }
201 input_num++;
202 }
203
204 // Also rename all _h and _w parameters to the new names.
205 for (auto& param : attr.code.parameters) {
206 param.name = absl::StrReplaceAll(param.name, replacements);
207 }
208 attr.code.source_code =
209 absl::StrReplaceAll(attr.code.source_code, replacements);
210
211 // Merge all objects, parameters and source code.
212 if (!MergeCode(&attr, &node_attr).ok()) {
213 return {TransformStatus::INVALID, "Unable to merge the code"};
214 }
215 absl::StrAppend(&node_attr.code.source_code, "{\n", attr.code.source_code,
216 "\n}");
217
218 if (!operation_type.empty()) {
219 operation_type += ",";
220 }
221 operation_type += input->operation.type;
222
223 if (!graph->DeleteNode(input->id).ok()) {
224 return {TransformStatus::INVALID, ""};
225 }
226 }
227
228 // Add back all inputs that are used directly by the fused node.
229 for (int i = 0; i < input_values.size(); i++) {
230 if (node_code.input == IOStructure::AUTO) {
231 absl::StrAppend(&values, " value_", input_values[i].second,
232 " = $input_data_", input_num,
233 "[gid.x, gid.y, gid.z]$;\n");
234 }
235 if (!graph->AddConsumer(node->id, input_values[i].first).ok()) {
236 return {TransformStatus::INVALID, ""};
237 }
238 input_num++;
239 }
240
241 node_code.input = IOStructure::ONLY_DEFINITIONS;
242
243 absl::StrAppend(&node->operation.type, "(", operation_type, ")");
244 node_code.source_code =
245 absl::StrCat(values, node_code.source_code, "{//FUSED",
246 node->operation.type, "\n", source_code, "\n}");
247
248 return {TransformStatus::APPLIED, ""};
249 }
250
251 } // namespace gl
252 } // namespace gpu
253 } // namespace tflite
254