1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/compiler/fuse_inplace.h"
17 
18 #include <cstring>
19 #include <string>
20 
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_replace.h"
23 #include "absl/strings/string_view.h"
24 #include "absl/types/any.h"
25 #include "tensorflow/lite/delegates/gpu/common/model.h"
26 #include "tensorflow/lite/delegates/gpu/common/types.h"
27 #include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h"
28 #include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h"
29 #include "tensorflow/lite/delegates/gpu/gl/node_shader.h"
30 
31 namespace tflite {
32 namespace gpu {
33 namespace gl {
34 namespace {
35 
36 static const char* kInplacePrefix = "inplace_update:\0";
37 
38 class EmptyInplaceRewrite : public InlineRewrite {
39  public:
Rewrite(absl::string_view input,std::string * output)40   RewriteStatus Rewrite(absl::string_view input, std::string* output) final {
41     if (input.compare(0, strlen(kInplacePrefix), kInplacePrefix) == 0) {
42       num_rewrites_++;
43       return RewriteStatus::SUCCESS;
44     }
45     return RewriteStatus::NOT_RECOGNIZED;
46   }
47 
num_rewrites() const48   int num_rewrites() const { return num_rewrites_; }
49 
50  private:
51   int num_rewrites_ = 0;
52 };
53 
54 // Takes a code as an input. Replaces 'value_0' in the code with a value that
55 // comes in a rewrite. For example:
56 //   code:    value_0 = max(value_0, 0);
57 //   rewrite: inplace_update:result_12 -> result_12 = max(result_12, 0);
58 //
59 class InplaceCodeRewrite : public InlineRewrite {
60  public:
InplaceCodeRewrite(const std::string & code)61   explicit InplaceCodeRewrite(const std::string& code) : code_(code) {}
62 
Rewrite(absl::string_view input,std::string * output)63   RewriteStatus Rewrite(absl::string_view input, std::string* output) final {
64     int len = strlen(kInplacePrefix);
65     if (input.compare(0, len, kInplacePrefix) == 0) {
66       auto variable_name = input.substr(len);
67       absl::StrAppend(output,
68                       absl::StrReplaceAll(code_, {{"value_0", variable_name}}));
69       return RewriteStatus::SUCCESS;
70     }
71     return RewriteStatus::NOT_RECOGNIZED;
72   }
73 
74  private:
75   std::string code_;
76 };
77 
78 }  // namespace
79 
ApplyToNode(Node * node,GraphFloat32 * graph)80 TransformResult RemoveUnusedInplaceUpdates::ApplyToNode(Node* node,
81                                                         GraphFloat32* graph) {
82   auto& attr =
83       absl::any_cast<CompiledNodeAttributes&>(node->operation.attributes);
84   // Remove inplace block by rewriting to empty string.
85   EmptyInplaceRewrite rewrite;
86   TextPreprocessor preprocessor('$', true);
87   preprocessor.AddRewrite(&rewrite);
88   if (!preprocessor.Rewrite(attr.code.source_code, &attr.code.source_code)
89            .ok()) {
90     return {TransformStatus::INVALID, ""};
91   }
92   return {rewrite.num_rewrites() > 0 ? TransformStatus::APPLIED
93                                      : TransformStatus::SKIPPED,
94           ""};
95 }
96 
ApplyToNodesSequence(const std::vector<Node * > & sequence,GraphFloat32 * graph)97 TransformResult FuseInplaceUpdate::ApplyToNodesSequence(
98     const std::vector<Node*>& sequence, GraphFloat32* graph) {
99   Node* node1 = sequence.front();
100   Node* node2 = sequence.back();
101   auto& attr1 =
102       absl::any_cast<CompiledNodeAttributes&>(node1->operation.attributes);
103   auto& attr2 =
104       absl::any_cast<CompiledNodeAttributes&>(node2->operation.attributes);
105 
106   if (graph->FindInputs(node2->id).size() != 1 ||
107       graph->FindOutputs(node2->id).size() != 1 ||
108       attr2.code.output != IOStructure::AUTO ||
109       attr2.code.input != IOStructure::AUTO ||
110       (attr1.code.workload != attr2.code.workload &&
111        uint3() != attr2.code.workload)) {
112     return {TransformStatus::SKIPPED, ""};
113   }
114 
115   // First count of replaces that would happen to check whether rewrite is
116   // needed.
117   {
118     EmptyInplaceRewrite counting_rewrite;
119     TextPreprocessor preprocessor('$', true);
120     preprocessor.AddRewrite(&counting_rewrite);
121     std::string temp;
122     if (!preprocessor.Rewrite(attr1.code.source_code, &temp).ok()) {
123       return {TransformStatus::INVALID, ""};
124     }
125     // no rewrites in the source code. skip it.
126     if (counting_rewrite.num_rewrites() == 0) {
127       return {TransformStatus::SKIPPED, ""};
128     }
129   }
130   if (!MergeCode(&attr2, &attr1).ok()) {
131     return {TransformStatus::INVALID, "Unable to merge two nodes"};
132   }
133   TextPreprocessor preprocessor('$', true);
134   InplaceCodeRewrite rewrite(attr2.code.source_code);
135   preprocessor.AddRewrite(&rewrite);
136   if (!preprocessor.Rewrite(attr1.code.source_code, &attr1.code.source_code)
137            .ok()) {
138     return {TransformStatus::INVALID, ""};
139   }
140   node1->operation.type += "+" + node2->operation.type;
141 
142   if (!RemoveFollowingNode(graph, node2, node1).ok()) {
143     return {TransformStatus::INVALID,
144             "Unable to remove node " + std::to_string(node2->id)};
145   }
146   return {TransformStatus::APPLIED, ""};
147 }
148 
149 }  // namespace gl
150 }  // namespace gpu
151 }  // namespace tflite
152