1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/grappler/utils/functions.h"
16 
17 #include "absl/container/flat_hash_map.h"
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_replace.h"
21 #include "absl/strings/substitute.h"
22 #include "tensorflow/core/common_runtime/function.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/function.pb.h"
26 #include "tensorflow/core/framework/graph_def_util.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/tensor_shape.pb.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/framework/versions.pb.h"
32 #include "tensorflow/core/grappler/op_types.h"
33 #include "tensorflow/core/grappler/utils.h"
34 #include "tensorflow/core/lib/strings/scanner.h"
35 
36 namespace tensorflow {
37 namespace grappler {
38 
GrapplerFunctionItem(string func_name,string description,AttrSlice func_attr,std::vector<const FunctionDef::ArgAttrs * > arg_attr,std::vector<InputArgInstantiation> input_args,std::vector<OutputArgInstantiation> output_args,std::vector<ControlOutput> control_outputs,const int graph_def_version,const bool is_stateful,GraphDef && function_body)39 GrapplerFunctionItem::GrapplerFunctionItem(
40     string func_name, string description, AttrSlice func_attr,
41     std::vector<const FunctionDef::ArgAttrs*> arg_attr,
42     std::vector<InputArgInstantiation> input_args,
43     std::vector<OutputArgInstantiation> output_args,
44     std::vector<ControlOutput> control_outputs, const int graph_def_version,
45     const bool is_stateful, GraphDef&& function_body)
46     : description_(std::move(description)),
47       func_attr_(func_attr),
48       arg_attr_(std::move(arg_attr)),
49       input_args_(std::move(input_args)),
50       output_args_(std::move(output_args)),
51       control_outputs_(std::move(control_outputs)),
52       is_stateful_(is_stateful) {
53   id = std::move(func_name);
54   graph = std::move(function_body);
55   graph.mutable_versions()->set_producer(graph_def_version);
56 
57   // Fill the feed nodes with function input arguments.
58   for (const InputArgInstantiation& input_arg : input_args_) {
59     feed.push_back({input_arg.node_name, Tensor()});
60   }
61   // Fill the fetch nodes with outputs.
62   for (const OutputArgInstantiation& output_arg : output_args_) {
63     fetch.push_back(output_arg.node_name);
64   }
65   // We must keep all control output nodes.
66   for (const ControlOutput& control_output : control_outputs_) {
67     keep_ops.push_back(control_output.node_name);
68   }
69 
70   // Tensorflow functions execution semantics is different from the main graph,
71   // and we need to preserve it when we do graph optimizations.
72   optimization_options().allow_pruning_stateful_and_dataset_ops = false;
73 }
74 
description() const75 const string& GrapplerFunctionItem::description() const { return description_; }
76 
inputs() const77 const std::vector<InputArgInstantiation>& GrapplerFunctionItem::inputs() const {
78   return input_args_;
79 }
80 
input(int i) const81 const InputArgInstantiation& GrapplerFunctionItem::input(int i) const {
82   return input_args_[i];
83 }
84 
input_size() const85 const std::size_t GrapplerFunctionItem::input_size() const {
86   return input_args_.size();
87 }
88 
outputs() const89 const std::vector<OutputArgInstantiation>& GrapplerFunctionItem::outputs()
90     const {
91   return output_args_;
92 }
93 
output(int i) const94 const OutputArgInstantiation& GrapplerFunctionItem::output(int i) const {
95   return output_args_[i];
96 }
97 
output_size() const98 const std::size_t GrapplerFunctionItem::output_size() const {
99   return output_args_.size();
100 }
101 
control_outputs() const102 const std::vector<ControlOutput>& GrapplerFunctionItem::control_outputs()
103     const {
104   return control_outputs_;
105 }
106 
control_output_size() const107 const std::size_t GrapplerFunctionItem::control_output_size() const {
108   return control_outputs_.size();
109 }
110 
func_attr() const111 const AttrSlice& GrapplerFunctionItem::func_attr() const { return func_attr_; }
112 
113 const std::vector<const FunctionDef::ArgAttrs*>&
arg_attr() const114 GrapplerFunctionItem::arg_attr() const {
115   return arg_attr_;
116 }
117 
function_body() const118 const GraphDef& GrapplerFunctionItem::function_body() const { return graph; }
119 
mutable_function_body()120 GraphDef& GrapplerFunctionItem::mutable_function_body() { return graph; }
121 
is_stateful() const122 bool GrapplerFunctionItem::is_stateful() const { return is_stateful_; }
123 
SwapFunctionBody(GraphDef && other)124 GrapplerFunctionItem& GrapplerFunctionItem::SwapFunctionBody(GraphDef&& other) {
125   graph.Swap(&other);
126   return *this;
127 }
128 
HasParametrizedType(const FunctionDef & func)129 bool HasParametrizedType(const FunctionDef& func) {
130   const auto is_type_parametrized = [](const OpDef::ArgDef& arg) {
131     return !arg.type_attr().empty() || !arg.number_attr().empty() ||
132            !arg.type_list_attr().empty();
133   };
134 
135   const auto& input = func.signature().input_arg();
136   const auto& output = func.signature().output_arg();
137   return std::any_of(input.begin(), input.end(), is_type_parametrized) ||
138          std::any_of(output.begin(), output.end(), is_type_parametrized);
139 }
140 
HasParametrizedBody(const FunctionDef & func)141 bool HasParametrizedBody(const FunctionDef& func) {
142   const auto is_parametrized = [&](const NodeDef& node) {
143     for (const auto& attr : node.attr()) {
144       if (!attr.second.placeholder().empty()) return true;
145     }
146     return false;
147   };
148   return std::any_of(func.node_def().begin(), func.node_def().end(),
149                      is_parametrized);
150 }
151 
IsParametrized(const FunctionDef & func)152 bool IsParametrized(const FunctionDef& func) {
153   return HasParametrizedType(func) || HasParametrizedBody(func);
154 }
155 
InstantiationTypeParameters(const FunctionDef & func,const AttrSlice & func_instantiation_attr,absl::flat_hash_map<string,DataType> * type_parameters)156 Status InstantiationTypeParameters(
157     const FunctionDef& func, const AttrSlice& func_instantiation_attr,
158     absl::flat_hash_map<string, DataType>* type_parameters) {
159   if (!type_parameters->empty()) {
160     return errors::InvalidArgument("Type parameters output map must be empty");
161   }
162 
163   const auto resolve_type_attr = [&](const OpDef::ArgDef& arg) -> Status {
164     if (!arg.type_attr().empty()) {
165       DataType dtype;
166       TF_RETURN_IF_ERROR(
167           GetNodeAttr(func_instantiation_attr, arg.type_attr(), &dtype));
168       type_parameters->emplace(arg.type_attr(), dtype);
169 
170     } else if (!arg.type_list_attr().empty()) {
171       std::vector<DataType> dtypes;
172       TF_RETURN_IF_ERROR(
173           GetNodeAttr(func_instantiation_attr, arg.type_list_attr(), &dtypes));
174       int index = 0;
175       for (const DataType& dtype : dtypes) {
176         type_parameters->emplace(absl::StrCat(arg.type_list_attr(), ":", index),
177                                  dtype);
178         ++index;
179       }
180     }
181     return Status::OK();
182   };
183 
184   for (const auto& input : func.signature().input_arg())
185     TF_RETURN_IF_ERROR(resolve_type_attr(input));
186   for (const auto& output : func.signature().output_arg())
187     TF_RETURN_IF_ERROR(resolve_type_attr(output));
188 
189   return Status::OK();
190 }
191 
InstantiationBodyParameters(const FunctionDef & func,const AttrSlice & func_instantiation_attr,absl::flat_hash_map<string,AttrValue> * body_parameters)192 Status InstantiationBodyParameters(
193     const FunctionDef& func, const AttrSlice& func_instantiation_attr,
194     absl::flat_hash_map<string, AttrValue>* body_parameters) {
195   if (!body_parameters->empty()) {
196     return errors::InvalidArgument("Body parameters output map must be empty");
197   }
198 
199   for (const NodeDef& func_body_node : func.node_def()) {
200     for (auto& attr : func_body_node.attr()) {
201       const string& placeholder = attr.second.placeholder();
202 
203       if (placeholder.empty() || body_parameters->contains(placeholder)) {
204         continue;
205       }
206 
207       const AttrValue* placeholder_value =
208           func_instantiation_attr.Find(placeholder);
209       if (placeholder_value) {
210         body_parameters->insert({placeholder, *placeholder_value});
211       } else {
212         return errors::InvalidArgument("Can't resolve placeholder: ",
213                                        placeholder);
214       }
215     }
216   }
217 
218   return Status::OK();
219 }
220 
MakeGrapplerFunctionItem(const FunctionDef & func,const AttrSlice & func_instantiation_attr,const FunctionLibraryDefinition & flib,const int graph_def_version,GrapplerFunctionItem * item)221 Status MakeGrapplerFunctionItem(const FunctionDef& func,
222                                 const AttrSlice& func_instantiation_attr,
223                                 const FunctionLibraryDefinition& flib,
224                                 const int graph_def_version,
225                                 GrapplerFunctionItem* item) {
226   const OpDef& signature = func.signature();
227 
228   if (signature.name().empty()) {
229     return errors::InvalidArgument("Function name must be specified");
230   }
231 
232   // Function types will be resolved from function instantiation attributes. All
233   // other attributes will be lost during conversion to FunctionDef.
234   for (const OpDef::AttrDef& attr : signature.attr()) {
235     if (attr.type() != "type") {
236       return errors::InvalidArgument(
237           "Function signature must have only type attributes");
238     }
239   }
240 
241   // Instantiate function into a statically defined FunctionBody Graph.
242   std::unique_ptr<FunctionBody> fbody;
243   TF_RETURN_IF_ERROR(
244       FunctionDefToBodyHelper(func, func_instantiation_attr, &flib, &fbody));
245 
246   GraphDef function_body;
247   fbody->graph->ToGraphDef(&function_body);
248 
249   // Function body shares the library with the graph that instantiated it. We do
250   // not need a full copy of the function library, just the reachable subset.
251   *function_body.mutable_library() = flib.ReachableDefinitions(func).ToProto();
252 
253   VLOG(3) << absl::Substitute(
254       "Deleted $0 unreachable functions from the Grappler function item "
255       "instantiation of $1 (library size = $2)",
256       flib.num_functions() - function_body.library().function_size(),
257       signature.name(), function_body.library().function_size());
258 
259   const int num_instantiated_inputs = fbody->arg_types.size();
260   const int num_instantiated_outputs = fbody->ret_types.size();
261 
262   std::vector<InputArgInstantiation> inputs;
263   inputs.reserve(num_instantiated_inputs);
264 
265   for (int in_id = 0; in_id < num_instantiated_inputs; ++in_id) {
266     const Node* node = fbody->arg_nodes[in_id];
267     const DataType& dtype = fbody->arg_types[in_id];
268     inputs.emplace_back(node->name(), dtype);
269   }
270 
271   std::vector<OutputArgInstantiation> outputs;
272   outputs.reserve(num_instantiated_outputs);
273 
274   for (int out_id = 0; out_id < num_instantiated_outputs; ++out_id) {
275     const Node* node = fbody->ret_nodes[out_id];
276     const DataType& dtype = fbody->ret_types[out_id];
277     outputs.emplace_back(node->name(), dtype);
278   }
279 
280   // Control outputs ensure that all side-effectful nodes in the function body
281   // will execute, even if they are not required to compute regular output args.
282   std::vector<ControlOutput> control_outputs;
283   control_outputs.reserve(func.control_ret_size());
284   for (const auto& control_ret : func.control_ret()) {
285     control_outputs.push_back({control_ret.first, control_ret.second});
286   }
287 
288   std::vector<const FunctionDef::ArgAttrs*> arg_attr(inputs.size(), nullptr);
289   for (const auto& attr : func.arg_attr()) {
290     arg_attr.at(attr.first) = &attr.second;
291   }
292 
293   *item = GrapplerFunctionItem(
294       /*func_name=*/signature.name(),
295       /*description=*/signature.description(),
296       /*func_attr=*/AttrSlice(&func.attr()), std::move(arg_attr),
297       std::move(inputs), std::move(outputs), std::move(control_outputs),
298       graph_def_version, signature.is_stateful(), std::move(function_body));
299   return Status::OK();
300 }
301 
MakeGrapplerFunctionItem(const FunctionDef & func,const FunctionLibraryDefinition & flib,const int graph_def_version,GrapplerFunctionItem * item)302 Status MakeGrapplerFunctionItem(const FunctionDef& func,
303                                 const FunctionLibraryDefinition& flib,
304                                 const int graph_def_version,
305                                 GrapplerFunctionItem* item) {
306   return MakeGrapplerFunctionItem(func, AttrSlice(), flib, graph_def_version,
307                                   item);
308 }
309 
ReplaceInputWithConst(const NodeDef & input_const,int input_index,GrapplerFunctionItem * item)310 Status ReplaceInputWithConst(const NodeDef& input_const, int input_index,
311                              GrapplerFunctionItem* item) {
312   if (!IsConstant(input_const)) {
313     return errors::InvalidArgument("Input node is not a constant: ",
314                                    SummarizeNodeDef(input_const));
315   }
316   const int item_input_size = item->input_size();
317   if (input_index < 0 || input_index >= item_input_size) {
318     return errors::InvalidArgument(
319         "Function input index is out of bound: index=", input_index,
320         " input_size=", item->input_size());
321   }
322 
323   const InputArgInstantiation& input_arg = item->input(input_index);
324 
325   for (NodeDef& node : *item->graph.mutable_node()) {
326     // Replace '_Arg' node in the function body with a 'Const' node.
327     if (node.name() == input_arg.node_name) {
328       node = input_const;
329       node.set_name(input_arg.node_name);
330       node.clear_input();
331       node.clear_device();  // device placement is defined by instantiating node
332     }
333 
334     // Update index in all inputs after the removed const input.
335     if (IsArg(node)) {
336       auto attrs = AttrSlice(node);
337       int index;
338       TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "index", &index));
339       if (index >= input_index) {
340         (*node.mutable_attr())["index"].set_i(index - 1);
341       }
342     }
343   }
344 
345   item->input_args_.erase(item->input_args_.begin() + input_index);
346   item->arg_attr_.erase(item->arg_attr_.begin() + input_index);
347 
348   return Status::OK();
349 }
350 
RemoveFunctionOutputs(const absl::flat_hash_set<int> & remove_outputs,GrapplerFunctionItem * item,std::vector<std::pair<int,int>> * output_mapping)351 Status RemoveFunctionOutputs(const absl::flat_hash_set<int>& remove_outputs,
352                              GrapplerFunctionItem* item,
353                              std::vector<std::pair<int, int>>* output_mapping) {
354   DCHECK(output_mapping->empty());
355 
356   // Do some sanity checking of the removed outputs positions.
357   for (int remove_output : remove_outputs) {
358     const int item_output_size = item->output_size();
359     if (remove_output < 0 || remove_output >= item_output_size) {
360       return errors::InvalidArgument(
361           "Function output index is out of bound: index=", remove_output,
362           " output_size=", item->output_size());
363     }
364   }
365 
366   absl::flat_hash_set<const OutputArgInstantiation*> remove_output_args;
367   const auto is_remove_output_arg = [&](const OutputArgInstantiation& output) {
368     return remove_output_args.find(&output) != remove_output_args.end();
369   };
370 
371   for (int i = 0, end = item->output_size(); i < end; ++i) {
372     const OutputArgInstantiation& output = item->output(i);
373     if (remove_outputs.contains(i)) {
374       VLOG(3) << "Remove functions output: name=" << output.node_name
375               << "(index = " << i << ")";
376       remove_output_args.insert(&output);
377     } else if (!remove_output_args.empty()) {
378       // Add output mapping only if output position changed.
379       output_mapping->push_back({i, i - remove_output_args.size()});
380     }
381   }
382 
383   // Update 'index' attribute in all '_Retval' nodes that are in output mapping.
384   for (NodeDef& node : *item->graph.mutable_node()) {
385     if (IsRetval(node)) {
386       auto attrs = AttrSlice(node);
387       int index;
388       TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "index", &index));
389 
390       for (const auto& mapping : *output_mapping) {
391         const int from = mapping.first;
392         const int to = mapping.second;
393         if (index == from) {
394           (*node.mutable_attr())["index"].set_i(to);
395         }
396       }
397     }
398   }
399 
400   auto& o = item->output_args_;
401   o.erase(std::remove_if(o.begin(), o.end(), is_remove_output_arg), o.end());
402 
403   return Status::OK();
404 }
405 
406 namespace {
407 
408 // FunctionDef uses different connectivity encoding for the function body nodes,
409 // than a GraphDef (see function.proto for details). This is a helper class that
410 // converts inputs in GraphDef format (node[:position]) to the FunctionDef
411 // format (node:output[:position]).
412 class MakeFunctionDefHelper {
413  public:
414   MakeFunctionDefHelper() = default;
415 
416   Status Initialize(const GrapplerFunctionItem& item,
417                     const FunctionLibraryDefinition& flib);
418 
419   // Converts input name from GraphDef format (name[:position]) to the
420   // FunctionDef input format (name[:output][:position]) using registered input
421   // arg instantiations and function body outputs.
422   Status AsFunctionDefInput(const string& graph_def_input,
423                             string* func_def_input) const;
424 
425   // Updates Node inputs from GraphDef to FunctionDef format.
426   Status AsFunctionDefNode(NodeDef* function_body_node) const;
427 
IsInputNode(const NodeDef & node) const428   bool IsInputNode(const NodeDef& node) const {
429     return input_nodes_.contains(node.name());
430   }
431 
IsOutputNode(const NodeDef & node) const432   bool IsOutputNode(const NodeDef& node) const {
433     return output_nodes_.contains(node.name());
434   }
435 
436  private:
437   absl::flat_hash_set<absl::string_view> input_nodes_;
438   absl::flat_hash_set<absl::string_view> output_nodes_;
439   // Mapping from function body node name to output names range map.
440   absl::flat_hash_map<string, tensorflow::NameRangeMap> function_body_outputs_;
441 };
442 
Initialize(const GrapplerFunctionItem & item,const FunctionLibraryDefinition & flib)443 Status MakeFunctionDefHelper::Initialize(
444     const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib) {
445   for (const InputArgInstantiation& input_arg : item.inputs()) {
446     input_nodes_.insert(input_arg.node_name);
447   }
448   for (const OutputArgInstantiation& output_arg : item.outputs()) {
449     output_nodes_.insert(output_arg.node_name);
450   }
451 
452   for (const NodeDef& node : item.function_body().node()) {
453     const OpRegistrationData* registration;
454     TF_RETURN_IF_ERROR(flib.LookUp(node.op(), &registration));
455 
456     tensorflow::NameRangeMap outputs_range_map;
457     TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode(
458         node, registration->op_def, nullptr, &outputs_range_map));
459 
460     function_body_outputs_.emplace(node.name(), std::move(outputs_range_map));
461   }
462 
463   return Status::OK();
464 }
465 
AsFunctionDefInput(const string & graph_def_input,string * func_def_input) const466 Status MakeFunctionDefHelper::AsFunctionDefInput(const string& graph_def_input,
467                                                  string* func_def_input) const {
468   if (IsControlInput(graph_def_input)) {
469     *func_def_input = graph_def_input;
470     return Status::OK();
471   }
472 
473   const SafeTensorId tensor = ParseTensorName(graph_def_input);
474   DCHECK_GE(tensor.index(), 0);
475 
476   // Graph def input corresponds to one of the function inputs.
477   const auto is_input = input_nodes_.find(tensor.node());
478   if (is_input != input_nodes_.end()) {
479     DCHECK_EQ(tensor.index(), 0);
480     *func_def_input = tensor.node();
481     return Status::OK();
482   }
483 
484   // Or it must be output from one of the function body nodes
485   const auto is_body_output = function_body_outputs_.find(tensor.node());
486   if (is_body_output != function_body_outputs_.end()) {
487     const tensorflow::NameRangeMap& outputs_range_map = is_body_output->second;
488 
489     for (const auto& el : outputs_range_map) {
490       const auto& output_name = el.first;
491       const auto& output_range = el.second;
492       if (tensor.index() >= output_range.first &&
493           tensor.index() < output_range.second) {
494         *func_def_input = absl::StrCat(tensor.node(), ":", output_name, ":",
495                                        tensor.index() - output_range.first);
496         return Status::OK();
497       }
498     }
499   }
500 
501   return errors::InvalidArgument("Unknown graph def input: ", graph_def_input);
502 }
503 
AsFunctionDefNode(NodeDef * function_body_node) const504 Status MakeFunctionDefHelper::AsFunctionDefNode(
505     NodeDef* function_body_node) const {
506   string func_def_input;
507 
508   for (int i = 0; i < function_body_node->input_size(); ++i) {
509     TF_RETURN_IF_ERROR(
510         AsFunctionDefInput(function_body_node->input(i), &func_def_input));
511     function_body_node->set_input(i, func_def_input);
512   }
513 
514   return Status::OK();
515 }
516 
517 }  // namespace
518 
MakeFunctionDef(const GrapplerFunctionItem & item,const FunctionLibraryDefinition & flib,FunctionDef * func)519 Status MakeFunctionDef(const GrapplerFunctionItem& item,
520                        const FunctionLibraryDefinition& flib,
521                        FunctionDef* func) {
522   func->mutable_signature()->set_name(item.id);
523   func->mutable_signature()->set_description(item.description());
524   func->mutable_signature()->set_is_stateful(item.is_stateful());
525 
526   MakeFunctionDefHelper helper;
527   TF_RETURN_IF_ERROR(helper.Initialize(item, flib));
528 
529   // Mapping from the '_Retval' node name to the output tensor.
530   absl::flat_hash_map<absl::string_view, string> output_tensors;
531   for (const NodeDef& func_body_node : item.function_body().node()) {
532     if (!helper.IsOutputNode(func_body_node)) continue;
533     if (func_body_node.input_size() != 1) {
534       return errors::Internal("_Retval node must have single input: ",
535                               SummarizeNodeDef(func_body_node));
536     }
537     output_tensors.emplace(func_body_node.name(), func_body_node.input(0));
538   }
539 
540   for (const InputArgInstantiation& input_arg : item.inputs()) {
541     OpDef::ArgDef arg_def;
542     arg_def.set_name(input_arg.node_name);
543     arg_def.set_type(input_arg.data_type);
544     arg_def.set_is_ref(IsRefType(input_arg.data_type));
545     *func->mutable_signature()->add_input_arg() = arg_def;
546   }
547 
548   // Add function output arguments.
549   for (const OutputArgInstantiation& output_arg : item.outputs()) {
550     const string output_name =
551         absl::StrReplaceAll(output_arg.node_name, {{"_RetVal", ""}});
552 
553     OpDef::ArgDef arg_def;
554     arg_def.set_name(output_name);
555     arg_def.set_type(output_arg.data_type);
556     arg_def.set_is_ref(IsRefType(output_arg.data_type));
557     *func->mutable_signature()->add_output_arg() = arg_def;
558 
559     auto it = output_tensors.find(output_arg.node_name);
560     if (it == output_tensors.end()) {
561       return errors::Internal(
562           "Can't find an output tensor for the output node: ",
563           output_arg.node_name);
564     }
565 
566     TF_RETURN_IF_ERROR(helper.AsFunctionDefInput(
567         it->second, &(*func->mutable_ret())[output_name]));
568   }
569 
570   // Add function control outputs.
571   for (const ControlOutput& control_out : item.control_outputs()) {
572     func->mutable_control_ret()->insert(
573         {control_out.output_name, control_out.node_name});
574     *func->mutable_signature()->add_control_output() = control_out.output_name;
575   }
576 
577   // Copy function definition specific attributes.
578   for (const auto& attr : item.func_attr()) {
579     const auto& attr_name = attr.first;
580     const auto& attr_value = attr.second;
581     (*func->mutable_attr())[attr_name] = attr_value;
582   }
583 
584   // Copy function arg attributes.
585   for (int i = 0, end = item.arg_attr().size(); i < end; ++i) {
586     const auto* attr = item.arg_attr().at(i);
587     if (attr != nullptr) {
588       (*func->mutable_arg_attr())[i] = *attr;
589     }
590   }
591 
592   // Copy function body nodes to the FunctionDef and update input format
593   for (const NodeDef& func_node : item.function_body().node()) {
594     // Skip original `_Arg` and `_Retval` nodes. If node was converted to some
595     // other type (e.g. inputs converted to placeholders), we need to check that
596     // it's not registered as function input or output node.
597     if (IsArg(func_node) || IsRetval(func_node) ||
598         helper.IsInputNode(func_node) || helper.IsOutputNode(func_node))
599       continue;
600 
601     NodeDef* func_def_node = func->add_node_def();
602     *func_def_node = func_node;
603     TF_RETURN_IF_ERROR(helper.AsFunctionDefNode(func_def_node));
604   }
605 
606   return Status::OK();
607 }
608 
609 }  // end namespace grappler
610 }  // end namespace tensorflow
611