1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.h"
17 
18 #include <queue>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/container/node_hash_map.h"
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_cat.h"
25 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
26 #include "tensorflow/compiler/jit/encapsulate_util.h"
27 #include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
28 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
29 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/core/common_runtime/function.h"
32 #include "tensorflow/core/framework/function.h"
33 #include "tensorflow/core/framework/graph_to_functiondef.h"
34 #include "tensorflow/core/framework/node_def.pb.h"
35 #include "tensorflow/core/framework/node_def_builder.h"
36 #include "tensorflow/core/framework/node_def_util.h"
37 #include "tensorflow/core/graph/algorithm.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/gtl/cleanup.h"
40 #include "tensorflow/core/lib/gtl/flatset.h"
41 #include "tensorflow/core/lib/hash/hash.h"
42 #include "tensorflow/core/lib/strings/proto_serialization.h"
43 #include "tensorflow/core/lib/strings/str_util.h"
44 #include "tensorflow/core/public/session_options.h"
45 #include "tensorflow/core/public/version.h"
46 #include "tensorflow/core/tpu/tpu_compile_interface.h"
47 #include "tensorflow/core/tpu/tpu_defs.h"
48 #include "tensorflow/core/util/dump_graph.h"
49 
50 namespace tensorflow {
51 
52 namespace {
53 
54 const char* const kTPUReplicatedInput = "TPUReplicatedInput";
55 const char* const kTPUReplicatedOutput = "TPUReplicatedOutput";
56 const char* const kPivotForClusterAttr = "_pivot_for_cluster";
57 const char* const kTPUPartitionedInput = "TPUPartitionedInput";
58 
59 // Finds the `index` of an _Arg or _Retval node.
GetIndexAttr(const Node & n,int num_args,int * index)60 Status GetIndexAttr(const Node& n, int num_args, int* index) {
61   TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", index));
62   if (*index < 0 || *index >= num_args) {
63     return errors::InvalidArgument("Invalid ", n.type_string(), " number ",
64                                    *index);
65   }
66   return Status::OK();
67 }
68 
69 // Rewrite function to be passed to EncapsulateSubgraphsInFunctions that sorts
70 // the arguments into the order expected by TPUReplicate computations:
71 // 1) replicated arguments
72 // 2) non-replicated (broadcast) arguments
73 // 3) resource variable arguments
74 // See the documentation of EncapsulateSubgraphsInFunctions for the meaning
75 // of the arguments.
RewriteSubgraph(const std::vector<OutputTensor> & arg_source_tensors,std::unique_ptr<Graph> * graph_ptr,std::vector<int> * input_permutation,std::vector<int> * output_permutation,NodeDef * call_def)76 Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
77                        std::unique_ptr<Graph>* graph_ptr,
78                        std::vector<int>* input_permutation,
79                        std::vector<int>* output_permutation,
80                        NodeDef* call_def) {
81   // Replicated inputs have TPUReplicatedInput nodes as predecessors in the
82   // input graph.
83   auto is_replicated_input = [&](const Node& n, bool* is_packed = nullptr) {
84     CHECK_EQ("_Arg", n.type_string());
85     int index;
86     TF_CHECK_OK(GetIndexAttr(n, arg_source_tensors.size(), &index));
87     bool ret =
88         arg_source_tensors.at(index).node->type_string() == kTPUReplicatedInput;
89     if (is_packed) {
90       if (!ret || !GetNodeAttr(arg_source_tensors.at(index).node->attrs(),
91                                "is_packed", is_packed)
92                        .ok()) {
93         *is_packed = false;
94       }
95     }
96     return ret;
97   };
98 
99   auto get_replicated_input_index = [&](const Node& n) {
100     CHECK_EQ("_Arg", n.type_string());
101     int index;
102     TF_CHECK_OK(GetIndexAttr(n, arg_source_tensors.size(), &index));
103     if (arg_source_tensors.at(index).node->type_string() !=
104         kTPUReplicatedInput) {
105       return -1;
106     }
107     int replicated_index;
108     TF_CHECK_OK(GetNodeAttr(arg_source_tensors.at(index).node->attrs(), "index",
109                             &replicated_index));
110 
111     return replicated_index;
112   };
113 
114   auto is_guaranteed_constant = [&](const Node& n) {
115     bool guaranteed_constant = false;
116     if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", &guaranteed_constant)
117              .ok()) {
118       return false;
119     }
120     // Replicated input nodes can be marked as guaranteed constants if they are
121     // const.
122     return guaranteed_constant && !is_replicated_input(n);
123   };
124 
125   Graph* graph = graph_ptr->get();
126   Node* metadata_node = nullptr;
127   const int num_args = input_permutation->size();
128   const int num_retvals = output_permutation->size();
129 
130   std::vector<Node*> args;
131   std::vector<Node*> retvals;
132   args.reserve(num_args);
133   retvals.reserve(num_retvals);
134   for (Node* n : graph->nodes()) {
135     if (n->type_string() == "_Arg") {
136       args.push_back(n);
137     } else if (n->type_string() == "_Retval") {
138       retvals.push_back(n);
139     } else if (n->type_string() == "TPUReplicateMetadata") {
140       metadata_node = n;
141     } else if (!str_util::StrContains(n->requested_device(),
142                                       DEVICE_TPU_REPLICATED_CORE)) {
143       // If an operator isn't assigned to a TPU core device, assign it to
144       // TPU_REPLICATED_CORE without a specific core ID. For some operators,
145       // such as variable reads/writes, the operator may be assigned to non-TPU
146       // devices due to colocation.
147       n->set_assigned_device_name(
148           strings::StrCat("/device:", DEVICE_TPU_REPLICATED_CORE));
149     }
150   }
151 
152   // Read the metadata node and remove it from the graph.
153   if (metadata_node == nullptr) {
154     return errors::InvalidArgument("Missing TPUReplicateMetadata node");
155   }
156 
157   for (const auto& attr : metadata_node->attrs()) {
158     if (attr.first == "computation_shape") {
159       // Convert the deprecated computation_shape attribute into a
160       // num_cores_per_replica value. If a computation_shape is present, it
161       // overrides num_cores_per_replica.
162       std::vector<int> shape;
163       TF_RETURN_IF_ERROR(
164           GetNodeAttr(metadata_node->attrs(), "computation_shape", &shape));
165       if (!shape.empty()) {
166         int64 num_cores_per_replica = 1LL;
167         for (int dim : shape) {
168           num_cores_per_replica *= dim;
169         }
170         call_def->mutable_attr()->erase("num_cores_per_replica");
171         AddNodeAttr("num_cores_per_replica", num_cores_per_replica, call_def);
172       }
173     } else {
174       call_def->mutable_attr()->insert(attr);
175     }
176   }
177   MergeDebugInfo(NodeDebugInfo(metadata_node->def()), call_def);
178   graph->RemoveNode(metadata_node);
179 
180   if (std::find(args.begin(), args.end(), nullptr) != args.end()) {
181     return errors::InvalidArgument("Missing or non-consecutive arguments");
182   }
183 
184   // Reorders the arguments.
185   std::sort(args.begin(), args.end(), [&](Node* a, Node* b) {
186     // Non-constants appear before constants
187     bool a_is_guaranteed_constant = is_guaranteed_constant(*a);
188     bool b_is_guaranteed_constant = is_guaranteed_constant(*b);
189     // Non-packed values appear before packed values.
190     bool a_is_packed;
191     bool b_is_packed;
192     // Replicated values appear before non-replicated values.
193     bool a_not_replicated = !is_replicated_input(*a, &a_is_packed);
194     bool b_not_replicated = !is_replicated_input(*b, &b_is_packed);
195     int a_replicated_index = get_replicated_input_index(*a);
196     int b_replicated_index = get_replicated_input_index(*b);
197     // Non-resources appear before resources
198     bool a_is_resource = (a->output_type(0) == DT_RESOURCE);
199     bool b_is_resource = (b->output_type(0) == DT_RESOURCE);
200     // Uses the name as a tiebreaker so the output is deterministic.
201     StringPiece a_name(a->name());
202     StringPiece b_name(b->name());
203     return std::tie(a_is_guaranteed_constant, a_not_replicated, a_is_packed,
204                     a_is_resource, a_replicated_index, a_name) <
205            std::tie(b_is_guaranteed_constant, b_not_replicated, b_is_packed,
206                     b_is_resource, b_replicated_index, b_name);
207   });
208   // Sorts the retvals by name so the order is deterministic.
209   std::sort(retvals.begin(), retvals.end(),
210             [](Node* a, Node* b) { return a->name() < b->name(); });
211 
212   // Computes the permutation to produce the correct argument order, and update
213   // the argument indices.
214   int variable_start_index = num_args;
215   int guaranteed_const_start_index = num_args;
216   for (int i = 0; i < num_args; ++i) {
217     int index;
218     TF_RETURN_IF_ERROR(GetIndexAttr(*args[i], num_args, &index));
219     if (args[i]->output_type(0) == DT_RESOURCE &&
220         !is_replicated_input(*args[i]) && variable_start_index == num_args) {
221       variable_start_index = i;
222     } else if (is_guaranteed_constant(*args[i]) &&
223                guaranteed_const_start_index == num_args) {
224       guaranteed_const_start_index = i;
225     }
226     (*input_permutation)[index] = i;
227     args[i]->AddAttr("index", i);
228   }
229   VLOG(4) << "variable_start_index: " << variable_start_index
230           << " guaranteed_const_start_index: " << guaranteed_const_start_index;
231 
232   // Computes the permutation to produce the correct retval order, and update
233   // the argument indices.
234   for (int i = 0; i < num_retvals; ++i) {
235     int index;
236     TF_RETURN_IF_ERROR(GetIndexAttr(*retvals[i], num_retvals, &index));
237     (*output_permutation)[index] = i;
238     retvals[i]->AddAttr("index", i);
239   }
240 
241   AddNodeAttr(kTPUReplicateAttr, call_def->name(), call_def);
242   AddNodeAttr("_variable_start_index", variable_start_index, call_def);
243   AddNodeAttr("_guaranteed_const_start_index", guaranteed_const_start_index,
244               call_def);
245 
246   // Uniquify the function name.
247   GraphDef gdef;
248   graph->ToGraphDef(&gdef);
249 
250   // Before serialization, sort each node's control inputs to achieve
251   // determinism. Sorting control inputs could help (but not necessarily)
252   // create a deterministic serialization and fingerprint. Other sources of
253   // nondeterminism include unstable node ordering.
254   SortControlInputs(&gdef);
255   // Fingerprint the function.
256   // Nondeterminism in serialization would not lead to incorrect results, but
257   // may cause spurious cache misses. DeterministicSerialization is a
258   // best-effort deterministic serialization.
259   string serialized;
260   TF_RET_CHECK(SerializeToStringDeterministic(gdef, &serialized));
261   uint64 fingerprint =
262       TpuCompileInterface::Get()->FingerprintString(serialized);
263   LOG(INFO) << "Subgraph fingerprint:" << fingerprint;
264   call_def->set_op(strings::StrCat(call_def->op(), "_", fingerprint));
265   return Status::OK();
266 }
267 
EdgeType(const Edge * edge)268 DataType EdgeType(const Edge* edge) {
269   return edge->dst()->input_type(edge->dst_input());
270 }
271 
272 // Adds the control inputs of `node` to `*deps`.
AddControlInputs(const Node & node,gtl::FlatSet<Node * > * deps)273 void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) {
274   for (const Edge* edge : node.in_edges()) {
275     if (edge->IsControlEdge()) {
276       deps->insert(edge->src());
277     }
278   }
279 }
280 
281 // Adds the control outputs of `node` to `*deps`.
AddControlOutputs(const Node & node,gtl::FlatSet<Node * > * deps)282 void AddControlOutputs(const Node& node, gtl::FlatSet<Node*>* deps) {
283   for (const Edge* edge : node.out_edges()) {
284     if (edge->IsControlEdge()) {
285       deps->insert(edge->dst());
286     }
287   }
288 }
289 
290 // We add Identity nodes for _Arg/_Retval in XLA computation. Remove those
291 // Identity nodes to simplify furthur processing.
RemoveIdentityNodesForArgRetval(Graph * g)292 Status RemoveIdentityNodesForArgRetval(Graph* g) {
293   // Collect Identity nodes for _Arg/_Retval.
294   std::vector<Node*> identity_nodes;
295   for (Node* n : g->nodes()) {
296     if (n->type_string() == "Identity" &&
297         (HasNodeAttr(n->def(), "_tpu_input_identity") ||
298          HasNodeAttr(n->def(), "_tpu_output_identity"))) {
299       identity_nodes.push_back(n);
300     }
301   }
302 
303   // Remove those Identity nodes.
304   for (Node* n : identity_nodes) {
305     const Edge* input_edge;
306     TF_RETURN_IF_ERROR(n->input_edge(0, &input_edge));
307 
308     std::vector<const Edge*> output_edges;
309     for (const Edge* e : n->out_edges()) {
310       output_edges.push_back(e);
311     }
312     for (const Edge* e : output_edges) {
313       if (e->IsControlEdge()) {
314         Node* dst = e->dst();
315         g->RemoveEdge(e);
316         g->AddControlEdge(input_edge->src(), dst);
317       } else {
318         Node* dst = e->dst();
319         int dst_input = e->dst_input();
320         g->RemoveEdge(e);
321         g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
322       }
323     }
324     g->RemoveNode(n);
325   }
326 
327   return Status::OK();
328 }
329 
330 // Updates the TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR when
331 // 'additional_per_replicate_inputs' are added to the inputs of `xla_node`.
UpdateMirroredVariableIndices(int additional_per_replica_inputs,Node * xla_node)332 Status UpdateMirroredVariableIndices(int additional_per_replica_inputs,
333                                      Node* xla_node) {
334   std::vector<int> mirrored_variable_indices;
335   if (xla_node->attrs().Find(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR) !=
336       nullptr) {
337     TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->def(),
338                                    TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
339                                    &mirrored_variable_indices));
340   }
341 
342   if (!mirrored_variable_indices.empty()) {
343     for (int i = 0; i < mirrored_variable_indices.size(); ++i)
344       mirrored_variable_indices[i] += additional_per_replica_inputs;
345     xla_node->ClearAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR);
346     xla_node->AddAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
347                       mirrored_variable_indices);
348   }
349   return Status::OK();
350 }
351 
352 // Move outside compilation nodes at the beginning of XLA computation to host.
353 // For XLA computation graph, we will add new _Arg nodes to replace those
354 // outside compilation nodes.
355 // For host graph, we will move those outside compilation nodes to host,
356 // replicate them, and use them as XLA node's input.
MoveHeadOutsideCompilationToHost(const string & outside_compilation_attr_name,const string & xla_func_name,const std::string & cluster_name,Graph * g,Graph * xla_graph,Node * xla_node,Node * pivot_node)357 Status MoveHeadOutsideCompilationToHost(
358     const string& outside_compilation_attr_name, const string& xla_func_name,
359     const std::string& cluster_name, Graph* g, Graph* xla_graph, Node* xla_node,
360     Node* pivot_node) {
361   // Find outside compilation nodes that only have _Arg or other outside
362   // compilation nodes as input. These nodes will be moved to host graph.
363   std::vector<Node*> oc_nodes_at_head;
364   const string kOnlyArgOrOcInputAttrName = "_xla_only_arg_or_oc_input";
365   ReverseDFS(
366       *xla_graph, /*enter=*/nullptr,
367       [&](Node* n) {
368         bool has_non_arg_or_oc_input = false;
369         for (const Edge* e : n->in_edges()) {
370           if (e->src() == xla_graph->source_node()) {
371             continue;
372           }
373           if (!e->src()->IsArg() &&
374               (!HasNodeAttr(e->src()->def(), outside_compilation_attr_name) ||
375                !HasNodeAttr(e->src()->def(), kOnlyArgOrOcInputAttrName))) {
376             has_non_arg_or_oc_input = true;
377             break;
378           }
379         }
380         if (HasNodeAttr(n->def(), outside_compilation_attr_name) &&
381             !has_non_arg_or_oc_input &&
382             !HasNodeAttr(n->def(), kXlaIsPlaceholderForArg)) {
383           n->AddAttr(kOnlyArgOrOcInputAttrName, true);
384           oc_nodes_at_head.push_back(n);
385         }
386       },
387       NodeComparatorName());
388   std::vector<Node*> const_nodes_to_remove;
389   for (Node* n : oc_nodes_at_head) {
390     // If a Const node is in "oc_nodes_at_head" but some of its successors are
391     // not, copy this Const node and use the copied node for those successors.
392     if (n->type_string() != "Const") {
393       continue;
394     }
395 
396     std::vector<const Edge*> edges_to_replace;
397     for (const Edge* e : n->out_edges()) {
398       if (!e->IsControlEdge() &&
399           HasNodeAttr(e->dst()->def(), outside_compilation_attr_name) &&
400           !HasNodeAttr(e->dst()->def(), kOnlyArgOrOcInputAttrName)) {
401         edges_to_replace.push_back(e);
402       }
403     }
404     if (edges_to_replace.empty()) {
405       continue;
406     }
407 
408     Node* const_copy = xla_graph->CopyNode(n);
409     for (const Edge* e : edges_to_replace) {
410       Node* dst = e->dst();
411       int dst_input = e->dst_input();
412       xla_graph->RemoveEdge(e);
413       xla_graph->AddEdge(const_copy, 0, dst, dst_input);
414     }
415     // Make sure the copied node can be traced from source node.
416     xla_graph->AddControlEdge(xla_graph->source_node(), const_copy);
417 
418     // If this Const node has no data output any more, remove it later.
419     bool has_output_edge = false;
420     for (const Edge* e : n->out_edges()) {
421       if (!e->IsControlEdge()) {
422         has_output_edge = true;
423         break;
424       }
425     }
426     if (!has_output_edge) {
427       const_nodes_to_remove.push_back(n);
428     }
429   }
430   for (Node* n : const_nodes_to_remove) {
431     xla_graph->RemoveNode(n);
432     oc_nodes_at_head.erase(
433         std::remove(oc_nodes_at_head.begin(), oc_nodes_at_head.end(), n),
434         oc_nodes_at_head.end());
435   }
436   if (VLOG_IS_ON(5)) {
437     for (Node* n : oc_nodes_at_head) {
438       VLOG(5) << "oc_nodes_at_head: " << n->DebugString();
439     }
440   }
441 
442   // Copy all nodes in `oc_nodes_at_head` to host graph, and also replicate
443   // them.
444 
445   // Sometimes `xla_node` can have a lot of inputs, calling Node::input_edge
446   // will become very expensive in this case because it is doing a linear
447   // search inside. Create an input_edges vector ahead to make the lookups
448   // faster.
449   std::vector<const Edge*> input_edges;
450   TF_RETURN_IF_ERROR(xla_node->input_edges(&input_edges));
451 
452   std::vector<DataType> input_types;
453   TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "Tinputs", &input_types));
454   int num_distributed_vars;
455   TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "num_distributed_variables",
456                                  &num_distributed_vars));
457   int num_replicas;
458   TF_RETURN_IF_ERROR(
459       GetNodeAttr(xla_node->attrs(), "num_replicas", &num_replicas));
460   int old_num_per_replica_inputs =
461       (input_types.size() - num_distributed_vars) / num_replicas;
462   VLOG(5) << "old_num_per_replica_inputs: " << old_num_per_replica_inputs;
463   std::map<Node*, std::vector<Node*>> node_images;
464   for (Node* n : oc_nodes_at_head) {
465     for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
466       NodeDef copy_def = n->def();
467       copy_def.set_name(absl::StrCat(n->name(), "_head_oc/R", replica_id));
468       copy_def.clear_device();
469 
470       Status s;
471       Node* copy_node = g->AddNode(copy_def, &s);
472       TF_RETURN_IF_ERROR(s);
473 
474       copy_node->AddAttr(kXlaReplicaIdAttrName, replica_id);
475       copy_node->AddAttr(kTPUReplicateAttr, cluster_name);
476 
477       for (const Edge* e : n->in_edges()) {
478         if (e->src() == xla_graph->source_node()) {
479           continue;
480         }
481         // Either e->src() is _Arg node, or it's in `node_images`.
482         if (e->src()->IsArg()) {
483           int index;
484           TF_RETURN_IF_ERROR(GetNodeAttr(e->src()->attrs(), "index", &index));
485           const int new_index =
486               (index < old_num_per_replica_inputs)
487                   ? (old_num_per_replica_inputs * replica_id + index)
488                   : (old_num_per_replica_inputs * num_replicas +
489                      (index - old_num_per_replica_inputs));
490           const Edge* original_edge = input_edges.at(new_index);
491           g->AddEdge(original_edge->src(), original_edge->src_output(),
492                      copy_node, e->dst_input());
493         } else {
494           g->AddEdge(node_images[e->src()][replica_id], e->src_output(),
495                      copy_node, e->dst_input());
496         }
497       }
498 
499       // Add control edge between `copy_node` and `xla_node`, so these outside
500       // compilation nodes will be executed before XLA computation happens.
501       g->AddControlEdge(copy_node, xla_node);
502 
503       // Add control edge between `pivot_node` and `copy_node`, so `copy_node`
504       // belongs to same while loop as `xla_node`.
505       if (pivot_node) {
506         g->AddControlEdge(pivot_node, copy_node);
507       }
508 
509       node_images[n].push_back(copy_node);
510     }
511   }
512 
513   // Record output edges from `oc_nodes_at_head`. We will create an _Arg node
514   // for each of these edges. An obvious optimization here is to deduplicate
515   // these edges by <src, src_output>. But that optimization will complicate
516   // the code, and in practice we usually do not have output edges with the
517   // same <src, src_output>.
518   std::vector<const Edge*> oc_output_edges;
519   std::vector<DataType> new_arg_types;
520   for (Node* n : oc_nodes_at_head) {
521     for (const Edge* e : n->out_edges()) {
522       if (!e->IsControlEdge() &&
523           node_images.find(e->dst()) == node_images.end()) {
524         VLOG(5) << "oc_output_edges: " << e->DebugString();
525         oc_output_edges.push_back(e);
526         new_arg_types.push_back(e->src()->output_type(e->src_output()));
527       }
528     }
529   }
530   int new_num_per_replica_inputs =
531       old_num_per_replica_inputs + oc_output_edges.size();
532   VLOG(5) << "new_num_per_replica_inputs: " << new_num_per_replica_inputs;
533 
534   // Process input edges for XLA node.
535   int num_variables;
536   TF_RETURN_IF_ERROR(
537       GetNodeAttr(xla_node->attrs(), "NumVariables", &num_variables));
538   std::vector<DataType> broadcast_input_types, guaranteed_constant_types;
539   TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "Tbroadcast_inputs",
540                                  &broadcast_input_types));
541   TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "Tguaranteed_constants",
542                                  &guaranteed_constant_types));
543   int num_other_inputs = num_distributed_vars + num_variables +
544                          broadcast_input_types.size() +
545                          guaranteed_constant_types.size();
546   VLOG(5) << "num_other_inputs: " << num_other_inputs;
547 
548   // Update `Tinputs` attribute for `xla_node`.
549   std::vector<DataType> new_input_types;
550   // Order of new_input_types: old per-replica inputs -> new per-replica inputs
551   // -> distributed variables
552   new_input_types.reserve(num_replicas * new_num_per_replica_inputs +
553                           num_distributed_vars);
554   for (int replica_id = 0; replica_id < num_replicas; ++replica_id) {
555     for (int i = 0; i < old_num_per_replica_inputs; ++i) {
556       new_input_types.push_back(input_types[i]);
557     }
558     for (int i = old_num_per_replica_inputs; i < new_num_per_replica_inputs;
559          ++i) {
560       new_input_types.push_back(new_arg_types[i - old_num_per_replica_inputs]);
561     }
562   }
563   const int num_new_per_replica_input_types = new_input_types.size();
564   for (int i = input_types.size() - num_distributed_vars;
565        i < input_types.size(); i++) {
566     new_input_types.push_back(input_types[i]);
567   }
568   xla_node->ClearAttr("Tinputs");
569   xla_node->AddAttr("Tinputs", new_input_types);
570 
571   TF_RETURN_IF_ERROR(UpdateMirroredVariableIndices(
572       /*additional_per_replica_inputs=*/oc_output_edges.size(), xla_node));
573 
574   int new_variable_start_index =
575       num_new_per_replica_input_types / num_replicas + num_distributed_vars +
576       broadcast_input_types.size();
577   if (xla_node->attrs().Find("_variable_start_index") != nullptr) {
578     xla_node->ClearAttr("_variable_start_index");
579     xla_node->AddAttr("_variable_start_index", new_variable_start_index);
580   }
581   int new_guaranteed_const_start_index =
582       new_variable_start_index + num_variables;
583   if (xla_node->attrs().Find("_guaranteed_const_start_index") != nullptr) {
584     xla_node->ClearAttr("_guaranteed_const_start_index");
585     xla_node->AddAttr("_guaranteed_const_start_index",
586                       new_guaranteed_const_start_index);
587   }
588 
589   // Move non per-replica input edges.
590   std::vector<const Edge*> new_input_edges(
591       num_replicas * new_num_per_replica_inputs + num_other_inputs);
592   int end_input_index =
593       num_replicas * new_num_per_replica_inputs + num_other_inputs - 1;
594   int start_input_index = end_input_index + 1 - num_other_inputs;
595   for (int input_index = end_input_index; input_index >= start_input_index;
596        input_index--) {
597     const Edge* e =
598         input_edges.at(input_index - num_replicas * new_arg_types.size());
599     Node* src = e->src();
600     int src_output = e->src_output();
601     g->RemoveEdge(e);
602     const Edge* new_input_edge =
603         g->AddEdge(src, src_output, xla_node, input_index);
604     new_input_edges[input_index] = new_input_edge;
605   }
606 
607   // Re-order old per-replica inputs edges, and add new per-replica input edges.
608   std::vector<std::pair<Node*, int>> per_replica_inputs;
609   std::vector<const Edge*> old_per_replica_edges;
610   for (int i = 0; i < old_num_per_replica_inputs * num_replicas; i++) {
611     const Edge* e = input_edges.at(i);
612     per_replica_inputs.push_back(std::make_pair(e->src(), e->src_output()));
613     old_per_replica_edges.push_back(e);
614   }
615   for (const Edge* e : old_per_replica_edges) {
616     g->RemoveEdge(e);
617   }
618   for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
619     for (int input_index = 0; input_index < old_num_per_replica_inputs;
620          input_index++) {
621       Node* src = per_replica_inputs[replica_id * old_num_per_replica_inputs +
622                                      input_index]
623                       .first;
624       int src_output =
625           per_replica_inputs[replica_id * old_num_per_replica_inputs +
626                              input_index]
627               .second;
628       const Edge* new_input_edge =
629           g->AddEdge(src, src_output, xla_node,
630                      replica_id * new_num_per_replica_inputs + input_index);
631       new_input_edges[input_index] = new_input_edge;
632     }
633     for (int input_index = old_num_per_replica_inputs;
634          input_index < new_num_per_replica_inputs; input_index++) {
635       Node* original_src =
636           oc_output_edges[input_index - old_num_per_replica_inputs]->src();
637       int original_src_output =
638           oc_output_edges[input_index - old_num_per_replica_inputs]
639               ->src_output();
640       Node* src = node_images[original_src][replica_id];
641       const Edge* new_input_edge =
642           g->AddEdge(src, original_src_output, xla_node,
643                      replica_id * new_num_per_replica_inputs + input_index);
644       new_input_edges[input_index] = new_input_edge;
645     }
646   }
647 
648   // Adjust original _Arg nodes in `xla_graph`.
649   for (Node* n : xla_graph->nodes()) {
650     if (n->IsArg()) {
651       int index;
652       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
653       if (index >= old_num_per_replica_inputs) {
654         index += new_arg_types.size();
655         n->ClearAttr("index");
656         n->AddAttr("index", index);
657       }
658     }
659   }
660 
661   // Create new _Arg nodes in `xla_graph`.
662   for (int i = old_num_per_replica_inputs; i < new_num_per_replica_inputs;
663        i++) {
664     NodeDefBuilder arg_builder(absl::StrCat("arg_", i),
665                                FunctionLibraryDefinition::kArgOp);
666     arg_builder.Attr("T", new_arg_types[i - old_num_per_replica_inputs]);
667     arg_builder.Attr("index", i);
668     NodeDef arg_def;
669     TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def));
670     Status s;
671     Node* arg_node = xla_graph->AddNode(arg_def, &s);
672     TF_RETURN_IF_ERROR(s);
673     const Edge* original_edge = oc_output_edges[i - old_num_per_replica_inputs];
674     Node* dst = original_edge->dst();
675     int dst_input = original_edge->dst_input();
676     xla_graph->RemoveEdge(original_edge);
677     xla_graph->AddEdge(arg_node, 0, dst, dst_input);
678   }
679 
680   // For lifted arg nodes:
681   // 1. Add a Placeholder node in `xla_graph`. When we build host side graph
682   //    in ExtractOutsideCompilationPass, we will use this new Placeholder node
683   //    instead of lifted arg node here.
684   // 2. Add an IdentityN node in `g` to indicate its inputs. We will reconnect
685   //    this IdentityN node and this lifted arg node's usage nodes in
686   //    DistributedTPURewritePass.
687   for (Node* n : oc_nodes_at_head) {
688     bool is_lifted_arg;
689     string outside_compilation_attr;
690     if (!TryGetNodeAttr(n->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg) ||
691         !TryGetNodeAttr(n->def(), kOutsideCompilationAttr,
692                         &outside_compilation_attr)) {
693       continue;
694     }
695 
696     TF_RET_CHECK(n->IsIdentity());
697     NodeDefBuilder ph_builder(absl::StrCat("placeholder_", n->name()),
698                               "Placeholder");
699     DataType dtype;
700     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
701     ph_builder.Attr("dtype", dtype);
702     ph_builder.Attr(kXlaIsLiftedArgAttrName, true);
703     ph_builder.Attr(kOutsideCompilationAttr, outside_compilation_attr);
704     NodeDef ph_def;
705     TF_RETURN_IF_ERROR(ph_builder.Finalize(&ph_def));
706     Status s;
707     xla_graph->AddNode(ph_def, &s);
708     TF_RETURN_IF_ERROR(s);
709 
710     Node* input_node;
711     TF_RETURN_IF_ERROR(n->input_node(0, &input_node));
712     TF_RET_CHECK(input_node->type_string() == "_Arg");
713     int index;
714     TF_RETURN_IF_ERROR(GetNodeAttr(input_node->def(), "index", &index));
715     // TODO(b/74023706): for now we only support resource input (e.g. summary
716     // writer), which is non-replicated input. Support replicated input as
717     // well.
718     TF_RET_CHECK(index >= new_num_per_replica_inputs + num_distributed_vars);
719     const Edge* input_edge =
720         new_input_edges.at(num_replicas * new_num_per_replica_inputs + index -
721                            new_num_per_replica_inputs);
722     NodeDefBuilder id_builder(absl::StrCat("lifted_arg_input_", index),
723                               "IdentityN");
724     DataType input_dtype =
725         input_edge->src()->output_type(input_edge->src_output());
726     id_builder.Attr("T", std::vector<DataType>(num_replicas, input_dtype));
727     std::vector<NodeDefBuilder::NodeOut> inputs(
728         num_replicas,
729         NodeDefBuilder::NodeOut{input_edge->src()->name(),
730                                 input_edge->src_output(), input_dtype});
731     id_builder.Attr(kXlaOutsideCompilationInputsAttrName,
732                     outside_compilation_attr);
733     id_builder.Input(inputs);
734     NodeDef id_def;
735     TF_RETURN_IF_ERROR(id_builder.Finalize(&id_def));
736     Node* id_node = g->AddNode(id_def, &s);
737     TF_RETURN_IF_ERROR(s);
738     for (int i = 0; i < num_replicas; i++) {
739       g->AddEdge(input_edge->src(), input_edge->src_output(), id_node, i);
740     }
741   }
742 
743   // Remove `oc_nodes_at_head`.
744   for (Node* n : oc_nodes_at_head) {
745     xla_graph->RemoveNode(n);
746   }
747 
748   VLOG(4) << "MoveHeadOutsideCompilationToHost host graph: "
749           << DumpGraphToFile(absl::StrCat("move_head_oc_host_", xla_func_name),
750                              *g);
751   VLOG(4) << "MoveHeadOutsideCompilationToHost XLA graph: "
752           << DumpGraphToFile(absl::StrCat("move_head_oc_xla_", xla_func_name),
753                              *xla_graph);
754 
755   return Status::OK();
756 }
757 
758 // If there are any unused _Arg nodes in `xla_graph`, remove them from
759 // `xla_graph` and remove corresponding input edge in host graph `g`.
RemoveUnusedXlaInput(const string & xla_func_name,Graph * g,Graph * xla_graph,Node * xla_node)760 Status RemoveUnusedXlaInput(const string& xla_func_name, Graph* g,
761                             Graph* xla_graph, Node* xla_node) {
762   // Find unused _Arg nodes, and remove them.
763   std::vector<DataType> input_types;
764   TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->def(), "Tinputs", &input_types));
765   std::vector<int> mirrored_variable_indices;
766   if (xla_node->attrs().Find(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR) !=
767       nullptr) {
768     TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->def(),
769                                    TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
770                                    &mirrored_variable_indices));
771   }
772   std::vector<DataType> broadcast_input_types;
773   TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->def(), "Tbroadcast_inputs",
774                                  &broadcast_input_types));
775   std::vector<DataType> guaranteed_constant_types;
776   TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->def(), "Tguaranteed_constants",
777                                  &guaranteed_constant_types));
778   int num_variables;
779   TF_RETURN_IF_ERROR(
780       GetNodeAttr(xla_node->def(), "NumVariables", &num_variables));
781   int num_replicas;
782   TF_RETURN_IF_ERROR(
783       GetNodeAttr(xla_node->def(), "num_replicas", &num_replicas));
784   int num_distributed_vars;
785   TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "num_distributed_variables",
786                                  &num_distributed_vars));
787   int num_per_replica_inputs =
788       (input_types.size() - num_distributed_vars) / num_replicas;
789   std::set<int> arg_indices_to_remove;
790   std::vector<Node*> arg_nodes_to_update, nodes_to_remove;
791   int num_args = 0, num_removed_per_replica_inputs = 0,
792       num_removed_distributed_vars = 0;
793   for (Node* n : xla_graph->nodes()) {
794     if (!n->IsArg()) {
795       continue;
796     }
797 
798     bool has_output = false;
799     for (const Edge* e : n->out_edges()) {
800       if (e->dst() != xla_graph->sink_node()) {
801         has_output = true;
802         break;
803       }
804     }
805 
806     num_args++;
807     int index;
808     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
809     if (has_output) {
810       arg_nodes_to_update.push_back(n);
811       continue;
812     }
813 
814     arg_indices_to_remove.insert(index);
815     if (index < num_per_replica_inputs) {
816       num_removed_per_replica_inputs++;
817     } else if (index < num_per_replica_inputs + num_distributed_vars) {
818       num_removed_distributed_vars++;
819     }
820     nodes_to_remove.push_back(n);
821   }
822   for (Node* n : nodes_to_remove) {
823     xla_graph->RemoveNode(n);
824   }
825 
826   // Update `index` for other _Arg nodes.
827   std::map<int, int> arg_index_mapping;
828   int new_arg_index = 0;
829   for (int i = 0; i < num_args; i++) {
830     if (arg_indices_to_remove.find(i) != arg_indices_to_remove.end()) {
831       continue;
832     } else {
833       arg_index_mapping[i] = new_arg_index;
834       new_arg_index++;
835     }
836   }
837   for (Node* n : arg_nodes_to_update) {
838     int index;
839     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
840     n->ClearAttr("index");
841     n->AddAttr("index", arg_index_mapping[index]);
842   }
843 
844   // Re-order replicated index edges for `xla_node`.
845 
846   // Sometimes `xla_node` can have a lot of inputs, calling Node::input_edge
847   // will become very expensive in this case because it is doing a linear search
848   // inside. Create a input_edges vector ahead to make the lookups faster.
849   std::vector<const Edge*> input_edges;
850   TF_RETURN_IF_ERROR(xla_node->input_edges(&input_edges));
851 
852   const int num_new_per_replica_inputs =
853       num_per_replica_inputs - num_removed_per_replica_inputs;
854   for (int i = 0; i < num_replicas; i++) {
855     for (int j = 0; j < num_per_replica_inputs; j++) {
856       auto iter = arg_index_mapping.find(j);
857       if (iter != arg_index_mapping.end()) {
858         const Edge* e = input_edges.at(i * num_per_replica_inputs + j);
859         Node* src = e->src();
860         int src_output = e->src_output();
861         int dst_input = i * num_new_per_replica_inputs + iter->second;
862 
863         g->RemoveEdge(e);
864         g->AddEdge(src, src_output, xla_node, dst_input);
865       } else {
866         const Edge* e = input_edges.at(i * num_per_replica_inputs + j);
867         g->RemoveEdge(e);
868       }
869     }
870   }
871 
872   // Move other data input edges.
873   for (int i = num_replicas * num_per_replica_inputs;
874        i < xla_node->num_inputs(); i++) {
875     int arg_index =
876         num_per_replica_inputs + i - num_replicas * num_per_replica_inputs;
877     auto iter = arg_index_mapping.find(arg_index);
878     if (iter != arg_index_mapping.end()) {
879       const Edge* e = input_edges.at(i);
880       Node* src = e->src();
881       int src_output = e->src_output();
882       int dst_input = num_replicas * num_new_per_replica_inputs + iter->second -
883                       num_new_per_replica_inputs;
884 
885       g->RemoveEdge(e);
886       g->AddEdge(src, src_output, xla_node, dst_input);
887     } else {
888       const Edge* e = input_edges.at(i);
889       g->RemoveEdge(e);
890     }
891   }
892 
893   // Update attributes for `xla_node`.
894   std::vector<DataType> new_input_types;
895   for (int i = 0; i < num_replicas; i++) {
896     for (int j = 0; j < num_per_replica_inputs; j++) {
897       auto iter = arg_index_mapping.find(j);
898       if (iter != arg_index_mapping.end()) {
899         new_input_types.push_back(input_types[iter->first]);
900       }
901     }
902   }
903   for (int i = 0; i < num_distributed_vars; ++i) {
904     auto iter = arg_index_mapping.find(i + num_per_replica_inputs);
905     if (iter != arg_index_mapping.end()) {
906       new_input_types.push_back(
907           input_types[iter->first - num_per_replica_inputs +
908                       num_per_replica_inputs * num_replicas]);
909     }
910   }
911   xla_node->ClearAttr("Tinputs");
912   xla_node->AddAttr("Tinputs", new_input_types);
913 
914   const int num_new_distributed_vars =
915       num_distributed_vars - num_removed_distributed_vars;
916   xla_node->ClearAttr("num_distributed_variables");
917   xla_node->AddAttr("num_distributed_variables", num_new_distributed_vars);
918 
919   if (!mirrored_variable_indices.empty()) {
920     std::vector<int> new_mirrored_variable_indices;
921     absl::flat_hash_set<int> old_mirrored_variable_indices_set;
922     for (int index : mirrored_variable_indices) {
923       old_mirrored_variable_indices_set.insert(index);
924     }
925     for (int i = 0; i < num_per_replica_inputs + num_distributed_vars; i++) {
926       auto iter = arg_index_mapping.find(i);
927       if (iter != arg_index_mapping.end() &&
928           old_mirrored_variable_indices_set.contains(iter->first)) {
929         new_mirrored_variable_indices.push_back(iter->second);
930       }
931     }
932     xla_node->ClearAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR);
933     xla_node->AddAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
934                       new_mirrored_variable_indices);
935   }
936 
937   int num_replicated_inputs = num_per_replica_inputs + num_distributed_vars;
938   std::vector<DataType> new_broadcast_input_types;
939   for (int i = 0; i < broadcast_input_types.size(); i++) {
940     int arg_index = num_replicated_inputs + i;
941     if (arg_index_mapping.find(arg_index) != arg_index_mapping.end()) {
942       new_broadcast_input_types.push_back(broadcast_input_types[i]);
943     }
944   }
945   xla_node->ClearAttr("Tbroadcast_inputs");
946   xla_node->AddAttr("Tbroadcast_inputs", new_broadcast_input_types);
947   int new_num_variables = 0;
948   for (int i = 0; i < num_variables; i++) {
949     int arg_index = num_replicated_inputs + broadcast_input_types.size() + i;
950     if (arg_index_mapping.find(arg_index) != arg_index_mapping.end()) {
951       new_num_variables++;
952     }
953   }
954   xla_node->ClearAttr("NumVariables");
955   xla_node->AddAttr("NumVariables", new_num_variables);
956   std::vector<DataType> new_guaranteed_constant_types;
957   for (int i = 0; i < guaranteed_constant_types.size(); i++) {
958     int arg_index = num_replicated_inputs + broadcast_input_types.size() +
959                     num_variables + i;
960     if (arg_index_mapping.find(arg_index) != arg_index_mapping.end()) {
961       new_guaranteed_constant_types.push_back(guaranteed_constant_types[i]);
962     }
963   }
964   xla_node->ClearAttr("Tguaranteed_constants");
965   xla_node->AddAttr("Tguaranteed_constants", new_guaranteed_constant_types);
966 
967   int new_variable_start_index = num_new_per_replica_inputs +
968                                  num_new_distributed_vars +
969                                  new_broadcast_input_types.size();
970   if (xla_node->attrs().Find("_variable_start_index") != nullptr) {
971     xla_node->ClearAttr("_variable_start_index");
972     xla_node->AddAttr("_variable_start_index", new_variable_start_index);
973   }
974   int new_guaranteed_const_start_index =
975       new_variable_start_index + new_num_variables;
976   if (xla_node->attrs().Find("_guaranteed_const_start_index") != nullptr) {
977     xla_node->ClearAttr("_guaranteed_const_start_index");
978     xla_node->AddAttr("_guaranteed_const_start_index",
979                       new_guaranteed_const_start_index);
980   }
981 
982   VLOG(4) << "RemoveUnusedXlaInput host graph: "
983           << DumpGraphToFile(
984                  absl::StrCat("remove_unused_input_host_", xla_func_name), *g);
985   VLOG(4) << "RemoveUnusedXlaInput XLA graph: "
986           << DumpGraphToFile(
987                  absl::StrCat("remove_unused_input_xla_", xla_func_name),
988                  *xla_graph);
989 
990   return Status::OK();
991 }
992 
993 // Move outside compilation nodes at the end of XLA computation to host.
994 // For XLA computation graph, we will add new _Retval nodes to replace those
995 // outside compilation nodes.
996 // For host graph, we will move those outside compilation nodes to host,
997 // replicate them, and use them as XLA node's output.
MoveTailOutsideCompilationToHost(const string & outside_compilation_attr_name,const string & xla_func_name,const std::string & cluster_name,Graph * g,Graph * xla_graph,Node * xla_node,Node * pivot_node)998 Status MoveTailOutsideCompilationToHost(
999     const string& outside_compilation_attr_name, const string& xla_func_name,
1000     const std::string& cluster_name, Graph* g, Graph* xla_graph, Node* xla_node,
1001     Node* pivot_node) {
1002   // Find outside compilation nodes that only have _Retval or other outside
1003   // compilation nodes as output. These nodes will be moved to host graph.
1004   std::vector<Node*> oc_nodes_at_tail;
1005   const string kOnlyRetOrOcOutputAttrName = "_xla_only_ret_or_oc_output";
1006   DFS(
1007       *xla_graph, /*enter=*/nullptr,
1008       [&](Node* n) {
1009         bool has_non_ret_or_oc_output = false;
1010         for (const Edge* e : n->out_edges()) {
1011           if (e->dst() == xla_graph->sink_node()) {
1012             continue;
1013           }
1014           if (!e->dst()->IsRetval() &&
1015               (!HasNodeAttr(e->dst()->def(), outside_compilation_attr_name) ||
1016                !HasNodeAttr(e->dst()->def(), kOnlyRetOrOcOutputAttrName))) {
1017             has_non_ret_or_oc_output = true;
1018             break;
1019           }
1020         }
1021         if (HasNodeAttr(n->def(), outside_compilation_attr_name) &&
1022             !has_non_ret_or_oc_output) {
1023           n->AddAttr(kOnlyRetOrOcOutputAttrName, true);
1024           oc_nodes_at_tail.push_back(n);
1025         }
1026       },
1027       NodeComparatorName());
1028   if (VLOG_IS_ON(5)) {
1029     for (Node* n : oc_nodes_at_tail) {
1030       VLOG(5) << "oc_nodes_at_tail: " << n->DebugString();
1031     }
1032   }
1033 
1034   // Record input edges from `oc_nodes_at_tail`. We will create an _Retval node
1035   // for each of these edges. An obvious optimization here is to deduplicate
1036   // these edges by <src, src_output>. But that optimization will complicate
1037   // the code, and in practice we usually do not have input edges with the
1038   // same <src, src_output>.
1039   std::vector<const Edge*> oc_input_edges;
1040   std::vector<DataType> new_ret_types;
1041   for (Node* n : oc_nodes_at_tail) {
1042     for (const Edge* e : n->in_edges()) {
1043       if (!e->IsControlEdge() &&
1044           !HasNodeAttr(e->src()->def(), kOnlyRetOrOcOutputAttrName)) {
1045         VLOG(5) << "oc_input_edges: " << e->DebugString();
1046         oc_input_edges.push_back(e);
1047         new_ret_types.push_back(e->src()->output_type(e->src_output()));
1048       }
1049     }
1050   }
1051   std::vector<DataType> output_types;
1052   TF_RETURN_IF_ERROR(
1053       GetNodeAttr(xla_node->attrs(), "output_types", &output_types));
1054   int num_replicas;
1055   TF_RETURN_IF_ERROR(
1056       GetNodeAttr(xla_node->attrs(), "num_replicas", &num_replicas));
1057   int old_num_replicated_outputs = output_types.size() / num_replicas;
1058   int new_num_replicated_outputs =
1059       old_num_replicated_outputs + oc_input_edges.size();
1060   VLOG(5) << "old_num_replicated_outputs: " << old_num_replicated_outputs;
1061   VLOG(5) << "new_num_replicated_outputs: " << new_num_replicated_outputs;
1062 
1063   // Update `output_types` attribute for `xla_node`.
1064   std::vector<DataType> new_output_types;
1065   for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
1066     for (int i = 0; i < old_num_replicated_outputs; i++) {
1067       new_output_types.push_back(output_types[i]);
1068     }
1069     for (int i = old_num_replicated_outputs; i < new_num_replicated_outputs;
1070          i++) {
1071       new_output_types.push_back(new_ret_types[i - old_num_replicated_outputs]);
1072     }
1073   }
1074   xla_node->ClearAttr("output_types");
1075   xla_node->AddAttr("output_types", new_output_types);
1076 
1077   // Re-order old replicated output edges. Since a node could potentially
1078   // connect to multiple nodes, build a vector<vector<pair>> mapping of
1079   // output index to input nodes/index.
1080   // The outer vector represents the output index, the inner vector
1081   // represents the destination node and input index pair with the possibility
1082   // of multiple node/index pairs.
1083   std::vector<std::vector<std::pair<Node*, int>>> replicated_outputs(
1084       old_num_replicated_outputs * num_replicas);
1085   std::vector<const Edge*> old_replicated_edges;
1086   for (const Edge* e : xla_node->out_edges()) {
1087     if (e->src_output() >= 0 &&
1088         e->src_output() < old_num_replicated_outputs * num_replicas) {
1089       replicated_outputs[e->src_output()].push_back(
1090           std::make_pair(e->dst(), e->dst_input()));
1091       old_replicated_edges.push_back(e);
1092     }
1093   }
1094   for (const Edge* e : old_replicated_edges) {
1095     g->RemoveEdge(e);
1096   }
1097   for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
1098     for (int output_index = 0; output_index < old_num_replicated_outputs;
1099          output_index++) {
1100       for (const auto& node_input_pair :
1101            replicated_outputs[replica_id * old_num_replicated_outputs +
1102                               output_index]) {
1103         Node* dst = node_input_pair.first;
1104         int dst_input = node_input_pair.second;
1105         g->AddEdge(xla_node,
1106                    replica_id * new_num_replicated_outputs + output_index, dst,
1107                    dst_input);
1108       }
1109     }
1110   }
1111 
1112   // Copy all nodes in `oc_nodes_at_tail` to host graph, and also replicate
1113   // them.
1114   std::map<Node*, std::vector<Node*>> node_images;
1115   for (Node* n : oc_nodes_at_tail) {
1116     for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
1117       NodeDef copy_def = n->def();
1118       copy_def.set_name(absl::StrCat(n->name(), "_tail_oc/R", replica_id));
1119       copy_def.clear_device();
1120 
1121       Status s;
1122       Node* copy_node = g->AddNode(copy_def, &s);
1123       TF_RETURN_IF_ERROR(s);
1124 
1125       copy_node->AddAttr(kXlaReplicaIdAttrName, replica_id);
1126       copy_node->AddAttr(kTPUReplicateAttr, cluster_name);
1127 
1128       for (const Edge* e : n->out_edges()) {
1129         if (e->dst() == xla_graph->sink_node()) {
1130           continue;
1131         }
1132         // Either e->dst() is _Retval, or it's in `node_images`.
1133         if (e->dst()->IsRetval()) {
1134           int index;
1135           TF_RETURN_IF_ERROR(GetNodeAttr(e->dst()->attrs(), "index", &index));
1136           for (const auto& output :
1137                replicated_outputs[replica_id * old_num_replicated_outputs +
1138                                   index]) {
1139             // Remove original input edge, if existent.
1140             const Edge* original_edge;
1141             Status s = output.first->input_edge(output.second, &original_edge);
1142             if (s.ok()) {
1143               g->RemoveEdge(original_edge);
1144             }
1145             g->AddEdge(copy_node, e->src_output(), output.first, output.second);
1146           }
1147         } else {
1148           g->AddEdge(copy_node, e->src_output(),
1149                      node_images[e->dst()][replica_id], e->dst_input());
1150         }
1151       }
1152 
1153       // Add attribute "_xla_tail_outside_compilation" to `copy_node`, and add a
1154       // control edge between `xla_node` and `copy_node`. As a result, in later
1155       // rewriting pass, a control edge will be added between `copy_node` and
1156       // "control_after" node for the XLA computation, so `copy_node` will be
1157       // executed before XLA computation's final results.
1158       copy_node->AddAttr("_xla_tail_outside_compilation", true);
1159       g->AddControlEdge(xla_node, copy_node);
1160 
1161       // Add control edge between `pivot_node` and `copy_node`, so `copy_node`
1162       // belongs to same while loop as `xla_node`.
1163       if (pivot_node) {
1164         g->AddControlEdge(pivot_node, copy_node);
1165       }
1166 
1167       node_images[n].push_back(copy_node);
1168     }
1169   }
1170 
1171   // Connect new output values of `xla_node` to dst nodes of `oc_input_edges`.
1172   for (int i = 0; i < new_ret_types.size(); i++) {
1173     const Edge* original_edge = oc_input_edges[i];
1174     for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
1175       int src_output = replica_id * new_num_replicated_outputs +
1176                        old_num_replicated_outputs + i;
1177       Node* dst = node_images[original_edge->dst()][replica_id];
1178       g->AddEdge(xla_node, src_output, dst, original_edge->dst_input());
1179     }
1180   }
1181 
1182   // Create new _Retval nodes in `xla_graph`.
1183   for (int i = old_num_replicated_outputs; i < new_num_replicated_outputs;
1184        i++) {
1185     NodeDefBuilder ret_builder(absl::StrCat("ret_", i),
1186                                FunctionLibraryDefinition::kRetOp);
1187     ret_builder.Attr("T", new_ret_types[i - old_num_replicated_outputs]);
1188     ret_builder.Attr("index", i);
1189     const Edge* original_edge = oc_input_edges[i - old_num_replicated_outputs];
1190     Node* src = original_edge->src();
1191     int src_output = original_edge->src_output();
1192     ret_builder.Input(src->name(), src_output, src->output_type(src_output));
1193     NodeDef ret_def;
1194     TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
1195     Status s;
1196     Node* ret_node = xla_graph->AddNode(ret_def, &s);
1197     TF_RETURN_IF_ERROR(s);
1198     xla_graph->RemoveEdge(original_edge);
1199     xla_graph->AddEdge(src, src_output, ret_node, 0);
1200   }
1201 
1202   // Remove `oc_nodes_at_tail`.
1203   for (Node* n : oc_nodes_at_tail) {
1204     xla_graph->RemoveNode(n);
1205   }
1206 
1207   // We cannot leave _Retval with no input. Add a placeholder input, which will
1208   // be removed later with unused _Retval.
1209   std::vector<Node*> unused_rets;
1210   for (Node* n : xla_graph->nodes()) {
1211     if (n->IsRetval() && n->in_edges().empty()) {
1212       unused_rets.push_back(n);
1213     }
1214   }
1215   for (Node* n : unused_rets) {
1216     NodeDefBuilder builder(absl::StrCat("placeholder_", n->name()),
1217                            "Placeholder");
1218     DataType dtype;
1219     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
1220     builder.Attr("dtype", dtype);
1221     builder.Attr(kXlaIsPlaceholderForTailOcAttrName, true);
1222     NodeDef def;
1223     TF_RETURN_IF_ERROR(builder.Finalize(&def));
1224     Status s;
1225     Node* placeholder = xla_graph->AddNode(def, &s);
1226     TF_RETURN_IF_ERROR(s);
1227     xla_graph->AddEdge(placeholder, 0, n, 0);
1228   }
1229 
1230   VLOG(4) << "MoveTailOutsideCompilationToHost host graph: "
1231           << DumpGraphToFile(absl::StrCat("move_tail_oc_host_", xla_func_name),
1232                              *g);
1233   VLOG(4) << "MoveTaildOutsideCompilationToHost XLA graph: "
1234           << DumpGraphToFile(absl::StrCat("move_tail_oc_xla_", xla_func_name),
1235                              *xla_graph);
1236 
1237   return Status::OK();
1238 }
1239 
ReplaceArgUsedByOutsideCompilationWithPlaceholder(const string & outside_compilation_attr_name,const string & xla_func_name,Graph * g,Graph * xla_graph,Node * xla_node)1240 Status ReplaceArgUsedByOutsideCompilationWithPlaceholder(
1241     const string& outside_compilation_attr_name, const string& xla_func_name,
1242     Graph* g, Graph* xla_graph, Node* xla_node) {
1243   std::vector<DataType> input_types;
1244   TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "Tinputs", &input_types));
1245   int num_distributed_vars;
1246   TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "num_distributed_variables",
1247                                  &num_distributed_vars));
1248   int num_replicas;
1249   TF_RETURN_IF_ERROR(
1250       GetNodeAttr(xla_node->attrs(), "num_replicas", &num_replicas));
1251   int num_per_replica_inputs =
1252       (input_types.size() - num_distributed_vars) / num_replicas;
1253 
1254   for (Node* n : xla_graph->op_nodes()) {
1255     if (!n->IsArg()) {
1256       continue;
1257     }
1258 
1259     DataType dtype;
1260     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
1261     // TODO(b/74023706): enable moving normal data tensors.
1262     if (dtype != DT_RESOURCE) {
1263       continue;
1264     }
1265 
1266     std::vector<const Edge*> oc_out_edges;
1267     for (const Edge* e : n->out_edges()) {
1268       if (e->IsControlEdge() ||
1269           !HasNodeAttr(e->dst()->def(), kOutsideCompilationAttr)) {
1270         continue;
1271       }
1272 
1273       oc_out_edges.push_back(e);
1274     }
1275     if (oc_out_edges.empty()) {
1276       continue;
1277     }
1278 
1279     // Sometimes `xla_node` can have a lot of inputs, calling Node::input_edge
1280     // will become very expensive in this case because it is doing a linear
1281     // search inside. Create an input_edges vector ahead to make the lookups
1282     // faster.
1283     std::vector<const Edge*> input_edges;
1284     TF_RETURN_IF_ERROR(xla_node->input_edges(&input_edges));
1285 
1286     // Build an IdentityN node to record inputs for this _Arg node.
1287     int index;
1288     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
1289     string oc_identifier = absl::StrCat("oc_only_arg_", index);
1290     NodeDefBuilder id_builder(absl::StrCat(oc_identifier, "_inputs"),
1291                               "IdentityN");
1292     std::vector<DataType> dtypes(num_replicas, dtype);
1293     id_builder.Attr("T", dtypes);
1294     id_builder.Attr(kXlaOutsideCompilationInputsAttrName, oc_identifier);
1295     std::vector<NodeDefBuilder::NodeOut> inputs(num_replicas);
1296     if (index >= num_per_replica_inputs) {
1297       const Edge* e = input_edges.at(num_replicas * num_per_replica_inputs +
1298                                      (index - num_per_replica_inputs));
1299       for (int i = 0; i < num_replicas; i++) {
1300         inputs[i] =
1301             NodeDefBuilder::NodeOut{e->src()->name(), e->src_output(),
1302                                     e->src()->output_type(e->src_output())};
1303       }
1304     } else {
1305       for (int i = 0; i < num_replicas; i++) {
1306         const Edge* e = input_edges.at(i * num_per_replica_inputs + index);
1307         inputs[i] =
1308             NodeDefBuilder::NodeOut{e->src()->name(), e->src_output(),
1309                                     e->src()->output_type(e->src_output())};
1310       }
1311     }
1312     id_builder.Input(inputs);
1313     NodeDef id_def;
1314     TF_RETURN_IF_ERROR(id_builder.Finalize(&id_def));
1315     Status s;
1316     Node* id_node = g->AddNode(id_def, &s);
1317     TF_RETURN_IF_ERROR(s);
1318     if (index >= num_per_replica_inputs) {
1319       const Edge* e = input_edges.at(num_replicas * num_per_replica_inputs +
1320                                      (index - num_per_replica_inputs));
1321       for (int i = 0; i < num_replicas; i++) {
1322         g->AddEdge(e->src(), e->src_output(), id_node, i);
1323       }
1324     } else {
1325       for (int i = 0; i < num_replicas; i++) {
1326         const Edge* e = input_edges.at(i * num_per_replica_inputs + index);
1327         g->AddEdge(e->src(), e->src_output(), id_node, i);
1328       }
1329     }
1330 
1331     for (const Edge* e : oc_out_edges) {
1332       // 'e' will use a new Placeholder node as input.
1333       NodeDefBuilder ph_builder(xla_graph->NewName("ph_for_arg_in_oc_"),
1334                                 "Placeholder");
1335       ph_builder.Attr("dtype", dtype);
1336 
1337       string outside_compilation_attr;
1338       TF_RETURN_IF_ERROR(GetNodeAttr(e->dst()->def(), kOutsideCompilationAttr,
1339                                      &outside_compilation_attr));
1340       ph_builder.Attr(kOutsideCompilationAttr, outside_compilation_attr);
1341       ph_builder.Attr(kXlaOutsideCompilationInputsAttrName, oc_identifier);
1342       ph_builder.Attr(kXlaIsPlaceholderForArg, true);
1343       NodeDef ph_def;
1344       TF_RETURN_IF_ERROR(ph_builder.Finalize(&ph_def));
1345       Status s;
1346       Node* ph_node = xla_graph->AddNode(ph_def, &s);
1347       TF_RETURN_IF_ERROR(s);
1348       Node* dst = e->dst();
1349       int dst_input = e->dst_input();
1350       xla_graph->RemoveEdge(e);
1351       xla_graph->AddEdge(ph_node, 0, dst, dst_input);
1352       xla_graph->AddControlEdge(xla_graph->source_node(), ph_node);
1353     }
1354   }
1355   VLOG(4) << "ReplaceOutsideCompilationOnlyArgWithPlaceholder host graph: "
1356           << DumpGraphToFile(
1357                  absl::StrCat("replace_oc_only_arg_host_", xla_func_name), *g);
1358   VLOG(4) << "ReplaceOutsideCompilationOnlyArgWithPlaceholder XLA graph: "
1359           << DumpGraphToFile(
1360                  absl::StrCat("replace_oc_only_arg_xla_", xla_func_name),
1361                  *xla_graph);
1362   return Status::OK();
1363 }
1364 
1365 // If there are any unused _Retval nodes in `xla_graph` (whose input is a
1366 // Placeholder node), remove them from `xla_graph` and remove corresponding
1367 // output edge in host graph `g`.
RemoveUnusedXlaOutput(const string & xla_func_name,Graph * g,Graph * xla_graph,Node * xla_node)1368 Status RemoveUnusedXlaOutput(const string& xla_func_name, Graph* g,
1369                              Graph* xla_graph, Node* xla_node) {
1370   // Find unused _Retval nodes, and remove them.
1371   std::vector<DataType> output_types;
1372   TF_RETURN_IF_ERROR(
1373       GetNodeAttr(xla_node->def(), "output_types", &output_types));
1374   int num_replicas;
1375   TF_RETURN_IF_ERROR(
1376       GetNodeAttr(xla_node->def(), "num_replicas", &num_replicas));
1377   int num_replicated_outputs = output_types.size() / num_replicas;
1378   std::set<int> ret_indices_to_remove;
1379   std::vector<Node*> ret_nodes_to_update, nodes_to_remove;
1380   int num_rets = 0;
1381   for (Node* n : xla_graph->nodes()) {
1382     if (!n->IsRetval()) {
1383       continue;
1384     }
1385 
1386     num_rets++;
1387 
1388     const Edge* e;
1389     TF_RETURN_IF_ERROR(n->input_edge(0, &e));
1390     if (e->src()->type_string() != "Placeholder" ||
1391         !HasNodeAttr(e->src()->def(), kXlaIsPlaceholderForTailOcAttrName)) {
1392       ret_nodes_to_update.push_back(n);
1393       continue;
1394     }
1395 
1396     int index;
1397     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
1398     ret_indices_to_remove.insert(index);
1399     nodes_to_remove.push_back(e->src());
1400     nodes_to_remove.push_back(n);
1401   }
1402   for (Node* n : nodes_to_remove) {
1403     xla_graph->RemoveNode(n);
1404   }
1405 
1406   // Update `index` for other _Arg nodes.
1407   std::map<int, int> ret_index_mapping;
1408   int new_ret_index = 0;
1409   for (int i = 0; i < num_rets; i++) {
1410     if (ret_indices_to_remove.find(i) != ret_indices_to_remove.end()) {
1411       continue;
1412     } else {
1413       ret_index_mapping[i] = new_ret_index;
1414       new_ret_index++;
1415     }
1416   }
1417   for (Node* n : ret_nodes_to_update) {
1418     int index;
1419     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
1420     n->ClearAttr("index");
1421     n->AddAttr("index", ret_index_mapping[index]);
1422   }
1423 
1424   // Update `output_types` attribute for `xla_node`.
1425   std::vector<DataType> new_output_types;
1426   for (int i = 0; i < num_replicas; i++) {
1427     for (const auto& e : ret_index_mapping) {
1428       new_output_types.push_back(output_types[e.first]);
1429     }
1430   }
1431 
1432   xla_node->ClearAttr("output_types");
1433   xla_node->AddAttr("output_types", new_output_types);
1434 
1435   // Re-order replicated output edges for `xla_node`.
1436   std::vector<std::vector<const Edge*>> output_edges(num_replicas *
1437                                                      num_replicated_outputs);
1438   for (const Edge* e : xla_node->out_edges()) {
1439     if (e->src_output() >= 0 &&
1440         e->src_output() < num_replicas * num_replicated_outputs) {
1441       output_edges[e->src_output()].push_back(e);
1442     }
1443   }
1444   for (int i = 0; i < num_replicas; i++) {
1445     for (int j = 0; j < num_replicated_outputs; j++) {
1446       auto iter = ret_index_mapping.find(j);
1447       if (iter != ret_index_mapping.end()) {
1448         for (const Edge* e : output_edges[i * num_replicated_outputs + j]) {
1449           Node* dst = e->dst();
1450           int dst_input = e->dst_input();
1451           int src_output =
1452               i * (num_replicated_outputs - ret_indices_to_remove.size()) +
1453               iter->second;
1454           g->RemoveEdge(e);
1455           g->AddEdge(xla_node, src_output, dst, dst_input);
1456         }
1457       } else {
1458         TF_RET_CHECK(output_edges[i * num_replicated_outputs + j].empty())
1459             << "Output edge not removed: "
1460             << output_edges[i * num_replicated_outputs + j][0]->DebugString();
1461       }
1462     }
1463   }
1464 
1465   VLOG(4) << "RemoveUnusedXlaOutput host graph: "
1466           << DumpGraphToFile(
1467                  absl::StrCat("remove_unused_output_host_", xla_func_name), *g);
1468   VLOG(4) << "RemoveUnusedXlaOutput XLA graph: "
1469           << DumpGraphToFile(
1470                  absl::StrCat("remove_unused_output_xla_", xla_func_name),
1471                  *xla_graph);
1472 
1473   return Status::OK();
1474 }
1475 
1476 // For data edges between _Arg and _Retval in `xla_graph`, remove them and
1477 // change input/output edges in `g` (host graph). For now, we only consider
1478 // replicated inputs.
RemoveEdgesBetweenArgAndRetval(const string & xla_func_name,Graph * g,Graph * xla_graph,Node * xla_node)1479 Status RemoveEdgesBetweenArgAndRetval(const string& xla_func_name, Graph* g,
1480                                       Graph* xla_graph, Node* xla_node) {
1481   // Collect data edges between _Arg and _Retval.
1482   int num_replicas;
1483   TF_RETURN_IF_ERROR(
1484       GetNodeAttr(xla_node->def(), "num_replicas", &num_replicas));
1485   std::vector<DataType> input_types;
1486   TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->def(), "Tinputs", &input_types));
1487   int num_distributed_vars;
1488   TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "num_distributed_variables",
1489                                  &num_distributed_vars));
1490   int old_num_per_replica_inputs =
1491       (input_types.size() - num_distributed_vars) / num_replicas;
1492   std::vector<DataType> output_types;
1493   TF_RETURN_IF_ERROR(
1494       GetNodeAttr(xla_node->def(), "output_types", &output_types));
1495   int old_num_outputs = output_types.size() / num_replicas;
1496   std::vector<const Edge*> edges;
1497   for (const Edge* e : xla_graph->edges()) {
1498     if (!e->IsControlEdge() && e->src()->IsArg() && e->dst()->IsRetval()) {
1499       edges.push_back(e);
1500     }
1501   }
1502 
1503   // In host graph `g`, remove output edge from `xla_node` and connect input &
1504   // output directly.
1505   std::vector<std::vector<const Edge*>> xla_node_out_edges(
1506       xla_node->num_outputs());
1507   for (const Edge* e : xla_node->out_edges()) {
1508     if (!e->IsControlEdge()) {
1509       xla_node_out_edges[e->src_output()].push_back(e);
1510     }
1511   }
1512 
1513   // Sometimes `xla_node` can have a lot of inputs, calling Node::input_edge
1514   // will become very expensive in this case because it is doing a linear
1515   // search inside. Create an input_edges vector ahead to make the lookups
1516   // faster.
1517   std::vector<const Edge*> input_edges;
1518   TF_RETURN_IF_ERROR(xla_node->input_edges(&input_edges));
1519   for (const Edge* e : edges) {
1520     int arg_index;
1521     TF_RETURN_IF_ERROR(GetNodeAttr(e->src()->def(), "index", &arg_index));
1522     int ret_index;
1523     TF_RETURN_IF_ERROR(GetNodeAttr(e->dst()->def(), "index", &ret_index));
1524 
1525     for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
1526       int input_index;
1527       if (arg_index < old_num_per_replica_inputs) {
1528         input_index = replica_id * old_num_per_replica_inputs + arg_index;
1529       } else {
1530         input_index = num_replicas * old_num_per_replica_inputs +
1531                       (arg_index - old_num_per_replica_inputs);
1532       }
1533       const Edge* input_edge = input_edges.at(input_index);
1534 
1535       int output_index = replica_id * old_num_outputs + ret_index;
1536       for (const Edge* output_edge : xla_node_out_edges[output_index]) {
1537         Node* dst = output_edge->dst();
1538         int dst_input = output_edge->dst_input();
1539 
1540         g->RemoveEdge(output_edge);
1541         g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
1542       }
1543     }
1544   }
1545 
1546   // Remove edges from `xla_graph`. Add a Placeholder node for the _Retval node,
1547   // which will be removed by `RemoveUnusedXlaOutput()` later.
1548   for (const Edge* e : edges) {
1549     NodeDefBuilder placeholder_builder(
1550         absl::StrCat("placeholder_", e->dst()->name()), "Placeholder");
1551     placeholder_builder.Attr("dtype", e->src()->output_type(e->src_output()));
1552     placeholder_builder.Attr(kXlaIsPlaceholderForTailOcAttrName, true);
1553     NodeDef placeholder_def;
1554     TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def));
1555     Status s;
1556     Node* placeholder_node = xla_graph->AddNode(placeholder_def, &s);
1557     TF_RETURN_IF_ERROR(s);
1558 
1559     Node* dst = e->dst();
1560     int dst_input = e->dst_input();
1561     xla_graph->RemoveEdge(e);
1562     xla_graph->AddEdge(placeholder_node, 0, dst, dst_input);
1563   }
1564 
1565   VLOG(4) << "RemoveUnusedArgRetvalPair host graph: "
1566           << DumpGraphToFile(
1567                  absl::StrCat("remove_unused_arg_ret_host_", xla_func_name),
1568                  *g);
1569   VLOG(4) << "RemoveUnusedArgRetvalPair XLA graph: "
1570           << DumpGraphToFile(
1571                  absl::StrCat("remove_unused_arg_ret_xla_", xla_func_name),
1572                  *xla_graph);
1573 
1574   return Status::OK();
1575 }
1576 
1577 // Remove any TPUReplicatedInput nodes with no output edges. Those nodes are
1578 // usually TPUMirroredVariable handles which are not used by any computations.
RemoveUnusedTPUReplicatedInputs(Graph * graph)1579 void RemoveUnusedTPUReplicatedInputs(Graph* graph) {
1580   for (Node* n : graph->nodes()) {
1581     if (n->type_string() == kTPUReplicatedInput) {
1582       bool has_output = false;
1583       for (const Edge* e : n->out_edges()) {
1584         if (!e->dst()->IsSink()) {
1585           has_output = true;
1586           break;
1587         }
1588       }
1589       if (!has_output) {
1590         // Remove any TPUPartitionedInput node from the src nodes of the
1591         // to-be-removed TPUReplicatedInput node
1592         std::vector<Node*> to_be_removed_src_nodes;
1593         for (const auto& e_in : n->in_edges()) {
1594           if (!e_in->IsControlEdge() &&
1595               e_in->src()->type_string() == kTPUPartitionedInput)
1596             to_be_removed_src_nodes.push_back(e_in->src());
1597         }
1598         graph->RemoveNode(n);
1599         for (Node* node : to_be_removed_src_nodes) {
1600           graph->RemoveNode(node);
1601         }
1602       }
1603     }
1604   }
1605 }
1606 
1607 // We might have duplicated cluster names in the graph, e.g. when a tf.function
1608 // containing tpu_strategy.run() is called multiple times with
1609 // the same inputs. Find clusters with duplicated names and rename them.
RenameClustersWithDuplicatedNames(Graph * g)1610 Status RenameClustersWithDuplicatedNames(Graph* g) {
1611   // Find all TPU clusters by finding all TPUReplicateMetadata nodes.
1612   std::unordered_map<string, std::vector<Node*>> cluster_name_to_metadata_nodes;
1613   std::unordered_set<string> cluster_names;
1614   for (Node* n : g->nodes()) {
1615     if (n->type_string() != "TPUReplicateMetadata") {
1616       continue;
1617     }
1618     string cluster_name;
1619     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kTPUReplicateAttr, &cluster_name));
1620     cluster_name_to_metadata_nodes[cluster_name].push_back(n);
1621     cluster_names.insert(cluster_name);
1622   }
1623   // Look for clusters with duplicated name.
1624   for (const auto& iter : cluster_name_to_metadata_nodes) {
1625     if (iter.second.size() == 1) {
1626       continue;
1627     }
1628 
1629     // Rename clusters.
1630     for (int i = 1; i < iter.second.size(); i++) {
1631       // Find an available cluster name.
1632       string new_cluster_name;
1633       int cluster_name_suffix = 1;
1634       while (true) {
1635         new_cluster_name = absl::StrCat(iter.first, "_", cluster_name_suffix);
1636         if (cluster_names.find(new_cluster_name) == cluster_names.end()) {
1637           break;
1638         }
1639         cluster_name_suffix++;
1640       }
1641       cluster_names.insert(new_cluster_name);
1642 
1643       // Change _tpu_replicate attribute for all nodes in this cluster.
1644       // Start with outputs of TPUReplicateMetadata and follow output edges.
1645       std::queue<Node*> queue;
1646       queue.push(iter.second.at(i));
1647       std::unordered_set<Node*> visited;
1648       while (!queue.empty()) {
1649         Node* n = queue.front();
1650         queue.pop();
1651 
1652         visited.insert(n);
1653 
1654         n->ClearAttr(kTPUReplicateAttr);
1655         n->AddAttr(kTPUReplicateAttr, new_cluster_name);
1656 
1657         string cluster_name;
1658         for (const Edge* e : n->out_edges()) {
1659           if (GetNodeAttr(e->dst()->def(), kTPUReplicateAttr, &cluster_name)
1660                   .ok() &&
1661               cluster_name == iter.first &&
1662               visited.find(e->dst()) == visited.end()) {
1663             queue.push(e->dst());
1664           }
1665         }
1666       }
1667       // Change "_tpu_compilation_status" attr for TPUCompilationResult node.
1668       for (const Edge* e : iter.second.at(i)->out_edges()) {
1669         if (e->dst()->type_string() == "TPUCompilationResult") {
1670           e->dst()->ClearAttr("_tpu_compilation_status");
1671           e->dst()->AddAttr("_tpu_compilation_status", new_cluster_name);
1672         }
1673       }
1674     }
1675   }
1676   return Status::OK();
1677 }
1678 
1679 // Instantiate a function that is associated with a functional control flow
1680 // node. The function name is found by looking up `function_name_attr` of given
1681 // node.
InstantiateAssociatedFunction(const Node & n,absl::string_view function_name_attr,FunctionLibraryDefinition * fld)1682 xla::StatusOr<std::unique_ptr<FunctionBody>> InstantiateAssociatedFunction(
1683     const Node& n, absl::string_view function_name_attr,
1684     FunctionLibraryDefinition* fld) {
1685   std::unique_ptr<FunctionBody> fbody;
1686   NameAttrList func_attr_list;
1687   TF_RETURN_IF_ERROR(GetNodeAttr(n.def(), function_name_attr, &func_attr_list));
1688   const FunctionDef* fdef = fld->Find(func_attr_list.name());
1689   if (fdef == nullptr) {
1690     return errors::Internal("Cannot find ", function_name_attr, " function",
1691                             "for node ", n.DebugString());
1692   }
1693   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1694       *fdef, AttrSlice(&func_attr_list.attr()), fld, &fbody));
1695   return fbody;
1696 }
1697 
1698 // Find inputs of If node that are only used for outside compilation if used at
1699 // all in both if/else branches
FindArgsToLiftForIfNode(const Node & if_node,FunctionLibraryDefinition * fld)1700 xla::StatusOr<absl::flat_hash_set<int>> FindArgsToLiftForIfNode(
1701     const Node& if_node, FunctionLibraryDefinition* fld) {
1702   absl::flat_hash_set<int> args_to_lift_indices;
1703   std::vector<DataType> dtypes;
1704   TF_RETURN_IF_ERROR(GetNodeAttr(if_node.def(), "Tin", &dtypes));
1705 
1706   int num_args = dtypes.size();
1707 
1708   for (int i = 0; i < num_args; i++) {
1709     // TODO(b/74023706): enable non resource inputs as well.
1710     if (dtypes[i] == DT_RESOURCE) {
1711       args_to_lift_indices.insert(i);
1712     }
1713   }
1714 
1715   TF_ASSIGN_OR_RETURN(
1716       std::unique_ptr<FunctionBody> then_branch_fbody,
1717       InstantiateAssociatedFunction(if_node, "then_branch", fld));
1718 
1719   TF_ASSIGN_OR_RETURN(
1720       std::unique_ptr<FunctionBody> else_branch_fbody,
1721       InstantiateAssociatedFunction(if_node, "else_branch", fld));
1722 
1723   for (int i = 0; i < num_args; ++i) {
1724     bool used = false;
1725 
1726     const Node* then_arg_node = then_branch_fbody->arg_nodes[i];
1727     for (const Edge* e : then_arg_node->out_edges()) {
1728       used = true;
1729       if (e->IsControlEdge() ||
1730           HasNodeAttr(e->dst()->def(), kOutsideCompilationAttr))
1731         continue;
1732 
1733       args_to_lift_indices.erase(i);
1734       break;
1735     }
1736 
1737     const Node* else_arg_node = else_branch_fbody->arg_nodes[i];
1738     for (const Edge* e : else_arg_node->out_edges()) {
1739       used = true;
1740       if (e->IsControlEdge() ||
1741           HasNodeAttr(e->dst()->def(), kOutsideCompilationAttr))
1742         continue;
1743 
1744       args_to_lift_indices.erase(i);
1745       break;
1746     }
1747 
1748     // Do not lift arguments that are not used at all. Otherwise, this unused
1749     // arg would be outside compiled, its output tensor will be forced to
1750     // transfer to host needlessly.
1751     if (!used) args_to_lift_indices.erase(i);
1752   }
1753 
1754   return args_to_lift_indices;
1755 }
1756 
1757 // Find inputs of While node that are:
1758 // 1. not used in cond func,
1759 // 2. only used for outside compilation in body func,
1760 // 3. loop invariant.
1761 // These inputs can be lifted out of the while loop.
FindArgsToLiftForWhileNode(Node * while_node,FunctionLibraryDefinition * fld)1762 xla::StatusOr<absl::flat_hash_set<int>> FindArgsToLiftForWhileNode(
1763     Node* while_node, FunctionLibraryDefinition* fld) {
1764   // DT_RESOURCE inputs are candidates.
1765   absl::flat_hash_set<int> result;
1766   std::vector<DataType> dtypes;
1767   TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "T", &dtypes));
1768   for (int i = 0; i < dtypes.size(); i++) {
1769     // TODO(b/74023706): enable non resource inputs as well.
1770     if (dtypes[i] == DT_RESOURCE) {
1771       result.insert(i);
1772     }
1773   }
1774 
1775   // Remove inputs that are used in cond func.
1776   NameAttrList cond_func;
1777   TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "cond", &cond_func));
1778   const FunctionDef* cond_fdef = fld->Find(cond_func.name());
1779   if (cond_fdef == nullptr) {
1780     return errors::Internal("Cannot find cond function ", cond_func.name(),
1781                             " for while node ", while_node->DebugString());
1782   }
1783   std::unique_ptr<FunctionBody> cond_fbody;
1784   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1785       *cond_fdef, AttrSlice(&cond_func.attr()), fld, &cond_fbody));
1786   for (int i = 0; i < cond_fbody->arg_nodes.size(); i++) {
1787     const Node* arg_node = cond_fbody->arg_nodes[i];
1788     for (const Edge* e : arg_node->out_edges()) {
1789       if (!e->IsControlEdge()) {
1790         result.erase(i);
1791       }
1792     }
1793   }
1794 
1795   // Remove inputs that are not loop invariant.
1796   NameAttrList body_func;
1797   TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "body", &body_func));
1798   const FunctionDef* body_fdef = fld->Find(body_func.name());
1799   if (body_fdef == nullptr) {
1800     return errors::Internal("Cannot find body function ", body_func.name(),
1801                             " for while node ", while_node->DebugString());
1802   }
1803   std::unique_ptr<FunctionBody> body_fbody;
1804   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1805       *body_fdef, AttrSlice(&body_func.attr()), fld, &body_fbody));
1806   for (int i = 0; i < body_fbody->ret_nodes.size(); i++) {
1807     const Node* node = body_fbody->ret_nodes[i];
1808     do {
1809       TF_RETURN_IF_ERROR(node->input_node(0, &node));
1810     } while (node->IsIdentity());
1811     if (node != body_fbody->arg_nodes[i]) {
1812       result.erase(i);
1813     }
1814   }
1815 
1816   // Remove inputs that only have one output edge (loop invariant, but not used
1817   // in outside compilation).
1818   for (int i = 0; i < body_fbody->arg_nodes.size(); i++) {
1819     const Node* arg_node = body_fbody->arg_nodes[i];
1820     int data_edge_count = std::count_if(
1821         arg_node->out_edges().begin(), arg_node->out_edges().end(),
1822         [](const Edge* e) { return !e->IsControlEdge(); });
1823     if (data_edge_count == 1) {
1824       result.erase(i);
1825     }
1826   }
1827 
1828   // Remove inputs that have non-outside-compilation usage.
1829   for (int i = 0; i < body_fbody->arg_nodes.size(); i++) {
1830     const Node* arg_node = body_fbody->arg_nodes[i];
1831     for (const Edge* e : arg_node->out_edges()) {
1832       if (!e->dst()->IsRetval() &&
1833           !HasNodeAttr(e->dst()->def(), kOutsideCompilationAttr)) {
1834         result.erase(i);
1835         break;
1836       }
1837     }
1838   }
1839 
1840   return result;
1841 }
1842 
1843 // Find inputs of function call node that are only used for outside compilation.
1844 // These inputs can be lifted out of the function call node.
FindArgsToLiftForCallNode(Node * call_node,const FunctionBody & fbody)1845 xla::StatusOr<absl::flat_hash_set<int>> FindArgsToLiftForCallNode(
1846     Node* call_node, const FunctionBody& fbody) {
1847   // DT_RESOURCE inputs are candidates.
1848   absl::flat_hash_set<int> result;
1849   std::vector<DataType> dtypes(call_node->input_types().begin(),
1850                                call_node->input_types().end());
1851   for (int i = 0; i < dtypes.size(); i++) {
1852     // TODO(b/74023706): enable for non resource inputs as well.
1853     if (dtypes[i] == DT_RESOURCE) {
1854       result.insert(i);
1855     }
1856   }
1857 
1858   // Remove inputs that have non-outside-compilation usage, or not used at all.
1859   for (int i = 0; i < fbody.arg_nodes.size(); i++) {
1860     const Node* arg_node = fbody.arg_nodes[i];
1861     if (arg_node->out_edges().empty()) {
1862       result.erase(i);
1863       continue;
1864     }
1865 
1866     for (const Edge* e : arg_node->out_edges()) {
1867       if (!HasNodeAttr(e->dst()->def(), kOutsideCompilationAttr)) {
1868         result.erase(i);
1869         break;
1870       }
1871     }
1872   }
1873   return result;
1874 }
1875 
1876 Status LiftOutsideCompilationOnlyArgs(Graph* g, FunctionLibraryRuntime* flr,
1877                                       FunctionLibraryDefinition* fld,
1878                                       int* lifted_arg_count, bool* rewritten);
1879 
LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(const FunctionBody & fbody,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,int * lifted_arg_count,absl::optional<string> new_func_name,bool * rewritten)1880 Status LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
1881     const FunctionBody& fbody, FunctionLibraryRuntime* flr,
1882     FunctionLibraryDefinition* fld, int* lifted_arg_count,
1883     absl::optional<string> new_func_name, bool* rewritten) {
1884   *rewritten = false;
1885   TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgs(
1886       fbody.graph, flr, fld, lifted_arg_count, rewritten));
1887 
1888   if (*rewritten) {
1889     FunctionDef rewritten_fdef;
1890     TF_RETURN_IF_ERROR(GraphToFunctionDef(
1891         *(fbody.graph), fbody.fdef.signature().name(), &rewritten_fdef));
1892     if (new_func_name) {
1893       rewritten_fdef.mutable_signature()->set_name(*new_func_name);
1894       TF_RETURN_IF_ERROR(fld->AddFunctionDef(rewritten_fdef));
1895     } else {
1896       TF_RETURN_IF_ERROR(
1897           fld->ReplaceFunction(fbody.fdef.signature().name(), rewritten_fdef));
1898     }
1899   }
1900 
1901   return Status::OK();
1902 }
1903 
MakeIdentityNodesForArgsToLift(const absl::flat_hash_set<int> & args_to_lift,const int arg_to_input_edge_offset,Graph * g,Node * n,absl::flat_hash_map<int,string> * lifted_arg_index_to_oc_cluster_name,int * lifted_arg_count)1904 Status MakeIdentityNodesForArgsToLift(
1905     const absl::flat_hash_set<int>& args_to_lift,
1906     const int arg_to_input_edge_offset, Graph* g, Node* n,
1907     absl::flat_hash_map<int, string>* lifted_arg_index_to_oc_cluster_name,
1908     int* lifted_arg_count) {
1909   int num_input = n->num_inputs();
1910   for (int arg_index = 0; arg_index < num_input; ++arg_index) {
1911     if (!args_to_lift.contains(arg_index)) continue;
1912 
1913     int input_edge_index = arg_index + arg_to_input_edge_offset;
1914     const Edge* arg_edge;
1915     TF_RETURN_IF_ERROR(n->input_edge(input_edge_index, &arg_edge));
1916 
1917     string node_name =
1918         g->NewName(absl::StrCat("lifted_arg", *lifted_arg_count));
1919     (*lifted_arg_count)++;
1920     (*lifted_arg_index_to_oc_cluster_name)[arg_index] = node_name;
1921     NodeDefBuilder id_builder(node_name, "Identity");
1922     id_builder.Attr("T", n->input_type(input_edge_index));
1923     id_builder.Attr(kOutsideCompilationAttr, id_builder.node_name());
1924     id_builder.Attr(kXlaIsLiftedArgAttrName, true);
1925     id_builder.Input(arg_edge->src()->name(), arg_edge->src_output(),
1926                      n->input_type(input_edge_index));
1927     NodeDef id_def;
1928     TF_RETURN_IF_ERROR(id_builder.Finalize(&id_def));
1929     Status s;
1930     Node* id_node = g->AddNode(id_def, &s);
1931     TF_RETURN_IF_ERROR(s);
1932     g->AddEdge(arg_edge->src(), arg_edge->src_output(), id_node, 0);
1933     g->AddControlEdge(id_node, n);
1934   }
1935 
1936   return Status::OK();
1937 }
1938 
1939 // Replaces all usages of lifted args with placeholder nodes. Afterwards,
1940 // removing these args should be safe since they no longer have users.
RemoveArgsToLiftFromFunctionBody(const absl::flat_hash_set<int> & args_to_lift,const std::vector<DataType> & arg_dtypes,const absl::flat_hash_map<int,string> & lifted_arg_index_to_oc_cluster_name,const absl::flat_hash_map<int,int> & index_mapping,const FunctionBody * fbody)1941 Status RemoveArgsToLiftFromFunctionBody(
1942     const absl::flat_hash_set<int>& args_to_lift,
1943     const std::vector<DataType>& arg_dtypes,
1944     const absl::flat_hash_map<int, string>& lifted_arg_index_to_oc_cluster_name,
1945     const absl::flat_hash_map<int, int>& index_mapping,
1946     const FunctionBody* fbody) {
1947   for (int i = 0; i < fbody->arg_nodes.size(); ++i) {
1948     Node* arg_node = fbody->arg_nodes[i];
1949 
1950     if (!args_to_lift.contains(i)) {
1951       int new_index = index_mapping.at(i);
1952       arg_node->ClearAttr("index");
1953       arg_node->AddAttr("index", new_index);
1954       arg_node->ClearAttr("T");
1955       arg_node->AddAttr("T", arg_dtypes[i]);
1956       continue;
1957     }
1958 
1959     std::vector<const Edge*> out_edges_to_oc;
1960     for (const Edge* e : arg_node->out_edges()) {
1961       if (HasNodeAttr(e->dst()->def(), kOutsideCompilationAttr)) {
1962         out_edges_to_oc.push_back(e);
1963       }
1964     }
1965 
1966     for (const Edge* e : out_edges_to_oc) {
1967       string outside_compilation_cluster;
1968       TF_RETURN_IF_ERROR(GetNodeAttr(e->dst()->def(), kOutsideCompilationAttr,
1969                                      &outside_compilation_cluster));
1970       NodeDefBuilder ph_builder(fbody->graph->NewName("lifted_arg"),
1971                                 "Placeholder");
1972       ph_builder.Attr("dtype", arg_dtypes[i]);
1973       ph_builder.Attr(kOutsideCompilationAttr, outside_compilation_cluster);
1974       TF_RET_CHECK(lifted_arg_index_to_oc_cluster_name.contains(i));
1975       ph_builder.Attr(kXlaLiftedArgOutsideCompilationAttrName,
1976                       lifted_arg_index_to_oc_cluster_name.at(i));
1977 
1978       NodeDef ph_def;
1979       TF_RETURN_IF_ERROR(ph_builder.Finalize(&ph_def));
1980 
1981       Status s;
1982       Node* ph_node = fbody->graph->AddNode(ph_def, &s);
1983       TF_RETURN_IF_ERROR(s);
1984 
1985       Node* dst = e->dst();
1986       int dst_input = e->dst_input();
1987       fbody->graph->RemoveEdge(e);
1988       fbody->graph->AddEdge(ph_node, 0, dst, dst_input);
1989     }
1990 
1991     fbody->graph->RemoveNode(arg_node);
1992   }
1993 
1994   return Status::OK();
1995 }
1996 
CleanUpInEdges(const absl::flat_hash_map<int,int> & index_mapping,const int arg_to_input_edge_offset,Graph * g,Node * n)1997 Status CleanUpInEdges(const absl::flat_hash_map<int, int>& index_mapping,
1998                       const int arg_to_input_edge_offset, Graph* g, Node* n) {
1999   int num_inputs = n->num_inputs();
2000   for (int i = 0; i < num_inputs; ++i) {
2001     if (i < arg_to_input_edge_offset) continue;
2002 
2003     int arg_idx = i - arg_to_input_edge_offset;
2004     const Edge* e;
2005     TF_RETURN_IF_ERROR(n->input_edge(i, &e));
2006 
2007     // If an edge maps to a lifted argument, simply remove that edge from graph.
2008     if (!index_mapping.contains(arg_idx)) {
2009       g->RemoveEdge(e);
2010       continue;
2011     }
2012 
2013     // If an edge maps to same input port, nothing to do.
2014     if (index_mapping.at(arg_idx) == arg_idx) continue;
2015 
2016     g->AddEdge(e->src(), e->src_output(), n,
2017                index_mapping.at(arg_idx) + arg_to_input_edge_offset);
2018     g->RemoveEdge(e);
2019   }
2020 
2021   return Status::OK();
2022 }
2023 
UpdateTypeAttribute(const absl::flat_hash_map<int,int> & index_mapping,const string & type_attr_name,const std::vector<DataType> & dtypes,Node * n)2024 Status UpdateTypeAttribute(const absl::flat_hash_map<int, int>& index_mapping,
2025                            const string& type_attr_name,
2026                            const std::vector<DataType>& dtypes, Node* n) {
2027   std::vector<DataType> new_dtypes;
2028   new_dtypes.reserve(index_mapping.size());
2029   for (int i = 0; i < dtypes.size(); ++i) {
2030     if (index_mapping.contains(i)) {
2031       new_dtypes.emplace_back(dtypes[i]);
2032     }
2033   }
2034 
2035   n->ClearAttr(type_attr_name);
2036   n->AddAttr(type_attr_name, new_dtypes);
2037 
2038   return Status::OK();
2039 }
2040 
2041 // While V2 always creates Identity node for each While node output, which is
2042 // not necessary for XLA computation. Remove those Identity nodes.
RemoveOutputIdentityNodesForWhileV2(Graph * g,Node * while_node)2043 void RemoveOutputIdentityNodesForWhileV2(Graph* g, Node* while_node) {
2044   std::vector<const Edge*> edges_to_identity_node;
2045   for (const Edge* e : while_node->out_edges()) {
2046     if (!e->IsControlEdge() && e->dst()->IsIdentity()) {
2047       edges_to_identity_node.push_back(e);
2048     }
2049   }
2050   for (const Edge* e : edges_to_identity_node) {
2051     Node* identity = e->dst();
2052     std::vector<const Edge*> out_edges(identity->out_edges().begin(),
2053                                        identity->out_edges().end());
2054     for (const Edge* out_edge : out_edges) {
2055       if (out_edge->IsControlEdge()) {
2056         g->AddControlEdge(while_node, out_edge->dst());
2057       } else {
2058         Node* dst = out_edge->dst();
2059         int dst_input = out_edge->dst_input();
2060         g->RemoveEdge(out_edge);
2061         g->AddEdge(while_node, e->src_output(), dst, dst_input);
2062       }
2063     }
2064     g->RemoveNode(identity);
2065   }
2066 }
2067 
2068 // If corresponding While node output is used, change it to use While node input
2069 // instead.
ReplaceOutputEdgesWithInputEdgeSourceForWhile(const absl::flat_hash_set<int> & args_to_lift,Graph * g,Node * while_node)2070 Status ReplaceOutputEdgesWithInputEdgeSourceForWhile(
2071     const absl::flat_hash_set<int>& args_to_lift, Graph* g, Node* while_node) {
2072   std::vector<const Edge*> edges_to_replace;
2073   for (const Edge* e : while_node->out_edges()) {
2074     if (args_to_lift.contains(e->src_output())) {
2075       edges_to_replace.push_back(e);
2076     }
2077   }
2078   for (const Edge* e : edges_to_replace) {
2079     const Edge* input_edge;
2080     TF_RETURN_IF_ERROR(while_node->input_edge(e->src_output(), &input_edge));
2081     Node* dst = e->dst();
2082     int dst_input = e->dst_input();
2083     g->RemoveEdge(e);
2084     g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
2085   }
2086 
2087   return Status::OK();
2088 }
2089 
2090 // Calculates mapping from argument index before lifting to index afterwards.
ArgIndexMapping(const int num_args,const absl::flat_hash_set<int> & args_to_lift)2091 absl::flat_hash_map<int, int> ArgIndexMapping(
2092     const int num_args, const absl::flat_hash_set<int>& args_to_lift) {
2093   absl::flat_hash_map<int, int> index_mapping;
2094   int new_index = 0;
2095   for (int i = 0; i < num_args; i++) {
2096     if (!args_to_lift.contains(i)) {
2097       index_mapping[i] = new_index;
2098       ++new_index;
2099     }
2100   }
2101 
2102   return index_mapping;
2103 }
2104 
2105 // Remove outputs of While node body function that maps to lifted arguments.
CleanUpRetvalsForWhileBody(const absl::flat_hash_map<int,int> & index_mapping,const std::vector<DataType> & dtypes,FunctionBody * fbody)2106 void CleanUpRetvalsForWhileBody(
2107     const absl::flat_hash_map<int, int>& index_mapping,
2108     const std::vector<DataType>& dtypes, FunctionBody* fbody) {
2109   for (int i = 0; i < fbody->ret_nodes.size(); i++) {
2110     Node* ret_node = fbody->ret_nodes[i];
2111     if (index_mapping.contains(i)) {
2112       int new_index = index_mapping.at(i);
2113       ret_node->ClearAttr("index");
2114       ret_node->AddAttr("index", new_index);
2115       ret_node->ClearAttr("T");
2116       ret_node->AddAttr("T", dtypes[i]);
2117     } else {
2118       fbody->graph->RemoveNode(ret_node);
2119     }
2120   }
2121 }
2122 
LiftOutsideCompilationOnlyArgsFromWhileNode(Graph * g,Node * while_node,FunctionLibraryDefinition * fld,int * lifted_arg_count,bool * rewritten)2123 Status LiftOutsideCompilationOnlyArgsFromWhileNode(
2124     Graph* g, Node* while_node, FunctionLibraryDefinition* fld,
2125     int* lifted_arg_count, bool* rewritten) {
2126   *rewritten = false;
2127 
2128   TF_ASSIGN_OR_RETURN(absl::flat_hash_set<int> args_to_lift,
2129                       FindArgsToLiftForWhileNode(while_node, fld));
2130   if (args_to_lift.empty()) return Status::OK();
2131 
2132   RemoveOutputIdentityNodesForWhileV2(g, while_node);
2133 
2134   TF_RETURN_IF_ERROR(ReplaceOutputEdgesWithInputEdgeSourceForWhile(
2135       args_to_lift, g, while_node));
2136 
2137   std::vector<DataType> dtypes;
2138   TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "T", &dtypes));
2139 
2140   absl::flat_hash_map<int, int> index_mapping =
2141       ArgIndexMapping(dtypes.size(), args_to_lift);
2142 
2143   // For each lifted arg, add an outside compilation Identity node to send
2144   // it to host.
2145   absl::flat_hash_map<int, string> lifted_arg_index_to_oc_cluster_name;
2146   TF_RETURN_IF_ERROR(MakeIdentityNodesForArgsToLift(
2147       args_to_lift, /*arg_to_input_edge_offset=*/0, g, while_node,
2148       &lifted_arg_index_to_oc_cluster_name, lifted_arg_count));
2149 
2150   // For cond func, remove _Arg nodes.
2151   TF_ASSIGN_OR_RETURN(std::unique_ptr<FunctionBody> cond_fbody,
2152                       InstantiateAssociatedFunction(*while_node, "cond", fld));
2153   TF_RETURN_IF_ERROR(RemoveArgsToLiftFromFunctionBody(
2154       args_to_lift, dtypes, lifted_arg_index_to_oc_cluster_name, index_mapping,
2155       cond_fbody.get()));
2156 
2157   FunctionDef rewritten_cond_fdef;
2158   TF_RETURN_IF_ERROR(GraphToFunctionDef(*(cond_fbody->graph),
2159                                         cond_fbody->fdef.signature().name(),
2160                                         &rewritten_cond_fdef));
2161   TF_RETURN_IF_ERROR(fld->ReplaceFunction(cond_fbody->fdef.signature().name(),
2162                                           rewritten_cond_fdef));
2163 
2164   // For body func, remove _Retval nodes, and replace _Arg nodes with
2165   // Placeholder nodes.
2166   TF_ASSIGN_OR_RETURN(std::unique_ptr<FunctionBody> body_fbody,
2167                       InstantiateAssociatedFunction(*while_node, "body", fld));
2168 
2169   TF_RETURN_IF_ERROR(RemoveArgsToLiftFromFunctionBody(
2170       args_to_lift, dtypes, lifted_arg_index_to_oc_cluster_name, index_mapping,
2171       body_fbody.get()));
2172 
2173   CleanUpRetvalsForWhileBody(index_mapping, dtypes, body_fbody.get());
2174 
2175   FunctionDef rewritten_body_fdef;
2176   TF_RETURN_IF_ERROR(GraphToFunctionDef(*(body_fbody->graph),
2177                                         body_fbody->fdef.signature().name(),
2178                                         &rewritten_body_fdef));
2179   TF_RETURN_IF_ERROR(fld->ReplaceFunction(body_fbody->fdef.signature().name(),
2180                                           rewritten_body_fdef));
2181 
2182   // Remove edges from lifted args to While node, and change "T" attr of the
2183   // While node.
2184   TF_RETURN_IF_ERROR(CleanUpInEdges(
2185       index_mapping, /*arg_to_input_edge_offset=*/0, g, while_node));
2186 
2187   TF_RETURN_IF_ERROR(
2188       UpdateTypeAttribute(index_mapping, "T", dtypes, while_node));
2189 
2190   *rewritten = true;
2191 
2192   return Status::OK();
2193 }
2194 
LiftOutsideCompilationOnlyArgsFromIfNode(Graph * g,Node * if_node,FunctionLibraryDefinition * fld,int * lifted_arg_count,bool * rewritten)2195 Status LiftOutsideCompilationOnlyArgsFromIfNode(Graph* g, Node* if_node,
2196                                                 FunctionLibraryDefinition* fld,
2197                                                 int* lifted_arg_count,
2198                                                 bool* rewritten) {
2199   *rewritten = false;
2200   TF_ASSIGN_OR_RETURN(absl::flat_hash_set<int> args_to_lift,
2201                       FindArgsToLiftForIfNode(*if_node, fld));
2202   if (args_to_lift.empty()) return Status::OK();
2203 
2204   std::vector<DataType> dtypes;
2205   TF_RETURN_IF_ERROR(GetNodeAttr(if_node->def(), "Tin", &dtypes));
2206 
2207   absl::flat_hash_map<int, int> index_mapping;
2208   int new_index = 0;
2209   for (int i = 0; i < dtypes.size(); i++) {
2210     if (!args_to_lift.contains(i)) {
2211       index_mapping[i] = new_index;
2212       ++new_index;
2213     }
2214   }
2215 
2216   // For each lifted arg, add an outside compilation Identity node to send
2217   // it to host.
2218   absl::flat_hash_map<int, string> lifted_arg_index_to_oc_cluster_name;
2219   TF_RETURN_IF_ERROR(MakeIdentityNodesForArgsToLift(
2220       args_to_lift, /*arg_to_input_edge_offset=*/1, g, if_node,
2221       &lifted_arg_index_to_oc_cluster_name, lifted_arg_count));
2222 
2223   TF_ASSIGN_OR_RETURN(
2224       std::unique_ptr<FunctionBody> then_branch_fbody,
2225       InstantiateAssociatedFunction(*if_node, "then_branch", fld));
2226 
2227   TF_RETURN_IF_ERROR(RemoveArgsToLiftFromFunctionBody(
2228       args_to_lift, dtypes, lifted_arg_index_to_oc_cluster_name, index_mapping,
2229       then_branch_fbody.get()));
2230 
2231   FunctionDef rewritten_then_branch_fdef;
2232   TF_RETURN_IF_ERROR(GraphToFunctionDef(
2233       *(then_branch_fbody->graph), then_branch_fbody->fdef.signature().name(),
2234       &rewritten_then_branch_fdef));
2235   TF_RETURN_IF_ERROR(fld->ReplaceFunction(
2236       then_branch_fbody->fdef.signature().name(), rewritten_then_branch_fdef));
2237 
2238   TF_ASSIGN_OR_RETURN(
2239       std::unique_ptr<FunctionBody> else_branch_fbody,
2240       InstantiateAssociatedFunction(*if_node, "else_branch", fld));
2241 
2242   TF_RETURN_IF_ERROR(RemoveArgsToLiftFromFunctionBody(
2243       args_to_lift, dtypes, lifted_arg_index_to_oc_cluster_name, index_mapping,
2244       else_branch_fbody.get()));
2245 
2246   FunctionDef rewritten_else_branch_fdef;
2247   TF_RETURN_IF_ERROR(GraphToFunctionDef(
2248       *(else_branch_fbody->graph), else_branch_fbody->fdef.signature().name(),
2249       &rewritten_else_branch_fdef));
2250   TF_RETURN_IF_ERROR(fld->ReplaceFunction(
2251       else_branch_fbody->fdef.signature().name(), rewritten_else_branch_fdef));
2252 
2253   // Remove edges from lifted args to If node, and change "Tin" attr of the
2254   // If node.
2255   TF_RETURN_IF_ERROR(CleanUpInEdges(
2256       index_mapping, /*arg_to_input_edge_offset=*/1, g, if_node));
2257   TF_RETURN_IF_ERROR(
2258       UpdateTypeAttribute(index_mapping, "Tin", dtypes, if_node));
2259 
2260   *rewritten = true;
2261 
2262   return Status::OK();
2263 }
2264 
LiftOutsideCompilationOnlyArgsFromCallNode(Graph * g,Node * call_node,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,int * lifted_arg_count,bool * rewritten)2265 Status LiftOutsideCompilationOnlyArgsFromCallNode(
2266     Graph* g, Node* call_node, FunctionLibraryRuntime* flr,
2267     FunctionLibraryDefinition* fld, int* lifted_arg_count, bool* rewritten) {
2268   *rewritten = false;
2269 
2270   // Instantiate the function.
2271   NameAttrList func;
2272   if (fld->Contains(call_node->type_string())) {
2273     func.set_name(call_node->type_string());
2274     *func.mutable_attr() = call_node->def().attr();
2275   } else if (call_node->IsPartitionedCall()) {
2276     TF_RETURN_IF_ERROR(GetNodeAttr(call_node->def(), "f", &func));
2277   } else {
2278     TF_RET_CHECK(call_node->type_string() ==
2279                  FunctionLibraryDefinition::kGradientOp);
2280     func.set_name(FunctionLibraryDefinition::kGradientOp);
2281     *func.mutable_attr() = call_node->def().attr();
2282   }
2283   FunctionLibraryRuntime::Handle handle;
2284   TF_RETURN_IF_ERROR(
2285       flr->Instantiate(func.name(), AttrSlice(&func.attr()), &handle));
2286   auto cleanup_handle = gtl::MakeCleanup(
2287       [&flr, &handle]() { flr->ReleaseHandle(handle).IgnoreError(); });
2288   const FunctionBody* fbody = flr->GetFunctionBody(handle);
2289 
2290   // Find _Arg nodes to lift.
2291   TF_ASSIGN_OR_RETURN(absl::flat_hash_set<int> args_to_lift,
2292                       FindArgsToLiftForCallNode(call_node, *fbody));
2293   if (args_to_lift.empty()) return Status::OK();
2294 
2295   std::vector<DataType> dtypes;
2296   dtypes = std::vector<DataType>(call_node->input_types().begin(),
2297                                  call_node->input_types().end());
2298 
2299   absl::flat_hash_map<int, int> index_mapping =
2300       ArgIndexMapping(dtypes.size(), args_to_lift);
2301 
2302   // For each lifted arg, add an outside compilation Identity node to send
2303   // it to host.
2304   absl::flat_hash_map<int, string> lifted_arg_index_to_oc_cluster_name;
2305   TF_RETURN_IF_ERROR(MakeIdentityNodesForArgsToLift(
2306       args_to_lift, /*arg_to_input_edge_offset=*/0, g, call_node,
2307       &lifted_arg_index_to_oc_cluster_name, lifted_arg_count));
2308 
2309   // Remove _Arg nodes.
2310   TF_RETURN_IF_ERROR(RemoveArgsToLiftFromFunctionBody(
2311       args_to_lift, dtypes, lifted_arg_index_to_oc_cluster_name, index_mapping,
2312       fbody));
2313 
2314   // Store rewritten function as a new function, because the original function
2315   // might be defined by user and we should not modify it.
2316   FunctionDef rewritten_fdef;
2317   TF_RETURN_IF_ERROR(GraphToFunctionDef(
2318       *(fbody->graph), fbody->fdef.signature().name(), &rewritten_fdef));
2319   string new_func_name =
2320       fld->UniqueFunctionName(fbody->fdef.signature().name());
2321   rewritten_fdef.mutable_signature()->set_name(new_func_name);
2322   TF_RETURN_IF_ERROR(fld->AddFunctionDef(rewritten_fdef));
2323 
2324   // Remove edges from lifted args to call node.
2325   TF_RETURN_IF_ERROR(CleanUpInEdges(
2326       index_mapping, /*arg_to_input_edge_offset=*/0, g, call_node));
2327 
2328   // Rewrite the call node to use the rewritten function.
2329   NodeDef node_def;
2330   node_def.set_name(g->NewName(call_node->name()));
2331   node_def.set_op(new_func_name);
2332   if (call_node->IsPartitionedCall()) {
2333     NameAttrList f;
2334     TF_RETURN_IF_ERROR(GetNodeAttr(call_node->def(), "f", &f));
2335     *node_def.mutable_attr() = f.attr();
2336   } else if (fld->Contains(call_node->type_string())) {
2337     *node_def.mutable_attr() = call_node->def().attr();
2338   } else {
2339     TF_RET_CHECK(call_node->type_string() ==
2340                  FunctionLibraryDefinition::kGradientOp);
2341     *node_def.mutable_attr() = call_node->def().attr();
2342     node_def.mutable_attr()->erase(FunctionLibraryDefinition::kFuncAttr);
2343   }
2344   TF_ASSIGN_OR_RETURN(call_node, ReplaceNode(g, call_node, node_def));
2345 
2346   *rewritten = true;
2347 
2348   return Status::OK();
2349 }
2350 
2351 // Lifts outside compilation only _Arg nodes out of If/While/function nodes.
LiftOutsideCompilationOnlyArgs(Graph * g,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,int * lifted_arg_count,bool * rewritten)2352 Status LiftOutsideCompilationOnlyArgs(Graph* g, FunctionLibraryRuntime* flr,
2353                                       FunctionLibraryDefinition* fld,
2354                                       int* lifted_arg_count, bool* rewritten) {
2355   *rewritten = false;
2356 
2357   // Handle deeper functional nodes first.
2358   std::vector<Node*> while_nodes, if_nodes, call_nodes;
2359   for (Node* n : g->op_nodes()) {
2360     if (HasNodeAttr(n->def(), kOutsideCompilationAttr)) {
2361       continue;
2362     }
2363 
2364     if (n->IsWhileNode()) {
2365       TF_ASSIGN_OR_RETURN(std::unique_ptr<FunctionBody> body_fbody,
2366                           InstantiateAssociatedFunction(*n, "body", fld));
2367       bool func_rewritten = false;
2368       TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
2369           *body_fbody, flr, fld, lifted_arg_count,
2370           /*new_func_name=*/absl::nullopt, &func_rewritten));
2371       *rewritten = *rewritten || func_rewritten;
2372 
2373       while_nodes.push_back(n);
2374     } else if (n->IsIfNode()) {
2375       TF_ASSIGN_OR_RETURN(
2376           std::unique_ptr<FunctionBody> then_branch_fbody,
2377           InstantiateAssociatedFunction(*n, "then_branch", fld));
2378       bool func_rewritten = false;
2379       TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
2380           *then_branch_fbody, flr, fld, lifted_arg_count,
2381           /*new_func_name=*/absl::nullopt, &func_rewritten));
2382       *rewritten |= func_rewritten;
2383 
2384       TF_ASSIGN_OR_RETURN(
2385           std::unique_ptr<FunctionBody> else_branch_fbody,
2386           InstantiateAssociatedFunction(*n, "else_branch", fld));
2387       func_rewritten = false;
2388       TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
2389           *else_branch_fbody, flr, fld, lifted_arg_count,
2390           /*new_func_name=*/absl::nullopt, &func_rewritten));
2391       *rewritten |= func_rewritten;
2392 
2393       if_nodes.push_back(n);
2394     } else if (IsFunctionCall(*fld, *n)) {
2395       // Function call nodes need to be rewritten, so handle them later.
2396       call_nodes.push_back(n);
2397     }
2398   }
2399 
2400   std::vector<Node*> rewritten_call_nodes;
2401   for (Node* call_node : call_nodes) {
2402     if (call_node->IsPartitionedCall()) {
2403       std::unique_ptr<FunctionBody> function_fbody;
2404       TF_ASSIGN_OR_RETURN(function_fbody,
2405                           InstantiateAssociatedFunction(*call_node, "f", fld));
2406       bool func_rewritten = false;
2407       string new_func_name =
2408           fld->UniqueFunctionName(function_fbody->fdef.signature().name());
2409       TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
2410           *function_fbody, flr, fld, lifted_arg_count, new_func_name,
2411           &func_rewritten));
2412       if (func_rewritten) {
2413         NameAttrList f;
2414         TF_RETURN_IF_ERROR(GetNodeAttr(call_node->def(), "f", &f));
2415         f.set_name(new_func_name);
2416         call_node->ClearAttr("f");
2417         call_node->AddAttr("f", f);
2418       }
2419 
2420       *rewritten |= func_rewritten;
2421       rewritten_call_nodes.push_back(call_node);
2422     } else if (fld->Contains(call_node->type_string())) {
2423       std::unique_ptr<FunctionBody> function_fbody;
2424       const FunctionDef* fdef = fld->Find(call_node->type_string());
2425       TF_RET_CHECK(fdef);
2426       TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, call_node->attrs(), fld,
2427                                                  &function_fbody));
2428       bool func_rewritten = false;
2429       string new_func_name =
2430           fld->UniqueFunctionName(function_fbody->fdef.signature().name());
2431       TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
2432           *function_fbody, flr, fld, lifted_arg_count, new_func_name,
2433           &func_rewritten));
2434       if (func_rewritten) {
2435         NodeDef node_def;
2436         node_def.set_name(g->NewName(call_node->name()));
2437         node_def.set_op(new_func_name);
2438         *node_def.mutable_attr() = call_node->def().attr();
2439         TF_ASSIGN_OR_RETURN(call_node, ReplaceNode(g, call_node, node_def));
2440       }
2441 
2442       *rewritten |= func_rewritten;
2443       rewritten_call_nodes.push_back(call_node);
2444     } else {
2445       TF_RET_CHECK(call_node->type_string() ==
2446                    FunctionLibraryDefinition::kGradientOp);
2447       FunctionLibraryRuntime::Handle handle;
2448       TF_RETURN_IF_ERROR(flr->Instantiate(call_node->type_string(),
2449                                           call_node->attrs(), &handle));
2450       auto cleanup_handle = gtl::MakeCleanup(
2451           [&flr, &handle]() { flr->ReleaseHandle(handle).IgnoreError(); });
2452       bool func_rewritten = false;
2453       string new_func_name = fld->UniqueFunctionName(
2454           absl::StrCat(call_node->name(), "_lift_args"));
2455       const FunctionBody* function_fbody = flr->GetFunctionBody(handle);
2456       TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
2457           *function_fbody, flr, fld, lifted_arg_count, new_func_name,
2458           &func_rewritten));
2459       if (func_rewritten) {
2460         NodeDef node_def;
2461         node_def.set_name(g->NewName(call_node->name()));
2462         node_def.set_op(new_func_name);
2463         *node_def.mutable_attr() = call_node->def().attr();
2464         node_def.mutable_attr()->erase(FunctionLibraryDefinition::kFuncAttr);
2465         TF_ASSIGN_OR_RETURN(call_node, ReplaceNode(g, call_node, node_def));
2466       }
2467 
2468       *rewritten |= func_rewritten;
2469       rewritten_call_nodes.push_back(call_node);
2470     }
2471   }
2472 
2473   for (Node* n : while_nodes) {
2474     bool node_rewritten = false;
2475     TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsFromWhileNode(
2476         g, n, fld, lifted_arg_count, &node_rewritten));
2477     *rewritten = *rewritten || node_rewritten;
2478   }
2479 
2480   for (Node* n : if_nodes) {
2481     bool node_rewritten = false;
2482     TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsFromIfNode(
2483         g, n, fld, lifted_arg_count, &node_rewritten));
2484     *rewritten = *rewritten || node_rewritten;
2485   }
2486 
2487   for (Node* n : rewritten_call_nodes) {
2488     bool node_rewritten = false;
2489     TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsFromCallNode(
2490         g, n, flr, fld, lifted_arg_count, &node_rewritten));
2491     *rewritten = *rewritten || node_rewritten;
2492   }
2493 
2494   if (*rewritten) {
2495     VLOG(4) << DumpGraphToFile("after_lifting_args", *g, fld);
2496   }
2497 
2498   return Status::OK();
2499 }
2500 
2501 }  // namespace
2502 
Encapsulate(std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def)2503 /*static*/ Status EncapsulateTPUComputationsPass::Encapsulate(
2504     std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def) {
2505   // Check for undeclared outputs before Encapsulation, so we can give a better
2506   // error message.
2507   // TODO(phawkins): merge this with the encapsulation code to avoid the extra
2508   // O(n) pass over the edges.
2509   for (const Edge* e : (*graph)->edges()) {
2510     if (!e->IsControlEdge() &&
2511         e->src()->attrs().Find(kTPUReplicateAttr) != nullptr &&
2512         e->src()->attrs().Find(kOutsideCompilationAttr) == nullptr &&
2513         e->dst()->attrs().Find(kTPUReplicateAttr) == nullptr &&
2514         e->dst()->type_string() != kTPUReplicatedOutput) {
2515       return errors::InvalidArgument(
2516           "Undeclared output of TPU computation. A common cause of this error "
2517           "is variable initializers that depend on the TPU computation. Edge: ",
2518           FormatNodeForError(*e->src()), ":", e->src_output(), " -> ",
2519           FormatNodeForError(*e->dst()), ":", e->dst_input());
2520     }
2521   }
2522 
2523   RemoveUnusedTPUReplicatedInputs(graph->get());
2524 
2525   TF_RETURN_IF_ERROR(RenameClustersWithDuplicatedNames(graph->get()));
2526 
2527   TF_RETURN_IF_ERROR(
2528       PerformStaticShapeInferenceBeforeEncapsulation(graph->get()));
2529 
2530   auto output = absl::make_unique<Graph>((*graph)->op_registry());
2531   TF_RETURN_WITH_CONTEXT_IF_ERROR(
2532       EncapsulateSubgraphsInFunctions(
2533           kTPUReplicateAttr, **graph, RewriteSubgraph,
2534           /*reuse_existing_functions=*/true, &output, flib_def),
2535       "EncapsulateTPUComputationsPass failed");
2536   graph->swap(output);
2537 
2538   return Status::OK();
2539 }
2540 
BuildTPUReplicateOps(Graph * graph)2541 /*static*/ Status EncapsulateTPUComputationsPass::BuildTPUReplicateOps(
2542     Graph* graph) {
2543   // Finds all of the replicate function calls, to avoid mutating the graph
2544   // while iterating.
2545   std::vector<Node*> replicate_nodes;
2546   std::vector<Node*> guarantee_const_nodes;
2547   for (Node* n : graph->nodes()) {
2548     string name;
2549     if (TryGetNodeAttr(n->attrs(), kTPUReplicateAttr, &name) &&
2550         !TryGetNodeAttr(n->attrs(), kOutsideCompilationAttr, &name)) {
2551       replicate_nodes.push_back(n);
2552     } else if (n->type_string() == "GuaranteeConst") {
2553       guarantee_const_nodes.push_back(n);
2554     }
2555   }
2556 
2557   // Replace any GuaranteeConst nodes with Identity nodes. These nodes have now
2558   // served their purpose and have no runtime effect, except increasing
2559   // inference latency due to executor overhead. Subsequent rewrites will remove
2560   // the Identity nodes.
2561   for (Node* n : guarantee_const_nodes) {
2562     std::vector<std::pair<Node*, int>> predecessors;
2563     for (const Edge* e : n->in_edges()) {
2564       predecessors.emplace_back(e->src(), e->src_output());
2565     }
2566     std::vector<std::pair<Node*, int>> successors;
2567     for (const Edge* e : n->out_edges()) {
2568       successors.emplace_back(e->dst(), e->dst_input());
2569     }
2570     NodeDef ndef;
2571     ndef.set_name(n->name());
2572     ndef.set_op("Identity");
2573     ndef.set_device(n->requested_device());
2574     MergeDebugInfo(NodeDebugInfo(n->def()), &ndef);
2575     AddNodeAttr("T", n->output_type(0), &ndef);
2576 
2577     graph->RemoveNode(n);
2578     Status s;
2579     Node* id_node = graph->AddNode(ndef, &s);
2580     TF_RETURN_IF_ERROR(s);
2581 
2582     for (const auto& pred : predecessors) {
2583       if (pred.second < 0) {
2584         graph->AddControlEdge(pred.first, id_node);
2585       } else {
2586         graph->AddEdge(pred.first, pred.second, id_node, 0);
2587       }
2588     }
2589     for (const auto& succ : successors) {
2590       if (succ.second < 0) {
2591         graph->AddControlEdge(id_node, succ.first);
2592       } else {
2593         graph->AddEdge(id_node, 0, succ.first, succ.second);
2594       }
2595     }
2596   }
2597 
2598   // Replaces each replicate function call together with its neighboring
2599   // TPUReplicatedInput/TPUReplicatedOutput nodes with a TPUReplicate node.
2600   for (Node* replicate : replicate_nodes) {
2601     int num_replicas;
2602     TF_RETURN_IF_ERROR(
2603         GetNodeAttr(replicate->attrs(), "num_replicas", &num_replicas));
2604     int variable_start_index;
2605     TF_RETURN_IF_ERROR(GetNodeAttr(replicate->attrs(), "_variable_start_index",
2606                                    &variable_start_index));
2607     int guaranteed_const_start_index;
2608     TF_RETURN_IF_ERROR(GetNodeAttr(replicate->attrs(),
2609                                    "_guaranteed_const_start_index",
2610                                    &guaranteed_const_start_index));
2611 
2612     if (HasNodeAttr(replicate->def(), "use_tpu")) {
2613       bool use_tpu;
2614       TF_RETURN_IF_ERROR(GetNodeAttr(replicate->attrs(), "use_tpu", &use_tpu));
2615       if (!use_tpu) {
2616         LOG(WARNING) << "use_tpu=false attr on a TPUReplicate node is ignored.";
2617       }
2618     }
2619 
2620     std::vector<const Edge*> in_edges;
2621     TF_RETURN_IF_ERROR(replicate->input_edges(&in_edges));
2622 
2623     // Counts the number of replicated, non-replicated, and variable inputs.
2624     int pos = 0;
2625     std::vector<int> mirrored_variable_indices;
2626     int distributed_var_start_index = 0;
2627     while (pos < in_edges.size() &&
2628            in_edges[pos]->src()->type_string() == kTPUReplicatedInput) {
2629       // Checks that each TPUReplicatedInput node has the correct number of
2630       // replicas.
2631       int input_num_replicas;
2632       TF_RETURN_IF_ERROR(
2633           GetNodeAttr(in_edges[pos]->src()->attrs(), "N", &input_num_replicas));
2634 
2635       bool is_mirrored_variable;
2636       CHECK(GetNodeAttr(in_edges[pos]->src()->attrs(), "is_mirrored_variable",
2637                         &is_mirrored_variable)
2638                 .ok());
2639       if (is_mirrored_variable) {
2640         mirrored_variable_indices.push_back(pos);
2641       }
2642 
2643       bool is_packed = false;
2644       GetNodeAttr(in_edges[pos]->src()->attrs(), "is_packed", &is_packed)
2645           .IgnoreError();
2646 
2647       bool is_distributed_variable =
2648           is_packed && (in_edges[pos]->src()->output_type(
2649                             in_edges[pos]->src_output()) == DT_RESOURCE);
2650 
2651       if (!is_distributed_variable && input_num_replicas != num_replicas) {
2652         return errors::InvalidArgument(
2653             "Mismatched number of replicas. Computation has ", num_replicas,
2654             " replicas, input '", FormatNodeForError(*in_edges[pos]->src()),
2655             "' has ", input_num_replicas, " replicas.");
2656       }
2657 
2658       if (!is_distributed_variable) {
2659         if (distributed_var_start_index < pos) {
2660           return errors::InvalidArgument(
2661               "Expect a distributed resource after index ",
2662               distributed_var_start_index,
2663               ", but got a replicated resource at index ", pos);
2664         } else {
2665           ++distributed_var_start_index;
2666         }
2667       }
2668       ++pos;
2669     }
2670     const int num_replicated_inputs = distributed_var_start_index;
2671     const int num_distributed_vars = pos - num_replicated_inputs;
2672 
2673     const int num_variables =
2674         std::max(0, guaranteed_const_start_index - variable_start_index);
2675 
2676     const int num_guaranteed_constants =
2677         in_edges.size() - guaranteed_const_start_index;
2678     TF_RET_CHECK(num_guaranteed_constants >= 0);
2679 
2680     VLOG(1) << "Replicate node '" << replicate->name() << "'"
2681             << " input edges: " << in_edges.size()
2682             << " num_replicated_inputs: " << num_replicated_inputs
2683             << " num_distributed_vars: " << num_distributed_vars
2684             << " num_variables: " << num_variables
2685             << " num_guaranteed_constants: " << num_guaranteed_constants
2686             << " num_mirrored_variables: " << mirrored_variable_indices.size();
2687 
2688     const int num_broadcast_inputs =
2689         in_edges.size() - (num_replicated_inputs + num_distributed_vars +
2690                            num_variables + num_guaranteed_constants);
2691     TF_RET_CHECK(num_broadcast_inputs >= 0);
2692 
2693     const int num_inputs = num_replicated_inputs * num_replicas +
2694                            num_distributed_vars + num_broadcast_inputs +
2695                            num_guaranteed_constants + num_variables;
2696 
2697     std::vector<Node*> nodes_to_remove = {replicate};
2698 
2699     // Data and control inputs to the new TPUReplicate node.
2700     std::vector<std::pair<Node*, int>> data_inputs(num_inputs);
2701     gtl::FlatSet<Node*> control_inputs;
2702 
2703     AddControlInputs(*replicate, &control_inputs);
2704 
2705     // Replicated inputs. Adds the inputs from the TPUReplicatedInput inputs,
2706     // in replica-major order. See the comments in
2707     // distributed_tpu_rewrite_pass.h for a description of the argument order.
2708     DataTypeVector replicated_input_types(num_replicated_inputs * num_replicas +
2709                                           num_distributed_vars);
2710 
2711     // Inputs with is_distributed_variable = false.
2712     for (int i = 0; i < num_replicated_inputs; ++i) {
2713       std::vector<const Edge*> replica_in_edges;
2714       TF_RETURN_IF_ERROR(in_edges[i]->src()->input_edges(&replica_in_edges));
2715       for (int replica = 0; replica < num_replicas; ++replica) {
2716         int pos = replica * num_replicated_inputs + i;
2717         const Edge* edge = replica_in_edges[replica];
2718         data_inputs[pos] = {edge->src(), edge->src_output()};
2719         replicated_input_types[pos] = EdgeType(edge);
2720       }
2721       AddControlInputs(*in_edges[i]->src(), &control_inputs);
2722       nodes_to_remove.push_back(in_edges[i]->src());
2723     }
2724 
2725     // Inputs with is_distributed_variable = true.
2726     for (int i = 0; i < num_distributed_vars; ++i) {
2727       int pos = num_replicas * num_replicated_inputs + i;
2728       std::vector<const Edge*> replica_in_edges;
2729       TF_RETURN_IF_ERROR(
2730           in_edges[num_replicated_inputs + i]->src()->input_edges(
2731               &replica_in_edges));
2732       TF_RET_CHECK(replica_in_edges.size() == 1);
2733       const Edge* edge = replica_in_edges[0];
2734       data_inputs[pos] = {edge->src(), edge->src_output()};
2735       replicated_input_types[pos] = EdgeType(edge);
2736       AddControlInputs(*in_edges[num_replicated_inputs + i]->src(),
2737                        &control_inputs);
2738       nodes_to_remove.push_back(in_edges[num_replicated_inputs + i]->src());
2739     }
2740 
2741     // Appends the broadcast inputs.
2742     DataTypeVector broadcast_input_types(num_broadcast_inputs);
2743     for (int i = 0; i < num_broadcast_inputs; ++i) {
2744       int pos = num_replicas * num_replicated_inputs + num_distributed_vars + i;
2745       const Edge* edge =
2746           in_edges[num_replicated_inputs + num_distributed_vars + i];
2747       data_inputs[pos] = {edge->src(), edge->src_output()};
2748       broadcast_input_types[i] = EdgeType(edge);
2749     }
2750 
2751     // Appends the variable inputs.
2752     for (int i = 0; i < num_variables; ++i) {
2753       int pos = num_replicas * num_replicated_inputs + num_distributed_vars +
2754                 num_broadcast_inputs + i;
2755       const Edge* edge = in_edges[num_replicated_inputs + num_distributed_vars +
2756                                   num_broadcast_inputs + i];
2757       data_inputs[pos] = {edge->src(), edge->src_output()};
2758     }
2759 
2760     DataTypeVector guaranteed_constant_types(num_guaranteed_constants);
2761     for (int i = 0; i < num_guaranteed_constants; ++i) {
2762       int pos = num_replicas * num_replicated_inputs + num_distributed_vars +
2763                 num_broadcast_inputs + num_variables + i;
2764       const Edge* edge = in_edges[num_replicated_inputs + num_distributed_vars +
2765                                   num_broadcast_inputs + num_variables + i];
2766       data_inputs[pos] = {edge->src(), edge->src_output()};
2767       guaranteed_constant_types[i] = EdgeType(edge);
2768     }
2769 
2770     // Outputs. All outputs from a replicated computation are replicated.
2771     const int num_outputs = replicate->output_types().size();
2772     gtl::FlatSet<Node*> control_outputs;
2773     std::vector<Node*> replicated_outputs(num_outputs);
2774     for (const Edge* e : replicate->out_edges()) {
2775       if (e->IsControlEdge()) {
2776         control_outputs.insert(e->dst());
2777       } else {
2778         TF_RET_CHECK(e->src_output() < num_outputs);
2779         TF_RET_CHECK(e->dst()->type_string() == kTPUReplicatedOutput)
2780             << e->DebugString();
2781         TF_RET_CHECK(e->dst()->output_types().size() == num_replicas);
2782         replicated_outputs[e->src_output()] = e->dst();
2783         nodes_to_remove.push_back(e->dst());
2784 
2785         AddControlOutputs(*e->dst(), &control_outputs);
2786       }
2787     }
2788 
2789     // Flattens the edges outgoing from the TPUReplicatedOutput nodes in
2790     // replica-major order.
2791     std::vector<std::vector<std::pair<Node*, int>>> data_outputs(num_replicas *
2792                                                                  num_outputs);
2793     DataTypeVector output_types(num_replicas * num_outputs);
2794     for (int i = 0; i < num_outputs; ++i) {
2795       std::vector<std::vector<const Edge*>> replica_out_edges(num_replicas);
2796       TF_RET_CHECK(replicated_outputs[i] != nullptr);
2797       for (const Edge* e : replicated_outputs[i]->out_edges()) {
2798         TF_RET_CHECK(!e->IsControlEdge());
2799         replica_out_edges[e->src_output()].push_back(e);
2800       }
2801 
2802       for (int replica = 0; replica < num_replicas; ++replica) {
2803         const int pos = replica * num_outputs + i;
2804         for (const Edge* edge : replica_out_edges[replica]) {
2805           data_outputs[pos].push_back({edge->dst(), edge->dst_input()});
2806         }
2807         output_types[pos] = replicated_outputs[i]->input_type(0);
2808       }
2809     }
2810 
2811     // TODO(b/79092708): Consolidate and cleanup to avoid TPU specialization.
2812     NodeDef def;
2813     def.set_name(replicate->name());
2814     def.set_op("_TPUReplicate");
2815     MergeDebugInfo(NodeDebugInfo(replicate->def()), &def);
2816     NameAttrList computation;
2817     computation.set_name(replicate->type_string());
2818     AddNodeAttr("computation", computation, &def);
2819     for (const auto& attr : replicate->attrs()) {
2820       def.mutable_attr()->insert(attr);
2821     }
2822     AddNodeAttr("Tinputs", replicated_input_types, &def);
2823     AddNodeAttr("Tbroadcast_inputs", broadcast_input_types, &def);
2824     AddNodeAttr("NumVariables", num_variables, &def);
2825     AddNodeAttr("Tguaranteed_constants", guaranteed_constant_types, &def);
2826     AddNodeAttr("output_types", output_types, &def);
2827     AddNodeAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
2828                 mirrored_variable_indices, &def);
2829     AddNodeAttr("num_distributed_variables", num_distributed_vars, &def);
2830 
2831     for (Node* node : nodes_to_remove) {
2832       VLOG(2) << "Deleting node " << node->DebugString();
2833       // Ensure that we do not attempt to add control edges to nodes that are
2834       // deleted.
2835       control_inputs.erase(node);
2836       control_outputs.erase(node);
2837       graph->RemoveNode(node);
2838     }
2839 
2840     Status status;
2841     Node* tpu_replicate = graph->AddNode(def, &status);
2842     if (!status.ok()) {
2843       return status;
2844     }
2845     for (int i = 0; i < data_inputs.size(); ++i) {
2846       graph->AddEdge(data_inputs[i].first, data_inputs[i].second, tpu_replicate,
2847                      i);
2848     }
2849     for (Node* n : control_inputs) {
2850       graph->AddControlEdge(n, tpu_replicate);
2851     }
2852     for (int i = 0; i < data_outputs.size(); ++i) {
2853       for (const auto& successor : data_outputs[i]) {
2854         graph->AddEdge(tpu_replicate, i, successor.first, successor.second);
2855       }
2856     }
2857     for (Node* n : control_outputs) {
2858       graph->AddControlEdge(tpu_replicate, n);
2859     }
2860   }
2861   return Status::OK();
2862 }
2863 
Run(const GraphOptimizationPassOptions & options)2864 Status EncapsulateTPUComputationsPass::Run(
2865     const GraphOptimizationPassOptions& options) {
2866   VLOG(1) << "EncapsulateTPUComputations(): "
2867           << DumpGraphToFile("encapsulate_tpu_computations_before",
2868                              **options.graph, options.flib_def);
2869 
2870   TF_RETURN_IF_ERROR(Encapsulate(options.graph, options.flib_def));
2871   VLOG(1) << "EncapsulateTPUComputations() half-way: "
2872           << DumpGraphToFile("encapsulate_tpu_computations_halfway",
2873                              **options.graph, options.flib_def);
2874 
2875   TF_RETURN_IF_ERROR(BuildTPUReplicateOps(options.graph->get()));
2876   VLOG(1) << "EncapsulateTPUComputations() finished: "
2877           << DumpGraphToFile("encapsulate_tpu_computations_after",
2878                              **options.graph, options.flib_def);
2879   return Status::OK();
2880 }
2881 
ProcessHeadTailOutsideCompilation(const string & outside_compilation_attr_name,int * lifted_arg_count,std::unordered_map<string,XlaClusterInfo> * clusters,Graph * g,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld)2882 Status ExtractOutsideCompilationPass::ProcessHeadTailOutsideCompilation(
2883     const string& outside_compilation_attr_name, int* lifted_arg_count,
2884     std::unordered_map<string, XlaClusterInfo>* clusters, Graph* g,
2885     FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld) {
2886   // Gather a list of pivots by cluster so we can easily look them up.
2887   absl::node_hash_map<string, Node*> pivots;
2888   string cluster_name;
2889   for (Node* node : g->nodes()) {
2890     if (TryGetNodeAttr(node->attrs(), kPivotForClusterAttr, &cluster_name)) {
2891       pivots[cluster_name] = node;
2892     }
2893   }
2894   for (auto& iter : *clusters) {
2895     // Find pivot node for this XLA cluster.
2896     Node* pivot_node = pivots[iter.first];
2897 
2898     // Instantiate XLA computation function.
2899     string xla_func_name = iter.second.func_name_attrs.name();
2900     std::unique_ptr<FunctionBody> xla_fbody;
2901     TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
2902         *fld->Find(xla_func_name),
2903         AttrSlice(&iter.second.func_name_attrs.attr()), fld, &xla_fbody));
2904     Graph* xla_graph = xla_fbody->graph;
2905 
2906     // Make sure all nodes can be traced from sink node.
2907     FixupSourceAndSinkEdges(xla_graph);
2908 
2909     // We create Identity nodes for all _Arg/_Retval nodes in XLA computation.
2910     // Remove those Identity nodes to simplify furthur processing.
2911     TF_RETURN_IF_ERROR(RemoveIdentityNodesForArgRetval(xla_graph));
2912 
2913     bool rewritten;
2914     TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgs(
2915         xla_graph, flr, fld, lifted_arg_count, &rewritten));
2916 
2917     // Move head outside compilation to host.
2918     TF_RETURN_IF_ERROR(MoveHeadOutsideCompilationToHost(
2919         outside_compilation_attr_name, iter.second.func_name_attrs.name(),
2920         iter.second.cluster_name, g, xla_graph, iter.second.node, pivot_node));
2921 
2922     // Move tail outside compilation to host.
2923     TF_RETURN_IF_ERROR(MoveTailOutsideCompilationToHost(
2924         outside_compilation_attr_name, iter.second.func_name_attrs.name(),
2925         iter.second.cluster_name, g, xla_graph, iter.second.node, pivot_node));
2926 
2927     // Replace outside compilation only _Arg nodes with Placeholder nodes.
2928     TF_RETURN_IF_ERROR(ReplaceArgUsedByOutsideCompilationWithPlaceholder(
2929         outside_compilation_attr_name, xla_func_name, g, xla_graph,
2930         iter.second.node));
2931 
2932     // There might be direct data edges between _Arg node and _Retval node in
2933     // `xla_graph`. Remove those edges to avoid back-and-forth data transfer
2934     // between host and XLA.
2935     TF_RETURN_IF_ERROR(RemoveEdgesBetweenArgAndRetval(
2936         iter.second.func_name_attrs.name(), g, xla_graph, iter.second.node));
2937 
2938     // After `MoveHeadOutsideCompilationToHost`, there might be unused XLA
2939     // inputs. Remove them.
2940     TF_RETURN_IF_ERROR(RemoveUnusedXlaInput(iter.second.func_name_attrs.name(),
2941                                             g, xla_graph, iter.second.node));
2942 
2943     // After `MoveTailOutsideCompilationToHost`, there might be unused XLA
2944     // outputs. Remove them.
2945     TF_RETURN_IF_ERROR(RemoveUnusedXlaOutput(iter.second.func_name_attrs.name(),
2946                                              g, xla_graph, iter.second.node));
2947 
2948     // Replace original function.
2949     FunctionDef replace_fdef;
2950     TF_RETURN_IF_ERROR(
2951         GraphToFunctionDef(*xla_graph, xla_func_name, &replace_fdef));
2952     TF_RETURN_IF_ERROR(fld->ReplaceFunction(xla_func_name, replace_fdef));
2953 
2954     FixupSourceAndSinkEdges(g);
2955   }
2956 
2957   return Status::OK();
2958 }
2959 
Run(const GraphOptimizationPassOptions & options)2960 Status ExtractOutsideCompilationPass::Run(
2961     const GraphOptimizationPassOptions& options) {
2962   const auto* config =
2963       (options.session_options ? &options.session_options->config : nullptr);
2964   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
2965       new ProcessFunctionLibraryRuntime(
2966           /*device_mgr=*/nullptr, options.session_options->env,
2967           /*config=*/config, TF_GRAPH_DEF_VERSION, options.flib_def,
2968           config ? config->graph_options().optimizer_options()
2969                  : OptimizerOptions()));
2970   FunctionLibraryRuntime* flr =
2971       pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
2972 
2973   // Find XLA compile ops and their corresponding FunctionDefs.
2974   static std::map<string, string>* kNodeTypeToFunctionAttrMapping =
2975       new std::map<string, string>{
2976           {"_TPUReplicate", "computation"},
2977       };
2978   std::unordered_map<string, XlaClusterInfo> clusters;
2979   int lifted_arg_count = 0;
2980   for (Node* n : (*options.graph)->nodes()) {
2981     auto iter = kNodeTypeToFunctionAttrMapping->find(n->type_string());
2982     if (iter == kNodeTypeToFunctionAttrMapping->end()) {
2983       continue;
2984     }
2985 
2986     string xla_cluster_name = n->name();
2987 
2988     string func_attr = iter->second;
2989     NameAttrList func;
2990     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func));
2991 
2992     std::vector<string> core_list;
2993     TF_RETURN_IF_ERROR(
2994         GetNodeAttr(n->attrs(), "host_compute_core", &core_list));
2995     std::map<string, int> host_compute_core;
2996     TF_RETURN_IF_ERROR(ParseHostComputeCoreList(core_list, &host_compute_core));
2997 
2998     clusters.emplace(xla_cluster_name, XlaClusterInfo{xla_cluster_name, func, n,
2999                                                       host_compute_core});
3000   }
3001   TF_RETURN_IF_ERROR(ProcessHeadTailOutsideCompilation(
3002       kOutsideCompilationAttr, &lifted_arg_count, &clusters,
3003       options.graph->get(), flr, options.flib_def));
3004   bool modified;
3005   TF_RETURN_IF_ERROR(ExtractOutsideCompilation(
3006       kTPUReplicateAttr, kOutsideCompilationAttr, clusters,
3007       options.graph->get(), flr, options.flib_def, &modified));
3008   if (modified) {
3009     TF_RETURN_IF_ERROR(
3010         PruneUnreachableFunctionsFromGraph(**options.graph, options.flib_def));
3011   }
3012 
3013   return Status::OK();
3014 }
3015 
3016 }  // namespace tensorflow
3017