1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/compiler/rename.h"
17
18 #include <algorithm>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_join.h"
25 #include "absl/strings/str_split.h"
26 #include "tensorflow/lite/delegates/gpu/common/status.h"
27 #include "tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h"
28 #include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h"
29 #include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h"
30 #include "tensorflow/lite/delegates/gpu/gl/object.h"
31 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
32
33 namespace tflite {
34 namespace gpu {
35 namespace gl {
36 namespace {
37
38 // Rewrites names of all variables according to returned values from the
39 // given NameFunctor.
40 class VariableRewriter : public InlineRewrite {
41 public:
VariableRewriter(const std::string & inline_delimiter,const NameFunctor & name_func)42 VariableRewriter(const std::string& inline_delimiter,
43 const NameFunctor& name_func)
44 : inline_delimiter_(inline_delimiter), name_func_(name_func) {}
45
Rewrite(absl::string_view input,std::string * output)46 RewriteStatus Rewrite(absl::string_view input, std::string* output) final {
47 auto ref = variable_accessor_internal::Parse(input);
48 if (ref.name.empty()) {
49 absl::StrAppend(output, "INVALID_SYNTAX");
50 return RewriteStatus::ERROR;
51 }
52
53 auto it =
54 name_to_variable_.find(std::string(ref.name.data(), ref.name.size()));
55 if (it == name_to_variable_.end()) {
56 return RewriteStatus::NOT_RECOGNIZED;
57 }
58
59 // reconstruct access using the new name.
60 absl::StrAppend(output, inline_delimiter_, it->second.name);
61 if (!ref.index.empty()) {
62 absl::StrAppend(output, "[", ref.index, "]");
63 }
64 absl::StrAppend(output, ref.field, inline_delimiter_);
65 return RewriteStatus::SUCCESS;
66 }
67
68 // Return true if variable was successfully added.
AddVariable(Variable && variable)69 bool AddVariable(Variable&& variable) {
70 std::string old_name = variable.name;
71 variable.name = name_func_(old_name);
72 return name_to_variable_.insert({old_name, std::move(variable)}).second;
73 }
74
75 // Returns a collection of uniform parameters with updated names.
GetUniformParameters() const76 std::vector<Variable> GetUniformParameters() const {
77 std::vector<Variable> variables;
78 variables.reserve(name_to_variable_.size());
79 for (const auto& variable : name_to_variable_) {
80 variables.push_back(variable.second);
81 }
82 return variables;
83 }
84
85 private:
86 const std::string inline_delimiter_;
87 const NameFunctor name_func_;
88
89 absl::flat_hash_map<std::string, Variable> name_to_variable_;
90 };
91
92 // Rewrites names of all objects according to returned values from the
93 // given NameFunctor.
94 class ObjectRewriter : public InlineRewrite {
95 public:
ObjectRewriter(const std::string & inline_delimiter,const NameFunctor & name_func)96 ObjectRewriter(const std::string& inline_delimiter,
97 const NameFunctor& name_func)
98 : inline_delimiter_(inline_delimiter), name_func_(name_func) {}
99
Rewrite(absl::string_view input,std::string * output)100 RewriteStatus Rewrite(absl::string_view input, std::string* output) final {
101 // Splits 'a = b' into {'a','b'}.
102 std::pair<absl::string_view, absl::string_view> n =
103 absl::StrSplit(input, absl::MaxSplits('=', 1), absl::SkipWhitespace());
104 if (n.first.empty()) {
105 return RewriteStatus::NOT_RECOGNIZED;
106 }
107
108 if (n.second.empty()) {
109 return RewriteRead(absl::StripAsciiWhitespace(n.first), output);
110 }
111 return RewriteWrite(absl::StripAsciiWhitespace(n.first),
112 absl::StripAsciiWhitespace(n.second), output);
113 }
114
115 // Return true if an object was successfully added.
AddObject(const std::string & name,Object object)116 bool AddObject(const std::string& name, Object object) {
117 std::string new_name = name_func_(name);
118 return name_to_object_.insert({name, {new_name, std::move(object)}}).second;
119 }
120
121 // Returns a collection of registered objects with updated names.
GetObjects() const122 std::vector<std::pair<std::string, Object>> GetObjects() const {
123 std::vector<std::pair<std::string, Object>> objects;
124 objects.reserve(name_to_object_.size());
125 for (const auto& o : name_to_object_) {
126 objects.push_back(o.second);
127 }
128 return objects;
129 }
130
131 private:
RewriteRead(absl::string_view location,std::string * output)132 RewriteStatus RewriteRead(absl::string_view location, std::string* output) {
133 auto element = object_accessor_internal::ParseElement(location);
134 if (element.object_name.empty()) {
135 absl::StrAppend(output, "UNABLE_TO_PARSE_INDEXED_ELEMENT");
136 return RewriteStatus::ERROR;
137 }
138 auto it = name_to_object_.find(
139 std::string(element.object_name.data(), element.object_name.size()));
140 if (it == name_to_object_.end()) {
141 return RewriteStatus::NOT_RECOGNIZED;
142 }
143 absl::StrAppend(output, inline_delimiter_, it->second.first, "[",
144 absl::StrJoin(element.indices, ","), "]",
145 inline_delimiter_);
146 return RewriteStatus::SUCCESS;
147 }
148
RewriteWrite(absl::string_view location,absl::string_view value,std::string * output)149 RewriteStatus RewriteWrite(absl::string_view location,
150 absl::string_view value, std::string* output) {
151 // name[index1, index2...] = value
152 auto element = object_accessor_internal::ParseElement(location);
153 if (element.object_name.empty()) {
154 absl::StrAppend(output, "UNABLE_TO_PARSE_INDEXED_ELEMENT");
155 return RewriteStatus::ERROR;
156 }
157 auto it = name_to_object_.find(
158 std::string(element.object_name.data(), element.object_name.size()));
159 if (it == name_to_object_.end()) {
160 return RewriteStatus::NOT_RECOGNIZED;
161 }
162 absl::StrAppend(output, inline_delimiter_, it->second.first, "[",
163 absl::StrJoin(element.indices, ","), "] = ", value,
164 inline_delimiter_);
165 return RewriteStatus::SUCCESS;
166 }
167
168 const std::string inline_delimiter_;
169 const NameFunctor name_func_;
170
171 absl::flat_hash_map<std::string, std::pair<std::string, Object>>
172 name_to_object_;
173 };
174
175 } // namespace
176
Rename(const NameFunctor & name_func,GeneratedCode * code)177 absl::Status Rename(const NameFunctor& name_func, GeneratedCode* code) {
178 VariableRewriter variable_rewriter("$", name_func);
179 ObjectRewriter object_rewriter("$", name_func);
180 for (auto&& uniform_parameter : code->parameters) {
181 if (!variable_rewriter.AddVariable(std::move(uniform_parameter))) {
182 return absl::InternalError("Variable name already exists");
183 }
184 }
185 for (auto&& object : code->objects) {
186 if (!object_rewriter.AddObject(object.first, std::move(object.second))) {
187 return absl::InternalError("Object name already exists");
188 }
189 }
190 TextPreprocessor preprocessor('$', /*keep_unknown_rewrites=*/true);
191 preprocessor.AddRewrite(&variable_rewriter);
192 preprocessor.AddRewrite(&object_rewriter);
193 std::string source_code;
194 RETURN_IF_ERROR(preprocessor.Rewrite(code->source_code, &source_code));
195 code->source_code = source_code;
196 code->parameters = variable_rewriter.GetUniformParameters();
197 code->objects = object_rewriter.GetObjects();
198 return absl::OkStatus();
199 }
200
201 } // namespace gl
202 } // namespace gpu
203 } // namespace tflite
204