1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/task/arguments.h"
17 
18 #include "absl/strings/ascii.h"
19 #include "absl/strings/match.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/lite/delegates/gpu/common/status.h"
22 
23 namespace tflite {
24 namespace gpu {
25 namespace {
IsWordSymbol(char symbol)26 bool IsWordSymbol(char symbol) {
27   return absl::ascii_isalnum(symbol) || symbol == '_';
28 }
29 
GetNextWord(const std::string & code,size_t first_position)30 std::string GetNextWord(const std::string& code, size_t first_position) {
31   size_t pos = first_position;
32   char t = code[pos];
33   while (IsWordSymbol(t)) {
34     pos++;
35     t = code[pos];
36   }
37   return code.substr(first_position, pos - first_position);
38 }
39 
HasWord(const std::string & word,const std::string & text)40 bool HasWord(const std::string& word, const std::string& text) {
41   size_t pos = text.find(word);
42   while (pos != std::string::npos) {
43     char prev = pos == 0 ? '.' : text[pos - 1];
44     char next = pos + word.size() < text.size() ? text[pos + word.size()] : '.';
45     if (!IsWordSymbol(prev) & !IsWordSymbol(next)) {
46       return true;
47     }
48     pos = text.find(word, pos + 1);
49   }
50   return false;
51 }
52 
RenameArg(const std::vector<std::string> & object_names,const std::string & postfix,const std::string & arg_name)53 std::string RenameArg(const std::vector<std::string>& object_names,
54                       const std::string& postfix, const std::string& arg_name) {
55   for (const auto& object_name : object_names) {
56     if (absl::StartsWith(arg_name, object_name) &&
57         arg_name.size() > object_name.size() &&
58         arg_name[object_name.size()] == '_') {
59       return object_name + postfix +
60              arg_name.substr(object_name.size(),
61                              arg_name.size() - object_name.size());
62     }
63   }
64   return arg_name + postfix;
65 }
66 
67 }  // namespace
68 
AddFloat(const std::string & name,float value)69 void Arguments::AddFloat(const std::string& name, float value) {
70   float_values_[name].value = value;
71 }
AddHalf(const std::string & name,half value)72 void Arguments::AddHalf(const std::string& name, half value) {
73   half_values_[name].value = value;
74 }
AddInt(const std::string & name,int value)75 void Arguments::AddInt(const std::string& name, int value) {
76   int_values_[name].value = value;
77 }
78 
AddObjectRef(const std::string & name,AccessType access_type,GPUObjectDescriptorPtr && descriptor_ptr)79 void Arguments::AddObjectRef(const std::string& name, AccessType access_type,
80                              GPUObjectDescriptorPtr&& descriptor_ptr) {
81   descriptor_ptr->SetAccess(access_type);
82   object_refs_[name] = {std::move(descriptor_ptr)};
83 }
84 
AddObject(const std::string & name,GPUObjectDescriptorPtr && descriptor_ptr)85 void Arguments::AddObject(const std::string& name,
86                           GPUObjectDescriptorPtr&& descriptor_ptr) {
87   descriptor_ptr->SetAccess(AccessType::READ);
88   objects_[name] = {std::move(descriptor_ptr)};
89 }
90 
RenameArgs(const std::string & postfix,std::string * code) const91 void Arguments::RenameArgs(const std::string& postfix,
92                            std::string* code) const {
93   static constexpr char kArgsPrefix[] = "args.";
94   size_t next_position = code->find(kArgsPrefix);
95   while (next_position != std::string::npos) {
96     size_t arg_pos = next_position + strlen(kArgsPrefix);
97     std::string arg_name = GetNextWord(*code, arg_pos);
98     code->replace(arg_pos, arg_name.size(), arg_name + postfix);
99     next_position = code->find(kArgsPrefix, arg_pos + arg_name.size());
100   }
101 }
102 
Merge(Arguments && args,const std::string & postfix)103 absl::Status Arguments::Merge(Arguments&& args, const std::string& postfix) {
104   std::vector<std::string> object_names;
105   object_names.reserve(args.object_refs_.size() + args.objects_.size());
106   for (auto& v : args.object_refs_) {
107     object_names.push_back(v.first);
108     const std::string name = v.first + postfix;
109     if (object_refs_.find(name) != object_refs_.end()) {
110       return absl::InvalidArgumentError(
111           absl::StrCat("Object reference name collision. Name - ", name));
112     }
113     object_refs_[name] = {std::move(v.second)};
114   }
115   for (auto& v : args.objects_) {
116     object_names.push_back(v.first);
117     const std::string name = v.first + postfix;
118     if (objects_.find(name) != objects_.end()) {
119       return absl::InvalidArgumentError(
120           absl::StrCat("Object name collision. Name - ", name));
121     }
122     objects_[name] = {std::move(v.second)};
123   }
124   for (const auto& v : args.int_values_) {
125     AddInt(RenameArg(object_names, postfix, v.first), v.second.value);
126   }
127   for (const auto& v : args.float_values_) {
128     AddFloat(RenameArg(object_names, postfix, v.first), v.second.value);
129   }
130   for (const auto& v : args.half_values_) {
131     AddHalf(RenameArg(object_names, postfix, v.first), v.second.value);
132   }
133   return absl::OkStatus();
134 }
135 
ReleaseCPURepresentation()136 void Arguments::ReleaseCPURepresentation() {
137   for (auto& t : objects_) {
138     t.second->Release();
139   }
140 }
141 
GetActiveArguments(const std::string & args_prefix,const std::string & code)142 void Arguments::GetActiveArguments(const std::string& args_prefix,
143                                    const std::string& code) {
144   for (auto& float_val : float_values_) {
145     float_val.second.active = HasWord(args_prefix + float_val.first, code);
146   }
147   for (auto& int_val : int_values_) {
148     int_val.second.active = HasWord(args_prefix + int_val.first, code);
149   }
150   for (auto& half_val : half_values_) {
151     half_val.second.active = HasWord(args_prefix + half_val.first, code);
152   }
153 }
154 
155 }  // namespace gpu
156 }  // namespace tflite
157