1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/strings/match.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
22 #include "tensorflow/compiler/jit/encapsulate_util.h"
23 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
24 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/core/common_runtime/function.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/graph_to_functiondef.h"
29 #include "tensorflow/core/framework/node_def_builder.h"
30 #include "tensorflow/core/framework/node_def_util.h"
31 #include "tensorflow/core/framework/tensor_shape.pb.h"
32 #include "tensorflow/core/graph/algorithm.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/gtl/cleanup.h"
35 #include "tensorflow/core/platform/macros.h"
36 #include "tensorflow/core/util/dump_graph.h"
37 #include "tensorflow/stream_executor/lib/statusor.h"
38 
39 namespace tensorflow {
40 
41 namespace {
42 
43 // Control return mapping function for outside compilation host graphs.
44 // All nodes with kXlaHasHostTransfer attribute are control outputs.
HostGraphControlRetMapping(const Node * n)45 absl::optional<string> HostGraphControlRetMapping(const Node* n) {
46   if (HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) {
47     return n->name();
48   }
49   return absl::nullopt;
50 }
51 
52 // Add a key placeholder node to the graph. The key placeholder node will be
53 // used as input for XlaRecvAtHost/XlaSendFromHost nodes.
AddHostComputeKeyPlaceholder(const string & xla_cluster_name,Graph * g)54 xla::StatusOr<Node*> AddHostComputeKeyPlaceholder(
55     const string& xla_cluster_name, Graph* g) {
56   NodeDef key_def;
57   NodeDefBuilder builder(absl::StrCat(xla_cluster_name, "_key_placeholder"),
58                          "Placeholder");
59   builder.Attr("dtype", DT_STRING);
60   builder.Attr("shape", PartialTensorShape({2}));
61   builder.Attr("_host_compute_call_node", xla_cluster_name);
62   Status s = builder.Finalize(&key_def);
63   if (!s.ok()) return s;
64 
65   Node* n = g->AddNode(key_def, &s);
66   if (!s.ok()) return s;
67   return n;
68 }
69 
70 // Returns if the node is a XLA computation key placeholder.
IsKeyPlaceholderNode(const Node & n)71 bool IsKeyPlaceholderNode(const Node& n) {
72   return n.type_string() == "Placeholder" &&
73          absl::EndsWith(n.name(), "_key_placeholder");
74 }
75 
76 // Returns nodes with given type.
GatherNodesWithType(const Graph & g,const string & type)77 std::vector<Node*> GatherNodesWithType(const Graph& g, const string& type) {
78   std::vector<Node*> result;
79   for (Node* n : g.nodes()) {
80     if (n->type_string() == type) {
81       result.push_back(n);
82     }
83   }
84   return result;
85 }
86 
87 // Gets data types from `arg_nodes` and fills them into `recv_at_host_dtypes`.
GetArgDataTypes(const std::vector<Node * > & arg_nodes,std::vector<DataType> * recv_at_host_dtypes)88 Status GetArgDataTypes(const std::vector<Node*>& arg_nodes,
89                        std::vector<DataType>* recv_at_host_dtypes) {
90   recv_at_host_dtypes->resize(arg_nodes.size(), DT_INVALID);
91   for (auto* n : arg_nodes) {
92     int index;
93     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
94     DataType dtype;
95     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype));
96     (*recv_at_host_dtypes)[index] = dtype;
97   }
98   for (int i = 0, end = recv_at_host_dtypes->size(); i < end; i++) {
99     if ((*recv_at_host_dtypes)[i] == DT_INVALID) {
100       return errors::Internal("Cannot get datatype for input ", i);
101     }
102   }
103   return Status::OK();
104 }
105 
106 // Builds XlaRecvAtHost node.
BuildRecvAtHostNode(Graph * g,const string & oc_cluster_name,const std::vector<DataType> & recv_at_host_dtypes,Node * key_placeholder)107 xla::StatusOr<Node*> BuildRecvAtHostNode(
108     Graph* g, const string& oc_cluster_name,
109     const std::vector<DataType>& recv_at_host_dtypes, Node* key_placeholder) {
110   NodeDefBuilder recv_at_host_builder(
111       absl::StrCat("outside_compilation_", oc_cluster_name, "_recv"),
112       "_XlaRecvAtHost");
113   NodeDef recv_at_host_def;
114   recv_at_host_builder.Attr("Toutputs", recv_at_host_dtypes);
115   // The correct device_ordinal will be inserted during replication in a
116   // subsequent rewrite.
117   AttrValue device_ordinal_value;
118   device_ordinal_value.set_placeholder("_device_ordinal");
119   recv_at_host_builder.Attr("device_ordinal", device_ordinal_value);
120   recv_at_host_builder.Attr(
121       "key", absl::StrCat("host_compute_channel_", oc_cluster_name));
122   recv_at_host_builder.Attr(kXlaHasHostTransferAttrName, true);
123   recv_at_host_builder.Input(key_placeholder->name(), 0, DT_STRING);
124   TF_RETURN_IF_ERROR(recv_at_host_builder.Finalize(&recv_at_host_def));
125   Status s;
126   Node* recv_at_host_node = g->AddNode(recv_at_host_def, &s);
127   TF_RETURN_IF_ERROR(s);
128   return recv_at_host_node;
129 }
130 
131 // Builds XlaRecvAtHost node, and replaces all _Arg nodes with it.
ReplaceArgNodesWithRecvAtHostNode(Graph * g,const string & oc_cluster_name,std::vector<DataType> * recv_at_host_dtypes,Node * key_placeholder)132 xla::StatusOr<Node*> ReplaceArgNodesWithRecvAtHostNode(
133     Graph* g, const string& oc_cluster_name,
134     std::vector<DataType>* recv_at_host_dtypes, Node* key_placeholder) {
135   // TODO(b/77601805): use out nodes for source node, instead of traversing all
136   // nodes.
137   std::vector<Node*> arg_nodes = GatherNodesWithType(*g, "_Arg");
138   TF_RETURN_IF_ERROR(GetArgDataTypes(arg_nodes, recv_at_host_dtypes));
139   TF_ASSIGN_OR_RETURN(
140       Node * recv_at_host_node,
141       BuildRecvAtHostNode(g, oc_cluster_name, *recv_at_host_dtypes,
142                           key_placeholder));
143   for (auto* n : arg_nodes) {
144     int index;
145     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
146     // Record out edges and remove `n` before adding those edges to RecvAtHost.
147     // This is to avoid multiple producers.
148     std::vector<OutEdgeInfo> out_edge_info;
149     for (auto edge : n->out_edges()) {
150       out_edge_info.push_back(
151           {edge->dst(), edge->src_output(), edge->dst_input()});
152     }
153     g->RemoveNode(n);
154     for (const OutEdgeInfo& edge : out_edge_info) {
155       if (edge.dst_input == Graph::kControlSlot) {
156         g->AddControlEdge(recv_at_host_node, edge.dst);
157       } else {
158         g->AddEdge(recv_at_host_node, index, edge.dst, edge.dst_input);
159       }
160     }
161 
162     // Rewrite dst nodes because their input changed.
163     for (int i = 0, end = out_edge_info.size(); i < end; i++) {
164       const OutEdgeInfo edge = out_edge_info[i];
165       if (edge.dst_input == Graph::kControlSlot) {
166         continue;
167       }
168 
169       Node* dst = edge.dst;
170       NodeDef new_def = dst->def();
171       *new_def.mutable_input(edge.dst_input) =
172           absl::StrCat(recv_at_host_node->name(), ":", index);
173       TF_ASSIGN_OR_RETURN(Node * dst_replace, ReplaceNode(g, dst, new_def));
174 
175       // Other edges might have `dst` as dst node as well. Update those edges
176       // with `dst_replace`.
177       for (int j = i + 1, end = out_edge_info.size(); j < end; j++) {
178         if (out_edge_info[j].dst == dst) {
179           out_edge_info[j].dst = dst_replace;
180         }
181       }
182     }
183   }
184   g->AddEdge(key_placeholder, 0, recv_at_host_node, 0);
185   return recv_at_host_node;
186 }
187 
188 // Gets data types from `ret_nodes` and fills them into `send_from_host_dtypes`.
GetRetDataTypes(const std::vector<Node * > & ret_nodes,std::vector<DataType> * send_from_host_dtypes)189 Status GetRetDataTypes(const std::vector<Node*>& ret_nodes,
190                        std::vector<DataType>* send_from_host_dtypes) {
191   send_from_host_dtypes->resize(ret_nodes.size(), DT_INVALID);
192   for (auto* n : ret_nodes) {
193     int index;
194     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
195     DataType dtype;
196     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype));
197     (*send_from_host_dtypes)[index] = dtype;
198   }
199   for (int i = 0, end = send_from_host_dtypes->size(); i < end; i++) {
200     if ((*send_from_host_dtypes)[i] == DT_INVALID) {
201       return errors::Internal("Cannot get datatype for output ", i);
202     }
203   }
204   return Status::OK();
205 }
206 
207 // Builds XlaSendFromHost node.
BuildSendFromHostNode(Graph * g,const string & oc_cluster_name,const std::vector<Node * > & ret_nodes,const std::vector<DataType> & send_from_host_dtypes,Node * key_placeholder)208 xla::StatusOr<Node*> BuildSendFromHostNode(
209     Graph* g, const string& oc_cluster_name,
210     const std::vector<Node*>& ret_nodes,
211     const std::vector<DataType>& send_from_host_dtypes, Node* key_placeholder) {
212   NodeDefBuilder send_from_host_builder(
213       absl::StrCat("outside_compilation_", oc_cluster_name, "_send"),
214       "_XlaSendFromHost");
215   NodeDef send_from_host_def;
216   send_from_host_builder.Attr("Tinputs", send_from_host_dtypes);
217   // The correct device_ordinal will be inserted during replication in a
218   // subsequent rewrite.
219   AttrValue device_ordinal_value;
220   device_ordinal_value.set_placeholder("_device_ordinal");
221   send_from_host_builder.Attr("device_ordinal", device_ordinal_value);
222   send_from_host_builder.Attr(
223       "key", absl::StrCat("host_compute_channel_", oc_cluster_name));
224   send_from_host_builder.Attr(kXlaHasHostTransferAttrName, true);
225   std::vector<NodeDefBuilder::NodeOut> inputs(send_from_host_dtypes.size());
226   for (auto* n : ret_nodes) {
227     int index;
228     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
229     const int num_dtypes = send_from_host_dtypes.size();
230     if (index < 0 || index >= num_dtypes) {
231       return errors::Internal("Invalid _Retval index: ", index);
232     }
233     for (auto edge : n->in_edges()) {
234       inputs[index] =
235           NodeDefBuilder::NodeOut{edge->src()->name(), edge->src_output(),
236                                   edge->src()->output_type(edge->src_output())};
237     }
238   }
239   send_from_host_builder.Input(inputs);
240   send_from_host_builder.Input(key_placeholder->name(), 0, DT_STRING);
241   TF_RETURN_IF_ERROR(send_from_host_builder.Finalize(&send_from_host_def));
242   Status s;
243   Node* send_from_host_node = g->AddNode(send_from_host_def, &s);
244   TF_RETURN_IF_ERROR(s);
245   return send_from_host_node;
246 }
247 
248 // Builds XlaSendFromHost node, and replaces all _Retval nodes with it.
ReplaceRetNodesWithSendFromHostNode(Graph * g,const string & oc_cluster_name,std::vector<DataType> * send_from_host_dtypes,Node * key_placeholder)249 xla::StatusOr<Node*> ReplaceRetNodesWithSendFromHostNode(
250     Graph* g, const string& oc_cluster_name,
251     std::vector<DataType>* send_from_host_dtypes, Node* key_placeholder) {
252   // TODO(b/77601805): use in nodes for sink node, instead of traversing all
253   // nodes.
254   std::vector<Node*> ret_nodes = GatherNodesWithType(*g, "_Retval");
255   TF_RETURN_IF_ERROR(GetRetDataTypes(ret_nodes, send_from_host_dtypes));
256   TF_ASSIGN_OR_RETURN(
257       Node * send_from_host_node,
258       BuildSendFromHostNode(g, oc_cluster_name, ret_nodes,
259                             *send_from_host_dtypes, key_placeholder));
260   for (auto* n : ret_nodes) {
261     int index;
262     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
263     for (auto edge : n->in_edges()) {
264       if (edge->src_output() == Graph::kControlSlot) {
265         g->AddControlEdge(edge->src(), send_from_host_node);
266       } else {
267         g->AddEdge(edge->src(), edge->src_output(), send_from_host_node, index);
268       }
269     }
270     g->RemoveNode(n);
271   }
272   g->AddEdge(key_placeholder, 0, send_from_host_node,
273              send_from_host_dtypes->size());
274   return send_from_host_node;
275 }
276 
277 // Returns input shapes (excluding key placeholder) for `send_from_host_node`
278 // if they are all fully defined; absl::nullopt otherwise.
GetInferredInputShapes(int num_inputs,Node * send_from_host_node)279 absl::optional<std::vector<PartialTensorShape>> GetInferredInputShapes(
280     int num_inputs, Node* send_from_host_node) {
281   std::vector<PartialTensorShape> results(num_inputs);
282   for (int i = 0; i < num_inputs; i++) {
283     const Edge* e;
284     if (!send_from_host_node->input_edge(i, &e).ok()) {
285       return absl::nullopt;
286     }
287 
288     std::vector<PartialTensorShape> shapes;
289     if (!GetNodeAttr(e->src()->attrs(), kXlaInferredShapesAttrName, &shapes)
290              .ok()) {
291       return absl::nullopt;
292     }
293 
294     const PartialTensorShape shape = shapes[e->src_output()];
295     if (!shape.IsFullyDefined()) {
296       return absl::nullopt;
297     }
298 
299     results[e->dst_input()] = shape;
300   }
301   return results;
302 }
303 
host_compute_node_name(const string & original_oc_name)304 string host_compute_node_name(const string& original_oc_name) {
305   return absl::StrCat("outside_compilation_", original_oc_name,
306                       "_host_compute");
307 }
308 
309 // Builds XlaHostCompute NodeDef from the outside compilation call node.
BuildXlaHostComputeNodeDef(const Node * call_node,const std::map<string,int> & host_compute_core,const absl::flat_hash_map<string,std::vector<string>> & cluster_deps)310 xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
311     const Node* call_node, const std::map<string, int>& host_compute_core,
312     const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
313   string original_oc_name;
314   TF_RETURN_IF_ERROR(GetNodeAttr(
315       call_node->attrs(), "_outside_compilation_subgraph", &original_oc_name));
316   NodeDefBuilder host_compute_builder(host_compute_node_name(original_oc_name),
317                                       "XlaHostCompute");
318   // In XlaCompiler, if XlaHostCompute node is in a function call node and that
319   // function is inlined, name of the XlaHostCompute node will be changed. So
320   // we cannot rely on node name; use an attribute instead.
321   host_compute_builder.Attr(kXlaOriginalOutsideCompilationNodeName,
322                             host_compute_builder.node_name());
323 
324   // Copy all attributes.
325   for (const auto& attr : call_node->attrs()) {
326     host_compute_builder.Attr(attr.first, attr.second);
327   }
328 
329   // Populate tpu_core assignment.
330   const auto iter = host_compute_core.find(original_oc_name);
331   if (iter != host_compute_core.end()) {
332     int core = iter->second;
333     host_compute_builder.Attr("tpu_core", core);
334   }
335 
336   // Set input tokens and other outside compilation clusters that current
337   // cluster depends in `kXlaTokenArgNodeName`. This is needed because when
338   // outside compilation subgraphs are encapsulated and moved to host graph,
339   // control/data edges between them will only be reflected in host graph.
340   // From XLA's perspective, two originally dependent clusters are no longer
341   // connected, which makes them look like they can be scheduled for execution
342   // in arbitrary order even though in fact they must be executed in order
343   // according to their host-side graph dependency. This can cause deadlock.
344   // Therefore, we hint XLA what the correct ordering of these clusters should
345   // be to avoid deadlocks.
346   std::vector<string> xla_token_input_nodes;
347   xla_token_input_nodes.emplace_back(kXlaTokenArgNodeName);
348   auto cluster_deps_it = cluster_deps.find(original_oc_name);
349   if (cluster_deps_it != cluster_deps.end()) {
350     for (const auto& dep : cluster_deps_it->second) {
351       xla_token_input_nodes.emplace_back(host_compute_node_name(dep));
352     }
353   }
354   host_compute_builder.Attr(kXlaTokenInputNodesAttrName, xla_token_input_nodes);
355 
356   // Populate inputs.
357   std::vector<DataType> input_dtypes;
358   TF_RETURN_IF_ERROR(GetNodeAttr(call_node->attrs(), "Tinputs", &input_dtypes));
359   std::vector<NodeDefBuilder::NodeOut> inputs(input_dtypes.size());
360   for (auto e : call_node->in_edges()) {
361     if (e->IsControlEdge()) {
362       continue;
363     }
364 
365     const int input_dtypes_size = input_dtypes.size();
366     if (e->dst_input() < 0 || e->dst_input() >= input_dtypes_size) {
367       return errors::Internal("Invalid dst_input: ", e->dst_input());
368     }
369     inputs[e->dst_input()] = NodeDefBuilder::NodeOut{
370         e->src()->name(), e->src_output(), input_dtypes[e->dst_input()]};
371   }
372   host_compute_builder.Input(inputs);
373 
374   NodeDef new_def;
375   TF_RETURN_IF_ERROR(host_compute_builder.Finalize(&new_def));
376   return new_def;
377 }
378 
379 // Replace outside compilation function call node with XlaHostCompute node.
ReplaceOutsideCompilationCallNode(Graph * g,Node * call_node,const std::map<string,int> & host_compute_core,const absl::flat_hash_map<string,std::vector<string>> & cluster_deps)380 TF_ATTRIBUTE_NOINLINE xla::StatusOr<Node*> ReplaceOutsideCompilationCallNode(
381     Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
382     const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
383   // Build XlaHostCompute NodeDef.
384   TF_ASSIGN_OR_RETURN(
385       NodeDef node_def,
386       BuildXlaHostComputeNodeDef(call_node, host_compute_core, cluster_deps));
387   TF_ASSIGN_OR_RETURN(Node * host_compute_node,
388                       ReplaceNode(g, call_node, node_def));
389   VLOG(4) << "Added HostCompute node: " << host_compute_node->DebugString();
390 
391   return host_compute_node;
392 }
393 
394 // Resets "_device_ordinal" attr to placeholder value for related nodes
395 // (XlaRecvAtHost nodes; XlaSendFromHost nodes; If/While/FuncCall nodes
396 // containing XlaRecvAtHost/XlaSendFromHost).
ResetDeviceOrdinalToPlaceholderValue(Graph * g)397 Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) {
398   AttrValue device_ordinal_value;
399   device_ordinal_value.set_placeholder("_device_ordinal");
400   for (Node* n : g->nodes()) {
401     if (!HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) {
402       continue;
403     }
404 
405     if (n->type_string() == "_XlaRecvAtHost" ||
406         n->type_string() == "_XlaSendFromHost") {
407       n->ClearAttr("device_ordinal");
408       n->AddAttr("device_ordinal", device_ordinal_value);
409     } else if (n->IsIfNode()) {
410       for (const string& attr_name :
411            std::vector<string>{"then_branch", "else_branch"}) {
412         NameAttrList branch_func;
413         TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func));
414         (*branch_func.mutable_attr())["_device_ordinal"] = device_ordinal_value;
415         n->ClearAttr(attr_name);
416         n->AddAttr(attr_name, branch_func);
417       }
418     } else if (n->IsWhileNode()) {
419       for (const string& attr_name : std::vector<string>{"cond", "body"}) {
420         NameAttrList branch_func;
421         TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func));
422         (*branch_func.mutable_attr())["_device_ordinal"] = device_ordinal_value;
423         n->ClearAttr(attr_name);
424         n->AddAttr(attr_name, branch_func);
425       }
426     } else if (HasNodeAttr(n->def(), "_device_ordinal")) {
427       // Function call node containing outside compilation.
428       n->ClearAttr("_device_ordinal");
429       n->AddAttr("_device_ordinal", device_ordinal_value);
430     } else {
431       return errors::Internal("Unknown node marked with ",
432                               kXlaHasHostTransferAttrName, ": ",
433                               n->DebugString());
434     }
435   }
436   return Status::OK();
437 }
438 
439 // Cheap check to tell whether FunctionDef contains a lifted argument.
HasLiftedArgs(const FunctionDef & function_def)440 bool HasLiftedArgs(const FunctionDef& function_def) {
441   return absl::c_any_of(function_def.node_def(), [](const NodeDef& node_def) {
442     return (node_def.op() == "Placeholder" &&
443             node_def.attr().find(kXlaLiftedArgOutsideCompilationAttrName) !=
444                 node_def.attr().end());
445   });
446 }
447 
448 // Find lifted arguments in a function body and their corresponding outside
449 // compilation nodes.
450 xla::StatusOr<std::vector<std::pair<Node*, Node*>>>
LiftedArgsAndOutsideCompilationNodesInFunctionBody(const FunctionBody & function_body,const std::unordered_map<string,Node * > & outside_compilation_attr_to_node)451 LiftedArgsAndOutsideCompilationNodesInFunctionBody(
452     const FunctionBody& function_body,
453     const std::unordered_map<string, Node*>& outside_compilation_attr_to_node) {
454   std::vector<std::pair<Node*, Node*>>
455       lifted_arg_nodes_and_outside_compilation_nodes;
456   for (Node* n : function_body.graph->op_nodes()) {
457     string oc_cluster;
458     if (n->type_string() == "Placeholder" &&
459         GetNodeAttr(n->def(), kXlaLiftedArgOutsideCompilationAttrName,
460                     &oc_cluster)
461             .ok()) {
462       TF_RET_CHECK(outside_compilation_attr_to_node.find(oc_cluster) !=
463                    outside_compilation_attr_to_node.end());
464       lifted_arg_nodes_and_outside_compilation_nodes.emplace_back(
465           n, outside_compilation_attr_to_node.at(oc_cluster));
466     }
467   }
468   return lifted_arg_nodes_and_outside_compilation_nodes;
469 }
470 
471 // Append lifted args' types to functional control flow node's `type_attr_name`
472 // attribute.
UpdateTypesAttribute(const std::vector<std::pair<Node *,Node * >> & lifted_arg_nodes_and_outside_compilation_nodes,const string & type_attr_name,Node * n)473 xla::StatusOr<std::vector<DataType>> UpdateTypesAttribute(
474     const std::vector<std::pair<Node*, Node*>>&
475         lifted_arg_nodes_and_outside_compilation_nodes,
476     const string& type_attr_name, Node* n) {
477   std::vector<DataType> data_types;
478   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), type_attr_name, &data_types));
479   for (auto pair : lifted_arg_nodes_and_outside_compilation_nodes) {
480     Node* outside_compilation_node = pair.second;
481     DataType data_type;
482     TF_RET_CHECK(outside_compilation_node->IsIdentity() ||
483                  outside_compilation_node->type_string() == "Placeholder");
484     if (outside_compilation_node->IsIdentity()) {
485       TF_RETURN_IF_ERROR(
486           GetNodeAttr(outside_compilation_node->def(), "T", &data_type));
487     } else {
488       TF_RETURN_IF_ERROR(
489           GetNodeAttr(outside_compilation_node->def(), "dtype", &data_type));
490     }
491     data_types.push_back(data_type);
492   }
493   n->ClearAttr(type_attr_name);
494   n->AddAttr(type_attr_name, data_types);
495 
496   return data_types;
497 }
498 
499 // Add edges from lifted outside compilation argument nodes to `n` in Graph `g`.
AddEdgesFromOutsideCompilationNodes(const int original_arg_count,const int arg_to_input_edge_offset,const std::vector<DataType> & data_types,const std::vector<Node * > & outside_compilation_nodes,Graph * g,Node * n)500 void AddEdgesFromOutsideCompilationNodes(
501     const int original_arg_count, const int arg_to_input_edge_offset,
502     const std::vector<DataType>& data_types,
503     const std::vector<Node*>& outside_compilation_nodes, Graph* g, Node* n) {
504   // Add edges from outside compilation nodes to While node.
505   for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
506     Node* outside_compilation_node =
507         outside_compilation_nodes[i - original_arg_count];
508     g->AddEdge(outside_compilation_node, 0, n, i + arg_to_input_edge_offset);
509   }
510 }
511 
512 // Construct _Arg that maps to lifted outside compilation argument node input.
AddOutsideCompilationInputArgToFunctionBody(const FunctionBody & function_body,const int arg_idx,const DataType & data_type)513 xla::StatusOr<Node*> AddOutsideCompilationInputArgToFunctionBody(
514     const FunctionBody& function_body, const int arg_idx,
515     const DataType& data_type) {
516   NodeDefBuilder arg_builder(absl::StrCat("arg_", arg_idx), "_Arg");
517   arg_builder.Attr("T", data_type);
518   arg_builder.Attr("index", arg_idx);
519   NodeDef arg_def;
520   TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def));
521 
522   Status s;
523   Node* arg_node = function_body.graph->AddNode(arg_def, &s);
524   TF_RETURN_IF_ERROR(s);
525   return arg_node;
526 }
527 
528 // Add _Retval node that matches newly added `arg_node` and connect `arg_node`
529 // to it.
AddMatchingRetvalNode(const FunctionBody & function_body,const int arg_idx,const DataType & data_type,Node * arg_node)530 Status AddMatchingRetvalNode(const FunctionBody& function_body,
531                              const int arg_idx, const DataType& data_type,
532                              Node* arg_node) {
533   NodeDefBuilder ret_builder(absl::StrCat("ret_", arg_idx), "_Retval");
534   ret_builder.Attr("T", data_type);
535   ret_builder.Attr("index", arg_idx);
536   ret_builder.Input(arg_node->name(), 0, data_type);
537   NodeDef ret_def;
538   TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
539   Status s;
540   Node* ret_node = function_body.graph->AddNode(ret_def, &s);
541   TF_RETURN_IF_ERROR(s);
542   function_body.graph->AddEdge(arg_node, 0, ret_node, 0);
543 
544   return Status::OK();
545 }
546 
ReplaceLiftedArgNodePlaceholderWithArg(const FunctionBody & function_body,const int original_arg_count,const int arg_idx,const std::vector<Node * > & lifted_arg_nodes,Node * arg_node)547 void ReplaceLiftedArgNodePlaceholderWithArg(
548     const FunctionBody& function_body, const int original_arg_count,
549     const int arg_idx, const std::vector<Node*>& lifted_arg_nodes,
550     Node* arg_node) {
551   Node* lifted_arg_node = lifted_arg_nodes[arg_idx - original_arg_count];
552   // This might happen because lifted_arg_node only exists in one branch of an
553   // If node, and we are handling the other branch.
554   if (!lifted_arg_node) {
555     return;
556   }
557 
558   for (const Edge* e : lifted_arg_node->out_edges()) {
559     if (e->IsControlEdge()) {
560       function_body.graph->AddControlEdge(arg_node, e->dst());
561     } else {
562       function_body.graph->AddEdge(arg_node, 0, e->dst(), e->dst_input());
563     }
564   }
565   function_body.graph->RemoveNode(lifted_arg_node);
566 }
567 
568 // Adds function def to function definition library and update the function
569 // callsite operation `callsite_node` to invoke new function instead.
AddFunctionWithNewName(const std::string & new_name,const std::string & func_attr_name,const FunctionDef & function_def,NameAttrList * func_attr,Node * callsite_node,FunctionLibraryDefinition * fld)570 Status AddFunctionWithNewName(const std::string& new_name,
571                               const std::string& func_attr_name,
572                               const FunctionDef& function_def,
573                               NameAttrList* func_attr, Node* callsite_node,
574                               FunctionLibraryDefinition* fld) {
575   TF_RETURN_IF_ERROR(fld->AddFunctionDef(function_def));
576   func_attr->set_name(new_name);
577   callsite_node->ClearAttr(func_attr_name);
578   callsite_node->AddAttr(func_attr_name, *func_attr);
579   return Status::OK();
580 }
581 
582 // Reconnect outside compilation lifted arguments in a functional While node to
583 // its outside compilation tensor sources.
PostprocessLiftedArgsForWhile(const std::unordered_map<string,Node * > & outside_compilation_attr_to_node,Graph * g,Node * n,FunctionLibraryDefinition * fld)584 Status PostprocessLiftedArgsForWhile(
585     const std::unordered_map<string, Node*>& outside_compilation_attr_to_node,
586     Graph* g, Node* n, FunctionLibraryDefinition* fld) {
587   TF_RET_CHECK(n->IsWhileNode());
588 
589   // Check if there is any lifted args in body function.
590   NameAttrList body_func;
591   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "body", &body_func));
592   const FunctionDef* body_function_def = fld->Find(body_func.name());
593   TF_RET_CHECK(body_function_def);
594 
595   if (!HasLiftedArgs(*body_function_def)) {
596     return Status::OK();
597   }
598 
599   // Gather all lifted args.
600   std::unique_ptr<FunctionBody> body_function_body;
601   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*body_function_def,
602                                              AttrSlice(&body_func.attr()), fld,
603                                              &body_function_body));
604 
605   int original_arg_count = body_function_body->arg_nodes.size();
606 
607   TF_ASSIGN_OR_RETURN(
608       auto lifted_arg_nodes_and_outside_compilation_nodes,
609       LiftedArgsAndOutsideCompilationNodesInFunctionBody(
610           *body_function_body, outside_compilation_attr_to_node));
611 
612   // Append lifted args' types to While node's T attribute.
613   TF_ASSIGN_OR_RETURN(
614       std::vector<DataType> data_types,
615       UpdateTypesAttribute(lifted_arg_nodes_and_outside_compilation_nodes, "T",
616                            n));
617 
618   // Add edges from outside compilation nodes to While node.
619   std::vector<Node*> outside_compilation_nodes;
620   std::transform(
621       lifted_arg_nodes_and_outside_compilation_nodes.begin(),
622       lifted_arg_nodes_and_outside_compilation_nodes.end(),
623       std::back_inserter(outside_compilation_nodes),
624       [](const std::pair<Node*, Node*>& pair) { return pair.second; });
625   AddEdgesFromOutsideCompilationNodes(original_arg_count,
626                                       /*arg_to_input_edge_offset=*/0,
627                                       data_types, outside_compilation_nodes, g,
628                                       n);
629 
630   // In body_graph, create new _Arg/_Retval nodes, and replace lifted arg
631   // nodes with the new _Arg nodes.
632   std::vector<Node*> lifted_arg_nodes;
633   std::transform(
634       lifted_arg_nodes_and_outside_compilation_nodes.begin(),
635       lifted_arg_nodes_and_outside_compilation_nodes.end(),
636       std::back_inserter(lifted_arg_nodes),
637       [](const std::pair<Node*, Node*>& pair) { return pair.first; });
638   for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
639     TF_ASSIGN_OR_RETURN(Node * arg_node,
640                         AddOutsideCompilationInputArgToFunctionBody(
641                             *body_function_body, i, data_types[i]));
642 
643     TF_RETURN_IF_ERROR(
644         AddMatchingRetvalNode(*body_function_body, i, data_types[i], arg_node));
645 
646     ReplaceLiftedArgNodePlaceholderWithArg(
647         *body_function_body, original_arg_count, i, lifted_arg_nodes, arg_node);
648   }
649 
650   const auto new_body_function_name =
651       fld->UniqueFunctionName(absl::StrCat(body_func.name(), "_lifted_arg_"));
652   FunctionDef rewritten_body_function_def;
653   TF_RETURN_IF_ERROR(GraphToFunctionDef(
654       *body_function_body->graph, new_body_function_name,
655       HostGraphControlRetMapping, &rewritten_body_function_def));
656   TF_RETURN_IF_ERROR(AddFunctionWithNewName(new_body_function_name, "body",
657                                             rewritten_body_function_def,
658                                             &body_func, n, fld));
659 
660   // In cond_graph, just add new _Arg nodes.
661   NameAttrList cond_func;
662   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "cond", &cond_func));
663   const FunctionDef* cond_function_def = fld->Find(cond_func.name());
664   TF_RET_CHECK(cond_function_def);
665   std::unique_ptr<FunctionBody> cond_function_body;
666   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*cond_function_def,
667                                              AttrSlice(&cond_func.attr()), fld,
668                                              &cond_function_body));
669 
670   for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
671     xla::StatusOr<Node*> arg_node_or =
672         AddOutsideCompilationInputArgToFunctionBody(*cond_function_body, i,
673                                                     data_types[i]);
674     TF_RETURN_IF_ERROR(arg_node_or.status());
675   }
676 
677   const auto new_cond_function_name =
678       fld->UniqueFunctionName(absl::StrCat(cond_func.name(), "_lifted_arg_"));
679   FunctionDef rewritten_cond_function_def;
680   TF_RETURN_IF_ERROR(GraphToFunctionDef(
681       *cond_function_body->graph, new_cond_function_name,
682       HostGraphControlRetMapping, &rewritten_cond_function_def));
683   TF_RETURN_IF_ERROR(AddFunctionWithNewName(new_cond_function_name, "cond",
684                                             rewritten_cond_function_def,
685                                             &cond_func, n, fld));
686   return Status::OK();
687 }
688 
PostprocessLiftedArgsForIf(const std::unordered_map<string,Node * > & outside_compilation_attr_to_node,Graph * g,Node * n,FunctionLibraryDefinition * fld)689 Status PostprocessLiftedArgsForIf(
690     const std::unordered_map<string, Node*>& outside_compilation_attr_to_node,
691     Graph* g, Node* n, FunctionLibraryDefinition* fld) {
692   TF_RET_CHECK(n->IsIfNode());
693 
694   NameAttrList then_branch_func;
695   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "then_branch", &then_branch_func));
696   const FunctionDef* then_branch_function_def =
697       fld->Find(then_branch_func.name());
698   TF_RET_CHECK(then_branch_function_def);
699 
700   NameAttrList else_branch_func;
701   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "else_branch", &else_branch_func));
702   const FunctionDef* else_branch_function_def =
703       fld->Find(else_branch_func.name());
704   TF_RET_CHECK(else_branch_function_def);
705 
706   // Nothing to do if neither branch contains any lifted arguments.
707   if (!HasLiftedArgs(*then_branch_function_def) &&
708       !HasLiftedArgs(*else_branch_function_def)) {
709     return Status::OK();
710   }
711 
712   std::unique_ptr<FunctionBody> then_branch_function_body;
713   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
714       *then_branch_function_def, AttrSlice(&then_branch_func.attr()), fld,
715       &then_branch_function_body));
716 
717   std::unique_ptr<FunctionBody> else_branch_function_body;
718   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
719       *else_branch_function_def, AttrSlice(&else_branch_func.attr()), fld,
720       &else_branch_function_body));
721 
722   // Then and else branches have same argument count and argument data types.
723   int original_arg_count = then_branch_function_body->arg_nodes.size();
724 
725   TF_ASSIGN_OR_RETURN(
726       auto then_branch_lifted_arg_nodes_and_outside_compilation_nodes,
727       LiftedArgsAndOutsideCompilationNodesInFunctionBody(
728           *then_branch_function_body, outside_compilation_attr_to_node));
729 
730   TF_ASSIGN_OR_RETURN(
731       auto else_branch_lifted_arg_nodes_and_outside_compilation_nodes,
732       LiftedArgsAndOutsideCompilationNodesInFunctionBody(
733           *else_branch_function_body, outside_compilation_attr_to_node));
734 
735   // Merge lifted args from then and else branches.
736   std::vector<Node*> outside_compilation_nodes;
737   std::vector<Node*> then_branch_lifted_arg_nodes;
738   for (const auto& pair :
739        then_branch_lifted_arg_nodes_and_outside_compilation_nodes) {
740     outside_compilation_nodes.push_back(pair.second);
741     then_branch_lifted_arg_nodes.push_back(pair.first);
742   }
743   for (const auto& pair :
744        else_branch_lifted_arg_nodes_and_outside_compilation_nodes) {
745     if (std::find(outside_compilation_nodes.begin(),
746                   outside_compilation_nodes.end(),
747                   pair.second) == outside_compilation_nodes.end()) {
748       outside_compilation_nodes.push_back(pair.second);
749       // Then branch does not contain this lifted arg. Add an empty item to
750       // then_branch_lifted_arg_nodes.
751       then_branch_lifted_arg_nodes.push_back(nullptr);
752     }
753   }
754   // Reorder else_branch_lifted_arg_nodes_and_outside_compilation_nodes.
755   std::vector<Node*> else_branch_lifted_arg_nodes(
756       outside_compilation_nodes.size());
757   for (const auto& pair :
758        else_branch_lifted_arg_nodes_and_outside_compilation_nodes) {
759     auto iter = std::find(outside_compilation_nodes.begin(),
760                           outside_compilation_nodes.end(), pair.second);
761     TF_RET_CHECK(iter != outside_compilation_nodes.end());
762     int index = iter - outside_compilation_nodes.begin();
763     else_branch_lifted_arg_nodes[index] = pair.first;
764   }
765 
766   // Append lifted args' types to If node's Tin attribute.
767   std::vector<DataType> data_types;
768   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "Tin", &data_types));
769   for (Node* n : outside_compilation_nodes) {
770     data_types.push_back(n->output_type(0));
771   }
772   n->ClearAttr("Tin");
773   n->AddAttr("Tin", data_types);
774 
775   // Add edges from outside compilation nodes to If node. If node's input #0
776   // is predicate input, input #1 maps to _Arg #0 of branch functions, thus
777   // arg_to_input_edge_offset is set to 1.
778   AddEdgesFromOutsideCompilationNodes(original_arg_count,
779                                       /*arg_to_input_edge_offset=*/1,
780                                       data_types, outside_compilation_nodes, g,
781                                       n);
782 
783   for (int i = original_arg_count, end = data_types.size(); i < end; ++i) {
784     TF_ASSIGN_OR_RETURN(Node * then_branch_arg_node,
785                         AddOutsideCompilationInputArgToFunctionBody(
786                             *then_branch_function_body, i, data_types[i]));
787 
788     ReplaceLiftedArgNodePlaceholderWithArg(
789         *then_branch_function_body, original_arg_count, i,
790         then_branch_lifted_arg_nodes, then_branch_arg_node);
791 
792     TF_ASSIGN_OR_RETURN(Node * else_branch_arg_node,
793                         AddOutsideCompilationInputArgToFunctionBody(
794                             *else_branch_function_body, i, data_types[i]));
795 
796     ReplaceLiftedArgNodePlaceholderWithArg(
797         *else_branch_function_body, original_arg_count, i,
798         else_branch_lifted_arg_nodes, else_branch_arg_node);
799   }
800 
801   const auto new_then_function_name = fld->UniqueFunctionName(
802       absl::StrCat(then_branch_func.name(), "_lifted_arg_"));
803   FunctionDef rewritten_then_branch_function_def;
804   TF_RETURN_IF_ERROR(GraphToFunctionDef(
805       *then_branch_function_body->graph, new_then_function_name,
806       HostGraphControlRetMapping, &rewritten_then_branch_function_def));
807   TF_RETURN_IF_ERROR(AddFunctionWithNewName(
808       new_then_function_name, "then_branch", rewritten_then_branch_function_def,
809       &then_branch_func, n, fld));
810 
811   const auto new_else_function_name = fld->UniqueFunctionName(
812       absl::StrCat(else_branch_func.name(), "_lifted_arg_"));
813   FunctionDef rewritten_else_branch_function_def;
814   TF_RETURN_IF_ERROR(GraphToFunctionDef(
815       *else_branch_function_body->graph, new_else_function_name,
816       HostGraphControlRetMapping, &rewritten_else_branch_function_def));
817   TF_RETURN_IF_ERROR(AddFunctionWithNewName(
818       new_else_function_name, "else_branch", rewritten_else_branch_function_def,
819       &else_branch_func, n, fld));
820   return Status::OK();
821 }
822 
PostprocessLiftedArgsForCall(const std::unordered_map<string,Node * > & outside_compilation_attr_to_node,Graph * g,Node * n,FunctionLibraryDefinition * fld)823 Status PostprocessLiftedArgsForCall(
824     const std::unordered_map<string, Node*>& outside_compilation_attr_to_node,
825     Graph* g, Node* n, FunctionLibraryDefinition* fld) {
826   const FunctionDef* fdef = fld->Find(n->type_string());
827   TF_RET_CHECK(fdef);
828 
829   // Nothing to do if the function does not contain any lifted arguments.
830   if (!HasLiftedArgs(*fdef)) {
831     return Status::OK();
832   }
833 
834   std::unique_ptr<FunctionBody> fbody;
835   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, n->attrs(), fld, &fbody));
836 
837   int original_arg_count = fbody->arg_nodes.size();
838 
839   TF_ASSIGN_OR_RETURN(auto lifted_arg_nodes_and_outside_compilation_nodes,
840                       LiftedArgsAndOutsideCompilationNodesInFunctionBody(
841                           *fbody, outside_compilation_attr_to_node));
842 
843   // Append lifted args' types to call node's input data types.
844   std::vector<DataType> data_types(n->input_types().begin(),
845                                    n->input_types().end());
846   for (auto pair : lifted_arg_nodes_and_outside_compilation_nodes) {
847     Node* outside_compilation_node = pair.second;
848     DataType data_type;
849     TF_RET_CHECK(outside_compilation_node->IsIdentity() ||
850                  outside_compilation_node->type_string() == "Placeholder");
851     if (outside_compilation_node->IsIdentity()) {
852       TF_RETURN_IF_ERROR(
853           GetNodeAttr(outside_compilation_node->def(), "T", &data_type));
854     } else {
855       TF_RETURN_IF_ERROR(
856           GetNodeAttr(outside_compilation_node->def(), "dtype", &data_type));
857     }
858     data_types.push_back(data_type);
859   }
860 
861   std::vector<Node*> lifted_arg_nodes;
862   std::transform(
863       lifted_arg_nodes_and_outside_compilation_nodes.begin(),
864       lifted_arg_nodes_and_outside_compilation_nodes.end(),
865       std::back_inserter(lifted_arg_nodes),
866       [](const std::pair<Node*, Node*>& pair) { return pair.first; });
867   for (int i = original_arg_count, end = data_types.size(); i < end; ++i) {
868     TF_ASSIGN_OR_RETURN(
869         Node * arg_node,
870         AddOutsideCompilationInputArgToFunctionBody(*fbody, i, data_types[i]));
871 
872     ReplaceLiftedArgNodePlaceholderWithArg(*fbody, original_arg_count, i,
873                                            lifted_arg_nodes, arg_node);
874   }
875 
876   FunctionDef rewritten_fdef;
877   TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, n->type_string(),
878                                         HostGraphControlRetMapping,
879                                         &rewritten_fdef));
880   const auto new_function_name =
881       fld->UniqueFunctionName(absl::StrCat(n->type_string(), "_lifted_arg_"));
882   rewritten_fdef.mutable_signature()->set_name(new_function_name);
883   TF_RETURN_IF_ERROR(fld->AddFunctionDef(rewritten_fdef));
884 
885   // We need to recreate the node. Otherwise TF will not know n->num_inputs()
886   // has increased.
887   NodeDef node_def = n->def();
888 
889   // Function name is represented via the Op's type. Reset the op type to new
890   // function def name;
891   *node_def.mutable_op() = new_function_name;
892 
893   for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
894     Node* outside_compilation_node =
895         lifted_arg_nodes_and_outside_compilation_nodes[i - original_arg_count]
896             .second;
897     node_def.add_input(absl::StrCat(outside_compilation_node->name(), ":", 0));
898   }
899   TF_ASSIGN_OR_RETURN(n, ReplaceNode(g, n, node_def));
900 
901   // Add edges from outside compilation nodes to call node.
902   std::vector<Node*> outside_compilation_nodes;
903   std::transform(
904       lifted_arg_nodes_and_outside_compilation_nodes.begin(),
905       lifted_arg_nodes_and_outside_compilation_nodes.end(),
906       std::back_inserter(outside_compilation_nodes),
907       [](const std::pair<Node*, Node*>& pair) { return pair.second; });
908   AddEdgesFromOutsideCompilationNodes(original_arg_count,
909                                       /*arg_to_input_edge_offset=*/0,
910                                       data_types, outside_compilation_nodes, g,
911                                       n);
912 
913   return Status::OK();
914 }
915 
916 // Creates a mapping from outside compilation cluster name to lifted argument
917 // placeholder.
OutsideCompilationAttrToNode(const Graph & g)918 xla::StatusOr<std::unordered_map<string, Node*>> OutsideCompilationAttrToNode(
919     const Graph& g) {
920   std::unordered_map<string, Node*> outside_compilation_attr_to_node;
921   for (Node* n : g.op_nodes()) {
922     bool is_lifted_arg;
923     string outside_compilation_attr;
924     if (TryGetNodeAttr(n->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg) &&
925         TryGetNodeAttr(n->def(), "_xla_outside_compilation",
926                        &outside_compilation_attr)) {
927       TF_RET_CHECK(is_lifted_arg);
928       TF_RET_CHECK(n->IsIdentity() || n->type_string() == "Placeholder");
929       outside_compilation_attr_to_node[outside_compilation_attr] = n;
930     }
931   }
932 
933   return outside_compilation_attr_to_node;
934 }
935 
PostprocessLiftedArgs(Graph * g,FunctionLibraryDefinition * fld)936 Status PostprocessLiftedArgs(Graph* g, FunctionLibraryDefinition* fld) {
937   TF_ASSIGN_OR_RETURN(auto outside_compilation_attr_to_node,
938                       OutsideCompilationAttrToNode(*g));
939 
940   std::vector<Node*> call_nodes;
941   for (Node* n : g->op_nodes()) {
942     if (!HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) {
943       continue;
944     }
945 
946     if (n->IsWhileNode()) {
947       TF_RETURN_IF_ERROR(PostprocessLiftedArgsForWhile(
948           outside_compilation_attr_to_node, g, n, fld));
949     }
950 
951     if (n->IsIfNode()) {
952       TF_RETURN_IF_ERROR(PostprocessLiftedArgsForIf(
953           outside_compilation_attr_to_node, g, n, fld));
954     }
955 
956     // Outside compilation host side function call will always be direct
957     // function call nodes.
958     // Function call nodes need to be handled separately because we rewrite
959     // nodes in `PostprocessLiftedArgsForCall`.
960     if (fld->Contains(n->type_string())) {
961       call_nodes.push_back(n);
962     }
963   }
964 
965   for (Node* n : call_nodes) {
966     TF_RETURN_IF_ERROR(PostprocessLiftedArgsForCall(
967         outside_compilation_attr_to_node, g, n, fld));
968   }
969 
970   return Status::OK();
971 }
972 
973 // For an XLA computation, builds host side graph given all outside compilation
974 // graphs inside it. The host side graph contains:
975 // 1) a "sequencer" node (we will add control edge between XlaRecvAtHost and
976 //    XlaSendFromHost to this sequencer node, so all outside compilation nodes
977 //    will be executed *before* this sequencer).
978 // 2) a "key placeholder" node. Later in ExpandHostGraphIntoMainGraph(), we will
979 //    replace this node with compilation result node.
980 // 3) all outside compilation graphs.
ConstructHostGraph(const string & xla_cluster_name,const string & outside_compilation_attr_name,const std::vector<string> & outside_compilation_host_graphs,FunctionLibraryDefinition * fld,std::unique_ptr<Graph> * host_graph)981 Status ConstructHostGraph(
982     const string& xla_cluster_name, const string& outside_compilation_attr_name,
983     const std::vector<string>& outside_compilation_host_graphs,
984     FunctionLibraryDefinition* fld, std::unique_ptr<Graph>* host_graph) {
985   host_graph->reset(new Graph(fld));
986 
987   // Create sequencer node in host graph.
988   NodeDefBuilder sequencer_builder(absl::StrCat(xla_cluster_name, "_sequencer"),
989                                    "NoOp");
990   sequencer_builder.Attr("_xla_host_transfer_sequencer", xla_cluster_name);
991   NodeDef sequencer_def;
992   TF_RETURN_IF_ERROR(sequencer_builder.Finalize(&sequencer_def));
993   Status s;
994   Node* sequencer = (*host_graph)->AddNode(sequencer_def, &s);
995   TF_RETURN_IF_ERROR(s);
996 
997   // Create key placeholder in host graph.
998   TF_ASSIGN_OR_RETURN(
999       Node * key_placeholder,
1000       AddHostComputeKeyPlaceholder(xla_cluster_name, host_graph->get()));
1001 
1002   // For each outside compilation graph, copy them to host graph with the
1003   // following changes:
1004   // a) Use key_placeholder in host graph instead of its own.
1005   // b) Add control edge from host transfer nodes (XlaRecvAtHost,
1006   //    XlaSendFromHost, If/While nodes containing
1007   //    XlaRecvAtHost/XlaSendFromHost) to sequencer node.
1008   // c) Clear node_def.device(), so device placer won't get confused.
1009   for (const string& host_func : outside_compilation_host_graphs) {
1010     VLOG(4) << "Expanding host graph " << host_func;
1011     // Temporarily use "0" as "_device_ordinal". It will be reset to placeholder
1012     // value after we expanded all host graphs. We cannot just use placeholder
1013     // value here because FunctionDef instantiation does not allow placeholder
1014     // value for attributes.
1015     AttrValue device_ordinal_attr;
1016     device_ordinal_attr.set_i(0);
1017     protobuf::Map<string, AttrValue> attrs;
1018     attrs["_device_ordinal"] = device_ordinal_attr;
1019     std::unique_ptr<FunctionBody> host_fbody;
1020     const FunctionDef* host_fdef = fld->Find(host_func);
1021     TF_RET_CHECK(host_fdef);
1022     TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*host_fdef, AttrSlice(&attrs),
1023                                                fld, &host_fbody));
1024 
1025     // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse
1026     // reachable from sink node so all nodes will be copied.
1027     // TODO(b/77601805): consolidate copy graph functions.
1028     FixupSourceAndSinkEdges(host_fbody->graph);
1029 
1030     std::map<const Node*, Node*> node_map;
1031     node_map[host_fbody->graph->source_node()] = (*host_graph)->source_node();
1032     node_map[host_fbody->graph->sink_node()] = (*host_graph)->sink_node();
1033     Status s;
1034     ReverseDFS(
1035         *host_fbody->graph, /*enter=*/nullptr,
1036         [&](const Node* n) {
1037           if (!s.ok()) {
1038             return;
1039           }
1040 
1041           Node* copy;
1042           if (node_map.find(n) != node_map.end()) {
1043             // Already copied this node.
1044             copy = node_map.at(n);
1045           } else if (IsKeyPlaceholderNode(*n)) {
1046             // Change a).
1047             copy = key_placeholder;
1048             node_map[n] = copy;
1049           } else {
1050             // Copy the node.
1051             NodeDef copy_def = n->def();
1052             // Change c).
1053             copy_def.clear_device();
1054             copy = (*host_graph)->AddNode(copy_def, &s);
1055             if (!s.ok()) {
1056               return;
1057             }
1058             node_map[n] = copy;
1059           }
1060 
1061           // Only handle input edges. Output edges will be added later as
1062           // its output nodes' input edges.
1063           for (auto e : n->in_edges()) {
1064             if (node_map.find(e->src()) == node_map.end()) {
1065               s = errors::Internal("Cannot find node image for ",
1066                                    e->src()->DebugString());
1067               return;
1068             }
1069             (*host_graph)
1070                 ->AddEdge(node_map[e->src()], e->src_output(), copy,
1071                           e->dst_input());
1072           }
1073 
1074           // Change b).
1075           if (HasNodeAttr(copy->def(), kXlaHasHostTransferAttrName)) {
1076             (*host_graph)->AddControlEdge(copy, sequencer);
1077           }
1078         },
1079         NodeComparatorID());
1080 
1081     if (!s.ok()) {
1082       return s;
1083     }
1084   }
1085   // Reset "_device_ordinal" to placeholder value.
1086   TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(host_graph->get()));
1087 
1088   // sequencer and key_placeholder might be dead nodes. Prune them if necessary.
1089   // - sequencer should be pruned iff it has no input control edges from
1090   //   RecvAtHost/SendFromHost. If it has input control edge, we connect it to
1091   //   sink node so it won't be pruned.
1092   // - key_placeholder should be pruned iff there's no RecvAtHost/SendFromHost.
1093   //   We don't need to do anything special.
1094   if (!sequencer->in_edges().empty()) {
1095     (*host_graph)->AddControlEdge(sequencer, (*host_graph)->sink_node());
1096   }
1097   PruneForReverseReachability(
1098       host_graph->get(),
1099       std::unordered_set<const Node*>{(*host_graph)->sink_node()});
1100 
1101   // Postprocess edges between different outside compilations.
1102   TF_RETURN_IF_ERROR(PostprocessEdgesBetweenOutsideCompilations(
1103       host_graph->get(), outside_compilation_attr_name));
1104 
1105   // Postprocess lifted arg nodes.
1106   TF_RETURN_IF_ERROR(PostprocessLiftedArgs(host_graph->get(), fld));
1107 
1108   if (VLOG_IS_ON(4)) {
1109     DumpGraphToFile(absl::StrCat("extract_outside_compilation_host_graph_for_",
1110                                  xla_cluster_name),
1111                     **host_graph, fld);
1112   }
1113 
1114   return Status::OK();
1115 }
1116 
1117 // Expand XLA computation's outside compilation host side graph into main graph.
1118 // Add a control edge between sequencer node and the XLA computation node.
ExpandHostGraphIntoMainGraph(Graph * main_graph,FunctionLibraryDefinition * fld,const string & host_graph_func_name,Node * xla_computation_node,Node * pivot_node)1119 Status ExpandHostGraphIntoMainGraph(Graph* main_graph,
1120                                     FunctionLibraryDefinition* fld,
1121                                     const string& host_graph_func_name,
1122                                     Node* xla_computation_node,
1123                                     Node* pivot_node) {
1124   // Temporarily use "0" as "_device_ordinal". It will be rewritten with the
1125   // correct value in a later pass. We cannot just use placeholder value here
1126   // because FunctionDef instantiation does not allow placeholder value for
1127   // attributes.
1128   AttrValue device_ordinal_attr;
1129   device_ordinal_attr.set_i(0);
1130   protobuf::Map<string, AttrValue> attrs;
1131   attrs["_device_ordinal"] = device_ordinal_attr;
1132   std::unique_ptr<FunctionBody> fbody;
1133   const FunctionDef* host_graph_func = fld->Find(host_graph_func_name);
1134   TF_RET_CHECK(host_graph_func);
1135   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*host_graph_func,
1136                                              AttrSlice(&attrs), fld, &fbody));
1137   Graph* host_graph = fbody->graph;
1138 
1139   // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse
1140   // reachable from sink node so all nodes will be copied.
1141   // TODO(b/77601805): consolidate copy graph functions.
1142   FixupSourceAndSinkEdges(host_graph);
1143 
1144   // Copy all nodes.
1145   std::map<const Node*, Node*> node_map;
1146   if (pivot_node) {
1147     node_map[host_graph->source_node()] = pivot_node;
1148   } else {
1149     node_map[host_graph->source_node()] = main_graph->source_node();
1150   }
1151   node_map[host_graph->sink_node()] = main_graph->sink_node();
1152   Status s = Status::OK();
1153   auto copy_node_fn = [&](const Node* n) {
1154     if (!s.ok()) {
1155       return;
1156     }
1157 
1158     Node* copy;
1159     if (node_map.find(n) != node_map.end()) {
1160       // Already copied this node.
1161       copy = node_map.at(n);
1162     } else {
1163       // Copy the node.
1164       NodeDef copy_def = n->def();
1165       copy = main_graph->AddNode(copy_def, &s);
1166       if (!s.ok()) {
1167         return;
1168       }
1169       node_map[n] = copy;
1170     }
1171 
1172     // Only handle input edges. Output edges will be added later as its output
1173     // nodes' input edges.
1174     for (auto e : n->in_edges()) {
1175       if (node_map.find(e->src()) == node_map.end()) {
1176         s = errors::Internal("Cannot find node image for ",
1177                              e->src()->DebugString());
1178         return;
1179       }
1180       main_graph->AddEdge(node_map[e->src()], e->src_output(), copy,
1181                           e->dst_input());
1182     }
1183 
1184     // Add control edge from sequencer to XLA computation node.
1185     if (copy->type_string() == "NoOp" &&
1186         HasNodeAttr(copy->def(), "_xla_host_transfer_sequencer")) {
1187       main_graph->AddControlEdge(copy, xla_computation_node);
1188     }
1189   };
1190   ReverseDFS(*host_graph, /*enter=*/nullptr, copy_node_fn, NodeComparatorID());
1191   return s;
1192 }
1193 
1194 // Rewrites shape inference graph for outside compilation:
1195 // 1) If XlaSendFromHost also exists in `host_graph`, copy nodes from
1196 //    `host_graph`. Because we might still have outside compilation to outside
1197 //    compilation placeholder nodes in shape inference graph, which will prevent
1198 //    us from inferring XlaSendFromHost shape. But in `host_graph`, we already
1199 //    removed those placeholder nodes.
1200 // 2) Remove control edges.
1201 // 3) Prune nodes that are not useful for shape inference.
RewriteShapeInferenceGraph(const string & shape_inference_graph_name,Graph * host_graph,Node * pivot_node,FunctionLibraryDefinition * fld)1202 Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name,
1203                                   Graph* host_graph, Node* pivot_node,
1204                                   FunctionLibraryDefinition* fld) {
1205   // Use "0" as "_device_ordinal". It does not matter for shape inference.
1206   AttrValue device_ordinal_attr;
1207   device_ordinal_attr.set_i(0);
1208   protobuf::Map<string, AttrValue> attrs;
1209   attrs["_device_ordinal"] = device_ordinal_attr;
1210   std::unique_ptr<FunctionBody> fbody;
1211   const FunctionDef* shape_inference_graph =
1212       fld->Find(shape_inference_graph_name);
1213   TF_RET_CHECK(shape_inference_graph);
1214   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*shape_inference_graph,
1215                                              AttrSlice(&attrs), fld, &fbody));
1216   Graph* g = fbody->graph;
1217 
1218   // Find SendFromHost node.
1219   Node* send_from_host = nullptr;
1220   for (Node* n : g->nodes()) {
1221     if (n->type_string() == "_XlaSendFromHost") {
1222       send_from_host = n;
1223       break;
1224     }
1225   }
1226   if (!send_from_host) {
1227     return errors::Internal("Shape inference graph ",
1228                             shape_inference_graph_name,
1229                             " does not have _XlaSendFromHost node.");
1230   }
1231 
1232   // See if the SendFromHost node exists in `host_graph`.
1233   Node* send_node_in_host_graph = nullptr;
1234   for (Node* n : host_graph->nodes()) {
1235     if (n->name() == send_from_host->name()) {
1236       send_node_in_host_graph = n;
1237       break;
1238     }
1239   }
1240   if (send_node_in_host_graph) {
1241     // This is an "top-level" outside compilation. Clear the graph, and copy
1242     // SendFromHost and all its predecessors from `host_graph`.
1243     std::vector<Node*> nodes;
1244     for (Node* n : g->op_nodes()) {
1245       nodes.push_back(n);
1246     }
1247     for (Node* n : nodes) {
1248       g->RemoveNode(n);
1249     }
1250     Node* start_node = pivot_node ? pivot_node : host_graph->source_node();
1251     // Reverse DFS from send_from_host_main_graph, and stop at start_node.
1252     struct Visit {
1253       Node* n;
1254       bool is_exiting;
1255     };
1256     std::vector<Visit> stack{{send_node_in_host_graph, false}};
1257     std::map<Node*, Node*> node_map;
1258     node_map[host_graph->source_node()] = g->source_node();
1259     while (!stack.empty()) {
1260       Visit& curr = stack.back();
1261       if (curr.is_exiting) {
1262         if (node_map.find(curr.n) == node_map.end()) {
1263           Node* copy = g->CopyNode(curr.n);
1264           if (curr.n != start_node) {
1265             for (const Edge* e : curr.n->in_edges()) {
1266               auto node_iter = node_map.find(e->src());
1267               if (node_iter == node_map.end()) {
1268                 return errors::Internal("Cannot find node image for ",
1269                                         e->src()->DebugString());
1270               }
1271               g->AddEdge(node_iter->second, e->src_output(), copy,
1272                          e->dst_input());
1273             }
1274           }
1275           node_map[curr.n] = copy;
1276         }
1277         stack.pop_back();
1278       } else {
1279         curr.is_exiting = true;
1280         if (curr.n != start_node) {
1281           for (const Edge* e : curr.n->in_edges()) {
1282             if (node_map.find(e->src()) != node_map.end()) {
1283               continue;
1284             }
1285             stack.push_back({e->src(), false});
1286           }
1287         }
1288       }
1289     }
1290 
1291     send_from_host = node_map[send_node_in_host_graph];
1292   } else {
1293     // This is an outside compilation generated for If/While/gradient/etc.
1294     // It will be enough for shape inference. Leave `g` unchanged.
1295   }
1296 
1297   // Control edges are not useful for shape inference. Remove them.
1298   for (auto e : g->edges()) {
1299     if (e->IsControlEdge()) {
1300       g->RemoveEdge(e);
1301     }
1302   }
1303 
1304   // Nodes that are not reverse reachable from SendFromHost are not useful for
1305   // shape inference. Prune them.
1306   PruneForReverseReachability(g,
1307                               std::unordered_set<const Node*>{send_from_host});
1308 
1309   if (VLOG_IS_ON(4)) {
1310     DumpGraphToFile(shape_inference_graph_name, *g, fld);
1311   }
1312 
1313   // Replace original shape inference graph.
1314   FunctionDef fdef_replace;
1315   TF_RETURN_IF_ERROR(
1316       GraphToFunctionDef(*g, shape_inference_graph_name, &fdef_replace));
1317   TF_RETURN_IF_ERROR(
1318       fld->ReplaceFunction(shape_inference_graph_name, fdef_replace));
1319 
1320   return Status::OK();
1321 }
1322 
1323 // Builds XlaSendToHost node which sends cond predicate to host.
BuildSendIfPredNode(const string & name,const string & host_transfer_key,Node * pred_node,Graph * g)1324 TF_ATTRIBUTE_NOINLINE xla::StatusOr<Node*> BuildSendIfPredNode(
1325     const string& name, const string& host_transfer_key, Node* pred_node,
1326     Graph* g) {
1327   NodeDefBuilder send_pred_builder(name, "XlaSendToHost");
1328   send_pred_builder.Attr("Tinput", DT_BOOL);
1329   send_pred_builder.Attr("key", absl::StrCat(host_transfer_key, "_dtoh_0"));
1330   send_pred_builder.Attr(kXlaTokenInputNodesAttrName,
1331                          std::vector<string>{kXlaTokenArgNodeName});
1332   send_pred_builder.Attr(kXlaOriginalOutsideCompilationNodeName, name);
1333   send_pred_builder.Input(pred_node->name(), 0, DT_BOOL);
1334   NodeDef send_pred_def;
1335   TF_RETURN_IF_ERROR(send_pred_builder.Finalize(&send_pred_def));
1336   Status s;
1337   Node* send_pred_node = g->AddNode(send_pred_def, &s);
1338   TF_RETURN_IF_ERROR(s);
1339   g->AddEdge(pred_node, 0, send_pred_node, 0);
1340   return send_pred_node;
1341 }
1342 
1343 // Replaces key placeholder node with an _Arg node.
ReplaceKeyPlaceholderWithArgNode(const string & xla_cluster_name,const string & func_name,FunctionLibraryDefinition * fld)1344 Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name,
1345                                         const string& func_name,
1346                                         FunctionLibraryDefinition* fld) {
1347   // Temporarily use "0" as "_device_ordinal". It will be reset to placeholder
1348   // value after rewriting.
1349   AttrValue device_ordinal_attr;
1350   device_ordinal_attr.set_i(0);
1351   protobuf::Map<string, AttrValue> attrs;
1352   attrs["_device_ordinal"] = device_ordinal_attr;
1353   std::unique_ptr<FunctionBody> fbody;
1354   const FunctionDef* func = fld->Find(func_name);
1355   TF_RETURN_IF_ERROR(
1356       FunctionDefToBodyHelper(*func, AttrSlice(&attrs), fld, &fbody));
1357   Graph* g = fbody->graph;
1358 
1359   // Find or create the key placeholder node.
1360   Node* key_placeholder = nullptr;
1361   for (Node* n : g->nodes()) {
1362     if (IsKeyPlaceholderNode(*n)) {
1363       key_placeholder = n;
1364       break;
1365     }
1366   }
1367   if (!key_placeholder) {
1368     TF_ASSIGN_OR_RETURN(key_placeholder,
1369                         AddHostComputeKeyPlaceholder(xla_cluster_name, g));
1370   }
1371 
1372   // Build the _Arg node, and replace key placeholder node with it.
1373   NodeDefBuilder arg_builder("key_arg", FunctionLibraryDefinition::kArgOp);
1374   arg_builder.Attr("T", DT_STRING);
1375   arg_builder.Attr("index", 0);
1376   NodeDef arg_def;
1377   TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def));
1378   TF_RETURN_IF_ERROR(ReplaceNode(g, key_placeholder, arg_def).status());
1379 
1380   // Reset "_device_ordinal" to placeholder value.
1381   TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(g));
1382 
1383   FunctionDef replace_fdef;
1384   TF_RETURN_IF_ERROR(GraphToFunctionDef(
1385       *g, func_name, HostGraphControlRetMapping, &replace_fdef));
1386   TF_RETURN_IF_ERROR(fld->ReplaceFunction(func_name, replace_fdef));
1387   return Status::OK();
1388 }
1389 
1390 // Builds host side graph for If node.
BuildHostGraphForIfNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const string & if_node_name,const string & host_transfer_key,const string & host_graph_func_name,FunctionLibraryDefinition * fld,const string & then_branch_host_func_name,const string & else_branch_host_func_name)1391 TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForIfNode(
1392     const string& xla_cluster_attr_name,
1393     const string& outside_compilation_attr_name, const string& xla_cluster_name,
1394     const string& if_node_name, const string& host_transfer_key,
1395     const string& host_graph_func_name, FunctionLibraryDefinition* fld,
1396     const string& then_branch_host_func_name,
1397     const string& else_branch_host_func_name) {
1398   Graph host_graph(fld);
1399   string outside_compilation_name = absl::StrCat("oc_if_", if_node_name);
1400   AttrValue device_ordinal_value;
1401   device_ordinal_value.set_placeholder("_device_ordinal");
1402 
1403   // Step 1: add key placeholder node.
1404   TF_ASSIGN_OR_RETURN(
1405       Node * key_placeholder,
1406       AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
1407 
1408   // Step 2: build XlaRecvAtHost node to recv predicate.
1409   NodeDefBuilder recv_pred_builder(
1410       absl::StrCat("recv_oc_if_pred_", if_node_name), "_XlaRecvAtHost");
1411   recv_pred_builder.Attr("Toutputs", std::vector<DataType>{DT_BOOL});
1412   recv_pred_builder.Attr("key", host_transfer_key);
1413   recv_pred_builder.Attr("device_ordinal", device_ordinal_value);
1414   recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1415   recv_pred_builder.Attr(outside_compilation_attr_name,
1416                          outside_compilation_name);
1417   recv_pred_builder.Attr(kXlaHasHostTransferAttrName, true);
1418   recv_pred_builder.Input(key_placeholder->name(), 0, DT_STRING);
1419   NodeDef recv_pred_def;
1420   TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def));
1421   Status s;
1422   Node* recv_pred_node = host_graph.AddNode(recv_pred_def, &s);
1423   TF_RETURN_IF_ERROR(s);
1424   host_graph.AddEdge(key_placeholder, 0, recv_pred_node, 0);
1425 
1426   // Step 3: rewrite `{then, else}_branch_host_func_name`, replace key
1427   // placeholder with an _Arg node.
1428   TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1429       xla_cluster_name, then_branch_host_func_name, fld));
1430   TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1431       xla_cluster_name, else_branch_host_func_name, fld));
1432 
1433   // Step 4: build If node to choose between `{then, else}_branch_host_graph`.
1434   NodeDefBuilder if_builder(absl::StrCat("oc_if_", if_node_name), "If");
1435   if_builder.Attr("Tcond", DT_BOOL);
1436   if_builder.Attr("Tin", std::vector<DataType>{DT_STRING});
1437   if_builder.Attr("Tout", std::vector<DataType>{});
1438   NameAttrList host_then_branch, host_else_branch;
1439   host_then_branch.set_name(then_branch_host_func_name);
1440   (*host_then_branch.mutable_attr())["_device_ordinal"] = device_ordinal_value;
1441   host_else_branch.set_name(else_branch_host_func_name);
1442   (*host_else_branch.mutable_attr())["_device_ordinal"] = device_ordinal_value;
1443   if_builder.Attr("then_branch", host_then_branch);
1444   if_builder.Attr("else_branch", host_else_branch);
1445   if_builder.Attr(kXlaHasHostTransferAttrName, true);
1446   if_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1447   if_builder.Attr(outside_compilation_attr_name, outside_compilation_name);
1448   if_builder.Input(recv_pred_node->name(), 0, DT_BOOL);
1449   std::vector<NodeDefBuilder::NodeOut> if_inputs{
1450       {key_placeholder->name(), 0, DT_STRING}};
1451   if_builder.Input(if_inputs);
1452   NodeDef if_def;
1453   TF_RETURN_IF_ERROR(if_builder.Finalize(&if_def));
1454   Node* if_node = host_graph.AddNode(if_def, &s);
1455   TF_RETURN_IF_ERROR(s);
1456   host_graph.AddEdge(recv_pred_node, 0, if_node, 0);
1457   host_graph.AddEdge(key_placeholder, 0, if_node, 1);
1458 
1459   // Convert `host_graph` to function.
1460   FunctionDef oc_host_graph_fdef;
1461   TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name,
1462                                         &oc_host_graph_fdef));
1463   if (fld->Find(host_graph_func_name)) {
1464     TF_RETURN_IF_ERROR(
1465         fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef));
1466   } else {
1467     TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef));
1468   }
1469 
1470   return Status::OK();
1471 }
1472 
1473 // Rewrites loop cond to add a node which sends loop cond to host.
AddSendLoopPredToLoopCond(const string & cond_xla_func_name,const string & host_transfer_key,NameAttrList * loop_cond_func,FunctionLibraryDefinition * fld,Node * while_node)1474 TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond(
1475     const string& cond_xla_func_name, const string& host_transfer_key,
1476     NameAttrList* loop_cond_func, FunctionLibraryDefinition* fld,
1477     Node* while_node) {
1478   // Instantiate the loop cond function.
1479   std::unique_ptr<FunctionBody> fbody;
1480   const FunctionDef* loop_cond_fdef = fld->Find(loop_cond_func->name());
1481   TF_RET_CHECK(loop_cond_fdef);
1482   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1483       *loop_cond_fdef, AttrSlice(&loop_cond_func->attr()), fld, &fbody));
1484   Graph* g = fbody->graph;
1485 
1486   // Find the _Retval node and the loop cond node.
1487   Node* ret_node = nullptr;
1488   for (Node* n : g->nodes()) {
1489     if (n->type_string() == "_Retval") {
1490       if (ret_node) {
1491         return errors::Internal("Multiple return node for loop cond function ",
1492                                 loop_cond_func->name(), ": ",
1493                                 ret_node->DebugString(), " and ",
1494                                 n->DebugString());
1495       } else {
1496         ret_node = n;
1497       }
1498     }
1499   }
1500   if (!ret_node) {
1501     return errors::Internal("No _Retval node for loop cond function ",
1502                             loop_cond_func->name());
1503   }
1504   Node* loop_cond;
1505   TF_RETURN_IF_ERROR(ret_node->input_node(0, &loop_cond));
1506 
1507   // Build the XlaSendToHost node.
1508   NodeDefBuilder send_loop_cond_builder(
1509       absl::StrCat("send_oc_while_cond_", while_node->name()), "XlaSendToHost");
1510   send_loop_cond_builder.Attr("Tinput", DT_BOOL);
1511   send_loop_cond_builder.Attr("key",
1512                               absl::StrCat(host_transfer_key, "_dtoh_0"));
1513   send_loop_cond_builder.Attr(kXlaTokenInputNodesAttrName,
1514                               std::vector<string>{kXlaTokenArgNodeName});
1515   send_loop_cond_builder.Attr(kXlaOriginalOutsideCompilationNodeName,
1516                               send_loop_cond_builder.node_name());
1517   send_loop_cond_builder.Input(loop_cond->name(), 0, DT_BOOL);
1518   NodeDef send_loop_cond_def;
1519   TF_RETURN_IF_ERROR(send_loop_cond_builder.Finalize(&send_loop_cond_def));
1520   Status s;
1521   Node* send_loop_cond_node = g->AddNode(send_loop_cond_def, &s);
1522   TF_RETURN_IF_ERROR(s);
1523   g->AddEdge(loop_cond, 0, send_loop_cond_node, 0);
1524 
1525   // Replace original function if loop_cond_func already has been re-written
1526   // for outside compilation.
1527   FunctionDef replace_fdef;
1528   if (loop_cond_func->name() == cond_xla_func_name) {
1529     TF_RETURN_IF_ERROR(
1530         GraphToFunctionDef(*g, loop_cond_func->name(), &replace_fdef));
1531     TF_RETURN_IF_ERROR(
1532         fld->ReplaceFunction(loop_cond_func->name(), replace_fdef));
1533   } else {
1534     // If original while cond function has not been modified, add a new function
1535     // with send loop predicated added and update the while node callsite
1536     // operation.
1537     const auto new_name = fld->UniqueFunctionName(
1538         absl::StrCat(loop_cond_func->name(), "_send_pred_added_"));
1539     TF_RETURN_IF_ERROR(GraphToFunctionDef(*g, new_name, &replace_fdef));
1540     TF_RETURN_IF_ERROR(fld->AddFunctionDef(replace_fdef));
1541     loop_cond_func->set_name(new_name);
1542     while_node->ClearAttr("cond");
1543     while_node->AddAttr("cond", *loop_cond_func);
1544   }
1545 
1546   return Status::OK();
1547 }
1548 
1549 // Rewrites while loop cond function for host.
RewriteHostWhileLoopCond(const string & cond_host_func_name,const string & while_node_name,const string & host_transfer_key,const string & xla_cluster_attr_name,const string & xla_cluster_name,const string & outside_compilation_attr_name,const string & outside_compilation_name,FunctionLibraryDefinition * fld)1550 Status RewriteHostWhileLoopCond(
1551     const string& cond_host_func_name, const string& while_node_name,
1552     const string& host_transfer_key, const string& xla_cluster_attr_name,
1553     const string& xla_cluster_name, const string& outside_compilation_attr_name,
1554     const string& outside_compilation_name, FunctionLibraryDefinition* fld) {
1555   // Replace key placeholder node with _Arg node.
1556   TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1557       xla_cluster_name, cond_host_func_name, fld));
1558 
1559   // Instantiate cond function.
1560   AttrValue device_ordinal_temp_value;
1561   device_ordinal_temp_value.set_i(0);
1562   protobuf::Map<string, AttrValue> attrs;
1563   attrs["_device_ordinal"] = device_ordinal_temp_value;
1564   std::unique_ptr<FunctionBody> cond_fbody;
1565   const FunctionDef* cond_host_func = fld->Find(cond_host_func_name);
1566   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*cond_host_func, AttrSlice(&attrs),
1567                                              fld, &cond_fbody));
1568   Graph* cond_graph = cond_fbody->graph;
1569   Node* key_arg = nullptr;
1570   for (Node* n : cond_graph->nodes()) {
1571     if (n->type_string() == "_Arg") {
1572       key_arg = n;
1573     }
1574   }
1575   if (!key_arg) {
1576     return errors::Internal(
1577         "No _Arg node found for host compute key in function ",
1578         cond_host_func_name);
1579   }
1580 
1581   // Add an XlaRecvAtHost node to use as cond function return value.
1582   NodeDefBuilder recv_pred_builder(
1583       absl::StrCat("recv_oc_while_cond_", while_node_name), "_XlaRecvAtHost");
1584   recv_pred_builder.Attr("Toutputs", std::vector<DataType>{DT_BOOL});
1585   recv_pred_builder.Attr("key", host_transfer_key);
1586   AttrValue device_ordinal_value;
1587   device_ordinal_value.set_placeholder("_device_ordinal");
1588   recv_pred_builder.Attr("device_ordinal", device_ordinal_value);
1589   recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1590   recv_pred_builder.Attr(outside_compilation_attr_name,
1591                          outside_compilation_name);
1592   recv_pred_builder.Attr(kXlaHasHostTransferAttrName, true);
1593   recv_pred_builder.Input(key_arg->name(), 0, DT_STRING);
1594   NodeDef recv_pred_def;
1595   TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def));
1596   Status s;
1597   Node* recv_pred_node = cond_graph->AddNode(recv_pred_def, &s);
1598   TF_RETURN_IF_ERROR(s);
1599   cond_graph->AddEdge(key_arg, 0, recv_pred_node, 0);
1600   NodeDefBuilder ret_builder(
1601       absl::StrCat("recv_oc_while_cond_ret_", while_node_name), "_Retval");
1602   ret_builder.Attr("T", DT_BOOL);
1603   ret_builder.Attr("index", 0);
1604   ret_builder.Input(recv_pred_node->name(), 0, DT_BOOL);
1605   NodeDef ret_def;
1606   TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
1607   Node* ret_node = cond_graph->AddNode(ret_def, &s);
1608   TF_RETURN_IF_ERROR(s);
1609   cond_graph->AddEdge(recv_pred_node, 0, ret_node, 0);
1610 
1611   // Reset device_ordinal to placeholder value.
1612   TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(cond_graph));
1613 
1614   // Replace original function.
1615   FunctionDef cond_replace_fdef;
1616   TF_RETURN_IF_ERROR(GraphToFunctionDef(*cond_graph, cond_host_func_name,
1617                                         HostGraphControlRetMapping,
1618                                         &cond_replace_fdef));
1619   TF_RETURN_IF_ERROR(
1620       fld->ReplaceFunction(cond_host_func_name, cond_replace_fdef));
1621 
1622   return Status::OK();
1623 }
1624 
1625 // Rewrites while loop body function for host.
RewriteHostWhileLoopBody(const string & body_host_func_name,const string & while_node_name,const string & host_transfer_key,const string & xla_cluster_attr_name,const string & xla_cluster_name,const string & outside_compilation_attr_name,const string & outside_compilation_name,FunctionLibraryDefinition * fld)1626 Status RewriteHostWhileLoopBody(
1627     const string& body_host_func_name, const string& while_node_name,
1628     const string& host_transfer_key, const string& xla_cluster_attr_name,
1629     const string& xla_cluster_name, const string& outside_compilation_attr_name,
1630     const string& outside_compilation_name, FunctionLibraryDefinition* fld) {
1631   // Replace key placeholder node with _Arg node.
1632   TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1633       xla_cluster_name, body_host_func_name, fld));
1634 
1635   // Instantiate body function.
1636   AttrValue device_ordinal_temp_value;
1637   device_ordinal_temp_value.set_i(0);
1638   protobuf::Map<string, AttrValue> attrs;
1639   attrs["_device_ordinal"] = device_ordinal_temp_value;
1640   std::unique_ptr<FunctionBody> body_fbody;
1641   const FunctionDef* body_host_func = fld->Find(body_host_func_name);
1642   TF_RET_CHECK(body_host_func);
1643   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*body_host_func, AttrSlice(&attrs),
1644                                              fld, &body_fbody));
1645   Graph* body_graph = body_fbody->graph;
1646   Node* key_arg = nullptr;
1647   for (Node* n : body_graph->nodes()) {
1648     if (n->type_string() == "_Arg") {
1649       key_arg = n;
1650     }
1651   }
1652   if (!key_arg) {
1653     return errors::Internal(
1654         "No _Arg node found for host compute key in function ",
1655         body_host_func_name);
1656   }
1657 
1658   // Add a _Retval node to loop body.
1659   NodeDefBuilder ret_builder(
1660       absl::StrCat("recv_oc_while_body_ret_", while_node_name), "_Retval");
1661   ret_builder.Attr("T", DT_STRING);
1662   ret_builder.Attr("index", 0);
1663   ret_builder.Input(key_arg->name(), 0, DT_STRING);
1664   NodeDef ret_def;
1665   TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
1666   Status s;
1667   Node* ret_node = body_graph->AddNode(ret_def, &s);
1668   TF_RETURN_IF_ERROR(s);
1669   body_graph->AddEdge(key_arg, 0, ret_node, 0);
1670 
1671   // Reset device_ordinal to placeholder value.
1672   TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(body_graph));
1673 
1674   // Replace original function.
1675   FunctionDef body_replace_fdef;
1676   TF_RETURN_IF_ERROR(GraphToFunctionDef(*body_graph, body_host_func_name,
1677                                         HostGraphControlRetMapping,
1678                                         &body_replace_fdef));
1679   TF_RETURN_IF_ERROR(
1680       fld->ReplaceFunction(body_host_func_name, body_replace_fdef));
1681 
1682   return Status::OK();
1683 }
1684 
1685 // Builds host side graph for while node.
BuildHostGraphForWhileNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const string & while_node_name,const string & host_transfer_key,const string & host_graph_func_name,FunctionLibraryDefinition * fld,const string & cond_host_func_name,const string & body_host_func_name)1686 TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForWhileNode(
1687     const string& xla_cluster_attr_name,
1688     const string& outside_compilation_attr_name, const string& xla_cluster_name,
1689     const string& while_node_name, const string& host_transfer_key,
1690     const string& host_graph_func_name, FunctionLibraryDefinition* fld,
1691     const string& cond_host_func_name, const string& body_host_func_name) {
1692   Graph host_graph(fld);
1693   string outside_compilation_name = absl::StrCat("oc_while_", while_node_name);
1694 
1695   // Step 1: add key placeholder node.
1696   TF_ASSIGN_OR_RETURN(
1697       Node * key_placeholder,
1698       AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
1699 
1700   // Step 2: rewrite cond function.
1701   TF_RETURN_IF_ERROR(RewriteHostWhileLoopCond(
1702       cond_host_func_name, while_node_name, host_transfer_key,
1703       xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
1704       outside_compilation_name, fld));
1705 
1706   // Step 3: rewrite body function.
1707   TF_RETURN_IF_ERROR(RewriteHostWhileLoopBody(
1708       body_host_func_name, while_node_name, host_transfer_key,
1709       xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
1710       outside_compilation_name, fld));
1711 
1712   // Step 4: build While node.
1713   NodeDefBuilder while_builder(absl::StrCat("oc_while_", while_node_name),
1714                                "While");
1715   while_builder.Attr("T", std::vector<DataType>{DT_STRING});
1716   NameAttrList func;
1717   AttrValue device_ordinal_value;
1718   device_ordinal_value.set_placeholder("_device_ordinal");
1719   (*func.mutable_attr())["_device_ordinal"] = device_ordinal_value;
1720   func.set_name(cond_host_func_name);
1721   while_builder.Attr("cond", func);
1722   func.set_name(body_host_func_name);
1723   while_builder.Attr("body", func);
1724   while_builder.Attr(kXlaHasHostTransferAttrName, true);
1725   while_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1726   while_builder.Attr(outside_compilation_attr_name, outside_compilation_name);
1727   // Make sure loop body of i-th iteration happens before loop cond of (i+1)-th
1728   // iteration.
1729   while_builder.Attr("parallel_iterations", 1);
1730   std::vector<NodeDefBuilder::NodeOut> while_inputs{
1731       {key_placeholder->name(), 0, DT_STRING}};
1732   while_builder.Input(while_inputs);
1733   NodeDef while_def;
1734   TF_RETURN_IF_ERROR(while_builder.Finalize(&while_def));
1735   Status s;
1736   Node* while_node = host_graph.AddNode(while_def, &s);
1737   TF_RETURN_IF_ERROR(s);
1738   host_graph.AddEdge(key_placeholder, 0, while_node, 0);
1739 
1740   // Convert `host_graph` to function.
1741   FunctionDef oc_host_graph_fdef;
1742   TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name,
1743                                         &oc_host_graph_fdef));
1744   if (fld->Find(host_graph_func_name)) {
1745     TF_RETURN_IF_ERROR(
1746         fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef));
1747   } else {
1748     TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef));
1749   }
1750 
1751   return Status::OK();
1752 }
1753 
1754 // Builds host graph for func call nodes.
BuildHostGraphForFuncCallNode(const string & xla_cluster_attr_name,const string & xla_cluster_name,const string & outside_compilation_attr_name,const string & func_call_node_name,const string & func_call_host_func_name,const string & host_graph_func_name,FunctionLibraryDefinition * fld)1755 Status BuildHostGraphForFuncCallNode(
1756     const string& xla_cluster_attr_name, const string& xla_cluster_name,
1757     const string& outside_compilation_attr_name,
1758     const string& func_call_node_name, const string& func_call_host_func_name,
1759     const string& host_graph_func_name, FunctionLibraryDefinition* fld) {
1760   Graph host_graph(fld);
1761   AttrValue device_ordinal_value;
1762   device_ordinal_value.set_placeholder("_device_ordinal");
1763 
1764   // Step 1: add key placeholder node.
1765   TF_ASSIGN_OR_RETURN(
1766       Node * key_placeholder,
1767       AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
1768 
1769   // Step 2: rewrite `host_func_name`, replace key placeholder with an _Arg
1770   // node.
1771   TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1772       xla_cluster_name, func_call_host_func_name, fld));
1773 
1774   // Step 3: build a function call node with `host_func_name`, with
1775   // `key_placeholder` as input.
1776   NodeDefBuilder call_builder(absl::StrCat("oc_call_", func_call_node_name),
1777                               func_call_host_func_name, fld);
1778   call_builder.Input(key_placeholder->name(), 0, DT_STRING);
1779   call_builder.Attr("_device_ordinal", device_ordinal_value);
1780   call_builder.Attr(kXlaHasHostTransferAttrName, true);
1781   call_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1782   call_builder.Attr(outside_compilation_attr_name, call_builder.node_name());
1783   NodeDef call_def;
1784   TF_RETURN_IF_ERROR(call_builder.Finalize(&call_def));
1785   Status s;
1786   Node* call_node = host_graph.AddNode(call_def, &s);
1787   TF_RETURN_IF_ERROR(s);
1788   host_graph.AddEdge(key_placeholder, 0, call_node, 0);
1789 
1790   // Convert `host_graph` to function.
1791   FunctionDef oc_host_graph_fdef;
1792   TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name,
1793                                         HostGraphControlRetMapping,
1794                                         &oc_host_graph_fdef));
1795   if (fld->Find(host_graph_func_name)) {
1796     TF_RETURN_IF_ERROR(
1797         fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef));
1798   } else {
1799     TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef));
1800   }
1801 
1802   return Status::OK();
1803 }
1804 
ExtractOutsideCompilationForFuncCallNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const std::map<string,int> & host_compute_core,Graph * g,Node * n,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * host_graphs,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)1805 TF_ATTRIBUTE_NOINLINE Status ExtractOutsideCompilationForFuncCallNode(
1806     const string& xla_cluster_attr_name,
1807     const string& outside_compilation_attr_name, const string& xla_cluster_name,
1808     const std::map<string, int>& host_compute_core, Graph* g, Node* n,
1809     FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
1810     std::vector<string>* host_graphs,
1811     std::vector<string>* shape_inference_graphs,
1812     bool* has_outside_compilation) {
1813   bool func_has_outside_compilation = false;
1814   NameAttrList func;
1815   if (fld->Contains(n->type_string())) {
1816     func.set_name(n->type_string());
1817     typedef protobuf::Map<string, AttrValue> AttrMap;
1818     *func.mutable_attr() = AttrMap(n->attrs().begin(), n->attrs().end());
1819   } else if (n->IsPartitionedCall()) {
1820     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &func));
1821   } else {
1822     TF_RET_CHECK(n->type_string() == FunctionLibraryDefinition::kGradientOp);
1823     func.set_name(FunctionLibraryDefinition::kGradientOp);
1824     *func.mutable_attr() = n->def().attr();
1825   }
1826   string canonical_func_name;
1827   if (func.name() == FunctionLibraryDefinition::kGradientOp) {
1828     NameAttrList forward_func;
1829     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &forward_func));
1830     canonical_func_name = absl::StrCat("gradient_", forward_func.name());
1831   } else {
1832     canonical_func_name = func.name();
1833   }
1834   string new_func_name = absl::StrCat(canonical_func_name, "_oc");
1835   string host_func_name =
1836       absl::StrCat("oc_func_call_host_", canonical_func_name);
1837   TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
1838       xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
1839       func, new_func_name, host_func_name, host_compute_core, flr, fld,
1840       shape_inference_graphs, &func_has_outside_compilation));
1841 
1842   // If the function call does not have outside compilation, nothing to do.
1843   if (!func_has_outside_compilation) {
1844     return Status::OK();
1845   }
1846 
1847   *has_outside_compilation = true;
1848 
1849   // Change `n` to call the new function directly.
1850   auto replace_builder =
1851       absl::make_unique<NodeDefBuilder>(n->name(), new_func_name, fld);
1852   std::vector<NodeDefBuilder::NodeOut> inputs(n->num_inputs());
1853   for (const Edge* e : n->in_edges()) {
1854     if (e->IsControlEdge()) {
1855       continue;
1856     }
1857 
1858     const bool input_size_check =
1859         e->dst_input() < static_cast<int>(inputs.size());
1860     TF_RET_CHECK(e->dst_input() >= 0 && input_size_check);
1861     inputs[e->dst_input()] =
1862         NodeDefBuilder::NodeOut{e->src()->name(), e->src_output(),
1863                                 e->src()->output_type(e->src_output())};
1864   }
1865   for (const auto& input : inputs) {
1866     replace_builder->Input(input);
1867   }
1868   for (const auto& attr : n->attrs()) {
1869     replace_builder->Attr(attr.first, attr.second);
1870   }
1871   auto replace_def = absl::make_unique<NodeDef>();
1872   TF_RETURN_IF_ERROR(replace_builder->Finalize(replace_def.get()));
1873   TF_ASSIGN_OR_RETURN(Node * replace, ReplaceNode(g, n, *replace_def));
1874   replace->AddAttr(kXlaTokenInputNodesAttrName,
1875                    std::vector<string>{kXlaTokenArgNodeName});
1876   replace->AddAttr(kXlaOriginalOutsideCompilationNodeName, replace->name());
1877 
1878   // Build host side graph for the function call.
1879   string oc_host_graph_name =
1880       absl::StrCat("oc_func_host_graph_", replace->name());
1881   TF_RETURN_IF_ERROR(BuildHostGraphForFuncCallNode(
1882       xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
1883       replace->name(), host_func_name, oc_host_graph_name, fld));
1884 
1885   // Record the host graph.
1886   host_graphs->push_back(oc_host_graph_name);
1887 
1888   return Status::OK();
1889 }
1890 
ExtractOutsideCompilationForIfNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const std::map<string,int> & host_compute_core,Graph * g,Node * n,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * host_graphs,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)1891 Status ExtractOutsideCompilationForIfNode(
1892     const string& xla_cluster_attr_name,
1893     const string& outside_compilation_attr_name, const string& xla_cluster_name,
1894     const std::map<string, int>& host_compute_core, Graph* g, Node* n,
1895     FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
1896     std::vector<string>* host_graphs,
1897     std::vector<string>* shape_inference_graphs,
1898     bool* has_outside_compilation) {
1899   // Instantiate "then_branch" and "else_branch".
1900   NameAttrList then_branch, else_branch;
1901   TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "then_branch", &then_branch));
1902   TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "else_branch", &else_branch));
1903 
1904   // Extract outside compilation for then_branch and else_branch.
1905   bool then_branch_has_outside_compilation = false;
1906   bool else_branch_has_outside_compilation = false;
1907   string then_branch_host_func_name =
1908              absl::StrCat("oc_then_branch_host_if_", then_branch.name()),
1909          else_branch_host_func_name =
1910              absl::StrCat("oc_else_branch_host_if_", else_branch.name());
1911   string then_branch_xla_func_name = absl::StrCat(then_branch.name(), "_oc"),
1912          else_branch_xla_func_name = absl::StrCat(else_branch.name(), "_oc");
1913   TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
1914       xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
1915       then_branch, then_branch_xla_func_name, then_branch_host_func_name,
1916       host_compute_core, flr, fld, shape_inference_graphs,
1917       &then_branch_has_outside_compilation));
1918   TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
1919       xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
1920       else_branch, else_branch_xla_func_name, else_branch_host_func_name,
1921       host_compute_core, flr, fld, shape_inference_graphs,
1922       &else_branch_has_outside_compilation));
1923 
1924   // If then/else branch do not have outside compilation, nothing to do.
1925   if (!then_branch_has_outside_compilation &&
1926       !else_branch_has_outside_compilation) {
1927     return Status::OK();
1928   }
1929 
1930   *has_outside_compilation = true;
1931 
1932   // Change If node to call the new functions.
1933   if (then_branch_has_outside_compilation) {
1934     then_branch.set_name(then_branch_xla_func_name);
1935     n->ClearAttr("then_branch");
1936     n->AddAttr("then_branch", then_branch);
1937   }
1938   if (else_branch_has_outside_compilation) {
1939     else_branch.set_name(else_branch_xla_func_name);
1940     n->ClearAttr("else_branch");
1941     n->AddAttr("else_branch", else_branch);
1942   }
1943   n->AddAttr(kXlaOriginalOutsideCompilationNodeName, n->name());
1944 
1945   string host_transfer_key = absl::StrCat("oc_if_pred_", n->name());
1946 
1947   // XLA computation: add a SendToHost node to send cond predicate.
1948   Node* pred_node;
1949   TF_RETURN_IF_ERROR(n->input_node(0, &pred_node));
1950   TF_ASSIGN_OR_RETURN(
1951       Node * send_pred_node,
1952       BuildSendIfPredNode(absl::StrCat("send_oc_if_pred_", n->name()),
1953                           host_transfer_key, pred_node, g));
1954   n->AddAttr(kXlaTokenInputNodesAttrName,
1955              std::vector<string>{send_pred_node->name()});
1956 
1957   // Add a control edge from `send_pred_node` to If node, so XlaCompiler will
1958   // visit If node after `send_pred_node`, thus the token output for
1959   // `send_pred_node` has been generated.
1960   g->AddControlEdge(send_pred_node, n);
1961 
1962   // Build host side graph for the "If" node.
1963   // If then/else branch does not have outside compilation, we won't build host
1964   // graph for the branch. But here we need a host graph for both branches, so
1965   // we need to create a no-op host graph.
1966   if (!then_branch_has_outside_compilation) {
1967     std::unique_ptr<Graph> then_branch_host_graph(new Graph(fld));
1968     std::vector<string> then_branch_host_graphs;
1969     TF_RETURN_IF_ERROR(ConstructHostGraph(
1970         xla_cluster_name, outside_compilation_attr_name,
1971         then_branch_host_graphs, fld, &then_branch_host_graph));
1972     FunctionDef then_branch_host_fdef;
1973     TF_RETURN_IF_ERROR(GraphToFunctionDef(*then_branch_host_graph,
1974                                           then_branch_host_func_name,
1975                                           &then_branch_host_fdef));
1976     if (fld->Find(then_branch_host_func_name)) {
1977       TF_RETURN_IF_ERROR(fld->ReplaceFunction(then_branch_host_func_name,
1978                                               then_branch_host_fdef));
1979     } else {
1980       TF_RETURN_IF_ERROR(fld->AddFunctionDef(then_branch_host_fdef));
1981     }
1982   }
1983   if (!else_branch_has_outside_compilation) {
1984     std::unique_ptr<Graph> else_branch_host_graph(new Graph(fld));
1985     std::vector<string> else_branch_host_graphs;
1986     TF_RETURN_IF_ERROR(ConstructHostGraph(
1987         xla_cluster_name, outside_compilation_attr_name,
1988         else_branch_host_graphs, fld, &else_branch_host_graph));
1989     FunctionDef else_branch_host_fdef;
1990     TF_RETURN_IF_ERROR(GraphToFunctionDef(*else_branch_host_graph,
1991                                           else_branch_host_func_name,
1992                                           &else_branch_host_fdef));
1993     if (fld->Find(else_branch_host_func_name)) {
1994       TF_RETURN_IF_ERROR(fld->ReplaceFunction(else_branch_host_func_name,
1995                                               else_branch_host_fdef));
1996     } else {
1997       TF_RETURN_IF_ERROR(fld->AddFunctionDef(else_branch_host_fdef));
1998     }
1999   }
2000   string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name());
2001   TF_RETURN_IF_ERROR(BuildHostGraphForIfNode(
2002       xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2003       n->name(), host_transfer_key, oc_host_graph_name, fld,
2004       then_branch_host_func_name, else_branch_host_func_name));
2005   host_graphs->push_back(oc_host_graph_name);
2006 
2007   return Status::OK();
2008 }
2009 
ExtractOutsideCompilationForWhileNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const std::map<string,int> & host_compute_core,Graph * g,Node * n,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * host_graphs,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)2010 Status ExtractOutsideCompilationForWhileNode(
2011     const string& xla_cluster_attr_name,
2012     const string& outside_compilation_attr_name, const string& xla_cluster_name,
2013     const std::map<string, int>& host_compute_core, Graph* g, Node* n,
2014     FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
2015     std::vector<string>* host_graphs,
2016     std::vector<string>* shape_inference_graphs,
2017     bool* has_outside_compilation) {
2018   // Instantiate "cond" and "body".
2019   NameAttrList cond, body;
2020   TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "cond", &cond));
2021   TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "body", &body));
2022 
2023   // Extract outside compilation for cond and body.
2024   bool cond_has_outside_compilation = false;
2025   bool body_has_outside_compilation = false;
2026   string cond_host_func_name = absl::StrCat("oc_cond_host_while_", cond.name()),
2027          body_host_func_name = absl::StrCat("oc_body_host_while_", body.name());
2028   string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"),
2029          body_xla_func_name = absl::StrCat(body.name(), "_oc");
2030   TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
2031       xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2032       cond, cond_xla_func_name, cond_host_func_name, host_compute_core, flr,
2033       fld, shape_inference_graphs, &cond_has_outside_compilation));
2034   TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
2035       xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2036       body, body_xla_func_name, body_host_func_name, host_compute_core, flr,
2037       fld, shape_inference_graphs, &body_has_outside_compilation));
2038 
2039   // If cond/body do not have outside compilation, nothing to do.
2040   if (!cond_has_outside_compilation && !body_has_outside_compilation) {
2041     return Status::OK();
2042   }
2043 
2044   *has_outside_compilation = true;
2045 
2046   // Change While node to call the new functions.
2047   if (cond_has_outside_compilation) {
2048     cond.set_name(cond_xla_func_name);
2049     n->ClearAttr("cond");
2050     n->AddAttr("cond", cond);
2051   }
2052   if (body_has_outside_compilation) {
2053     body.set_name(body_xla_func_name);
2054     n->ClearAttr("body");
2055     n->AddAttr("body", body);
2056   }
2057   n->AddAttr(kXlaOriginalOutsideCompilationNodeName, n->name());
2058 
2059   string host_transfer_key = absl::StrCat("oc_while_pred_", n->name());
2060 
2061   // XLA computation: rewrite cond function to add a SendToHost node to send
2062   // loop predicate.
2063   TF_RETURN_IF_ERROR(AddSendLoopPredToLoopCond(
2064       cond_xla_func_name, host_transfer_key, &cond, fld, n));
2065   n->AddAttr(kXlaTokenInputNodesAttrName,
2066              std::vector<string>{kXlaTokenArgNodeName});
2067 
2068   // Build host side graph for the "While" node.
2069   if (!cond_has_outside_compilation) {
2070     std::unique_ptr<Graph> cond_host_graph(new Graph(fld));
2071     std::vector<string> host_graphs;
2072     TF_RETURN_IF_ERROR(ConstructHostGraph(xla_cluster_name,
2073                                           outside_compilation_attr_name,
2074                                           host_graphs, fld, &cond_host_graph));
2075     FunctionDef cond_host_fdef;
2076     TF_RETURN_IF_ERROR(GraphToFunctionDef(*cond_host_graph, cond_host_func_name,
2077                                           &cond_host_fdef));
2078     if (fld->Find(cond_host_func_name)) {
2079       TF_RETURN_IF_ERROR(
2080           fld->ReplaceFunction(cond_host_func_name, cond_host_fdef));
2081     } else {
2082       TF_RETURN_IF_ERROR(fld->AddFunctionDef(cond_host_fdef));
2083     }
2084   }
2085   if (!body_has_outside_compilation) {
2086     std::unique_ptr<Graph> body_host_graph(new Graph(fld));
2087     std::vector<string> host_graphs;
2088     TF_RETURN_IF_ERROR(ConstructHostGraph(xla_cluster_name,
2089                                           outside_compilation_attr_name,
2090                                           host_graphs, fld, &body_host_graph));
2091     FunctionDef body_host_fdef;
2092     TF_RETURN_IF_ERROR(GraphToFunctionDef(*body_host_graph, body_host_func_name,
2093                                           &body_host_fdef));
2094     if (fld->Find(body_host_func_name)) {
2095       TF_RETURN_IF_ERROR(
2096           fld->ReplaceFunction(body_host_func_name, body_host_fdef));
2097     } else {
2098       TF_RETURN_IF_ERROR(fld->AddFunctionDef(body_host_fdef));
2099     }
2100   }
2101   string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name());
2102   TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode(
2103       xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2104       n->name(), host_transfer_key, oc_host_graph_name, fld,
2105       cond_host_func_name, body_host_func_name));
2106   host_graphs->push_back(oc_host_graph_name);
2107 
2108   return Status::OK();
2109 }
2110 
ExtractOutsideCompilationForNodesWithAssociatedFunctions(Graph * g,const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const std::map<string,int> & host_compute_core,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * host_graphs,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)2111 Status ExtractOutsideCompilationForNodesWithAssociatedFunctions(
2112     Graph* g, const string& xla_cluster_attr_name,
2113     const string& outside_compilation_attr_name, const string& xla_cluster_name,
2114     const std::map<string, int>& host_compute_core, FunctionLibraryRuntime* flr,
2115     FunctionLibraryDefinition* fld, std::vector<string>* host_graphs,
2116     std::vector<string>* shape_inference_graphs,
2117     bool* has_outside_compilation) {
2118   std::vector<Node*> if_nodes, while_nodes, func_call_nodes;
2119   for (Node* n : g->nodes()) {
2120     if (n->IsIfNode()) {
2121       if_nodes.push_back(n);
2122     } else if (n->IsWhileNode()) {
2123       while_nodes.push_back(n);
2124     } else if (IsFunctionCall(*fld, *n)) {
2125       func_call_nodes.push_back(n);
2126     }
2127   }
2128 
2129   for (Node* n : func_call_nodes) {
2130     TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFuncCallNode(
2131         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2132         host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs,
2133         has_outside_compilation));
2134   }
2135 
2136   for (Node* n : if_nodes) {
2137     TF_RETURN_IF_ERROR(ExtractOutsideCompilationForIfNode(
2138         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2139         host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs,
2140         has_outside_compilation));
2141   }
2142 
2143   for (Node* n : while_nodes) {
2144     TF_RETURN_IF_ERROR(ExtractOutsideCompilationForWhileNode(
2145         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2146         host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs,
2147         has_outside_compilation));
2148   }
2149 
2150   return Status::OK();
2151 }
2152 
CopyOutsideCompilationConstNodes(Graph * g,const string & outside_compilation_attr_name)2153 Status CopyOutsideCompilationConstNodes(
2154     Graph* g, const string& outside_compilation_attr_name) {
2155   for (Node* n : g->op_nodes()) {
2156     if (!n->IsConstant() ||
2157         !HasNodeAttr(n->def(), outside_compilation_attr_name)) {
2158       continue;
2159     }
2160 
2161     std::vector<const Edge*> out_edges(n->out_edges().begin(),
2162                                        n->out_edges().end());
2163     bool has_non_oc_output = false;
2164     for (const Edge* e : out_edges) {
2165       if (!e->IsControlEdge() &&
2166           !HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
2167         has_non_oc_output = true;
2168         break;
2169       }
2170     }
2171     if (!has_non_oc_output) {
2172       continue;
2173     }
2174 
2175     NodeDef copy_def = n->def();
2176     copy_def.set_name(g->NewName(n->name()));
2177     copy_def.mutable_attr()->erase(outside_compilation_attr_name);
2178     Status s;
2179     Node* copy_node = g->AddNode(copy_def, &s);
2180     TF_RETURN_IF_ERROR(s);
2181     for (const Edge* e : n->in_edges()) {
2182       if (e->IsControlEdge()) {
2183         g->AddControlEdge(e->src(), copy_node);
2184       }
2185     }
2186     for (const Edge* e : out_edges) {
2187       if (!e->IsControlEdge() &&
2188           !HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
2189         Node* dst = e->dst();
2190         int dst_input = e->dst_input();
2191         g->RemoveEdge(e);
2192         g->AddEdge(copy_node, 0, dst, dst_input);
2193       }
2194     }
2195   }
2196 
2197   return Status::OK();
2198 }
2199 
2200 }  // namespace
2201 
operator ()(const std::vector<OutputTensor> & arg_source_tensors,std::unique_ptr<Graph> * graph,std::vector<int> * input_permutation,std::vector<int> * output_permutation,NodeDef * node_def)2202 Status RewriteOutsideCompilationSubgraphFn::operator()(
2203     const std::vector<OutputTensor>& arg_source_tensors,
2204     std::unique_ptr<Graph>* graph, std::vector<int>* input_permutation,
2205     std::vector<int>* output_permutation, NodeDef* node_def) {
2206   string old_name = node_def->op();
2207   string new_name =
2208       absl::StrCat(xla_cluster_name_, "_", new_function_name_, "_", old_name);
2209   node_def->set_op(new_name);
2210   node_def->set_name(new_name);
2211 
2212   // Later we will run PruneForReverseReachability(), so make sure all original
2213   // nodes are reachable from sink node and won't be removed.
2214   FixupSourceAndSinkEdges(graph->get());
2215 
2216   // Step 1: create a key placeholder node.
2217   TF_ASSIGN_OR_RETURN(
2218       Node * key_placeholder,
2219       AddHostComputeKeyPlaceholder(xla_cluster_name_, graph->get()));
2220 
2221   // Step 2: build RecvAtHost node, and replace all _Arg nodes with it.
2222   std::vector<DataType> recv_at_host_dtypes;
2223   TF_ASSIGN_OR_RETURN(
2224       Node * recv_at_host_node,
2225       ReplaceArgNodesWithRecvAtHostNode(graph->get(), new_name,
2226                                         &recv_at_host_dtypes, key_placeholder));
2227 
2228   // Step 3: build SendFromHost node, and replace all _Retval nodes with it.
2229   std::vector<DataType> send_from_host_dtypes;
2230   TF_ASSIGN_OR_RETURN(
2231       Node * send_from_host_node,
2232       ReplaceRetNodesWithSendFromHostNode(
2233           graph->get(), new_name, &send_from_host_dtypes, key_placeholder));
2234 
2235   // Step 4: add XLA cluster and outside compilation attr.
2236   for (Node* n : (*graph)->nodes()) {
2237     if (IsKeyPlaceholderNode(*n)) {
2238       continue;
2239     }
2240 
2241     n->AddAttr(xla_cluster_attr_name_, xla_cluster_name_);
2242     n->AddAttr(outside_compilation_attr_name_, old_name);
2243   }
2244 
2245   // Check whether we have all input shapes for XlaSendFromHost. If we do, we
2246   // will set `shapes` attr for the call node; otherwise we will save the
2247   // shape inference graph and set `shape_inference_graph` for the call node.
2248   absl::optional<std::vector<PartialTensorShape>> shapes =
2249       GetInferredInputShapes(send_from_host_dtypes.size(), send_from_host_node);
2250   for (Node* n : (*graph)->nodes()) {
2251     n->ClearAttr(kXlaInferredShapesAttrName);
2252   }
2253 
2254   // Step 5: add control edges for originally XLA <-> outside compilation
2255   // control edges.
2256   for (Node* n : (*graph)->nodes()) {
2257     if (HasNodeAttr(n->def(), kXlaConnectedToXlaComputationAttrName)) {
2258       (*graph)->AddControlEdge(n, send_from_host_node);
2259       n->ClearAttr(kXlaConnectedToXlaComputationAttrName);
2260     }
2261     if (HasNodeAttr(n->def(), kXlaConnectedFromXlaComputationAttrName)) {
2262       (*graph)->AddControlEdge(recv_at_host_node, n);
2263       n->ClearAttr(kXlaConnectedFromXlaComputationAttrName);
2264     }
2265   }
2266 
2267   // Step 6: RecvAtHost/SendFromHost/key_placeholder might be dead nodes. Prune
2268   // them if necessary.
2269   // - RecvAtHost should be pruned iff it has no output data/control edges. If
2270   //   it has any output edge, it will be reverse reachable from sink node. We
2271   //   don't need to do anything special.
2272   // - SendFromHost should be pruned iff it has no input data/control edges. If
2273   //   it has input edges other than key_placeholder, we connect it to sink
2274   //   node so it won't be pruned.
2275   // - key_placeholder should be pruned iff RecvAtHost/SendFromHost are pruned.
2276   //   We don't need to do anything special.
2277   if (send_from_host_node->in_edges().size() > 1) {
2278     (*graph)->AddControlEdge(send_from_host_node, (*graph)->sink_node());
2279   }
2280   PruneForReverseReachability(
2281       graph->get(), std::unordered_set<const Node*>{(*graph)->sink_node()});
2282 
2283   // Step 7: add necessary attributes to function call node, so we can replace
2284   // it with HostCompute node later.
2285   AddNodeAttr("_outside_compilation_subgraph", old_name, node_def);
2286   if (shapes) {
2287     NameAttrList shape_inference_graph;
2288     AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def);
2289     AddNodeAttr("shapes", *shapes, node_def);
2290   } else {
2291     string shape_inference_func_name =
2292         absl::StrCat("_outside_compilation_shape_inference_", new_name);
2293     NameAttrList shape_inference_graph;
2294     shape_inference_graph.set_name(shape_inference_func_name);
2295     AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def);
2296     AddNodeAttr("shapes", std::vector<TensorShapeProto>{}, node_def);
2297   }
2298   AddNodeAttr("ancestors", std::vector<string>{}, node_def);
2299   AddNodeAttr("Tinputs", recv_at_host_dtypes, node_def);
2300   AddNodeAttr("Toutputs", send_from_host_dtypes, node_def);
2301   AddNodeAttr("key", absl::StrCat("host_compute_channel_", new_name), node_def);
2302 
2303   return Status::OK();
2304 }
2305 
ExtractOutsideCompilationForFunction(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const NameAttrList & func_name_attrs,const string & new_func_name,const string & host_graph_func_name,const std::map<string,int> & host_compute_core,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)2306 Status ExtractOutsideCompilationForFunction(
2307     const string& xla_cluster_attr_name,
2308     const string& outside_compilation_attr_name, const string& xla_cluster_name,
2309     const NameAttrList& func_name_attrs, const string& new_func_name,
2310     const string& host_graph_func_name,
2311     const std::map<string, int>& host_compute_core, FunctionLibraryRuntime* flr,
2312     FunctionLibraryDefinition* fld, std::vector<string>* shape_inference_graphs,
2313     bool* has_outside_compilation) {
2314   // Convert the function to graph.
2315   const string& func_name = func_name_attrs.name();
2316   FunctionLibraryRuntime::Handle handle;
2317   TF_RETURN_IF_ERROR(
2318       flr->Instantiate(func_name, AttrSlice(&func_name_attrs.attr()), &handle));
2319   Status ret_status = Status::OK();
2320   auto cleanup_handle = gtl::MakeCleanup([&]() {
2321     auto s = flr->ReleaseHandle(handle);
2322     if (!s.ok()) {
2323       ret_status.Update(s);
2324     }
2325   });
2326   const FunctionBody* fbody = flr->GetFunctionBody(handle);
2327 
2328   // Check if we have outside compilation nodes.
2329   *has_outside_compilation = false;
2330   for (Node* n : fbody->graph->nodes()) {
2331     if (HasNodeAttr(n->def(), outside_compilation_attr_name)) {
2332       *has_outside_compilation = true;
2333       break;
2334     }
2335   }
2336   // We cannot early return here, because we might have outside compilation in
2337   // If/While function body.
2338 
2339   if (VLOG_IS_ON(4)) {
2340     DumpGraphToFile(
2341         absl::StrCat("extract_outside_compilation_for_func_before_", func_name),
2342         *fbody->graph, fld);
2343   }
2344 
2345   std::unique_ptr<Graph> graph_out;
2346   std::vector<string> outside_compilation_host_graphs;
2347   std::vector<string> shape_inference_graphs_to_rewrite;
2348   if (*has_outside_compilation) {
2349     // Copy outside compilation Const nodes with non outside compilation users.
2350     TF_RETURN_IF_ERROR(CopyOutsideCompilationConstNodes(
2351         fbody->graph, outside_compilation_attr_name));
2352 
2353     // Find dependencies between outside compilation clusters.
2354     TF_ASSIGN_OR_RETURN(auto cluster_deps,
2355                         OutsideCompilationClusterDependencies(
2356                             fbody->graph, outside_compilation_attr_name));
2357 
2358     // Preprocess edges between different outside compilations. They will be
2359     // restored in `ConstructHostGraph()`.
2360     TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations(
2361         fbody->graph, outside_compilation_attr_name));
2362 
2363     // Encapsulate outside_compilation cluster into function call node.
2364     auto rewrite_fn = absl::make_unique<RewriteOutsideCompilationSubgraphFn>(
2365         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2366         new_func_name);
2367     TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
2368         outside_compilation_attr_name, *fbody->graph, *rewrite_fn,
2369         /*reuse_existing_functions=*/true, &graph_out, fld));
2370 
2371     // Replace outside_compilation function nodes with HostCompute ops.
2372     std::vector<Node*> outside_compilation_nodes;
2373     for (Node* n : graph_out->nodes()) {
2374       if (HasNodeAttr(n->def(), "_outside_compilation_subgraph")) {
2375         outside_compilation_nodes.push_back(n);
2376         outside_compilation_host_graphs.push_back(n->name());
2377 
2378         // If we could not infer shapes for XlaSendFromHost inputs statically,
2379         // we will set the "shape_inference_graph" attribute. In that case, copy
2380         // outside compilation subgraph as shape inference graph in `fld`.
2381         auto shape_inference_graph = absl::make_unique<NameAttrList>();
2382         TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "shape_inference_graph",
2383                                        shape_inference_graph.get()));
2384         if (!shape_inference_graph->name().empty()) {
2385           shape_inference_graphs->push_back(shape_inference_graph->name());
2386           shape_inference_graphs_to_rewrite.push_back(
2387               shape_inference_graph->name());
2388 
2389           const FunctionDef* xla_fdef = fld->Find(n->name());
2390           if (!xla_fdef) {
2391             return errors::Internal("Cannot find XLA function ", n->name());
2392           }
2393           auto shape_inference_fdef = absl::make_unique<FunctionDef>(*xla_fdef);
2394           shape_inference_fdef->mutable_signature()->set_name(
2395               shape_inference_graph->name());
2396           if (fld->Find(shape_inference_graph->name())) {
2397             TF_RETURN_IF_ERROR(fld->ReplaceFunction(
2398                 shape_inference_graph->name(), *shape_inference_fdef));
2399           } else {
2400             TF_RETURN_IF_ERROR(fld->AddFunctionDef(*shape_inference_fdef));
2401           }
2402         }
2403       }
2404     }
2405     std::map<string, Node*> host_compute_nodes;
2406     for (Node* n : outside_compilation_nodes) {
2407       auto host_compute_node_or = ReplaceOutsideCompilationCallNode(
2408           graph_out.get(), n, host_compute_core, *cluster_deps);
2409       TF_RETURN_IF_ERROR(host_compute_node_or.status());
2410       Node* host_compute_node = host_compute_node_or.ValueOrDie();
2411       host_compute_nodes[host_compute_node->name()] = host_compute_node;
2412     }
2413     // For XlaHostCompute nodes with dependencies, add control edges between
2414     // them so XlaCompiler can handle them in correct order.
2415     for (const auto& iter : host_compute_nodes) {
2416       Node* host_compute_node = iter.second;
2417       std::vector<string> token_input_node_names;
2418       TF_RETURN_IF_ERROR(GetNodeAttr(host_compute_node->def(),
2419                                      kXlaTokenInputNodesAttrName,
2420                                      &token_input_node_names));
2421       for (const string& node_name : token_input_node_names) {
2422         if (node_name == kXlaTokenArgNodeName) {
2423           continue;
2424         }
2425 
2426         auto iter = host_compute_nodes.find(node_name);
2427         TF_RET_CHECK(iter != host_compute_nodes.end());
2428         graph_out->AddControlEdge(iter->second, host_compute_node);
2429       }
2430     }
2431   }
2432 
2433   // Handle nodes with associated functions.
2434   Graph* g = (*has_outside_compilation) ? graph_out.get() : fbody->graph;
2435   TF_RETURN_IF_ERROR(ExtractOutsideCompilationForNodesWithAssociatedFunctions(
2436       g, xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2437       host_compute_core, flr, fld, &outside_compilation_host_graphs,
2438       shape_inference_graphs, has_outside_compilation));
2439 
2440   if (*has_outside_compilation) {
2441     // Construct host graph.
2442     std::unique_ptr<Graph> host_graph;
2443     TF_RETURN_IF_ERROR(
2444         ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name,
2445                            outside_compilation_host_graphs, fld, &host_graph));
2446     auto host_graph_fdef = absl::make_unique<FunctionDef>();
2447     TF_RETURN_IF_ERROR(GraphToFunctionDef(*host_graph, host_graph_func_name,
2448                                           HostGraphControlRetMapping,
2449                                           host_graph_fdef.get()));
2450     if (fld->Find(host_graph_func_name)) {
2451       TF_RETURN_IF_ERROR(
2452           fld->ReplaceFunction(host_graph_func_name, *host_graph_fdef));
2453     } else {
2454       TF_RETURN_IF_ERROR(fld->AddFunctionDef(*host_graph_fdef));
2455     }
2456 
2457     // Shape inference graphs might contain Placeholder nodes for outside
2458     // compilation to outside compilation edges. Rewrite shape inference graphs
2459     // to remove such nodes.
2460     for (const string& shape_inference_graph :
2461          shape_inference_graphs_to_rewrite) {
2462       TF_RETURN_IF_ERROR(
2463           RewriteShapeInferenceGraph(shape_inference_graph, host_graph.get(),
2464                                      /*pivot_node=*/nullptr, fld));
2465     }
2466 
2467     // Remove the outside compilation graphs from function library.
2468     for (const string& func : outside_compilation_host_graphs) {
2469       TF_RETURN_IF_ERROR(fld->RemoveFunction(func));
2470     }
2471 
2472     // Replace original function.
2473     auto updated_fdef = absl::make_unique<FunctionDef>();
2474     TF_RETURN_IF_ERROR(
2475         GraphToFunctionDef(*g, new_func_name, updated_fdef.get()));
2476     updated_fdef->mutable_signature()->set_is_stateful(true);
2477     const FunctionDef* original_fdef = fld->Find(func_name);
2478     if (original_fdef) {
2479       for (const auto& attr : original_fdef->attr()) {
2480         (*updated_fdef->mutable_attr())[attr.first] = attr.second;
2481       }
2482     }
2483     if (fld->Find(new_func_name)) {
2484       TF_RETURN_IF_ERROR(fld->ReplaceFunction(new_func_name, *updated_fdef));
2485     } else {
2486       TF_RETURN_IF_ERROR(fld->AddFunctionDef(*updated_fdef));
2487     }
2488     if (VLOG_IS_ON(4)) {
2489       DumpGraphToFile(
2490           absl::StrCat("extract_outside_compilation_for_func_after_",
2491                        func_name),
2492           *g, fld);
2493     }
2494   }
2495 
2496   return ret_status;
2497 }
2498 
ExtractOutsideCompilation(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const std::unordered_map<string,XlaClusterInfo> & clusters,Graph * g,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,bool * modified)2499 Status ExtractOutsideCompilation(
2500     const string& xla_cluster_attr_name,
2501     const string& outside_compilation_attr_name,
2502     const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g,
2503     FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
2504     bool* modified) {
2505   if (VLOG_IS_ON(4)) {
2506     DumpGraphToFile("extract_outside_compilation_before", *g, fld);
2507   }
2508 
2509   *modified = false;
2510   auto node_name_index = g->BuildNodeNameIndex();
2511   for (auto& iter : clusters) {
2512     string xla_cluster_name = iter.first;
2513     Node* n = iter.second.node;
2514     auto const& func_name_attrs = iter.second.func_name_attrs;
2515     auto const& host_compute_core = iter.second.host_compute_core;
2516 
2517     std::vector<string> shape_inference_graphs;
2518     bool has_outside_compilation;
2519     string host_graph_func_name =
2520         absl::StrCat("oc_host_graph_", xla_cluster_name);
2521     TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
2522         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2523         func_name_attrs, func_name_attrs.name(), host_graph_func_name,
2524         host_compute_core, flr, fld, &shape_inference_graphs,
2525         &has_outside_compilation));
2526     *modified |= has_outside_compilation;
2527 
2528     if (has_outside_compilation) {
2529       string pivot_name = absl::StrCat(xla_cluster_name, "/pivot");
2530       Node* pivot_node = node_name_index[pivot_name];
2531       TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph(
2532           g, fld, host_graph_func_name, n, pivot_node));
2533 
2534       TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name));
2535 
2536       for (const auto& shape_inference_graph_name : shape_inference_graphs) {
2537         TF_RETURN_IF_ERROR(RewriteShapeInferenceGraph(
2538             shape_inference_graph_name, g, pivot_node, fld));
2539       }
2540     }
2541   }
2542 
2543   if (VLOG_IS_ON(4)) {
2544     DumpGraphToFile("extract_outside_compilation_after", *g, fld);
2545   }
2546   return Status::OK();
2547 }
2548 
2549 }  // namespace tensorflow
2550