1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/function.h"
17 
18 #include <map>
19 #include <unordered_map>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/str_join.h"
25 #include "tensorflow/core/framework/common_shape_fns.h"
26 #include "tensorflow/core/framework/function.pb_text.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/graph/graph.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/gtl/inlined_vector.h"
34 #include "tensorflow/core/lib/gtl/map_util.h"
35 #include "tensorflow/core/lib/strings/str_util.h"
36 #include "tensorflow/core/util/device_name_utils.h"
37 #include "tensorflow/core/util/equal_graph_def.h"
38 
39 namespace tensorflow {
40 
41 // Extracts the actual type from "attr_values" based on its definition
42 // "arg_def".
43 //
44 // If "arg_def" is a N*T type, *is_type_list is set to false, and
45 // *dtypes is set to be a vector of size N and each element is T.
46 //
47 // If "arg_def" is a list(type), *is_type_list is set to true, and
48 // *dtypes is set to be a vector of types specified in attrs for
49 // arg_def.
50 //
51 // Otherwise (arg_def is a simple type T), *is_type_list is set to
52 // false, and *dtypes is set to a single element vector, whose only
53 // element is T.
ArgNumType(AttrSlice attrs,const OpDef::ArgDef & arg_def,bool * is_type_list,DataTypeVector * dtypes)54 Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
55                   bool* is_type_list, DataTypeVector* dtypes) {
56   dtypes->clear();
57   if (!arg_def.type_list_attr().empty()) {
58     const AttrValue* v = attrs.Find(arg_def.type_list_attr());
59     if (v == nullptr) {
60       return errors::NotFound("type attr not found: ",
61                               arg_def.type_list_attr());
62     }
63     *is_type_list = true;
64     for (int i = 0; i < v->list().type_size(); ++i) {
65       dtypes->push_back(v->list().type(i));
66     }
67     return Status::OK();
68   }
69 
70   *is_type_list = false;
71   int num = 1;
72   if (!arg_def.number_attr().empty()) {
73     const AttrValue* v = attrs.Find(arg_def.number_attr());
74     if (v == nullptr) {
75       return errors::NotFound("type attr not found: ", arg_def.type_attr());
76     }
77     num = v->i();
78   }
79 
80   DataType dtype;
81   if (arg_def.type() != DT_INVALID) {
82     dtype = arg_def.type();
83   } else if (arg_def.type_attr().empty()) {
84     dtype = DT_INVALID;
85   } else {
86     const AttrValue* v = attrs.Find(arg_def.type_attr());
87     if (v == nullptr) {
88       return errors::NotFound("type attr not found: ", arg_def.type_attr());
89     }
90     dtype = v->type();
91   }
92   dtypes->resize(num, dtype);
93   return Status::OK();
94 }
95 
96 namespace {
97 
98 template <typename T>
AddAttr(const string & name,const T & val,NodeDef * ndef)99 void AddAttr(const string& name, const T& val, NodeDef* ndef) {
100   SetAttrValue(val, &((*ndef->mutable_attr())[name]));
101 }
102 
ValidateSignatureWithAttrs(const OpDef & sig,AttrSlice attr_values)103 Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) {
104   // attr_values should specify all attrs defined in fdef.
105   for (const auto& a : sig.attr()) {
106     const AttrValue* v = attr_values.Find(a.name());
107     if (!v) {
108       return errors::NotFound("Attr ", a.name(), " is not found from ",
109                               SummarizeOpDef(sig));
110     }
111     Status status = AttrValueHasType(*v, a.type());
112     if (!status.ok()) {
113       errors::AppendToMessage(&status, "for attr '", a.name(), "'");
114       return status;
115     }
116   }
117 
118 // TODO(josh11b): Enable this code once it works with function gradients.
119 // Right now the C++ function gradient code assumes it can pass
120 // all the attrs of the function to the gradient, and any attrs that
121 // the gradient doesn't care about will be ignored.
122 #if 0
123   if (attr_values.size() != sig.attr_size()) {
124     for (const auto& a : attr_values) {
125       // TODO(josh11b): Possibly should ignore attrs that start with "_" here?
126       bool found = false;
127       for (const auto& s : sig.attr()) {
128         if (a.first == s.name()) {
129           found = true;
130           break;
131         }
132       }
133       if (!found) {
134         return errors::NotFound("Attr ", a.first, " is not found in ",
135                                 SummarizeOpDef(sig));
136       }
137     }
138   }
139 #endif
140 
141   return Status::OK();
142 }
143 
144 // A helper class for instantiating functions. This contains shared information
145 // like the resulting graph and node name index.
146 class FunctionInstantiationHelper {
147  public:
FunctionInstantiationHelper(GetFunctionSignature get_function,InstantiationResult * result)148   FunctionInstantiationHelper(GetFunctionSignature get_function,
149                               InstantiationResult* result)
150       : get_function_(std ::move(get_function)), result_(*result) {
151     result_.nodes.clear();
152   }
153 
154   // Builds index for nodes that can be used as node's input arguments.
BuildInputArgIndex(const OpDef::ArgDef & arg_def,AttrSlice attr_values,bool ints_on_device)155   Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, AttrSlice attr_values,
156                             bool ints_on_device) {
157     bool is_type_list;
158     DataTypeVector dtypes;
159     TF_RETURN_IF_ERROR(
160         ArgNumType(attr_values, arg_def, &is_type_list, &dtypes));
161     CHECK_GE(dtypes.size(), size_t{1});
162     int arg_index = result_.nodes.size();
163     TF_RETURN_IF_ERROR(
164         AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes}));
165     // Creates dtypes.size() nodes in the graph.
166     for (size_t i = 0; i < dtypes.size(); ++i) {
167       TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":", i),
168                                  {true, arg_index, 0, false, {dtypes[i]}}));
169       DCHECK_EQ(arg_index, result_.nodes.size());
170       string name = arg_def.name();
171       if (dtypes.size() > 1) {
172         strings::StrAppend(&name, "_", i);
173       }
174       NodeDef* gnode = AddNode(name);
175       if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
176         gnode->set_op(FunctionLibraryDefinition::kDeviceArgOp);
177       } else {
178         gnode->set_op(FunctionLibraryDefinition::kArgOp);
179       }
180       AddAttr("T", dtypes[i], gnode);
181       AddAttr("index", arg_index, gnode);
182       result_.arg_types.push_back(dtypes[i]);
183       ++arg_index;
184     }
185     return Status::OK();
186   }
187 
BuildNodeOutputIndex(const NodeDef & node,AttrSlice attrs,const int arg_index)188   Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs,
189                               const int arg_index) {
190     const OpDef* node_sig = nullptr;
191     TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig));
192     if (node_sig->output_arg_size() == 0) {
193       return AddItem(node.name(), {false, arg_index, 0, false, {}});
194     }
195     const int num_retval = node_sig->output_arg_size();
196     int start = 0;
197     bool is_type_list;
198     DataTypeVector dtypes;
199     for (int i = 0; i < num_retval; ++i) {
200       TF_RETURN_IF_ERROR(
201           ArgNumType(attrs, node_sig->output_arg(i), &is_type_list, &dtypes));
202       // Note that we rely on the backwards-compatibility test enforcing
203       // that output_arg(*).name() doesn't change here.
204       const string base_name =
205           strings::StrCat(node.name(), ":", node_sig->output_arg(i).name());
206       TF_RETURN_IF_ERROR(
207           AddItem(base_name, {false, arg_index, start, is_type_list, dtypes}));
208       for (int j = 0; j < static_cast<int>(dtypes.size()); ++j) {
209         TF_RETURN_IF_ERROR(
210             AddItem(strings::StrCat(base_name, ":", j),
211                     {false, arg_index, start + j, false, {dtypes[j]}}));
212       }
213       start += dtypes.size();
214     }
215     return Status::OK();
216   }
217 
InstantiateNode(const NodeDef & fnode,AttrSlice attrs)218   Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) {
219     const OpDef* fnode_sig = nullptr;
220     TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig));
221     NodeDef* gnode = AddNode(fnode.name());
222     gnode->set_op(fnode.op());
223     gnode->set_device(fnode.device());
224     int gnode_idx = nodes_.size() - 1;
225 
226     // Input
227     const int num_args = fnode_sig->input_arg_size();
228     bool is_type_list;  // ignored
229     DataTypeVector dtypes;
230     int fnode_arg_index = 0;
231     for (int i = 0; i < num_args; ++i) {
232       TF_RETURN_IF_ERROR(
233           ArgNumType(attrs, fnode_sig->input_arg(i), &is_type_list, &dtypes));
234       // Consume inputs (indexed by fnode_arg_index) until we have
235       // matched each element of dtypes (indexed by j).
236       for (size_t j = 0; j < dtypes.size(); ++fnode_arg_index) {
237         if (fnode_arg_index >= fnode.input_size()) {
238           // Should never happen if we computed dtypes correctly.
239           return errors::InvalidArgument(
240               "Attempt to access beyond input size: ", fnode_arg_index,
241               " >= ", fnode.input_size());
242         }
243         // Look up the next input.
244         const string& input_name = fnode.input(fnode_arg_index);
245         const auto* item = GetItemOrNull(input_name);
246         if (item == nullptr) {
247           return errors::InvalidArgument(
248               "input ", input_name,
249               " is not found: ", FormatNodeDefForError(fnode));
250         }
251         if (item->dtypes.size() > dtypes.size() - j) {
252           return errors::InvalidArgument("Input ", input_name, " too long for ",
253                                          fnode_sig->input_arg(i).name());
254         }
255         // Match up all the elements of this input (indexed by k) with
256         // elements of dtypes (advancing j).
257         for (int k = 0; k < item->dtypes.size(); ++k, ++j) {
258           if (item->dtypes[k] != dtypes[j]) {
259             return errors::InvalidArgument(
260                 "input ", fnode_sig->input_arg(i).name(), "[", j,
261                 "] expected type ", DataTypeString(dtypes[j]),
262                 " != ", DataTypeString(item->dtypes[k]), ", the type of ",
263                 input_name, "[", k, "]");
264           }
265           if (item->is_func_arg) {
266             AddInput(gnode_idx, item->nid + k, 0);
267           } else {
268             AddInput(gnode_idx, item->nid, item->idx + k);
269           }
270         }
271       }
272     }
273 
274     // Control deps.
275     for (int i = fnode_arg_index; i < fnode.input_size(); ++i) {
276       const string& input = fnode.input(i);
277       if (input.empty() || input[0] != '^') {
278         return errors::InvalidArgument("Expected input[", i, "] == '", input,
279                                        "' to be a control input.");
280       }
281       int nid = -1;
282       const string node_name = input.substr(1);
283       const string node_colon = node_name + ":";
284       const string node_colon_bound = node_name + ";";
285       // index_ is a map sorted lexicographically, so the key we are looking for
286       // must lie in the range [node_name, node_colon_bound).
287       auto it = index_.lower_bound(node_name);
288       while (it != index_.end() && it->first <= node_colon_bound) {
289         if (it->first == node_name ||
290             tensorflow::str_util::StartsWith(it->first, node_colon)) {
291           nid = it->second.nid;
292           break;
293         }
294         ++it;
295       }
296       if (nid == -1) {
297         return errors::InvalidArgument("input[", i, "] == '", input,
298                                        "', is not found.");
299       }
300       AddDep(gnode_idx, nid);
301     }
302 
303     // Attrs.
304     for (const auto& p : attrs) {
305       (*gnode->mutable_attr())[p.first] = p.second;
306     }
307 
308     return Status::OK();
309   }
310 
AddReturnNode(const OpDef::ArgDef & ret_def,AttrSlice attrs,const::tensorflow::protobuf::Map<string,string> & ret_map,bool ints_on_device,int * ret_index)311   Status AddReturnNode(
312       const OpDef::ArgDef& ret_def, AttrSlice attrs,
313       const ::tensorflow::protobuf::Map<string, string>& ret_map,
314       bool ints_on_device, int* ret_index) {
315     auto ret_iter = ret_map.find(ret_def.name());
316     if (ret_iter == ret_map.end()) {
317       return errors::InvalidArgument("Return ", ret_def.name(), " missing.");
318     }
319     bool is_type_list;
320     DataTypeVector dtypes;
321     TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
322     CHECK_GE(dtypes.size(), size_t{1});
323     const auto* item = GetItemOrNull(ret_iter->second);
324     if (item == nullptr) {
325       return errors::InvalidArgument("Return ", ret_def.name(), " -> ",
326                                      ret_iter->second, " is not found.");
327     }
328     if (dtypes != item->dtypes) {
329       return errors::InvalidArgument("Invalid ret types ", ret_def.name(),
330                                      " : ", DataTypeVectorString(dtypes),
331                                      " vs. ",
332                                      DataTypeVectorString(item->dtypes));
333     }
334     for (size_t i = 0; i < dtypes.size(); ++i) {
335       string name = strings::StrCat(ret_def.name(), "_RetVal");
336       if (dtypes.size() > 1) {
337         strings::StrAppend(&name, "_", i);
338       }
339       NodeDef* gnode = AddNode(name);
340       if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
341         gnode->set_op(FunctionLibraryDefinition::kDeviceRetOp);
342       } else {
343         gnode->set_op(FunctionLibraryDefinition::kRetOp);
344       }
345       AddInput(nodes_.size() - 1, item->nid, item->idx + i);
346       AddAttr("T", dtypes[i], gnode);
347       AddAttr("index", (*ret_index)++, gnode);
348       result_.ret_types.push_back(dtypes[i]);
349     }
350     return Status::OK();
351   }
352 
353   // Adds the actual node inputs to the result graph by converting indexes to
354   // the node names.
AddNodeInputs()355   void AddNodeInputs() {
356     for (int i = 0; i < result_.nodes.size(); i++) {
357       NodeInfo& node_info = nodes_[i];
358       for (const auto& p : node_info.data_inputs) {
359         result_.nodes[i].add_input(Name(p.first, p.second));
360       }
361       for (int index : node_info.control_inputs) {
362         result_.nodes[i].add_input(Dep(index));
363       }
364     }
365   }
366 
367  private:
368   // This is used to build a small index for all names that can be used as a
369   // node's input arguments.
370   //
371   // If is_func_arg is true, the name is a function's argument.  In
372   // this case, the produced graph def has node[nid:nid + dtype.size()].
373   //
374   // Otherwise, the name is a function body's node return value.  In
375   // this case, the produced graph def has one node node[nid] and
376   // the node's output index [idx ... idx + num) corresponds to the
377   // named outputs.
378   //
379   // In all cases, "dtype" specifies the data type.
380   struct NameInfoItem {
381     bool is_func_arg;
382     int nid;
383     int idx;
384     bool is_type_list;
385     DataTypeVector dtypes;
386   };
387 
388   // Adds an item into the input name index.
AddItem(const string & name,const NameInfoItem & item)389   Status AddItem(const string& name, const NameInfoItem& item) {
390     if (!index_.insert({name, item}).second) {
391       return errors::InvalidArgument(
392           strings::StrCat("Duplicated ", item.is_func_arg ? "arg" : "ret",
393                           " name: "),
394           name);
395     }
396     return Status::OK();
397   }
398 
GetItemOrNull(const string & name) const399   const NameInfoItem* GetItemOrNull(const string& name) const {
400     return gtl::FindOrNull(index_, name);
401   }
402 
Dep(int node_index) const403   string Dep(int node_index) const {
404     return strings::StrCat("^", Name(node_index));
405   }
406 
Name(int node_index) const407   string Name(int node_index) const {
408     CHECK_LT(node_index, nodes_.size());
409     return nodes_[node_index].name;
410   }
411 
Name(int node_index,int output_index) const412   string Name(int node_index, int output_index) const {
413     if (output_index == 0) {
414       return Name(node_index);
415     } else {
416       return strings::StrCat(Name(node_index), ":", output_index);
417     }
418   }
419 
AddNode(const string & name)420   NodeDef* AddNode(const string& name) {
421     result_.nodes.emplace_back();
422     NodeDef* gnode = &result_.nodes.back();
423     gnode->set_name(name);
424     nodes_.push_back({name, {}, {}});
425     CHECK_EQ(result_.nodes.size(), nodes_.size());
426     return gnode;
427   }
428 
AddInput(int node_index,int output_node,int output_index)429   void AddInput(int node_index, int output_node, int output_index) {
430     CHECK_LT(node_index, nodes_.size());
431     nodes_[node_index].data_inputs.push_back(
432         std::make_pair(output_node, output_index));
433   }
434 
AddDep(int node_index,int dep_index)435   void AddDep(int node_index, int dep_index) {
436     CHECK_LT(node_index, nodes_.size());
437     nodes_[node_index].control_inputs.push_back(dep_index);
438   }
439 
440   GetFunctionSignature get_function_;
441   InstantiationResult& result_;
442   // A small index for all names that can be used as a node's input arguments.
443   std::map<string, NameInfoItem> index_;
444   // This contains information about a node in the new graph including the node
445   // names and input nodes' indexes.
446   struct NodeInfo {
447     string name;
448     // Data inputs where <n, k> means arg k of node n.
449     std::vector<std::pair<int, int>> data_inputs;
450     // Control inputs (dependencies).
451     std::vector<int> control_inputs;
452   };
453   // nodes_[i] is the information about result_.nodes[i].
454   std::vector<NodeInfo> nodes_;
455 };
456 
457 // Various helpers Print(proto) to print relevant protos to ascii.
Print(const OpDef::ArgDef & arg)458 string Print(const OpDef::ArgDef& arg) {
459   string out;
460   strings::StrAppend(&out, arg.name(), ":");
461   if (arg.is_ref()) strings::StrAppend(&out, "Ref(");
462   if (!arg.number_attr().empty()) {
463     strings::StrAppend(&out, arg.number_attr(), "*");
464   }
465   if (arg.type() != DT_INVALID) {
466     strings::StrAppend(&out, DataTypeString(arg.type()));
467   } else {
468     strings::StrAppend(&out, arg.type_attr());
469   }
470   if (arg.is_ref()) strings::StrAppend(&out, ")");
471   return out;
472 }
473 
474 // TODO(josh11b): Merge this with SummarizeAttrValue().
Print(const AttrValue & attr_value)475 string Print(const AttrValue& attr_value) {
476   if (attr_value.value_case() == AttrValue::kType) {
477     return DataTypeString(attr_value.type());
478   } else if ((attr_value.value_case() == AttrValue::kList) &&
479              (attr_value.list().type_size() > 0)) {
480     string ret = "{";
481     for (int i = 0; i < attr_value.list().type_size(); ++i) {
482       if (i > 0) strings::StrAppend(&ret, ", ");
483       strings::StrAppend(&ret, DataTypeString(attr_value.list().type(i)));
484     }
485     strings::StrAppend(&ret, "}");
486     return ret;
487   } else if (attr_value.value_case() == AttrValue::kFunc) {
488     if (attr_value.func().attr_size() == 0) {
489       return attr_value.func().name();
490     }
491     std::vector<string> entries;
492     for (auto p : attr_value.func().attr()) {
493       entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
494     }
495     std::sort(entries.begin(), entries.end());
496     return strings::StrCat(attr_value.func().name(), "[",
497                            str_util::Join(entries, ", "), "]");
498   }
499   return SummarizeAttrValue(attr_value);
500 }
501 
502 // TODO(josh11b): Merge this with SummarizeNodeDef().
Print(const NodeDef & n)503 string Print(const NodeDef& n) {
504   string out;
505   strings::StrAppend(&out, n.name(), " = ", n.op());
506   if (n.attr_size() > 0) {
507     std::vector<string> entries;
508     for (auto& a : n.attr()) {
509       entries.push_back(strings::StrCat(a.first, "=", Print(a.second)));
510     }
511     std::sort(entries.begin(), entries.end());
512     // Add a short device string at the end of all attributes.
513     if (!n.device().empty()) {
514       DeviceNameUtils::ParsedName parsed;
515       if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) {
516         entries.push_back(
517             strings::StrCat("device=", parsed.type, ":", parsed.id));
518       } else {
519         entries.push_back("device=<FAILED_TO_PARSE>");
520       }
521     }
522     strings::StrAppend(&out, "[", str_util::Join(entries, ", "), "]");
523   }
524   strings::StrAppend(&out, "(");
525   std::vector<StringPiece> dat;
526   std::vector<string> dep;
527   for (StringPiece s : n.input()) {
528     if (str_util::ConsumePrefix(&s, "^")) {
529       dep.emplace_back(s);
530     } else {
531       dat.push_back(s);
532     }
533   }
534   strings::StrAppend(&out, str_util::Join(dat, ", "), ")");
535   if (!dep.empty()) {
536     strings::StrAppend(&out, " @ ", str_util::Join(dep, ", "));
537   }
538   return out;
539 }
540 
Print(const FunctionDef & fdef)541 string Print(const FunctionDef& fdef) {
542   string out;
543   const OpDef& sig = fdef.signature();
544   strings::StrAppend(&out, "\n", sig.name());
545   if (sig.attr_size() > 0) {
546     strings::StrAppend(&out, "[");
547     for (int i = 0; i < sig.attr_size(); ++i) {
548       const auto& a = sig.attr(i);
549       if (i > 0) strings::StrAppend(&out, ", ");
550       if (a.type() == "type") {
551         strings::StrAppend(&out, a.name(), ":", Print(a.allowed_values()));
552       } else {
553         strings::StrAppend(&out, a.name(), ":", a.type());
554       }
555     }
556     strings::StrAppend(&out, "]");
557   }
558   strings::StrAppend(&out, "(");
559   for (int i = 0; i < sig.input_arg_size(); ++i) {
560     if (i > 0) strings::StrAppend(&out, ", ");
561     strings::StrAppend(&out, Print(sig.input_arg(i)));
562   }
563   strings::StrAppend(&out, ") -> (");
564   for (int i = 0; i < sig.output_arg_size(); ++i) {
565     if (i > 0) strings::StrAppend(&out, ", ");
566     strings::StrAppend(&out, Print(sig.output_arg(i)));
567   }
568   strings::StrAppend(&out, ") {\n");
569   for (const auto& n : fdef.node_def()) {
570     strings::StrAppend(&out, "  ", Print(n), "\n");
571   }
572   for (const auto& cr : fdef.control_ret()) {
573     strings::StrAppend(&out, "  @return ", cr.first, " = ", cr.second, "\n");
574   }
575   for (const auto& r : fdef.ret()) {
576     strings::StrAppend(&out, "  return ", r.first, " = ", r.second, "\n");
577   }
578   strings::StrAppend(&out, "}\n");
579   return out;
580 }
581 
Print(gtl::ArraySlice<const NodeDef * > nodes)582 string Print(gtl::ArraySlice<const NodeDef*> nodes) {
583   std::vector<const NodeDef*> arg;
584   std::vector<const NodeDef*> ret;
585   std::vector<const NodeDef*> body;
586   for (const NodeDef* n : nodes) {
587     if (n->op() == FunctionLibraryDefinition::kArgOp ||
588         n->op() == FunctionLibraryDefinition::kDeviceArgOp) {
589       arg.push_back(n);
590     } else if (n->op() == FunctionLibraryDefinition::kRetOp ||
591                n->op() == FunctionLibraryDefinition::kDeviceRetOp) {
592       ret.push_back(n);
593     } else {
594       body.push_back(n);
595     }
596   }
597   auto comp = [](const NodeDef* x, const NodeDef* y) {
598     int xi;
599     TF_CHECK_OK(GetNodeAttr(*x, "index", &xi));
600     int yi;
601     TF_CHECK_OK(GetNodeAttr(*y, "index", &yi));
602     return xi < yi;
603   };
604   std::sort(arg.begin(), arg.end(), comp);
605   std::sort(ret.begin(), ret.end(), comp);
606   string out;
607   strings::StrAppend(&out, "\n(");
608   auto get_type_and_device = [](const NodeDef& n) {
609     DataType dt;
610     if (!GetNodeAttr(n, "T", &dt).ok()) {
611       dt = DT_INVALID;
612     }
613     if (!n.device().empty()) {
614       DeviceNameUtils::ParsedName parsed;
615       if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) {
616         return strings::StrCat(DataTypeString(dt), "@", parsed.type, ":",
617                                parsed.id);
618       } else {
619         LOG(WARNING) << "Failed to parse device \"" << n.device() << "\" in "
620                      << n.op() << ":" << n.name();
621         return strings::StrCat(DataTypeString(dt), "@",
622                                "<FAILED_TO_PARSE_DEVICE>");
623       }
624     }
625     return DataTypeString(dt);
626   };
627   for (size_t i = 0; i < arg.size(); ++i) {
628     const NodeDef* n = arg[i];
629     if (i > 0) strings::StrAppend(&out, ", ");
630     CHECK_GE(n->attr_size(), 2);
631     strings::StrAppend(&out, n->name(), ":", get_type_and_device(*n));
632   }
633   strings::StrAppend(&out, ") -> (");
634   for (size_t i = 0; i < ret.size(); ++i) {
635     const NodeDef* n = ret[i];
636     if (i > 0) strings::StrAppend(&out, ", ");
637     CHECK_LE(2, n->attr_size());
638 
639     // The _RetVal op should have a unique non-control input. We assert that
640     // here and add it to the output.
641     bool found_non_control_input = false;
642     for (const string& input : n->input()) {
643       if (!input.empty() && input[0] != '^') {
644         DCHECK_EQ(found_non_control_input, false)
645             << "RetVal node has more than one non-control input: "
646             << absl::StrJoin(n->input(), ", ");
647         strings::StrAppend(&out, n->input(0), ":", get_type_and_device(*n));
648         found_non_control_input = true;
649       }
650     }
651     DCHECK_EQ(found_non_control_input, true)
652         << "RetVal did not have any non-control inputs: "
653         << absl::StrJoin(n->input(), ", ");
654   }
655   strings::StrAppend(&out, ") {\n");
656   for (size_t i = 0; i < body.size(); ++i) {
657     strings::StrAppend(&out, "  ", Print(*body[i]), "\n");
658   }
659   strings::StrAppend(&out, "}\n");
660   return out;
661 }
662 
AddDefaultAttrs(const string & op,const GetFunctionSignature & get_function,AttrValueMap * attrs)663 Status AddDefaultAttrs(const string& op,
664                        const GetFunctionSignature& get_function,
665                        AttrValueMap* attrs) {
666   const OpDef* op_def = nullptr;
667   TF_RETURN_IF_ERROR(get_function(op, &op_def));
668   AttrSlice attr_slice(attrs);
669   for (const auto& attr_def : op_def->attr()) {
670     if (attr_def.has_default_value() && !attr_slice.Find(attr_def.name())) {
671       if (!attrs->insert({attr_def.name(), attr_def.default_value()}).second) {
672         return errors::Internal("Somehow duplicated: ", attr_def.name());
673       }
674     }
675   }
676   return Status::OK();
677 }
678 
679 }  // end namespace
680 
681 // TODO(shikharagarwal): Transmit original node names correctly in file.
InstantiateFunction(const FunctionDef & fdef,AttrSlice attr_values,GetFunctionSignature get_function,InstantiationResult * result)682 Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
683                            GetFunctionSignature get_function,
684                            InstantiationResult* result) {
685   VLOG(4) << "Instantiation Function: " << Print(fdef);
686 
687   const OpDef& sig = fdef.signature();
688   TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values));
689 
690   bool ints_on_device =
691       fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
692       fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b();
693 
694   FunctionInstantiationHelper helper(get_function, result);
695   Status s;
696   for (const OpDef::ArgDef& arg_def : sig.input_arg()) {
697     s = helper.BuildInputArgIndex(arg_def, attr_values, ints_on_device);
698     if (!s.ok()) {
699       errors::AppendToMessage(&s, "In ", Print(arg_def));
700       return s;
701     }
702   }
703 
704   auto substitute = [attr_values](StringPiece name, AttrValue* val) {
705     if (const AttrValue* v = attr_values.Find(name)) {
706       *val = *v;
707       return true;
708     }
709     return false;
710   };
711 
712   // Makes a copy of all attrs in fdef and substitutes placeholders.
713   // After this step, every attr is bound to a concrete value.
714   std::vector<AttrValueMap> node_attrs;
715   node_attrs.resize(fdef.node_def_size());
716   for (int i = 0; i < fdef.node_def_size(); ++i) {
717     for (auto attr : fdef.node_def(i).attr()) {
718       if (!SubstitutePlaceholders(substitute, &attr.second)) {
719         return errors::InvalidArgument("Failed to bind all placeholders in ",
720                                        SummarizeAttrValue(attr.second));
721       }
722       if (!node_attrs[i].insert(attr).second) {
723         return errors::Internal("Somehow duplicated: ", attr.first);
724       }
725     }
726     TF_RETURN_IF_ERROR(
727         AddDefaultAttrs(fdef.node_def(i).op(), get_function, &node_attrs[i]));
728   }
729 
730   for (int i = 0; i < fdef.node_def_size(); ++i) {
731     s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]),
732                                     result->nodes.size() + i);
733     if (!s.ok()) {
734       errors::AppendToMessage(&s, "In ",
735                               FormatNodeDefForError(fdef.node_def(i)));
736       return s;
737     }
738   }
739   // Emits one node for each fdef.node_def.
740   for (int i = 0; i < fdef.node_def_size(); ++i) {
741     s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i]));
742     if (!s.ok()) {
743       errors::AppendToMessage(&s, "In ",
744                               FormatNodeDefForError(fdef.node_def(i)));
745       return s;
746     }
747   }
748 
749   // Emits nodes for the function's return values.
750   int ret_index = 0;
751   for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
752     s = helper.AddReturnNode(ret_def, attr_values, fdef.ret(), ints_on_device,
753                              &ret_index);
754     if (!s.ok()) {
755       errors::AppendToMessage(&s, "In function output ", Print(ret_def));
756       return s;
757     }
758   }
759 
760   // Adds the actual node inputs using the input indexes.
761   helper.AddNodeInputs();
762 
763   return Status::OK();
764 }
765 
DebugString(const FunctionDef & func_def)766 string DebugString(const FunctionDef& func_def) { return Print(func_def); }
767 
DebugString(const GraphDef & instantiated_func_def)768 string DebugString(const GraphDef& instantiated_func_def) {
769   std::vector<const NodeDef*> ptrs;
770   for (const NodeDef& n : instantiated_func_def.node()) {
771     ptrs.push_back(&n);
772   }
773   return Print(ptrs);
774 }
775 
DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes)776 string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes) {
777   std::vector<const NodeDef*> ptrs;
778   for (const NodeDef& n : instantiated_func_nodes) {
779     ptrs.push_back(&n);
780   }
781   return Print(ptrs);
782 }
783 
DebugStringWhole(const GraphDef & gdef)784 string DebugStringWhole(const GraphDef& gdef) {
785   string ret;
786   for (const auto& fdef : gdef.library().function()) {
787     strings::StrAppend(&ret, Print(fdef));
788   }
789   strings::StrAppend(&ret, "\n");
790   for (const auto& ndef : gdef.node()) {
791     strings::StrAppend(&ret, Print(ndef), "\n");
792   }
793   return ret;
794 }
795 
796 namespace {
797 
798 // Returns the name -> attr mapping of fdef's attrs that have a value set. In
799 // Python, it's possible to access unset attrs, which returns a default value
800 // and adds an unset attr to the map.
GetSetAttrs(const FunctionDef & fdef)801 std::map<string, AttrValue> GetSetAttrs(const FunctionDef& fdef) {
802   std::map<string, AttrValue> set_attrs;
803   for (auto pair : fdef.attr()) {
804     if (pair.second.value_case() != AttrValue::VALUE_NOT_SET) {
805       set_attrs[pair.first] = pair.second;
806     }
807   }
808   return set_attrs;
809 }
810 
811 }  // end namespace
812 
FunctionDefsEqual(const FunctionDef & f1,const FunctionDef & f2)813 bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
814   if (!OpDefEqual(f1.signature(), f2.signature())) return false;
815 
816   std::map<string, AttrValue> f1_attrs = GetSetAttrs(f1);
817   std::map<string, AttrValue> f2_attrs = GetSetAttrs(f2);
818   if (f1_attrs.size() != f2_attrs.size()) return false;
819   for (auto iter1 : f1_attrs) {
820     auto iter2 = f2_attrs.find(iter1.first);
821     if (iter2 == f2_attrs.end()) return false;
822     if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false;
823   }
824 
825   if (!EqualRepeatedNodeDef(f1.node_def(), f2.node_def(), nullptr)) {
826     return false;
827   }
828 
829   std::map<string, string> ret1(f1.ret().begin(), f1.ret().end());
830   std::map<string, string> ret2(f2.ret().begin(), f2.ret().end());
831   if (ret1 != ret2) return false;
832 
833   std::map<string, string> control_ret1(f1.control_ret().begin(),
834                                         f1.control_ret().end());
835   std::map<string, string> control_ret2(f2.control_ret().begin(),
836                                         f2.control_ret().end());
837   if (control_ret1 != control_ret2) return false;
838 
839   return true;
840 }
841 
FunctionDefHash(const FunctionDef & fdef)842 uint64 FunctionDefHash(const FunctionDef& fdef) {
843   // signature
844   uint64 h = OpDefHash(fdef.signature());
845 
846   // attrs
847   std::map<string, AttrValue> attrs = GetSetAttrs(fdef);
848   for (const auto& p : attrs) {
849     h = Hash64(p.first.data(), p.first.size(), h);
850     h = Hash64Combine(AttrValueHash(p.second), h);
851   }
852 
853   // node defs
854   h = Hash64Combine(RepeatedNodeDefHash(fdef.node_def()), h);
855 
856   // output names
857   std::map<string, string> ret(fdef.ret().begin(), fdef.ret().end());
858   for (const auto& p : ret) {
859     h = Hash64(p.first.data(), p.first.size(), h);
860     h = Hash64(p.second.data(), p.second.size(), h);
861   }
862 
863   // control output names
864   std::map<string, string> control_ret(fdef.control_ret().begin(),
865                                        fdef.control_ret().end());
866   for (const auto& p : control_ret) {
867     h = Hash64(p.first.data(), p.first.size(), h);
868     h = Hash64(p.second.data(), p.second.size(), h);
869   }
870 
871   return h;
872 }
873 
874 static constexpr const char* const kExecutorAttr = "_executor";
875 
876 /* static */
ExecutorType(const InstantiateOptions & options,AttrSlice attrs)877 string FunctionLibraryRuntime::ExecutorType(const InstantiateOptions& options,
878                                             AttrSlice attrs) {
879   if (!options.executor_type.empty()) {
880     return options.executor_type;
881   } else if (const AttrValue* executor_attr = attrs.Find(kExecutorAttr)) {
882     return executor_attr->s();
883   } else {
884     return string();
885   }
886 }
887 
Canonicalize(const string & funcname,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options)888 string Canonicalize(const string& funcname, AttrSlice attrs,
889                     const FunctionLibraryRuntime::InstantiateOptions& options) {
890   std::vector<string> entries;
891   entries.reserve(attrs.size() + static_cast<int>(options.target.empty()) +
892                   options.input_devices.size());
893   for (auto p : attrs) {
894     if (p.first != kExecutorAttr) {
895       entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
896     }
897   }
898   if (!options.target.empty()) {
899     entries.push_back(
900         strings::StrCat("_target", "=", str_util::CEscape(options.target)));
901   }
902   for (int i = 0; i < options.input_devices.size(); ++i) {
903     entries.push_back(strings::StrCat(
904         "_input_dev", i, "=", str_util::CEscape(options.input_devices[i])));
905   }
906   for (int i = 0; i < options.output_devices.size(); ++i) {
907     entries.push_back(strings::StrCat(
908         "_output_dev", i, "=", str_util::CEscape(options.output_devices[i])));
909   }
910   if (options.overlay_lib) {
911     entries.push_back(strings::StrCat(
912         "_overlay_lib", "=", reinterpret_cast<uintptr_t>(options.overlay_lib)));
913   }
914   if (!options.state_handle.empty()) {
915     entries.push_back(
916         strings::StrCat("_state_handle", "=", options.state_handle));
917   }
918   string executor_type = FunctionLibraryRuntime::ExecutorType(options, attrs);
919   if (!executor_type.empty()) {
920     entries.push_back(strings::StrCat(kExecutorAttr, "=", executor_type));
921   }
922   string config_proto_serialized;
923   options.config_proto.SerializeToString(&config_proto_serialized);
924   if (!config_proto_serialized.empty()) {
925     entries.push_back(strings::StrCat(
926         "_config_proto", "=", str_util::CEscape(config_proto_serialized)));
927   }
928   std::sort(entries.begin(), entries.end());
929   return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]");
930 }
931 
FunctionCallFrame(DataTypeSlice arg_types,DataTypeSlice ret_types)932 FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types,
933                                      DataTypeSlice ret_types)
934     : arg_types_(arg_types.begin(), arg_types.end()),
935       ret_types_(ret_types.begin(), ret_types.end()) {
936   args_.resize(arg_types_.size());
937   rets_.resize(ret_types_.size());
938 }
939 
~FunctionCallFrame()940 FunctionCallFrame::~FunctionCallFrame() {}
941 
SetArgs(gtl::ArraySlice<Tensor> args)942 Status FunctionCallFrame::SetArgs(gtl::ArraySlice<Tensor> args) {
943   // Input type checks.
944   if (args.size() != arg_types_.size()) {
945     return errors::InvalidArgument("Expects ", arg_types_.size(),
946                                    " arguments, but ", args.size(),
947                                    " is provided");
948   }
949   for (size_t i = 0; i < args.size(); ++i) {
950     if (arg_types_[i] != args[i].dtype()) {
951       return errors::InvalidArgument(
952           "Expects arg[", i, "] to be ", DataTypeString(arg_types_[i]), " but ",
953           DataTypeString(args[i].dtype()), " is provided");
954     }
955     args_[i] = args[i];
956   }
957   return Status::OK();
958 }
959 
GetRetvals(std::vector<Tensor> * rets) const960 Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
961   rets->clear();
962   rets->reserve(rets_.size());
963   for (size_t i = 0; i < rets_.size(); ++i) {
964     const auto& item = rets_[i];
965     if (item.has_val) {
966       rets->push_back(item.val);
967     } else {
968       return errors::Internal("Retval[", i, "] does not have value");
969     }
970   }
971   return Status::OK();
972 }
973 
ConsumeRetvals(std::vector<Tensor> * rets,bool allow_dead_tensors)974 Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets,
975                                          bool allow_dead_tensors) {
976   rets->clear();
977   rets->reserve(rets_.size());
978   for (size_t i = 0; i < rets_.size(); ++i) {
979     if (rets_[i].has_val) {
980       rets->emplace_back(std::move(rets_[i].val));
981     } else if (allow_dead_tensors) {
982       rets->emplace_back();
983     } else {
984       return errors::Internal("Retval[", i, "] does not have value");
985     }
986   }
987   return Status::OK();
988 }
989 
GetArg(int index,Tensor * val) const990 Status FunctionCallFrame::GetArg(int index, Tensor* val) const {
991   if (index < 0 || static_cast<size_t>(index) >= args_.size()) {
992     return errors::InvalidArgument("GetArg ", index, " is not within [0, ",
993                                    args_.size(), ")");
994   }
995   *val = args_[index];
996   return Status::OK();
997 }
998 
SetRetval(int index,const Tensor & val)999 Status FunctionCallFrame::SetRetval(int index, const Tensor& val) {
1000   if (index < 0 || static_cast<size_t>(index) >= rets_.size()) {
1001     return errors::InvalidArgument("SetRetval ", index, " is not within [0, ",
1002                                    rets_.size(), ")");
1003   }
1004   if (val.dtype() != ret_types_[index]) {
1005     return errors::InvalidArgument(
1006         "Expects ret[", index, "] to be ", DataTypeString(ret_types_[index]),
1007         ", but ", DataTypeString(val.dtype()), " is provided.");
1008   }
1009   Retval* item = &rets_[index];
1010   if (!item->has_val) {
1011     item->has_val = true;
1012     item->val = val;
1013   } else {
1014     return errors::Internal("Retval[", index, "] has already been set.");
1015   }
1016   return Status::OK();
1017 }
1018 
1019 FunctionLibraryDefinition::FunctionDefAndOpRegistration::
FunctionDefAndOpRegistration(const FunctionDef & fdef_in)1020     FunctionDefAndOpRegistration(const FunctionDef& fdef_in)
1021     : fdef(fdef_in),
1022       // Exact shape inference for functions is handled by ShapeRefiner.
1023       // Here we pass a dummy shape inference function for legacy code paths.
1024       op_registration_data(fdef.signature(), shape_inference::UnknownShape,
1025                            true /* is_function */) {}
1026 
FunctionLibraryDefinition(const FunctionLibraryDefinition & other)1027 FunctionLibraryDefinition::FunctionLibraryDefinition(
1028     const FunctionLibraryDefinition& other)
1029     : default_registry_(other.default_registry_) {
1030   tf_shared_lock l(other.mu_);
1031   for (const auto& it : other.function_defs_) {
1032     TF_CHECK_OK(AddFunctionDef(it.second->fdef));
1033   }
1034   func_grad_ = other.func_grad_;
1035 }
1036 
FunctionLibraryDefinition(const OpRegistryInterface * default_registry,const FunctionDefLibrary & def_lib)1037 FunctionLibraryDefinition::FunctionLibraryDefinition(
1038     const OpRegistryInterface* default_registry,
1039     const FunctionDefLibrary& def_lib)
1040     : default_registry_(default_registry),
1041       function_defs_(def_lib.function_size()) {
1042   for (const auto& fdef : def_lib.function()) {
1043     // The latter function definition wins.
1044     auto& ptr = function_defs_[fdef.signature().name()];
1045     ptr.reset(new FunctionDefAndOpRegistration(fdef));
1046   }
1047   for (const auto& grad : def_lib.gradient()) {
1048     func_grad_[grad.function_name()] = grad.gradient_func();
1049   }
1050 }
1051 
~FunctionLibraryDefinition()1052 FunctionLibraryDefinition::~FunctionLibraryDefinition() {}
1053 
Contains(const string & func) const1054 bool FunctionLibraryDefinition::Contains(const string& func) const {
1055   tf_shared_lock l(mu_);
1056   return function_defs_.find(func) != function_defs_.end();
1057 }
1058 
Find(const string & func) const1059 const FunctionDef* FunctionLibraryDefinition::Find(const string& func) const {
1060   tf_shared_lock l(mu_);
1061   return FindHelper(func);
1062 }
1063 
FindHelper(const string & func) const1064 const FunctionDef* FunctionLibraryDefinition::FindHelper(
1065     const string& func) const {
1066   auto iter = function_defs_.find(func);
1067   if (iter == function_defs_.end()) {
1068     return nullptr;
1069   } else {
1070     return &iter->second->fdef;
1071   }
1072 }
1073 
AddFunctionDef(const FunctionDef & fdef)1074 Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
1075   mutex_lock l(mu_);
1076   bool added;
1077   return AddFunctionDefHelper(fdef, &added);
1078 }
1079 
AddFunctionDefHelper(const FunctionDef & fdef,bool * added)1080 Status FunctionLibraryDefinition::AddFunctionDefHelper(const FunctionDef& fdef,
1081                                                        bool* added) {
1082   *added = false;
1083   std::unique_ptr<FunctionDefAndOpRegistration>* entry =
1084       &function_defs_[fdef.signature().name()];
1085   if (*entry != nullptr) {
1086     if (!FunctionDefsEqual((*entry)->fdef, fdef)) {
1087       return errors::InvalidArgument(
1088           "Cannot add function '", fdef.signature().name(),
1089           "' because a different function with the same name already "
1090           "exists.");
1091     }
1092     // Ignore duplicate FunctionDefs
1093     return Status::OK();
1094   }
1095   const OpDef* op_def;
1096   if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) {
1097     return errors::InvalidArgument(
1098         "Cannot add function '", fdef.signature().name(),
1099         "' because an op with the same name already exists.");
1100   }
1101   entry->reset(new FunctionDefAndOpRegistration(fdef));
1102   *added = true;
1103   return Status::OK();
1104 }
1105 
AddGradientDef(const GradientDef & grad)1106 Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
1107   mutex_lock l(mu_);
1108   bool added;
1109   return AddGradientDefHelper(grad, &added);
1110 }
1111 
AddGradientDefHelper(const GradientDef & grad,bool * added)1112 Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad,
1113                                                        bool* added) {
1114   *added = false;
1115   string* entry = &func_grad_[grad.function_name()];
1116   if (!entry->empty()) {
1117     if (*entry != grad.gradient_func()) {
1118       return errors::InvalidArgument(
1119           "Cannot assign gradient function '", grad.gradient_func(), "' to '",
1120           grad.function_name(), "' because it already has gradient function ",
1121           "'", *entry, "'");
1122     }
1123     // Ignore duplicate GradientDefs
1124     return Status::OK();
1125   }
1126   *entry = grad.gradient_func();
1127   *added = true;
1128   return Status::OK();
1129 }
1130 
AddLibrary(const FunctionLibraryDefinition & other)1131 Status FunctionLibraryDefinition::AddLibrary(
1132     const FunctionLibraryDefinition& other) {
1133   // Clone `other` to ensure thread-safety (grabbing `other`'s lock for
1134   // the duration of the function could lead to deadlock).
1135   FunctionLibraryDefinition clone(other);
1136   mutex_lock l(mu_);
1137   // Remember the funcs and grads that we added successfully so that
1138   // we can roll them back on error.
1139   std::vector<string> funcs;
1140   std::vector<string> funcs_with_grads;
1141   Status s;
1142   bool added;
1143   for (auto iter : clone.function_defs_) {
1144     s = AddFunctionDefHelper(iter.second->fdef, &added);
1145     if (!s.ok()) {
1146       Remove(funcs, funcs_with_grads);
1147       return s;
1148     }
1149     if (added) {
1150       funcs.push_back(iter.second->fdef.signature().name());
1151     }
1152   }
1153   for (auto iter : clone.func_grad_) {
1154     GradientDef grad;
1155     grad.set_function_name(iter.first);
1156     grad.set_gradient_func(iter.second);
1157     s = AddGradientDefHelper(grad, &added);
1158     if (!s.ok()) {
1159       Remove(funcs, funcs_with_grads);
1160       return s;
1161     }
1162     if (added) {
1163       funcs_with_grads.push_back(grad.function_name());
1164     }
1165   }
1166   return Status::OK();
1167 }
1168 
AddLibrary(const FunctionDefLibrary & lib_def)1169 Status FunctionLibraryDefinition::AddLibrary(
1170     const FunctionDefLibrary& lib_def) {
1171   // Remember the funcs and grads that we added successfully so that
1172   // we can roll them back on error.
1173   mutex_lock l(mu_);
1174   std::vector<string> funcs;
1175   std::vector<string> funcs_with_grads;
1176   Status s;
1177   bool added;
1178   for (const FunctionDef& fdef : lib_def.function()) {
1179     s = AddFunctionDefHelper(fdef, &added);
1180     if (!s.ok()) {
1181       Remove(funcs, funcs_with_grads);
1182       return s;
1183     }
1184     if (added) {
1185       funcs.push_back(fdef.signature().name());
1186     }
1187   }
1188   for (const GradientDef& grad : lib_def.gradient()) {
1189     s = AddGradientDefHelper(grad, &added);
1190     if (!s.ok()) {
1191       Remove(funcs, funcs_with_grads);
1192       return s;
1193     }
1194     if (added) {
1195       funcs_with_grads.push_back(grad.function_name());
1196     }
1197   }
1198   return Status::OK();
1199 }
1200 
ReplaceFunction(const string & func,const FunctionDef & fdef)1201 Status FunctionLibraryDefinition::ReplaceFunction(const string& func,
1202                                                   const FunctionDef& fdef) {
1203   mutex_lock l(mu_);
1204   bool added;
1205   TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
1206   TF_RETURN_IF_ERROR(AddFunctionDefHelper(fdef, &added));
1207   return Status::OK();
1208 }
1209 
ReplaceGradient(const GradientDef & grad)1210 Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) {
1211   mutex_lock l(mu_);
1212   bool added;
1213   TF_RETURN_IF_ERROR(RemoveGradient(grad.function_name()));
1214   TF_RETURN_IF_ERROR(AddGradientDefHelper(grad, &added));
1215   return Status::OK();
1216 }
1217 
RemoveFunction(const string & func)1218 Status FunctionLibraryDefinition::RemoveFunction(const string& func) {
1219   mutex_lock l(mu_);
1220   TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
1221   return Status::OK();
1222 }
1223 
RemoveFunctionHelper(const string & func)1224 Status FunctionLibraryDefinition::RemoveFunctionHelper(const string& func) {
1225   const auto& i = function_defs_.find(func);
1226   if (i == function_defs_.end()) {
1227     return errors::InvalidArgument("Tried to remove non-existent function ",
1228                                    func);
1229   }
1230   function_defs_.erase(i);
1231   return Status::OK();
1232 }
1233 
RemoveGradient(const string & func)1234 Status FunctionLibraryDefinition::RemoveGradient(const string& func) {
1235   const auto& i = func_grad_.find(func);
1236   if (i == func_grad_.end()) {
1237     return errors::InvalidArgument("Tried to remove non-existent gradient ",
1238                                    func);
1239   }
1240   func_grad_.erase(i);
1241   return Status::OK();
1242 }
1243 
Remove(const std::vector<string> & funcs,const std::vector<string> & funcs_with_grads)1244 void FunctionLibraryDefinition::Remove(
1245     const std::vector<string>& funcs,
1246     const std::vector<string>& funcs_with_grads) {
1247   for (const string& f : funcs) {
1248     Status s = RemoveFunctionHelper(f);
1249     DCHECK(s.ok());
1250   }
1251   for (const string& f : funcs_with_grads) {
1252     Status s = RemoveGradient(f);
1253     DCHECK(s.ok());
1254   }
1255 }
1256 
FindGradient(const string & func) const1257 string FunctionLibraryDefinition::FindGradient(const string& func) const {
1258   tf_shared_lock l(mu_);
1259   return gtl::FindWithDefault(func_grad_, func, "");
1260 }
1261 
FindGradientHelper(const string & func) const1262 string FunctionLibraryDefinition::FindGradientHelper(const string& func) const {
1263   return gtl::FindWithDefault(func_grad_, func, "");
1264 }
1265 
LookUp(const string & op,const OpRegistrationData ** op_reg_data) const1266 Status FunctionLibraryDefinition::LookUp(
1267     const string& op, const OpRegistrationData** op_reg_data) const {
1268   tf_shared_lock l(mu_);
1269   auto iter = function_defs_.find(op);
1270   if (iter != function_defs_.end()) {
1271     *op_reg_data = &iter->second->op_registration_data;
1272     return Status::OK();
1273   }
1274   return default_registry_->LookUp(op, op_reg_data);
1275 }
1276 
UniqueFunctionName(StringPiece prefix) const1277 string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const {
1278   tf_shared_lock l(mu_);
1279   int index = 0;
1280   string name = strings::StrCat(prefix, index);
1281   while (function_defs_.find(name) != function_defs_.end()) {
1282     ++index;
1283     name = strings::StrCat(prefix, index);
1284   }
1285   return name;
1286 }
1287 
GetAttrImpl(const NodeDef & ndef) const1288 const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
1289     const NodeDef& ndef) const {
1290   if (ndef.op() != kGradientOp) {
1291     // If 'ndef' calls a function and the function's def has the attr,
1292     // returns it.
1293     return Find(ndef.op());
1294   }
1295 
1296   // If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or
1297   // Foo's attributes.
1298   const NameAttrList* forward_func_attrs;
1299   if (!GetNodeAttr(ndef, kFuncAttr, &forward_func_attrs).ok()) {
1300     return nullptr;
1301   }
1302   const string& func_name = forward_func_attrs->name();
1303   {
1304     tf_shared_lock l(mu_);
1305     const string& grad_name = FindGradientHelper(func_name);
1306     // If 'func' has a user-defined gradient function, uses the grad
1307     // function's attrs to see if noinline is specified. Otherwise,
1308     // uses func's attrs.
1309     if (!grad_name.empty()) {
1310       return FindHelper(grad_name);
1311     }
1312     return FindHelper(func_name);
1313   }
1314 }
1315 
ListFunctionNames() const1316 std::vector<string> FunctionLibraryDefinition::ListFunctionNames() const {
1317   std::vector<string> function_names;
1318   tf_shared_lock l(mu_);
1319   function_names.reserve(function_defs_.size());
1320   for (const auto& it : function_defs_) {
1321     function_names.emplace_back(it.first);
1322   }
1323   return function_names;
1324 }
1325 
ToProto() const1326 FunctionDefLibrary FunctionLibraryDefinition::ToProto() const {
1327   FunctionDefLibrary lib;
1328   tf_shared_lock l(mu_);
1329   for (const auto& f : function_defs_) {
1330     *lib.add_function() = f.second->fdef;
1331   }
1332   for (const auto& g : func_grad_) {
1333     GradientDef* gd = lib.add_gradient();
1334     gd->set_function_name(g.first);
1335     gd->set_gradient_func(g.second);
1336   }
1337   return lib;
1338 }
1339 
1340 template <typename T>
GetAttr(const NodeDef & ndef,const string & attr,T * value) const1341 Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
1342                                           const string& attr, T* value) const {
1343   const FunctionDef* fdef = GetAttrImpl(ndef);
1344   if (fdef && GetNodeAttr(AttrSlice(&fdef->attr()), attr, value).ok()) {
1345     return Status::OK();
1346   }
1347   return errors::InvalidArgument("Attr ", attr, " is not defined.");
1348 }
1349 
1350 template <typename T>
GetAttr(const Node & node,const string & attr,T * value) const1351 Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr,
1352                                           T* value) const {
1353   return GetAttr(node.def(), attr, value);
1354 }
1355 
1356 #define GET_ATTR(T)                                                            \
1357   template Status FunctionLibraryDefinition::GetAttr(const Node&,              \
1358                                                      const string&, T*) const; \
1359   template Status FunctionLibraryDefinition::GetAttr(const NodeDef&,           \
1360                                                      const string&, T*) const;
1361 GET_ATTR(string)
1362 GET_ATTR(bool)
1363 #undef GET_ATTR
1364 
1365 namespace {
1366 
1367 constexpr char kApiImplements[] = "api_implements";
1368 
ReachableFunctions(const FunctionLibraryDefinition & flib,const protobuf::RepeatedPtrField<NodeDef> & nodes)1369 absl::flat_hash_set<string> ReachableFunctions(
1370     const FunctionLibraryDefinition& flib,
1371     const protobuf::RepeatedPtrField<NodeDef>& nodes) {
1372   // Functions that are reachable from the graph.
1373   absl::flat_hash_set<string> reachable_funcs;
1374 
1375   // For any functions, if it has attribute "api_implements" =
1376   // "some_interface" and it is reachable, then it means any other
1377   // function with same attribute name and value could also be potentially
1378   // reachable, eg via implementation_selector swapping the
1379   // nodedef.
1380   absl::flat_hash_set<string> reachable_api_interface;
1381 
1382   // Functions might be reachable from the nested function calls, so we keep a
1383   // queue of functions that we have to check.
1384   gtl::InlinedVector<const FunctionDef*, 4> func_queue;
1385 
1386   // Add reachable and not already processed functions to the functions queue.
1387   const auto add_to_func_queue = [&](const string& func_name) {
1388     const FunctionDef* func = flib.Find(func_name);
1389     if (func && reachable_funcs.find(func_name) == reachable_funcs.end()) {
1390       func_queue.push_back(func);
1391     }
1392   };
1393 
1394   // Add all the functions that are reachable from the given node to the queue.
1395   const auto process_node = [&](const NodeDef& node) {
1396     // Node itself can be a call to the function.
1397     add_to_func_queue(node.op());
1398 
1399     // Or node can have an attribute referencing a function.
1400     for (const auto& attr : node.attr()) {
1401       const auto& attr_value = attr.second;
1402 
1403       // 1. AttrValue.func
1404       if (attr_value.has_func()) {
1405         add_to_func_queue(attr_value.func().name());
1406       }
1407 
1408       // 2. AttrValue.ListValue.func
1409       if (attr_value.has_list()) {
1410         for (const auto& func : attr_value.list().func()) {
1411           add_to_func_queue(func.name());
1412         }
1413       }
1414     }
1415   };
1416 
1417   // Add all functions that are directly called from the optimized graph.
1418   std::for_each(nodes.begin(), nodes.end(), process_node);
1419 
1420   // Process all reachable functions.
1421   while (!func_queue.empty()) {
1422     const FunctionDef* func = func_queue.back();
1423     func_queue.pop_back();
1424 
1425     const string& func_name = func->signature().name();
1426     reachable_funcs.insert(func_name);
1427 
1428     const auto attr_it = func->attr().find(kApiImplements);
1429     if (attr_it != func->attr().end()) {
1430       reachable_api_interface.insert(attr_it->second.s());
1431     }
1432 
1433     // Find all the functions called from the function body.
1434     const auto& func_body = func->node_def();
1435     std::for_each(func_body.begin(), func_body.end(), process_node);
1436 
1437     // Check if the function has a registered gradient.
1438     const string grad_func_name = flib.FindGradient(func_name);
1439     if (!grad_func_name.empty()) add_to_func_queue(grad_func_name);
1440   }
1441 
1442   for (const auto& func_name : flib.ListFunctionNames()) {
1443     const auto& func_def = flib.Find(func_name);
1444     const auto attr_it = func_def->attr().find(kApiImplements);
1445     if (attr_it != func_def->attr().end()) {
1446       if (reachable_api_interface.contains(attr_it->second.s())) {
1447         reachable_funcs.insert(func_name);
1448       }
1449     }
1450   }
1451 
1452   return reachable_funcs;
1453 }
1454 
ReachableFunctionLibraryDefinition(const FunctionLibraryDefinition & flib,const protobuf::RepeatedPtrField<NodeDef> & nodes)1455 FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
1456     const FunctionLibraryDefinition& flib,
1457     const protobuf::RepeatedPtrField<NodeDef>& nodes) {
1458   absl::flat_hash_set<string> reachable_funcs = ReachableFunctions(flib, nodes);
1459 
1460   FunctionLibraryDefinition reachable_flib(flib.default_registry(),
1461                                            FunctionDefLibrary());
1462 
1463   for (const string& func_name : reachable_funcs) {
1464     const FunctionDef* func = flib.Find(func_name);
1465     DCHECK_NE(func, nullptr);
1466     // That should never fail, because we copy functions from valid flib and use
1467     // the same default registry.
1468     const Status added = reachable_flib.AddFunctionDef(*func);
1469     DCHECK(added.ok());
1470 
1471     const string grad_func_name = flib.FindGradient(func_name);
1472     if (!grad_func_name.empty()) {
1473       GradientDef grad;
1474       grad.set_function_name(func_name);
1475       grad.set_gradient_func(grad_func_name);
1476       // It can only fail if function already has a gradient function.
1477       const Status added_grad = reachable_flib.AddGradientDef(grad);
1478       DCHECK(added_grad.ok());
1479     }
1480   }
1481 
1482   return reachable_flib;
1483 }
1484 
1485 }  // namespace
1486 
ReachableDefinitions(const GraphDef & graph) const1487 FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions(
1488     const GraphDef& graph) const {
1489   return ReachableFunctionLibraryDefinition(*this, graph.node());
1490 }
1491 
ReachableDefinitions(const FunctionDef & func) const1492 FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions(
1493     const FunctionDef& func) const {
1494   return ReachableFunctionLibraryDefinition(*this, func.node_def());
1495 }
1496 
InitFromString(StringPiece val)1497 void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) {
1498   if (val.size() >= 2 && val[0] == '$') {
1499     proto.set_placeholder(val.data() + 1, val.size() - 1);
1500   } else {
1501     SetAttrValue(val, &proto);
1502   }
1503 }
1504 
FunctionRef(const string & name,gtl::ArraySlice<std::pair<string,AttrValueWrapper>> attrs)1505 FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef(
1506     const string& name,
1507     gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs) {
1508   AttrValueWrapper ret;
1509   ret.proto.mutable_func()->set_name(name);
1510   for (const auto& a : attrs) {
1511     ret.proto.mutable_func()->mutable_attr()->insert({a.first, a.second.proto});
1512   }
1513   return ret;
1514 }
1515 
ToNodeDef() const1516 NodeDef FunctionDefHelper::Node::ToNodeDef() const {
1517   NodeDef n;
1518   n.set_op(this->op);
1519   n.set_name(this->ret[0]);
1520   for (const auto& a : this->attr) {
1521     n.mutable_attr()->insert({a.first, a.second.proto});
1522   }
1523   for (const string& a : this->arg) {
1524     n.add_input(a);
1525   }
1526   for (const string& d : this->dep) {
1527     n.add_input(strings::StrCat("^", d));
1528   }
1529   if (!this->device.empty()) {
1530     n.set_device(this->device);
1531   }
1532   return n;
1533 }
1534 
1535 /* static */
Create(const string & function_name,gtl::ArraySlice<string> in_def,gtl::ArraySlice<string> out_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def,gtl::ArraySlice<std::pair<string,string>> ret_def,gtl::ArraySlice<std::pair<string,string>> control_ret_def)1536 FunctionDef FunctionDefHelper::Create(
1537     const string& function_name, gtl::ArraySlice<string> in_def,
1538     gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
1539     gtl::ArraySlice<Node> node_def,
1540     gtl::ArraySlice<std::pair<string, string>> ret_def,
1541     gtl::ArraySlice<std::pair<string, string>> control_ret_def) {
1542   FunctionDef fdef;
1543 
1544   // Signature
1545   OpDefBuilder b(function_name);
1546   for (const auto& i : in_def) b.Input(i);
1547   for (const auto& o : out_def) b.Output(o);
1548   for (const auto& a : attr_def) b.Attr(a);
1549   for (const auto& c : control_ret_def) b.ControlOutput(c.first);
1550 
1551   OpRegistrationData op_reg_data;
1552   TF_CHECK_OK(b.Finalize(&op_reg_data));
1553   fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1554 
1555   // Function body
1556   for (const auto& n : node_def) {
1557     *(fdef.add_node_def()) = n.ToNodeDef();
1558   }
1559 
1560   // Returns
1561   for (const auto& r : ret_def) {
1562     fdef.mutable_ret()->insert({r.first, r.second});
1563   }
1564 
1565   // Control returns
1566   for (const auto& cr : control_ret_def) {
1567     fdef.mutable_control_ret()->insert({cr.first, cr.second});
1568   }
1569 
1570   auto* op_def_registry = OpRegistry::Global();
1571   // Check if any op is stateful.
1572   for (const auto& n : node_def) {
1573     const OpDef* op_def = nullptr;
1574     auto status = op_def_registry->LookUpOpDef(n.op, &op_def);
1575     // Lookup can fail if e.g. we are calling a function that was not yet
1576     // defined.  If it happens, conservatively assume the op is stateful.
1577     if (!status.ok() || op_def->is_stateful()) {
1578       fdef.mutable_signature()->set_is_stateful(true);
1579     }
1580   }
1581 
1582   return fdef;
1583 }
1584 
1585 /* static */
Create(const string & function_name,gtl::ArraySlice<string> in_def,gtl::ArraySlice<string> out_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def,gtl::ArraySlice<std::pair<string,string>> ret_def)1586 FunctionDef FunctionDefHelper::Create(
1587     const string& function_name, gtl::ArraySlice<string> in_def,
1588     gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
1589     gtl::ArraySlice<Node> node_def,
1590     gtl::ArraySlice<std::pair<string, string>> ret_def) {
1591   return Create(function_name, in_def, out_def, attr_def, node_def, ret_def,
1592                 /*control_ret_def=*/{});
1593 }
1594 
1595 /* static */
Define(const string & name,gtl::ArraySlice<string> arg_def,gtl::ArraySlice<string> ret_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def)1596 FunctionDef FunctionDefHelper::Define(const string& name,
1597                                       gtl::ArraySlice<string> arg_def,
1598                                       gtl::ArraySlice<string> ret_def,
1599                                       gtl::ArraySlice<string> attr_def,
1600                                       gtl::ArraySlice<Node> node_def) {
1601   FunctionDef fdef;
1602   OpDefBuilder b(name);
1603   for (const auto& a : arg_def) b.Input(a);
1604   for (const auto& r : ret_def) b.Output(r);
1605   for (const auto& a : attr_def) b.Attr(a);
1606 
1607   OpRegistrationData op_reg_data;
1608   TF_CHECK_OK(b.Finalize(&op_reg_data));
1609   fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1610 
1611   // Mapping from legacy output names to NodeDef outputs.
1612   std::unordered_map<string, string> ret_index;
1613   for (const auto& a : fdef.signature().input_arg()) {
1614     ret_index[a.name()] = a.name();
1615   }
1616 
1617   // For looking up OpDefs
1618   auto* op_def_registry = OpRegistry::Global();
1619 
1620   // Function body
1621   for (const auto& src : node_def) {
1622     NodeDef* n = fdef.add_node_def();
1623     n->set_op(src.op);
1624     n->set_name(src.ret[0]);
1625     for (const auto& a : src.attr) {
1626       n->mutable_attr()->insert({a.first, a.second.proto});
1627     }
1628     for (const string& a : src.arg) {
1629       const auto iter = ret_index.find(a);
1630       CHECK(iter != ret_index.end())
1631           << "Node input '" << a << "' in '" << src.ret[0] << "' of " << name;
1632       n->add_input(iter->second);
1633     }
1634     for (const string& d : src.dep) {
1635       n->add_input(strings::StrCat("^", d));
1636     }
1637 
1638     // Add the outputs of this node to ret_index.
1639     const OpDef* op_def = nullptr;
1640     TF_CHECK_OK(op_def_registry->LookUpOpDef(n->op(), &op_def)) << n->op();
1641     CHECK(op_def != nullptr) << n->op();
1642     NameRangeMap output_names;
1643     TF_CHECK_OK(NameRangesForNode(*n, *op_def, nullptr, &output_names));
1644     for (const auto& o : output_names) {
1645       CHECK_LE(o.second.second, src.ret.size())
1646           << "Missing ret for output '" << o.first << "' in '" << src.ret[0]
1647           << "' of " << name;
1648       for (int i = o.second.first; i < o.second.second; ++i) {
1649         ret_index[src.ret[i]] =
1650             strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first);
1651       }
1652     }
1653     if (op_def->is_stateful()) fdef.mutable_signature()->set_is_stateful(true);
1654   }
1655 
1656   // Returns
1657   for (const auto& r : fdef.signature().output_arg()) {
1658     const auto iter = ret_index.find(r.name());
1659     CHECK(iter != ret_index.end()) << "Return '" << r.name() << "' in " << name;
1660     fdef.mutable_ret()->insert({r.name(), iter->second});
1661   }
1662   return fdef;
1663 }
1664 
Define(gtl::ArraySlice<string> arg_def,gtl::ArraySlice<string> ret_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def)1665 FunctionDef FunctionDefHelper::Define(gtl::ArraySlice<string> arg_def,
1666                                       gtl::ArraySlice<string> ret_def,
1667                                       gtl::ArraySlice<string> attr_def,
1668                                       gtl::ArraySlice<Node> node_def) {
1669   return Define("_", arg_def, ret_def, attr_def, node_def);
1670 }
1671 
1672 namespace gradient {
1673 
1674 typedef std::unordered_map<string, Creator> OpGradFactory;
1675 
GetOpGradFactory()1676 OpGradFactory* GetOpGradFactory() {
1677   static OpGradFactory* factory = new OpGradFactory;
1678   return factory;
1679 }
1680 
RegisterOp(const string & op,Creator func)1681 bool RegisterOp(const string& op, Creator func) {
1682   CHECK(GetOpGradFactory()->insert({op, func}).second)
1683       << "Duplicated gradient for " << op;
1684   return true;
1685 }
1686 
GetOpGradientCreator(const string & op,Creator * creator)1687 Status GetOpGradientCreator(const string& op, Creator* creator) {
1688   auto fac = GetOpGradFactory();
1689   auto iter = fac->find(op);
1690   if (iter == fac->end()) {
1691     return errors::NotFound("No gradient defined for op: ", op);
1692   }
1693   *creator = iter->second;
1694   return Status::OK();
1695 }
1696 
1697 }  // end namespace gradient
1698 
1699 }  // namespace tensorflow
1700