1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/compiler/rename.h"
17 
18 #include <algorithm>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_join.h"
25 #include "absl/strings/str_split.h"
26 #include "tensorflow/lite/delegates/gpu/common/status.h"
27 #include "tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h"
28 #include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h"
29 #include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h"
30 #include "tensorflow/lite/delegates/gpu/gl/object.h"
31 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
32 
33 namespace tflite {
34 namespace gpu {
35 namespace gl {
36 namespace {
37 
38 // Rewrites names of all variables according to returned values from the
39 // given NameFunctor.
40 class VariableRewriter : public InlineRewrite {
41  public:
VariableRewriter(const std::string & inline_delimiter,const NameFunctor & name_func)42   VariableRewriter(const std::string& inline_delimiter,
43                    const NameFunctor& name_func)
44       : inline_delimiter_(inline_delimiter), name_func_(name_func) {}
45 
Rewrite(absl::string_view input,std::string * output)46   RewriteStatus Rewrite(absl::string_view input, std::string* output) final {
47     auto ref = variable_accessor_internal::Parse(input);
48     if (ref.name.empty()) {
49       absl::StrAppend(output, "INVALID_SYNTAX");
50       return RewriteStatus::ERROR;
51     }
52 
53     auto it =
54         name_to_variable_.find(std::string(ref.name.data(), ref.name.size()));
55     if (it == name_to_variable_.end()) {
56       return RewriteStatus::NOT_RECOGNIZED;
57     }
58 
59     // reconstruct access using the new name.
60     absl::StrAppend(output, inline_delimiter_, it->second.name);
61     if (!ref.index.empty()) {
62       absl::StrAppend(output, "[", ref.index, "]");
63     }
64     absl::StrAppend(output, ref.field, inline_delimiter_);
65     return RewriteStatus::SUCCESS;
66   }
67 
68   // Return true if variable was successfully added.
AddVariable(Variable && variable)69   bool AddVariable(Variable&& variable) {
70     std::string old_name = variable.name;
71     variable.name = name_func_(old_name);
72     return name_to_variable_.insert({old_name, std::move(variable)}).second;
73   }
74 
75   // Returns a collection of uniform parameters with updated names.
GetUniformParameters() const76   std::vector<Variable> GetUniformParameters() const {
77     std::vector<Variable> variables;
78     variables.reserve(name_to_variable_.size());
79     for (const auto& variable : name_to_variable_) {
80       variables.push_back(variable.second);
81     }
82     return variables;
83   }
84 
85  private:
86   const std::string inline_delimiter_;
87   const NameFunctor name_func_;
88 
89   absl::flat_hash_map<std::string, Variable> name_to_variable_;
90 };
91 
92 // Rewrites names of all objects according to returned values from the
93 // given NameFunctor.
94 class ObjectRewriter : public InlineRewrite {
95  public:
ObjectRewriter(const std::string & inline_delimiter,const NameFunctor & name_func)96   ObjectRewriter(const std::string& inline_delimiter,
97                  const NameFunctor& name_func)
98       : inline_delimiter_(inline_delimiter), name_func_(name_func) {}
99 
Rewrite(absl::string_view input,std::string * output)100   RewriteStatus Rewrite(absl::string_view input, std::string* output) final {
101     // Splits 'a = b' into {'a','b'}.
102     std::pair<absl::string_view, absl::string_view> n =
103         absl::StrSplit(input, absl::MaxSplits('=', 1), absl::SkipWhitespace());
104     if (n.first.empty()) {
105       return RewriteStatus::NOT_RECOGNIZED;
106     }
107 
108     if (n.second.empty()) {
109       return RewriteRead(absl::StripAsciiWhitespace(n.first), output);
110     }
111     return RewriteWrite(absl::StripAsciiWhitespace(n.first),
112                         absl::StripAsciiWhitespace(n.second), output);
113   }
114 
115   // Return true if an object was successfully added.
AddObject(const std::string & name,Object object)116   bool AddObject(const std::string& name, Object object) {
117     std::string new_name = name_func_(name);
118     return name_to_object_.insert({name, {new_name, std::move(object)}}).second;
119   }
120 
121   // Returns a collection of registered objects with updated names.
GetObjects() const122   std::vector<std::pair<std::string, Object>> GetObjects() const {
123     std::vector<std::pair<std::string, Object>> objects;
124     objects.reserve(name_to_object_.size());
125     for (const auto& o : name_to_object_) {
126       objects.push_back(o.second);
127     }
128     return objects;
129   }
130 
131  private:
RewriteRead(absl::string_view location,std::string * output)132   RewriteStatus RewriteRead(absl::string_view location, std::string* output) {
133     auto element = object_accessor_internal::ParseElement(location);
134     if (element.object_name.empty()) {
135       absl::StrAppend(output, "UNABLE_TO_PARSE_INDEXED_ELEMENT");
136       return RewriteStatus::ERROR;
137     }
138     auto it = name_to_object_.find(
139         std::string(element.object_name.data(), element.object_name.size()));
140     if (it == name_to_object_.end()) {
141       return RewriteStatus::NOT_RECOGNIZED;
142     }
143     absl::StrAppend(output, inline_delimiter_, it->second.first, "[",
144                     absl::StrJoin(element.indices, ","), "]",
145                     inline_delimiter_);
146     return RewriteStatus::SUCCESS;
147   }
148 
RewriteWrite(absl::string_view location,absl::string_view value,std::string * output)149   RewriteStatus RewriteWrite(absl::string_view location,
150                              absl::string_view value, std::string* output) {
151     // name[index1, index2...] = value
152     auto element = object_accessor_internal::ParseElement(location);
153     if (element.object_name.empty()) {
154       absl::StrAppend(output, "UNABLE_TO_PARSE_INDEXED_ELEMENT");
155       return RewriteStatus::ERROR;
156     }
157     auto it = name_to_object_.find(
158         std::string(element.object_name.data(), element.object_name.size()));
159     if (it == name_to_object_.end()) {
160       return RewriteStatus::NOT_RECOGNIZED;
161     }
162     absl::StrAppend(output, inline_delimiter_, it->second.first, "[",
163                     absl::StrJoin(element.indices, ","), "] = ", value,
164                     inline_delimiter_);
165     return RewriteStatus::SUCCESS;
166   }
167 
168   const std::string inline_delimiter_;
169   const NameFunctor name_func_;
170 
171   absl::flat_hash_map<std::string, std::pair<std::string, Object>>
172       name_to_object_;
173 };
174 
175 }  // namespace
176 
Rename(const NameFunctor & name_func,GeneratedCode * code)177 absl::Status Rename(const NameFunctor& name_func, GeneratedCode* code) {
178   VariableRewriter variable_rewriter("$", name_func);
179   ObjectRewriter object_rewriter("$", name_func);
180   for (auto&& uniform_parameter : code->parameters) {
181     if (!variable_rewriter.AddVariable(std::move(uniform_parameter))) {
182       return absl::InternalError("Variable name already exists");
183     }
184   }
185   for (auto&& object : code->objects) {
186     if (!object_rewriter.AddObject(object.first, std::move(object.second))) {
187       return absl::InternalError("Object name already exists");
188     }
189   }
190   TextPreprocessor preprocessor('$', /*keep_unknown_rewrites=*/true);
191   preprocessor.AddRewrite(&variable_rewriter);
192   preprocessor.AddRewrite(&object_rewriter);
193   std::string source_code;
194   RETURN_IF_ERROR(preprocessor.Rewrite(code->source_code, &source_code));
195   code->source_code = source_code;
196   code->parameters = variable_rewriter.GetUniformParameters();
197   code->objects = object_rewriter.GetObjects();
198   return absl::OkStatus();
199 }
200 
201 }  // namespace gl
202 }  // namespace gpu
203 }  // namespace tflite
204