1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/compiler/fuse_inplace.h"
17
18 #include <cstring>
19 #include <string>
20
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_replace.h"
23 #include "absl/strings/string_view.h"
24 #include "absl/types/any.h"
25 #include "tensorflow/lite/delegates/gpu/common/model.h"
26 #include "tensorflow/lite/delegates/gpu/common/types.h"
27 #include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h"
28 #include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h"
29 #include "tensorflow/lite/delegates/gpu/gl/node_shader.h"
30
31 namespace tflite {
32 namespace gpu {
33 namespace gl {
34 namespace {
35
36 static const char* kInplacePrefix = "inplace_update:\0";
37
38 class EmptyInplaceRewrite : public InlineRewrite {
39 public:
Rewrite(absl::string_view input,std::string * output)40 RewriteStatus Rewrite(absl::string_view input, std::string* output) final {
41 if (input.compare(0, strlen(kInplacePrefix), kInplacePrefix) == 0) {
42 num_rewrites_++;
43 return RewriteStatus::SUCCESS;
44 }
45 return RewriteStatus::NOT_RECOGNIZED;
46 }
47
num_rewrites() const48 int num_rewrites() const { return num_rewrites_; }
49
50 private:
51 int num_rewrites_ = 0;
52 };
53
54 // Takes a code as an input. Replaces 'value_0' in the code with a value that
55 // comes in a rewrite. For example:
56 // code: value_0 = max(value_0, 0);
57 // rewrite: inplace_update:result_12 -> result_12 = max(result_12, 0);
58 //
59 class InplaceCodeRewrite : public InlineRewrite {
60 public:
InplaceCodeRewrite(const std::string & code)61 explicit InplaceCodeRewrite(const std::string& code) : code_(code) {}
62
Rewrite(absl::string_view input,std::string * output)63 RewriteStatus Rewrite(absl::string_view input, std::string* output) final {
64 int len = strlen(kInplacePrefix);
65 if (input.compare(0, len, kInplacePrefix) == 0) {
66 auto variable_name = input.substr(len);
67 absl::StrAppend(output,
68 absl::StrReplaceAll(code_, {{"value_0", variable_name}}));
69 return RewriteStatus::SUCCESS;
70 }
71 return RewriteStatus::NOT_RECOGNIZED;
72 }
73
74 private:
75 std::string code_;
76 };
77
78 } // namespace
79
ApplyToNode(Node * node,GraphFloat32 * graph)80 TransformResult RemoveUnusedInplaceUpdates::ApplyToNode(Node* node,
81 GraphFloat32* graph) {
82 auto& attr =
83 absl::any_cast<CompiledNodeAttributes&>(node->operation.attributes);
84 // Remove inplace block by rewriting to empty string.
85 EmptyInplaceRewrite rewrite;
86 TextPreprocessor preprocessor('$', true);
87 preprocessor.AddRewrite(&rewrite);
88 if (!preprocessor.Rewrite(attr.code.source_code, &attr.code.source_code)
89 .ok()) {
90 return {TransformStatus::INVALID, ""};
91 }
92 return {rewrite.num_rewrites() > 0 ? TransformStatus::APPLIED
93 : TransformStatus::SKIPPED,
94 ""};
95 }
96
ApplyToNodesSequence(const std::vector<Node * > & sequence,GraphFloat32 * graph)97 TransformResult FuseInplaceUpdate::ApplyToNodesSequence(
98 const std::vector<Node*>& sequence, GraphFloat32* graph) {
99 Node* node1 = sequence.front();
100 Node* node2 = sequence.back();
101 auto& attr1 =
102 absl::any_cast<CompiledNodeAttributes&>(node1->operation.attributes);
103 auto& attr2 =
104 absl::any_cast<CompiledNodeAttributes&>(node2->operation.attributes);
105
106 if (graph->FindInputs(node2->id).size() != 1 ||
107 graph->FindOutputs(node2->id).size() != 1 ||
108 attr2.code.output != IOStructure::AUTO ||
109 attr2.code.input != IOStructure::AUTO ||
110 (attr1.code.workload != attr2.code.workload &&
111 uint3() != attr2.code.workload)) {
112 return {TransformStatus::SKIPPED, ""};
113 }
114
115 // First count of replaces that would happen to check whether rewrite is
116 // needed.
117 {
118 EmptyInplaceRewrite counting_rewrite;
119 TextPreprocessor preprocessor('$', true);
120 preprocessor.AddRewrite(&counting_rewrite);
121 std::string temp;
122 if (!preprocessor.Rewrite(attr1.code.source_code, &temp).ok()) {
123 return {TransformStatus::INVALID, ""};
124 }
125 // no rewrites in the source code. skip it.
126 if (counting_rewrite.num_rewrites() == 0) {
127 return {TransformStatus::SKIPPED, ""};
128 }
129 }
130 if (!MergeCode(&attr2, &attr1).ok()) {
131 return {TransformStatus::INVALID, "Unable to merge two nodes"};
132 }
133 TextPreprocessor preprocessor('$', true);
134 InplaceCodeRewrite rewrite(attr2.code.source_code);
135 preprocessor.AddRewrite(&rewrite);
136 if (!preprocessor.Rewrite(attr1.code.source_code, &attr1.code.source_code)
137 .ok()) {
138 return {TransformStatus::INVALID, ""};
139 }
140 node1->operation.type += "+" + node2->operation.type;
141
142 if (!RemoveFollowingNode(graph, node2, node1).ok()) {
143 return {TransformStatus::INVALID,
144 "Unable to remove node " + std::to_string(node2->id)};
145 }
146 return {TransformStatus::APPLIED, ""};
147 }
148
149 } // namespace gl
150 } // namespace gpu
151 } // namespace tflite
152