1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/function_api_info.h"
17 
18 #include <string>
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/framework/op_def.pb.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/status.h"
23 
24 namespace tensorflow {
25 namespace grappler {
FunctionApiInfo()26 FunctionApiInfo::FunctionApiInfo() {}
~FunctionApiInfo()27 FunctionApiInfo::~FunctionApiInfo() {}
28 
Init(const FunctionDef & function_def)29 Status FunctionApiInfo::Init(const FunctionDef& function_def) {
30   function_type_ = FunctionApiInfo::FunctionType::INFERENCE;
31   for (const auto& attr : function_def.attr()) {
32     if (attr.first == "api_preferred_device") {
33       preferred_device_ = attr.second.s();
34     }
35     if (attr.first == "api_implements") {
36       interface_name_ = attr.second.s();
37     }
38     if (attr.first == "forward_function_name") {
39       function_type_ = FunctionApiInfo::FunctionType::BACKWARD;
40       pairing_function_name_ = attr.second.s();
41     }
42     if (attr.first == "backward_function_name") {
43       function_type_ = FunctionApiInfo::FunctionType::FORWARD;
44       pairing_function_name_ = attr.second.s();
45     }
46   }
47 
48   input_arg_dtypes_.reserve(function_def.signature().input_arg_size());
49   for (const auto& input_arg : function_def.signature().input_arg()) {
50     input_arg_dtypes_.emplace_back(input_arg.type());
51   }
52   output_arg_dtypes_.reserve(function_def.signature().output_arg_size());
53   for (const auto& output_arg : function_def.signature().output_arg()) {
54     output_arg_dtypes_.emplace_back(output_arg.type());
55   }
56 
57   if (interface_name_.empty() && !preferred_device_.empty()) {
58     return errors::InvalidArgument(
59         "Function '", function_def.signature().name(),
60         "' has a preferred device, but does not implement an interface");
61   }
62   return Status::OK();
63 }
64 
preferred_device() const65 const string& FunctionApiInfo::preferred_device() const {
66   return preferred_device_;
67 }
68 
interface_name() const69 const string& FunctionApiInfo::interface_name() const {
70   return interface_name_;
71 }
72 
function_type() const73 const FunctionApiInfo::FunctionType FunctionApiInfo::function_type() const {
74   return function_type_;
75 }
76 
pairing_function_name() const77 const string& FunctionApiInfo::pairing_function_name() const {
78   return pairing_function_name_;
79 }
80 
input_arg_dtypes() const81 const DataTypeVector& FunctionApiInfo::input_arg_dtypes() const {
82   return input_arg_dtypes_;
83 }
84 
output_arg_dtypes() const85 const DataTypeVector& FunctionApiInfo::output_arg_dtypes() const {
86   return output_arg_dtypes_;
87 }
88 
FunctionLibraryApiInfo()89 FunctionLibraryApiInfo::FunctionLibraryApiInfo() {}
~FunctionLibraryApiInfo()90 FunctionLibraryApiInfo::~FunctionLibraryApiInfo() {}
91 
92 namespace {
IsSameArgDef(const OpDef::ArgDef & arg1,const OpDef::ArgDef & arg2)93 bool IsSameArgDef(const OpDef::ArgDef& arg1, const OpDef::ArgDef& arg2) {
94   if (arg1.type() != arg2.type()) return false;
95   if (arg1.type_attr() != arg2.type_attr()) return false;
96   if (arg1.number_attr() != arg2.number_attr()) return false;
97   if (arg1.type_list_attr() != arg2.type_list_attr()) return false;
98   if (arg1.is_ref() != arg2.is_ref()) return false;
99   return true;
100 }
101 
IsSameSignature(const FunctionDef & f1,const FunctionDef & f2,const bool check_inputs,const bool check_outputs)102 bool IsSameSignature(const FunctionDef& f1, const FunctionDef& f2,
103                      const bool check_inputs, const bool check_outputs) {
104   const auto& sig1 = f1.signature();
105   const auto& sig2 = f2.signature();
106   // Functions have positional semantics, so we don't check for names.
107   if (check_inputs) {
108     if (sig1.input_arg_size() != sig2.input_arg_size()) return false;
109     for (int k = 0; k < sig1.input_arg_size(); ++k) {
110       if (!IsSameArgDef(sig1.input_arg(k), sig2.input_arg(k))) return false;
111     }
112   }
113   if (check_outputs) {
114     if (f1.ret().size() != f2.ret().size()) return false;
115     if (sig1.output_arg_size() != sig2.output_arg_size()) return false;
116     for (int k = 0; k < sig1.output_arg_size(); ++k) {
117       if (!IsSameArgDef(sig1.output_arg(k), sig2.output_arg(k))) return false;
118     }
119   }
120   return true;
121 }
122 
ValidateSignature(const string & interface_name,const std::vector<const FunctionDef * > & equiv_funcs,const FunctionApiInfo::FunctionType function_type)123 Status ValidateSignature(const string& interface_name,
124                          const std::vector<const FunctionDef*>& equiv_funcs,
125                          const FunctionApiInfo::FunctionType function_type) {
126   if (equiv_funcs.size() < 2) return Status::OK();
127   for (size_t k = 1; k < equiv_funcs.size(); ++k) {
128     const bool check_input =
129         (function_type == FunctionApiInfo::FunctionType::INFERENCE ||
130          function_type == FunctionApiInfo::FunctionType::FORWARD);
131     const bool check_output =
132         (function_type == FunctionApiInfo::FunctionType::INFERENCE ||
133          function_type == FunctionApiInfo::FunctionType::BACKWARD);
134     if (!IsSameSignature(*equiv_funcs[0], *equiv_funcs[k], check_input,
135                          check_output)) {
136       return errors::InvalidArgument(
137           "Functions '", equiv_funcs[0]->signature().name(), "' and '",
138           equiv_funcs[k]->signature().name(), "' both implement '",
139           interface_name, "' but their signatures do not match.");
140     }
141   }
142   return Status::OK();
143 }
144 
ValidateSignatures(const std::unordered_map<string,std::vector<const FunctionDef * >> & intf_to_func,const FunctionApiInfo::FunctionType function_type)145 Status ValidateSignatures(
146     const std::unordered_map<string, std::vector<const FunctionDef*>>&
147         intf_to_func,
148     const FunctionApiInfo::FunctionType function_type) {
149   for (const auto& item : intf_to_func)
150     TF_RETURN_IF_ERROR(
151         ValidateSignature(item.first, item.second, function_type));
152   return Status::OK();
153 }
154 }  // namespace
155 
Init(const FunctionDefLibrary & function_library)156 Status FunctionLibraryApiInfo::Init(
157     const FunctionDefLibrary& function_library) {
158   std::unordered_map<string, std::vector<const FunctionDef*>> infer_funcs;
159   std::unordered_map<string, std::vector<const FunctionDef*>> fwd_funcs;
160   std::unordered_map<string, std::vector<const FunctionDef*>> bwd_funcs;
161   for (const auto& function : function_library.function()) {
162     std::unique_ptr<FunctionApiInfo> func_info(new FunctionApiInfo);
163     TF_RETURN_IF_ERROR(func_info->Init(function));
164     // Ignore the function if it does not implement any interface.
165     if (func_info->interface_name().empty()) continue;
166 
167     const string& function_name = function.signature().name();
168     const string& interface_name = func_info->interface_name();
169     VLOG(3) << "Got " << func_info->function_type()
170             << " function: " << function_name
171             << " with interface: " << interface_name;
172     switch (func_info->function_type()) {
173       case FunctionApiInfo::FunctionType::INFERENCE:
174         intf_to_inference_funcs_[interface_name].emplace_back(function_name);
175         infer_funcs[interface_name].emplace_back(&function);
176         break;
177       case FunctionApiInfo::FunctionType::FORWARD:
178         intf_to_forward_funcs_[interface_name].emplace_back(function_name);
179         fwd_funcs[interface_name].emplace_back(&function);
180         break;
181       case FunctionApiInfo::FunctionType::BACKWARD:
182         intf_to_backward_funcs_[interface_name].emplace_back(function_name);
183         bwd_funcs[interface_name].emplace_back(&function);
184         break;
185       default:
186         return errors::InvalidArgument("Unrecognized function type: ",
187                                        func_info->function_type());
188     }
189     func_info_[function_name] = std::move(func_info);
190   }
191   TF_RETURN_IF_ERROR(ValidateSignatures(
192       infer_funcs, FunctionApiInfo::FunctionType::INFERENCE));
193   TF_RETURN_IF_ERROR(
194       ValidateSignatures(fwd_funcs, FunctionApiInfo::FunctionType::FORWARD));
195   TF_RETURN_IF_ERROR(
196       ValidateSignatures(bwd_funcs, FunctionApiInfo::FunctionType::BACKWARD));
197   return Status::OK();
198 }
199 
GetEquivalentImplementations(const string & function_name,std::vector<string> * other_functions) const200 Status FunctionLibraryApiInfo::GetEquivalentImplementations(
201     const string& function_name, std::vector<string>* other_functions) const {
202   const auto func_it = func_info_.find(function_name);
203   if (func_it == func_info_.end()) return Status::OK();
204   const FunctionApiInfo* func_info = func_it->second.get();
205 
206   absl::flat_hash_map<string, std::vector<string>>::const_iterator it;
207   switch (func_info->function_type()) {
208     case FunctionApiInfo::FunctionType::INFERENCE:
209       it = intf_to_inference_funcs_.find(func_info->interface_name());
210       break;
211     case FunctionApiInfo::FunctionType::FORWARD:
212       it = intf_to_forward_funcs_.find(func_info->interface_name());
213       break;
214     case FunctionApiInfo::FunctionType::BACKWARD:
215       it = intf_to_backward_funcs_.find(func_info->interface_name());
216       break;
217     default:
218       return errors::InvalidArgument("Unrecognized function type: ",
219                                      func_info->function_type());
220   }
221 
222   for (const auto& func_name : it->second) {
223     if (func_name == function_name) continue;
224     other_functions->emplace_back(func_name);
225   }
226   return Status::OK();
227 }
228 
GetApiInfo(const string & function_name) const229 const FunctionApiInfo* FunctionLibraryApiInfo::GetApiInfo(
230     const string& function_name) const {
231   const auto it = func_info_.find(function_name);
232   if (it == func_info_.end()) return nullptr;
233   return it->second.get();
234 }
235 
236 }  // end namespace grappler
237 }  // end namespace tensorflow
238