1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
17 
18 #include <algorithm>
19 
20 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
21 #include "tensorflow/compiler/xla/status_macros.h"
22 #include "tensorflow/core/common_runtime/function.h"
23 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/graph_to_functiondef.h"
26 #include "tensorflow/core/graph/algorithm.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/gtl/cleanup.h"
29 #include "tensorflow/core/lib/gtl/inlined_vector.h"
30 #include "tensorflow/core/public/session_options.h"
31 #include "tensorflow/core/public/version.h"
32 #include "tensorflow/core/util/dump_graph.h"
33 
34 namespace tensorflow {
35 
36 namespace {
37 
38 // Given original input types and argument index mapping, return the new input
39 // types.
ShuffleInputDataTypeAttribute(const std::vector<DataType> & in_types,const std::vector<int> & index_mapping)40 std::vector<DataType> ShuffleInputDataTypeAttribute(
41     const std::vector<DataType>& in_types,
42     const std::vector<int>& index_mapping) {
43   std::vector<DataType> result(index_mapping.size());
44   for (int i = 0, end = in_types.size(); i < end; i++) {
45     result[index_mapping.at(i)] = in_types[i];
46   }
47   return result;
48 }
49 
50 // Given original input types, check if we need to rewrite the function (by
51 // checking if all DT_RESOURCE inputs are in the end). If the function needs to
52 // be rewritten, `resource_input_count` will be set to number of DT_RESOURCE
53 // inputs, and `index_mapping` will hold a mapping for original input index to
54 // rearranged input index.
InputTypesNeedsRearrange(const std::vector<DataType> & in_types,bool * need_rewrite,int * resource_input_count,std::vector<int> * index_mapping)55 Status InputTypesNeedsRearrange(const std::vector<DataType>& in_types,
56                                 bool* need_rewrite, int* resource_input_count,
57                                 std::vector<int>* index_mapping) {
58   int first_resource_index = -1;
59   for (int i = 0, end = in_types.size(); i < end; i++) {
60     DataType type = in_types[i];
61     if (type == DT_RESOURCE) {
62       first_resource_index = i;
63       break;
64     }
65   }
66   if (first_resource_index == -1) {
67     // No resource input. No need to rewrite.
68     *need_rewrite = false;
69     return Status::OK();
70   }
71 
72   *need_rewrite = false;
73   for (int i = first_resource_index + 1, end = in_types.size(); i < end; i++) {
74     if (in_types[i] != DT_RESOURCE) {
75       *need_rewrite = true;
76       break;
77     }
78   }
79   if (!*need_rewrite) {
80     return Status::OK();
81   }
82 
83   *resource_input_count = 0;
84   for (int i = 0, end = in_types.size(); i < end; i++) {
85     DataType type = in_types[i];
86     if (type == DT_RESOURCE) {
87       ++(*resource_input_count);
88     }
89   }
90   int non_resource_index = 0,
91       resource_index = in_types.size() - *resource_input_count;
92   index_mapping->resize(in_types.size());
93   for (int i = 0, end = in_types.size(); i < end; i++) {
94     if (in_types[i] != DT_RESOURCE) {
95       (*index_mapping)[i] = non_resource_index;
96       non_resource_index++;
97     } else {
98       (*index_mapping)[i] = resource_index;
99       resource_index++;
100     }
101   }
102 
103   return Status::OK();
104 }
105 
106 // Given mapping between original input index and rearranged input index,
107 // reorder input edges for the node.
ReorderInputEdges(Graph * g,Node * n,const std::vector<int> & index_mapping)108 Status ReorderInputEdges(Graph* g, Node* n,
109                          const std::vector<int>& index_mapping) {
110   std::vector<const Edge*> input_edges;
111   for (const Edge* e : n->in_edges()) {
112     if (e->IsControlEdge()) {
113       continue;
114     }
115     input_edges.push_back(e);
116   }
117   for (const Edge* e : input_edges) {
118     Node* src = e->src();
119     int src_output = e->src_output();
120     int dst_input = e->dst_input();
121     int new_dst_input = index_mapping.at(dst_input);
122     g->RemoveEdge(e);
123     g->AddEdge(src, src_output, n, new_dst_input)->DebugString();
124   }
125   return Status::OK();
126 }
127 
128 // For While node, given mapping between original input index and rearranged
129 // input index, reorder output edges for the node. DT_RESOURCE outputs are
130 // removed from the node and we will use the node's corresponding input for the
131 // edge.
ReorderOutputEdges(Graph * g,Node * n,int input_count,int resource_input_count,const std::vector<int> & index_mapping)132 Status ReorderOutputEdges(Graph* g, Node* n, int input_count,
133                           int resource_input_count,
134                           const std::vector<int>& index_mapping) {
135   std::vector<const Edge*> output_edges;
136   for (const Edge* e : n->out_edges()) {
137     if (e->IsControlEdge()) {
138       continue;
139     }
140     output_edges.push_back(e);
141   }
142   for (const Edge* e : output_edges) {
143     int src_output = e->src_output();
144     int new_src_output = index_mapping.at(src_output);
145     Node* dst = e->dst();
146     int dst_input = e->dst_input();
147     g->RemoveEdge(e);
148 
149     if (new_src_output < input_count - resource_input_count) {
150       g->AddEdge(n, new_src_output, dst, dst_input);
151     } else {
152       const Edge* input_edge;
153       TF_RETURN_IF_ERROR(n->input_edge(new_src_output, &input_edge));
154       g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
155     }
156   }
157   return Status::OK();
158 }
159 
160 // Given mapping between original input index and rearranged input index, change
161 // "index" attribute for _Arg nodes.
RearrangeArgNodes(const gtl::InlinedVector<Node *,4> * arg_nodes,const std::vector<int> & index_mapping)162 void RearrangeArgNodes(
163     const gtl::InlinedVector<Node*, 4>* arg_nodes,  // non-absl ok
164     const std::vector<int>& index_mapping) {
165   for (int i = 0; i < arg_nodes->size(); i++) {
166     Node* n = (*arg_nodes)[i];
167     int new_index = index_mapping.at(i);
168     n->ClearAttr("index");
169     n->AddAttr("index", new_index);
170   }
171 }
172 
173 // Given all _Retval nodes in the function, return if we need to rewrite the
174 // function (by checking if we have DT_RESOURCE return values). If we need to
175 // rewrite the function, `retval_index_mapping` will hold the mapping from
176 // original _Retval to rearranged _Retval, and `resource_retval_to_arg` will
177 // hold mapping from DT_RESOURCE _Retval index to its input _Arg index. Here we
178 // assume that all DT_RESOURCE _Retval nodes come from _Arg nodes directly.
CalculateRetvalRearrange(const gtl::InlinedVector<Node *,4> & ret_nodes,std::map<int,int> * retval_index_mapping,std::map<int,int> * resource_retval_to_arg)179 Status CalculateRetvalRearrange(
180     const gtl::InlinedVector<Node*, 4>& ret_nodes,  // non-absl ok
181     std::map<int, int>* retval_index_mapping,
182     std::map<int, int>* resource_retval_to_arg) {
183   for (int i = 0, end = ret_nodes.size(); i < end; i++) {
184     Node* n = ret_nodes[i];
185     DataType t;
186     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &t));
187     if (t != DT_RESOURCE) {
188       int new_retval_index = retval_index_mapping->size();
189       retval_index_mapping->insert(std::make_pair(i, new_retval_index));
190       continue;
191     }
192 
193     const Edge* e;
194     TF_RETURN_IF_ERROR(n->input_edge(0, &e));
195     if (!e->src()->IsArg()) {
196       return errors::Unimplemented(
197           "Resource _Retval node's input does not come from _Arg "
198           "directly: ",
199           e->DebugString());
200     }
201     Node* arg = e->src();
202     int src_index;
203     TF_RETURN_IF_ERROR(GetNodeAttr(arg->def(), "index", &src_index));
204     resource_retval_to_arg->insert(std::make_pair(i, src_index));
205   }
206   return Status::OK();
207 }
208 
209 // Given original output types and return value index mapping, return the new
210 // output types. Notice that DT_RESOURCE will be removed.
ShuffleOutputDataTypeAttribute(const std::vector<DataType> & out_types,const std::map<int,int> & index_mapping)211 std::vector<DataType> ShuffleOutputDataTypeAttribute(
212     const std::vector<DataType>& out_types,
213     const std::map<int, int>& index_mapping) {
214   std::vector<DataType> result(index_mapping.size());
215   for (int i = 0; i < out_types.size(); i++) {
216     auto iter = index_mapping.find(i);
217     if (iter != index_mapping.end()) {
218       result[iter->second] = out_types[i];
219     }
220   }
221   return result;
222 }
223 
224 // For StatefulPartitionedCall node, given mapping between original input index
225 // and rearranged input index, reorder output edges for the node. DT_RESOURCE
226 // outputs are removed from the node and we will use the node's corresponding
227 // input for the edge.
RearrangeOutputEdges(Node * n,Graph * g,const std::map<int,int> & retval_index_mapping,const std::map<int,int> & resource_retval_to_arg)228 Status RearrangeOutputEdges(Node* n, Graph* g,
229                             const std::map<int, int>& retval_index_mapping,
230                             const std::map<int, int>& resource_retval_to_arg) {
231   std::vector<const Edge*> out_edges;
232   for (const Edge* e : n->out_edges()) {
233     if (!e->IsControlEdge()) {
234       out_edges.push_back(e);
235     }
236   }
237   for (const Edge* e : out_edges) {
238     Node* dst = e->dst();
239     int dst_input = e->dst_input();
240     int src_output = e->src_output();
241     auto iter = retval_index_mapping.find(src_output);
242     if (iter == retval_index_mapping.end()) {
243       TF_RET_CHECK(resource_retval_to_arg.find(src_output) !=
244                    resource_retval_to_arg.end());
245       g->RemoveEdge(e);
246       const Edge* input_edge;
247       TF_RETURN_IF_ERROR(
248           n->input_edge(resource_retval_to_arg.at(src_output), &input_edge));
249       g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
250     } else {
251       g->RemoveEdge(e);
252       g->AddEdge(n, iter->second, dst, dst_input);
253     }
254   }
255   return Status::OK();
256 }
257 
258 // Given mapping between original output index and rearranged output index,
259 // change "index" attribute for _Retval nodes. Notice that DT_RESOURCE _Retval
260 // nodes will be removed.
RearrangeRetvalNodes(const gtl::InlinedVector<Node *,4> & ret_nodes,Graph * g,const std::map<int,int> & retval_index_mapping)261 void RearrangeRetvalNodes(
262     const gtl::InlinedVector<Node*, 4>& ret_nodes,  // non-absl ok
263     Graph* g, const std::map<int, int>& retval_index_mapping) {
264   for (int i = 0, end = ret_nodes.size(); i < end; i++) {
265     Node* n = ret_nodes[i];
266     auto iter = retval_index_mapping.find(i);
267     if (iter == retval_index_mapping.end()) {
268       g->RemoveNode(n);
269     } else {
270       n->ClearAttr("index");
271       n->AddAttr("index", iter->second);
272     }
273   }
274 }
275 
MaybeRewriteWhileNode(std::function<Status (const NameAttrList &,const FunctionBody **)> get_function_body_fn,Graph * g,Node * n,FunctionLibraryDefinition * fld,bool * node_rewritten)276 Status MaybeRewriteWhileNode(
277     std::function<Status(const NameAttrList&, const FunctionBody**)>
278         get_function_body_fn,
279     Graph* g, Node* n, FunctionLibraryDefinition* fld, bool* node_rewritten) {
280   // Check if this While node needs rewrite.
281   std::vector<DataType> types;
282   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &types));
283   bool input_need_rearrange;
284   int resource_input_count;
285   std::vector<int> index_mapping;
286   TF_RETURN_IF_ERROR(InputTypesNeedsRearrange(
287       types, &input_need_rearrange, &resource_input_count, &index_mapping));
288   if (!input_need_rearrange) {
289     *node_rewritten = false;
290     return Status::OK();
291   }
292 
293   *node_rewritten = true;
294 
295   // Modify "T" attribute for this While node.
296   std::vector<DataType> new_types =
297       ShuffleInputDataTypeAttribute(types, index_mapping);
298   n->ClearAttr("T");
299   n->AddAttr("T", new_types);
300 
301   // Reorder input and output edges.
302   TF_RETURN_IF_ERROR(ReorderInputEdges(g, n, index_mapping));
303   TF_RETURN_IF_ERROR(ReorderOutputEdges(g, n, types.size(),
304                                         resource_input_count, index_mapping));
305 
306   // Modify cond and body functions.
307   for (auto const& attr_name : std::vector<string>{"cond", "body"}) {
308     NameAttrList attr_value;
309     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &attr_value));
310     const FunctionBody* fbody;
311     TF_RETURN_IF_ERROR(get_function_body_fn(attr_value, &fbody));
312 
313     // Check that resource _Arg nodes for While node are always returned with
314     // the same index, and we don't have cases like this:
315     // tf.while_loop(
316     //     cond,
317     //     lambda resource_var1, resource_var2: [resource_var2, resource_var1],
318     //     [resource_var1, resource_var2])
319     if (attr_name == "body") {
320       for (int i = 0, end = fbody->ret_nodes.size(); i < end; i++) {
321         Node* n = fbody->ret_nodes[i];
322         DataType dtype;
323         TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
324         if (dtype != DT_RESOURCE) {
325           continue;
326         }
327 
328         Node* input_node;
329         TF_RETURN_IF_ERROR(n->input_node(0, &input_node));
330         while (input_node->IsIdentity()) {
331           TF_RETURN_IF_ERROR(input_node->input_node(0, &input_node));
332         }
333         if (input_node->IsArg()) {
334           int index;
335           TF_RETURN_IF_ERROR(GetNodeAttr(input_node->def(), "index", &index));
336           if (index != i) {
337             return errors::Unimplemented("While node ", n->DebugString(),
338                                          " has resource _Retval[", i,
339                                          "] coming from _Arg[", index, "]");
340           }
341         } else {
342           return errors::Unimplemented("Encountered node ",
343                                        input_node->DebugString(),
344                                        " while tracing _Arg node for _Retval[",
345                                        i, "] of while node ", n->DebugString());
346         }
347       }
348     }
349 
350     RearrangeArgNodes(&fbody->arg_nodes, index_mapping);
351     if (attr_name == "body") {
352       for (int i = 0, end = fbody->ret_nodes.size(); i < end; i++) {
353         Node* n = fbody->ret_nodes[i];
354         int new_index = index_mapping.at(i);
355         if (new_index < types.size() - resource_input_count) {
356           n->ClearAttr("index");
357           n->AddAttr("index", new_index);
358         } else {
359           fbody->graph->RemoveNode(n);
360         }
361       }
362     }
363 
364     // Save the new FunctionDef.
365     FunctionDef new_fdef;
366     string new_name =
367         fld->UniqueFunctionName(absl::StrCat(attr_value.name(), "_rearrange_"));
368     TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, new_name, &new_fdef));
369     TF_RETURN_IF_ERROR(
370         fld->AddFunctionDef(new_fdef, fld->GetStackTraces(attr_value.name())));
371 
372     // Change node to use rewritten function.
373     attr_value.set_name(new_name);
374     n->ClearAttr(attr_name);
375     n->AddAttr(attr_name, attr_value);
376   }
377   return Status::OK();
378 }
379 
MaybeRewriteIfNode(std::function<Status (const NameAttrList &,const FunctionBody **)> get_function_body_fn,Graph * g,Node * n,FunctionLibraryDefinition * fld,bool * node_rewritten,const FunctionLibraryDefinition * global_fld)380 Status MaybeRewriteIfNode(
381     std::function<Status(const NameAttrList&, const FunctionBody**)>
382         get_function_body_fn,
383     Graph* g, Node* n, FunctionLibraryDefinition* fld, bool* node_rewritten,
384     const FunctionLibraryDefinition* global_fld) {
385   // This node needs rewrite when either of these is true:
386   // 1) Tin has DT_RESOURCE which requires rearrange;
387   // 2) Tout has DT_RESOURCE.
388   std::vector<DataType> in_types;
389   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "Tin", &in_types));
390   bool input_need_rearrange;
391   int resource_input_count;
392   std::vector<int> index_mapping;
393   TF_RETURN_IF_ERROR(InputTypesNeedsRearrange(
394       in_types, &input_need_rearrange, &resource_input_count, &index_mapping));
395   std::vector<DataType> out_types;
396   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "Tout", &out_types));
397   bool has_resource_output = std::find(out_types.begin(), out_types.end(),
398                                        DT_RESOURCE) != out_types.end();
399   if (!input_need_rearrange && !has_resource_output) {
400     *node_rewritten = false;
401     return Status::OK();
402   }
403 
404   *node_rewritten = true;
405 
406   if (input_need_rearrange) {
407     // Reorder input edges.
408     std::vector<const Edge*> input_edges;
409     for (const Edge* e : n->in_edges()) {
410       if (e->IsControlEdge() || e->dst_input() == 0) {
411         continue;
412       }
413       input_edges.push_back(e);
414     }
415     for (const Edge* e : input_edges) {
416       Node* src = e->src();
417       int src_output = e->src_output();
418       int dst_input = e->dst_input();
419       int new_dst_input = index_mapping.at(dst_input - 1) + 1;
420       g->RemoveEdge(e);
421       g->AddEdge(src, src_output, n, new_dst_input)->DebugString();
422     }
423 
424     // Change Tin attribute.
425     std::vector<DataType> new_in_types =
426         ShuffleInputDataTypeAttribute(in_types, index_mapping);
427     n->ClearAttr("Tin");
428     n->AddAttr("Tin", new_in_types);
429   }
430 
431   std::map<int, int> resource_retval_to_arg, retval_index_mapping;
432   for (auto const& attr_name :
433        std::vector<string>{"then_branch", "else_branch"}) {
434     NameAttrList f;
435     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &f));
436     const FunctionBody* fbody;
437     TF_RETURN_IF_ERROR(get_function_body_fn(f, &fbody));
438 
439     if (input_need_rearrange) {
440       // Change _Arg node index.
441       RearrangeArgNodes(&fbody->arg_nodes, index_mapping);
442     }
443 
444     if (has_resource_output) {
445       // Resource _Retval must come from resource _Arg directly, or we do
446       // not support it.
447       TF_RETURN_IF_ERROR(CalculateRetvalRearrange(
448           fbody->ret_nodes, &retval_index_mapping, &resource_retval_to_arg));
449 
450       // Change index for _Retval nodes.
451       RearrangeRetvalNodes(fbody->ret_nodes, fbody->graph,
452                            retval_index_mapping);
453     }
454 
455     // Save the new FunctionDef.
456     FunctionDef new_fdef;
457     string new_name =
458         fld->UniqueFunctionName(absl::StrCat(f.name(), "_rearrange_"));
459     TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, new_name, &new_fdef));
460     const StackTracesMap& stack_traces =
461         fld->GetStackTraces(f.name()).empty() && global_fld
462             ? global_fld->GetStackTraces(f.name())
463             : fld->GetStackTraces(f.name());
464     TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef, stack_traces));
465 
466     // Change node to use rewritten function.
467     f.set_name(new_name);
468     n->ClearAttr(attr_name);
469     n->AddAttr(attr_name, f);
470   }
471 
472   if (has_resource_output) {
473     // Rearrange output edges.
474     std::vector<const Edge*> out_edges;
475     for (const Edge* e : n->out_edges()) {
476       if (!e->IsControlEdge()) {
477         out_edges.push_back(e);
478       }
479     }
480     for (const Edge* e : out_edges) {
481       Node* dst = e->dst();
482       int dst_input = e->dst_input();
483       int src_output = e->src_output();
484       auto iter = retval_index_mapping.find(src_output);
485       if (iter == retval_index_mapping.end()) {
486         TF_RET_CHECK(resource_retval_to_arg.find(src_output) !=
487                      resource_retval_to_arg.end());
488         g->RemoveEdge(e);
489         const Edge* input_edge;
490         TF_RETURN_IF_ERROR(n->input_edge(
491             resource_retval_to_arg.at(src_output) + 1, &input_edge));
492         g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
493       } else {
494         g->RemoveEdge(e);
495         g->AddEdge(n, iter->second, dst, dst_input);
496       }
497     }
498 
499     // Change Tout attribute for the node.
500     std::vector<DataType> new_out_types =
501         ShuffleOutputDataTypeAttribute(out_types, retval_index_mapping);
502     n->ClearAttr("Tout");
503     n->AddAttr("Tout", new_out_types);
504   }
505   return Status::OK();
506 }
507 
508 }  // namespace
509 
RearrangeFunctionArguments(std::function<Status (const NameAttrList &,const FunctionBody **)> get_function_body_fn,Graph * g,FunctionLibraryDefinition * fld,const FunctionLibraryDefinition * global_fld)510 Status RearrangeFunctionArguments(
511     std::function<Status(const NameAttrList&, const FunctionBody**)>
512         get_function_body_fn,
513     Graph* g, FunctionLibraryDefinition* fld,
514     const FunctionLibraryDefinition* global_fld) {
515   // Inline StatefulPartitionedCall nodes.
516   std::vector<Node*> call_nodes;
517   for (Node* n : g->nodes()) {
518     if (n->type_string() == "StatefulPartitionedCall") {
519       call_nodes.push_back(n);
520     }
521   }
522   for (Node* n : call_nodes) {
523     NameAttrList func_name_attrs;
524     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &func_name_attrs));
525     const FunctionBody* fbody;
526     TF_RETURN_IF_ERROR(get_function_body_fn(func_name_attrs, &fbody));
527     InlineFunctionBodyOptions opts;
528     Status s = InlineFunctionBody(*fld, g, n, fbody, opts);
529     // Inlining might fail because the function is marked with attribute
530     // _noinline.
531     s.IgnoreError();
532     FixupSourceAndSinkEdges(g);
533   }
534 
535   // Rewrite If/While nodes.
536   for (Node* n : g->nodes()) {
537     if (n->IsWhileNode()) {
538       bool node_rewritten;
539       TF_RETURN_IF_ERROR(MaybeRewriteWhileNode(get_function_body_fn, g, n, fld,
540                                                &node_rewritten));
541     } else if (n->IsIfNode()) {
542       bool node_rewritten;
543       TF_RETURN_IF_ERROR(MaybeRewriteIfNode(get_function_body_fn, g, n, fld,
544                                             &node_rewritten, global_fld));
545     }
546   }
547 
548   return Status::OK();
549 }
550 
551 }  // namespace tensorflow
552