1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/function.h"
17
18 #include <map>
19 #include <unordered_map>
20 #include <utility>
21 #include <vector>
22
23 #include "tensorflow/core/framework/common_shape_fns.h"
24 #include "tensorflow/core/framework/function.pb_text.h"
25 #include "tensorflow/core/framework/graph.pb.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/node_def_util.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/gtl/inlined_vector.h"
32 #include "tensorflow/core/lib/gtl/map_util.h"
33 #include "tensorflow/core/util/equal_graph_def.h"
34
35 namespace tensorflow {
36
37 // Extracts the actual type from "attr_values" based on its definition
38 // "arg_def".
39 //
40 // If "arg_def" is a N*T type, *is_type_list is set to false, and
41 // *dtypes is set to be a vector of size N and each element is T.
42 //
43 // If "arg_def" is a list(type), *is_type_list is set to true, and
44 // *dtypes is set to be a vector of types specified in attrs for
45 // arg_def.
46 //
47 // Otherwise (arg_def is a simple type T), *is_type_list is set to
48 // false, and *dtypes is set to a single element vector, whose only
49 // element is T.
ArgNumType(AttrSlice attrs,const OpDef::ArgDef & arg_def,bool * is_type_list,DataTypeVector * dtypes)50 Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
51 bool* is_type_list, DataTypeVector* dtypes) {
52 dtypes->clear();
53 if (!arg_def.type_list_attr().empty()) {
54 const AttrValue* v = attrs.Find(arg_def.type_list_attr());
55 if (v == nullptr) {
56 return errors::NotFound("type attr not found: ",
57 arg_def.type_list_attr());
58 }
59 *is_type_list = true;
60 for (int i = 0; i < v->list().type_size(); ++i) {
61 dtypes->push_back(v->list().type(i));
62 }
63 return Status::OK();
64 }
65
66 *is_type_list = false;
67 int num = 1;
68 if (!arg_def.number_attr().empty()) {
69 const AttrValue* v = attrs.Find(arg_def.number_attr());
70 if (v == nullptr) {
71 return errors::NotFound("type attr not found: ", arg_def.type_attr());
72 }
73 num = v->i();
74 }
75
76 DataType dtype;
77 if (arg_def.type() != DT_INVALID) {
78 dtype = arg_def.type();
79 } else if (arg_def.type_attr().empty()) {
80 dtype = DT_INVALID;
81 } else {
82 const AttrValue* v = attrs.Find(arg_def.type_attr());
83 if (v == nullptr) {
84 return errors::NotFound("type attr not found: ", arg_def.type_attr());
85 }
86 dtype = v->type();
87 }
88 dtypes->resize(num, dtype);
89 return Status::OK();
90 }
91
92 namespace {
93
94 template <typename T>
AddAttr(const string & name,const T & val,NodeDef * ndef)95 void AddAttr(const string& name, const T& val, NodeDef* ndef) {
96 SetAttrValue(val, &((*ndef->mutable_attr())[name]));
97 }
98
ValidateSignatureWithAttrs(const OpDef & sig,AttrSlice attr_values)99 Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) {
100 // attr_values should specify all attrs defined in fdef.
101 for (const auto& a : sig.attr()) {
102 const AttrValue* v = attr_values.Find(a.name());
103 if (!v) {
104 return errors::NotFound("Attr ", a.name(), " is not found from ",
105 SummarizeOpDef(sig));
106 }
107 Status status = AttrValueHasType(*v, a.type());
108 if (!status.ok()) {
109 errors::AppendToMessage(&status, "for attr '", a.name(), "'");
110 return status;
111 }
112 }
113
114 // TODO(josh11b): Enable this code once it works with function gradients.
115 // Right now the C++ function gradient code assumes it can pass
116 // all the attrs of the function to the gradient, and any attrs that
117 // the gradient doesn't care about will be ignored.
118 #if 0
119 if (attr_values.size() != sig.attr_size()) {
120 for (const auto& a : attr_values) {
121 // TODO(josh11b): Possibly should ignore attrs that start with "_" here?
122 bool found = false;
123 for (const auto& s : sig.attr()) {
124 if (a.first == s.name()) {
125 found = true;
126 break;
127 }
128 }
129 if (!found) {
130 return errors::NotFound("Attr ", a.first, " is not found in ",
131 SummarizeOpDef(sig));
132 }
133 }
134 }
135 #endif
136
137 return Status::OK();
138 }
139
140 // A helper class for instantiating functions. This contains shared information
141 // like the resulting graph and node name index.
142 class FunctionInstantiationHelper {
143 public:
FunctionInstantiationHelper(GetFunctionSignature get_function,InstantiationResult * result)144 FunctionInstantiationHelper(GetFunctionSignature get_function,
145 InstantiationResult* result)
146 : get_function_(std ::move(get_function)), result_(*result) {
147 result_.nodes.clear();
148 }
149
150 // Builds index for nodes that can be used as node's input arguments.
BuildInputArgIndex(const OpDef::ArgDef & arg_def,AttrSlice attr_values)151 Status BuildInputArgIndex(const OpDef::ArgDef& arg_def,
152 AttrSlice attr_values) {
153 bool is_type_list;
154 DataTypeVector dtypes;
155 TF_RETURN_IF_ERROR(
156 ArgNumType(attr_values, arg_def, &is_type_list, &dtypes));
157 CHECK_GE(dtypes.size(), size_t{1});
158 int arg_index = result_.nodes.size();
159 TF_RETURN_IF_ERROR(
160 AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes}));
161 // Creates dtypes.size() nodes in the graph.
162 for (size_t i = 0; i < dtypes.size(); ++i) {
163 TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":", i),
164 {true, arg_index, 0, false, {dtypes[i]}}));
165 DCHECK_EQ(arg_index, result_.nodes.size());
166 string name = arg_def.name();
167 if (dtypes.size() > 1) {
168 strings::StrAppend(&name, "_", i);
169 }
170 NodeDef* gnode = AddNode(name);
171 gnode->set_op("_Arg");
172 AddAttr("T", dtypes[i], gnode);
173 AddAttr("index", arg_index, gnode);
174 result_.arg_types.push_back(dtypes[i]);
175 ++arg_index;
176 }
177 return Status::OK();
178 }
179
BuildNodeOutputIndex(const NodeDef & node,AttrSlice attrs,const int arg_index)180 Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs,
181 const int arg_index) {
182 const OpDef* node_sig = nullptr;
183 TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig));
184 if (node_sig->output_arg_size() == 0) {
185 return AddItem(node.name(), {false, arg_index, 0, false, {}});
186 }
187 const int num_retval = node_sig->output_arg_size();
188 int start = 0;
189 bool is_type_list;
190 DataTypeVector dtypes;
191 for (int i = 0; i < num_retval; ++i) {
192 TF_RETURN_IF_ERROR(
193 ArgNumType(attrs, node_sig->output_arg(i), &is_type_list, &dtypes));
194 // Note that we rely on the backwards-compatibility test enforcing
195 // that output_arg(*).name() doesn't change here.
196 const string base_name =
197 strings::StrCat(node.name(), ":", node_sig->output_arg(i).name());
198 TF_RETURN_IF_ERROR(
199 AddItem(base_name, {false, arg_index, start, is_type_list, dtypes}));
200 for (int j = 0; j < static_cast<int>(dtypes.size()); ++j) {
201 TF_RETURN_IF_ERROR(
202 AddItem(strings::StrCat(base_name, ":", j),
203 {false, arg_index, start + j, false, {dtypes[j]}}));
204 }
205 start += dtypes.size();
206 }
207 return Status::OK();
208 }
209
InstantiateNode(const NodeDef & fnode,AttrSlice attrs)210 Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) {
211 const OpDef* fnode_sig = nullptr;
212 TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig));
213 NodeDef* gnode = AddNode(fnode.name());
214 gnode->set_op(fnode.op());
215 gnode->set_device(fnode.device());
216 int gnode_idx = nodes_.size() - 1;
217
218 // Input
219 const int num_args = fnode_sig->input_arg_size();
220 bool is_type_list; // ignored
221 DataTypeVector dtypes;
222 int fnode_arg_index = 0;
223 for (int i = 0; i < num_args; ++i) {
224 TF_RETURN_IF_ERROR(
225 ArgNumType(attrs, fnode_sig->input_arg(i), &is_type_list, &dtypes));
226 // Consume inputs (indexed by fnode_arg_index) until we have
227 // matched each element of dtypes (indexed by j).
228 for (size_t j = 0; j < dtypes.size(); ++fnode_arg_index) {
229 if (fnode_arg_index >= fnode.input_size()) {
230 // Should never happen if we computed dtypes correctly.
231 return errors::InvalidArgument(
232 "Attempt to access beyond input size: ", fnode_arg_index,
233 " >= ", fnode.input_size());
234 }
235 // Look up the next input.
236 const string& input_name = fnode.input(fnode_arg_index);
237 const auto* item = GetItemOrNull(input_name);
238 if (item == nullptr) {
239 return errors::InvalidArgument(
240 "input ", input_name, " is not found: ", SummarizeNodeDef(fnode));
241 }
242 if (item->dtypes.size() > dtypes.size() - j) {
243 return errors::InvalidArgument("Input ", input_name, " too long for ",
244 fnode_sig->input_arg(i).name());
245 }
246 // Match up all the elements of this input (indexed by k) with
247 // elements of dtypes (advancing j).
248 for (int k = 0; k < item->dtypes.size(); ++k, ++j) {
249 if (item->dtypes[k] != dtypes[j]) {
250 return errors::InvalidArgument(
251 "input ", fnode_sig->input_arg(i).name(), "[", j,
252 "] expected type ", DataTypeString(dtypes[j]),
253 " != ", DataTypeString(item->dtypes[k]), ", the type of ",
254 input_name, "[", k, "]");
255 }
256 if (item->is_func_arg) {
257 AddInput(gnode_idx, item->nid + k, 0);
258 } else {
259 AddInput(gnode_idx, item->nid, item->idx + k);
260 }
261 }
262 }
263 }
264
265 // Control deps.
266 for (int i = fnode_arg_index; i < fnode.input_size(); ++i) {
267 const string& input = fnode.input(i);
268 if (input.empty() || input[0] != '^') {
269 return errors::InvalidArgument("Expected input[", i, "] == '", input,
270 "' to be a control input.");
271 }
272 int nid = -1;
273 const string node_name = input.substr(1);
274 const string node_colon = node_name + ":";
275 const string node_colon_bound = node_name + ";";
276 // index_ is a map sorted lexicographically, so the key we are looking for
277 // must lie in the range [node_name, node_colon_bound).
278 auto it = index_.lower_bound(node_name);
279 while (it != index_.end() && it->first <= node_colon_bound) {
280 if (it->first == node_name ||
281 tensorflow::StringPiece(it->first).starts_with(node_colon)) {
282 nid = it->second.nid;
283 break;
284 }
285 ++it;
286 }
287 if (nid == -1) {
288 return errors::InvalidArgument("input[", i, "] == '", input,
289 "', is not found.");
290 }
291 AddDep(gnode_idx, nid);
292 }
293
294 // Attrs.
295 for (const auto& p : attrs) {
296 (*gnode->mutable_attr())[p.first] = p.second;
297 }
298
299 return Status::OK();
300 }
301
AddReturnNode(const OpDef::ArgDef & ret_def,AttrSlice attrs,const::tensorflow::protobuf::Map<string,string> & ret_map,int * ret_index)302 Status AddReturnNode(
303 const OpDef::ArgDef& ret_def, AttrSlice attrs,
304 const ::tensorflow::protobuf::Map<string, string>& ret_map,
305 int* ret_index) {
306 auto ret_iter = ret_map.find(ret_def.name());
307 if (ret_iter == ret_map.end()) {
308 return errors::InvalidArgument("Return ", ret_def.name(), " missing.");
309 }
310 bool is_type_list;
311 DataTypeVector dtypes;
312 TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
313 CHECK_GE(dtypes.size(), size_t{1});
314 const auto* item = GetItemOrNull(ret_iter->second);
315 if (item == nullptr) {
316 return errors::InvalidArgument("Return ", ret_def.name(), " -> ",
317 ret_iter->second, " is not found.");
318 }
319 if (dtypes != item->dtypes) {
320 return errors::InvalidArgument("Invalid ret types ", ret_def.name(),
321 " : ", DataTypeVectorString(dtypes),
322 " vs. ",
323 DataTypeVectorString(item->dtypes));
324 }
325 for (size_t i = 0; i < dtypes.size(); ++i) {
326 string name = strings::StrCat(ret_def.name(), "_RetVal");
327 if (dtypes.size() > 1) {
328 strings::StrAppend(&name, "_", i);
329 }
330 NodeDef* gnode = AddNode(name);
331 gnode->set_op("_Retval");
332 AddInput(nodes_.size() - 1, item->nid, item->idx + i);
333 AddAttr("T", dtypes[i], gnode);
334 AddAttr("index", (*ret_index)++, gnode);
335 result_.ret_types.push_back(dtypes[i]);
336 }
337 return Status::OK();
338 }
339
340 // Adds the actual node inputs to the result graph by converting indexes to
341 // the node names.
AddNodeInputs()342 void AddNodeInputs() {
343 for (int i = 0; i < result_.nodes.size(); i++) {
344 NodeInfo& node_info = nodes_[i];
345 for (const auto& p : node_info.data_inputs) {
346 result_.nodes[i].add_input(Name(p.first, p.second));
347 }
348 for (int index : node_info.control_inputs) {
349 result_.nodes[i].add_input(Dep(index));
350 }
351 }
352 }
353
354 private:
355 // This is used to build a small index for all names that can be used as a
356 // node's input arguments.
357 //
358 // If is_func_arg is true, the name is a function's argument. In
359 // this case, the produced graph def has node[nid:nid + dtype.size()].
360 //
361 // Otherwise, the name is a function body's node return value. In
362 // this case, the produced graph def has one node node[nid] and
363 // the node's output index [idx ... idx + num) corresponds to the
364 // named outputs.
365 //
366 // In all cases, "dtype" specifies the data type.
367 struct NameInfoItem {
368 bool is_func_arg;
369 int nid;
370 int idx;
371 bool is_type_list;
372 DataTypeVector dtypes;
373 };
374
375 // Adds an item into the input name index.
AddItem(const string & name,const NameInfoItem & item)376 Status AddItem(const string& name, const NameInfoItem& item) {
377 if (!index_.insert({name, item}).second) {
378 return errors::InvalidArgument(
379 strings::StrCat("Duplicated ", item.is_func_arg ? "arg" : "ret",
380 " name: "),
381 name);
382 }
383 return Status::OK();
384 }
385
GetItemOrNull(const string & name) const386 const NameInfoItem* GetItemOrNull(const string& name) const {
387 return gtl::FindOrNull(index_, name);
388 }
389
Dep(int node_index) const390 string Dep(int node_index) const {
391 return strings::StrCat("^", Name(node_index));
392 }
393
Name(int node_index) const394 string Name(int node_index) const {
395 CHECK_LT(node_index, nodes_.size());
396 return nodes_[node_index].name;
397 }
398
Name(int node_index,int output_index) const399 string Name(int node_index, int output_index) const {
400 if (output_index == 0) {
401 return Name(node_index);
402 } else {
403 return strings::StrCat(Name(node_index), ":", output_index);
404 }
405 }
406
AddNode(const string & name)407 NodeDef* AddNode(const string& name) {
408 result_.nodes.emplace_back();
409 NodeDef* gnode = &result_.nodes.back();
410 gnode->set_name(name);
411 nodes_.push_back({name, {}, {}});
412 CHECK_EQ(result_.nodes.size(), nodes_.size());
413 return gnode;
414 }
415
AddInput(int node_index,int output_node,int output_index)416 void AddInput(int node_index, int output_node, int output_index) {
417 CHECK_LT(node_index, nodes_.size());
418 nodes_[node_index].data_inputs.push_back(
419 std::make_pair(output_node, output_index));
420 }
421
AddDep(int node_index,int dep_index)422 void AddDep(int node_index, int dep_index) {
423 CHECK_LT(node_index, nodes_.size());
424 nodes_[node_index].control_inputs.push_back(dep_index);
425 }
426
427 GetFunctionSignature get_function_;
428 InstantiationResult& result_;
429 // A small index for all names that can be used as a node's input arguments.
430 std::map<string, NameInfoItem> index_;
431 // This contains information about a node in the new graph including the node
432 // names and input nodes' indexes.
433 struct NodeInfo {
434 string name;
435 // Data inputs where <n, k> means arg k of node n.
436 std::vector<std::pair<int, int>> data_inputs;
437 // Control inputs (dependencies).
438 std::vector<int> control_inputs;
439 };
440 // nodes_[i] is the information about result_.nodes[i].
441 std::vector<NodeInfo> nodes_;
442 };
443
444 // Various helpers Print(proto) to print relevant protos to ascii.
Print(const OpDef::ArgDef & arg)445 string Print(const OpDef::ArgDef& arg) {
446 string out;
447 strings::StrAppend(&out, arg.name(), ":");
448 if (arg.is_ref()) strings::StrAppend(&out, "Ref(");
449 if (!arg.number_attr().empty()) {
450 strings::StrAppend(&out, arg.number_attr(), "*");
451 }
452 if (arg.type() != DT_INVALID) {
453 strings::StrAppend(&out, DataTypeString(arg.type()));
454 } else {
455 strings::StrAppend(&out, arg.type_attr());
456 }
457 if (arg.is_ref()) strings::StrAppend(&out, ")");
458 return out;
459 }
460
461 // TODO(josh11b): Merge this with SummarizeAttrValue().
Print(const AttrValue & attr_value)462 string Print(const AttrValue& attr_value) {
463 if (attr_value.value_case() == AttrValue::kType) {
464 return DataTypeString(attr_value.type());
465 } else if ((attr_value.value_case() == AttrValue::kList) &&
466 (attr_value.list().type_size() > 0)) {
467 string ret = "{";
468 for (int i = 0; i < attr_value.list().type_size(); ++i) {
469 if (i > 0) strings::StrAppend(&ret, ", ");
470 strings::StrAppend(&ret, DataTypeString(attr_value.list().type(i)));
471 }
472 strings::StrAppend(&ret, "}");
473 return ret;
474 } else if (attr_value.value_case() == AttrValue::kFunc) {
475 if (attr_value.func().attr_size() == 0) {
476 return attr_value.func().name();
477 }
478 std::vector<string> entries;
479 for (auto p : attr_value.func().attr()) {
480 entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
481 }
482 std::sort(entries.begin(), entries.end());
483 return strings::StrCat(attr_value.func().name(), "[",
484 str_util::Join(entries, ", "), "]");
485 }
486 return SummarizeAttrValue(attr_value);
487 }
488
489 // TODO(josh11b): Merge this with SummarizeNodeDef().
Print(const NodeDef & n)490 string Print(const NodeDef& n) {
491 string out;
492 strings::StrAppend(&out, n.name(), " = ", n.op());
493 if (n.attr_size() > 0) {
494 std::vector<string> entries;
495 for (auto& a : n.attr()) {
496 entries.push_back(strings::StrCat(a.first, "=", Print(a.second)));
497 }
498 std::sort(entries.begin(), entries.end());
499 strings::StrAppend(&out, "[", str_util::Join(entries, ", "), "]");
500 }
501 strings::StrAppend(&out, "(");
502 std::vector<StringPiece> dat;
503 std::vector<string> dep;
504 for (StringPiece s : n.input()) {
505 if (s.Consume("^")) {
506 dep.push_back(s.ToString());
507 } else {
508 dat.push_back(s);
509 }
510 }
511 strings::StrAppend(&out, str_util::Join(dat, ", "), ")");
512 if (!dep.empty()) {
513 strings::StrAppend(&out, " @ ", str_util::Join(dep, ", "));
514 }
515 return out;
516 }
517
Print(const FunctionDef & fdef)518 string Print(const FunctionDef& fdef) {
519 string out;
520 const OpDef& sig = fdef.signature();
521 strings::StrAppend(&out, "\n", sig.name());
522 if (sig.attr_size() > 0) {
523 strings::StrAppend(&out, "[");
524 for (int i = 0; i < sig.attr_size(); ++i) {
525 const auto& a = sig.attr(i);
526 if (i > 0) strings::StrAppend(&out, ", ");
527 if (a.type() == "type") {
528 strings::StrAppend(&out, a.name(), ":", Print(a.allowed_values()));
529 } else {
530 strings::StrAppend(&out, a.name(), ":", a.type());
531 }
532 }
533 strings::StrAppend(&out, "]");
534 }
535 strings::StrAppend(&out, "(");
536 for (int i = 0; i < sig.input_arg_size(); ++i) {
537 if (i > 0) strings::StrAppend(&out, ", ");
538 strings::StrAppend(&out, Print(sig.input_arg(i)));
539 }
540 strings::StrAppend(&out, ") -> (");
541 for (int i = 0; i < sig.output_arg_size(); ++i) {
542 if (i > 0) strings::StrAppend(&out, ", ");
543 strings::StrAppend(&out, Print(sig.output_arg(i)));
544 }
545 strings::StrAppend(&out, ") {\n");
546 for (const auto& n : fdef.node_def()) {
547 strings::StrAppend(&out, " ", Print(n), "\n");
548 }
549 for (const auto& r : fdef.ret()) {
550 strings::StrAppend(&out, " return ", r.first, " = ", r.second, "\n");
551 }
552 strings::StrAppend(&out, "}\n");
553 return out;
554 }
555
Print(gtl::ArraySlice<const NodeDef * > nodes)556 string Print(gtl::ArraySlice<const NodeDef*> nodes) {
557 std::vector<const NodeDef*> arg;
558 std::vector<const NodeDef*> ret;
559 std::vector<const NodeDef*> body;
560 for (const NodeDef* n : nodes) {
561 if (n->op() == "_Arg") {
562 arg.push_back(n);
563 } else if (n->op() == "_Retval") {
564 ret.push_back(n);
565 } else {
566 body.push_back(n);
567 }
568 }
569 auto comp = [](const NodeDef* x, const NodeDef* y) {
570 int xi;
571 TF_CHECK_OK(GetNodeAttr(*x, "index", &xi));
572 int yi;
573 TF_CHECK_OK(GetNodeAttr(*y, "index", &yi));
574 return xi < yi;
575 };
576 std::sort(arg.begin(), arg.end(), comp);
577 std::sort(ret.begin(), ret.end(), comp);
578 string out;
579 strings::StrAppend(&out, "\n(");
580 auto get_type = [](const NodeDef& n) {
581 DataType dt;
582 if (!GetNodeAttr(n, "T", &dt).ok()) {
583 dt = DT_INVALID;
584 }
585 return DataTypeString(dt);
586 };
587 for (size_t i = 0; i < arg.size(); ++i) {
588 const NodeDef* n = arg[i];
589 if (i > 0) strings::StrAppend(&out, ", ");
590 CHECK_GE(n->attr_size(), 2);
591 strings::StrAppend(&out, n->name(), ":", get_type(*n));
592 }
593 strings::StrAppend(&out, ") -> (");
594 for (size_t i = 0; i < ret.size(); ++i) {
595 const NodeDef* n = ret[i];
596 if (i > 0) strings::StrAppend(&out, ", ");
597 CHECK_LE(2, n->attr_size());
598 CHECK_EQ(1, n->input_size());
599 strings::StrAppend(&out, n->input(0), ":", get_type(*n));
600 }
601 strings::StrAppend(&out, ") {\n");
602 for (size_t i = 0; i < body.size(); ++i) {
603 strings::StrAppend(&out, " ", Print(*body[i]), "\n");
604 }
605 strings::StrAppend(&out, "}\n");
606 return out;
607 }
608
AddDefaultAttrs(const string & op,const GetFunctionSignature & get_function,AttrValueMap * attrs)609 Status AddDefaultAttrs(const string& op,
610 const GetFunctionSignature& get_function,
611 AttrValueMap* attrs) {
612 const OpDef* op_def = nullptr;
613 TF_RETURN_IF_ERROR(get_function(op, &op_def));
614 AttrSlice attr_slice(attrs);
615 for (const auto& attr_def : op_def->attr()) {
616 if (attr_def.has_default_value() && !attr_slice.Find(attr_def.name())) {
617 if (!attrs->insert({attr_def.name(), attr_def.default_value()}).second) {
618 return errors::Internal("Somehow duplicated: ", attr_def.name());
619 }
620 }
621 }
622 return Status::OK();
623 }
624
625 } // end namespace
626
InstantiateFunction(const FunctionDef & fdef,AttrSlice attr_values,GetFunctionSignature get_function,InstantiationResult * result)627 Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
628 GetFunctionSignature get_function,
629 InstantiationResult* result) {
630 VLOG(3) << "Instantiation Function: " << Print(fdef);
631
632 const OpDef& sig = fdef.signature();
633 TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values));
634
635 FunctionInstantiationHelper helper(get_function, result);
636 Status s;
637 for (const OpDef::ArgDef& arg_def : sig.input_arg()) {
638 s = helper.BuildInputArgIndex(arg_def, attr_values);
639 if (!s.ok()) {
640 errors::AppendToMessage(&s, "In ", Print(arg_def));
641 return s;
642 }
643 }
644
645 auto substitute = [attr_values](StringPiece name, AttrValue* val) {
646 if (const AttrValue* v = attr_values.Find(name)) {
647 *val = *v;
648 return true;
649 }
650 return false;
651 };
652
653 // Makes a copy of all attrs in fdef and substitutes placeholders.
654 // After this step, every attr is bound to a concrete value.
655 std::vector<AttrValueMap> node_attrs;
656 node_attrs.resize(fdef.node_def_size());
657 for (int i = 0; i < fdef.node_def_size(); ++i) {
658 for (auto attr : fdef.node_def(i).attr()) {
659 if (!SubstitutePlaceholders(substitute, &attr.second)) {
660 return errors::InvalidArgument("Failed to bind all placeholders in ",
661 SummarizeAttrValue(attr.second));
662 }
663 if (!node_attrs[i].insert(attr).second) {
664 return errors::Internal("Somehow duplicated: ", attr.first);
665 }
666 }
667 TF_RETURN_IF_ERROR(
668 AddDefaultAttrs(fdef.node_def(i).op(), get_function, &node_attrs[i]));
669 }
670
671 for (int i = 0; i < fdef.node_def_size(); ++i) {
672 s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]),
673 result->nodes.size() + i);
674 if (!s.ok()) {
675 errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i)));
676 return s;
677 }
678 }
679 // Emits one node for each fdef.node_def.
680 for (int i = 0; i < fdef.node_def_size(); ++i) {
681 s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i]));
682 if (!s.ok()) {
683 errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i)));
684 return s;
685 }
686 }
687
688 // Emits nodes for the function's return values.
689 int ret_index = 0;
690 for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
691 s = helper.AddReturnNode(ret_def, attr_values, fdef.ret(), &ret_index);
692 if (!s.ok()) {
693 errors::AppendToMessage(&s, "In function output ", Print(ret_def));
694 return s;
695 }
696 }
697
698 // Adds the actual node inputs using the input indexes.
699 helper.AddNodeInputs();
700
701 return Status::OK();
702 }
703
DebugString(const FunctionDef & func_def)704 string DebugString(const FunctionDef& func_def) { return Print(func_def); }
705
DebugString(const GraphDef & instantiated_func_def)706 string DebugString(const GraphDef& instantiated_func_def) {
707 std::vector<const NodeDef*> ptrs;
708 for (const NodeDef& n : instantiated_func_def.node()) {
709 ptrs.push_back(&n);
710 }
711 return Print(ptrs);
712 }
713
DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes)714 string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes) {
715 std::vector<const NodeDef*> ptrs;
716 for (const NodeDef& n : instantiated_func_nodes) {
717 ptrs.push_back(&n);
718 }
719 return Print(ptrs);
720 }
721
DebugStringWhole(const GraphDef & gdef)722 string DebugStringWhole(const GraphDef& gdef) {
723 string ret;
724 for (const auto& fdef : gdef.library().function()) {
725 strings::StrAppend(&ret, Print(fdef));
726 }
727 strings::StrAppend(&ret, "\n");
728 for (const auto& ndef : gdef.node()) {
729 strings::StrAppend(&ret, Print(ndef), "\n");
730 }
731 return ret;
732 }
733
734 namespace {
735
736 // Returns the name -> attr mapping of fdef's attrs that have a value set. In
737 // Python, it's possible to access unset attrs, which returns a default value
738 // and adds an unset attr to the map.
GetSetAttrs(const FunctionDef & fdef)739 std::map<string, AttrValue> GetSetAttrs(const FunctionDef& fdef) {
740 std::map<string, AttrValue> set_attrs;
741 for (auto pair : fdef.attr()) {
742 if (pair.second.value_case() != AttrValue::VALUE_NOT_SET) {
743 set_attrs[pair.first] = pair.second;
744 }
745 }
746 return set_attrs;
747 }
748
749 } // end namespace
750
FunctionDefsEqual(const FunctionDef & f1,const FunctionDef & f2)751 bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
752 if (!OpDefEqual(f1.signature(), f2.signature())) return false;
753
754 std::map<string, AttrValue> f1_attrs = GetSetAttrs(f1);
755 std::map<string, AttrValue> f2_attrs = GetSetAttrs(f2);
756 if (f1_attrs.size() != f2_attrs.size()) return false;
757 for (auto iter1 : f1_attrs) {
758 auto iter2 = f2_attrs.find(iter1.first);
759 if (iter2 == f2_attrs.end()) return false;
760 if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false;
761 }
762
763 if (!EqualRepeatedNodeDef(f1.node_def(), f2.node_def(), nullptr)) {
764 return false;
765 }
766
767 std::map<string, string> ret1(f1.ret().begin(), f1.ret().end());
768 std::map<string, string> ret2(f2.ret().begin(), f2.ret().end());
769 if (ret1 != ret2) return false;
770
771 return true;
772 }
773
FunctionDefHash(const FunctionDef & fdef)774 uint64 FunctionDefHash(const FunctionDef& fdef) {
775 // signature
776 uint64 h = OpDefHash(fdef.signature());
777
778 // attrs
779 std::map<string, AttrValue> attrs = GetSetAttrs(fdef);
780 for (const auto& p : attrs) {
781 h = Hash64(p.first.data(), p.first.size(), h);
782 h = Hash64Combine(AttrValueHash(p.second), h);
783 }
784
785 // node defs
786 h = Hash64Combine(RepeatedNodeDefHash(fdef.node_def()), h);
787
788 // output names
789 std::map<string, string> ret(fdef.ret().begin(), fdef.ret().end());
790 for (const auto& p : ret) {
791 h = Hash64(p.first.data(), p.first.size(), h);
792 h = Hash64(p.second.data(), p.second.size(), h);
793 }
794
795 return h;
796 }
797
Canonicalize(const string & funcname,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options)798 string Canonicalize(const string& funcname, AttrSlice attrs,
799 const FunctionLibraryRuntime::InstantiateOptions& options) {
800 std::vector<string> entries;
801 entries.reserve(options.target.empty() ? attrs.size() : (attrs.size() + 1));
802 for (auto p : attrs) {
803 entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
804 }
805 if (!options.target.empty()) {
806 entries.push_back(
807 strings::StrCat("_target", "=", str_util::CEscape(options.target)));
808 }
809 if (options.overlay_lib) {
810 entries.push_back(strings::StrCat(
811 "_overlay_lib", "=", reinterpret_cast<uintptr_t>(options.overlay_lib)));
812 }
813 if (!options.state_handle.empty()) {
814 entries.push_back(
815 strings::StrCat("_state_handle", "=", options.state_handle));
816 }
817 std::sort(entries.begin(), entries.end());
818 return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]");
819 }
820
FunctionCallFrame(DataTypeSlice arg_types,DataTypeSlice ret_types)821 FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types,
822 DataTypeSlice ret_types)
823 : arg_types_(arg_types.begin(), arg_types.end()),
824 ret_types_(ret_types.begin(), ret_types.end()) {
825 args_.resize(arg_types_.size());
826 rets_.resize(ret_types_.size());
827 }
828
~FunctionCallFrame()829 FunctionCallFrame::~FunctionCallFrame() {}
830
SetArgs(gtl::ArraySlice<Tensor> args)831 Status FunctionCallFrame::SetArgs(gtl::ArraySlice<Tensor> args) {
832 // Input type checks.
833 if (args.size() != arg_types_.size()) {
834 return errors::InvalidArgument("Expects ", arg_types_.size(),
835 " arguments, but ", args.size(),
836 " is provided");
837 }
838 for (size_t i = 0; i < args.size(); ++i) {
839 if (arg_types_[i] != args[i].dtype()) {
840 return errors::InvalidArgument(
841 "Expects arg[", i, "] to be ", DataTypeString(arg_types_[i]), " but ",
842 DataTypeString(args[i].dtype()), " is provided");
843 }
844 args_[i] = args[i];
845 }
846 return Status::OK();
847 }
848
GetRetvals(std::vector<Tensor> * rets) const849 Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
850 rets->clear();
851 rets->reserve(rets_.size());
852 for (size_t i = 0; i < rets_.size(); ++i) {
853 const auto& item = rets_[i];
854 if (item.has_val) {
855 rets->push_back(item.val);
856 } else {
857 return errors::Internal("Retval[", i, "] does not have value");
858 }
859 }
860 return Status::OK();
861 }
862
ConsumeRetvals(std::vector<Tensor> * rets)863 Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets) {
864 rets->clear();
865 rets->reserve(rets_.size());
866 for (size_t i = 0; i < rets_.size(); ++i) {
867 if (rets_[i].has_val) {
868 rets->emplace_back(std::move(rets_[i].val));
869 } else {
870 return errors::Internal("Retval[", i, "] does not have value");
871 }
872 }
873 return Status::OK();
874 }
875
GetArg(int index,Tensor * val) const876 Status FunctionCallFrame::GetArg(int index, Tensor* val) const {
877 if (index < 0 || static_cast<size_t>(index) >= args_.size()) {
878 return errors::InvalidArgument("GetArg ", index, " is not within [0, ",
879 args_.size(), ")");
880 }
881 *val = args_[index];
882 return Status::OK();
883 }
884
SetRetval(int index,const Tensor & val)885 Status FunctionCallFrame::SetRetval(int index, const Tensor& val) {
886 if (index < 0 || static_cast<size_t>(index) >= rets_.size()) {
887 return errors::InvalidArgument("SetRetval ", index, " is not within [0, ",
888 rets_.size(), ")");
889 }
890 if (val.dtype() != ret_types_[index]) {
891 return errors::InvalidArgument(
892 "Expects ret[", index, "] to be ", DataTypeString(ret_types_[index]),
893 ", but ", DataTypeString(val.dtype()), " is provided.");
894 }
895 Retval* item = &rets_[index];
896 if (!item->has_val) {
897 item->has_val = true;
898 item->val = val;
899 } else {
900 return errors::Internal("Retval[", index, "] has already been set.");
901 }
902 return Status::OK();
903 }
904
905 FunctionLibraryDefinition::FunctionDefAndOpRegistration::
FunctionDefAndOpRegistration(const FunctionDef & fdef_in)906 FunctionDefAndOpRegistration(const FunctionDef& fdef_in)
907 : fdef(fdef_in),
908 // Exact shape inference for functions is handled by ShapeRefiner.
909 // Here we pass a dummy shape inference function for legacy code paths.
910 op_registration_data(fdef.signature(), shape_inference::UnknownShape,
911 true /* is_function */) {}
912
FunctionLibraryDefinition(const FunctionLibraryDefinition & other)913 FunctionLibraryDefinition::FunctionLibraryDefinition(
914 const FunctionLibraryDefinition& other)
915 : default_registry_(other.default_registry_), func_grad_(other.func_grad_) {
916 for (const auto& it : other.function_defs_) {
917 TF_CHECK_OK(AddFunctionDef(it.second->fdef));
918 }
919 }
920
FunctionLibraryDefinition(const OpRegistryInterface * default_registry,const FunctionDefLibrary & def_lib)921 FunctionLibraryDefinition::FunctionLibraryDefinition(
922 const OpRegistryInterface* default_registry,
923 const FunctionDefLibrary& def_lib)
924 : default_registry_(default_registry),
925 function_defs_(def_lib.function_size()) {
926 for (const auto& fdef : def_lib.function()) {
927 // The latter function definition wins.
928 auto& ptr = function_defs_[fdef.signature().name()];
929 ptr.reset(new FunctionDefAndOpRegistration(fdef));
930 }
931 for (const auto& grad : def_lib.gradient()) {
932 func_grad_[grad.function_name()] = grad.gradient_func();
933 }
934 }
935
~FunctionLibraryDefinition()936 FunctionLibraryDefinition::~FunctionLibraryDefinition() {}
937
Find(const string & name) const938 const FunctionDef* FunctionLibraryDefinition::Find(const string& name) const {
939 auto iter = function_defs_.find(name);
940 if (iter == function_defs_.end()) {
941 return nullptr;
942 } else {
943 return &iter->second->fdef;
944 }
945 }
946
AddFunctionDef(const FunctionDef & fdef)947 Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
948 bool added;
949 return AddFunctionDefHelper(fdef, &added);
950 }
951
AddFunctionDefHelper(const FunctionDef & fdef,bool * added)952 Status FunctionLibraryDefinition::AddFunctionDefHelper(const FunctionDef& fdef,
953 bool* added) {
954 *added = false;
955 std::unique_ptr<FunctionDefAndOpRegistration>* entry =
956 &function_defs_[fdef.signature().name()];
957 if (*entry != nullptr) {
958 if (!FunctionDefsEqual((*entry)->fdef, fdef)) {
959 return errors::InvalidArgument(
960 "Cannot add function '", fdef.signature().name(),
961 "' because a different function with the same name already "
962 "exists.");
963 }
964 // Ignore duplicate FunctionDefs
965 return Status::OK();
966 }
967 const OpDef* op_def;
968 if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) {
969 return errors::InvalidArgument(
970 "Cannot add function '", fdef.signature().name(),
971 "' because an op with the same name already exists.");
972 }
973 entry->reset(new FunctionDefAndOpRegistration(fdef));
974 *added = true;
975 return Status::OK();
976 }
977
AddGradientDef(const GradientDef & grad)978 Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
979 bool added;
980 return AddGradientDefHelper(grad, &added);
981 }
982
AddGradientDefHelper(const GradientDef & grad,bool * added)983 Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad,
984 bool* added) {
985 *added = false;
986 string* entry = &func_grad_[grad.function_name()];
987 if (!entry->empty()) {
988 if (*entry != grad.gradient_func()) {
989 return errors::InvalidArgument(
990 "Cannot assign gradient function '", grad.gradient_func(), "' to '",
991 grad.function_name(), "' because it already has gradient function ",
992 "'", *entry, "'");
993 }
994 // Ignore duplicate GradientDefs
995 return Status::OK();
996 }
997 *entry = grad.gradient_func();
998 *added = true;
999 return Status::OK();
1000 }
1001
AddLibrary(const FunctionLibraryDefinition & other)1002 Status FunctionLibraryDefinition::AddLibrary(
1003 const FunctionLibraryDefinition& other) {
1004 // Remember the funcs and grads that we added successfully so that
1005 // we can roll them back on error.
1006 std::vector<string> funcs;
1007 std::vector<string> funcs_with_grads;
1008 Status s;
1009 bool added;
1010 for (auto iter : other.function_defs_) {
1011 s = AddFunctionDefHelper(iter.second->fdef, &added);
1012 if (!s.ok()) {
1013 Remove(funcs, funcs_with_grads);
1014 return s;
1015 }
1016 if (added) {
1017 funcs.push_back(iter.second->fdef.signature().name());
1018 }
1019 }
1020 for (auto iter : other.func_grad_) {
1021 GradientDef grad;
1022 grad.set_function_name(iter.first);
1023 grad.set_gradient_func(iter.second);
1024 s = AddGradientDefHelper(grad, &added);
1025 if (!s.ok()) {
1026 Remove(funcs, funcs_with_grads);
1027 return s;
1028 }
1029 if (added) {
1030 funcs_with_grads.push_back(grad.function_name());
1031 }
1032 }
1033 return Status::OK();
1034 }
1035
AddLibrary(const FunctionDefLibrary & lib_def)1036 Status FunctionLibraryDefinition::AddLibrary(
1037 const FunctionDefLibrary& lib_def) {
1038 // Remember the funcs and grads that we added successfully so that
1039 // we can roll them back on error.
1040 std::vector<string> funcs;
1041 std::vector<string> funcs_with_grads;
1042 Status s;
1043 bool added;
1044 for (const FunctionDef& fdef : lib_def.function()) {
1045 s = AddFunctionDefHelper(fdef, &added);
1046 if (!s.ok()) {
1047 Remove(funcs, funcs_with_grads);
1048 return s;
1049 }
1050 if (added) {
1051 funcs.push_back(fdef.signature().name());
1052 }
1053 }
1054 for (const GradientDef& grad : lib_def.gradient()) {
1055 s = AddGradientDefHelper(grad, &added);
1056 if (!s.ok()) {
1057 Remove(funcs, funcs_with_grads);
1058 return s;
1059 }
1060 if (added) {
1061 funcs_with_grads.push_back(grad.function_name());
1062 }
1063 }
1064 return Status::OK();
1065 }
1066
RemoveFunction(const string & func)1067 Status FunctionLibraryDefinition::RemoveFunction(const string& func) {
1068 const auto& i = function_defs_.find(func);
1069 if (i == function_defs_.end()) {
1070 return errors::InvalidArgument("Tried to remove non-existent function ",
1071 func);
1072 }
1073 function_defs_.erase(i);
1074 return Status::OK();
1075 }
1076
RemoveGradient(const string & func)1077 Status FunctionLibraryDefinition::RemoveGradient(const string& func) {
1078 const auto& i = func_grad_.find(func);
1079 if (i == func_grad_.end()) {
1080 return errors::InvalidArgument("Tried to remove non-existent gradient ",
1081 func);
1082 }
1083 func_grad_.erase(i);
1084 return Status::OK();
1085 }
1086
Remove(const std::vector<string> & funcs,const std::vector<string> & funcs_with_grads)1087 void FunctionLibraryDefinition::Remove(
1088 const std::vector<string>& funcs,
1089 const std::vector<string>& funcs_with_grads) {
1090 for (const string& f : funcs) {
1091 Status s = RemoveFunction(f);
1092 DCHECK(s.ok());
1093 }
1094 for (const string& f : funcs_with_grads) {
1095 Status s = RemoveGradient(f);
1096 DCHECK(s.ok());
1097 }
1098 }
1099
FindGradient(const string & func) const1100 string FunctionLibraryDefinition::FindGradient(const string& func) const {
1101 return gtl::FindWithDefault(func_grad_, func, "");
1102 }
1103
LookUp(const string & op,const OpRegistrationData ** op_reg_data) const1104 Status FunctionLibraryDefinition::LookUp(
1105 const string& op, const OpRegistrationData** op_reg_data) const {
1106 auto iter = function_defs_.find(op);
1107 if (iter != function_defs_.end()) {
1108 *op_reg_data = &iter->second->op_registration_data;
1109 return Status::OK();
1110 }
1111 return default_registry_->LookUp(op, op_reg_data);
1112 }
1113
GetAttrImpl(const NodeDef & ndef) const1114 const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
1115 const NodeDef& ndef) const {
1116 if (ndef.op() != kGradientOp) {
1117 // If 'ndef' calls a function and the function's def has the attr,
1118 // returns it.
1119 return Find(ndef.op());
1120 }
1121
1122 // If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or
1123 // Foo's attributes.
1124 const NameAttrList* forward_func_attrs;
1125 if (!GetNodeAttr(ndef, kFuncAttr, &forward_func_attrs).ok()) {
1126 return nullptr;
1127 }
1128 const string& func_name = forward_func_attrs->name();
1129 const string& grad_name = FindGradient(func_name);
1130 // If 'func' has a user-defined gradient function, uses the grad
1131 // function's attrs to see if noinline is specified. Otherwise,
1132 // uses func's attrs.
1133 if (!grad_name.empty()) {
1134 return Find(grad_name);
1135 }
1136 return Find(func_name);
1137 }
1138
ToProto() const1139 FunctionDefLibrary FunctionLibraryDefinition::ToProto() const {
1140 FunctionDefLibrary lib;
1141 for (const auto& f : function_defs_) {
1142 *lib.add_function() = f.second->fdef;
1143 }
1144 for (const auto& g : func_grad_) {
1145 GradientDef* gd = lib.add_gradient();
1146 gd->set_function_name(g.first);
1147 gd->set_gradient_func(g.second);
1148 }
1149 return lib;
1150 }
1151
1152 template <typename T>
GetAttr(const NodeDef & ndef,const string & attr,T * value) const1153 Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
1154 const string& attr, T* value) const {
1155 const FunctionDef* fdef = GetAttrImpl(ndef);
1156 if (fdef && GetNodeAttr(AttrSlice(&fdef->attr()), attr, value).ok()) {
1157 return Status::OK();
1158 }
1159 return errors::InvalidArgument("Attr ", attr, " is not defined.");
1160 }
1161
1162 template <typename T>
GetAttr(const Node & node,const string & attr,T * value) const1163 Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr,
1164 T* value) const {
1165 return GetAttr(node.def(), attr, value);
1166 }
1167
1168 #define GET_ATTR(T) \
1169 template Status FunctionLibraryDefinition::GetAttr(const Node&, \
1170 const string&, T*) const; \
1171 template Status FunctionLibraryDefinition::GetAttr(const NodeDef&, \
1172 const string&, T*) const;
1173 GET_ATTR(string)
GET_ATTR(bool)1174 GET_ATTR(bool)
1175 #undef GET_ATTR
1176
1177 void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) {
1178 if (val.size() >= 2 && val[0] == '$') {
1179 proto.set_placeholder(val.data() + 1, val.size() - 1);
1180 } else {
1181 SetAttrValue(val, &proto);
1182 }
1183 }
1184
FunctionRef(const string & name,gtl::ArraySlice<std::pair<string,AttrValueWrapper>> attrs)1185 FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef(
1186 const string& name,
1187 gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs) {
1188 AttrValueWrapper ret;
1189 ret.proto.mutable_func()->set_name(name);
1190 for (const auto& a : attrs) {
1191 ret.proto.mutable_func()->mutable_attr()->insert({a.first, a.second.proto});
1192 }
1193 return ret;
1194 }
1195
ToNodeDef() const1196 NodeDef FunctionDefHelper::Node::ToNodeDef() const {
1197 NodeDef n;
1198 n.set_op(this->op);
1199 n.set_name(this->ret[0]);
1200 for (const auto& a : this->attr) {
1201 n.mutable_attr()->insert({a.first, a.second.proto});
1202 }
1203 for (const string& a : this->arg) {
1204 n.add_input(a);
1205 }
1206 for (const string& d : this->dep) {
1207 n.add_input(strings::StrCat("^", d));
1208 }
1209 return n;
1210 }
1211
1212 /* static */
Create(const string & function_name,gtl::ArraySlice<string> in_def,gtl::ArraySlice<string> out_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def,gtl::ArraySlice<std::pair<string,string>> ret_def)1213 FunctionDef FunctionDefHelper::Create(
1214 const string& function_name, gtl::ArraySlice<string> in_def,
1215 gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
1216 gtl::ArraySlice<Node> node_def,
1217 gtl::ArraySlice<std::pair<string, string>> ret_def) {
1218 FunctionDef fdef;
1219
1220 // Signature
1221 OpDefBuilder b(function_name);
1222 for (const auto& i : in_def) b.Input(i);
1223 for (const auto& o : out_def) b.Output(o);
1224 for (const auto& a : attr_def) b.Attr(a);
1225
1226 OpRegistrationData op_reg_data;
1227 TF_CHECK_OK(b.Finalize(&op_reg_data));
1228 fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1229
1230 // Function body
1231 for (const auto& n : node_def) {
1232 *(fdef.add_node_def()) = n.ToNodeDef();
1233 }
1234
1235 // Returns
1236 for (const auto& r : ret_def) {
1237 fdef.mutable_ret()->insert({r.first, r.second});
1238 }
1239 return fdef;
1240 }
1241
1242 /* static */
Define(const string & name,gtl::ArraySlice<string> arg_def,gtl::ArraySlice<string> ret_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def)1243 FunctionDef FunctionDefHelper::Define(const string& name,
1244 gtl::ArraySlice<string> arg_def,
1245 gtl::ArraySlice<string> ret_def,
1246 gtl::ArraySlice<string> attr_def,
1247 gtl::ArraySlice<Node> node_def) {
1248 FunctionDef fdef;
1249 OpDefBuilder b(name);
1250 for (const auto& a : arg_def) b.Input(a);
1251 for (const auto& r : ret_def) b.Output(r);
1252 for (const auto& a : attr_def) b.Attr(a);
1253
1254 OpRegistrationData op_reg_data;
1255 TF_CHECK_OK(b.Finalize(&op_reg_data));
1256 fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1257
1258 // Mapping from legacy output names to NodeDef outputs.
1259 std::unordered_map<string, string> ret_index;
1260 for (const auto& a : fdef.signature().input_arg()) {
1261 ret_index[a.name()] = a.name();
1262 }
1263
1264 // For looking up OpDefs
1265 auto* op_def_registry = OpRegistry::Global();
1266
1267 // Function body
1268 for (const auto& src : node_def) {
1269 NodeDef* n = fdef.add_node_def();
1270 n->set_op(src.op);
1271 n->set_name(src.ret[0]);
1272 for (const auto& a : src.attr) {
1273 n->mutable_attr()->insert({a.first, a.second.proto});
1274 }
1275 for (const string& a : src.arg) {
1276 const auto iter = ret_index.find(a);
1277 CHECK(iter != ret_index.end())
1278 << "Node input '" << a << "' in '" << src.ret[0] << "' of " << name;
1279 n->add_input(iter->second);
1280 }
1281 for (const string& d : src.dep) {
1282 n->add_input(strings::StrCat("^", d));
1283 }
1284
1285 // Add the outputs of this node to ret_index.
1286 const OpDef* op_def = nullptr;
1287 TF_CHECK_OK(op_def_registry->LookUpOpDef(n->op(), &op_def)) << n->op();
1288 CHECK(op_def != nullptr) << n->op();
1289 NameRangeMap output_names;
1290 TF_CHECK_OK(NameRangesForNode(*n, *op_def, nullptr, &output_names));
1291 for (const auto& o : output_names) {
1292 CHECK_LE(o.second.second, src.ret.size())
1293 << "Missing ret for output '" << o.first << "' in '" << src.ret[0]
1294 << "' of " << name;
1295 for (int i = o.second.first; i < o.second.second; ++i) {
1296 ret_index[src.ret[i]] =
1297 strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first);
1298 }
1299 }
1300 }
1301
1302 // Returns
1303 for (const auto& r : fdef.signature().output_arg()) {
1304 const auto iter = ret_index.find(r.name());
1305 CHECK(iter != ret_index.end()) << "Return '" << r.name() << "' in " << name;
1306 fdef.mutable_ret()->insert({r.name(), iter->second});
1307 }
1308 return fdef;
1309 }
1310
Define(gtl::ArraySlice<string> arg_def,gtl::ArraySlice<string> ret_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def)1311 FunctionDef FunctionDefHelper::Define(gtl::ArraySlice<string> arg_def,
1312 gtl::ArraySlice<string> ret_def,
1313 gtl::ArraySlice<string> attr_def,
1314 gtl::ArraySlice<Node> node_def) {
1315 return Define("_", arg_def, ret_def, attr_def, node_def);
1316 }
1317
1318 namespace gradient {
1319
1320 typedef std::unordered_map<string, Creator> OpGradFactory;
1321
GetOpGradFactory()1322 OpGradFactory* GetOpGradFactory() {
1323 static OpGradFactory* factory = new OpGradFactory;
1324 return factory;
1325 }
1326
RegisterOp(const string & op,Creator func)1327 bool RegisterOp(const string& op, Creator func) {
1328 CHECK(GetOpGradFactory()->insert({op, func}).second)
1329 << "Duplicated gradient for " << op;
1330 return true;
1331 }
1332
GetOpGradientCreator(const string & op,Creator * creator)1333 Status GetOpGradientCreator(const string& op, Creator* creator) {
1334 auto fac = GetOpGradFactory();
1335 auto iter = fac->find(op);
1336 if (iter == fac->end()) {
1337 return errors::NotFound("No gradient defined for op: ", op);
1338 }
1339 *creator = iter->second;
1340 return Status::OK();
1341 }
1342
1343 } // end namespace gradient
1344
1345 } // end namespace tensorflow
1346