1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/function.h"
17 
18 #include <ctype.h>
19 
20 #include <map>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/escaping.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_join.h"
29 #include "tensorflow/core/framework/allocator.h"
30 #include "tensorflow/core/framework/common_shape_fns.h"
31 #include "tensorflow/core/framework/function.pb.h"
32 #include "tensorflow/core/framework/graph.pb.h"
33 #include "tensorflow/core/framework/node_def.pb.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/op.h"
36 #include "tensorflow/core/graph/graph.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/gtl/inlined_vector.h"
39 #include "tensorflow/core/lib/gtl/map_util.h"
40 #include "tensorflow/core/util/device_name_utils.h"
41 #include "tensorflow/core/util/equal_graph_def.h"
42 
43 namespace tensorflow {
44 
45 /* static */ constexpr const char* const FunctionLibraryDefinition::kArgOp;
46 /* static */ constexpr const char* const
47     FunctionLibraryDefinition::kDeviceArgOp;
48 /* static */ constexpr const char* const FunctionLibraryDefinition::kRetOp;
49 /* static */ constexpr const char* const
50     FunctionLibraryDefinition::kDeviceRetOp;
51 /* static */ constexpr const char* const
52     FunctionLibraryDefinition::kIntsOnDeviceAttr;
53 /* static */ constexpr const char* const FunctionLibraryDefinition::kGradientOp;
54 /* static */ constexpr const char* const FunctionLibraryDefinition::kFuncAttr;
55 
56 // Extracts the actual type from "attr_values" based on its definition
57 // "arg_def".
58 //
59 // If "arg_def" is a N*T type, *is_type_list is set to false, and
60 // *dtypes is set to be a vector of size N and each element is T.
61 //
62 // If "arg_def" is a list(type), *is_type_list is set to true, and
63 // *dtypes is set to be a vector of types specified in attrs for
64 // arg_def.
65 //
66 // Otherwise (arg_def is a simple type T), *is_type_list is set to
67 // false, and *dtypes is set to a single element vector, whose only
68 // element is T.
ArgNumType(AttrSlice attrs,const OpDef::ArgDef & arg_def,bool * is_type_list,DataTypeVector * dtypes)69 Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
70                   bool* is_type_list, DataTypeVector* dtypes) {
71   dtypes->clear();
72   if (!arg_def.type_list_attr().empty()) {
73     const AttrValue* v = attrs.Find(arg_def.type_list_attr());
74     if (v == nullptr) {
75       return errors::NotFound("type attr not found: ",
76                               arg_def.type_list_attr());
77     }
78     *is_type_list = true;
79     for (int i = 0; i < v->list().type_size(); ++i) {
80       dtypes->push_back(v->list().type(i));
81     }
82     return Status::OK();
83   }
84 
85   *is_type_list = false;
86   int num = 1;
87   if (!arg_def.number_attr().empty()) {
88     const AttrValue* v = attrs.Find(arg_def.number_attr());
89     if (v == nullptr) {
90       return errors::NotFound("type attr not found: ", arg_def.type_attr());
91     }
92     num = v->i();
93   }
94 
95   DataType dtype;
96   if (arg_def.type() != DT_INVALID) {
97     dtype = arg_def.type();
98   } else if (arg_def.type_attr().empty()) {
99     dtype = DT_INVALID;
100   } else {
101     const AttrValue* v = attrs.Find(arg_def.type_attr());
102     if (v == nullptr) {
103       return errors::NotFound("type attr not found: ", arg_def.type_attr());
104     }
105     dtype = v->type();
106   }
107   dtypes->resize(num, dtype);
108   return Status::OK();
109 }
110 
111 namespace {
112 
113 template <typename T>
AddAttr(const string & name,const T & val,NodeDef * ndef)114 void AddAttr(const string& name, const T& val, NodeDef* ndef) {
115   SetAttrValue(val, &((*ndef->mutable_attr())[name]));
116 }
117 
ValidateSignatureWithAttrs(const OpDef & sig,AttrSlice attr_values)118 Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) {
119   // attr_values should specify all attrs defined in fdef.
120   for (const auto& a : sig.attr()) {
121     const AttrValue* v = attr_values.Find(a.name());
122     if (!v) {
123       return errors::NotFound("Attr ", a.name(), " is not found from ",
124                               SummarizeOpDef(sig));
125     }
126     Status status = AttrValueHasType(*v, a.type());
127     if (!status.ok()) {
128       errors::AppendToMessage(&status, "for attr '", a.name(), "'");
129       return status;
130     }
131   }
132 
133 // TODO(josh11b): Enable this code once it works with function gradients.
134 // Right now the C++ function gradient code assumes it can pass
135 // all the attrs of the function to the gradient, and any attrs that
136 // the gradient doesn't care about will be ignored.
137 #if 0
138   if (attr_values.size() != sig.attr_size()) {
139     for (const auto& a : attr_values) {
140       // TODO(josh11b): Possibly should ignore attrs that start with "_" here?
141       bool found = false;
142       for (const auto& s : sig.attr()) {
143         if (a.first == s.name()) {
144           found = true;
145           break;
146         }
147       }
148       if (!found) {
149         return errors::NotFound("Attr ", a.first, " is not found in ",
150                                 SummarizeOpDef(sig));
151       }
152     }
153   }
154 #endif
155 
156   return Status::OK();
157 }
158 
159 // A helper class for instantiating functions. This contains shared information
160 // like the resulting graph and node name index.
161 class FunctionInstantiationHelper {
162  public:
FunctionInstantiationHelper(GetFunctionSignature get_function,InstantiationResult * result)163   FunctionInstantiationHelper(GetFunctionSignature get_function,
164                               InstantiationResult* result)
165       : get_function_(std ::move(get_function)), result_(*result) {
166     result_.nodes.clear();
167   }
168 
169   // Builds index for nodes that can be used as node's input arguments.
170   // `resource_arg_unique_id`: if non-negative, will be populated to the
171   // "_resource_arg_unique_id" attribute of the arg node.
BuildInputArgIndex(const OpDef::ArgDef & arg_def,AttrSlice attr_values,const FunctionDef::ArgAttrs * arg_attrs,bool ints_on_device,int64 resource_arg_unique_id)172   Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, AttrSlice attr_values,
173                             const FunctionDef::ArgAttrs* arg_attrs,
174                             bool ints_on_device, int64 resource_arg_unique_id) {
175     bool is_type_list;
176     DataTypeVector dtypes;
177     TF_RETURN_IF_ERROR(
178         ArgNumType(attr_values, arg_def, &is_type_list, &dtypes));
179     CHECK_GE(dtypes.size(), size_t{1});
180     int arg_index = result_.nodes.size();
181     TF_RETURN_IF_ERROR(
182         AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes}));
183     // Creates dtypes.size() nodes in the graph.
184     for (size_t i = 0; i < dtypes.size(); ++i) {
185       TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":", i),
186                                  {true, arg_index, 0, false, {dtypes[i]}}));
187       DCHECK_EQ(arg_index, result_.nodes.size());
188       string name = arg_def.name();
189       if (dtypes.size() > 1) {
190         strings::StrAppend(&name, "_", i);
191       }
192       NodeDef* gnode = AddNode(name);
193       if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
194         gnode->set_op(FunctionLibraryDefinition::kDeviceArgOp);
195       } else {
196         gnode->set_op(FunctionLibraryDefinition::kArgOp);
197       }
198       DataType dtype = arg_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i];
199       AddAttr("T", dtype, gnode);
200       AddAttr("index", arg_index, gnode);
201       if (resource_arg_unique_id >= 0) {
202         AddAttr("_resource_arg_unique_id", resource_arg_unique_id, gnode);
203       }
204       if (arg_attrs) {
205         for (const auto& arg_attr : arg_attrs->attr()) {
206           AddAttr(arg_attr.first, arg_attr.second, gnode->mutable_attr());
207         }
208       }
209       result_.arg_types.push_back(dtypes[i]);
210       ++arg_index;
211     }
212     return Status::OK();
213   }
214 
BuildNodeOutputIndex(const NodeDef & node,AttrSlice attrs,const int arg_index)215   Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs,
216                               const int arg_index) {
217     const OpDef* node_sig = nullptr;
218     TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig));
219     if (node_sig->output_arg_size() == 0) {
220       return AddItem(node.name(), {false, arg_index, 0, false, {}});
221     }
222     const int num_retval = node_sig->output_arg_size();
223     int start = 0;
224     bool is_type_list;
225     DataTypeVector dtypes;
226     for (int i = 0; i < num_retval; ++i) {
227       TF_RETURN_IF_ERROR(
228           ArgNumType(attrs, node_sig->output_arg(i), &is_type_list, &dtypes));
229       // Note that we rely on the backwards-compatibility test enforcing
230       // that output_arg(*).name() doesn't change here.
231       const string base_name =
232           strings::StrCat(node.name(), ":", node_sig->output_arg(i).name());
233       TF_RETURN_IF_ERROR(
234           AddItem(base_name, {false, arg_index, start, is_type_list, dtypes}));
235       for (int j = 0; j < static_cast<int>(dtypes.size()); ++j) {
236         TF_RETURN_IF_ERROR(
237             AddItem(strings::StrCat(base_name, ":", j),
238                     {false, arg_index, start + j, false, {dtypes[j]}}));
239       }
240       start += dtypes.size();
241     }
242     return Status::OK();
243   }
244 
InstantiateNode(const NodeDef & fnode,AttrSlice attrs)245   Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) {
246     const OpDef* fnode_sig = nullptr;
247     TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig));
248     NodeDef* gnode = AddNode(fnode.name());
249     gnode->set_op(fnode.op());
250     gnode->set_device(fnode.device());
251     int gnode_idx = nodes_.size() - 1;
252 
253     // Input
254     const int num_args = fnode_sig->input_arg_size();
255     bool is_type_list;  // ignored
256     DataTypeVector dtypes;
257     int fnode_arg_index = 0;
258     for (int i = 0; i < num_args; ++i) {
259       TF_RETURN_IF_ERROR(
260           ArgNumType(attrs, fnode_sig->input_arg(i), &is_type_list, &dtypes));
261       // Consume inputs (indexed by fnode_arg_index) until we have
262       // matched each element of dtypes (indexed by j).
263       for (size_t j = 0; j < dtypes.size(); ++fnode_arg_index) {
264         if (fnode_arg_index >= fnode.input_size()) {
265           // Should never happen if we computed dtypes correctly.
266           return errors::InvalidArgument(
267               "Attempt to access beyond input size: ", fnode_arg_index,
268               " >= ", fnode.input_size());
269         }
270         // Look up the next input.
271         const string& input_name = fnode.input(fnode_arg_index);
272         const auto* item = GetItemOrNull(input_name);
273         if (item == nullptr) {
274           return errors::InvalidArgument(
275               "input ", input_name,
276               " is not found: ", FormatNodeDefForError(fnode));
277         }
278         if (item->dtypes.size() > dtypes.size() - j) {
279           return errors::InvalidArgument("Input ", input_name, " too long for ",
280                                          fnode_sig->input_arg(i).name());
281         }
282         // Match up all the elements of this input (indexed by k) with
283         // elements of dtypes (advancing j).
284         for (int k = 0; k < item->dtypes.size(); ++k, ++j) {
285           if (item->dtypes[k] != dtypes[j]) {
286             return errors::InvalidArgument(
287                 "input ", fnode_sig->input_arg(i).name(), "[", j,
288                 "] expected type ", DataTypeString(dtypes[j]),
289                 " != ", DataTypeString(item->dtypes[k]), ", the type of ",
290                 input_name, "[", k, "]");
291           }
292           if (item->is_func_arg) {
293             AddInput(gnode_idx, item->nid + k, 0);
294           } else {
295             AddInput(gnode_idx, item->nid, item->idx + k);
296           }
297         }
298       }
299     }
300 
301     // Control deps.
302     for (int i = fnode_arg_index; i < fnode.input_size(); ++i) {
303       const string& input = fnode.input(i);
304       if (input.empty() || input[0] != '^') {
305         return errors::InvalidArgument("Expected input[", i, "] == '", input,
306                                        "' to be a control input.");
307       }
308       int nid = -1;
309       const string node_name = input.substr(1);
310       const string node_colon = node_name + ":";
311       const string node_colon_bound = node_name + ";";
312       // index_ is a map sorted lexicographically, so the key we are looking for
313       // must lie in the range [node_name, node_colon_bound).
314       auto it = index_.lower_bound(node_name);
315       while (it != index_.end() && it->first <= node_colon_bound) {
316         if (it->first == node_name || absl::StartsWith(it->first, node_colon)) {
317           nid = it->second.nid;
318           break;
319         }
320         ++it;
321       }
322       if (nid == -1) {
323         return errors::InvalidArgument("input[", i, "] == '", input,
324                                        "', is not found.");
325       }
326       AddDep(gnode_idx, nid);
327     }
328 
329     // Attrs.
330     for (const auto& p : attrs) {
331       (*gnode->mutable_attr())[p.first] = p.second;
332     }
333 
334     return Status::OK();
335   }
336 
AddReturnNode(const OpDef::ArgDef & ret_def,AttrSlice attrs,const::tensorflow::protobuf::Map<string,string> & ret_map,bool ints_on_device,int * ret_index)337   Status AddReturnNode(
338       const OpDef::ArgDef& ret_def, AttrSlice attrs,
339       const ::tensorflow::protobuf::Map<string, string>& ret_map,
340       bool ints_on_device, int* ret_index) {
341     auto ret_iter = ret_map.find(ret_def.name());
342     if (ret_iter == ret_map.end()) {
343       return errors::InvalidArgument("Return ", ret_def.name(), " missing.");
344     }
345     bool is_type_list;
346     DataTypeVector dtypes;
347     TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
348     CHECK_GE(dtypes.size(), size_t{1});
349     const auto* item = GetItemOrNull(ret_iter->second);
350     if (item == nullptr) {
351       return errors::InvalidArgument("Return ", ret_def.name(), " -> ",
352                                      ret_iter->second, " is not found.");
353     }
354     if (dtypes != item->dtypes) {
355       return errors::InvalidArgument("Invalid ret types ", ret_def.name(),
356                                      " : ", DataTypeVectorString(dtypes),
357                                      " vs. ",
358                                      DataTypeVectorString(item->dtypes));
359     }
360     for (size_t i = 0; i < dtypes.size(); ++i) {
361       string name = strings::StrCat(ret_def.name(), "_RetVal");
362       if (dtypes.size() > 1) {
363         strings::StrAppend(&name, "_", i);
364       }
365       NodeDef* gnode = AddNode(name);
366       if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
367         gnode->set_op(FunctionLibraryDefinition::kDeviceRetOp);
368       } else {
369         gnode->set_op(FunctionLibraryDefinition::kRetOp);
370       }
371       AddInput(nodes_.size() - 1, item->nid, item->idx + i);
372       DataType dtype = ret_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i];
373       AddAttr("T", dtype, gnode);
374       AddAttr("index", (*ret_index)++, gnode);
375       result_.ret_types.push_back(dtypes[i]);
376     }
377     return Status::OK();
378   }
379 
380   // Adds the actual node inputs to the result graph by converting indexes to
381   // the node names.
AddNodeInputs()382   void AddNodeInputs() {
383     for (int i = 0; i < result_.nodes.size(); i++) {
384       NodeInfo& node_info = nodes_[i];
385       for (const auto& p : node_info.data_inputs) {
386         result_.nodes[i].add_input(Name(p.first, p.second));
387       }
388       for (int index : node_info.control_inputs) {
389         result_.nodes[i].add_input(Dep(index));
390       }
391     }
392   }
393 
394  private:
395   // This is used to build a small index for all names that can be used as a
396   // node's input arguments.
397   //
398   // If is_func_arg is true, the name is a function's argument.  In
399   // this case, the produced graph def has node[nid:nid + dtype.size()].
400   //
401   // Otherwise, the name is a function body's node return value.  In
402   // this case, the produced graph def has one node node[nid] and
403   // the node's output index [idx ... idx + num) corresponds to the
404   // named outputs.
405   //
406   // In all cases, "dtype" specifies the data type.
407   struct NameInfoItem {
408     bool is_func_arg;
409     int nid;
410     int idx;
411     bool is_type_list;
412     DataTypeVector dtypes;
413   };
414 
415   // Adds an item into the input name index.
AddItem(const string & name,const NameInfoItem & item)416   Status AddItem(const string& name, const NameInfoItem& item) {
417     if (!index_.insert({name, item}).second) {
418       return errors::InvalidArgument(
419           strings::StrCat("Duplicated ", item.is_func_arg ? "arg" : "ret",
420                           " name: "),
421           name);
422     }
423     return Status::OK();
424   }
425 
GetItemOrNull(const string & name) const426   const NameInfoItem* GetItemOrNull(const string& name) const {
427     return gtl::FindOrNull(index_, name);
428   }
429 
Dep(int node_index) const430   string Dep(int node_index) const {
431     return strings::StrCat("^", Name(node_index));
432   }
433 
Name(int node_index) const434   string Name(int node_index) const {
435     CHECK_LT(node_index, nodes_.size());
436     return nodes_[node_index].name;
437   }
438 
Name(int node_index,int output_index) const439   string Name(int node_index, int output_index) const {
440     if (output_index == 0) {
441       return Name(node_index);
442     } else {
443       return strings::StrCat(Name(node_index), ":", output_index);
444     }
445   }
446 
AddNode(const string & name)447   NodeDef* AddNode(const string& name) {
448     result_.nodes.emplace_back();
449     NodeDef* gnode = &result_.nodes.back();
450     gnode->set_name(name);
451     nodes_.push_back({name, {}, {}});
452     CHECK_EQ(result_.nodes.size(), nodes_.size());
453     return gnode;
454   }
455 
AddInput(int node_index,int output_node,int output_index)456   void AddInput(int node_index, int output_node, int output_index) {
457     CHECK_LT(node_index, nodes_.size());
458     nodes_[node_index].data_inputs.push_back(
459         std::make_pair(output_node, output_index));
460   }
461 
AddDep(int node_index,int dep_index)462   void AddDep(int node_index, int dep_index) {
463     CHECK_LT(node_index, nodes_.size());
464     nodes_[node_index].control_inputs.push_back(dep_index);
465   }
466 
467   GetFunctionSignature get_function_;
468   InstantiationResult& result_;
469   // A small index for all names that can be used as a node's input arguments.
470   std::map<string, NameInfoItem> index_;
471   // This contains information about a node in the new graph including the node
472   // names and input nodes' indexes.
473   struct NodeInfo {
474     string name;
475     // Data inputs where <n, k> means arg k of node n.
476     std::vector<std::pair<int, int>> data_inputs;
477     // Control inputs (dependencies).
478     std::vector<int> control_inputs;
479   };
480   // nodes_[i] is the information about result_.nodes[i].
481   std::vector<NodeInfo> nodes_;
482 };
483 
484 // Various helpers Print(proto) to print relevant protos to ascii.
Print(const OpDef::ArgDef & arg)485 string Print(const OpDef::ArgDef& arg) {
486   string out;
487   strings::StrAppend(&out, arg.name(), ":");
488   if (arg.is_ref()) strings::StrAppend(&out, "Ref(");
489   if (!arg.number_attr().empty()) {
490     strings::StrAppend(&out, arg.number_attr(), "*");
491   }
492   if (arg.type() != DT_INVALID) {
493     strings::StrAppend(&out, DataTypeString(arg.type()));
494   } else {
495     strings::StrAppend(&out, arg.type_attr());
496   }
497   if (arg.is_ref()) strings::StrAppend(&out, ")");
498   return out;
499 }
500 
501 // TODO(josh11b): Merge this with SummarizeAttrValue().
Print(const AttrValue & attr_value)502 string Print(const AttrValue& attr_value) {
503   if (attr_value.value_case() == AttrValue::kType) {
504     return DataTypeString(attr_value.type());
505   } else if ((attr_value.value_case() == AttrValue::kList) &&
506              (attr_value.list().type_size() > 0)) {
507     string ret = "{";
508     for (int i = 0; i < attr_value.list().type_size(); ++i) {
509       if (i > 0) strings::StrAppend(&ret, ", ");
510       strings::StrAppend(&ret, DataTypeString(attr_value.list().type(i)));
511     }
512     strings::StrAppend(&ret, "}");
513     return ret;
514   } else if (attr_value.value_case() == AttrValue::kFunc) {
515     if (attr_value.func().attr_size() == 0) {
516       return attr_value.func().name();
517     }
518     std::vector<string> entries;
519     for (const auto& p : attr_value.func().attr()) {
520       entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
521     }
522     std::sort(entries.begin(), entries.end());
523     return strings::StrCat(attr_value.func().name(), "[",
524                            absl::StrJoin(entries, ", "), "]");
525   }
526   return SummarizeAttrValue(attr_value);
527 }
528 
529 // TODO(josh11b): Merge this with SummarizeNodeDef().
Print(const NodeDef & n)530 string Print(const NodeDef& n) {
531   string out;
532   strings::StrAppend(&out, n.name(), " = ", n.op());
533   if (n.attr_size() > 0) {
534     std::vector<string> entries;
535     for (auto& a : n.attr()) {
536       entries.push_back(strings::StrCat(a.first, "=", Print(a.second)));
537     }
538     std::sort(entries.begin(), entries.end());
539     // Add a short device string at the end of all attributes.
540     if (!n.device().empty()) {
541       DeviceNameUtils::ParsedName parsed;
542       if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) {
543         entries.push_back(
544             strings::StrCat("device=", parsed.type, ":", parsed.id));
545       } else {
546         entries.push_back("device=<FAILED_TO_PARSE>");
547       }
548     }
549     strings::StrAppend(&out, "[", absl::StrJoin(entries, ", "), "]");
550   }
551   strings::StrAppend(&out, "(");
552   std::vector<StringPiece> dat;
553   std::vector<string> dep;
554   for (StringPiece s : n.input()) {
555     if (absl::ConsumePrefix(&s, "^")) {
556       dep.emplace_back(s);
557     } else {
558       dat.push_back(s);
559     }
560   }
561   strings::StrAppend(&out, absl::StrJoin(dat, ", "), ")");
562   if (!dep.empty()) {
563     strings::StrAppend(&out, " @ ", absl::StrJoin(dep, ", "));
564   }
565   return out;
566 }
567 
Print(const FunctionDef & fdef)568 string Print(const FunctionDef& fdef) {
569   string out;
570   const OpDef& sig = fdef.signature();
571   strings::StrAppend(&out, "\n", sig.name());
572   if (sig.attr_size() > 0) {
573     strings::StrAppend(&out, "[");
574     for (int i = 0; i < sig.attr_size(); ++i) {
575       const auto& a = sig.attr(i);
576       if (i > 0) strings::StrAppend(&out, ", ");
577       if (a.type() == "type") {
578         strings::StrAppend(&out, a.name(), ":", Print(a.allowed_values()));
579       } else {
580         strings::StrAppend(&out, a.name(), ":", a.type());
581       }
582     }
583     strings::StrAppend(&out, "]");
584   }
585   strings::StrAppend(&out, "(");
586   for (int i = 0; i < sig.input_arg_size(); ++i) {
587     if (i > 0) strings::StrAppend(&out, ", ");
588     strings::StrAppend(&out, Print(sig.input_arg(i)));
589   }
590   strings::StrAppend(&out, ") -> (");
591   for (int i = 0; i < sig.output_arg_size(); ++i) {
592     if (i > 0) strings::StrAppend(&out, ", ");
593     strings::StrAppend(&out, Print(sig.output_arg(i)));
594   }
595   strings::StrAppend(&out, ") {\n");
596   for (const auto& n : fdef.node_def()) {
597     strings::StrAppend(&out, "  ", Print(n), "\n");
598   }
599   for (const auto& cr : fdef.control_ret()) {
600     strings::StrAppend(&out, "  @return ", cr.first, " = ", cr.second, "\n");
601   }
602   for (const auto& r : fdef.ret()) {
603     strings::StrAppend(&out, "  return ", r.first, " = ", r.second, "\n");
604   }
605   strings::StrAppend(&out, "}\n");
606   return out;
607 }
608 
Print(gtl::ArraySlice<const NodeDef * > nodes)609 string Print(gtl::ArraySlice<const NodeDef*> nodes) {
610   std::vector<const NodeDef*> arg;
611   std::vector<const NodeDef*> ret;
612   std::vector<const NodeDef*> body;
613   for (const NodeDef* n : nodes) {
614     if (n->op() == FunctionLibraryDefinition::kArgOp ||
615         n->op() == FunctionLibraryDefinition::kDeviceArgOp) {
616       arg.push_back(n);
617     } else if (n->op() == FunctionLibraryDefinition::kRetOp ||
618                n->op() == FunctionLibraryDefinition::kDeviceRetOp) {
619       ret.push_back(n);
620     } else {
621       body.push_back(n);
622     }
623   }
624   auto comp = [](const NodeDef* x, const NodeDef* y) {
625     int xi;
626     TF_CHECK_OK(GetNodeAttr(*x, "index", &xi));
627     int yi;
628     TF_CHECK_OK(GetNodeAttr(*y, "index", &yi));
629     return xi < yi;
630   };
631   std::sort(arg.begin(), arg.end(), comp);
632   std::sort(ret.begin(), ret.end(), comp);
633   string out;
634   strings::StrAppend(&out, "\n(");
635   auto get_type_and_device = [](const NodeDef& n) {
636     DataType dt;
637     if (!TryGetNodeAttr(n, "T", &dt)) {
638       dt = DT_INVALID;
639     }
640     if (!n.device().empty()) {
641       DeviceNameUtils::ParsedName parsed;
642       if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) {
643         return strings::StrCat(DataTypeString(dt), "@", parsed.type, ":",
644                                parsed.id);
645       } else {
646         LOG(WARNING) << "Failed to parse device \"" << n.device() << "\" in "
647                      << n.op() << ":" << n.name();
648         return strings::StrCat(DataTypeString(dt), "@",
649                                "<FAILED_TO_PARSE_DEVICE>");
650       }
651     }
652     return DataTypeString(dt);
653   };
654   for (size_t i = 0; i < arg.size(); ++i) {
655     const NodeDef* n = arg[i];
656     if (i > 0) strings::StrAppend(&out, ", ");
657     CHECK_GE(n->attr_size(), 2);
658     strings::StrAppend(&out, n->name(), ":", get_type_and_device(*n));
659   }
660   strings::StrAppend(&out, ") -> (");
661   for (size_t i = 0; i < ret.size(); ++i) {
662     const NodeDef* n = ret[i];
663     if (i > 0) strings::StrAppend(&out, ", ");
664     CHECK_LE(2, n->attr_size());
665 
666     // The _RetVal op should have a unique non-control input. We assert that
667     // here and add it to the output.
668     bool found_non_control_input = false;
669     for (const string& input : n->input()) {
670       if (!input.empty() && input[0] != '^') {
671         DCHECK_EQ(found_non_control_input, false)
672             << "RetVal node has more than one non-control input: "
673             << absl::StrJoin(n->input(), ", ");
674         strings::StrAppend(&out, n->input(0), ":", get_type_and_device(*n));
675         found_non_control_input = true;
676       }
677     }
678     DCHECK_EQ(found_non_control_input, true)
679         << "RetVal did not have any non-control inputs: "
680         << absl::StrJoin(n->input(), ", ");
681   }
682   strings::StrAppend(&out, ") {\n");
683   for (size_t i = 0; i < body.size(); ++i) {
684     strings::StrAppend(&out, "  ", Print(*body[i]), "\n");
685   }
686   strings::StrAppend(&out, "}\n");
687   return out;
688 }
689 
AddDefaultAttrs(const string & op,const GetFunctionSignature & get_function,AttrValueMap * attrs)690 Status AddDefaultAttrs(const string& op,
691                        const GetFunctionSignature& get_function,
692                        AttrValueMap* attrs) {
693   const OpDef* op_def = nullptr;
694   TF_RETURN_IF_ERROR(get_function(op, &op_def));
695   AttrSlice attr_slice(attrs);
696   for (const auto& attr_def : op_def->attr()) {
697     if (attr_def.has_default_value() && !attr_slice.Find(attr_def.name())) {
698       if (!attrs->insert({attr_def.name(), attr_def.default_value()}).second) {
699         return errors::Internal("Somehow duplicated: ", attr_def.name());
700       }
701     }
702   }
703   return Status::OK();
704 }
705 
706 }  // end namespace
707 
708 // TODO(shikharagarwal): Transmit original node names correctly in file.
InstantiateFunction(const FunctionDef & fdef,AttrSlice attr_values,GetFunctionSignature get_function,InstantiationResult * result)709 Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
710                            GetFunctionSignature get_function,
711                            InstantiationResult* result) {
712   if (VLOG_IS_ON(5)) {
713     const auto& signature = fdef.signature();
714     VLOG(5) << "Instantiate function definition: name=" << signature.name()
715             << " #input_args=" << signature.input_arg_size()
716             << " #output_args=" << signature.output_arg_size()
717             << " #control_output=" << signature.control_output_size();
718     for (const auto& line : str_util::Split(Print(fdef), '\n')) {
719       VLOG(5) << "|| " << line;
720     }
721   }
722 
723   const OpDef& sig = fdef.signature();
724   TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values));
725 
726   bool ints_on_device =
727       fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
728       fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b();
729 
730   FunctionInstantiationHelper helper(get_function, result);
731   Status s;
732   for (int i = 0, e = sig.input_arg_size(); i < e; ++i) {
733     const OpDef::ArgDef& arg_def = sig.input_arg(i);
734     auto it = fdef.arg_attr().find(i);
735     const FunctionDef::ArgAttrs* arg_attrs =
736         it != fdef.arg_attr().end() ? &it->second : nullptr;
737     auto resource_id_it = fdef.resource_arg_unique_id().find(i);
738     int64 resource_arg_unique_id =
739         resource_id_it != fdef.resource_arg_unique_id().end()
740             ? resource_id_it->second
741             : -1LL;
742     s = helper.BuildInputArgIndex(arg_def, attr_values, arg_attrs,
743                                   ints_on_device, resource_arg_unique_id);
744 
745     if (!s.ok()) {
746       errors::AppendToMessage(&s, "In ", Print(arg_def));
747       return s;
748     }
749   }
750 
751   auto substitute = [attr_values](StringPiece name, AttrValue* val) {
752     if (const AttrValue* v = attr_values.Find(name)) {
753       *val = *v;
754       return true;
755     }
756     return false;
757   };
758 
759   // Makes a copy of all attrs in fdef and substitutes placeholders.
760   // After this step, every attr is bound to a concrete value.
761   std::vector<AttrValueMap> node_attrs;
762   node_attrs.resize(fdef.node_def_size());
763   for (int i = 0; i < fdef.node_def_size(); ++i) {
764     for (auto attr : fdef.node_def(i).attr()) {
765       if (!SubstitutePlaceholders(substitute, &attr.second)) {
766         return errors::InvalidArgument("Failed to bind all placeholders in ",
767                                        SummarizeAttrValue(attr.second));
768       }
769       if (!node_attrs[i].insert(attr).second) {
770         return errors::Internal("Somehow duplicated: ", attr.first);
771       }
772     }
773     TF_RETURN_IF_ERROR(
774         AddDefaultAttrs(fdef.node_def(i).op(), get_function, &node_attrs[i]));
775   }
776 
777   for (int i = 0; i < fdef.node_def_size(); ++i) {
778     s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]),
779                                     result->nodes.size() + i);
780     if (!s.ok()) {
781       errors::AppendToMessage(&s, "In ",
782                               FormatNodeDefForError(fdef.node_def(i)));
783       return s;
784     }
785   }
786   // Emits one node for each fdef.node_def.
787   for (int i = 0; i < fdef.node_def_size(); ++i) {
788     s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i]));
789     if (!s.ok()) {
790       errors::AppendToMessage(&s, "In ",
791                               FormatNodeDefForError(fdef.node_def(i)));
792       return s;
793     }
794   }
795 
796   // Emits nodes for the function's return values.
797   int ret_index = 0;
798   for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
799     s = helper.AddReturnNode(ret_def, attr_values, fdef.ret(), ints_on_device,
800                              &ret_index);
801     if (!s.ok()) {
802       errors::AppendToMessage(&s, "In function output ", Print(ret_def));
803       return s;
804     }
805   }
806 
807   // Adds the actual node inputs using the input indexes.
808   helper.AddNodeInputs();
809 
810   return Status::OK();
811 }
812 
DebugString(const FunctionDef & func_def)813 string DebugString(const FunctionDef& func_def) { return Print(func_def); }
814 
DebugString(const GraphDef & instantiated_func_def)815 string DebugString(const GraphDef& instantiated_func_def) {
816   std::vector<const NodeDef*> ptrs;
817   for (const NodeDef& n : instantiated_func_def.node()) {
818     ptrs.push_back(&n);
819   }
820   return Print(ptrs);
821 }
822 
DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes)823 string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes) {
824   std::vector<const NodeDef*> ptrs;
825   for (const NodeDef& n : instantiated_func_nodes) {
826     ptrs.push_back(&n);
827   }
828   return Print(ptrs);
829 }
830 
DebugStringWhole(const GraphDef & gdef)831 string DebugStringWhole(const GraphDef& gdef) {
832   string ret;
833   for (const auto& fdef : gdef.library().function()) {
834     strings::StrAppend(&ret, Print(fdef));
835   }
836   strings::StrAppend(&ret, "\n");
837   for (const auto& ndef : gdef.node()) {
838     strings::StrAppend(&ret, Print(ndef), "\n");
839   }
840   return ret;
841 }
842 
843 namespace {
844 
845 // Returns the name -> attr mapping of fdef's attrs that have a value set. In
846 // Python, it's possible to access unset attrs, which returns a default value
847 // and adds an unset attr to the map.
GetSetAttrs(const FunctionDef & fdef)848 std::map<string, AttrValue> GetSetAttrs(const FunctionDef& fdef) {
849   std::map<string, AttrValue> set_attrs;
850   for (const auto& pair : fdef.attr()) {
851     if (pair.second.value_case() != AttrValue::VALUE_NOT_SET) {
852       set_attrs[pair.first] = pair.second;
853     }
854   }
855   return set_attrs;
856 }
857 
858 }  // end namespace
859 
FunctionDefsEqual(const FunctionDef & f1,const FunctionDef & f2)860 bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
861   if (!OpDefEqual(f1.signature(), f2.signature())) return false;
862 
863   std::map<string, AttrValue> f1_attrs = GetSetAttrs(f1);
864   std::map<string, AttrValue> f2_attrs = GetSetAttrs(f2);
865   if (f1_attrs.size() != f2_attrs.size()) return false;
866   for (const auto& iter1 : f1_attrs) {
867     auto iter2 = f2_attrs.find(iter1.first);
868     if (iter2 == f2_attrs.end()) return false;
869     if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false;
870   }
871 
872   if (!EqualRepeatedNodeDef(f1.node_def(), f2.node_def(), nullptr)) {
873     return false;
874   }
875 
876   std::map<string, string> ret1(f1.ret().begin(), f1.ret().end());
877   std::map<string, string> ret2(f2.ret().begin(), f2.ret().end());
878   if (ret1 != ret2) return false;
879 
880   std::map<string, string> control_ret1(f1.control_ret().begin(),
881                                         f1.control_ret().end());
882   std::map<string, string> control_ret2(f2.control_ret().begin(),
883                                         f2.control_ret().end());
884   if (control_ret1 != control_ret2) return false;
885 
886   return true;
887 }
888 
FunctionDefHash(const FunctionDef & fdef)889 uint64 FunctionDefHash(const FunctionDef& fdef) {
890   // signature
891   uint64 h = OpDefHash(fdef.signature());
892 
893   // attrs
894   std::map<string, AttrValue> attrs = GetSetAttrs(fdef);
895   for (const auto& p : attrs) {
896     h = Hash64(p.first.data(), p.first.size(), h);
897     h = Hash64Combine(AttrValueHash(p.second), h);
898   }
899 
900   // node defs
901   h = Hash64Combine(RepeatedNodeDefHash(fdef.node_def()), h);
902 
903   // output names
904   std::map<string, string> ret(fdef.ret().begin(), fdef.ret().end());
905   for (const auto& p : ret) {
906     h = Hash64(p.first.data(), p.first.size(), h);
907     h = Hash64(p.second.data(), p.second.size(), h);
908   }
909 
910   // control output names
911   std::map<string, string> control_ret(fdef.control_ret().begin(),
912                                        fdef.control_ret().end());
913   for (const auto& p : control_ret) {
914     h = Hash64(p.first.data(), p.first.size(), h);
915     h = Hash64(p.second.data(), p.second.size(), h);
916   }
917 
918   return h;
919 }
920 
921 static constexpr const char* const kExecutorAttr = "_executor";
922 
923 /* static */
ExecutorType(const InstantiateOptions & options,AttrSlice attrs)924 string FunctionLibraryRuntime::ExecutorType(const InstantiateOptions& options,
925                                             AttrSlice attrs) {
926   if (!options.executor_type.empty()) {
927     return options.executor_type;
928   } else if (const AttrValue* executor_attr = attrs.Find(kExecutorAttr)) {
929     return executor_attr->s();
930   } else {
931     return string();
932   }
933 }
934 
935 namespace {
936 class AttrKeyAndValue {
937  public:
938   enum ValueRepresentationOp {
939     kRaw,
940     kCEscape,
941   };
AttrKeyAndValue(absl::string_view key_name,int key_suffix,string value,ValueRepresentationOp value_op=kRaw)942   AttrKeyAndValue(absl::string_view key_name, int key_suffix, string value,
943                   ValueRepresentationOp value_op = kRaw)
944       : key_name_(key_name),
945         key_suffix_(key_suffix),
946         value_op_(value_op),
947         value_(std::move(value)) {}
948 
operator <(const AttrKeyAndValue & b) const949   bool operator<(const AttrKeyAndValue& b) const {
950     if (key_name_ != b.key_name_) {
951       return key_name_ < b.key_name_;
952     } else if (key_suffix_ != b.key_suffix_) {
953       return key_suffix_ < b.key_suffix_;
954     } else {
955       return value_ < b.value_;
956     }
957   }
958 
AppendTo(bool first,string * s) const959   void AppendTo(bool first, string* s) const {
960     absl::string_view v;
961     bool add_escaped = false;
962     if ((value_op_ == kCEscape) && NeedsEscaping(value_)) {
963       // Use CEscape call below
964       add_escaped = true;
965     } else {
966       // Add raw value contents directly
967       v = value_;
968     }
969     if (key_suffix_ >= 0) {
970       strings::StrAppend(s, first ? "" : ",", key_name_, key_suffix_, "=", v);
971     } else {
972       strings::StrAppend(s, first ? "" : ",", key_name_, "=", v);
973     }
974     if (add_escaped) {
975       strings::StrAppend(s, absl::CEscape(value_));
976     }
977   }
978 
979  private:
NeedsEscaping(const string & s)980   static bool NeedsEscaping(const string& s) {
981     for (auto c : s) {
982       if (!isalnum(c) && (c != ' ')) {
983         return true;
984       }
985     }
986     return false;
987   }
988 
989   absl::string_view key_name_;
990   int key_suffix_;  // -1 if missing
991   ValueRepresentationOp value_op_;
992   string value_;
993 };
994 }  // namespace
995 
GetFunctionResourceInputDevice(const Tensor & input,const int arg_index,const FunctionDef & function_def,absl::flat_hash_map<string,std::vector<string>> * composite_devices)996 string GetFunctionResourceInputDevice(
997     const Tensor& input, const int arg_index, const FunctionDef& function_def,
998     absl::flat_hash_map<string, std::vector<string>>* composite_devices) {
999   const auto& handles = input.flat<ResourceHandle>();
1000   const ResourceHandle& handle0 = handles(0);
1001   string composite_device;
1002   auto iter = function_def.arg_attr().find(arg_index);
1003   if (iter != function_def.arg_attr().end()) {
1004     auto arg_attr = iter->second.attr().find("_composite_device");
1005     if (arg_attr != iter->second.attr().end()) {
1006       composite_device = arg_attr->second.s();
1007     }
1008   }
1009   if (!composite_device.empty()) {
1010     if (composite_devices->find(composite_device) == composite_devices->end()) {
1011       for (int i = 0; i < handles.size(); ++i) {
1012         (*composite_devices)[composite_device].push_back(handles(i).device());
1013       }
1014     }
1015     return composite_device;
1016   } else {
1017     return handle0.device();
1018   }
1019 }
1020 
Canonicalize(const string & funcname,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options)1021 string Canonicalize(const string& funcname, AttrSlice attrs,
1022                     const FunctionLibraryRuntime::InstantiateOptions& options) {
1023   absl::InlinedVector<AttrKeyAndValue, 8> entries;
1024   entries.reserve(attrs.size() + static_cast<int>(!options.target.empty()) +
1025                   options.input_devices.size());
1026   for (const auto& p : attrs) {
1027     if (p.first != kExecutorAttr) {
1028       entries.push_back(AttrKeyAndValue(p.first, -1, Print(p.second)));
1029     }
1030   }
1031   if (!options.target.empty()) {
1032     entries.push_back(AttrKeyAndValue("_target", -1, options.target,
1033                                       AttrKeyAndValue::kCEscape));
1034   }
1035   for (int i = 0; i < options.input_devices.size(); ++i) {
1036     entries.push_back(AttrKeyAndValue("_input_dev", i, options.input_devices[i],
1037                                       AttrKeyAndValue::kCEscape));
1038   }
1039   for (int i = 0; i < options.output_devices.size(); ++i) {
1040     entries.push_back(AttrKeyAndValue("_output_dev", i,
1041                                       options.output_devices[i],
1042                                       AttrKeyAndValue::kCEscape));
1043   }
1044   for (const auto& iter : options.input_resource_dtypes_and_shapes) {
1045     entries.push_back(AttrKeyAndValue("_input_resource_dtype", iter.first,
1046                                       DataTypeString(iter.second.dtype)));
1047     entries.push_back(AttrKeyAndValue("_input_resource_shape", iter.first,
1048                                       iter.second.shape.DebugString(),
1049                                       AttrKeyAndValue::kCEscape));
1050   }
1051   if (options.lib_def) {
1052     entries.push_back(AttrKeyAndValue(
1053         "_lib_def", -1,
1054         absl::StrCat("", reinterpret_cast<uintptr_t>(options.lib_def))));
1055   }
1056   if (!options.state_handle.empty()) {
1057     entries.push_back(
1058         AttrKeyAndValue("_state_handle", -1, options.state_handle));
1059   }
1060   string executor_type = FunctionLibraryRuntime::ExecutorType(options, attrs);
1061   if (!executor_type.empty()) {
1062     entries.push_back(AttrKeyAndValue(kExecutorAttr, -1, executor_type));
1063   }
1064   if (options.config_proto.ByteSize() > 0) {
1065     string config_proto_serialized;
1066     options.config_proto.SerializeToString(&config_proto_serialized);
1067     entries.push_back(AttrKeyAndValue("_config_proto", -1,
1068                                       config_proto_serialized,
1069                                       AttrKeyAndValue::kCEscape));
1070   }
1071   std::sort(entries.begin(), entries.end());
1072   string result = strings::StrCat(funcname, "[");
1073   bool first = true;
1074   for (const auto& entry : entries) {
1075     entry.AppendTo(first, &result);
1076     first = false;
1077   }
1078   result += "]";
1079   return result;
1080 }
1081 
Canonicalize(const string & funcname,AttrSlice attrs)1082 string Canonicalize(const string& funcname, AttrSlice attrs) {
1083   static const FunctionLibraryRuntime::InstantiateOptions* kEmptyOptions =
1084       new FunctionLibraryRuntime::InstantiateOptions;
1085   return Canonicalize(funcname, attrs, *kEmptyOptions);
1086 }
1087 
FunctionCallFrame(DataTypeSlice arg_types,DataTypeSlice ret_types)1088 FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types,
1089                                      DataTypeSlice ret_types)
1090     : arg_types_(arg_types.begin(), arg_types.end()),
1091       ret_types_(ret_types.begin(), ret_types.end()) {
1092   args_.resize(arg_types_.size());
1093   rets_.resize(ret_types_.size());
1094 }
1095 
~FunctionCallFrame()1096 FunctionCallFrame::~FunctionCallFrame() {}
1097 
SetArgs(gtl::ArraySlice<Tensor> args)1098 Status FunctionCallFrame::SetArgs(gtl::ArraySlice<Tensor> args) {
1099   // Input type checks.
1100   if (args.size() != arg_types_.size()) {
1101     return errors::InvalidArgument("Expects ", arg_types_.size(),
1102                                    " arguments, but ", args.size(),
1103                                    " is provided");
1104   }
1105   for (size_t i = 0; i < args.size(); ++i) {
1106     if (arg_types_[i] != args[i].dtype()) {
1107       return errors::InvalidArgument(
1108           "Expects arg[", i, "] to be ", DataTypeString(arg_types_[i]), " but ",
1109           DataTypeString(args[i].dtype()), " is provided");
1110     }
1111     args_[i] = args[i];
1112   }
1113   return Status::OK();
1114 }
1115 
GetRetvals(std::vector<Tensor> * rets) const1116 Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
1117   rets->clear();
1118   rets->reserve(rets_.size());
1119   for (size_t i = 0; i < rets_.size(); ++i) {
1120     const auto& item = rets_[i];
1121     if (item.has_val) {
1122       rets->push_back(item.val);
1123     } else {
1124       return errors::Internal("Retval[", i, "] does not have value");
1125     }
1126   }
1127   return Status::OK();
1128 }
1129 
ConsumeRetvals(std::vector<Tensor> * rets,bool allow_dead_tensors)1130 Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets,
1131                                          bool allow_dead_tensors) {
1132   rets->clear();
1133   rets->reserve(rets_.size());
1134   for (size_t i = 0; i < rets_.size(); ++i) {
1135     if (rets_[i].has_val) {
1136       rets->emplace_back(std::move(rets_[i].val));
1137     } else if (allow_dead_tensors) {
1138       rets->emplace_back();
1139     } else {
1140       return errors::Internal("Retval[", i, "] does not have value");
1141     }
1142   }
1143   return Status::OK();
1144 }
1145 
GetArg(int index,const Tensor ** val)1146 Status FunctionCallFrame::GetArg(int index, const Tensor** val) {
1147   if (index < 0 || static_cast<size_t>(index) >= args_.size()) {
1148     return errors::InvalidArgument("GetArg ", index, " is not within [0, ",
1149                                    args_.size(), ")");
1150   }
1151   *val = &args_[index];
1152   return Status::OK();
1153 }
1154 
SetRetval(int index,const Tensor & val)1155 Status FunctionCallFrame::SetRetval(int index, const Tensor& val) {
1156   if (index < 0 || static_cast<size_t>(index) >= rets_.size()) {
1157     return errors::InvalidArgument("SetRetval ", index, " is not within [0, ",
1158                                    rets_.size(), ")");
1159   }
1160   if (val.dtype() != ret_types_[index]) {
1161     return errors::InvalidArgument(
1162         "Expects ret[", index, "] to be ", DataTypeString(ret_types_[index]),
1163         ", but ", DataTypeString(val.dtype()), " is provided.");
1164   }
1165   Retval* item = &rets_[index];
1166   if (!item->has_val) {
1167     item->has_val = true;
1168     item->val = val;
1169   } else {
1170     return errors::Internal("Retval[", index, "] has already been set.");
1171   }
1172   return Status::OK();
1173 }
1174 
1175 FunctionLibraryDefinition::FunctionDefAndOpRegistration::
FunctionDefAndOpRegistration(const FunctionDef & fdef_in,const StackTracesMap & stack_traces)1176     FunctionDefAndOpRegistration(const FunctionDef& fdef_in,
1177                                  const StackTracesMap& stack_traces)
1178     : fdef(fdef_in),
1179       // Exact shape inference for functions is handled by ShapeRefiner.
1180       // Here we pass a dummy shape inference function for legacy code paths.
1181       op_registration_data(fdef.signature(), shape_inference::UnknownShape,
1182                            true /* is_function */),
1183       stack_traces(stack_traces) {}
1184 
FunctionLibraryDefinition(const FunctionLibraryDefinition & other)1185 FunctionLibraryDefinition::FunctionLibraryDefinition(
1186     const FunctionLibraryDefinition& other)
1187     : default_registry_(other.default_registry_) {
1188   tf_shared_lock l(other.mu_);
1189   function_defs_ = other.function_defs_;
1190   func_grad_ = other.func_grad_;
1191 }
1192 
FunctionLibraryDefinition(const OpRegistryInterface * default_registry,const FunctionDefLibrary & def_lib)1193 FunctionLibraryDefinition::FunctionLibraryDefinition(
1194     const OpRegistryInterface* default_registry,
1195     const FunctionDefLibrary& def_lib)
1196     : default_registry_(default_registry),
1197       function_defs_(def_lib.function_size()) {
1198   for (const auto& fdef : def_lib.function()) {
1199     // The latter function definition wins.
1200     auto& ptr = function_defs_[fdef.signature().name()];
1201     ptr.reset(new FunctionDefAndOpRegistration(fdef));
1202   }
1203   for (const auto& grad : def_lib.gradient()) {
1204     func_grad_[grad.function_name()] = grad.gradient_func();
1205   }
1206 }
1207 
~FunctionLibraryDefinition()1208 FunctionLibraryDefinition::~FunctionLibraryDefinition() {}
1209 
Contains(const string & func) const1210 bool FunctionLibraryDefinition::Contains(const string& func) const {
1211   tf_shared_lock l(mu_);
1212   return function_defs_.find(func) != function_defs_.end();
1213 }
1214 
Find(const string & func) const1215 const FunctionDef* FunctionLibraryDefinition::Find(const string& func) const {
1216   tf_shared_lock l(mu_);
1217   auto result = FindHelper(func);
1218   if (result) {
1219     return &result->fdef;
1220   } else {
1221     return nullptr;
1222   }
1223 }
1224 
1225 std::shared_ptr<FunctionLibraryDefinition::FunctionDefAndOpRegistration>
FindHelper(const string & func) const1226 FunctionLibraryDefinition::FindHelper(const string& func) const {
1227   auto iter = function_defs_.find(func);
1228   if (iter == function_defs_.end()) {
1229     return nullptr;
1230   } else {
1231     return iter->second;
1232   }
1233 }
1234 
AddFunctionDef(const FunctionDef & fdef,const StackTracesMap & stack_traces)1235 Status FunctionLibraryDefinition::AddFunctionDef(
1236     const FunctionDef& fdef, const StackTracesMap& stack_traces) {
1237   mutex_lock l(mu_);
1238   bool added;
1239   return AddFunctionDefHelper(fdef, stack_traces, &added);
1240 }
1241 
AddFunctionDefHelper(const FunctionDef & fdef,const StackTracesMap & stack_traces,bool * added)1242 Status FunctionLibraryDefinition::AddFunctionDefHelper(
1243     const FunctionDef& fdef, const StackTracesMap& stack_traces, bool* added) {
1244   *added = false;
1245   std::shared_ptr<FunctionDefAndOpRegistration>& entry =
1246       function_defs_[fdef.signature().name()];
1247   if (entry) {
1248     if (!FunctionDefsEqual(entry->fdef, fdef)) {
1249       return errors::InvalidArgument(
1250           "Cannot add function '", fdef.signature().name(),
1251           "' because a different function with the same name already "
1252           "exists.");
1253     }
1254     // Ignore duplicate FunctionDefs.
1255     return Status::OK();
1256   }
1257   const OpDef* op_def;
1258   if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) {
1259     return errors::InvalidArgument(
1260         "Cannot add function '", fdef.signature().name(),
1261         "' because an op with the same name already exists.");
1262   }
1263   entry = std::make_shared<FunctionDefAndOpRegistration>(fdef, stack_traces);
1264   *added = true;
1265   return Status::OK();
1266 }
1267 
AddHelper(std::shared_ptr<FunctionDefAndOpRegistration> registration,bool * added)1268 Status FunctionLibraryDefinition::AddHelper(
1269     std::shared_ptr<FunctionDefAndOpRegistration> registration, bool* added) {
1270   *added = false;
1271   std::shared_ptr<FunctionDefAndOpRegistration>& entry =
1272       function_defs_[registration->fdef.signature().name()];
1273   if (entry) {
1274     if (!FunctionDefsEqual(entry->fdef, registration->fdef)) {
1275       return errors::InvalidArgument(
1276           "Cannot add function '", registration->fdef.signature().name(),
1277           "' because a different function with the same name already "
1278           "exists.");
1279     }
1280     // Ignore duplicate FunctionDefs.
1281     return Status::OK();
1282   }
1283   const OpDef* op_def;
1284   if (default_registry_
1285           ->LookUpOpDef(registration->fdef.signature().name(), &op_def)
1286           .ok()) {
1287     return errors::InvalidArgument(
1288         "Cannot add function '", registration->fdef.signature().name(),
1289         "' because an op with the same name already exists.");
1290   }
1291   entry = std::move(registration);
1292   *added = true;
1293   return Status::OK();
1294 }
1295 
CopyFunctionDefFrom(const string & func,const FunctionLibraryDefinition & other)1296 Status FunctionLibraryDefinition::CopyFunctionDefFrom(
1297     const string& func, const FunctionLibraryDefinition& other) {
1298   if (default_registry_ != other.default_registry_) {
1299     return errors::InvalidArgument(
1300         "Cannot copy function '", func,
1301         "' because CopyFunctionDefFrom() requires that both libraries have the "
1302         "same default registry.");
1303   }
1304   std::shared_ptr<FunctionDefAndOpRegistration> function_def;
1305   {
1306     tf_shared_lock l(other.mu_);
1307     function_def = other.FindHelper(func);
1308   }
1309   if (!function_def) {
1310     return errors::InvalidArgument(
1311         "Cannot copy function '", func,
1312         "' because no function with that name exists in the other library.");
1313   }
1314   {
1315     mutex_lock l(mu_);
1316     std::shared_ptr<FunctionDefAndOpRegistration>& entry = function_defs_[func];
1317     if (entry) {
1318       if (!FunctionDefsEqual(entry->fdef, function_def->fdef)) {
1319         return errors::InvalidArgument(
1320             "Cannot copy function '", func,
1321             "' because a different function with the same name already "
1322             "exists.");
1323       }
1324     } else {
1325       entry = std::move(function_def);
1326     }
1327   }
1328   return Status::OK();
1329 }
1330 
AddGradientDef(const GradientDef & grad)1331 Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
1332   mutex_lock l(mu_);
1333   bool added;
1334   return AddGradientDefHelper(grad, &added);
1335 }
1336 
AddGradientDefHelper(const GradientDef & grad,bool * added)1337 Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad,
1338                                                        bool* added) {
1339   *added = false;
1340   string* entry = &func_grad_[grad.function_name()];
1341   if (!entry->empty()) {
1342     if (*entry != grad.gradient_func()) {
1343       return errors::InvalidArgument(
1344           "Cannot assign gradient function '", grad.gradient_func(), "' to '",
1345           grad.function_name(), "' because it already has gradient function ",
1346           "'", *entry, "'");
1347     }
1348     // Ignore duplicate GradientDefs
1349     return Status::OK();
1350   }
1351   *entry = grad.gradient_func();
1352   *added = true;
1353   return Status::OK();
1354 }
1355 
AddLibrary(const FunctionLibraryDefinition & other)1356 Status FunctionLibraryDefinition::AddLibrary(
1357     const FunctionLibraryDefinition& other) {
1358   // Clone `other` to ensure thread-safety (grabbing `other`'s lock for
1359   // the duration of the function could lead to deadlock).
1360   FunctionLibraryDefinition clone(other);
1361   mutex_lock l(mu_);
1362   mutex_lock l2(clone.mu_);
1363   // Remember the funcs and grads that we added successfully so that
1364   // we can roll them back on error.
1365   std::vector<string> funcs;
1366   std::vector<string> funcs_with_grads;
1367   Status s;
1368   bool added;
1369   for (auto iter : clone.function_defs_) {
1370     s = AddHelper(iter.second, &added);
1371     if (!s.ok()) {
1372       Remove(funcs, funcs_with_grads);
1373       return s;
1374     }
1375     if (added) {
1376       funcs.push_back(iter.second->fdef.signature().name());
1377     }
1378   }
1379   for (auto iter : clone.func_grad_) {
1380     GradientDef grad;
1381     grad.set_function_name(iter.first);
1382     grad.set_gradient_func(iter.second);
1383     s = AddGradientDefHelper(grad, &added);
1384     if (!s.ok()) {
1385       Remove(funcs, funcs_with_grads);
1386       return s;
1387     }
1388     if (added) {
1389       funcs_with_grads.push_back(grad.function_name());
1390     }
1391   }
1392   return Status::OK();
1393 }
1394 
AddLibrary(const FunctionDefLibrary & lib_def)1395 Status FunctionLibraryDefinition::AddLibrary(
1396     const FunctionDefLibrary& lib_def) {
1397   // Remember the funcs and grads that we added successfully so that
1398   // we can roll them back on error.
1399   mutex_lock l(mu_);
1400   std::vector<string> funcs;
1401   std::vector<string> funcs_with_grads;
1402   Status s;
1403   bool added;
1404   for (const FunctionDef& fdef : lib_def.function()) {
1405     s = AddFunctionDefHelper(fdef, /*stack_traces=*/{}, &added);
1406     if (!s.ok()) {
1407       Remove(funcs, funcs_with_grads);
1408       return s;
1409     }
1410     if (added) {
1411       funcs.push_back(fdef.signature().name());
1412     }
1413   }
1414   for (const GradientDef& grad : lib_def.gradient()) {
1415     s = AddGradientDefHelper(grad, &added);
1416     if (!s.ok()) {
1417       Remove(funcs, funcs_with_grads);
1418       return s;
1419     }
1420     if (added) {
1421       funcs_with_grads.push_back(grad.function_name());
1422     }
1423   }
1424   return Status::OK();
1425 }
1426 
ReplaceFunction(const string & func,const FunctionDef & fdef)1427 Status FunctionLibraryDefinition::ReplaceFunction(const string& func,
1428                                                   const FunctionDef& fdef) {
1429   mutex_lock l(mu_);
1430   bool added;
1431   TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
1432   TF_RETURN_IF_ERROR(AddFunctionDefHelper(fdef, /*stack_traces=*/{}, &added));
1433   return Status::OK();
1434 }
1435 
ReplaceGradient(const GradientDef & grad)1436 Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) {
1437   mutex_lock l(mu_);
1438   bool added;
1439   TF_RETURN_IF_ERROR(RemoveGradient(grad.function_name()));
1440   TF_RETURN_IF_ERROR(AddGradientDefHelper(grad, &added));
1441   return Status::OK();
1442 }
1443 
RemoveFunction(const string & func)1444 Status FunctionLibraryDefinition::RemoveFunction(const string& func) {
1445   mutex_lock l(mu_);
1446   TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
1447   return Status::OK();
1448 }
1449 
RemoveFunctionHelper(const string & func)1450 Status FunctionLibraryDefinition::RemoveFunctionHelper(const string& func) {
1451   const auto& i = function_defs_.find(func);
1452   if (i == function_defs_.end()) {
1453     return errors::InvalidArgument("Tried to remove non-existent function '",
1454                                    func, "'.");
1455   }
1456   function_defs_.erase(i);
1457   return Status::OK();
1458 }
1459 
Clear()1460 void FunctionLibraryDefinition::Clear() {
1461   mutex_lock l(mu_);
1462   function_defs_.clear();
1463   func_grad_.clear();
1464 }
1465 
RemoveGradient(const string & func)1466 Status FunctionLibraryDefinition::RemoveGradient(const string& func) {
1467   const auto& i = func_grad_.find(func);
1468   if (i == func_grad_.end()) {
1469     return errors::InvalidArgument("Tried to remove non-existent gradient '",
1470                                    func, "'.");
1471   }
1472   func_grad_.erase(i);
1473   return Status::OK();
1474 }
1475 
Remove(const std::vector<string> & funcs,const std::vector<string> & funcs_with_grads)1476 void FunctionLibraryDefinition::Remove(
1477     const std::vector<string>& funcs,
1478     const std::vector<string>& funcs_with_grads) {
1479   for (const string& f : funcs) {
1480     Status s = RemoveFunctionHelper(f);
1481     DCHECK(s.ok());
1482   }
1483   for (const string& f : funcs_with_grads) {
1484     Status s = RemoveGradient(f);
1485     DCHECK(s.ok());
1486   }
1487 }
1488 
FindGradient(const string & func) const1489 string FunctionLibraryDefinition::FindGradient(const string& func) const {
1490   tf_shared_lock l(mu_);
1491   return gtl::FindWithDefault(func_grad_, func, "");
1492 }
1493 
FindGradientHelper(const string & func) const1494 string FunctionLibraryDefinition::FindGradientHelper(const string& func) const {
1495   return gtl::FindWithDefault(func_grad_, func, "");
1496 }
1497 
LookUp(const string & op,const OpRegistrationData ** op_reg_data) const1498 Status FunctionLibraryDefinition::LookUp(
1499     const string& op, const OpRegistrationData** op_reg_data) const {
1500   tf_shared_lock l(mu_);
1501   auto iter = function_defs_.find(op);
1502   if (iter != function_defs_.end()) {
1503     *op_reg_data = &iter->second->op_registration_data;
1504     return Status::OK();
1505   }
1506   return default_registry_->LookUp(op, op_reg_data);
1507 }
1508 
UniqueFunctionName(StringPiece prefix) const1509 string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const {
1510   tf_shared_lock l(mu_);
1511   int index = 0;
1512   string name = strings::StrCat(prefix, index);
1513   while (function_defs_.find(name) != function_defs_.end()) {
1514     ++index;
1515     name = strings::StrCat(prefix, index);
1516   }
1517   return name;
1518 }
1519 
GetAttrImpl(const NodeDef & ndef) const1520 const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
1521     const NodeDef& ndef) const {
1522   if (ndef.op() != kGradientOp) {
1523     // If 'ndef' calls a function and the function's def has the attr,
1524     // returns it.
1525     return Find(ndef.op());
1526   }
1527 
1528   // If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or
1529   // Foo's attributes.
1530   const NameAttrList* forward_func_attrs;
1531   if (!TryGetNodeAttr(ndef, kFuncAttr, &forward_func_attrs)) {
1532     return nullptr;
1533   }
1534   const string& func_name = forward_func_attrs->name();
1535   {
1536     tf_shared_lock l(mu_);
1537     const string& grad_name = FindGradientHelper(func_name);
1538     // If 'func' has a user-defined gradient function, uses the grad
1539     // function's attrs to see if noinline is specified. Otherwise,
1540     // uses func's attrs.
1541     if (!grad_name.empty()) {
1542       if (const auto helper = FindHelper(grad_name)) {
1543         return &(helper->fdef);
1544       } else {
1545         return nullptr;
1546       }
1547     }
1548     if (const auto helper = FindHelper(func_name)) {
1549       return &(helper->fdef);
1550     } else {
1551       return nullptr;
1552     }
1553   }
1554 }
1555 
ListFunctionNames() const1556 std::vector<string> FunctionLibraryDefinition::ListFunctionNames() const {
1557   std::vector<string> function_names;
1558   tf_shared_lock l(mu_);
1559   function_names.reserve(function_defs_.size());
1560   for (const auto& it : function_defs_) {
1561     function_names.emplace_back(it.first);
1562   }
1563   return function_names;
1564 }
1565 
ToProto() const1566 FunctionDefLibrary FunctionLibraryDefinition::ToProto() const {
1567   FunctionDefLibrary lib;
1568   tf_shared_lock l(mu_);
1569   for (const auto& f : function_defs_) {
1570     *lib.add_function() = f.second->fdef;
1571   }
1572   for (const auto& g : func_grad_) {
1573     GradientDef* gd = lib.add_gradient();
1574     gd->set_function_name(g.first);
1575     gd->set_gradient_func(g.second);
1576   }
1577   return lib;
1578 }
1579 
1580 template <typename T>
GetAttr(const NodeDef & ndef,const string & attr,T * value) const1581 Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
1582                                           const string& attr, T* value) const {
1583   const FunctionDef* fdef = GetAttrImpl(ndef);
1584   if (fdef && TryGetNodeAttr(AttrSlice(&fdef->attr()), attr, value)) {
1585     return Status::OK();
1586   }
1587   return errors::InvalidArgument("Attr ", attr, " is not defined.");
1588 }
1589 
1590 template <typename T>
GetAttr(const Node & node,const string & attr,T * value) const1591 Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr,
1592                                           T* value) const {
1593   return GetAttr(node.def(), attr, value);
1594 }
1595 
1596 #define GET_ATTR(T)                                                            \
1597   template Status FunctionLibraryDefinition::GetAttr(const Node&,              \
1598                                                      const string&, T*) const; \
1599   template Status FunctionLibraryDefinition::GetAttr(const NodeDef&,           \
1600                                                      const string&, T*) const;
1601 GET_ATTR(string)
1602 GET_ATTR(bool)
1603 #undef GET_ATTR
1604 
1605 namespace {
1606 
1607 constexpr char kApiImplements[] = "api_implements";
1608 
ReachableFunctions(const FunctionLibraryDefinition & flib,const protobuf::RepeatedPtrField<NodeDef> & nodes)1609 std::set<string> ReachableFunctions(
1610     const FunctionLibraryDefinition& flib,
1611     const protobuf::RepeatedPtrField<NodeDef>& nodes) {
1612   // Functions that are reachable from the graph.
1613   std::set<string> reachable_funcs;
1614 
1615   // For any functions, if it has attribute "api_implements" =
1616   // "some_interface" and it is reachable, then it means any other
1617   // function with same attribute name and value could also be potentially
1618   // reachable, eg via implementation_selector swapping the nodedef.
1619   absl::flat_hash_set<string> reachable_api_interface;
1620 
1621   // Functions might be reachable from the nested function calls, so we keep a
1622   // queue of functions that we have to check.
1623   gtl::InlinedVector<const FunctionDef*, 4> func_queue;
1624 
1625   // Add reachable and not already processed functions to the functions queue.
1626   const auto add_to_func_queue = [&](const string& func_name) {
1627     const FunctionDef* func = flib.Find(func_name);
1628     if (func && reachable_funcs.find(func_name) == reachable_funcs.end()) {
1629       func_queue.push_back(func);
1630     }
1631   };
1632 
1633   // If any function with certain API name is reachable, all the other functions
1634   // with same API name should also be checked.
1635   const auto add_function_with_api_interface = [&](const string& api_name) {
1636     if (!reachable_api_interface.contains(api_name)) {
1637       reachable_api_interface.insert(api_name);
1638       for (const auto& func_name : flib.ListFunctionNames()) {
1639         const auto& func_def = flib.Find(func_name);
1640         const auto attr_it = func_def->attr().find(kApiImplements);
1641         if (attr_it != func_def->attr().end() &&
1642             attr_it->second.s() == api_name) {
1643           add_to_func_queue(func_name);
1644         }
1645       }
1646     }
1647   };
1648 
1649   // Add all the functions that are reachable from the given node to the queue.
1650   const auto process_node = [&](const NodeDef& node) {
1651     // Node itself can be a call to the function.
1652     add_to_func_queue(node.op());
1653 
1654     // Or node can have an attribute referencing a function.
1655     for (const auto& attr : node.attr()) {
1656       const auto& attr_value = attr.second;
1657 
1658       // 1. AttrValue.func
1659       if (attr_value.has_func()) {
1660         add_to_func_queue(attr_value.func().name());
1661       }
1662 
1663       // 2. AttrValue.ListValue.func
1664       if (attr_value.has_list()) {
1665         for (const auto& func : attr_value.list().func()) {
1666           add_to_func_queue(func.name());
1667         }
1668       }
1669     }
1670   };
1671 
1672   // Add all functions that are directly called from the optimized graph.
1673   std::for_each(nodes.begin(), nodes.end(), process_node);
1674 
1675   // Process all reachable functions.
1676   while (!func_queue.empty()) {
1677     const FunctionDef* func = func_queue.back();
1678     func_queue.pop_back();
1679 
1680     const string& func_name = func->signature().name();
1681     reachable_funcs.insert(func_name);
1682 
1683     const auto attr_it = func->attr().find(kApiImplements);
1684     if (attr_it != func->attr().end()) {
1685       add_function_with_api_interface(attr_it->second.s());
1686     }
1687 
1688     // Find all the functions called from the function body.
1689     const auto& func_body = func->node_def();
1690     std::for_each(func_body.begin(), func_body.end(), process_node);
1691 
1692     // Check if the function has a registered gradient.
1693     const string grad_func_name = flib.FindGradient(func_name);
1694     if (!grad_func_name.empty()) add_to_func_queue(grad_func_name);
1695   }
1696 
1697   return reachable_funcs;
1698 }
1699 
ReachableFunctionLibraryDefinition(const FunctionLibraryDefinition & flib,const protobuf::RepeatedPtrField<NodeDef> & nodes)1700 FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
1701     const FunctionLibraryDefinition& flib,
1702     const protobuf::RepeatedPtrField<NodeDef>& nodes) {
1703   std::set<string> reachable_funcs = ReachableFunctions(flib, nodes);
1704 
1705   FunctionLibraryDefinition reachable_flib(flib.default_registry(),
1706                                            FunctionDefLibrary());
1707 
1708   for (const string& func_name : reachable_funcs) {
1709     // This should never fail, because we copy functions from a valid flib and
1710     // use the same default registry.
1711     Status added = reachable_flib.CopyFunctionDefFrom(func_name, flib);
1712     TF_DCHECK_OK(added);
1713 
1714     const string grad_func_name = flib.FindGradient(func_name);
1715     if (!grad_func_name.empty()) {
1716       GradientDef grad;
1717       grad.set_function_name(func_name);
1718       grad.set_gradient_func(grad_func_name);
1719       // It can only fail if function already has a gradient function.
1720       const Status added_grad = reachable_flib.AddGradientDef(grad);
1721       TF_DCHECK_OK(added_grad);
1722     }
1723   }
1724 
1725   return reachable_flib;
1726 }
1727 
AllocatorAttributesToString(const std::vector<AllocatorAttributes> & attrs)1728 string AllocatorAttributesToString(
1729     const std::vector<AllocatorAttributes>& attrs) {
1730   string result("[");
1731   // AllocatorAttribute::DebugString produces around 85 bytes now.
1732   result.reserve(100 * attrs.size());
1733   for (const AllocatorAttributes& attr : attrs) {
1734     result.append(attr.DebugString());
1735     result.append(", ");
1736   }
1737   if (!attrs.empty()) {
1738     result.resize(result.size() - 2);
1739   }
1740   result.append("]");
1741   return result;
1742 }
1743 
IsSet(void * ptr)1744 const char* IsSet(void* ptr) { return ptr == nullptr ? "unset" : "set"; }
1745 
1746 }  // namespace
1747 
ReachableDefinitions(const GraphDef & graph) const1748 FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions(
1749     const GraphDef& graph) const {
1750   return ReachableFunctionLibraryDefinition(*this, graph.node());
1751 }
1752 
ReachableDefinitions(const FunctionDef & func) const1753 FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions(
1754     const FunctionDef& func) const {
1755   return ReachableFunctionLibraryDefinition(*this, func.node_def());
1756 }
1757 
DebugString() const1758 string FunctionLibraryRuntime::Options::DebugString() const {
1759   return absl::StrCat(
1760       "FLR::Options(step_id=", step_id, " rendezvous=", IsSet(rendezvous),
1761       " cancellation_manager=", IsSet(cancellation_manager),
1762       " collective_executor=", IsSet(collective_executor),
1763       " step_container=", IsSet(step_container),
1764       " stats_collector=", IsSet(stats_collector), " runner=", IsSet(runner),
1765       " remote_execution=", remote_execution, " source_device=", source_device,
1766       " create_rendezvous=", create_rendezvous,
1767       " allow_dead_tensors=", allow_dead_tensors,
1768       " args_alloc_attrs=", AllocatorAttributesToString(args_alloc_attrs),
1769       " rets_alloc_attrs=", AllocatorAttributesToString(rets_alloc_attrs), ")");
1770 }
1771 
InitFromString(StringPiece val)1772 void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) {
1773   if (val.size() >= 2 && val[0] == '$') {
1774     proto.set_placeholder(val.data() + 1, val.size() - 1);
1775   } else {
1776     SetAttrValue(val, &proto);
1777   }
1778 }
1779 
FunctionRef(const string & name,gtl::ArraySlice<std::pair<string,AttrValueWrapper>> attrs)1780 FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef(
1781     const string& name,
1782     gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs) {
1783   AttrValueWrapper ret;
1784   ret.proto.mutable_func()->set_name(name);
1785   for (const auto& a : attrs) {
1786     ret.proto.mutable_func()->mutable_attr()->insert({a.first, a.second.proto});
1787   }
1788   return ret;
1789 }
1790 
ToNodeDef() const1791 NodeDef FunctionDefHelper::Node::ToNodeDef() const {
1792   NodeDef n;
1793   n.set_op(this->op);
1794   n.set_name(this->ret[0]);
1795   for (const auto& a : this->attr) {
1796     n.mutable_attr()->insert({a.first, a.second.proto});
1797   }
1798   for (const string& a : this->arg) {
1799     n.add_input(a);
1800   }
1801   for (const string& d : this->dep) {
1802     n.add_input(strings::StrCat("^", d));
1803   }
1804   if (!this->device.empty()) {
1805     n.set_device(this->device);
1806   }
1807   return n;
1808 }
1809 
1810 /* static */
Create(const string & function_name,gtl::ArraySlice<string> in_def,gtl::ArraySlice<string> out_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def,gtl::ArraySlice<std::pair<string,string>> ret_def,gtl::ArraySlice<std::pair<string,string>> control_ret_def)1811 FunctionDef FunctionDefHelper::Create(
1812     const string& function_name, gtl::ArraySlice<string> in_def,
1813     gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
1814     gtl::ArraySlice<Node> node_def,
1815     gtl::ArraySlice<std::pair<string, string>> ret_def,
1816     gtl::ArraySlice<std::pair<string, string>> control_ret_def) {
1817   FunctionDef fdef;
1818 
1819   // Signature
1820   OpDefBuilder b(function_name);
1821   for (const auto& i : in_def) b.Input(i);
1822   for (const auto& o : out_def) b.Output(o);
1823   for (const auto& a : attr_def) b.Attr(a);
1824   for (const auto& c : control_ret_def) b.ControlOutput(c.first);
1825 
1826   OpRegistrationData op_reg_data;
1827   TF_CHECK_OK(b.Finalize(&op_reg_data));
1828   fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1829 
1830   // Function body
1831   for (const auto& n : node_def) {
1832     *(fdef.add_node_def()) = n.ToNodeDef();
1833   }
1834 
1835   // Returns
1836   for (const auto& r : ret_def) {
1837     fdef.mutable_ret()->insert({r.first, r.second});
1838   }
1839 
1840   // Control returns
1841   for (const auto& cr : control_ret_def) {
1842     fdef.mutable_control_ret()->insert({cr.first, cr.second});
1843   }
1844 
1845   auto* op_def_registry = OpRegistry::Global();
1846   // Check if any op is stateful.
1847   for (const auto& n : node_def) {
1848     const OpDef* op_def = nullptr;
1849     auto status = op_def_registry->LookUpOpDef(n.op, &op_def);
1850     // Lookup can fail if e.g. we are calling a function that was not yet
1851     // defined.  If it happens, conservatively assume the op is stateful.
1852     if (!status.ok() || op_def->is_stateful()) {
1853       fdef.mutable_signature()->set_is_stateful(true);
1854     }
1855   }
1856 
1857   return fdef;
1858 }
1859 
1860 /* static */
Create(const string & function_name,gtl::ArraySlice<string> in_def,gtl::ArraySlice<string> out_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def,gtl::ArraySlice<std::pair<string,string>> ret_def)1861 FunctionDef FunctionDefHelper::Create(
1862     const string& function_name, gtl::ArraySlice<string> in_def,
1863     gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
1864     gtl::ArraySlice<Node> node_def,
1865     gtl::ArraySlice<std::pair<string, string>> ret_def) {
1866   return Create(function_name, in_def, out_def, attr_def, node_def, ret_def,
1867                 /*control_ret_def=*/{});
1868 }
1869 
1870 /* static */
Define(const string & name,gtl::ArraySlice<string> arg_def,gtl::ArraySlice<string> ret_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def)1871 FunctionDef FunctionDefHelper::Define(const string& name,
1872                                       gtl::ArraySlice<string> arg_def,
1873                                       gtl::ArraySlice<string> ret_def,
1874                                       gtl::ArraySlice<string> attr_def,
1875                                       gtl::ArraySlice<Node> node_def) {
1876   FunctionDef fdef;
1877   OpDefBuilder b(name);
1878   for (const auto& a : arg_def) b.Input(a);
1879   for (const auto& r : ret_def) b.Output(r);
1880   for (const auto& a : attr_def) b.Attr(a);
1881 
1882   OpRegistrationData op_reg_data;
1883   TF_CHECK_OK(b.Finalize(&op_reg_data));
1884   fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1885 
1886   // Mapping from legacy output names to NodeDef outputs.
1887   std::unordered_map<string, string> ret_index;
1888   for (const auto& a : fdef.signature().input_arg()) {
1889     ret_index[a.name()] = a.name();
1890   }
1891 
1892   // For looking up OpDefs
1893   auto* op_def_registry = OpRegistry::Global();
1894 
1895   // Function body
1896   for (const auto& src : node_def) {
1897     NodeDef* n = fdef.add_node_def();
1898     n->set_op(src.op);
1899     n->set_name(src.ret[0]);
1900     for (const auto& a : src.attr) {
1901       n->mutable_attr()->insert({a.first, a.second.proto});
1902     }
1903     for (const string& a : src.arg) {
1904       const auto iter = ret_index.find(a);
1905       CHECK(iter != ret_index.end())
1906           << "Node input '" << a << "' in '" << src.ret[0] << "' of " << name;
1907       n->add_input(iter->second);
1908     }
1909     for (const string& d : src.dep) {
1910       n->add_input(strings::StrCat("^", d));
1911     }
1912 
1913     // Add the outputs of this node to ret_index.
1914     const OpDef* op_def = nullptr;
1915     TF_CHECK_OK(op_def_registry->LookUpOpDef(n->op(), &op_def)) << n->op();
1916     CHECK(op_def != nullptr) << n->op();
1917     NameRangeMap output_names;
1918     TF_CHECK_OK(NameRangesForNode(*n, *op_def, nullptr, &output_names));
1919     for (const auto& o : output_names) {
1920       CHECK_LE(o.second.second, src.ret.size())
1921           << "Missing ret for output '" << o.first << "' in '" << src.ret[0]
1922           << "' of " << name;
1923       for (int i = o.second.first; i < o.second.second; ++i) {
1924         ret_index[src.ret[i]] =
1925             strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first);
1926       }
1927     }
1928     if (op_def->is_stateful()) fdef.mutable_signature()->set_is_stateful(true);
1929   }
1930 
1931   // Returns
1932   for (const auto& r : fdef.signature().output_arg()) {
1933     const auto iter = ret_index.find(r.name());
1934     CHECK(iter != ret_index.end()) << "Return '" << r.name() << "' in " << name;
1935     fdef.mutable_ret()->insert({r.name(), iter->second});
1936   }
1937   return fdef;
1938 }
1939 
Define(gtl::ArraySlice<string> arg_def,gtl::ArraySlice<string> ret_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def)1940 FunctionDef FunctionDefHelper::Define(gtl::ArraySlice<string> arg_def,
1941                                       gtl::ArraySlice<string> ret_def,
1942                                       gtl::ArraySlice<string> attr_def,
1943                                       gtl::ArraySlice<Node> node_def) {
1944   return Define("_", arg_def, ret_def, attr_def, node_def);
1945 }
1946 
1947 namespace gradient {
1948 
1949 typedef std::unordered_map<string, Creator> OpGradFactory;
1950 
GetOpGradFactory()1951 OpGradFactory* GetOpGradFactory() {
1952   static OpGradFactory* factory = new OpGradFactory;
1953   return factory;
1954 }
1955 
RegisterOp(const string & op,Creator func)1956 bool RegisterOp(const string& op, Creator func) {
1957   CHECK(GetOpGradFactory()->insert({op, func}).second)
1958       << "Duplicated gradient for " << op;
1959   return true;
1960 }
1961 
GetOpGradientCreator(const string & op,Creator * creator)1962 Status GetOpGradientCreator(const string& op, Creator* creator) {
1963   auto fac = GetOpGradFactory();
1964   auto iter = fac->find(op);
1965   if (iter == fac->end()) {
1966     return errors::NotFound("No gradient defined for op: ", op);
1967   }
1968   *creator = iter->second;
1969   return Status::OK();
1970 }
1971 
1972 }  // end namespace gradient
1973 
1974 }  // namespace tensorflow
1975