1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
17
18 #include <algorithm>
19
20 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
21 #include "tensorflow/compiler/xla/status_macros.h"
22 #include "tensorflow/core/common_runtime/function.h"
23 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/graph_to_functiondef.h"
26 #include "tensorflow/core/graph/algorithm.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/gtl/cleanup.h"
29 #include "tensorflow/core/lib/gtl/inlined_vector.h"
30 #include "tensorflow/core/public/session_options.h"
31 #include "tensorflow/core/public/version.h"
32 #include "tensorflow/core/util/dump_graph.h"
33
34 namespace tensorflow {
35
36 namespace {
37
38 // Given original input types and argument index mapping, return the new input
39 // types.
ShuffleInputDataTypeAttribute(const std::vector<DataType> & in_types,const std::vector<int> & index_mapping)40 std::vector<DataType> ShuffleInputDataTypeAttribute(
41 const std::vector<DataType>& in_types,
42 const std::vector<int>& index_mapping) {
43 std::vector<DataType> result(index_mapping.size());
44 for (int i = 0, end = in_types.size(); i < end; i++) {
45 result[index_mapping.at(i)] = in_types[i];
46 }
47 return result;
48 }
49
50 // Given original input types, check if we need to rewrite the function (by
51 // checking if all DT_RESOURCE inputs are in the end). If the function needs to
52 // be rewritten, `resource_input_count` will be set to number of DT_RESOURCE
53 // inputs, and `index_mapping` will hold a mapping for original input index to
54 // rearranged input index.
InputTypesNeedsRearrange(const std::vector<DataType> & in_types,bool * need_rewrite,int * resource_input_count,std::vector<int> * index_mapping)55 Status InputTypesNeedsRearrange(const std::vector<DataType>& in_types,
56 bool* need_rewrite, int* resource_input_count,
57 std::vector<int>* index_mapping) {
58 int first_resource_index = -1;
59 for (int i = 0, end = in_types.size(); i < end; i++) {
60 DataType type = in_types[i];
61 if (type == DT_RESOURCE) {
62 first_resource_index = i;
63 break;
64 }
65 }
66 if (first_resource_index == -1) {
67 // No resource input. No need to rewrite.
68 *need_rewrite = false;
69 return Status::OK();
70 }
71
72 *need_rewrite = false;
73 for (int i = first_resource_index + 1, end = in_types.size(); i < end; i++) {
74 if (in_types[i] != DT_RESOURCE) {
75 *need_rewrite = true;
76 break;
77 }
78 }
79 if (!*need_rewrite) {
80 return Status::OK();
81 }
82
83 *resource_input_count = 0;
84 for (int i = 0, end = in_types.size(); i < end; i++) {
85 DataType type = in_types[i];
86 if (type == DT_RESOURCE) {
87 ++(*resource_input_count);
88 }
89 }
90 int non_resource_index = 0,
91 resource_index = in_types.size() - *resource_input_count;
92 index_mapping->resize(in_types.size());
93 for (int i = 0, end = in_types.size(); i < end; i++) {
94 if (in_types[i] != DT_RESOURCE) {
95 (*index_mapping)[i] = non_resource_index;
96 non_resource_index++;
97 } else {
98 (*index_mapping)[i] = resource_index;
99 resource_index++;
100 }
101 }
102
103 return Status::OK();
104 }
105
106 // Given mapping between original input index and rearranged input index,
107 // reorder input edges for the node.
ReorderInputEdges(Graph * g,Node * n,const std::vector<int> & index_mapping)108 Status ReorderInputEdges(Graph* g, Node* n,
109 const std::vector<int>& index_mapping) {
110 std::vector<const Edge*> input_edges;
111 for (const Edge* e : n->in_edges()) {
112 if (e->IsControlEdge()) {
113 continue;
114 }
115 input_edges.push_back(e);
116 }
117 for (const Edge* e : input_edges) {
118 Node* src = e->src();
119 int src_output = e->src_output();
120 int dst_input = e->dst_input();
121 int new_dst_input = index_mapping.at(dst_input);
122 g->RemoveEdge(e);
123 g->AddEdge(src, src_output, n, new_dst_input)->DebugString();
124 }
125 return Status::OK();
126 }
127
128 // For While node, given mapping between original input index and rearranged
129 // input index, reorder output edges for the node. DT_RESOURCE outputs are
130 // removed from the node and we will use the node's corresponding input for the
131 // edge.
ReorderOutputEdges(Graph * g,Node * n,int input_count,int resource_input_count,const std::vector<int> & index_mapping)132 Status ReorderOutputEdges(Graph* g, Node* n, int input_count,
133 int resource_input_count,
134 const std::vector<int>& index_mapping) {
135 std::vector<const Edge*> output_edges;
136 for (const Edge* e : n->out_edges()) {
137 if (e->IsControlEdge()) {
138 continue;
139 }
140 output_edges.push_back(e);
141 }
142 for (const Edge* e : output_edges) {
143 int src_output = e->src_output();
144 int new_src_output = index_mapping.at(src_output);
145 Node* dst = e->dst();
146 int dst_input = e->dst_input();
147 g->RemoveEdge(e);
148
149 if (new_src_output < input_count - resource_input_count) {
150 g->AddEdge(n, new_src_output, dst, dst_input);
151 } else {
152 const Edge* input_edge;
153 TF_RETURN_IF_ERROR(n->input_edge(new_src_output, &input_edge));
154 g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
155 }
156 }
157 return Status::OK();
158 }
159
160 // Given mapping between original input index and rearranged input index, change
161 // "index" attribute for _Arg nodes.
RearrangeArgNodes(const gtl::InlinedVector<Node *,4> * arg_nodes,const std::vector<int> & index_mapping)162 void RearrangeArgNodes(
163 const gtl::InlinedVector<Node*, 4>* arg_nodes, // non-absl ok
164 const std::vector<int>& index_mapping) {
165 for (int i = 0; i < arg_nodes->size(); i++) {
166 Node* n = (*arg_nodes)[i];
167 int new_index = index_mapping.at(i);
168 n->ClearAttr("index");
169 n->AddAttr("index", new_index);
170 }
171 }
172
173 // Given all _Retval nodes in the function, return if we need to rewrite the
174 // function (by checking if we have DT_RESOURCE return values). If we need to
175 // rewrite the function, `retval_index_mapping` will hold the mapping from
176 // original _Retval to rearranged _Retval, and `resource_retval_to_arg` will
177 // hold mapping from DT_RESOURCE _Retval index to its input _Arg index. Here we
178 // assume that all DT_RESOURCE _Retval nodes come from _Arg nodes directly.
CalculateRetvalRearrange(const gtl::InlinedVector<Node *,4> & ret_nodes,std::map<int,int> * retval_index_mapping,std::map<int,int> * resource_retval_to_arg)179 Status CalculateRetvalRearrange(
180 const gtl::InlinedVector<Node*, 4>& ret_nodes, // non-absl ok
181 std::map<int, int>* retval_index_mapping,
182 std::map<int, int>* resource_retval_to_arg) {
183 for (int i = 0, end = ret_nodes.size(); i < end; i++) {
184 Node* n = ret_nodes[i];
185 DataType t;
186 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &t));
187 if (t != DT_RESOURCE) {
188 int new_retval_index = retval_index_mapping->size();
189 retval_index_mapping->insert(std::make_pair(i, new_retval_index));
190 continue;
191 }
192
193 const Edge* e;
194 TF_RETURN_IF_ERROR(n->input_edge(0, &e));
195 if (!e->src()->IsArg()) {
196 return errors::Unimplemented(
197 "Resource _Retval node's input does not come from _Arg "
198 "directly: ",
199 e->DebugString());
200 }
201 Node* arg = e->src();
202 int src_index;
203 TF_RETURN_IF_ERROR(GetNodeAttr(arg->def(), "index", &src_index));
204 resource_retval_to_arg->insert(std::make_pair(i, src_index));
205 }
206 return Status::OK();
207 }
208
209 // Given original output types and return value index mapping, return the new
210 // output types. Notice that DT_RESOURCE will be removed.
ShuffleOutputDataTypeAttribute(const std::vector<DataType> & out_types,const std::map<int,int> & index_mapping)211 std::vector<DataType> ShuffleOutputDataTypeAttribute(
212 const std::vector<DataType>& out_types,
213 const std::map<int, int>& index_mapping) {
214 std::vector<DataType> result(index_mapping.size());
215 for (int i = 0; i < out_types.size(); i++) {
216 auto iter = index_mapping.find(i);
217 if (iter != index_mapping.end()) {
218 result[iter->second] = out_types[i];
219 }
220 }
221 return result;
222 }
223
224 // For StatefulPartitionedCall node, given mapping between original input index
225 // and rearranged input index, reorder output edges for the node. DT_RESOURCE
226 // outputs are removed from the node and we will use the node's corresponding
227 // input for the edge.
RearrangeOutputEdges(Node * n,Graph * g,const std::map<int,int> & retval_index_mapping,const std::map<int,int> & resource_retval_to_arg)228 Status RearrangeOutputEdges(Node* n, Graph* g,
229 const std::map<int, int>& retval_index_mapping,
230 const std::map<int, int>& resource_retval_to_arg) {
231 std::vector<const Edge*> out_edges;
232 for (const Edge* e : n->out_edges()) {
233 if (!e->IsControlEdge()) {
234 out_edges.push_back(e);
235 }
236 }
237 for (const Edge* e : out_edges) {
238 Node* dst = e->dst();
239 int dst_input = e->dst_input();
240 int src_output = e->src_output();
241 auto iter = retval_index_mapping.find(src_output);
242 if (iter == retval_index_mapping.end()) {
243 TF_RET_CHECK(resource_retval_to_arg.find(src_output) !=
244 resource_retval_to_arg.end());
245 g->RemoveEdge(e);
246 const Edge* input_edge;
247 TF_RETURN_IF_ERROR(
248 n->input_edge(resource_retval_to_arg.at(src_output), &input_edge));
249 g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
250 } else {
251 g->RemoveEdge(e);
252 g->AddEdge(n, iter->second, dst, dst_input);
253 }
254 }
255 return Status::OK();
256 }
257
258 // Given mapping between original output index and rearranged output index,
259 // change "index" attribute for _Retval nodes. Notice that DT_RESOURCE _Retval
260 // nodes will be removed.
RearrangeRetvalNodes(const gtl::InlinedVector<Node *,4> & ret_nodes,Graph * g,const std::map<int,int> & retval_index_mapping)261 void RearrangeRetvalNodes(
262 const gtl::InlinedVector<Node*, 4>& ret_nodes, // non-absl ok
263 Graph* g, const std::map<int, int>& retval_index_mapping) {
264 for (int i = 0, end = ret_nodes.size(); i < end; i++) {
265 Node* n = ret_nodes[i];
266 auto iter = retval_index_mapping.find(i);
267 if (iter == retval_index_mapping.end()) {
268 g->RemoveNode(n);
269 } else {
270 n->ClearAttr("index");
271 n->AddAttr("index", iter->second);
272 }
273 }
274 }
275
MaybeRewriteWhileNode(std::function<Status (const NameAttrList &,const FunctionBody **)> get_function_body_fn,Graph * g,Node * n,FunctionLibraryDefinition * fld,bool * node_rewritten)276 Status MaybeRewriteWhileNode(
277 std::function<Status(const NameAttrList&, const FunctionBody**)>
278 get_function_body_fn,
279 Graph* g, Node* n, FunctionLibraryDefinition* fld, bool* node_rewritten) {
280 // Check if this While node needs rewrite.
281 std::vector<DataType> types;
282 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &types));
283 bool input_need_rearrange;
284 int resource_input_count;
285 std::vector<int> index_mapping;
286 TF_RETURN_IF_ERROR(InputTypesNeedsRearrange(
287 types, &input_need_rearrange, &resource_input_count, &index_mapping));
288 if (!input_need_rearrange) {
289 *node_rewritten = false;
290 return Status::OK();
291 }
292
293 *node_rewritten = true;
294
295 // Modify "T" attribute for this While node.
296 std::vector<DataType> new_types =
297 ShuffleInputDataTypeAttribute(types, index_mapping);
298 n->ClearAttr("T");
299 n->AddAttr("T", new_types);
300
301 // Reorder input and output edges.
302 TF_RETURN_IF_ERROR(ReorderInputEdges(g, n, index_mapping));
303 TF_RETURN_IF_ERROR(ReorderOutputEdges(g, n, types.size(),
304 resource_input_count, index_mapping));
305
306 // Modify cond and body functions.
307 for (auto const& attr_name : std::vector<string>{"cond", "body"}) {
308 NameAttrList attr_value;
309 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &attr_value));
310 const FunctionBody* fbody;
311 TF_RETURN_IF_ERROR(get_function_body_fn(attr_value, &fbody));
312
313 // Check that resource _Arg nodes for While node are always returned with
314 // the same index, and we don't have cases like this:
315 // tf.while_loop(
316 // cond,
317 // lambda resource_var1, resource_var2: [resource_var2, resource_var1],
318 // [resource_var1, resource_var2])
319 if (attr_name == "body") {
320 for (int i = 0, end = fbody->ret_nodes.size(); i < end; i++) {
321 Node* n = fbody->ret_nodes[i];
322 DataType dtype;
323 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
324 if (dtype != DT_RESOURCE) {
325 continue;
326 }
327
328 Node* input_node;
329 TF_RETURN_IF_ERROR(n->input_node(0, &input_node));
330 while (input_node->IsIdentity()) {
331 TF_RETURN_IF_ERROR(input_node->input_node(0, &input_node));
332 }
333 if (input_node->IsArg()) {
334 int index;
335 TF_RETURN_IF_ERROR(GetNodeAttr(input_node->def(), "index", &index));
336 if (index != i) {
337 return errors::Unimplemented("While node ", n->DebugString(),
338 " has resource _Retval[", i,
339 "] coming from _Arg[", index, "]");
340 }
341 } else {
342 return errors::Unimplemented("Encountered node ",
343 input_node->DebugString(),
344 " while tracing _Arg node for _Retval[",
345 i, "] of while node ", n->DebugString());
346 }
347 }
348 }
349
350 RearrangeArgNodes(&fbody->arg_nodes, index_mapping);
351 if (attr_name == "body") {
352 for (int i = 0, end = fbody->ret_nodes.size(); i < end; i++) {
353 Node* n = fbody->ret_nodes[i];
354 int new_index = index_mapping.at(i);
355 if (new_index < types.size() - resource_input_count) {
356 n->ClearAttr("index");
357 n->AddAttr("index", new_index);
358 } else {
359 fbody->graph->RemoveNode(n);
360 }
361 }
362 }
363
364 // Save the new FunctionDef.
365 FunctionDef new_fdef;
366 string new_name =
367 fld->UniqueFunctionName(absl::StrCat(attr_value.name(), "_rearrange_"));
368 TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, new_name, &new_fdef));
369 TF_RETURN_IF_ERROR(
370 fld->AddFunctionDef(new_fdef, fld->GetStackTraces(attr_value.name())));
371
372 // Change node to use rewritten function.
373 attr_value.set_name(new_name);
374 n->ClearAttr(attr_name);
375 n->AddAttr(attr_name, attr_value);
376 }
377 return Status::OK();
378 }
379
MaybeRewriteIfNode(std::function<Status (const NameAttrList &,const FunctionBody **)> get_function_body_fn,Graph * g,Node * n,FunctionLibraryDefinition * fld,bool * node_rewritten,const FunctionLibraryDefinition * global_fld)380 Status MaybeRewriteIfNode(
381 std::function<Status(const NameAttrList&, const FunctionBody**)>
382 get_function_body_fn,
383 Graph* g, Node* n, FunctionLibraryDefinition* fld, bool* node_rewritten,
384 const FunctionLibraryDefinition* global_fld) {
385 // This node needs rewrite when either of these is true:
386 // 1) Tin has DT_RESOURCE which requires rearrange;
387 // 2) Tout has DT_RESOURCE.
388 std::vector<DataType> in_types;
389 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "Tin", &in_types));
390 bool input_need_rearrange;
391 int resource_input_count;
392 std::vector<int> index_mapping;
393 TF_RETURN_IF_ERROR(InputTypesNeedsRearrange(
394 in_types, &input_need_rearrange, &resource_input_count, &index_mapping));
395 std::vector<DataType> out_types;
396 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "Tout", &out_types));
397 bool has_resource_output = std::find(out_types.begin(), out_types.end(),
398 DT_RESOURCE) != out_types.end();
399 if (!input_need_rearrange && !has_resource_output) {
400 *node_rewritten = false;
401 return Status::OK();
402 }
403
404 *node_rewritten = true;
405
406 if (input_need_rearrange) {
407 // Reorder input edges.
408 std::vector<const Edge*> input_edges;
409 for (const Edge* e : n->in_edges()) {
410 if (e->IsControlEdge() || e->dst_input() == 0) {
411 continue;
412 }
413 input_edges.push_back(e);
414 }
415 for (const Edge* e : input_edges) {
416 Node* src = e->src();
417 int src_output = e->src_output();
418 int dst_input = e->dst_input();
419 int new_dst_input = index_mapping.at(dst_input - 1) + 1;
420 g->RemoveEdge(e);
421 g->AddEdge(src, src_output, n, new_dst_input)->DebugString();
422 }
423
424 // Change Tin attribute.
425 std::vector<DataType> new_in_types =
426 ShuffleInputDataTypeAttribute(in_types, index_mapping);
427 n->ClearAttr("Tin");
428 n->AddAttr("Tin", new_in_types);
429 }
430
431 std::map<int, int> resource_retval_to_arg, retval_index_mapping;
432 for (auto const& attr_name :
433 std::vector<string>{"then_branch", "else_branch"}) {
434 NameAttrList f;
435 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &f));
436 const FunctionBody* fbody;
437 TF_RETURN_IF_ERROR(get_function_body_fn(f, &fbody));
438
439 if (input_need_rearrange) {
440 // Change _Arg node index.
441 RearrangeArgNodes(&fbody->arg_nodes, index_mapping);
442 }
443
444 if (has_resource_output) {
445 // Resource _Retval must come from resource _Arg directly, or we do
446 // not support it.
447 TF_RETURN_IF_ERROR(CalculateRetvalRearrange(
448 fbody->ret_nodes, &retval_index_mapping, &resource_retval_to_arg));
449
450 // Change index for _Retval nodes.
451 RearrangeRetvalNodes(fbody->ret_nodes, fbody->graph,
452 retval_index_mapping);
453 }
454
455 // Save the new FunctionDef.
456 FunctionDef new_fdef;
457 string new_name =
458 fld->UniqueFunctionName(absl::StrCat(f.name(), "_rearrange_"));
459 TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, new_name, &new_fdef));
460 const StackTracesMap& stack_traces =
461 fld->GetStackTraces(f.name()).empty() && global_fld
462 ? global_fld->GetStackTraces(f.name())
463 : fld->GetStackTraces(f.name());
464 TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef, stack_traces));
465
466 // Change node to use rewritten function.
467 f.set_name(new_name);
468 n->ClearAttr(attr_name);
469 n->AddAttr(attr_name, f);
470 }
471
472 if (has_resource_output) {
473 // Rearrange output edges.
474 std::vector<const Edge*> out_edges;
475 for (const Edge* e : n->out_edges()) {
476 if (!e->IsControlEdge()) {
477 out_edges.push_back(e);
478 }
479 }
480 for (const Edge* e : out_edges) {
481 Node* dst = e->dst();
482 int dst_input = e->dst_input();
483 int src_output = e->src_output();
484 auto iter = retval_index_mapping.find(src_output);
485 if (iter == retval_index_mapping.end()) {
486 TF_RET_CHECK(resource_retval_to_arg.find(src_output) !=
487 resource_retval_to_arg.end());
488 g->RemoveEdge(e);
489 const Edge* input_edge;
490 TF_RETURN_IF_ERROR(n->input_edge(
491 resource_retval_to_arg.at(src_output) + 1, &input_edge));
492 g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
493 } else {
494 g->RemoveEdge(e);
495 g->AddEdge(n, iter->second, dst, dst_input);
496 }
497 }
498
499 // Change Tout attribute for the node.
500 std::vector<DataType> new_out_types =
501 ShuffleOutputDataTypeAttribute(out_types, retval_index_mapping);
502 n->ClearAttr("Tout");
503 n->AddAttr("Tout", new_out_types);
504 }
505 return Status::OK();
506 }
507
508 } // namespace
509
RearrangeFunctionArguments(std::function<Status (const NameAttrList &,const FunctionBody **)> get_function_body_fn,Graph * g,FunctionLibraryDefinition * fld,const FunctionLibraryDefinition * global_fld)510 Status RearrangeFunctionArguments(
511 std::function<Status(const NameAttrList&, const FunctionBody**)>
512 get_function_body_fn,
513 Graph* g, FunctionLibraryDefinition* fld,
514 const FunctionLibraryDefinition* global_fld) {
515 // Inline StatefulPartitionedCall nodes.
516 std::vector<Node*> call_nodes;
517 for (Node* n : g->nodes()) {
518 if (n->type_string() == "StatefulPartitionedCall") {
519 call_nodes.push_back(n);
520 }
521 }
522 for (Node* n : call_nodes) {
523 NameAttrList func_name_attrs;
524 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &func_name_attrs));
525 const FunctionBody* fbody;
526 TF_RETURN_IF_ERROR(get_function_body_fn(func_name_attrs, &fbody));
527 InlineFunctionBodyOptions opts;
528 Status s = InlineFunctionBody(*fld, g, n, fbody, opts);
529 // Inlining might fail because the function is marked with attribute
530 // _noinline.
531 s.IgnoreError();
532 FixupSourceAndSinkEdges(g);
533 }
534
535 // Rewrite If/While nodes.
536 for (Node* n : g->nodes()) {
537 if (n->IsWhileNode()) {
538 bool node_rewritten;
539 TF_RETURN_IF_ERROR(MaybeRewriteWhileNode(get_function_body_fn, g, n, fld,
540 &node_rewritten));
541 } else if (n->IsIfNode()) {
542 bool node_rewritten;
543 TF_RETURN_IF_ERROR(MaybeRewriteIfNode(get_function_body_fn, g, n, fld,
544 &node_rewritten, global_fld));
545 }
546 }
547
548 return Status::OK();
549 }
550
551 } // namespace tensorflow
552