1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.h"
17 
18 #include <string>
19 #include <vector>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_replace.h"
24 #include "absl/types/any.h"
25 #include "absl/types/variant.h"
26 #include "tensorflow/lite/delegates/gpu/common/model.h"
27 #include "tensorflow/lite/delegates/gpu/common/operations.h"
28 #include "tensorflow/lite/delegates/gpu/common/types.h"
29 #include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h"
30 
31 namespace tflite {
32 namespace gpu {
33 namespace gl {
34 namespace {
35 
MakeValueReplacement(int n,int k)36 std::pair<std::string, std::string> MakeValueReplacement(int n, int k) {
37   return {absl::StrCat("value_", n), absl::StrCat("value_", k)};
38 }
39 
MakeDataReplacement(int n,int k)40 std::pair<std::string, std::string> MakeDataReplacement(int n, int k) {
41   return {absl::StrCat("input_data_", n), absl::StrCat("input_data_", k)};
42 }
43 
44 }  // namespace
45 
ApplyToNode(Node * node,GraphFloat32 * graph)46 TransformResult FuseAutoInput::ApplyToNode(Node* node, GraphFloat32* graph) {
47   auto& node_attr =
48       absl::any_cast<CompiledNodeAttributes&>(node->operation.attributes);
49   auto& node_code = node_attr.code;
50 
51   if (node_code.input != IOStructure::AUTO) {
52     return {TransformStatus::SKIPPED, ""};
53   }
54   uint3 workgroup = node_code.workgroup;
55 
56   auto node_outputs = graph->FindOutputs(node->id);
57 
58   // Check which inputs could be fused into the current node.
59   std::vector<std::pair<Node*, int>> nodes_to_fuse;
60   std::vector<std::pair<ValueId, int>> input_values;
61   int input_num = -1;
62   for (auto input_value : graph->FindInputs(node->id)) {
63     input_num++;
64     const ValueId input_id = input_value->id;
65     input_values.push_back({input_id, input_num});
66 
67     if (graph->FindConsumers(input_id).size() > 1) {
68       continue;  // input is consumed by >1 nodes
69     }
70     Node* input_producer = graph->FindProducer(input_id);
71     if (input_producer == nullptr) {
72       continue;  // graph's input
73     }
74     if (graph->FindOutputs(input_producer->id).size() != 1) {
75       continue;  // input node has more than one output
76     }
77     auto& input_producer_attr = absl::any_cast<const CompiledNodeAttributes&>(
78         input_producer->operation.attributes);
79     if (input_producer_attr.code.output != IOStructure::AUTO) {
80       continue;
81     }
82     if (input_producer_attr.code.workload != node_code.workload &&
83         uint3() != input_producer_attr.code.workload) {
84       continue;
85     }
86     if (input_producer_attr.code.workgroup != uint3()) {
87       // New fused node should fuse only a single shader that has pre-defined
88       // workgroup. Such shader is considered "heavy". Do not fuse two heavy
89       // shaders into one.
90       // TODO(eignasheva): make sure it still works.
91       if (workgroup != uint3()) {
92         continue;
93       }
94       workgroup = input_producer_attr.code.workgroup;
95     }
96     nodes_to_fuse.push_back({input_producer, input_num});
97     input_values.pop_back();  // this value will not be used as input.
98   }
99   if (nodes_to_fuse.empty()) {
100     return {TransformStatus::SKIPPED, ""};
101   }
102 
103   // Skip fusions which will result in duplicate inputs, e.g. diamond shapes.
104   {
105     absl::flat_hash_set<ValueId> all_inputs;
106     for (const auto& node_to_fuse : nodes_to_fuse) {
107       for (const auto& input : graph->FindInputs(node_to_fuse.first->id)) {
108         if (all_inputs.find(input->id) != all_inputs.end()) {
109           return {TransformStatus::SKIPPED, ""};
110         }
111         all_inputs.insert(input->id);
112       }
113     }
114     for (const auto& input : graph->FindInputs(node->id)) {
115       if (all_inputs.find(input->id) != all_inputs.end()) {
116         return {TransformStatus::SKIPPED, ""};
117       }
118       all_inputs.insert(input->id);
119     }
120   }
121 
122   // Break connections between current node and its inputs.
123   for (auto value : graph->FindInputs(node->id)) {
124     if (!graph->RemoveConsumer(node->id, value->id).ok()) {
125       return {TransformStatus::INVALID, ""};
126     }
127   }
128 
129   std::string operation_type;
130   std::string source_code;
131   std::string values;
132 
133   // Node source code need to be appended later to the end.
134   std::swap(source_code, node_code.source_code);
135 
136   // Indicates value_k that is beyond originally declared [0..n] values,
137   // therefore, it can be used by newly added dependencies.
138   int extra_input_num = input_num;
139   input_num = 0;
140 
141   // Fuse all nodes into one.
142   for (auto input_and_num : nodes_to_fuse) {
143     auto& input = input_and_num.first;
144     auto& attr =
145         absl::any_cast<CompiledNodeAttributes&>(input->operation.attributes);
146     auto super_inputs = graph->FindInputs(input->id);
147 
148     // Replace all internal references in the input source code. For example:
149     // source code "value_0 = max(0, value_0);" will be rewritten into
150     // "value_2 = max(0, value_2);"
151     std::vector<std::pair<std::string, std::string>> replacements;
152     for (int i = 0; i < super_inputs.size(); ++i) {
153       // Node source code uses value_N to access output value from the fused
154       // node. Use correct reference.
155       //
156       // Here value_N does not correspond to input_N anymore. Instead it tracks
157       // value_n and input_m independently. Value_index uses an index needed
158       // for the "final" shader, while input_num preserves the order of inputs.
159       // For example:
160       //    Shader A: input_0, input_1
161       //    value_0 = value_0 > value_1 ? value_0 : value_1;
162       //
163       //    Shader B:  input_0
164       //    value_0 = max(0, value_0);
165       //
166       //    AddShader: input_0, input_1
167       //    value_0 = value_0 + value_1;
168       //
169       //    Fused shader is going to have 3 inputs: input_0 (A), input_1 (A),
170       //    input_2 (B). But Shader B need to store result in value_1, because
171       //    AddShader refers to it as 'value_1'. So, fused shader will look as
172       //    follows:
173       //
174       //    // Shader A
175       //    vec4 value_0 = input_data_0.data[gid.x, gid.y, gid.z];
176       //    vec4 value_2 = input_data_1.data[gid.x, gid.y, gid.z];
177       //    value_0 = value_0 > value_2 ? value_0 : value_2;
178       //
179       //    // Shader B
180       //    vec4 value_1 = input_data_2.data[gid.x, gid.y, gid.z];
181       //    value_1 = max(0, value_1);
182       //
183       //    // AddShader
184       //    value_0 = value_0 + value_1;
185       //
186       //    output_data_0.data[gid.x, gid.y, gid.z] = value_0;
187       int value_index = i == 0 ? input_and_num.second : ++extra_input_num;
188       replacements.push_back(MakeValueReplacement(i, value_index));
189       replacements.push_back(MakeDataReplacement(i, input_num));
190 
191       // Declare input values based on the input structure of the merged node.
192       // This code copies what shader_codegen would do automatically.
193       if (attr.code.input == IOStructure::AUTO) {
194         absl::StrAppend(&values, "  value_", value_index, " = $input_data_",
195                         input_num, "[gid.x, gid.y, gid.z]$;\n");
196       }
197 
198       if (!graph->AddConsumer(node->id, super_inputs[i]->id).ok()) {
199         return {TransformStatus::INVALID, ""};
200       }
201       input_num++;
202     }
203 
204     // Also rename all _h and _w parameters to the new names.
205     for (auto& param : attr.code.parameters) {
206       param.name = absl::StrReplaceAll(param.name, replacements);
207     }
208     attr.code.source_code =
209         absl::StrReplaceAll(attr.code.source_code, replacements);
210 
211     // Merge all objects, parameters and source code.
212     if (!MergeCode(&attr, &node_attr).ok()) {
213       return {TransformStatus::INVALID, "Unable to merge the code"};
214     }
215     absl::StrAppend(&node_attr.code.source_code, "{\n", attr.code.source_code,
216                     "\n}");
217 
218     if (!operation_type.empty()) {
219       operation_type += ",";
220     }
221     operation_type += input->operation.type;
222 
223     if (!graph->DeleteNode(input->id).ok()) {
224       return {TransformStatus::INVALID, ""};
225     }
226   }
227 
228   // Add back all inputs that are used directly by the fused node.
229   for (int i = 0; i < input_values.size(); i++) {
230     if (node_code.input == IOStructure::AUTO) {
231       absl::StrAppend(&values, "  value_", input_values[i].second,
232                       " = $input_data_", input_num,
233                       "[gid.x, gid.y, gid.z]$;\n");
234     }
235     if (!graph->AddConsumer(node->id, input_values[i].first).ok()) {
236       return {TransformStatus::INVALID, ""};
237     }
238     input_num++;
239   }
240 
241   node_code.input = IOStructure::ONLY_DEFINITIONS;
242 
243   absl::StrAppend(&node->operation.type, "(", operation_type, ")");
244   node_code.source_code =
245       absl::StrCat(values, node_code.source_code, "{//FUSED",
246                    node->operation.type, "\n", source_code, "\n}");
247 
248   return {TransformStatus::APPLIED, ""};
249 }
250 
251 }  // namespace gl
252 }  // namespace gpu
253 }  // namespace tflite
254