1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/function.h"
17 
18 #include <map>
19 #include <unordered_map>
20 #include <utility>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/common_shape_fns.h"
24 #include "tensorflow/core/framework/function.pb_text.h"
25 #include "tensorflow/core/framework/graph.pb.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/node_def_util.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/gtl/inlined_vector.h"
32 #include "tensorflow/core/lib/gtl/map_util.h"
33 #include "tensorflow/core/util/equal_graph_def.h"
34 
35 namespace tensorflow {
36 
37 // Extracts the actual type from "attr_values" based on its definition
38 // "arg_def".
39 //
40 // If "arg_def" is a N*T type, *is_type_list is set to false, and
41 // *dtypes is set to be a vector of size N and each element is T.
42 //
43 // If "arg_def" is a list(type), *is_type_list is set to true, and
44 // *dtypes is set to be a vector of types specified in attrs for
45 // arg_def.
46 //
47 // Otherwise (arg_def is a simple type T), *is_type_list is set to
48 // false, and *dtypes is set to a single element vector, whose only
49 // element is T.
ArgNumType(AttrSlice attrs,const OpDef::ArgDef & arg_def,bool * is_type_list,DataTypeVector * dtypes)50 Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
51                   bool* is_type_list, DataTypeVector* dtypes) {
52   dtypes->clear();
53   if (!arg_def.type_list_attr().empty()) {
54     const AttrValue* v = attrs.Find(arg_def.type_list_attr());
55     if (v == nullptr) {
56       return errors::NotFound("type attr not found: ",
57                               arg_def.type_list_attr());
58     }
59     *is_type_list = true;
60     for (int i = 0; i < v->list().type_size(); ++i) {
61       dtypes->push_back(v->list().type(i));
62     }
63     return Status::OK();
64   }
65 
66   *is_type_list = false;
67   int num = 1;
68   if (!arg_def.number_attr().empty()) {
69     const AttrValue* v = attrs.Find(arg_def.number_attr());
70     if (v == nullptr) {
71       return errors::NotFound("type attr not found: ", arg_def.type_attr());
72     }
73     num = v->i();
74   }
75 
76   DataType dtype;
77   if (arg_def.type() != DT_INVALID) {
78     dtype = arg_def.type();
79   } else if (arg_def.type_attr().empty()) {
80     dtype = DT_INVALID;
81   } else {
82     const AttrValue* v = attrs.Find(arg_def.type_attr());
83     if (v == nullptr) {
84       return errors::NotFound("type attr not found: ", arg_def.type_attr());
85     }
86     dtype = v->type();
87   }
88   dtypes->resize(num, dtype);
89   return Status::OK();
90 }
91 
92 namespace {
93 
94 template <typename T>
AddAttr(const string & name,const T & val,NodeDef * ndef)95 void AddAttr(const string& name, const T& val, NodeDef* ndef) {
96   SetAttrValue(val, &((*ndef->mutable_attr())[name]));
97 }
98 
ValidateSignatureWithAttrs(const OpDef & sig,AttrSlice attr_values)99 Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) {
100   // attr_values should specify all attrs defined in fdef.
101   for (const auto& a : sig.attr()) {
102     const AttrValue* v = attr_values.Find(a.name());
103     if (!v) {
104       return errors::NotFound("Attr ", a.name(), " is not found from ",
105                               SummarizeOpDef(sig));
106     }
107     Status status = AttrValueHasType(*v, a.type());
108     if (!status.ok()) {
109       errors::AppendToMessage(&status, "for attr '", a.name(), "'");
110       return status;
111     }
112   }
113 
114 // TODO(josh11b): Enable this code once it works with function gradients.
115 // Right now the C++ function gradient code assumes it can pass
116 // all the attrs of the function to the gradient, and any attrs that
117 // the gradient doesn't care about will be ignored.
118 #if 0
119   if (attr_values.size() != sig.attr_size()) {
120     for (const auto& a : attr_values) {
121       // TODO(josh11b): Possibly should ignore attrs that start with "_" here?
122       bool found = false;
123       for (const auto& s : sig.attr()) {
124         if (a.first == s.name()) {
125           found = true;
126           break;
127         }
128       }
129       if (!found) {
130         return errors::NotFound("Attr ", a.first, " is not found in ",
131                                 SummarizeOpDef(sig));
132       }
133     }
134   }
135 #endif
136 
137   return Status::OK();
138 }
139 
140 // A helper class for instantiating functions. This contains shared information
141 // like the resulting graph and node name index.
142 class FunctionInstantiationHelper {
143  public:
FunctionInstantiationHelper(GetFunctionSignature get_function,InstantiationResult * result)144   FunctionInstantiationHelper(GetFunctionSignature get_function,
145                               InstantiationResult* result)
146       : get_function_(std ::move(get_function)), result_(*result) {
147     result_.nodes.clear();
148   }
149 
150   // Builds index for nodes that can be used as node's input arguments.
BuildInputArgIndex(const OpDef::ArgDef & arg_def,AttrSlice attr_values)151   Status BuildInputArgIndex(const OpDef::ArgDef& arg_def,
152                             AttrSlice attr_values) {
153     bool is_type_list;
154     DataTypeVector dtypes;
155     TF_RETURN_IF_ERROR(
156         ArgNumType(attr_values, arg_def, &is_type_list, &dtypes));
157     CHECK_GE(dtypes.size(), size_t{1});
158     int arg_index = result_.nodes.size();
159     TF_RETURN_IF_ERROR(
160         AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes}));
161     // Creates dtypes.size() nodes in the graph.
162     for (size_t i = 0; i < dtypes.size(); ++i) {
163       TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":", i),
164                                  {true, arg_index, 0, false, {dtypes[i]}}));
165       DCHECK_EQ(arg_index, result_.nodes.size());
166       string name = arg_def.name();
167       if (dtypes.size() > 1) {
168         strings::StrAppend(&name, "_", i);
169       }
170       NodeDef* gnode = AddNode(name);
171       gnode->set_op("_Arg");
172       AddAttr("T", dtypes[i], gnode);
173       AddAttr("index", arg_index, gnode);
174       result_.arg_types.push_back(dtypes[i]);
175       ++arg_index;
176     }
177     return Status::OK();
178   }
179 
BuildNodeOutputIndex(const NodeDef & node,AttrSlice attrs,const int arg_index)180   Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs,
181                               const int arg_index) {
182     const OpDef* node_sig = nullptr;
183     TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig));
184     if (node_sig->output_arg_size() == 0) {
185       return AddItem(node.name(), {false, arg_index, 0, false, {}});
186     }
187     const int num_retval = node_sig->output_arg_size();
188     int start = 0;
189     bool is_type_list;
190     DataTypeVector dtypes;
191     for (int i = 0; i < num_retval; ++i) {
192       TF_RETURN_IF_ERROR(
193           ArgNumType(attrs, node_sig->output_arg(i), &is_type_list, &dtypes));
194       // Note that we rely on the backwards-compatibility test enforcing
195       // that output_arg(*).name() doesn't change here.
196       const string base_name =
197           strings::StrCat(node.name(), ":", node_sig->output_arg(i).name());
198       TF_RETURN_IF_ERROR(
199           AddItem(base_name, {false, arg_index, start, is_type_list, dtypes}));
200       for (int j = 0; j < static_cast<int>(dtypes.size()); ++j) {
201         TF_RETURN_IF_ERROR(
202             AddItem(strings::StrCat(base_name, ":", j),
203                     {false, arg_index, start + j, false, {dtypes[j]}}));
204       }
205       start += dtypes.size();
206     }
207     return Status::OK();
208   }
209 
InstantiateNode(const NodeDef & fnode,AttrSlice attrs)210   Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) {
211     const OpDef* fnode_sig = nullptr;
212     TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig));
213     NodeDef* gnode = AddNode(fnode.name());
214     gnode->set_op(fnode.op());
215     gnode->set_device(fnode.device());
216     int gnode_idx = nodes_.size() - 1;
217 
218     // Input
219     const int num_args = fnode_sig->input_arg_size();
220     bool is_type_list;  // ignored
221     DataTypeVector dtypes;
222     int fnode_arg_index = 0;
223     for (int i = 0; i < num_args; ++i) {
224       TF_RETURN_IF_ERROR(
225           ArgNumType(attrs, fnode_sig->input_arg(i), &is_type_list, &dtypes));
226       // Consume inputs (indexed by fnode_arg_index) until we have
227       // matched each element of dtypes (indexed by j).
228       for (size_t j = 0; j < dtypes.size(); ++fnode_arg_index) {
229         if (fnode_arg_index >= fnode.input_size()) {
230           // Should never happen if we computed dtypes correctly.
231           return errors::InvalidArgument(
232               "Attempt to access beyond input size: ", fnode_arg_index,
233               " >= ", fnode.input_size());
234         }
235         // Look up the next input.
236         const string& input_name = fnode.input(fnode_arg_index);
237         const auto* item = GetItemOrNull(input_name);
238         if (item == nullptr) {
239           return errors::InvalidArgument(
240               "input ", input_name, " is not found: ", SummarizeNodeDef(fnode));
241         }
242         if (item->dtypes.size() > dtypes.size() - j) {
243           return errors::InvalidArgument("Input ", input_name, " too long for ",
244                                          fnode_sig->input_arg(i).name());
245         }
246         // Match up all the elements of this input (indexed by k) with
247         // elements of dtypes (advancing j).
248         for (int k = 0; k < item->dtypes.size(); ++k, ++j) {
249           if (item->dtypes[k] != dtypes[j]) {
250             return errors::InvalidArgument(
251                 "input ", fnode_sig->input_arg(i).name(), "[", j,
252                 "] expected type ", DataTypeString(dtypes[j]),
253                 " != ", DataTypeString(item->dtypes[k]), ", the type of ",
254                 input_name, "[", k, "]");
255           }
256           if (item->is_func_arg) {
257             AddInput(gnode_idx, item->nid + k, 0);
258           } else {
259             AddInput(gnode_idx, item->nid, item->idx + k);
260           }
261         }
262       }
263     }
264 
265     // Control deps.
266     for (int i = fnode_arg_index; i < fnode.input_size(); ++i) {
267       const string& input = fnode.input(i);
268       if (input.empty() || input[0] != '^') {
269         return errors::InvalidArgument("Expected input[", i, "] == '", input,
270                                        "' to be a control input.");
271       }
272       int nid = -1;
273       const string node_name = input.substr(1);
274       const string node_colon = node_name + ":";
275       const string node_colon_bound = node_name + ";";
276       // index_ is a map sorted lexicographically, so the key we are looking for
277       // must lie in the range [node_name, node_colon_bound).
278       auto it = index_.lower_bound(node_name);
279       while (it != index_.end() && it->first <= node_colon_bound) {
280         if (it->first == node_name ||
281             tensorflow::StringPiece(it->first).starts_with(node_colon)) {
282           nid = it->second.nid;
283           break;
284         }
285         ++it;
286       }
287       if (nid == -1) {
288         return errors::InvalidArgument("input[", i, "] == '", input,
289                                        "', is not found.");
290       }
291       AddDep(gnode_idx, nid);
292     }
293 
294     // Attrs.
295     for (const auto& p : attrs) {
296       (*gnode->mutable_attr())[p.first] = p.second;
297     }
298 
299     return Status::OK();
300   }
301 
AddReturnNode(const OpDef::ArgDef & ret_def,AttrSlice attrs,const::tensorflow::protobuf::Map<string,string> & ret_map,int * ret_index)302   Status AddReturnNode(
303       const OpDef::ArgDef& ret_def, AttrSlice attrs,
304       const ::tensorflow::protobuf::Map<string, string>& ret_map,
305       int* ret_index) {
306     auto ret_iter = ret_map.find(ret_def.name());
307     if (ret_iter == ret_map.end()) {
308       return errors::InvalidArgument("Return ", ret_def.name(), " missing.");
309     }
310     bool is_type_list;
311     DataTypeVector dtypes;
312     TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
313     CHECK_GE(dtypes.size(), size_t{1});
314     const auto* item = GetItemOrNull(ret_iter->second);
315     if (item == nullptr) {
316       return errors::InvalidArgument("Return ", ret_def.name(), " -> ",
317                                      ret_iter->second, " is not found.");
318     }
319     if (dtypes != item->dtypes) {
320       return errors::InvalidArgument("Invalid ret types ", ret_def.name(),
321                                      " : ", DataTypeVectorString(dtypes),
322                                      " vs. ",
323                                      DataTypeVectorString(item->dtypes));
324     }
325     for (size_t i = 0; i < dtypes.size(); ++i) {
326       string name = strings::StrCat(ret_def.name(), "_RetVal");
327       if (dtypes.size() > 1) {
328         strings::StrAppend(&name, "_", i);
329       }
330       NodeDef* gnode = AddNode(name);
331       gnode->set_op("_Retval");
332       AddInput(nodes_.size() - 1, item->nid, item->idx + i);
333       AddAttr("T", dtypes[i], gnode);
334       AddAttr("index", (*ret_index)++, gnode);
335       result_.ret_types.push_back(dtypes[i]);
336     }
337     return Status::OK();
338   }
339 
340   // Adds the actual node inputs to the result graph by converting indexes to
341   // the node names.
AddNodeInputs()342   void AddNodeInputs() {
343     for (int i = 0; i < result_.nodes.size(); i++) {
344       NodeInfo& node_info = nodes_[i];
345       for (const auto& p : node_info.data_inputs) {
346         result_.nodes[i].add_input(Name(p.first, p.second));
347       }
348       for (int index : node_info.control_inputs) {
349         result_.nodes[i].add_input(Dep(index));
350       }
351     }
352   }
353 
354  private:
355   // This is used to build a small index for all names that can be used as a
356   // node's input arguments.
357   //
358   // If is_func_arg is true, the name is a function's argument.  In
359   // this case, the produced graph def has node[nid:nid + dtype.size()].
360   //
361   // Otherwise, the name is a function body's node return value.  In
362   // this case, the produced graph def has one node node[nid] and
363   // the node's output index [idx ... idx + num) corresponds to the
364   // named outputs.
365   //
366   // In all cases, "dtype" specifies the data type.
367   struct NameInfoItem {
368     bool is_func_arg;
369     int nid;
370     int idx;
371     bool is_type_list;
372     DataTypeVector dtypes;
373   };
374 
375   // Adds an item into the input name index.
AddItem(const string & name,const NameInfoItem & item)376   Status AddItem(const string& name, const NameInfoItem& item) {
377     if (!index_.insert({name, item}).second) {
378       return errors::InvalidArgument(
379           strings::StrCat("Duplicated ", item.is_func_arg ? "arg" : "ret",
380                           " name: "),
381           name);
382     }
383     return Status::OK();
384   }
385 
GetItemOrNull(const string & name) const386   const NameInfoItem* GetItemOrNull(const string& name) const {
387     return gtl::FindOrNull(index_, name);
388   }
389 
Dep(int node_index) const390   string Dep(int node_index) const {
391     return strings::StrCat("^", Name(node_index));
392   }
393 
Name(int node_index) const394   string Name(int node_index) const {
395     CHECK_LT(node_index, nodes_.size());
396     return nodes_[node_index].name;
397   }
398 
Name(int node_index,int output_index) const399   string Name(int node_index, int output_index) const {
400     if (output_index == 0) {
401       return Name(node_index);
402     } else {
403       return strings::StrCat(Name(node_index), ":", output_index);
404     }
405   }
406 
AddNode(const string & name)407   NodeDef* AddNode(const string& name) {
408     result_.nodes.emplace_back();
409     NodeDef* gnode = &result_.nodes.back();
410     gnode->set_name(name);
411     nodes_.push_back({name, {}, {}});
412     CHECK_EQ(result_.nodes.size(), nodes_.size());
413     return gnode;
414   }
415 
AddInput(int node_index,int output_node,int output_index)416   void AddInput(int node_index, int output_node, int output_index) {
417     CHECK_LT(node_index, nodes_.size());
418     nodes_[node_index].data_inputs.push_back(
419         std::make_pair(output_node, output_index));
420   }
421 
AddDep(int node_index,int dep_index)422   void AddDep(int node_index, int dep_index) {
423     CHECK_LT(node_index, nodes_.size());
424     nodes_[node_index].control_inputs.push_back(dep_index);
425   }
426 
427   GetFunctionSignature get_function_;
428   InstantiationResult& result_;
429   // A small index for all names that can be used as a node's input arguments.
430   std::map<string, NameInfoItem> index_;
431   // This contains information about a node in the new graph including the node
432   // names and input nodes' indexes.
433   struct NodeInfo {
434     string name;
435     // Data inputs where <n, k> means arg k of node n.
436     std::vector<std::pair<int, int>> data_inputs;
437     // Control inputs (dependencies).
438     std::vector<int> control_inputs;
439   };
440   // nodes_[i] is the information about result_.nodes[i].
441   std::vector<NodeInfo> nodes_;
442 };
443 
444 // Various helpers Print(proto) to print relevant protos to ascii.
Print(const OpDef::ArgDef & arg)445 string Print(const OpDef::ArgDef& arg) {
446   string out;
447   strings::StrAppend(&out, arg.name(), ":");
448   if (arg.is_ref()) strings::StrAppend(&out, "Ref(");
449   if (!arg.number_attr().empty()) {
450     strings::StrAppend(&out, arg.number_attr(), "*");
451   }
452   if (arg.type() != DT_INVALID) {
453     strings::StrAppend(&out, DataTypeString(arg.type()));
454   } else {
455     strings::StrAppend(&out, arg.type_attr());
456   }
457   if (arg.is_ref()) strings::StrAppend(&out, ")");
458   return out;
459 }
460 
461 // TODO(josh11b): Merge this with SummarizeAttrValue().
Print(const AttrValue & attr_value)462 string Print(const AttrValue& attr_value) {
463   if (attr_value.value_case() == AttrValue::kType) {
464     return DataTypeString(attr_value.type());
465   } else if ((attr_value.value_case() == AttrValue::kList) &&
466              (attr_value.list().type_size() > 0)) {
467     string ret = "{";
468     for (int i = 0; i < attr_value.list().type_size(); ++i) {
469       if (i > 0) strings::StrAppend(&ret, ", ");
470       strings::StrAppend(&ret, DataTypeString(attr_value.list().type(i)));
471     }
472     strings::StrAppend(&ret, "}");
473     return ret;
474   } else if (attr_value.value_case() == AttrValue::kFunc) {
475     if (attr_value.func().attr_size() == 0) {
476       return attr_value.func().name();
477     }
478     std::vector<string> entries;
479     for (auto p : attr_value.func().attr()) {
480       entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
481     }
482     std::sort(entries.begin(), entries.end());
483     return strings::StrCat(attr_value.func().name(), "[",
484                            str_util::Join(entries, ", "), "]");
485   }
486   return SummarizeAttrValue(attr_value);
487 }
488 
489 // TODO(josh11b): Merge this with SummarizeNodeDef().
Print(const NodeDef & n)490 string Print(const NodeDef& n) {
491   string out;
492   strings::StrAppend(&out, n.name(), " = ", n.op());
493   if (n.attr_size() > 0) {
494     std::vector<string> entries;
495     for (auto& a : n.attr()) {
496       entries.push_back(strings::StrCat(a.first, "=", Print(a.second)));
497     }
498     std::sort(entries.begin(), entries.end());
499     strings::StrAppend(&out, "[", str_util::Join(entries, ", "), "]");
500   }
501   strings::StrAppend(&out, "(");
502   std::vector<StringPiece> dat;
503   std::vector<string> dep;
504   for (StringPiece s : n.input()) {
505     if (s.Consume("^")) {
506       dep.push_back(s.ToString());
507     } else {
508       dat.push_back(s);
509     }
510   }
511   strings::StrAppend(&out, str_util::Join(dat, ", "), ")");
512   if (!dep.empty()) {
513     strings::StrAppend(&out, " @ ", str_util::Join(dep, ", "));
514   }
515   return out;
516 }
517 
Print(const FunctionDef & fdef)518 string Print(const FunctionDef& fdef) {
519   string out;
520   const OpDef& sig = fdef.signature();
521   strings::StrAppend(&out, "\n", sig.name());
522   if (sig.attr_size() > 0) {
523     strings::StrAppend(&out, "[");
524     for (int i = 0; i < sig.attr_size(); ++i) {
525       const auto& a = sig.attr(i);
526       if (i > 0) strings::StrAppend(&out, ", ");
527       if (a.type() == "type") {
528         strings::StrAppend(&out, a.name(), ":", Print(a.allowed_values()));
529       } else {
530         strings::StrAppend(&out, a.name(), ":", a.type());
531       }
532     }
533     strings::StrAppend(&out, "]");
534   }
535   strings::StrAppend(&out, "(");
536   for (int i = 0; i < sig.input_arg_size(); ++i) {
537     if (i > 0) strings::StrAppend(&out, ", ");
538     strings::StrAppend(&out, Print(sig.input_arg(i)));
539   }
540   strings::StrAppend(&out, ") -> (");
541   for (int i = 0; i < sig.output_arg_size(); ++i) {
542     if (i > 0) strings::StrAppend(&out, ", ");
543     strings::StrAppend(&out, Print(sig.output_arg(i)));
544   }
545   strings::StrAppend(&out, ") {\n");
546   for (const auto& n : fdef.node_def()) {
547     strings::StrAppend(&out, "  ", Print(n), "\n");
548   }
549   for (const auto& r : fdef.ret()) {
550     strings::StrAppend(&out, "  return ", r.first, " = ", r.second, "\n");
551   }
552   strings::StrAppend(&out, "}\n");
553   return out;
554 }
555 
Print(gtl::ArraySlice<const NodeDef * > nodes)556 string Print(gtl::ArraySlice<const NodeDef*> nodes) {
557   std::vector<const NodeDef*> arg;
558   std::vector<const NodeDef*> ret;
559   std::vector<const NodeDef*> body;
560   for (const NodeDef* n : nodes) {
561     if (n->op() == "_Arg") {
562       arg.push_back(n);
563     } else if (n->op() == "_Retval") {
564       ret.push_back(n);
565     } else {
566       body.push_back(n);
567     }
568   }
569   auto comp = [](const NodeDef* x, const NodeDef* y) {
570     int xi;
571     TF_CHECK_OK(GetNodeAttr(*x, "index", &xi));
572     int yi;
573     TF_CHECK_OK(GetNodeAttr(*y, "index", &yi));
574     return xi < yi;
575   };
576   std::sort(arg.begin(), arg.end(), comp);
577   std::sort(ret.begin(), ret.end(), comp);
578   string out;
579   strings::StrAppend(&out, "\n(");
580   auto get_type = [](const NodeDef& n) {
581     DataType dt;
582     if (!GetNodeAttr(n, "T", &dt).ok()) {
583       dt = DT_INVALID;
584     }
585     return DataTypeString(dt);
586   };
587   for (size_t i = 0; i < arg.size(); ++i) {
588     const NodeDef* n = arg[i];
589     if (i > 0) strings::StrAppend(&out, ", ");
590     CHECK_GE(n->attr_size(), 2);
591     strings::StrAppend(&out, n->name(), ":", get_type(*n));
592   }
593   strings::StrAppend(&out, ") -> (");
594   for (size_t i = 0; i < ret.size(); ++i) {
595     const NodeDef* n = ret[i];
596     if (i > 0) strings::StrAppend(&out, ", ");
597     CHECK_LE(2, n->attr_size());
598     CHECK_EQ(1, n->input_size());
599     strings::StrAppend(&out, n->input(0), ":", get_type(*n));
600   }
601   strings::StrAppend(&out, ") {\n");
602   for (size_t i = 0; i < body.size(); ++i) {
603     strings::StrAppend(&out, "  ", Print(*body[i]), "\n");
604   }
605   strings::StrAppend(&out, "}\n");
606   return out;
607 }
608 
AddDefaultAttrs(const string & op,const GetFunctionSignature & get_function,AttrValueMap * attrs)609 Status AddDefaultAttrs(const string& op,
610                        const GetFunctionSignature& get_function,
611                        AttrValueMap* attrs) {
612   const OpDef* op_def = nullptr;
613   TF_RETURN_IF_ERROR(get_function(op, &op_def));
614   AttrSlice attr_slice(attrs);
615   for (const auto& attr_def : op_def->attr()) {
616     if (attr_def.has_default_value() && !attr_slice.Find(attr_def.name())) {
617       if (!attrs->insert({attr_def.name(), attr_def.default_value()}).second) {
618         return errors::Internal("Somehow duplicated: ", attr_def.name());
619       }
620     }
621   }
622   return Status::OK();
623 }
624 
625 }  // end namespace
626 
InstantiateFunction(const FunctionDef & fdef,AttrSlice attr_values,GetFunctionSignature get_function,InstantiationResult * result)627 Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
628                            GetFunctionSignature get_function,
629                            InstantiationResult* result) {
630   VLOG(3) << "Instantiation Function: " << Print(fdef);
631 
632   const OpDef& sig = fdef.signature();
633   TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values));
634 
635   FunctionInstantiationHelper helper(get_function, result);
636   Status s;
637   for (const OpDef::ArgDef& arg_def : sig.input_arg()) {
638     s = helper.BuildInputArgIndex(arg_def, attr_values);
639     if (!s.ok()) {
640       errors::AppendToMessage(&s, "In ", Print(arg_def));
641       return s;
642     }
643   }
644 
645   auto substitute = [attr_values](StringPiece name, AttrValue* val) {
646     if (const AttrValue* v = attr_values.Find(name)) {
647       *val = *v;
648       return true;
649     }
650     return false;
651   };
652 
653   // Makes a copy of all attrs in fdef and substitutes placeholders.
654   // After this step, every attr is bound to a concrete value.
655   std::vector<AttrValueMap> node_attrs;
656   node_attrs.resize(fdef.node_def_size());
657   for (int i = 0; i < fdef.node_def_size(); ++i) {
658     for (auto attr : fdef.node_def(i).attr()) {
659       if (!SubstitutePlaceholders(substitute, &attr.second)) {
660         return errors::InvalidArgument("Failed to bind all placeholders in ",
661                                        SummarizeAttrValue(attr.second));
662       }
663       if (!node_attrs[i].insert(attr).second) {
664         return errors::Internal("Somehow duplicated: ", attr.first);
665       }
666     }
667     TF_RETURN_IF_ERROR(
668         AddDefaultAttrs(fdef.node_def(i).op(), get_function, &node_attrs[i]));
669   }
670 
671   for (int i = 0; i < fdef.node_def_size(); ++i) {
672     s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]),
673                                     result->nodes.size() + i);
674     if (!s.ok()) {
675       errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i)));
676       return s;
677     }
678   }
679   // Emits one node for each fdef.node_def.
680   for (int i = 0; i < fdef.node_def_size(); ++i) {
681     s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i]));
682     if (!s.ok()) {
683       errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i)));
684       return s;
685     }
686   }
687 
688   // Emits nodes for the function's return values.
689   int ret_index = 0;
690   for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
691     s = helper.AddReturnNode(ret_def, attr_values, fdef.ret(), &ret_index);
692     if (!s.ok()) {
693       errors::AppendToMessage(&s, "In function output ", Print(ret_def));
694       return s;
695     }
696   }
697 
698   // Adds the actual node inputs using the input indexes.
699   helper.AddNodeInputs();
700 
701   return Status::OK();
702 }
703 
DebugString(const FunctionDef & func_def)704 string DebugString(const FunctionDef& func_def) { return Print(func_def); }
705 
DebugString(const GraphDef & instantiated_func_def)706 string DebugString(const GraphDef& instantiated_func_def) {
707   std::vector<const NodeDef*> ptrs;
708   for (const NodeDef& n : instantiated_func_def.node()) {
709     ptrs.push_back(&n);
710   }
711   return Print(ptrs);
712 }
713 
DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes)714 string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes) {
715   std::vector<const NodeDef*> ptrs;
716   for (const NodeDef& n : instantiated_func_nodes) {
717     ptrs.push_back(&n);
718   }
719   return Print(ptrs);
720 }
721 
DebugStringWhole(const GraphDef & gdef)722 string DebugStringWhole(const GraphDef& gdef) {
723   string ret;
724   for (const auto& fdef : gdef.library().function()) {
725     strings::StrAppend(&ret, Print(fdef));
726   }
727   strings::StrAppend(&ret, "\n");
728   for (const auto& ndef : gdef.node()) {
729     strings::StrAppend(&ret, Print(ndef), "\n");
730   }
731   return ret;
732 }
733 
734 namespace {
735 
736 // Returns the name -> attr mapping of fdef's attrs that have a value set. In
737 // Python, it's possible to access unset attrs, which returns a default value
738 // and adds an unset attr to the map.
GetSetAttrs(const FunctionDef & fdef)739 std::map<string, AttrValue> GetSetAttrs(const FunctionDef& fdef) {
740   std::map<string, AttrValue> set_attrs;
741   for (auto pair : fdef.attr()) {
742     if (pair.second.value_case() != AttrValue::VALUE_NOT_SET) {
743       set_attrs[pair.first] = pair.second;
744     }
745   }
746   return set_attrs;
747 }
748 
749 }  // end namespace
750 
FunctionDefsEqual(const FunctionDef & f1,const FunctionDef & f2)751 bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
752   if (!OpDefEqual(f1.signature(), f2.signature())) return false;
753 
754   std::map<string, AttrValue> f1_attrs = GetSetAttrs(f1);
755   std::map<string, AttrValue> f2_attrs = GetSetAttrs(f2);
756   if (f1_attrs.size() != f2_attrs.size()) return false;
757   for (auto iter1 : f1_attrs) {
758     auto iter2 = f2_attrs.find(iter1.first);
759     if (iter2 == f2_attrs.end()) return false;
760     if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false;
761   }
762 
763   if (!EqualRepeatedNodeDef(f1.node_def(), f2.node_def(), nullptr)) {
764     return false;
765   }
766 
767   std::map<string, string> ret1(f1.ret().begin(), f1.ret().end());
768   std::map<string, string> ret2(f2.ret().begin(), f2.ret().end());
769   if (ret1 != ret2) return false;
770 
771   return true;
772 }
773 
FunctionDefHash(const FunctionDef & fdef)774 uint64 FunctionDefHash(const FunctionDef& fdef) {
775   // signature
776   uint64 h = OpDefHash(fdef.signature());
777 
778   // attrs
779   std::map<string, AttrValue> attrs = GetSetAttrs(fdef);
780   for (const auto& p : attrs) {
781     h = Hash64(p.first.data(), p.first.size(), h);
782     h = Hash64Combine(AttrValueHash(p.second), h);
783   }
784 
785   // node defs
786   h = Hash64Combine(RepeatedNodeDefHash(fdef.node_def()), h);
787 
788   // output names
789   std::map<string, string> ret(fdef.ret().begin(), fdef.ret().end());
790   for (const auto& p : ret) {
791     h = Hash64(p.first.data(), p.first.size(), h);
792     h = Hash64(p.second.data(), p.second.size(), h);
793   }
794 
795   return h;
796 }
797 
Canonicalize(const string & funcname,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options)798 string Canonicalize(const string& funcname, AttrSlice attrs,
799                     const FunctionLibraryRuntime::InstantiateOptions& options) {
800   std::vector<string> entries;
801   entries.reserve(options.target.empty() ? attrs.size() : (attrs.size() + 1));
802   for (auto p : attrs) {
803     entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
804   }
805   if (!options.target.empty()) {
806     entries.push_back(
807         strings::StrCat("_target", "=", str_util::CEscape(options.target)));
808   }
809   if (options.overlay_lib) {
810     entries.push_back(strings::StrCat(
811         "_overlay_lib", "=", reinterpret_cast<uintptr_t>(options.overlay_lib)));
812   }
813   if (!options.state_handle.empty()) {
814     entries.push_back(
815         strings::StrCat("_state_handle", "=", options.state_handle));
816   }
817   std::sort(entries.begin(), entries.end());
818   return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]");
819 }
820 
FunctionCallFrame(DataTypeSlice arg_types,DataTypeSlice ret_types)821 FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types,
822                                      DataTypeSlice ret_types)
823     : arg_types_(arg_types.begin(), arg_types.end()),
824       ret_types_(ret_types.begin(), ret_types.end()) {
825   args_.resize(arg_types_.size());
826   rets_.resize(ret_types_.size());
827 }
828 
~FunctionCallFrame()829 FunctionCallFrame::~FunctionCallFrame() {}
830 
SetArgs(gtl::ArraySlice<Tensor> args)831 Status FunctionCallFrame::SetArgs(gtl::ArraySlice<Tensor> args) {
832   // Input type checks.
833   if (args.size() != arg_types_.size()) {
834     return errors::InvalidArgument("Expects ", arg_types_.size(),
835                                    " arguments, but ", args.size(),
836                                    " is provided");
837   }
838   for (size_t i = 0; i < args.size(); ++i) {
839     if (arg_types_[i] != args[i].dtype()) {
840       return errors::InvalidArgument(
841           "Expects arg[", i, "] to be ", DataTypeString(arg_types_[i]), " but ",
842           DataTypeString(args[i].dtype()), " is provided");
843     }
844     args_[i] = args[i];
845   }
846   return Status::OK();
847 }
848 
GetRetvals(std::vector<Tensor> * rets) const849 Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
850   rets->clear();
851   rets->reserve(rets_.size());
852   for (size_t i = 0; i < rets_.size(); ++i) {
853     const auto& item = rets_[i];
854     if (item.has_val) {
855       rets->push_back(item.val);
856     } else {
857       return errors::Internal("Retval[", i, "] does not have value");
858     }
859   }
860   return Status::OK();
861 }
862 
ConsumeRetvals(std::vector<Tensor> * rets)863 Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets) {
864   rets->clear();
865   rets->reserve(rets_.size());
866   for (size_t i = 0; i < rets_.size(); ++i) {
867     if (rets_[i].has_val) {
868       rets->emplace_back(std::move(rets_[i].val));
869     } else {
870       return errors::Internal("Retval[", i, "] does not have value");
871     }
872   }
873   return Status::OK();
874 }
875 
GetArg(int index,Tensor * val) const876 Status FunctionCallFrame::GetArg(int index, Tensor* val) const {
877   if (index < 0 || static_cast<size_t>(index) >= args_.size()) {
878     return errors::InvalidArgument("GetArg ", index, " is not within [0, ",
879                                    args_.size(), ")");
880   }
881   *val = args_[index];
882   return Status::OK();
883 }
884 
SetRetval(int index,const Tensor & val)885 Status FunctionCallFrame::SetRetval(int index, const Tensor& val) {
886   if (index < 0 || static_cast<size_t>(index) >= rets_.size()) {
887     return errors::InvalidArgument("SetRetval ", index, " is not within [0, ",
888                                    rets_.size(), ")");
889   }
890   if (val.dtype() != ret_types_[index]) {
891     return errors::InvalidArgument(
892         "Expects ret[", index, "] to be ", DataTypeString(ret_types_[index]),
893         ", but ", DataTypeString(val.dtype()), " is provided.");
894   }
895   Retval* item = &rets_[index];
896   if (!item->has_val) {
897     item->has_val = true;
898     item->val = val;
899   } else {
900     return errors::Internal("Retval[", index, "] has already been set.");
901   }
902   return Status::OK();
903 }
904 
905 FunctionLibraryDefinition::FunctionDefAndOpRegistration::
FunctionDefAndOpRegistration(const FunctionDef & fdef_in)906     FunctionDefAndOpRegistration(const FunctionDef& fdef_in)
907     : fdef(fdef_in),
908       // Exact shape inference for functions is handled by ShapeRefiner.
909       // Here we pass a dummy shape inference function for legacy code paths.
910       op_registration_data(fdef.signature(), shape_inference::UnknownShape,
911                            true /* is_function */) {}
912 
FunctionLibraryDefinition(const FunctionLibraryDefinition & other)913 FunctionLibraryDefinition::FunctionLibraryDefinition(
914     const FunctionLibraryDefinition& other)
915     : default_registry_(other.default_registry_), func_grad_(other.func_grad_) {
916   for (const auto& it : other.function_defs_) {
917     TF_CHECK_OK(AddFunctionDef(it.second->fdef));
918   }
919 }
920 
FunctionLibraryDefinition(const OpRegistryInterface * default_registry,const FunctionDefLibrary & def_lib)921 FunctionLibraryDefinition::FunctionLibraryDefinition(
922     const OpRegistryInterface* default_registry,
923     const FunctionDefLibrary& def_lib)
924     : default_registry_(default_registry),
925       function_defs_(def_lib.function_size()) {
926   for (const auto& fdef : def_lib.function()) {
927     // The latter function definition wins.
928     auto& ptr = function_defs_[fdef.signature().name()];
929     ptr.reset(new FunctionDefAndOpRegistration(fdef));
930   }
931   for (const auto& grad : def_lib.gradient()) {
932     func_grad_[grad.function_name()] = grad.gradient_func();
933   }
934 }
935 
~FunctionLibraryDefinition()936 FunctionLibraryDefinition::~FunctionLibraryDefinition() {}
937 
Find(const string & name) const938 const FunctionDef* FunctionLibraryDefinition::Find(const string& name) const {
939   auto iter = function_defs_.find(name);
940   if (iter == function_defs_.end()) {
941     return nullptr;
942   } else {
943     return &iter->second->fdef;
944   }
945 }
946 
AddFunctionDef(const FunctionDef & fdef)947 Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
948   bool added;
949   return AddFunctionDefHelper(fdef, &added);
950 }
951 
AddFunctionDefHelper(const FunctionDef & fdef,bool * added)952 Status FunctionLibraryDefinition::AddFunctionDefHelper(const FunctionDef& fdef,
953                                                        bool* added) {
954   *added = false;
955   std::unique_ptr<FunctionDefAndOpRegistration>* entry =
956       &function_defs_[fdef.signature().name()];
957   if (*entry != nullptr) {
958     if (!FunctionDefsEqual((*entry)->fdef, fdef)) {
959       return errors::InvalidArgument(
960           "Cannot add function '", fdef.signature().name(),
961           "' because a different function with the same name already "
962           "exists.");
963     }
964     // Ignore duplicate FunctionDefs
965     return Status::OK();
966   }
967   const OpDef* op_def;
968   if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) {
969     return errors::InvalidArgument(
970         "Cannot add function '", fdef.signature().name(),
971         "' because an op with the same name already exists.");
972   }
973   entry->reset(new FunctionDefAndOpRegistration(fdef));
974   *added = true;
975   return Status::OK();
976 }
977 
AddGradientDef(const GradientDef & grad)978 Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
979   bool added;
980   return AddGradientDefHelper(grad, &added);
981 }
982 
AddGradientDefHelper(const GradientDef & grad,bool * added)983 Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad,
984                                                        bool* added) {
985   *added = false;
986   string* entry = &func_grad_[grad.function_name()];
987   if (!entry->empty()) {
988     if (*entry != grad.gradient_func()) {
989       return errors::InvalidArgument(
990           "Cannot assign gradient function '", grad.gradient_func(), "' to '",
991           grad.function_name(), "' because it already has gradient function ",
992           "'", *entry, "'");
993     }
994     // Ignore duplicate GradientDefs
995     return Status::OK();
996   }
997   *entry = grad.gradient_func();
998   *added = true;
999   return Status::OK();
1000 }
1001 
AddLibrary(const FunctionLibraryDefinition & other)1002 Status FunctionLibraryDefinition::AddLibrary(
1003     const FunctionLibraryDefinition& other) {
1004   // Remember the funcs and grads that we added successfully so that
1005   // we can roll them back on error.
1006   std::vector<string> funcs;
1007   std::vector<string> funcs_with_grads;
1008   Status s;
1009   bool added;
1010   for (auto iter : other.function_defs_) {
1011     s = AddFunctionDefHelper(iter.second->fdef, &added);
1012     if (!s.ok()) {
1013       Remove(funcs, funcs_with_grads);
1014       return s;
1015     }
1016     if (added) {
1017       funcs.push_back(iter.second->fdef.signature().name());
1018     }
1019   }
1020   for (auto iter : other.func_grad_) {
1021     GradientDef grad;
1022     grad.set_function_name(iter.first);
1023     grad.set_gradient_func(iter.second);
1024     s = AddGradientDefHelper(grad, &added);
1025     if (!s.ok()) {
1026       Remove(funcs, funcs_with_grads);
1027       return s;
1028     }
1029     if (added) {
1030       funcs_with_grads.push_back(grad.function_name());
1031     }
1032   }
1033   return Status::OK();
1034 }
1035 
AddLibrary(const FunctionDefLibrary & lib_def)1036 Status FunctionLibraryDefinition::AddLibrary(
1037     const FunctionDefLibrary& lib_def) {
1038   // Remember the funcs and grads that we added successfully so that
1039   // we can roll them back on error.
1040   std::vector<string> funcs;
1041   std::vector<string> funcs_with_grads;
1042   Status s;
1043   bool added;
1044   for (const FunctionDef& fdef : lib_def.function()) {
1045     s = AddFunctionDefHelper(fdef, &added);
1046     if (!s.ok()) {
1047       Remove(funcs, funcs_with_grads);
1048       return s;
1049     }
1050     if (added) {
1051       funcs.push_back(fdef.signature().name());
1052     }
1053   }
1054   for (const GradientDef& grad : lib_def.gradient()) {
1055     s = AddGradientDefHelper(grad, &added);
1056     if (!s.ok()) {
1057       Remove(funcs, funcs_with_grads);
1058       return s;
1059     }
1060     if (added) {
1061       funcs_with_grads.push_back(grad.function_name());
1062     }
1063   }
1064   return Status::OK();
1065 }
1066 
RemoveFunction(const string & func)1067 Status FunctionLibraryDefinition::RemoveFunction(const string& func) {
1068   const auto& i = function_defs_.find(func);
1069   if (i == function_defs_.end()) {
1070     return errors::InvalidArgument("Tried to remove non-existent function ",
1071                                    func);
1072   }
1073   function_defs_.erase(i);
1074   return Status::OK();
1075 }
1076 
RemoveGradient(const string & func)1077 Status FunctionLibraryDefinition::RemoveGradient(const string& func) {
1078   const auto& i = func_grad_.find(func);
1079   if (i == func_grad_.end()) {
1080     return errors::InvalidArgument("Tried to remove non-existent gradient ",
1081                                    func);
1082   }
1083   func_grad_.erase(i);
1084   return Status::OK();
1085 }
1086 
Remove(const std::vector<string> & funcs,const std::vector<string> & funcs_with_grads)1087 void FunctionLibraryDefinition::Remove(
1088     const std::vector<string>& funcs,
1089     const std::vector<string>& funcs_with_grads) {
1090   for (const string& f : funcs) {
1091     Status s = RemoveFunction(f);
1092     DCHECK(s.ok());
1093   }
1094   for (const string& f : funcs_with_grads) {
1095     Status s = RemoveGradient(f);
1096     DCHECK(s.ok());
1097   }
1098 }
1099 
FindGradient(const string & func) const1100 string FunctionLibraryDefinition::FindGradient(const string& func) const {
1101   return gtl::FindWithDefault(func_grad_, func, "");
1102 }
1103 
LookUp(const string & op,const OpRegistrationData ** op_reg_data) const1104 Status FunctionLibraryDefinition::LookUp(
1105     const string& op, const OpRegistrationData** op_reg_data) const {
1106   auto iter = function_defs_.find(op);
1107   if (iter != function_defs_.end()) {
1108     *op_reg_data = &iter->second->op_registration_data;
1109     return Status::OK();
1110   }
1111   return default_registry_->LookUp(op, op_reg_data);
1112 }
1113 
GetAttrImpl(const NodeDef & ndef) const1114 const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
1115     const NodeDef& ndef) const {
1116   if (ndef.op() != kGradientOp) {
1117     // If 'ndef' calls a function and the function's def has the attr,
1118     // returns it.
1119     return Find(ndef.op());
1120   }
1121 
1122   // If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or
1123   // Foo's attributes.
1124   const NameAttrList* forward_func_attrs;
1125   if (!GetNodeAttr(ndef, kFuncAttr, &forward_func_attrs).ok()) {
1126     return nullptr;
1127   }
1128   const string& func_name = forward_func_attrs->name();
1129   const string& grad_name = FindGradient(func_name);
1130   // If 'func' has a user-defined gradient function, uses the grad
1131   // function's attrs to see if noinline is specified. Otherwise,
1132   // uses func's attrs.
1133   if (!grad_name.empty()) {
1134     return Find(grad_name);
1135   }
1136   return Find(func_name);
1137 }
1138 
ToProto() const1139 FunctionDefLibrary FunctionLibraryDefinition::ToProto() const {
1140   FunctionDefLibrary lib;
1141   for (const auto& f : function_defs_) {
1142     *lib.add_function() = f.second->fdef;
1143   }
1144   for (const auto& g : func_grad_) {
1145     GradientDef* gd = lib.add_gradient();
1146     gd->set_function_name(g.first);
1147     gd->set_gradient_func(g.second);
1148   }
1149   return lib;
1150 }
1151 
1152 template <typename T>
GetAttr(const NodeDef & ndef,const string & attr,T * value) const1153 Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
1154                                           const string& attr, T* value) const {
1155   const FunctionDef* fdef = GetAttrImpl(ndef);
1156   if (fdef && GetNodeAttr(AttrSlice(&fdef->attr()), attr, value).ok()) {
1157     return Status::OK();
1158   }
1159   return errors::InvalidArgument("Attr ", attr, " is not defined.");
1160 }
1161 
1162 template <typename T>
GetAttr(const Node & node,const string & attr,T * value) const1163 Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr,
1164                                           T* value) const {
1165   return GetAttr(node.def(), attr, value);
1166 }
1167 
1168 #define GET_ATTR(T)                                                            \
1169   template Status FunctionLibraryDefinition::GetAttr(const Node&,              \
1170                                                      const string&, T*) const; \
1171   template Status FunctionLibraryDefinition::GetAttr(const NodeDef&,           \
1172                                                      const string&, T*) const;
1173 GET_ATTR(string)
GET_ATTR(bool)1174 GET_ATTR(bool)
1175 #undef GET_ATTR
1176 
1177 void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) {
1178   if (val.size() >= 2 && val[0] == '$') {
1179     proto.set_placeholder(val.data() + 1, val.size() - 1);
1180   } else {
1181     SetAttrValue(val, &proto);
1182   }
1183 }
1184 
FunctionRef(const string & name,gtl::ArraySlice<std::pair<string,AttrValueWrapper>> attrs)1185 FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef(
1186     const string& name,
1187     gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs) {
1188   AttrValueWrapper ret;
1189   ret.proto.mutable_func()->set_name(name);
1190   for (const auto& a : attrs) {
1191     ret.proto.mutable_func()->mutable_attr()->insert({a.first, a.second.proto});
1192   }
1193   return ret;
1194 }
1195 
ToNodeDef() const1196 NodeDef FunctionDefHelper::Node::ToNodeDef() const {
1197   NodeDef n;
1198   n.set_op(this->op);
1199   n.set_name(this->ret[0]);
1200   for (const auto& a : this->attr) {
1201     n.mutable_attr()->insert({a.first, a.second.proto});
1202   }
1203   for (const string& a : this->arg) {
1204     n.add_input(a);
1205   }
1206   for (const string& d : this->dep) {
1207     n.add_input(strings::StrCat("^", d));
1208   }
1209   return n;
1210 }
1211 
1212 /* static */
Create(const string & function_name,gtl::ArraySlice<string> in_def,gtl::ArraySlice<string> out_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def,gtl::ArraySlice<std::pair<string,string>> ret_def)1213 FunctionDef FunctionDefHelper::Create(
1214     const string& function_name, gtl::ArraySlice<string> in_def,
1215     gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
1216     gtl::ArraySlice<Node> node_def,
1217     gtl::ArraySlice<std::pair<string, string>> ret_def) {
1218   FunctionDef fdef;
1219 
1220   // Signature
1221   OpDefBuilder b(function_name);
1222   for (const auto& i : in_def) b.Input(i);
1223   for (const auto& o : out_def) b.Output(o);
1224   for (const auto& a : attr_def) b.Attr(a);
1225 
1226   OpRegistrationData op_reg_data;
1227   TF_CHECK_OK(b.Finalize(&op_reg_data));
1228   fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1229 
1230   // Function body
1231   for (const auto& n : node_def) {
1232     *(fdef.add_node_def()) = n.ToNodeDef();
1233   }
1234 
1235   // Returns
1236   for (const auto& r : ret_def) {
1237     fdef.mutable_ret()->insert({r.first, r.second});
1238   }
1239   return fdef;
1240 }
1241 
1242 /* static */
Define(const string & name,gtl::ArraySlice<string> arg_def,gtl::ArraySlice<string> ret_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def)1243 FunctionDef FunctionDefHelper::Define(const string& name,
1244                                       gtl::ArraySlice<string> arg_def,
1245                                       gtl::ArraySlice<string> ret_def,
1246                                       gtl::ArraySlice<string> attr_def,
1247                                       gtl::ArraySlice<Node> node_def) {
1248   FunctionDef fdef;
1249   OpDefBuilder b(name);
1250   for (const auto& a : arg_def) b.Input(a);
1251   for (const auto& r : ret_def) b.Output(r);
1252   for (const auto& a : attr_def) b.Attr(a);
1253 
1254   OpRegistrationData op_reg_data;
1255   TF_CHECK_OK(b.Finalize(&op_reg_data));
1256   fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1257 
1258   // Mapping from legacy output names to NodeDef outputs.
1259   std::unordered_map<string, string> ret_index;
1260   for (const auto& a : fdef.signature().input_arg()) {
1261     ret_index[a.name()] = a.name();
1262   }
1263 
1264   // For looking up OpDefs
1265   auto* op_def_registry = OpRegistry::Global();
1266 
1267   // Function body
1268   for (const auto& src : node_def) {
1269     NodeDef* n = fdef.add_node_def();
1270     n->set_op(src.op);
1271     n->set_name(src.ret[0]);
1272     for (const auto& a : src.attr) {
1273       n->mutable_attr()->insert({a.first, a.second.proto});
1274     }
1275     for (const string& a : src.arg) {
1276       const auto iter = ret_index.find(a);
1277       CHECK(iter != ret_index.end())
1278           << "Node input '" << a << "' in '" << src.ret[0] << "' of " << name;
1279       n->add_input(iter->second);
1280     }
1281     for (const string& d : src.dep) {
1282       n->add_input(strings::StrCat("^", d));
1283     }
1284 
1285     // Add the outputs of this node to ret_index.
1286     const OpDef* op_def = nullptr;
1287     TF_CHECK_OK(op_def_registry->LookUpOpDef(n->op(), &op_def)) << n->op();
1288     CHECK(op_def != nullptr) << n->op();
1289     NameRangeMap output_names;
1290     TF_CHECK_OK(NameRangesForNode(*n, *op_def, nullptr, &output_names));
1291     for (const auto& o : output_names) {
1292       CHECK_LE(o.second.second, src.ret.size())
1293           << "Missing ret for output '" << o.first << "' in '" << src.ret[0]
1294           << "' of " << name;
1295       for (int i = o.second.first; i < o.second.second; ++i) {
1296         ret_index[src.ret[i]] =
1297             strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first);
1298       }
1299     }
1300   }
1301 
1302   // Returns
1303   for (const auto& r : fdef.signature().output_arg()) {
1304     const auto iter = ret_index.find(r.name());
1305     CHECK(iter != ret_index.end()) << "Return '" << r.name() << "' in " << name;
1306     fdef.mutable_ret()->insert({r.name(), iter->second});
1307   }
1308   return fdef;
1309 }
1310 
Define(gtl::ArraySlice<string> arg_def,gtl::ArraySlice<string> ret_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def)1311 FunctionDef FunctionDefHelper::Define(gtl::ArraySlice<string> arg_def,
1312                                       gtl::ArraySlice<string> ret_def,
1313                                       gtl::ArraySlice<string> attr_def,
1314                                       gtl::ArraySlice<Node> node_def) {
1315   return Define("_", arg_def, ret_def, attr_def, node_def);
1316 }
1317 
1318 namespace gradient {
1319 
1320 typedef std::unordered_map<string, Creator> OpGradFactory;
1321 
GetOpGradFactory()1322 OpGradFactory* GetOpGradFactory() {
1323   static OpGradFactory* factory = new OpGradFactory;
1324   return factory;
1325 }
1326 
RegisterOp(const string & op,Creator func)1327 bool RegisterOp(const string& op, Creator func) {
1328   CHECK(GetOpGradFactory()->insert({op, func}).second)
1329       << "Duplicated gradient for " << op;
1330   return true;
1331 }
1332 
GetOpGradientCreator(const string & op,Creator * creator)1333 Status GetOpGradientCreator(const string& op, Creator* creator) {
1334   auto fac = GetOpGradFactory();
1335   auto iter = fac->find(op);
1336   if (iter == fac->end()) {
1337     return errors::NotFound("No gradient defined for op: ", op);
1338   }
1339   *creator = iter->second;
1340   return Status::OK();
1341 }
1342 
1343 }  // end namespace gradient
1344 
1345 }  // end namespace tensorflow
1346