1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/grappler/utils/functions.h"
16 
17 #include "absl/container/flat_hash_map.h"
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/substitute.h"
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/function.h"
23 #include "tensorflow/core/framework/function.pb.h"
24 #include "tensorflow/core/framework/graph_def_util.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/tensor_shape.pb.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/framework/versions.pb.h"
30 #include "tensorflow/core/grappler/op_types.h"
31 #include "tensorflow/core/grappler/utils.h"
32 #include "tensorflow/core/lib/strings/scanner.h"
33 
34 namespace tensorflow {
35 namespace grappler {
36 
37 namespace {
38 
RegisterFunctionBodyOutputs(const OpRegistrationData & registration,const NodeDef & node,GrapplerFunctionConnectivity * connectivity)39 Status RegisterFunctionBodyOutputs(const OpRegistrationData& registration,
40                                    const NodeDef& node,
41                                    GrapplerFunctionConnectivity* connectivity) {
42   tensorflow::NameRangeMap outputs_range_map;
43   TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode(
44       node, registration.op_def, nullptr, &outputs_range_map));
45   connectivity->RegisterFunctionBodyOutputs(node.name(),
46                                             std::move(outputs_range_map));
47   return Status::OK();
48 }
49 
RegisterFunctionBodyOutputs(const FunctionLibraryDefinition & flib,const NodeDef & node,GrapplerFunctionConnectivity * connectivity)50 Status RegisterFunctionBodyOutputs(const FunctionLibraryDefinition& flib,
51                                    const NodeDef& node,
52                                    GrapplerFunctionConnectivity* connectivity) {
53   const OpRegistrationData* registration;
54   TF_RETURN_IF_ERROR(flib.LookUp(node.op(), &registration));
55   return RegisterFunctionBodyOutputs(*registration, node, connectivity);
56 }
57 
58 // Replace the placeholder attribute values with the values specified in
59 // instantiation attributes.
ResolveFunctionBodyNodeAttrPlaceholders(const AttrSlice & func_instantiation_attr,NodeDef * node)60 Status ResolveFunctionBodyNodeAttrPlaceholders(
61     const AttrSlice& func_instantiation_attr, NodeDef* node) {
62   for (auto& attr : *node->mutable_attr()) {
63     const string& placeholder = attr.second.placeholder();
64     if (placeholder.empty()) continue;
65 
66     const AttrValue* attr_value = func_instantiation_attr.Find(placeholder);
67     if (attr_value) {
68       attr.second = *attr_value;
69     } else {
70       return errors::InvalidArgument("Can't resolve placeholder: ",
71                                      placeholder);
72     }
73   }
74   return Status::OK();
75 }
76 
77 }  // namespace
78 
RegisterInputArgExpansion(InputArgExpansion input_arg_expansion)79 void GrapplerFunctionConnectivity::RegisterInputArgExpansion(
80     InputArgExpansion input_arg_expansion) {
81   string input_name = input_arg_expansion.input_name;
82   const auto& placeholders = input_arg_expansion.placeholders;
83 
84   for (int i = 0; i < placeholders.size(); ++i) {
85     const string& placeholder = input_arg_expansion.placeholders[i];
86     input_arg_placeholders_.insert(
87         {placeholder, InputArgPlaceholder{input_name, /*input_index=*/i}});
88   }
89   input_arg_expansions_.insert(
90       {std::move(input_name), std::move(input_arg_expansion)});
91 }
92 
RegisterFunctionBodyOutputs(const string & node_name,tensorflow::NameRangeMap && outputs)93 void GrapplerFunctionConnectivity::RegisterFunctionBodyOutputs(
94     const string& node_name, tensorflow::NameRangeMap&& outputs) {
95   function_body_outputs_[node_name] = std::move(outputs);
96 }
97 
ExpandFunctionDefInput(const string & func_def_input,std::vector<string> * graph_def_inputs) const98 Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
99     const string& func_def_input, std::vector<string>* graph_def_inputs) const {
100   using ::tensorflow::strings::Scanner;
101 
102   if (IsControlInput(func_def_input)) {
103     graph_def_inputs->push_back(func_def_input);
104     return Status::OK();
105   }
106 
107   // Parse input format: "node_name[:node_output][:position]"
108   string node_name;
109   string node_output;
110   int position = -1;
111 
112   StringPiece capture;
113   StringPiece remaining;
114 
115   // Parse "node_name"
116   if (Scanner(func_def_input)
117           .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
118           .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
119           .GetResult(&remaining, &capture)) {
120     node_name = string(capture.data(), capture.size());
121   }
122 
123   // Parse "node_output" if it exists
124   if (Scanner(remaining)
125           .OneLiteral(":")
126           .RestartCapture()
127           .One(strings::Scanner::LETTER)
128           .Any(strings::Scanner::LETTER_DIGIT_UNDERSCORE)
129           .GetResult(&remaining, &capture)) {
130     node_output = string(capture.data(), capture.size());
131   }
132 
133   // Parse "position" if it exists
134   if (Scanner(remaining)
135           .OneLiteral(":")
136           .RestartCapture()
137           .Many(strings::Scanner::DIGIT)
138           .GetResult(nullptr, &capture)) {
139     CHECK(strings::safe_strto32(capture, &position));
140   }
141 
142   // If "node_output" is not empty, it must be an output of a function body node
143   bool is_function_body_output = !node_output.empty();
144 
145   // Function input argument: "node_name[:position]"
146   if (!is_function_body_output) {
147     auto input_arg = input_arg_expansions_.find(node_name);
148     if (input_arg != input_arg_expansions_.end()) {
149       const InputArgExpansion& input_arg_expansion = input_arg->second;
150       const auto& placeholders = input_arg_expansion.placeholders;
151 
152       if (position == -1) {
153         // If position is not defined use all placeholders
154         graph_def_inputs->reserve(placeholders.size());
155         for (const string& placeholder : placeholders) {
156           graph_def_inputs->push_back(placeholder);
157         }
158       } else {
159         if (position > input_arg_expansion.placeholders.size() - 1) {
160           return errors::InvalidArgument("Invalid input ", node_name,
161                                          "position: ", position,
162                                          " (out of range)");
163         }
164         graph_def_inputs->push_back(input_arg_expansion.placeholders[position]);
165       }
166 
167       return Status::OK();
168     }
169   }
170 
171   // Function body output: "node_name:node_output[:position]"
172   if (is_function_body_output) {
173     auto function_body_outputs = function_body_outputs_.find(node_name);
174     if (function_body_outputs != function_body_outputs_.end()) {
175       const tensorflow::NameRangeMap& outputs = function_body_outputs->second;
176       auto output = outputs.find(node_output);
177       if (output != outputs.end()) {
178         const auto& output_range = output->second;
179 
180         if (position == -1) {
181           graph_def_inputs->reserve(graph_def_inputs->size() +
182                                     output_range.second - output_range.first);
183           // If position is not defined expand node output range
184           for (int i = output_range.first; i < output_range.second; ++i) {
185             graph_def_inputs->push_back(
186                 i == 0 ? node_name : absl::StrCat(node_name, ":", i));
187           }
188         } else {
189           if (position > (output_range.second - output_range.first)) {
190             return errors::InvalidArgument(
191                 "Invalid node ", node_name, " output ", node_output,
192                 " position: ", position, " (out of range)");
193           }
194           int pos = output_range.first + position;
195           graph_def_inputs->push_back(
196               pos == 0 ? node_name : absl::StrCat(node_name, ":", pos));
197         }
198 
199         return Status::OK();
200       }
201     }
202   }
203 
204   return errors::InvalidArgument("Failed to expand a function def input: ",
205                                  func_def_input);
206 }
207 
ExpandNodeInputs(NodeDef * function_body_node) const208 Status GrapplerFunctionConnectivity::ExpandNodeInputs(
209     NodeDef* function_body_node) const {
210   std::vector<string> expanded_inputs;
211 
212   for (const string& function_def_input : function_body_node->input()) {
213     TF_RETURN_IF_ERROR(
214         ExpandFunctionDefInput(function_def_input, &expanded_inputs));
215   }
216 
217   function_body_node->clear_input();
218   for (string& expanded_input : expanded_inputs)
219     function_body_node->add_input(std::move(expanded_input));
220   return Status::OK();
221 }
222 
AsFunctionDefInput(const string & graph_def_input,string * func_def_input) const223 Status GrapplerFunctionConnectivity::AsFunctionDefInput(
224     const string& graph_def_input, string* func_def_input) const {
225   if (IsControlInput(graph_def_input)) {
226     *func_def_input = graph_def_input;
227     return Status::OK();
228   }
229 
230   const TensorId tensor = ParseTensorName(graph_def_input);
231   DCHECK_GE(tensor.index(), 0);
232 
233   const absl::string_view node_name = tensor.node();
234   const int index = tensor.index();
235 
236   // Check if it's an input arg placeholder
237   if (tensor.index() == 0) {
238     const auto is_input_placeholder = input_arg_placeholders_.find(node_name);
239     if (is_input_placeholder != input_arg_placeholders_.end()) {
240       const InputArgPlaceholder& placeholder = is_input_placeholder->second;
241       *func_def_input =
242           absl::StrCat(placeholder.input_name, ":", placeholder.input_index);
243       return Status::OK();
244     }
245   }
246 
247   // It must be output from one of the function body nodes
248   const auto is_body_output = function_body_outputs_.find(tensor.node());
249   if (is_body_output != function_body_outputs_.end()) {
250     const tensorflow::NameRangeMap& outputs_range_map = is_body_output->second;
251 
252     for (const auto& el : outputs_range_map) {
253       const auto& output_name = el.first;
254       const auto& output_range = el.second;
255       if (index >= output_range.first && index < output_range.second) {
256         int pos = index - output_range.first;
257         *func_def_input = absl::StrCat(node_name, ":", output_name, ":", pos);
258         return Status::OK();
259       }
260     }
261   }
262 
263   return errors::InvalidArgument("Unknown graph def input: ", graph_def_input);
264 }
265 
AsFunctionDefNode(NodeDef * function_body_node) const266 Status GrapplerFunctionConnectivity::AsFunctionDefNode(
267     NodeDef* function_body_node) const {
268   string func_def_input;
269 
270   for (int i = 0; i < function_body_node->input_size(); ++i) {
271     TF_RETURN_IF_ERROR(
272         AsFunctionDefInput(function_body_node->input(i), &func_def_input));
273     function_body_node->set_input(i, func_def_input);
274   }
275 
276   return Status::OK();
277 }
278 
GetTypeAttr(const string & type_attr_name,DataType * data_type) const279 Status GrapplerFunctionItemInstantiation::GetTypeAttr(
280     const string& type_attr_name, DataType* data_type) const {
281   const AttrValue* type_attr = func_instantiation_attr_.Find(type_attr_name);
282   if (type_attr == nullptr) {
283     return errors::InvalidArgument("Type attribute ", type_attr_name,
284                                    " is not defined");
285   } else if (type_attr->type() == DT_INVALID) {
286     return errors::InvalidArgument("Type attribute ", type_attr_name,
287                                    " is not defined with a valid type");
288   } else {
289     *data_type = type_attr->type();
290   }
291   return Status::OK();
292 }
293 
GetArgType(const OpDef::ArgDef & arg,DataType * data_type) const294 Status GrapplerFunctionItemInstantiation::GetArgType(
295     const OpDef::ArgDef& arg, DataType* data_type) const {
296   if (arg.type() != DT_INVALID) {
297     *data_type = arg.type();
298   } else {
299     if (!arg.type_list_attr().empty() || !arg.number_attr().empty()) {
300       return errors::InvalidArgument(
301           "Arguments with sequence of tensors are not supported. Unsupported "
302           "argument name: ",
303           arg.name());
304     }
305     TF_RETURN_IF_ERROR(GetTypeAttr(arg.type_attr(), data_type));
306   }
307   return Status::OK();
308 }
309 
GrapplerFunctionItem(string func_name,string description,AttrSlice func_attr,std::vector<InputArgExpansion> input_arg_expansions,std::vector<OutputArgExpansion> output_arg_expansions,std::vector<ControlOutput> control_outputs,const int graph_def_version,const bool is_stateful,GraphDef && function_body)310 GrapplerFunctionItem::GrapplerFunctionItem(
311     string func_name, string description, AttrSlice func_attr,
312     std::vector<InputArgExpansion> input_arg_expansions,
313     std::vector<OutputArgExpansion> output_arg_expansions,
314     std::vector<ControlOutput> control_outputs, const int graph_def_version,
315     const bool is_stateful, GraphDef&& function_body)
316     : description_(std::move(description)),
317       func_attr_(func_attr),
318       input_arg_expansions_(std::move(input_arg_expansions)),
319       output_arg_expansions_(std::move(output_arg_expansions)),
320       control_outputs_(std::move(control_outputs)),
321       is_stateful_(is_stateful) {
322   id = std::move(func_name);
323   graph = std::move(function_body);
324 
325   graph.mutable_versions()->set_producer(graph_def_version);
326   // Fill the feed nodes with input placeholders.
327   for (const InputArgExpansion& input_arg : input_arg_expansions_) {
328     for (const string& placeholder : input_arg.placeholders) {
329       feed.push_back({placeholder, Tensor()});
330     }
331   }
332   // Fill the fetch nodes with outputs.
333   for (const OutputArgExpansion& output_arg : output_arg_expansions_) {
334     for (const string& output_node : output_arg.output_nodes) {
335       fetch.push_back(output_node);
336     }
337   }
338   // We must keep all control output nodes.
339   for (const ControlOutput& control_output : control_outputs_) {
340     keep_ops.push_back(control_output.node_name);
341   }
342 
343   // Tensorflow functions execution semantics is different from the main graph,
344   // and we need to preserve it when we do graph optimizations.
345   optimization_options().allow_pruning_stateful_and_dataset_ops = false;
346 }
347 
description() const348 const string& GrapplerFunctionItem::description() const { return description_; }
349 
inputs() const350 const std::vector<InputArgExpansion>& GrapplerFunctionItem::inputs() const {
351   return input_arg_expansions_;
352 }
353 
input(int i) const354 const InputArgExpansion& GrapplerFunctionItem::input(int i) const {
355   return input_arg_expansions_[i];
356 }
357 
input_size() const358 const std::size_t GrapplerFunctionItem::input_size() const {
359   return input_arg_expansions_.size();
360 }
361 
outputs() const362 const std::vector<OutputArgExpansion>& GrapplerFunctionItem::outputs() const {
363   return output_arg_expansions_;
364 }
365 
output(int i) const366 const OutputArgExpansion& GrapplerFunctionItem::output(int i) const {
367   return output_arg_expansions_[i];
368 }
369 
output_size() const370 const std::size_t GrapplerFunctionItem::output_size() const {
371   return output_arg_expansions_.size();
372 }
373 
control_outputs() const374 const std::vector<ControlOutput>& GrapplerFunctionItem::control_outputs()
375     const {
376   return control_outputs_;
377 }
378 
control_output_size() const379 const std::size_t GrapplerFunctionItem::control_output_size() const {
380   return control_outputs_.size();
381 }
382 
func_attr() const383 const AttrSlice& GrapplerFunctionItem::func_attr() const { return func_attr_; }
384 
function_body() const385 const GraphDef& GrapplerFunctionItem::function_body() const { return graph; }
386 
mutable_function_body()387 GraphDef& GrapplerFunctionItem::mutable_function_body() { return graph; }
388 
is_stateful() const389 bool GrapplerFunctionItem::is_stateful() const { return is_stateful_; }
390 
SwapFunctionBody(GraphDef && other)391 GrapplerFunctionItem& GrapplerFunctionItem::SwapFunctionBody(GraphDef&& other) {
392   graph.Swap(&other);
393   return *this;
394 }
395 
HasParametrizedType(const FunctionDef & func)396 bool HasParametrizedType(const FunctionDef& func) {
397   const auto is_type_parametrized = [](const OpDef::ArgDef& arg) {
398     return !arg.type_attr().empty() || !arg.number_attr().empty() ||
399            !arg.type_list_attr().empty();
400   };
401 
402   const auto& input = func.signature().input_arg();
403   const auto& output = func.signature().output_arg();
404   return std::any_of(input.begin(), input.end(), is_type_parametrized) ||
405          std::any_of(output.begin(), output.end(), is_type_parametrized);
406 }
407 
HasParametrizedBody(const FunctionDef & func)408 bool HasParametrizedBody(const FunctionDef& func) {
409   const auto is_parametrized = [&](const NodeDef& node) {
410     for (const auto& attr : node.attr()) {
411       if (!attr.second.placeholder().empty()) return true;
412     }
413     return false;
414   };
415   return std::any_of(func.node_def().begin(), func.node_def().end(),
416                      is_parametrized);
417 }
418 
IsParametrized(const FunctionDef & func)419 bool IsParametrized(const FunctionDef& func) {
420   return HasParametrizedType(func) || HasParametrizedBody(func);
421 }
422 
InstantiationTypeParameters(const FunctionDef & func,const AttrSlice & func_instantiation_attr,absl::flat_hash_map<string,DataType> * type_parameters)423 Status InstantiationTypeParameters(
424     const FunctionDef& func, const AttrSlice& func_instantiation_attr,
425     absl::flat_hash_map<string, DataType>* type_parameters) {
426   if (!type_parameters->empty()) {
427     return errors::InvalidArgument("Type parameters output map must be empty");
428   }
429 
430   GrapplerFunctionItemInstantiation instantiation(func_instantiation_attr);
431 
432   const auto resolve_type_attr = [&](const OpDef::ArgDef& arg) {
433     // Check if it's unknown and unresolved type.
434     if (arg.type() == DT_INVALID &&
435         type_parameters->find(arg.type_attr()) == type_parameters->end()) {
436       DataType data_type;
437       TF_RETURN_IF_ERROR(instantiation.GetArgType(arg, &data_type));
438       type_parameters->insert({arg.type_attr(), data_type});
439     }
440     return Status::OK();
441   };
442 
443   for (const auto& input : func.signature().input_arg())
444     TF_RETURN_IF_ERROR(resolve_type_attr(input));
445   for (const auto& output : func.signature().output_arg())
446     TF_RETURN_IF_ERROR(resolve_type_attr(output));
447 
448   return Status::OK();
449 }
450 
InstantiationBodyParameters(const FunctionDef & func,const AttrSlice & func_instantiation_attr,absl::flat_hash_map<string,AttrValue> * body_parameters)451 Status InstantiationBodyParameters(
452     const FunctionDef& func, const AttrSlice& func_instantiation_attr,
453     absl::flat_hash_map<string, AttrValue>* body_parameters) {
454   if (!body_parameters->empty()) {
455     return errors::InvalidArgument("Body parameters output map must be empty");
456   }
457 
458   for (const NodeDef& func_body_node : func.node_def()) {
459     for (auto& attr : func_body_node.attr()) {
460       const string& placeholder = attr.second.placeholder();
461 
462       if (placeholder.empty() ||
463           body_parameters->find(placeholder) != body_parameters->end()) {
464         continue;
465       }
466 
467       const AttrValue* placeholder_value =
468           func_instantiation_attr.Find(placeholder);
469       if (placeholder_value) {
470         body_parameters->insert({placeholder, *placeholder_value});
471       } else {
472         return errors::InvalidArgument("Can't resolve placeholder: ",
473                                        placeholder);
474       }
475     }
476   }
477 
478   return Status::OK();
479 }
480 
MakeGrapplerFunctionItem(const FunctionDef & func,const AttrSlice & func_instantiation_attr,const FunctionLibraryDefinition & flib,const int graph_def_version,GrapplerFunctionItem * item)481 Status MakeGrapplerFunctionItem(const FunctionDef& func,
482                                 const AttrSlice& func_instantiation_attr,
483                                 const FunctionLibraryDefinition& flib,
484                                 const int graph_def_version,
485                                 GrapplerFunctionItem* item) {
486   const OpDef& signature = func.signature();
487 
488   if (signature.name().empty()) {
489     return errors::InvalidArgument("Function name must be specified");
490   }
491 
492   // Function types will be resolved from function instantiation attributes. All
493   // other attributes will be lost during conversion to FunctionDef.
494   for (const OpDef::AttrDef& attr : signature.attr()) {
495     if (attr.type() != "type") {
496       return errors::InvalidArgument(
497           "Function signature must have only type attributes");
498     }
499   }
500 
501   // Helper methods to lookup function instantiation attributes
502   GrapplerFunctionItemInstantiation instantiation(func_instantiation_attr);
503 
504   // Mapping from FunctionDef input format (name[:output][:position]) to
505   // GraphDef input format (name[:position])
506   GrapplerFunctionConnectivity connectivity;
507 
508   // Instantiate function body into a statically defined graph def.
509   GraphDef function_body;
510 
511   // Function body shares the library with the graph that instantiated it. We do
512   // not need a full copy of the function library, just the reachable subset.
513   *function_body.mutable_library() = flib.ReachableDefinitions(func).ToProto();
514 
515   VLOG(3) << absl::Substitute(
516       "Deleted $0 unreachable functions from the Grappler function item "
517       "instantiation of $1 (library size = $2)",
518       flib.num_functions() - function_body.library().function_size(),
519       signature.name(), function_body.library().function_size());
520 
521   // TODO(ezhulenev): support functions with tensor sequence inputs/outputs
522 
523   // Make sure that there are no tensor lists in inputs or outputs.
524   for (const OpDef::ArgDef& input : signature.input_arg()) {
525     if (!input.type_list_attr().empty() || !input.number_attr().empty()) {
526       return errors::InvalidArgument(
527           "Inputs with lists of tensors are not supported. Input: ",
528           input.name());
529     }
530   }
531   for (const OpDef::ArgDef& output : signature.output_arg()) {
532     if (!output.type_list_attr().empty() || !output.number_attr().empty()) {
533       return errors::InvalidArgument(
534           "Outputs with lists of tensors are not supported. Output: ",
535           output.name());
536     }
537   }
538 
539   std::vector<InputArgExpansion> inputs;
540   inputs.reserve(signature.input_arg_size());
541 
542   // For each input argument create a placeholder in function body.
543   for (const OpDef::ArgDef& input : signature.input_arg()) {
544     DataType input_data_type;
545     TF_RETURN_IF_ERROR(instantiation.GetArgType(input, &input_data_type));
546 
547     NodeDef* placeholder = function_body.add_node();
548     placeholder->set_name(input.name());
549     placeholder->set_op("Placeholder");
550     (*placeholder->mutable_attr())["dtype"].set_type(input_data_type);
551     (*placeholder->mutable_attr())["shape"].mutable_shape()->set_unknown_rank(
552         true);
553 
554     InputArgExpansion input_expansion{/*input_name=*/input.name(),
555                                       /*data_type=*/input_data_type,
556                                       /*is_ref=*/input.is_ref(),
557                                       /*placeholders=*/{input.name()}};
558     connectivity.RegisterInputArgExpansion(input_expansion);
559     inputs.push_back(std::move(input_expansion));
560   }
561 
562   // Keep names of all nodes in the function body to guarantee that we do not
563   // add an identity with a duplicate name.
564   absl::flat_hash_set<absl::string_view> func_body_nodes;
565 
566   // Generate unique output node name: "${out_arg_name}_output_node_${index}".
567   const auto output_node_name = [&func_body_nodes](const OpDef::ArgDef& out,
568                                                    int index) -> string {
569     string name = absl::StrCat(out.name(), "_output_node_", index);
570     int i = 1;
571     while (func_body_nodes.find(name) != func_body_nodes.end()) {
572       name = absl::StrCat(out.name(), "_output_node_", index, "_", i++);
573     }
574     return name;
575   };
576 
577   // Add all function nodes to the function body.
578   for (const NodeDef& func_def_node : func.node_def()) {
579     func_body_nodes.insert(func_def_node.name());
580 
581     NodeDef* new_node = function_body.add_node();
582     *new_node = func_def_node;
583 
584     const OpRegistrationData* registration;
585     TF_RETURN_IF_ERROR(flib.LookUp(func_def_node.op(), &registration));
586 
587     // Resolve all placeholder values using function instantiation attributes.
588     TF_RETURN_IF_ERROR(ResolveFunctionBodyNodeAttrPlaceholders(
589         func_instantiation_attr, new_node));
590 
591     // Register node output range in a function connectivity.
592     TF_RETURN_IF_ERROR(RegisterFunctionBodyOutputs(*registration, func_def_node,
593                                                    &connectivity));
594   }
595 
596   // Rewrite inputs to use GraphDef format
597   for (NodeDef& node : *function_body.mutable_node()) {
598     TF_RETURN_IF_ERROR(connectivity.ExpandNodeInputs(&node));
599   }
600 
601   std::vector<OutputArgExpansion> outputs;
602   outputs.reserve(signature.output_arg_size());
603 
604   // For each function output argument we create an Identity node in the
605   // function body, that reads output tensor from the function body node.
606   for (const OpDef::ArgDef& out : signature.output_arg()) {
607     DataType output_data_type;
608     TF_RETURN_IF_ERROR(instantiation.GetArgType(out, &output_data_type));
609 
610     std::vector<string> output_tensors;
611     auto ret = func.ret().find(out.name());
612     TF_RETURN_IF_ERROR(
613         ret != func.ret().end()
614             // Expand outputs using provided output mapping
615             ? connectivity.ExpandFunctionDefInput(ret->second, &output_tensors)
616             // Otherwise output must be one of the function inputs
617             : connectivity.ExpandFunctionDefInput(out.name(), &output_tensors));
618 
619     absl::InlinedVector<string, 1> output_nodes;
620     for (int i = 0; i < output_tensors.size(); ++i) {
621       const string& output_tensor = output_tensors[i];
622 
623       NodeDef* identity = function_body.add_node();
624       identity->set_name(output_node_name(out, i));
625       identity->set_op("Identity");
626       (*identity->mutable_attr())["T"].set_type(output_data_type);
627       identity->add_input(output_tensor);
628 
629       output_nodes.push_back(identity->name());
630     }
631 
632     OutputArgExpansion output{/*output_name=*/out.name(),
633                               /*data_type=*/output_data_type,
634                               /*is_ref=*/out.is_ref(),
635                               /*output_nodes=*/std::move(output_nodes)};
636     outputs.push_back(std::move(output));
637   }
638 
639   // Control outputs ensure that all side-effectful nodes in the function body
640   // will execute, even if they are not required to compute regular output args.
641   std::vector<ControlOutput> control_outputs;
642   control_outputs.reserve(func.control_ret_size());
643   for (const auto& control_ret : func.control_ret()) {
644     control_outputs.push_back({control_ret.first, control_ret.second});
645   }
646 
647   *item = GrapplerFunctionItem(
648       /*func_name=*/signature.name(),
649       /*description=*/signature.description(),
650       /*func_attr=*/AttrSlice(&func.attr()), std::move(inputs),
651       std::move(outputs), std::move(control_outputs), graph_def_version,
652       signature.is_stateful(), std::move(function_body));
653   return Status::OK();
654 }
655 
MakeGrapplerFunctionItem(const FunctionDef & func,const FunctionLibraryDefinition & flib,const int graph_def_version,GrapplerFunctionItem * item)656 Status MakeGrapplerFunctionItem(const FunctionDef& func,
657                                 const FunctionLibraryDefinition& flib,
658                                 const int graph_def_version,
659                                 GrapplerFunctionItem* item) {
660   return MakeGrapplerFunctionItem(func, AttrSlice(), flib, graph_def_version,
661                                   item);
662 }
663 
664 // Register GrapplerFunctionItem input arg expansion and function body outputs
665 // in the GrapplerFunctionConnectivity.
RegisterGrapplerFunctionConnectivity(const GrapplerFunctionItem & item,const FunctionLibraryDefinition & flib,GrapplerFunctionConnectivity * connectivity)666 Status RegisterGrapplerFunctionConnectivity(
667     const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib,
668     GrapplerFunctionConnectivity* connectivity) {
669   for (const InputArgExpansion& input : item.inputs()) {
670     connectivity->RegisterInputArgExpansion(input);
671   }
672   for (const NodeDef& func_body_node : item.function_body().node()) {
673     TF_RETURN_IF_ERROR(
674         RegisterFunctionBodyOutputs(flib, func_body_node, connectivity));
675   }
676   return Status::OK();
677 }
678 
ReplaceInputWithConst(const NodeDef & input_const,int input_index,GrapplerFunctionItem * item)679 Status ReplaceInputWithConst(const NodeDef& input_const, int input_index,
680                              GrapplerFunctionItem* item) {
681   if (!IsConstant(input_const)) {
682     return errors::InvalidArgument("Input node ", input_const.name(),
683                                    " is not a constant");
684   }
685 
686   auto& inputs = item->input_arg_expansions_;
687 
688   // Find input arg expansion and input placeholder position in it for the
689   // given function input position.
690   InputArgExpansion* input_arg_expansion = nullptr;
691   int placeholder_idx = input_index;
692 
693   for (InputArgExpansion& input : inputs) {
694     if (placeholder_idx < input.placeholders.size()) {
695       input_arg_expansion = &input;
696       break;
697     }
698     placeholder_idx -= input.placeholders.size();
699   }
700 
701   if (input_arg_expansion == nullptr) {
702     return errors::InvalidArgument("Input placeholder not found: input_index=",
703                                    input_index, " function=", item->id);
704   }
705 
706   // Delete placeholder from input expansion.
707   string placeholder_name = input_arg_expansion->placeholders[placeholder_idx];
708   input_arg_expansion->placeholders.erase(
709       input_arg_expansion->placeholders.begin() + placeholder_idx);
710 
711   // Delete empty input expansions.
712   inputs.erase(std::remove_if(inputs.begin(), inputs.end(),
713                               [](const InputArgExpansion& input) {
714                                 return input.placeholders.empty();
715                               }),
716                inputs.end());
717 
718   // Replace placeholder node in the function body with a const node.
719   for (NodeDef& node : *item->graph.mutable_node()) {
720     if (node.name() == placeholder_name) {
721       node = input_const;
722       node.set_name(placeholder_name);
723       node.clear_input();   // remove potential control inputs
724       node.clear_device();  // device placement is defined by instantiating node
725     }
726   }
727 
728   return Status::OK();
729 }
730 
RemoveFunctionOutputs(const absl::flat_hash_set<int> & remove_outputs,GrapplerFunctionItem * item,std::vector<std::pair<int,int>> * output_mapping)731 Status RemoveFunctionOutputs(const absl::flat_hash_set<int>& remove_outputs,
732                              GrapplerFunctionItem* item,
733                              std::vector<std::pair<int, int>>* output_mapping) {
734   DCHECK(output_mapping->empty());
735 
736   // Code below assumes that we do not support tensor list outputs and there is
737   // a 1-to-1 mapping between output tensor and output argument expansion.
738   for (const OutputArgExpansion& out_arg : item->outputs()) {
739     DCHECK(out_arg.output_nodes.size() == 1)
740         << "Output arg expansion must have single output";
741   }
742 
743   // Do some sanity checking of the removed outputs positions.
744   for (int remove_output : remove_outputs) {
745     if (remove_output < 0 || remove_output >= item->output_size()) {
746       return errors::InvalidArgument(
747           "Function output index is out of bound: index=", remove_output,
748           " max_output_index=", item->output_size());
749     }
750   }
751 
752   absl::flat_hash_set<const OutputArgExpansion*> remove_output_args;
753   const auto is_remove_output_arg = [&](const OutputArgExpansion& output) {
754     return remove_output_args.find(&output) != remove_output_args.end();
755   };
756 
757   for (int i = 0; i < item->output_size(); ++i) {
758     const OutputArgExpansion& output = item->output(i);
759     if (remove_outputs.find(i) != remove_outputs.end()) {
760       VLOG(3) << "Remove functions output: output_name=" << output.output_name
761               << "(index = " << i << ")";
762       remove_output_args.insert(&output);
763     } else if (!remove_output_args.empty()) {
764       // Add output mapping only if output position changed.
765       output_mapping->push_back({i, i - remove_output_args.size()});
766     }
767   }
768 
769   auto& o = item->output_arg_expansions_;
770   o.erase(std::remove_if(o.begin(), o.end(), is_remove_output_arg), o.end());
771 
772   return Status::OK();
773 }
774 
MakeFunctionDef(const GrapplerFunctionItem & item,const FunctionLibraryDefinition & flib,FunctionDef * func)775 Status MakeFunctionDef(const GrapplerFunctionItem& item,
776                        const FunctionLibraryDefinition& flib,
777                        FunctionDef* func) {
778   func->mutable_signature()->set_name(item.id);
779   func->mutable_signature()->set_description(item.description());
780   func->mutable_signature()->set_is_stateful(item.is_stateful());
781 
782   // Keep track of placeholders that were added to the graph in place of
783   // expanded function input arguments.
784   absl::flat_hash_set<absl::string_view> input_placeholders;
785   for (const InputArgExpansion& input_arg : item.inputs()) {
786     for (const string& placeholder : input_arg.placeholders) {
787       input_placeholders.insert(placeholder);
788     }
789   }
790 
791   // Keep track of identity nodes that were added to the graph in place of
792   // expanded function output arguments.
793   absl::flat_hash_set<absl::string_view> output_nodes;
794   for (const OutputArgExpansion& output_arg : item.outputs()) {
795     for (const string& output_node : output_arg.output_nodes) {
796       output_nodes.insert(output_node);
797     }
798   }
799 
800   // If the output identity node was not modified by any optimizer, we can
801   // bypass it and returns the function value from its input.
802   absl::flat_hash_map<absl::string_view, string> output_tensors;
803   for (const NodeDef& func_body_node : item.function_body().node()) {
804     if (!IsIdentity(func_body_node)) continue;
805 
806     const string& node_name = func_body_node.name();
807     if (output_nodes.find(node_name) != output_nodes.end()) {
808       // Grappler optimizers might optimize nodes in the fanin of the output
809       // node, and forward their control dependencies. We can't express control
810       // dependencies in a function signature, so we have to keep the node.
811       if (func_body_node.input_size() == 1) {
812         VLOG(3) << "Bypass function output node: " << node_name << " -> "
813                 << func_body_node.input(0);
814         output_tensors.emplace(node_name, func_body_node.input(0));
815       } else {
816         VLOG(3) << "Keep function output node: " << node_name;
817       }
818     }
819   }
820 
821   // Return output tensor name (input of the output node) if it's safe to bypass
822   // output node, otherwise returns the output node name.
823   const auto output_tensor =
824       [&output_tensors](const OutputArgExpansion& output_arg) -> const string& {
825     const string& output_node = output_arg.output_nodes[0];
826     const auto is_output_tensor = output_tensors.find(output_node);
827     return is_output_tensor == output_tensors.end() ? output_node
828                                                     : is_output_tensor->second;
829   };
830 
831   // Build a GrapplerFunctionConnectivity from inputs and new function body.
832   GrapplerFunctionConnectivity connectivity;
833   TF_RETURN_IF_ERROR(
834       RegisterGrapplerFunctionConnectivity(item, flib, &connectivity));
835 
836   // Add function input arguments.
837   for (const InputArgExpansion& input_arg : item.inputs()) {
838     DCHECK(input_arg.placeholders.size() == 1)  // do some sanity checking
839         << "Inputs of tensor lists are not supported";
840 
841     OpDef::ArgDef arg_def;
842     arg_def.set_name(input_arg.input_name);
843     arg_def.set_type(input_arg.data_type);
844     arg_def.set_is_ref(input_arg.is_ref);
845     *func->mutable_signature()->add_input_arg() = arg_def;
846   }
847 
848   // Add function output arguments.
849   for (const OutputArgExpansion& output_arg : item.outputs()) {
850     DCHECK(output_arg.output_nodes.size() == 1)  // do some sanity checking
851         << "Outputs of tensor lists are not supported";
852 
853     OpDef::ArgDef arg_def;
854     arg_def.set_name(output_arg.output_name);
855     arg_def.set_type(output_arg.data_type);
856     arg_def.set_is_ref(output_arg.is_ref);
857     *func->mutable_signature()->add_output_arg() = arg_def;
858 
859     TF_RETURN_IF_ERROR(connectivity.AsFunctionDefInput(
860         output_tensor(output_arg),
861         &(*func->mutable_ret())[output_arg.output_name]));
862   }
863 
864   // Add function control outputs.
865   for (const ControlOutput& control_out : item.control_outputs()) {
866     func->mutable_control_ret()->insert(
867         {control_out.output_name, control_out.node_name});
868     *func->mutable_signature()->add_control_output() = control_out.output_name;
869   }
870 
871   // Copy function definition specific attributes.
872   for (const auto& attr : item.func_attr()) {
873     const auto& attr_name = attr.first;
874     const auto& attr_value = attr.second;
875     (*func->mutable_attr())[attr_name] = attr_value;
876   }
877 
878   // Copy function body nodes to the FunctionDef and update input format
879   for (const NodeDef& func_node : item.function_body().node()) {
880     const string& name = func_node.name();
881 
882     // Do not copy input placeholders.
883     if (IsPlaceholder(func_node) && input_placeholders.count(name)) continue;
884     // Do not copy output nodes that we bypassed.
885     if (IsIdentity(func_node) && output_tensors.count(name)) continue;
886 
887     NodeDef* func_def_node = func->add_node_def();
888     *func_def_node = func_node;
889     TF_RETURN_IF_ERROR(connectivity.AsFunctionDefNode(func_def_node));
890   }
891 
892   return Status::OK();
893 }
894 
895 }  // end namespace grappler
896 }  // end namespace tensorflow
897