1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
17
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/strings/match.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
22 #include "tensorflow/compiler/jit/encapsulate_util.h"
23 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
24 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/core/common_runtime/function.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/graph_to_functiondef.h"
29 #include "tensorflow/core/framework/node_def_builder.h"
30 #include "tensorflow/core/framework/node_def_util.h"
31 #include "tensorflow/core/framework/tensor_shape.pb.h"
32 #include "tensorflow/core/graph/algorithm.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/gtl/cleanup.h"
35 #include "tensorflow/core/platform/macros.h"
36 #include "tensorflow/core/util/dump_graph.h"
37 #include "tensorflow/stream_executor/lib/statusor.h"
38
39 namespace tensorflow {
40
41 namespace {
42
43 // Control return mapping function for outside compilation host graphs.
44 // All nodes with kXlaHasHostTransfer attribute are control outputs.
HostGraphControlRetMapping(const Node * n)45 absl::optional<string> HostGraphControlRetMapping(const Node* n) {
46 if (HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) {
47 return n->name();
48 }
49 return absl::nullopt;
50 }
51
52 // Add a key placeholder node to the graph. The key placeholder node will be
53 // used as input for XlaRecvAtHost/XlaSendFromHost nodes.
AddHostComputeKeyPlaceholder(const string & xla_cluster_name,Graph * g)54 xla::StatusOr<Node*> AddHostComputeKeyPlaceholder(
55 const string& xla_cluster_name, Graph* g) {
56 NodeDef key_def;
57 NodeDefBuilder builder(absl::StrCat(xla_cluster_name, "_key_placeholder"),
58 "Placeholder");
59 builder.Attr("dtype", DT_STRING);
60 builder.Attr("shape", PartialTensorShape({2}));
61 builder.Attr("_host_compute_call_node", xla_cluster_name);
62 Status s = builder.Finalize(&key_def);
63 if (!s.ok()) return s;
64
65 Node* n = g->AddNode(key_def, &s);
66 if (!s.ok()) return s;
67 return n;
68 }
69
70 // Returns if the node is a XLA computation key placeholder.
IsKeyPlaceholderNode(const Node & n)71 bool IsKeyPlaceholderNode(const Node& n) {
72 return n.type_string() == "Placeholder" &&
73 absl::EndsWith(n.name(), "_key_placeholder");
74 }
75
76 // Returns nodes with given type.
GatherNodesWithType(const Graph & g,const string & type)77 std::vector<Node*> GatherNodesWithType(const Graph& g, const string& type) {
78 std::vector<Node*> result;
79 for (Node* n : g.nodes()) {
80 if (n->type_string() == type) {
81 result.push_back(n);
82 }
83 }
84 return result;
85 }
86
87 // Gets data types from `arg_nodes` and fills them into `recv_at_host_dtypes`.
GetArgDataTypes(const std::vector<Node * > & arg_nodes,std::vector<DataType> * recv_at_host_dtypes)88 Status GetArgDataTypes(const std::vector<Node*>& arg_nodes,
89 std::vector<DataType>* recv_at_host_dtypes) {
90 recv_at_host_dtypes->resize(arg_nodes.size(), DT_INVALID);
91 for (auto* n : arg_nodes) {
92 int index;
93 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
94 DataType dtype;
95 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype));
96 (*recv_at_host_dtypes)[index] = dtype;
97 }
98 for (int i = 0, end = recv_at_host_dtypes->size(); i < end; i++) {
99 if ((*recv_at_host_dtypes)[i] == DT_INVALID) {
100 return errors::Internal("Cannot get datatype for input ", i);
101 }
102 }
103 return Status::OK();
104 }
105
106 // Builds XlaRecvAtHost node.
BuildRecvAtHostNode(Graph * g,const string & oc_cluster_name,const std::vector<DataType> & recv_at_host_dtypes,Node * key_placeholder)107 xla::StatusOr<Node*> BuildRecvAtHostNode(
108 Graph* g, const string& oc_cluster_name,
109 const std::vector<DataType>& recv_at_host_dtypes, Node* key_placeholder) {
110 NodeDefBuilder recv_at_host_builder(
111 absl::StrCat("outside_compilation_", oc_cluster_name, "_recv"),
112 "_XlaRecvAtHost");
113 NodeDef recv_at_host_def;
114 recv_at_host_builder.Attr("Toutputs", recv_at_host_dtypes);
115 // The correct device_ordinal will be inserted during replication in a
116 // subsequent rewrite.
117 AttrValue device_ordinal_value;
118 device_ordinal_value.set_placeholder("_device_ordinal");
119 recv_at_host_builder.Attr("device_ordinal", device_ordinal_value);
120 recv_at_host_builder.Attr(
121 "key", absl::StrCat("host_compute_channel_", oc_cluster_name));
122 recv_at_host_builder.Attr(kXlaHasHostTransferAttrName, true);
123 recv_at_host_builder.Input(key_placeholder->name(), 0, DT_STRING);
124 TF_RETURN_IF_ERROR(recv_at_host_builder.Finalize(&recv_at_host_def));
125 Status s;
126 Node* recv_at_host_node = g->AddNode(recv_at_host_def, &s);
127 TF_RETURN_IF_ERROR(s);
128 return recv_at_host_node;
129 }
130
131 // Builds XlaRecvAtHost node, and replaces all _Arg nodes with it.
ReplaceArgNodesWithRecvAtHostNode(Graph * g,const string & oc_cluster_name,std::vector<DataType> * recv_at_host_dtypes,Node * key_placeholder)132 xla::StatusOr<Node*> ReplaceArgNodesWithRecvAtHostNode(
133 Graph* g, const string& oc_cluster_name,
134 std::vector<DataType>* recv_at_host_dtypes, Node* key_placeholder) {
135 // TODO(b/77601805): use out nodes for source node, instead of traversing all
136 // nodes.
137 std::vector<Node*> arg_nodes = GatherNodesWithType(*g, "_Arg");
138 TF_RETURN_IF_ERROR(GetArgDataTypes(arg_nodes, recv_at_host_dtypes));
139 TF_ASSIGN_OR_RETURN(
140 Node * recv_at_host_node,
141 BuildRecvAtHostNode(g, oc_cluster_name, *recv_at_host_dtypes,
142 key_placeholder));
143 for (auto* n : arg_nodes) {
144 int index;
145 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
146 // Record out edges and remove `n` before adding those edges to RecvAtHost.
147 // This is to avoid multiple producers.
148 std::vector<OutEdgeInfo> out_edge_info;
149 for (auto edge : n->out_edges()) {
150 out_edge_info.push_back(
151 {edge->dst(), edge->src_output(), edge->dst_input()});
152 }
153 g->RemoveNode(n);
154 for (const OutEdgeInfo& edge : out_edge_info) {
155 if (edge.dst_input == Graph::kControlSlot) {
156 g->AddControlEdge(recv_at_host_node, edge.dst);
157 } else {
158 g->AddEdge(recv_at_host_node, index, edge.dst, edge.dst_input);
159 }
160 }
161
162 // Rewrite dst nodes because their input changed.
163 for (int i = 0, end = out_edge_info.size(); i < end; i++) {
164 const OutEdgeInfo edge = out_edge_info[i];
165 if (edge.dst_input == Graph::kControlSlot) {
166 continue;
167 }
168
169 Node* dst = edge.dst;
170 NodeDef new_def = dst->def();
171 *new_def.mutable_input(edge.dst_input) =
172 absl::StrCat(recv_at_host_node->name(), ":", index);
173 TF_ASSIGN_OR_RETURN(Node * dst_replace, ReplaceNode(g, dst, new_def));
174
175 // Other edges might have `dst` as dst node as well. Update those edges
176 // with `dst_replace`.
177 for (int j = i + 1, end = out_edge_info.size(); j < end; j++) {
178 if (out_edge_info[j].dst == dst) {
179 out_edge_info[j].dst = dst_replace;
180 }
181 }
182 }
183 }
184 g->AddEdge(key_placeholder, 0, recv_at_host_node, 0);
185 return recv_at_host_node;
186 }
187
188 // Gets data types from `ret_nodes` and fills them into `send_from_host_dtypes`.
GetRetDataTypes(const std::vector<Node * > & ret_nodes,std::vector<DataType> * send_from_host_dtypes)189 Status GetRetDataTypes(const std::vector<Node*>& ret_nodes,
190 std::vector<DataType>* send_from_host_dtypes) {
191 send_from_host_dtypes->resize(ret_nodes.size(), DT_INVALID);
192 for (auto* n : ret_nodes) {
193 int index;
194 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
195 DataType dtype;
196 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype));
197 (*send_from_host_dtypes)[index] = dtype;
198 }
199 for (int i = 0, end = send_from_host_dtypes->size(); i < end; i++) {
200 if ((*send_from_host_dtypes)[i] == DT_INVALID) {
201 return errors::Internal("Cannot get datatype for output ", i);
202 }
203 }
204 return Status::OK();
205 }
206
207 // Builds XlaSendFromHost node.
BuildSendFromHostNode(Graph * g,const string & oc_cluster_name,const std::vector<Node * > & ret_nodes,const std::vector<DataType> & send_from_host_dtypes,Node * key_placeholder)208 xla::StatusOr<Node*> BuildSendFromHostNode(
209 Graph* g, const string& oc_cluster_name,
210 const std::vector<Node*>& ret_nodes,
211 const std::vector<DataType>& send_from_host_dtypes, Node* key_placeholder) {
212 NodeDefBuilder send_from_host_builder(
213 absl::StrCat("outside_compilation_", oc_cluster_name, "_send"),
214 "_XlaSendFromHost");
215 NodeDef send_from_host_def;
216 send_from_host_builder.Attr("Tinputs", send_from_host_dtypes);
217 // The correct device_ordinal will be inserted during replication in a
218 // subsequent rewrite.
219 AttrValue device_ordinal_value;
220 device_ordinal_value.set_placeholder("_device_ordinal");
221 send_from_host_builder.Attr("device_ordinal", device_ordinal_value);
222 send_from_host_builder.Attr(
223 "key", absl::StrCat("host_compute_channel_", oc_cluster_name));
224 send_from_host_builder.Attr(kXlaHasHostTransferAttrName, true);
225 std::vector<NodeDefBuilder::NodeOut> inputs(send_from_host_dtypes.size());
226 for (auto* n : ret_nodes) {
227 int index;
228 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
229 const int num_dtypes = send_from_host_dtypes.size();
230 if (index < 0 || index >= num_dtypes) {
231 return errors::Internal("Invalid _Retval index: ", index);
232 }
233 for (auto edge : n->in_edges()) {
234 inputs[index] =
235 NodeDefBuilder::NodeOut{edge->src()->name(), edge->src_output(),
236 edge->src()->output_type(edge->src_output())};
237 }
238 }
239 send_from_host_builder.Input(inputs);
240 send_from_host_builder.Input(key_placeholder->name(), 0, DT_STRING);
241 TF_RETURN_IF_ERROR(send_from_host_builder.Finalize(&send_from_host_def));
242 Status s;
243 Node* send_from_host_node = g->AddNode(send_from_host_def, &s);
244 TF_RETURN_IF_ERROR(s);
245 return send_from_host_node;
246 }
247
248 // Builds XlaSendFromHost node, and replaces all _Retval nodes with it.
ReplaceRetNodesWithSendFromHostNode(Graph * g,const string & oc_cluster_name,std::vector<DataType> * send_from_host_dtypes,Node * key_placeholder)249 xla::StatusOr<Node*> ReplaceRetNodesWithSendFromHostNode(
250 Graph* g, const string& oc_cluster_name,
251 std::vector<DataType>* send_from_host_dtypes, Node* key_placeholder) {
252 // TODO(b/77601805): use in nodes for sink node, instead of traversing all
253 // nodes.
254 std::vector<Node*> ret_nodes = GatherNodesWithType(*g, "_Retval");
255 TF_RETURN_IF_ERROR(GetRetDataTypes(ret_nodes, send_from_host_dtypes));
256 TF_ASSIGN_OR_RETURN(
257 Node * send_from_host_node,
258 BuildSendFromHostNode(g, oc_cluster_name, ret_nodes,
259 *send_from_host_dtypes, key_placeholder));
260 for (auto* n : ret_nodes) {
261 int index;
262 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
263 for (auto edge : n->in_edges()) {
264 if (edge->src_output() == Graph::kControlSlot) {
265 g->AddControlEdge(edge->src(), send_from_host_node);
266 } else {
267 g->AddEdge(edge->src(), edge->src_output(), send_from_host_node, index);
268 }
269 }
270 g->RemoveNode(n);
271 }
272 g->AddEdge(key_placeholder, 0, send_from_host_node,
273 send_from_host_dtypes->size());
274 return send_from_host_node;
275 }
276
277 // Returns input shapes (excluding key placeholder) for `send_from_host_node`
278 // if they are all fully defined; absl::nullopt otherwise.
GetInferredInputShapes(int num_inputs,Node * send_from_host_node)279 absl::optional<std::vector<PartialTensorShape>> GetInferredInputShapes(
280 int num_inputs, Node* send_from_host_node) {
281 std::vector<PartialTensorShape> results(num_inputs);
282 for (int i = 0; i < num_inputs; i++) {
283 const Edge* e;
284 if (!send_from_host_node->input_edge(i, &e).ok()) {
285 return absl::nullopt;
286 }
287
288 std::vector<PartialTensorShape> shapes;
289 if (!GetNodeAttr(e->src()->attrs(), kXlaInferredShapesAttrName, &shapes)
290 .ok()) {
291 return absl::nullopt;
292 }
293
294 const PartialTensorShape shape = shapes[e->src_output()];
295 if (!shape.IsFullyDefined()) {
296 return absl::nullopt;
297 }
298
299 results[e->dst_input()] = shape;
300 }
301 return results;
302 }
303
host_compute_node_name(const string & original_oc_name)304 string host_compute_node_name(const string& original_oc_name) {
305 return absl::StrCat("outside_compilation_", original_oc_name,
306 "_host_compute");
307 }
308
309 // Builds XlaHostCompute NodeDef from the outside compilation call node.
BuildXlaHostComputeNodeDef(const Node * call_node,const std::map<string,int> & host_compute_core,const absl::flat_hash_map<string,std::vector<string>> & cluster_deps)310 xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
311 const Node* call_node, const std::map<string, int>& host_compute_core,
312 const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
313 string original_oc_name;
314 TF_RETURN_IF_ERROR(GetNodeAttr(
315 call_node->attrs(), "_outside_compilation_subgraph", &original_oc_name));
316 NodeDefBuilder host_compute_builder(host_compute_node_name(original_oc_name),
317 "XlaHostCompute");
318 // In XlaCompiler, if XlaHostCompute node is in a function call node and that
319 // function is inlined, name of the XlaHostCompute node will be changed. So
320 // we cannot rely on node name; use an attribute instead.
321 host_compute_builder.Attr(kXlaOriginalOutsideCompilationNodeName,
322 host_compute_builder.node_name());
323
324 // Copy all attributes.
325 for (const auto& attr : call_node->attrs()) {
326 host_compute_builder.Attr(attr.first, attr.second);
327 }
328
329 // Populate tpu_core assignment.
330 const auto iter = host_compute_core.find(original_oc_name);
331 if (iter != host_compute_core.end()) {
332 int core = iter->second;
333 host_compute_builder.Attr("tpu_core", core);
334 }
335
336 // Set input tokens and other outside compilation clusters that current
337 // cluster depends in `kXlaTokenArgNodeName`. This is needed because when
338 // outside compilation subgraphs are encapsulated and moved to host graph,
339 // control/data edges between them will only be reflected in host graph.
340 // From XLA's perspective, two originally dependent clusters are no longer
341 // connected, which makes them look like they can be scheduled for execution
342 // in arbitrary order even though in fact they must be executed in order
343 // according to their host-side graph dependency. This can cause deadlock.
344 // Therefore, we hint XLA what the correct ordering of these clusters should
345 // be to avoid deadlocks.
346 std::vector<string> xla_token_input_nodes;
347 xla_token_input_nodes.emplace_back(kXlaTokenArgNodeName);
348 auto cluster_deps_it = cluster_deps.find(original_oc_name);
349 if (cluster_deps_it != cluster_deps.end()) {
350 for (const auto& dep : cluster_deps_it->second) {
351 xla_token_input_nodes.emplace_back(host_compute_node_name(dep));
352 }
353 }
354 host_compute_builder.Attr(kXlaTokenInputNodesAttrName, xla_token_input_nodes);
355
356 // Populate inputs.
357 std::vector<DataType> input_dtypes;
358 TF_RETURN_IF_ERROR(GetNodeAttr(call_node->attrs(), "Tinputs", &input_dtypes));
359 std::vector<NodeDefBuilder::NodeOut> inputs(input_dtypes.size());
360 for (auto e : call_node->in_edges()) {
361 if (e->IsControlEdge()) {
362 continue;
363 }
364
365 const int input_dtypes_size = input_dtypes.size();
366 if (e->dst_input() < 0 || e->dst_input() >= input_dtypes_size) {
367 return errors::Internal("Invalid dst_input: ", e->dst_input());
368 }
369 inputs[e->dst_input()] = NodeDefBuilder::NodeOut{
370 e->src()->name(), e->src_output(), input_dtypes[e->dst_input()]};
371 }
372 host_compute_builder.Input(inputs);
373
374 NodeDef new_def;
375 TF_RETURN_IF_ERROR(host_compute_builder.Finalize(&new_def));
376 return new_def;
377 }
378
379 // Replace outside compilation function call node with XlaHostCompute node.
ReplaceOutsideCompilationCallNode(Graph * g,Node * call_node,const std::map<string,int> & host_compute_core,const absl::flat_hash_map<string,std::vector<string>> & cluster_deps)380 TF_ATTRIBUTE_NOINLINE xla::StatusOr<Node*> ReplaceOutsideCompilationCallNode(
381 Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
382 const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
383 // Build XlaHostCompute NodeDef.
384 TF_ASSIGN_OR_RETURN(
385 NodeDef node_def,
386 BuildXlaHostComputeNodeDef(call_node, host_compute_core, cluster_deps));
387 TF_ASSIGN_OR_RETURN(Node * host_compute_node,
388 ReplaceNode(g, call_node, node_def));
389 VLOG(4) << "Added HostCompute node: " << host_compute_node->DebugString();
390
391 return host_compute_node;
392 }
393
394 // Resets "_device_ordinal" attr to placeholder value for related nodes
395 // (XlaRecvAtHost nodes; XlaSendFromHost nodes; If/While/FuncCall nodes
396 // containing XlaRecvAtHost/XlaSendFromHost).
ResetDeviceOrdinalToPlaceholderValue(Graph * g)397 Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) {
398 AttrValue device_ordinal_value;
399 device_ordinal_value.set_placeholder("_device_ordinal");
400 for (Node* n : g->nodes()) {
401 if (!HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) {
402 continue;
403 }
404
405 if (n->type_string() == "_XlaRecvAtHost" ||
406 n->type_string() == "_XlaSendFromHost") {
407 n->ClearAttr("device_ordinal");
408 n->AddAttr("device_ordinal", device_ordinal_value);
409 } else if (n->IsIfNode()) {
410 for (const string& attr_name :
411 std::vector<string>{"then_branch", "else_branch"}) {
412 NameAttrList branch_func;
413 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func));
414 (*branch_func.mutable_attr())["_device_ordinal"] = device_ordinal_value;
415 n->ClearAttr(attr_name);
416 n->AddAttr(attr_name, branch_func);
417 }
418 } else if (n->IsWhileNode()) {
419 for (const string& attr_name : std::vector<string>{"cond", "body"}) {
420 NameAttrList branch_func;
421 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func));
422 (*branch_func.mutable_attr())["_device_ordinal"] = device_ordinal_value;
423 n->ClearAttr(attr_name);
424 n->AddAttr(attr_name, branch_func);
425 }
426 } else if (HasNodeAttr(n->def(), "_device_ordinal")) {
427 // Function call node containing outside compilation.
428 n->ClearAttr("_device_ordinal");
429 n->AddAttr("_device_ordinal", device_ordinal_value);
430 } else {
431 return errors::Internal("Unknown node marked with ",
432 kXlaHasHostTransferAttrName, ": ",
433 n->DebugString());
434 }
435 }
436 return Status::OK();
437 }
438
439 // Cheap check to tell whether FunctionDef contains a lifted argument.
HasLiftedArgs(const FunctionDef & function_def)440 bool HasLiftedArgs(const FunctionDef& function_def) {
441 return absl::c_any_of(function_def.node_def(), [](const NodeDef& node_def) {
442 return (node_def.op() == "Placeholder" &&
443 node_def.attr().find(kXlaLiftedArgOutsideCompilationAttrName) !=
444 node_def.attr().end());
445 });
446 }
447
448 // Find lifted arguments in a function body and their corresponding outside
449 // compilation nodes.
450 xla::StatusOr<std::vector<std::pair<Node*, Node*>>>
LiftedArgsAndOutsideCompilationNodesInFunctionBody(const FunctionBody & function_body,const std::unordered_map<string,Node * > & outside_compilation_attr_to_node)451 LiftedArgsAndOutsideCompilationNodesInFunctionBody(
452 const FunctionBody& function_body,
453 const std::unordered_map<string, Node*>& outside_compilation_attr_to_node) {
454 std::vector<std::pair<Node*, Node*>>
455 lifted_arg_nodes_and_outside_compilation_nodes;
456 for (Node* n : function_body.graph->op_nodes()) {
457 string oc_cluster;
458 if (n->type_string() == "Placeholder" &&
459 GetNodeAttr(n->def(), kXlaLiftedArgOutsideCompilationAttrName,
460 &oc_cluster)
461 .ok()) {
462 TF_RET_CHECK(outside_compilation_attr_to_node.find(oc_cluster) !=
463 outside_compilation_attr_to_node.end());
464 lifted_arg_nodes_and_outside_compilation_nodes.emplace_back(
465 n, outside_compilation_attr_to_node.at(oc_cluster));
466 }
467 }
468 return lifted_arg_nodes_and_outside_compilation_nodes;
469 }
470
471 // Append lifted args' types to functional control flow node's `type_attr_name`
472 // attribute.
UpdateTypesAttribute(const std::vector<std::pair<Node *,Node * >> & lifted_arg_nodes_and_outside_compilation_nodes,const string & type_attr_name,Node * n)473 xla::StatusOr<std::vector<DataType>> UpdateTypesAttribute(
474 const std::vector<std::pair<Node*, Node*>>&
475 lifted_arg_nodes_and_outside_compilation_nodes,
476 const string& type_attr_name, Node* n) {
477 std::vector<DataType> data_types;
478 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), type_attr_name, &data_types));
479 for (auto pair : lifted_arg_nodes_and_outside_compilation_nodes) {
480 Node* outside_compilation_node = pair.second;
481 DataType data_type;
482 TF_RET_CHECK(outside_compilation_node->IsIdentity() ||
483 outside_compilation_node->type_string() == "Placeholder");
484 if (outside_compilation_node->IsIdentity()) {
485 TF_RETURN_IF_ERROR(
486 GetNodeAttr(outside_compilation_node->def(), "T", &data_type));
487 } else {
488 TF_RETURN_IF_ERROR(
489 GetNodeAttr(outside_compilation_node->def(), "dtype", &data_type));
490 }
491 data_types.push_back(data_type);
492 }
493 n->ClearAttr(type_attr_name);
494 n->AddAttr(type_attr_name, data_types);
495
496 return data_types;
497 }
498
499 // Add edges from lifted outside compilation argument nodes to `n` in Graph `g`.
AddEdgesFromOutsideCompilationNodes(const int original_arg_count,const int arg_to_input_edge_offset,const std::vector<DataType> & data_types,const std::vector<Node * > & outside_compilation_nodes,Graph * g,Node * n)500 void AddEdgesFromOutsideCompilationNodes(
501 const int original_arg_count, const int arg_to_input_edge_offset,
502 const std::vector<DataType>& data_types,
503 const std::vector<Node*>& outside_compilation_nodes, Graph* g, Node* n) {
504 // Add edges from outside compilation nodes to While node.
505 for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
506 Node* outside_compilation_node =
507 outside_compilation_nodes[i - original_arg_count];
508 g->AddEdge(outside_compilation_node, 0, n, i + arg_to_input_edge_offset);
509 }
510 }
511
512 // Construct _Arg that maps to lifted outside compilation argument node input.
AddOutsideCompilationInputArgToFunctionBody(const FunctionBody & function_body,const int arg_idx,const DataType & data_type)513 xla::StatusOr<Node*> AddOutsideCompilationInputArgToFunctionBody(
514 const FunctionBody& function_body, const int arg_idx,
515 const DataType& data_type) {
516 NodeDefBuilder arg_builder(absl::StrCat("arg_", arg_idx), "_Arg");
517 arg_builder.Attr("T", data_type);
518 arg_builder.Attr("index", arg_idx);
519 NodeDef arg_def;
520 TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def));
521
522 Status s;
523 Node* arg_node = function_body.graph->AddNode(arg_def, &s);
524 TF_RETURN_IF_ERROR(s);
525 return arg_node;
526 }
527
528 // Add _Retval node that matches newly added `arg_node` and connect `arg_node`
529 // to it.
AddMatchingRetvalNode(const FunctionBody & function_body,const int arg_idx,const DataType & data_type,Node * arg_node)530 Status AddMatchingRetvalNode(const FunctionBody& function_body,
531 const int arg_idx, const DataType& data_type,
532 Node* arg_node) {
533 NodeDefBuilder ret_builder(absl::StrCat("ret_", arg_idx), "_Retval");
534 ret_builder.Attr("T", data_type);
535 ret_builder.Attr("index", arg_idx);
536 ret_builder.Input(arg_node->name(), 0, data_type);
537 NodeDef ret_def;
538 TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
539 Status s;
540 Node* ret_node = function_body.graph->AddNode(ret_def, &s);
541 TF_RETURN_IF_ERROR(s);
542 function_body.graph->AddEdge(arg_node, 0, ret_node, 0);
543
544 return Status::OK();
545 }
546
ReplaceLiftedArgNodePlaceholderWithArg(const FunctionBody & function_body,const int original_arg_count,const int arg_idx,const std::vector<Node * > & lifted_arg_nodes,Node * arg_node)547 void ReplaceLiftedArgNodePlaceholderWithArg(
548 const FunctionBody& function_body, const int original_arg_count,
549 const int arg_idx, const std::vector<Node*>& lifted_arg_nodes,
550 Node* arg_node) {
551 Node* lifted_arg_node = lifted_arg_nodes[arg_idx - original_arg_count];
552 // This might happen because lifted_arg_node only exists in one branch of an
553 // If node, and we are handling the other branch.
554 if (!lifted_arg_node) {
555 return;
556 }
557
558 for (const Edge* e : lifted_arg_node->out_edges()) {
559 if (e->IsControlEdge()) {
560 function_body.graph->AddControlEdge(arg_node, e->dst());
561 } else {
562 function_body.graph->AddEdge(arg_node, 0, e->dst(), e->dst_input());
563 }
564 }
565 function_body.graph->RemoveNode(lifted_arg_node);
566 }
567
568 // Adds function def to function definition library and update the function
569 // callsite operation `callsite_node` to invoke new function instead.
AddFunctionWithNewName(const std::string & new_name,const std::string & func_attr_name,const FunctionDef & function_def,NameAttrList * func_attr,Node * callsite_node,FunctionLibraryDefinition * fld)570 Status AddFunctionWithNewName(const std::string& new_name,
571 const std::string& func_attr_name,
572 const FunctionDef& function_def,
573 NameAttrList* func_attr, Node* callsite_node,
574 FunctionLibraryDefinition* fld) {
575 TF_RETURN_IF_ERROR(fld->AddFunctionDef(function_def));
576 func_attr->set_name(new_name);
577 callsite_node->ClearAttr(func_attr_name);
578 callsite_node->AddAttr(func_attr_name, *func_attr);
579 return Status::OK();
580 }
581
582 // Reconnect outside compilation lifted arguments in a functional While node to
583 // its outside compilation tensor sources.
PostprocessLiftedArgsForWhile(const std::unordered_map<string,Node * > & outside_compilation_attr_to_node,Graph * g,Node * n,FunctionLibraryDefinition * fld)584 Status PostprocessLiftedArgsForWhile(
585 const std::unordered_map<string, Node*>& outside_compilation_attr_to_node,
586 Graph* g, Node* n, FunctionLibraryDefinition* fld) {
587 TF_RET_CHECK(n->IsWhileNode());
588
589 // Check if there is any lifted args in body function.
590 NameAttrList body_func;
591 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "body", &body_func));
592 const FunctionDef* body_function_def = fld->Find(body_func.name());
593 TF_RET_CHECK(body_function_def);
594
595 if (!HasLiftedArgs(*body_function_def)) {
596 return Status::OK();
597 }
598
599 // Gather all lifted args.
600 std::unique_ptr<FunctionBody> body_function_body;
601 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*body_function_def,
602 AttrSlice(&body_func.attr()), fld,
603 &body_function_body));
604
605 int original_arg_count = body_function_body->arg_nodes.size();
606
607 TF_ASSIGN_OR_RETURN(
608 auto lifted_arg_nodes_and_outside_compilation_nodes,
609 LiftedArgsAndOutsideCompilationNodesInFunctionBody(
610 *body_function_body, outside_compilation_attr_to_node));
611
612 // Append lifted args' types to While node's T attribute.
613 TF_ASSIGN_OR_RETURN(
614 std::vector<DataType> data_types,
615 UpdateTypesAttribute(lifted_arg_nodes_and_outside_compilation_nodes, "T",
616 n));
617
618 // Add edges from outside compilation nodes to While node.
619 std::vector<Node*> outside_compilation_nodes;
620 std::transform(
621 lifted_arg_nodes_and_outside_compilation_nodes.begin(),
622 lifted_arg_nodes_and_outside_compilation_nodes.end(),
623 std::back_inserter(outside_compilation_nodes),
624 [](const std::pair<Node*, Node*>& pair) { return pair.second; });
625 AddEdgesFromOutsideCompilationNodes(original_arg_count,
626 /*arg_to_input_edge_offset=*/0,
627 data_types, outside_compilation_nodes, g,
628 n);
629
630 // In body_graph, create new _Arg/_Retval nodes, and replace lifted arg
631 // nodes with the new _Arg nodes.
632 std::vector<Node*> lifted_arg_nodes;
633 std::transform(
634 lifted_arg_nodes_and_outside_compilation_nodes.begin(),
635 lifted_arg_nodes_and_outside_compilation_nodes.end(),
636 std::back_inserter(lifted_arg_nodes),
637 [](const std::pair<Node*, Node*>& pair) { return pair.first; });
638 for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
639 TF_ASSIGN_OR_RETURN(Node * arg_node,
640 AddOutsideCompilationInputArgToFunctionBody(
641 *body_function_body, i, data_types[i]));
642
643 TF_RETURN_IF_ERROR(
644 AddMatchingRetvalNode(*body_function_body, i, data_types[i], arg_node));
645
646 ReplaceLiftedArgNodePlaceholderWithArg(
647 *body_function_body, original_arg_count, i, lifted_arg_nodes, arg_node);
648 }
649
650 const auto new_body_function_name =
651 fld->UniqueFunctionName(absl::StrCat(body_func.name(), "_lifted_arg_"));
652 FunctionDef rewritten_body_function_def;
653 TF_RETURN_IF_ERROR(GraphToFunctionDef(
654 *body_function_body->graph, new_body_function_name,
655 HostGraphControlRetMapping, &rewritten_body_function_def));
656 TF_RETURN_IF_ERROR(AddFunctionWithNewName(new_body_function_name, "body",
657 rewritten_body_function_def,
658 &body_func, n, fld));
659
660 // In cond_graph, just add new _Arg nodes.
661 NameAttrList cond_func;
662 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "cond", &cond_func));
663 const FunctionDef* cond_function_def = fld->Find(cond_func.name());
664 TF_RET_CHECK(cond_function_def);
665 std::unique_ptr<FunctionBody> cond_function_body;
666 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*cond_function_def,
667 AttrSlice(&cond_func.attr()), fld,
668 &cond_function_body));
669
670 for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
671 xla::StatusOr<Node*> arg_node_or =
672 AddOutsideCompilationInputArgToFunctionBody(*cond_function_body, i,
673 data_types[i]);
674 TF_RETURN_IF_ERROR(arg_node_or.status());
675 }
676
677 const auto new_cond_function_name =
678 fld->UniqueFunctionName(absl::StrCat(cond_func.name(), "_lifted_arg_"));
679 FunctionDef rewritten_cond_function_def;
680 TF_RETURN_IF_ERROR(GraphToFunctionDef(
681 *cond_function_body->graph, new_cond_function_name,
682 HostGraphControlRetMapping, &rewritten_cond_function_def));
683 TF_RETURN_IF_ERROR(AddFunctionWithNewName(new_cond_function_name, "cond",
684 rewritten_cond_function_def,
685 &cond_func, n, fld));
686 return Status::OK();
687 }
688
PostprocessLiftedArgsForIf(const std::unordered_map<string,Node * > & outside_compilation_attr_to_node,Graph * g,Node * n,FunctionLibraryDefinition * fld)689 Status PostprocessLiftedArgsForIf(
690 const std::unordered_map<string, Node*>& outside_compilation_attr_to_node,
691 Graph* g, Node* n, FunctionLibraryDefinition* fld) {
692 TF_RET_CHECK(n->IsIfNode());
693
694 NameAttrList then_branch_func;
695 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "then_branch", &then_branch_func));
696 const FunctionDef* then_branch_function_def =
697 fld->Find(then_branch_func.name());
698 TF_RET_CHECK(then_branch_function_def);
699
700 NameAttrList else_branch_func;
701 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "else_branch", &else_branch_func));
702 const FunctionDef* else_branch_function_def =
703 fld->Find(else_branch_func.name());
704 TF_RET_CHECK(else_branch_function_def);
705
706 // Nothing to do if neither branch contains any lifted arguments.
707 if (!HasLiftedArgs(*then_branch_function_def) &&
708 !HasLiftedArgs(*else_branch_function_def)) {
709 return Status::OK();
710 }
711
712 std::unique_ptr<FunctionBody> then_branch_function_body;
713 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
714 *then_branch_function_def, AttrSlice(&then_branch_func.attr()), fld,
715 &then_branch_function_body));
716
717 std::unique_ptr<FunctionBody> else_branch_function_body;
718 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
719 *else_branch_function_def, AttrSlice(&else_branch_func.attr()), fld,
720 &else_branch_function_body));
721
722 // Then and else branches have same argument count and argument data types.
723 int original_arg_count = then_branch_function_body->arg_nodes.size();
724
725 TF_ASSIGN_OR_RETURN(
726 auto then_branch_lifted_arg_nodes_and_outside_compilation_nodes,
727 LiftedArgsAndOutsideCompilationNodesInFunctionBody(
728 *then_branch_function_body, outside_compilation_attr_to_node));
729
730 TF_ASSIGN_OR_RETURN(
731 auto else_branch_lifted_arg_nodes_and_outside_compilation_nodes,
732 LiftedArgsAndOutsideCompilationNodesInFunctionBody(
733 *else_branch_function_body, outside_compilation_attr_to_node));
734
735 // Merge lifted args from then and else branches.
736 std::vector<Node*> outside_compilation_nodes;
737 std::vector<Node*> then_branch_lifted_arg_nodes;
738 for (const auto& pair :
739 then_branch_lifted_arg_nodes_and_outside_compilation_nodes) {
740 outside_compilation_nodes.push_back(pair.second);
741 then_branch_lifted_arg_nodes.push_back(pair.first);
742 }
743 for (const auto& pair :
744 else_branch_lifted_arg_nodes_and_outside_compilation_nodes) {
745 if (std::find(outside_compilation_nodes.begin(),
746 outside_compilation_nodes.end(),
747 pair.second) == outside_compilation_nodes.end()) {
748 outside_compilation_nodes.push_back(pair.second);
749 // Then branch does not contain this lifted arg. Add an empty item to
750 // then_branch_lifted_arg_nodes.
751 then_branch_lifted_arg_nodes.push_back(nullptr);
752 }
753 }
754 // Reorder else_branch_lifted_arg_nodes_and_outside_compilation_nodes.
755 std::vector<Node*> else_branch_lifted_arg_nodes(
756 outside_compilation_nodes.size());
757 for (const auto& pair :
758 else_branch_lifted_arg_nodes_and_outside_compilation_nodes) {
759 auto iter = std::find(outside_compilation_nodes.begin(),
760 outside_compilation_nodes.end(), pair.second);
761 TF_RET_CHECK(iter != outside_compilation_nodes.end());
762 int index = iter - outside_compilation_nodes.begin();
763 else_branch_lifted_arg_nodes[index] = pair.first;
764 }
765
766 // Append lifted args' types to If node's Tin attribute.
767 std::vector<DataType> data_types;
768 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "Tin", &data_types));
769 for (Node* n : outside_compilation_nodes) {
770 data_types.push_back(n->output_type(0));
771 }
772 n->ClearAttr("Tin");
773 n->AddAttr("Tin", data_types);
774
775 // Add edges from outside compilation nodes to If node. If node's input #0
776 // is predicate input, input #1 maps to _Arg #0 of branch functions, thus
777 // arg_to_input_edge_offset is set to 1.
778 AddEdgesFromOutsideCompilationNodes(original_arg_count,
779 /*arg_to_input_edge_offset=*/1,
780 data_types, outside_compilation_nodes, g,
781 n);
782
783 for (int i = original_arg_count, end = data_types.size(); i < end; ++i) {
784 TF_ASSIGN_OR_RETURN(Node * then_branch_arg_node,
785 AddOutsideCompilationInputArgToFunctionBody(
786 *then_branch_function_body, i, data_types[i]));
787
788 ReplaceLiftedArgNodePlaceholderWithArg(
789 *then_branch_function_body, original_arg_count, i,
790 then_branch_lifted_arg_nodes, then_branch_arg_node);
791
792 TF_ASSIGN_OR_RETURN(Node * else_branch_arg_node,
793 AddOutsideCompilationInputArgToFunctionBody(
794 *else_branch_function_body, i, data_types[i]));
795
796 ReplaceLiftedArgNodePlaceholderWithArg(
797 *else_branch_function_body, original_arg_count, i,
798 else_branch_lifted_arg_nodes, else_branch_arg_node);
799 }
800
801 const auto new_then_function_name = fld->UniqueFunctionName(
802 absl::StrCat(then_branch_func.name(), "_lifted_arg_"));
803 FunctionDef rewritten_then_branch_function_def;
804 TF_RETURN_IF_ERROR(GraphToFunctionDef(
805 *then_branch_function_body->graph, new_then_function_name,
806 HostGraphControlRetMapping, &rewritten_then_branch_function_def));
807 TF_RETURN_IF_ERROR(AddFunctionWithNewName(
808 new_then_function_name, "then_branch", rewritten_then_branch_function_def,
809 &then_branch_func, n, fld));
810
811 const auto new_else_function_name = fld->UniqueFunctionName(
812 absl::StrCat(else_branch_func.name(), "_lifted_arg_"));
813 FunctionDef rewritten_else_branch_function_def;
814 TF_RETURN_IF_ERROR(GraphToFunctionDef(
815 *else_branch_function_body->graph, new_else_function_name,
816 HostGraphControlRetMapping, &rewritten_else_branch_function_def));
817 TF_RETURN_IF_ERROR(AddFunctionWithNewName(
818 new_else_function_name, "else_branch", rewritten_else_branch_function_def,
819 &else_branch_func, n, fld));
820 return Status::OK();
821 }
822
PostprocessLiftedArgsForCall(const std::unordered_map<string,Node * > & outside_compilation_attr_to_node,Graph * g,Node * n,FunctionLibraryDefinition * fld)823 Status PostprocessLiftedArgsForCall(
824 const std::unordered_map<string, Node*>& outside_compilation_attr_to_node,
825 Graph* g, Node* n, FunctionLibraryDefinition* fld) {
826 const FunctionDef* fdef = fld->Find(n->type_string());
827 TF_RET_CHECK(fdef);
828
829 // Nothing to do if the function does not contain any lifted arguments.
830 if (!HasLiftedArgs(*fdef)) {
831 return Status::OK();
832 }
833
834 std::unique_ptr<FunctionBody> fbody;
835 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, n->attrs(), fld, &fbody));
836
837 int original_arg_count = fbody->arg_nodes.size();
838
839 TF_ASSIGN_OR_RETURN(auto lifted_arg_nodes_and_outside_compilation_nodes,
840 LiftedArgsAndOutsideCompilationNodesInFunctionBody(
841 *fbody, outside_compilation_attr_to_node));
842
843 // Append lifted args' types to call node's input data types.
844 std::vector<DataType> data_types(n->input_types().begin(),
845 n->input_types().end());
846 for (auto pair : lifted_arg_nodes_and_outside_compilation_nodes) {
847 Node* outside_compilation_node = pair.second;
848 DataType data_type;
849 TF_RET_CHECK(outside_compilation_node->IsIdentity() ||
850 outside_compilation_node->type_string() == "Placeholder");
851 if (outside_compilation_node->IsIdentity()) {
852 TF_RETURN_IF_ERROR(
853 GetNodeAttr(outside_compilation_node->def(), "T", &data_type));
854 } else {
855 TF_RETURN_IF_ERROR(
856 GetNodeAttr(outside_compilation_node->def(), "dtype", &data_type));
857 }
858 data_types.push_back(data_type);
859 }
860
861 std::vector<Node*> lifted_arg_nodes;
862 std::transform(
863 lifted_arg_nodes_and_outside_compilation_nodes.begin(),
864 lifted_arg_nodes_and_outside_compilation_nodes.end(),
865 std::back_inserter(lifted_arg_nodes),
866 [](const std::pair<Node*, Node*>& pair) { return pair.first; });
867 for (int i = original_arg_count, end = data_types.size(); i < end; ++i) {
868 TF_ASSIGN_OR_RETURN(
869 Node * arg_node,
870 AddOutsideCompilationInputArgToFunctionBody(*fbody, i, data_types[i]));
871
872 ReplaceLiftedArgNodePlaceholderWithArg(*fbody, original_arg_count, i,
873 lifted_arg_nodes, arg_node);
874 }
875
876 FunctionDef rewritten_fdef;
877 TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, n->type_string(),
878 HostGraphControlRetMapping,
879 &rewritten_fdef));
880 const auto new_function_name =
881 fld->UniqueFunctionName(absl::StrCat(n->type_string(), "_lifted_arg_"));
882 rewritten_fdef.mutable_signature()->set_name(new_function_name);
883 TF_RETURN_IF_ERROR(fld->AddFunctionDef(rewritten_fdef));
884
885 // We need to recreate the node. Otherwise TF will not know n->num_inputs()
886 // has increased.
887 NodeDef node_def = n->def();
888
889 // Function name is represented via the Op's type. Reset the op type to new
890 // function def name;
891 *node_def.mutable_op() = new_function_name;
892
893 for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
894 Node* outside_compilation_node =
895 lifted_arg_nodes_and_outside_compilation_nodes[i - original_arg_count]
896 .second;
897 node_def.add_input(absl::StrCat(outside_compilation_node->name(), ":", 0));
898 }
899 TF_ASSIGN_OR_RETURN(n, ReplaceNode(g, n, node_def));
900
901 // Add edges from outside compilation nodes to call node.
902 std::vector<Node*> outside_compilation_nodes;
903 std::transform(
904 lifted_arg_nodes_and_outside_compilation_nodes.begin(),
905 lifted_arg_nodes_and_outside_compilation_nodes.end(),
906 std::back_inserter(outside_compilation_nodes),
907 [](const std::pair<Node*, Node*>& pair) { return pair.second; });
908 AddEdgesFromOutsideCompilationNodes(original_arg_count,
909 /*arg_to_input_edge_offset=*/0,
910 data_types, outside_compilation_nodes, g,
911 n);
912
913 return Status::OK();
914 }
915
916 // Creates a mapping from outside compilation cluster name to lifted argument
917 // placeholder.
OutsideCompilationAttrToNode(const Graph & g)918 xla::StatusOr<std::unordered_map<string, Node*>> OutsideCompilationAttrToNode(
919 const Graph& g) {
920 std::unordered_map<string, Node*> outside_compilation_attr_to_node;
921 for (Node* n : g.op_nodes()) {
922 bool is_lifted_arg;
923 string outside_compilation_attr;
924 if (TryGetNodeAttr(n->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg) &&
925 TryGetNodeAttr(n->def(), "_xla_outside_compilation",
926 &outside_compilation_attr)) {
927 TF_RET_CHECK(is_lifted_arg);
928 TF_RET_CHECK(n->IsIdentity() || n->type_string() == "Placeholder");
929 outside_compilation_attr_to_node[outside_compilation_attr] = n;
930 }
931 }
932
933 return outside_compilation_attr_to_node;
934 }
935
PostprocessLiftedArgs(Graph * g,FunctionLibraryDefinition * fld)936 Status PostprocessLiftedArgs(Graph* g, FunctionLibraryDefinition* fld) {
937 TF_ASSIGN_OR_RETURN(auto outside_compilation_attr_to_node,
938 OutsideCompilationAttrToNode(*g));
939
940 std::vector<Node*> call_nodes;
941 for (Node* n : g->op_nodes()) {
942 if (!HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) {
943 continue;
944 }
945
946 if (n->IsWhileNode()) {
947 TF_RETURN_IF_ERROR(PostprocessLiftedArgsForWhile(
948 outside_compilation_attr_to_node, g, n, fld));
949 }
950
951 if (n->IsIfNode()) {
952 TF_RETURN_IF_ERROR(PostprocessLiftedArgsForIf(
953 outside_compilation_attr_to_node, g, n, fld));
954 }
955
956 // Outside compilation host side function call will always be direct
957 // function call nodes.
958 // Function call nodes need to be handled separately because we rewrite
959 // nodes in `PostprocessLiftedArgsForCall`.
960 if (fld->Contains(n->type_string())) {
961 call_nodes.push_back(n);
962 }
963 }
964
965 for (Node* n : call_nodes) {
966 TF_RETURN_IF_ERROR(PostprocessLiftedArgsForCall(
967 outside_compilation_attr_to_node, g, n, fld));
968 }
969
970 return Status::OK();
971 }
972
973 // For an XLA computation, builds host side graph given all outside compilation
974 // graphs inside it. The host side graph contains:
975 // 1) a "sequencer" node (we will add control edge between XlaRecvAtHost and
976 // XlaSendFromHost to this sequencer node, so all outside compilation nodes
977 // will be executed *before* this sequencer).
978 // 2) a "key placeholder" node. Later in ExpandHostGraphIntoMainGraph(), we will
979 // replace this node with compilation result node.
980 // 3) all outside compilation graphs.
ConstructHostGraph(const string & xla_cluster_name,const string & outside_compilation_attr_name,const std::vector<string> & outside_compilation_host_graphs,FunctionLibraryDefinition * fld,std::unique_ptr<Graph> * host_graph)981 Status ConstructHostGraph(
982 const string& xla_cluster_name, const string& outside_compilation_attr_name,
983 const std::vector<string>& outside_compilation_host_graphs,
984 FunctionLibraryDefinition* fld, std::unique_ptr<Graph>* host_graph) {
985 host_graph->reset(new Graph(fld));
986
987 // Create sequencer node in host graph.
988 NodeDefBuilder sequencer_builder(absl::StrCat(xla_cluster_name, "_sequencer"),
989 "NoOp");
990 sequencer_builder.Attr("_xla_host_transfer_sequencer", xla_cluster_name);
991 NodeDef sequencer_def;
992 TF_RETURN_IF_ERROR(sequencer_builder.Finalize(&sequencer_def));
993 Status s;
994 Node* sequencer = (*host_graph)->AddNode(sequencer_def, &s);
995 TF_RETURN_IF_ERROR(s);
996
997 // Create key placeholder in host graph.
998 TF_ASSIGN_OR_RETURN(
999 Node * key_placeholder,
1000 AddHostComputeKeyPlaceholder(xla_cluster_name, host_graph->get()));
1001
1002 // For each outside compilation graph, copy them to host graph with the
1003 // following changes:
1004 // a) Use key_placeholder in host graph instead of its own.
1005 // b) Add control edge from host transfer nodes (XlaRecvAtHost,
1006 // XlaSendFromHost, If/While nodes containing
1007 // XlaRecvAtHost/XlaSendFromHost) to sequencer node.
1008 // c) Clear node_def.device(), so device placer won't get confused.
1009 for (const string& host_func : outside_compilation_host_graphs) {
1010 VLOG(4) << "Expanding host graph " << host_func;
1011 // Temporarily use "0" as "_device_ordinal". It will be reset to placeholder
1012 // value after we expanded all host graphs. We cannot just use placeholder
1013 // value here because FunctionDef instantiation does not allow placeholder
1014 // value for attributes.
1015 AttrValue device_ordinal_attr;
1016 device_ordinal_attr.set_i(0);
1017 protobuf::Map<string, AttrValue> attrs;
1018 attrs["_device_ordinal"] = device_ordinal_attr;
1019 std::unique_ptr<FunctionBody> host_fbody;
1020 const FunctionDef* host_fdef = fld->Find(host_func);
1021 TF_RET_CHECK(host_fdef);
1022 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*host_fdef, AttrSlice(&attrs),
1023 fld, &host_fbody));
1024
1025 // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse
1026 // reachable from sink node so all nodes will be copied.
1027 // TODO(b/77601805): consolidate copy graph functions.
1028 FixupSourceAndSinkEdges(host_fbody->graph);
1029
1030 std::map<const Node*, Node*> node_map;
1031 node_map[host_fbody->graph->source_node()] = (*host_graph)->source_node();
1032 node_map[host_fbody->graph->sink_node()] = (*host_graph)->sink_node();
1033 Status s;
1034 ReverseDFS(
1035 *host_fbody->graph, /*enter=*/nullptr,
1036 [&](const Node* n) {
1037 if (!s.ok()) {
1038 return;
1039 }
1040
1041 Node* copy;
1042 if (node_map.find(n) != node_map.end()) {
1043 // Already copied this node.
1044 copy = node_map.at(n);
1045 } else if (IsKeyPlaceholderNode(*n)) {
1046 // Change a).
1047 copy = key_placeholder;
1048 node_map[n] = copy;
1049 } else {
1050 // Copy the node.
1051 NodeDef copy_def = n->def();
1052 // Change c).
1053 copy_def.clear_device();
1054 copy = (*host_graph)->AddNode(copy_def, &s);
1055 if (!s.ok()) {
1056 return;
1057 }
1058 node_map[n] = copy;
1059 }
1060
1061 // Only handle input edges. Output edges will be added later as
1062 // its output nodes' input edges.
1063 for (auto e : n->in_edges()) {
1064 if (node_map.find(e->src()) == node_map.end()) {
1065 s = errors::Internal("Cannot find node image for ",
1066 e->src()->DebugString());
1067 return;
1068 }
1069 (*host_graph)
1070 ->AddEdge(node_map[e->src()], e->src_output(), copy,
1071 e->dst_input());
1072 }
1073
1074 // Change b).
1075 if (HasNodeAttr(copy->def(), kXlaHasHostTransferAttrName)) {
1076 (*host_graph)->AddControlEdge(copy, sequencer);
1077 }
1078 },
1079 NodeComparatorID());
1080
1081 if (!s.ok()) {
1082 return s;
1083 }
1084 }
1085 // Reset "_device_ordinal" to placeholder value.
1086 TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(host_graph->get()));
1087
1088 // sequencer and key_placeholder might be dead nodes. Prune them if necessary.
1089 // - sequencer should be pruned iff it has no input control edges from
1090 // RecvAtHost/SendFromHost. If it has input control edge, we connect it to
1091 // sink node so it won't be pruned.
1092 // - key_placeholder should be pruned iff there's no RecvAtHost/SendFromHost.
1093 // We don't need to do anything special.
1094 if (!sequencer->in_edges().empty()) {
1095 (*host_graph)->AddControlEdge(sequencer, (*host_graph)->sink_node());
1096 }
1097 PruneForReverseReachability(
1098 host_graph->get(),
1099 std::unordered_set<const Node*>{(*host_graph)->sink_node()});
1100
1101 // Postprocess edges between different outside compilations.
1102 TF_RETURN_IF_ERROR(PostprocessEdgesBetweenOutsideCompilations(
1103 host_graph->get(), outside_compilation_attr_name));
1104
1105 // Postprocess lifted arg nodes.
1106 TF_RETURN_IF_ERROR(PostprocessLiftedArgs(host_graph->get(), fld));
1107
1108 if (VLOG_IS_ON(4)) {
1109 DumpGraphToFile(absl::StrCat("extract_outside_compilation_host_graph_for_",
1110 xla_cluster_name),
1111 **host_graph, fld);
1112 }
1113
1114 return Status::OK();
1115 }
1116
1117 // Expand XLA computation's outside compilation host side graph into main graph.
1118 // Add a control edge between sequencer node and the XLA computation node.
ExpandHostGraphIntoMainGraph(Graph * main_graph,FunctionLibraryDefinition * fld,const string & host_graph_func_name,Node * xla_computation_node,Node * pivot_node)1119 Status ExpandHostGraphIntoMainGraph(Graph* main_graph,
1120 FunctionLibraryDefinition* fld,
1121 const string& host_graph_func_name,
1122 Node* xla_computation_node,
1123 Node* pivot_node) {
1124 // Temporarily use "0" as "_device_ordinal". It will be rewritten with the
1125 // correct value in a later pass. We cannot just use placeholder value here
1126 // because FunctionDef instantiation does not allow placeholder value for
1127 // attributes.
1128 AttrValue device_ordinal_attr;
1129 device_ordinal_attr.set_i(0);
1130 protobuf::Map<string, AttrValue> attrs;
1131 attrs["_device_ordinal"] = device_ordinal_attr;
1132 std::unique_ptr<FunctionBody> fbody;
1133 const FunctionDef* host_graph_func = fld->Find(host_graph_func_name);
1134 TF_RET_CHECK(host_graph_func);
1135 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*host_graph_func,
1136 AttrSlice(&attrs), fld, &fbody));
1137 Graph* host_graph = fbody->graph;
1138
1139 // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse
1140 // reachable from sink node so all nodes will be copied.
1141 // TODO(b/77601805): consolidate copy graph functions.
1142 FixupSourceAndSinkEdges(host_graph);
1143
1144 // Copy all nodes.
1145 std::map<const Node*, Node*> node_map;
1146 if (pivot_node) {
1147 node_map[host_graph->source_node()] = pivot_node;
1148 } else {
1149 node_map[host_graph->source_node()] = main_graph->source_node();
1150 }
1151 node_map[host_graph->sink_node()] = main_graph->sink_node();
1152 Status s = Status::OK();
1153 auto copy_node_fn = [&](const Node* n) {
1154 if (!s.ok()) {
1155 return;
1156 }
1157
1158 Node* copy;
1159 if (node_map.find(n) != node_map.end()) {
1160 // Already copied this node.
1161 copy = node_map.at(n);
1162 } else {
1163 // Copy the node.
1164 NodeDef copy_def = n->def();
1165 copy = main_graph->AddNode(copy_def, &s);
1166 if (!s.ok()) {
1167 return;
1168 }
1169 node_map[n] = copy;
1170 }
1171
1172 // Only handle input edges. Output edges will be added later as its output
1173 // nodes' input edges.
1174 for (auto e : n->in_edges()) {
1175 if (node_map.find(e->src()) == node_map.end()) {
1176 s = errors::Internal("Cannot find node image for ",
1177 e->src()->DebugString());
1178 return;
1179 }
1180 main_graph->AddEdge(node_map[e->src()], e->src_output(), copy,
1181 e->dst_input());
1182 }
1183
1184 // Add control edge from sequencer to XLA computation node.
1185 if (copy->type_string() == "NoOp" &&
1186 HasNodeAttr(copy->def(), "_xla_host_transfer_sequencer")) {
1187 main_graph->AddControlEdge(copy, xla_computation_node);
1188 }
1189 };
1190 ReverseDFS(*host_graph, /*enter=*/nullptr, copy_node_fn, NodeComparatorID());
1191 return s;
1192 }
1193
1194 // Rewrites shape inference graph for outside compilation:
1195 // 1) If XlaSendFromHost also exists in `host_graph`, copy nodes from
1196 // `host_graph`. Because we might still have outside compilation to outside
1197 // compilation placeholder nodes in shape inference graph, which will prevent
1198 // us from inferring XlaSendFromHost shape. But in `host_graph`, we already
1199 // removed those placeholder nodes.
1200 // 2) Remove control edges.
1201 // 3) Prune nodes that are not useful for shape inference.
RewriteShapeInferenceGraph(const string & shape_inference_graph_name,Graph * host_graph,Node * pivot_node,FunctionLibraryDefinition * fld)1202 Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name,
1203 Graph* host_graph, Node* pivot_node,
1204 FunctionLibraryDefinition* fld) {
1205 // Use "0" as "_device_ordinal". It does not matter for shape inference.
1206 AttrValue device_ordinal_attr;
1207 device_ordinal_attr.set_i(0);
1208 protobuf::Map<string, AttrValue> attrs;
1209 attrs["_device_ordinal"] = device_ordinal_attr;
1210 std::unique_ptr<FunctionBody> fbody;
1211 const FunctionDef* shape_inference_graph =
1212 fld->Find(shape_inference_graph_name);
1213 TF_RET_CHECK(shape_inference_graph);
1214 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*shape_inference_graph,
1215 AttrSlice(&attrs), fld, &fbody));
1216 Graph* g = fbody->graph;
1217
1218 // Find SendFromHost node.
1219 Node* send_from_host = nullptr;
1220 for (Node* n : g->nodes()) {
1221 if (n->type_string() == "_XlaSendFromHost") {
1222 send_from_host = n;
1223 break;
1224 }
1225 }
1226 if (!send_from_host) {
1227 return errors::Internal("Shape inference graph ",
1228 shape_inference_graph_name,
1229 " does not have _XlaSendFromHost node.");
1230 }
1231
1232 // See if the SendFromHost node exists in `host_graph`.
1233 Node* send_node_in_host_graph = nullptr;
1234 for (Node* n : host_graph->nodes()) {
1235 if (n->name() == send_from_host->name()) {
1236 send_node_in_host_graph = n;
1237 break;
1238 }
1239 }
1240 if (send_node_in_host_graph) {
1241 // This is an "top-level" outside compilation. Clear the graph, and copy
1242 // SendFromHost and all its predecessors from `host_graph`.
1243 std::vector<Node*> nodes;
1244 for (Node* n : g->op_nodes()) {
1245 nodes.push_back(n);
1246 }
1247 for (Node* n : nodes) {
1248 g->RemoveNode(n);
1249 }
1250 Node* start_node = pivot_node ? pivot_node : host_graph->source_node();
1251 // Reverse DFS from send_from_host_main_graph, and stop at start_node.
1252 struct Visit {
1253 Node* n;
1254 bool is_exiting;
1255 };
1256 std::vector<Visit> stack{{send_node_in_host_graph, false}};
1257 std::map<Node*, Node*> node_map;
1258 node_map[host_graph->source_node()] = g->source_node();
1259 while (!stack.empty()) {
1260 Visit& curr = stack.back();
1261 if (curr.is_exiting) {
1262 if (node_map.find(curr.n) == node_map.end()) {
1263 Node* copy = g->CopyNode(curr.n);
1264 if (curr.n != start_node) {
1265 for (const Edge* e : curr.n->in_edges()) {
1266 auto node_iter = node_map.find(e->src());
1267 if (node_iter == node_map.end()) {
1268 return errors::Internal("Cannot find node image for ",
1269 e->src()->DebugString());
1270 }
1271 g->AddEdge(node_iter->second, e->src_output(), copy,
1272 e->dst_input());
1273 }
1274 }
1275 node_map[curr.n] = copy;
1276 }
1277 stack.pop_back();
1278 } else {
1279 curr.is_exiting = true;
1280 if (curr.n != start_node) {
1281 for (const Edge* e : curr.n->in_edges()) {
1282 if (node_map.find(e->src()) != node_map.end()) {
1283 continue;
1284 }
1285 stack.push_back({e->src(), false});
1286 }
1287 }
1288 }
1289 }
1290
1291 send_from_host = node_map[send_node_in_host_graph];
1292 } else {
1293 // This is an outside compilation generated for If/While/gradient/etc.
1294 // It will be enough for shape inference. Leave `g` unchanged.
1295 }
1296
1297 // Control edges are not useful for shape inference. Remove them.
1298 for (auto e : g->edges()) {
1299 if (e->IsControlEdge()) {
1300 g->RemoveEdge(e);
1301 }
1302 }
1303
1304 // Nodes that are not reverse reachable from SendFromHost are not useful for
1305 // shape inference. Prune them.
1306 PruneForReverseReachability(g,
1307 std::unordered_set<const Node*>{send_from_host});
1308
1309 if (VLOG_IS_ON(4)) {
1310 DumpGraphToFile(shape_inference_graph_name, *g, fld);
1311 }
1312
1313 // Replace original shape inference graph.
1314 FunctionDef fdef_replace;
1315 TF_RETURN_IF_ERROR(
1316 GraphToFunctionDef(*g, shape_inference_graph_name, &fdef_replace));
1317 TF_RETURN_IF_ERROR(
1318 fld->ReplaceFunction(shape_inference_graph_name, fdef_replace));
1319
1320 return Status::OK();
1321 }
1322
1323 // Builds XlaSendToHost node which sends cond predicate to host.
BuildSendIfPredNode(const string & name,const string & host_transfer_key,Node * pred_node,Graph * g)1324 TF_ATTRIBUTE_NOINLINE xla::StatusOr<Node*> BuildSendIfPredNode(
1325 const string& name, const string& host_transfer_key, Node* pred_node,
1326 Graph* g) {
1327 NodeDefBuilder send_pred_builder(name, "XlaSendToHost");
1328 send_pred_builder.Attr("Tinput", DT_BOOL);
1329 send_pred_builder.Attr("key", absl::StrCat(host_transfer_key, "_dtoh_0"));
1330 send_pred_builder.Attr(kXlaTokenInputNodesAttrName,
1331 std::vector<string>{kXlaTokenArgNodeName});
1332 send_pred_builder.Attr(kXlaOriginalOutsideCompilationNodeName, name);
1333 send_pred_builder.Input(pred_node->name(), 0, DT_BOOL);
1334 NodeDef send_pred_def;
1335 TF_RETURN_IF_ERROR(send_pred_builder.Finalize(&send_pred_def));
1336 Status s;
1337 Node* send_pred_node = g->AddNode(send_pred_def, &s);
1338 TF_RETURN_IF_ERROR(s);
1339 g->AddEdge(pred_node, 0, send_pred_node, 0);
1340 return send_pred_node;
1341 }
1342
1343 // Replaces key placeholder node with an _Arg node.
ReplaceKeyPlaceholderWithArgNode(const string & xla_cluster_name,const string & func_name,FunctionLibraryDefinition * fld)1344 Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name,
1345 const string& func_name,
1346 FunctionLibraryDefinition* fld) {
1347 // Temporarily use "0" as "_device_ordinal". It will be reset to placeholder
1348 // value after rewriting.
1349 AttrValue device_ordinal_attr;
1350 device_ordinal_attr.set_i(0);
1351 protobuf::Map<string, AttrValue> attrs;
1352 attrs["_device_ordinal"] = device_ordinal_attr;
1353 std::unique_ptr<FunctionBody> fbody;
1354 const FunctionDef* func = fld->Find(func_name);
1355 TF_RETURN_IF_ERROR(
1356 FunctionDefToBodyHelper(*func, AttrSlice(&attrs), fld, &fbody));
1357 Graph* g = fbody->graph;
1358
1359 // Find or create the key placeholder node.
1360 Node* key_placeholder = nullptr;
1361 for (Node* n : g->nodes()) {
1362 if (IsKeyPlaceholderNode(*n)) {
1363 key_placeholder = n;
1364 break;
1365 }
1366 }
1367 if (!key_placeholder) {
1368 TF_ASSIGN_OR_RETURN(key_placeholder,
1369 AddHostComputeKeyPlaceholder(xla_cluster_name, g));
1370 }
1371
1372 // Build the _Arg node, and replace key placeholder node with it.
1373 NodeDefBuilder arg_builder("key_arg", FunctionLibraryDefinition::kArgOp);
1374 arg_builder.Attr("T", DT_STRING);
1375 arg_builder.Attr("index", 0);
1376 NodeDef arg_def;
1377 TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def));
1378 TF_RETURN_IF_ERROR(ReplaceNode(g, key_placeholder, arg_def).status());
1379
1380 // Reset "_device_ordinal" to placeholder value.
1381 TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(g));
1382
1383 FunctionDef replace_fdef;
1384 TF_RETURN_IF_ERROR(GraphToFunctionDef(
1385 *g, func_name, HostGraphControlRetMapping, &replace_fdef));
1386 TF_RETURN_IF_ERROR(fld->ReplaceFunction(func_name, replace_fdef));
1387 return Status::OK();
1388 }
1389
1390 // Builds host side graph for If node.
BuildHostGraphForIfNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const string & if_node_name,const string & host_transfer_key,const string & host_graph_func_name,FunctionLibraryDefinition * fld,const string & then_branch_host_func_name,const string & else_branch_host_func_name)1391 TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForIfNode(
1392 const string& xla_cluster_attr_name,
1393 const string& outside_compilation_attr_name, const string& xla_cluster_name,
1394 const string& if_node_name, const string& host_transfer_key,
1395 const string& host_graph_func_name, FunctionLibraryDefinition* fld,
1396 const string& then_branch_host_func_name,
1397 const string& else_branch_host_func_name) {
1398 Graph host_graph(fld);
1399 string outside_compilation_name = absl::StrCat("oc_if_", if_node_name);
1400 AttrValue device_ordinal_value;
1401 device_ordinal_value.set_placeholder("_device_ordinal");
1402
1403 // Step 1: add key placeholder node.
1404 TF_ASSIGN_OR_RETURN(
1405 Node * key_placeholder,
1406 AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
1407
1408 // Step 2: build XlaRecvAtHost node to recv predicate.
1409 NodeDefBuilder recv_pred_builder(
1410 absl::StrCat("recv_oc_if_pred_", if_node_name), "_XlaRecvAtHost");
1411 recv_pred_builder.Attr("Toutputs", std::vector<DataType>{DT_BOOL});
1412 recv_pred_builder.Attr("key", host_transfer_key);
1413 recv_pred_builder.Attr("device_ordinal", device_ordinal_value);
1414 recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1415 recv_pred_builder.Attr(outside_compilation_attr_name,
1416 outside_compilation_name);
1417 recv_pred_builder.Attr(kXlaHasHostTransferAttrName, true);
1418 recv_pred_builder.Input(key_placeholder->name(), 0, DT_STRING);
1419 NodeDef recv_pred_def;
1420 TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def));
1421 Status s;
1422 Node* recv_pred_node = host_graph.AddNode(recv_pred_def, &s);
1423 TF_RETURN_IF_ERROR(s);
1424 host_graph.AddEdge(key_placeholder, 0, recv_pred_node, 0);
1425
1426 // Step 3: rewrite `{then, else}_branch_host_func_name`, replace key
1427 // placeholder with an _Arg node.
1428 TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1429 xla_cluster_name, then_branch_host_func_name, fld));
1430 TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1431 xla_cluster_name, else_branch_host_func_name, fld));
1432
1433 // Step 4: build If node to choose between `{then, else}_branch_host_graph`.
1434 NodeDefBuilder if_builder(absl::StrCat("oc_if_", if_node_name), "If");
1435 if_builder.Attr("Tcond", DT_BOOL);
1436 if_builder.Attr("Tin", std::vector<DataType>{DT_STRING});
1437 if_builder.Attr("Tout", std::vector<DataType>{});
1438 NameAttrList host_then_branch, host_else_branch;
1439 host_then_branch.set_name(then_branch_host_func_name);
1440 (*host_then_branch.mutable_attr())["_device_ordinal"] = device_ordinal_value;
1441 host_else_branch.set_name(else_branch_host_func_name);
1442 (*host_else_branch.mutable_attr())["_device_ordinal"] = device_ordinal_value;
1443 if_builder.Attr("then_branch", host_then_branch);
1444 if_builder.Attr("else_branch", host_else_branch);
1445 if_builder.Attr(kXlaHasHostTransferAttrName, true);
1446 if_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1447 if_builder.Attr(outside_compilation_attr_name, outside_compilation_name);
1448 if_builder.Input(recv_pred_node->name(), 0, DT_BOOL);
1449 std::vector<NodeDefBuilder::NodeOut> if_inputs{
1450 {key_placeholder->name(), 0, DT_STRING}};
1451 if_builder.Input(if_inputs);
1452 NodeDef if_def;
1453 TF_RETURN_IF_ERROR(if_builder.Finalize(&if_def));
1454 Node* if_node = host_graph.AddNode(if_def, &s);
1455 TF_RETURN_IF_ERROR(s);
1456 host_graph.AddEdge(recv_pred_node, 0, if_node, 0);
1457 host_graph.AddEdge(key_placeholder, 0, if_node, 1);
1458
1459 // Convert `host_graph` to function.
1460 FunctionDef oc_host_graph_fdef;
1461 TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name,
1462 &oc_host_graph_fdef));
1463 if (fld->Find(host_graph_func_name)) {
1464 TF_RETURN_IF_ERROR(
1465 fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef));
1466 } else {
1467 TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef));
1468 }
1469
1470 return Status::OK();
1471 }
1472
1473 // Rewrites loop cond to add a node which sends loop cond to host.
AddSendLoopPredToLoopCond(const string & cond_xla_func_name,const string & host_transfer_key,NameAttrList * loop_cond_func,FunctionLibraryDefinition * fld,Node * while_node)1474 TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond(
1475 const string& cond_xla_func_name, const string& host_transfer_key,
1476 NameAttrList* loop_cond_func, FunctionLibraryDefinition* fld,
1477 Node* while_node) {
1478 // Instantiate the loop cond function.
1479 std::unique_ptr<FunctionBody> fbody;
1480 const FunctionDef* loop_cond_fdef = fld->Find(loop_cond_func->name());
1481 TF_RET_CHECK(loop_cond_fdef);
1482 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1483 *loop_cond_fdef, AttrSlice(&loop_cond_func->attr()), fld, &fbody));
1484 Graph* g = fbody->graph;
1485
1486 // Find the _Retval node and the loop cond node.
1487 Node* ret_node = nullptr;
1488 for (Node* n : g->nodes()) {
1489 if (n->type_string() == "_Retval") {
1490 if (ret_node) {
1491 return errors::Internal("Multiple return node for loop cond function ",
1492 loop_cond_func->name(), ": ",
1493 ret_node->DebugString(), " and ",
1494 n->DebugString());
1495 } else {
1496 ret_node = n;
1497 }
1498 }
1499 }
1500 if (!ret_node) {
1501 return errors::Internal("No _Retval node for loop cond function ",
1502 loop_cond_func->name());
1503 }
1504 Node* loop_cond;
1505 TF_RETURN_IF_ERROR(ret_node->input_node(0, &loop_cond));
1506
1507 // Build the XlaSendToHost node.
1508 NodeDefBuilder send_loop_cond_builder(
1509 absl::StrCat("send_oc_while_cond_", while_node->name()), "XlaSendToHost");
1510 send_loop_cond_builder.Attr("Tinput", DT_BOOL);
1511 send_loop_cond_builder.Attr("key",
1512 absl::StrCat(host_transfer_key, "_dtoh_0"));
1513 send_loop_cond_builder.Attr(kXlaTokenInputNodesAttrName,
1514 std::vector<string>{kXlaTokenArgNodeName});
1515 send_loop_cond_builder.Attr(kXlaOriginalOutsideCompilationNodeName,
1516 send_loop_cond_builder.node_name());
1517 send_loop_cond_builder.Input(loop_cond->name(), 0, DT_BOOL);
1518 NodeDef send_loop_cond_def;
1519 TF_RETURN_IF_ERROR(send_loop_cond_builder.Finalize(&send_loop_cond_def));
1520 Status s;
1521 Node* send_loop_cond_node = g->AddNode(send_loop_cond_def, &s);
1522 TF_RETURN_IF_ERROR(s);
1523 g->AddEdge(loop_cond, 0, send_loop_cond_node, 0);
1524
1525 // Replace original function if loop_cond_func already has been re-written
1526 // for outside compilation.
1527 FunctionDef replace_fdef;
1528 if (loop_cond_func->name() == cond_xla_func_name) {
1529 TF_RETURN_IF_ERROR(
1530 GraphToFunctionDef(*g, loop_cond_func->name(), &replace_fdef));
1531 TF_RETURN_IF_ERROR(
1532 fld->ReplaceFunction(loop_cond_func->name(), replace_fdef));
1533 } else {
1534 // If original while cond function has not been modified, add a new function
1535 // with send loop predicated added and update the while node callsite
1536 // operation.
1537 const auto new_name = fld->UniqueFunctionName(
1538 absl::StrCat(loop_cond_func->name(), "_send_pred_added_"));
1539 TF_RETURN_IF_ERROR(GraphToFunctionDef(*g, new_name, &replace_fdef));
1540 TF_RETURN_IF_ERROR(fld->AddFunctionDef(replace_fdef));
1541 loop_cond_func->set_name(new_name);
1542 while_node->ClearAttr("cond");
1543 while_node->AddAttr("cond", *loop_cond_func);
1544 }
1545
1546 return Status::OK();
1547 }
1548
1549 // Rewrites while loop cond function for host.
RewriteHostWhileLoopCond(const string & cond_host_func_name,const string & while_node_name,const string & host_transfer_key,const string & xla_cluster_attr_name,const string & xla_cluster_name,const string & outside_compilation_attr_name,const string & outside_compilation_name,FunctionLibraryDefinition * fld)1550 Status RewriteHostWhileLoopCond(
1551 const string& cond_host_func_name, const string& while_node_name,
1552 const string& host_transfer_key, const string& xla_cluster_attr_name,
1553 const string& xla_cluster_name, const string& outside_compilation_attr_name,
1554 const string& outside_compilation_name, FunctionLibraryDefinition* fld) {
1555 // Replace key placeholder node with _Arg node.
1556 TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1557 xla_cluster_name, cond_host_func_name, fld));
1558
1559 // Instantiate cond function.
1560 AttrValue device_ordinal_temp_value;
1561 device_ordinal_temp_value.set_i(0);
1562 protobuf::Map<string, AttrValue> attrs;
1563 attrs["_device_ordinal"] = device_ordinal_temp_value;
1564 std::unique_ptr<FunctionBody> cond_fbody;
1565 const FunctionDef* cond_host_func = fld->Find(cond_host_func_name);
1566 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*cond_host_func, AttrSlice(&attrs),
1567 fld, &cond_fbody));
1568 Graph* cond_graph = cond_fbody->graph;
1569 Node* key_arg = nullptr;
1570 for (Node* n : cond_graph->nodes()) {
1571 if (n->type_string() == "_Arg") {
1572 key_arg = n;
1573 }
1574 }
1575 if (!key_arg) {
1576 return errors::Internal(
1577 "No _Arg node found for host compute key in function ",
1578 cond_host_func_name);
1579 }
1580
1581 // Add an XlaRecvAtHost node to use as cond function return value.
1582 NodeDefBuilder recv_pred_builder(
1583 absl::StrCat("recv_oc_while_cond_", while_node_name), "_XlaRecvAtHost");
1584 recv_pred_builder.Attr("Toutputs", std::vector<DataType>{DT_BOOL});
1585 recv_pred_builder.Attr("key", host_transfer_key);
1586 AttrValue device_ordinal_value;
1587 device_ordinal_value.set_placeholder("_device_ordinal");
1588 recv_pred_builder.Attr("device_ordinal", device_ordinal_value);
1589 recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1590 recv_pred_builder.Attr(outside_compilation_attr_name,
1591 outside_compilation_name);
1592 recv_pred_builder.Attr(kXlaHasHostTransferAttrName, true);
1593 recv_pred_builder.Input(key_arg->name(), 0, DT_STRING);
1594 NodeDef recv_pred_def;
1595 TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def));
1596 Status s;
1597 Node* recv_pred_node = cond_graph->AddNode(recv_pred_def, &s);
1598 TF_RETURN_IF_ERROR(s);
1599 cond_graph->AddEdge(key_arg, 0, recv_pred_node, 0);
1600 NodeDefBuilder ret_builder(
1601 absl::StrCat("recv_oc_while_cond_ret_", while_node_name), "_Retval");
1602 ret_builder.Attr("T", DT_BOOL);
1603 ret_builder.Attr("index", 0);
1604 ret_builder.Input(recv_pred_node->name(), 0, DT_BOOL);
1605 NodeDef ret_def;
1606 TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
1607 Node* ret_node = cond_graph->AddNode(ret_def, &s);
1608 TF_RETURN_IF_ERROR(s);
1609 cond_graph->AddEdge(recv_pred_node, 0, ret_node, 0);
1610
1611 // Reset device_ordinal to placeholder value.
1612 TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(cond_graph));
1613
1614 // Replace original function.
1615 FunctionDef cond_replace_fdef;
1616 TF_RETURN_IF_ERROR(GraphToFunctionDef(*cond_graph, cond_host_func_name,
1617 HostGraphControlRetMapping,
1618 &cond_replace_fdef));
1619 TF_RETURN_IF_ERROR(
1620 fld->ReplaceFunction(cond_host_func_name, cond_replace_fdef));
1621
1622 return Status::OK();
1623 }
1624
1625 // Rewrites while loop body function for host.
RewriteHostWhileLoopBody(const string & body_host_func_name,const string & while_node_name,const string & host_transfer_key,const string & xla_cluster_attr_name,const string & xla_cluster_name,const string & outside_compilation_attr_name,const string & outside_compilation_name,FunctionLibraryDefinition * fld)1626 Status RewriteHostWhileLoopBody(
1627 const string& body_host_func_name, const string& while_node_name,
1628 const string& host_transfer_key, const string& xla_cluster_attr_name,
1629 const string& xla_cluster_name, const string& outside_compilation_attr_name,
1630 const string& outside_compilation_name, FunctionLibraryDefinition* fld) {
1631 // Replace key placeholder node with _Arg node.
1632 TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1633 xla_cluster_name, body_host_func_name, fld));
1634
1635 // Instantiate body function.
1636 AttrValue device_ordinal_temp_value;
1637 device_ordinal_temp_value.set_i(0);
1638 protobuf::Map<string, AttrValue> attrs;
1639 attrs["_device_ordinal"] = device_ordinal_temp_value;
1640 std::unique_ptr<FunctionBody> body_fbody;
1641 const FunctionDef* body_host_func = fld->Find(body_host_func_name);
1642 TF_RET_CHECK(body_host_func);
1643 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*body_host_func, AttrSlice(&attrs),
1644 fld, &body_fbody));
1645 Graph* body_graph = body_fbody->graph;
1646 Node* key_arg = nullptr;
1647 for (Node* n : body_graph->nodes()) {
1648 if (n->type_string() == "_Arg") {
1649 key_arg = n;
1650 }
1651 }
1652 if (!key_arg) {
1653 return errors::Internal(
1654 "No _Arg node found for host compute key in function ",
1655 body_host_func_name);
1656 }
1657
1658 // Add a _Retval node to loop body.
1659 NodeDefBuilder ret_builder(
1660 absl::StrCat("recv_oc_while_body_ret_", while_node_name), "_Retval");
1661 ret_builder.Attr("T", DT_STRING);
1662 ret_builder.Attr("index", 0);
1663 ret_builder.Input(key_arg->name(), 0, DT_STRING);
1664 NodeDef ret_def;
1665 TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
1666 Status s;
1667 Node* ret_node = body_graph->AddNode(ret_def, &s);
1668 TF_RETURN_IF_ERROR(s);
1669 body_graph->AddEdge(key_arg, 0, ret_node, 0);
1670
1671 // Reset device_ordinal to placeholder value.
1672 TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(body_graph));
1673
1674 // Replace original function.
1675 FunctionDef body_replace_fdef;
1676 TF_RETURN_IF_ERROR(GraphToFunctionDef(*body_graph, body_host_func_name,
1677 HostGraphControlRetMapping,
1678 &body_replace_fdef));
1679 TF_RETURN_IF_ERROR(
1680 fld->ReplaceFunction(body_host_func_name, body_replace_fdef));
1681
1682 return Status::OK();
1683 }
1684
1685 // Builds host side graph for while node.
BuildHostGraphForWhileNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const string & while_node_name,const string & host_transfer_key,const string & host_graph_func_name,FunctionLibraryDefinition * fld,const string & cond_host_func_name,const string & body_host_func_name)1686 TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForWhileNode(
1687 const string& xla_cluster_attr_name,
1688 const string& outside_compilation_attr_name, const string& xla_cluster_name,
1689 const string& while_node_name, const string& host_transfer_key,
1690 const string& host_graph_func_name, FunctionLibraryDefinition* fld,
1691 const string& cond_host_func_name, const string& body_host_func_name) {
1692 Graph host_graph(fld);
1693 string outside_compilation_name = absl::StrCat("oc_while_", while_node_name);
1694
1695 // Step 1: add key placeholder node.
1696 TF_ASSIGN_OR_RETURN(
1697 Node * key_placeholder,
1698 AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
1699
1700 // Step 2: rewrite cond function.
1701 TF_RETURN_IF_ERROR(RewriteHostWhileLoopCond(
1702 cond_host_func_name, while_node_name, host_transfer_key,
1703 xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
1704 outside_compilation_name, fld));
1705
1706 // Step 3: rewrite body function.
1707 TF_RETURN_IF_ERROR(RewriteHostWhileLoopBody(
1708 body_host_func_name, while_node_name, host_transfer_key,
1709 xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
1710 outside_compilation_name, fld));
1711
1712 // Step 4: build While node.
1713 NodeDefBuilder while_builder(absl::StrCat("oc_while_", while_node_name),
1714 "While");
1715 while_builder.Attr("T", std::vector<DataType>{DT_STRING});
1716 NameAttrList func;
1717 AttrValue device_ordinal_value;
1718 device_ordinal_value.set_placeholder("_device_ordinal");
1719 (*func.mutable_attr())["_device_ordinal"] = device_ordinal_value;
1720 func.set_name(cond_host_func_name);
1721 while_builder.Attr("cond", func);
1722 func.set_name(body_host_func_name);
1723 while_builder.Attr("body", func);
1724 while_builder.Attr(kXlaHasHostTransferAttrName, true);
1725 while_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1726 while_builder.Attr(outside_compilation_attr_name, outside_compilation_name);
1727 // Make sure loop body of i-th iteration happens before loop cond of (i+1)-th
1728 // iteration.
1729 while_builder.Attr("parallel_iterations", 1);
1730 std::vector<NodeDefBuilder::NodeOut> while_inputs{
1731 {key_placeholder->name(), 0, DT_STRING}};
1732 while_builder.Input(while_inputs);
1733 NodeDef while_def;
1734 TF_RETURN_IF_ERROR(while_builder.Finalize(&while_def));
1735 Status s;
1736 Node* while_node = host_graph.AddNode(while_def, &s);
1737 TF_RETURN_IF_ERROR(s);
1738 host_graph.AddEdge(key_placeholder, 0, while_node, 0);
1739
1740 // Convert `host_graph` to function.
1741 FunctionDef oc_host_graph_fdef;
1742 TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name,
1743 &oc_host_graph_fdef));
1744 if (fld->Find(host_graph_func_name)) {
1745 TF_RETURN_IF_ERROR(
1746 fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef));
1747 } else {
1748 TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef));
1749 }
1750
1751 return Status::OK();
1752 }
1753
1754 // Builds host graph for func call nodes.
BuildHostGraphForFuncCallNode(const string & xla_cluster_attr_name,const string & xla_cluster_name,const string & outside_compilation_attr_name,const string & func_call_node_name,const string & func_call_host_func_name,const string & host_graph_func_name,FunctionLibraryDefinition * fld)1755 Status BuildHostGraphForFuncCallNode(
1756 const string& xla_cluster_attr_name, const string& xla_cluster_name,
1757 const string& outside_compilation_attr_name,
1758 const string& func_call_node_name, const string& func_call_host_func_name,
1759 const string& host_graph_func_name, FunctionLibraryDefinition* fld) {
1760 Graph host_graph(fld);
1761 AttrValue device_ordinal_value;
1762 device_ordinal_value.set_placeholder("_device_ordinal");
1763
1764 // Step 1: add key placeholder node.
1765 TF_ASSIGN_OR_RETURN(
1766 Node * key_placeholder,
1767 AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
1768
1769 // Step 2: rewrite `host_func_name`, replace key placeholder with an _Arg
1770 // node.
1771 TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1772 xla_cluster_name, func_call_host_func_name, fld));
1773
1774 // Step 3: build a function call node with `host_func_name`, with
1775 // `key_placeholder` as input.
1776 NodeDefBuilder call_builder(absl::StrCat("oc_call_", func_call_node_name),
1777 func_call_host_func_name, fld);
1778 call_builder.Input(key_placeholder->name(), 0, DT_STRING);
1779 call_builder.Attr("_device_ordinal", device_ordinal_value);
1780 call_builder.Attr(kXlaHasHostTransferAttrName, true);
1781 call_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1782 call_builder.Attr(outside_compilation_attr_name, call_builder.node_name());
1783 NodeDef call_def;
1784 TF_RETURN_IF_ERROR(call_builder.Finalize(&call_def));
1785 Status s;
1786 Node* call_node = host_graph.AddNode(call_def, &s);
1787 TF_RETURN_IF_ERROR(s);
1788 host_graph.AddEdge(key_placeholder, 0, call_node, 0);
1789
1790 // Convert `host_graph` to function.
1791 FunctionDef oc_host_graph_fdef;
1792 TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name,
1793 HostGraphControlRetMapping,
1794 &oc_host_graph_fdef));
1795 if (fld->Find(host_graph_func_name)) {
1796 TF_RETURN_IF_ERROR(
1797 fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef));
1798 } else {
1799 TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef));
1800 }
1801
1802 return Status::OK();
1803 }
1804
ExtractOutsideCompilationForFuncCallNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const std::map<string,int> & host_compute_core,Graph * g,Node * n,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * host_graphs,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)1805 TF_ATTRIBUTE_NOINLINE Status ExtractOutsideCompilationForFuncCallNode(
1806 const string& xla_cluster_attr_name,
1807 const string& outside_compilation_attr_name, const string& xla_cluster_name,
1808 const std::map<string, int>& host_compute_core, Graph* g, Node* n,
1809 FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
1810 std::vector<string>* host_graphs,
1811 std::vector<string>* shape_inference_graphs,
1812 bool* has_outside_compilation) {
1813 bool func_has_outside_compilation = false;
1814 NameAttrList func;
1815 if (fld->Contains(n->type_string())) {
1816 func.set_name(n->type_string());
1817 typedef protobuf::Map<string, AttrValue> AttrMap;
1818 *func.mutable_attr() = AttrMap(n->attrs().begin(), n->attrs().end());
1819 } else if (n->IsPartitionedCall()) {
1820 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &func));
1821 } else {
1822 TF_RET_CHECK(n->type_string() == FunctionLibraryDefinition::kGradientOp);
1823 func.set_name(FunctionLibraryDefinition::kGradientOp);
1824 *func.mutable_attr() = n->def().attr();
1825 }
1826 string canonical_func_name;
1827 if (func.name() == FunctionLibraryDefinition::kGradientOp) {
1828 NameAttrList forward_func;
1829 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &forward_func));
1830 canonical_func_name = absl::StrCat("gradient_", forward_func.name());
1831 } else {
1832 canonical_func_name = func.name();
1833 }
1834 string new_func_name = absl::StrCat(canonical_func_name, "_oc");
1835 string host_func_name =
1836 absl::StrCat("oc_func_call_host_", canonical_func_name);
1837 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
1838 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
1839 func, new_func_name, host_func_name, host_compute_core, flr, fld,
1840 shape_inference_graphs, &func_has_outside_compilation));
1841
1842 // If the function call does not have outside compilation, nothing to do.
1843 if (!func_has_outside_compilation) {
1844 return Status::OK();
1845 }
1846
1847 *has_outside_compilation = true;
1848
1849 // Change `n` to call the new function directly.
1850 auto replace_builder =
1851 absl::make_unique<NodeDefBuilder>(n->name(), new_func_name, fld);
1852 std::vector<NodeDefBuilder::NodeOut> inputs(n->num_inputs());
1853 for (const Edge* e : n->in_edges()) {
1854 if (e->IsControlEdge()) {
1855 continue;
1856 }
1857
1858 const bool input_size_check =
1859 e->dst_input() < static_cast<int>(inputs.size());
1860 TF_RET_CHECK(e->dst_input() >= 0 && input_size_check);
1861 inputs[e->dst_input()] =
1862 NodeDefBuilder::NodeOut{e->src()->name(), e->src_output(),
1863 e->src()->output_type(e->src_output())};
1864 }
1865 for (const auto& input : inputs) {
1866 replace_builder->Input(input);
1867 }
1868 for (const auto& attr : n->attrs()) {
1869 replace_builder->Attr(attr.first, attr.second);
1870 }
1871 auto replace_def = absl::make_unique<NodeDef>();
1872 TF_RETURN_IF_ERROR(replace_builder->Finalize(replace_def.get()));
1873 TF_ASSIGN_OR_RETURN(Node * replace, ReplaceNode(g, n, *replace_def));
1874 replace->AddAttr(kXlaTokenInputNodesAttrName,
1875 std::vector<string>{kXlaTokenArgNodeName});
1876 replace->AddAttr(kXlaOriginalOutsideCompilationNodeName, replace->name());
1877
1878 // Build host side graph for the function call.
1879 string oc_host_graph_name =
1880 absl::StrCat("oc_func_host_graph_", replace->name());
1881 TF_RETURN_IF_ERROR(BuildHostGraphForFuncCallNode(
1882 xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
1883 replace->name(), host_func_name, oc_host_graph_name, fld));
1884
1885 // Record the host graph.
1886 host_graphs->push_back(oc_host_graph_name);
1887
1888 return Status::OK();
1889 }
1890
ExtractOutsideCompilationForIfNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const std::map<string,int> & host_compute_core,Graph * g,Node * n,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * host_graphs,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)1891 Status ExtractOutsideCompilationForIfNode(
1892 const string& xla_cluster_attr_name,
1893 const string& outside_compilation_attr_name, const string& xla_cluster_name,
1894 const std::map<string, int>& host_compute_core, Graph* g, Node* n,
1895 FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
1896 std::vector<string>* host_graphs,
1897 std::vector<string>* shape_inference_graphs,
1898 bool* has_outside_compilation) {
1899 // Instantiate "then_branch" and "else_branch".
1900 NameAttrList then_branch, else_branch;
1901 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "then_branch", &then_branch));
1902 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "else_branch", &else_branch));
1903
1904 // Extract outside compilation for then_branch and else_branch.
1905 bool then_branch_has_outside_compilation = false;
1906 bool else_branch_has_outside_compilation = false;
1907 string then_branch_host_func_name =
1908 absl::StrCat("oc_then_branch_host_if_", then_branch.name()),
1909 else_branch_host_func_name =
1910 absl::StrCat("oc_else_branch_host_if_", else_branch.name());
1911 string then_branch_xla_func_name = absl::StrCat(then_branch.name(), "_oc"),
1912 else_branch_xla_func_name = absl::StrCat(else_branch.name(), "_oc");
1913 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
1914 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
1915 then_branch, then_branch_xla_func_name, then_branch_host_func_name,
1916 host_compute_core, flr, fld, shape_inference_graphs,
1917 &then_branch_has_outside_compilation));
1918 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
1919 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
1920 else_branch, else_branch_xla_func_name, else_branch_host_func_name,
1921 host_compute_core, flr, fld, shape_inference_graphs,
1922 &else_branch_has_outside_compilation));
1923
1924 // If then/else branch do not have outside compilation, nothing to do.
1925 if (!then_branch_has_outside_compilation &&
1926 !else_branch_has_outside_compilation) {
1927 return Status::OK();
1928 }
1929
1930 *has_outside_compilation = true;
1931
1932 // Change If node to call the new functions.
1933 if (then_branch_has_outside_compilation) {
1934 then_branch.set_name(then_branch_xla_func_name);
1935 n->ClearAttr("then_branch");
1936 n->AddAttr("then_branch", then_branch);
1937 }
1938 if (else_branch_has_outside_compilation) {
1939 else_branch.set_name(else_branch_xla_func_name);
1940 n->ClearAttr("else_branch");
1941 n->AddAttr("else_branch", else_branch);
1942 }
1943 n->AddAttr(kXlaOriginalOutsideCompilationNodeName, n->name());
1944
1945 string host_transfer_key = absl::StrCat("oc_if_pred_", n->name());
1946
1947 // XLA computation: add a SendToHost node to send cond predicate.
1948 Node* pred_node;
1949 TF_RETURN_IF_ERROR(n->input_node(0, &pred_node));
1950 TF_ASSIGN_OR_RETURN(
1951 Node * send_pred_node,
1952 BuildSendIfPredNode(absl::StrCat("send_oc_if_pred_", n->name()),
1953 host_transfer_key, pred_node, g));
1954 n->AddAttr(kXlaTokenInputNodesAttrName,
1955 std::vector<string>{send_pred_node->name()});
1956
1957 // Add a control edge from `send_pred_node` to If node, so XlaCompiler will
1958 // visit If node after `send_pred_node`, thus the token output for
1959 // `send_pred_node` has been generated.
1960 g->AddControlEdge(send_pred_node, n);
1961
1962 // Build host side graph for the "If" node.
1963 // If then/else branch does not have outside compilation, we won't build host
1964 // graph for the branch. But here we need a host graph for both branches, so
1965 // we need to create a no-op host graph.
1966 if (!then_branch_has_outside_compilation) {
1967 std::unique_ptr<Graph> then_branch_host_graph(new Graph(fld));
1968 std::vector<string> then_branch_host_graphs;
1969 TF_RETURN_IF_ERROR(ConstructHostGraph(
1970 xla_cluster_name, outside_compilation_attr_name,
1971 then_branch_host_graphs, fld, &then_branch_host_graph));
1972 FunctionDef then_branch_host_fdef;
1973 TF_RETURN_IF_ERROR(GraphToFunctionDef(*then_branch_host_graph,
1974 then_branch_host_func_name,
1975 &then_branch_host_fdef));
1976 if (fld->Find(then_branch_host_func_name)) {
1977 TF_RETURN_IF_ERROR(fld->ReplaceFunction(then_branch_host_func_name,
1978 then_branch_host_fdef));
1979 } else {
1980 TF_RETURN_IF_ERROR(fld->AddFunctionDef(then_branch_host_fdef));
1981 }
1982 }
1983 if (!else_branch_has_outside_compilation) {
1984 std::unique_ptr<Graph> else_branch_host_graph(new Graph(fld));
1985 std::vector<string> else_branch_host_graphs;
1986 TF_RETURN_IF_ERROR(ConstructHostGraph(
1987 xla_cluster_name, outside_compilation_attr_name,
1988 else_branch_host_graphs, fld, &else_branch_host_graph));
1989 FunctionDef else_branch_host_fdef;
1990 TF_RETURN_IF_ERROR(GraphToFunctionDef(*else_branch_host_graph,
1991 else_branch_host_func_name,
1992 &else_branch_host_fdef));
1993 if (fld->Find(else_branch_host_func_name)) {
1994 TF_RETURN_IF_ERROR(fld->ReplaceFunction(else_branch_host_func_name,
1995 else_branch_host_fdef));
1996 } else {
1997 TF_RETURN_IF_ERROR(fld->AddFunctionDef(else_branch_host_fdef));
1998 }
1999 }
2000 string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name());
2001 TF_RETURN_IF_ERROR(BuildHostGraphForIfNode(
2002 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2003 n->name(), host_transfer_key, oc_host_graph_name, fld,
2004 then_branch_host_func_name, else_branch_host_func_name));
2005 host_graphs->push_back(oc_host_graph_name);
2006
2007 return Status::OK();
2008 }
2009
ExtractOutsideCompilationForWhileNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const std::map<string,int> & host_compute_core,Graph * g,Node * n,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * host_graphs,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)2010 Status ExtractOutsideCompilationForWhileNode(
2011 const string& xla_cluster_attr_name,
2012 const string& outside_compilation_attr_name, const string& xla_cluster_name,
2013 const std::map<string, int>& host_compute_core, Graph* g, Node* n,
2014 FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
2015 std::vector<string>* host_graphs,
2016 std::vector<string>* shape_inference_graphs,
2017 bool* has_outside_compilation) {
2018 // Instantiate "cond" and "body".
2019 NameAttrList cond, body;
2020 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "cond", &cond));
2021 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "body", &body));
2022
2023 // Extract outside compilation for cond and body.
2024 bool cond_has_outside_compilation = false;
2025 bool body_has_outside_compilation = false;
2026 string cond_host_func_name = absl::StrCat("oc_cond_host_while_", cond.name()),
2027 body_host_func_name = absl::StrCat("oc_body_host_while_", body.name());
2028 string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"),
2029 body_xla_func_name = absl::StrCat(body.name(), "_oc");
2030 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
2031 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2032 cond, cond_xla_func_name, cond_host_func_name, host_compute_core, flr,
2033 fld, shape_inference_graphs, &cond_has_outside_compilation));
2034 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
2035 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2036 body, body_xla_func_name, body_host_func_name, host_compute_core, flr,
2037 fld, shape_inference_graphs, &body_has_outside_compilation));
2038
2039 // If cond/body do not have outside compilation, nothing to do.
2040 if (!cond_has_outside_compilation && !body_has_outside_compilation) {
2041 return Status::OK();
2042 }
2043
2044 *has_outside_compilation = true;
2045
2046 // Change While node to call the new functions.
2047 if (cond_has_outside_compilation) {
2048 cond.set_name(cond_xla_func_name);
2049 n->ClearAttr("cond");
2050 n->AddAttr("cond", cond);
2051 }
2052 if (body_has_outside_compilation) {
2053 body.set_name(body_xla_func_name);
2054 n->ClearAttr("body");
2055 n->AddAttr("body", body);
2056 }
2057 n->AddAttr(kXlaOriginalOutsideCompilationNodeName, n->name());
2058
2059 string host_transfer_key = absl::StrCat("oc_while_pred_", n->name());
2060
2061 // XLA computation: rewrite cond function to add a SendToHost node to send
2062 // loop predicate.
2063 TF_RETURN_IF_ERROR(AddSendLoopPredToLoopCond(
2064 cond_xla_func_name, host_transfer_key, &cond, fld, n));
2065 n->AddAttr(kXlaTokenInputNodesAttrName,
2066 std::vector<string>{kXlaTokenArgNodeName});
2067
2068 // Build host side graph for the "While" node.
2069 if (!cond_has_outside_compilation) {
2070 std::unique_ptr<Graph> cond_host_graph(new Graph(fld));
2071 std::vector<string> host_graphs;
2072 TF_RETURN_IF_ERROR(ConstructHostGraph(xla_cluster_name,
2073 outside_compilation_attr_name,
2074 host_graphs, fld, &cond_host_graph));
2075 FunctionDef cond_host_fdef;
2076 TF_RETURN_IF_ERROR(GraphToFunctionDef(*cond_host_graph, cond_host_func_name,
2077 &cond_host_fdef));
2078 if (fld->Find(cond_host_func_name)) {
2079 TF_RETURN_IF_ERROR(
2080 fld->ReplaceFunction(cond_host_func_name, cond_host_fdef));
2081 } else {
2082 TF_RETURN_IF_ERROR(fld->AddFunctionDef(cond_host_fdef));
2083 }
2084 }
2085 if (!body_has_outside_compilation) {
2086 std::unique_ptr<Graph> body_host_graph(new Graph(fld));
2087 std::vector<string> host_graphs;
2088 TF_RETURN_IF_ERROR(ConstructHostGraph(xla_cluster_name,
2089 outside_compilation_attr_name,
2090 host_graphs, fld, &body_host_graph));
2091 FunctionDef body_host_fdef;
2092 TF_RETURN_IF_ERROR(GraphToFunctionDef(*body_host_graph, body_host_func_name,
2093 &body_host_fdef));
2094 if (fld->Find(body_host_func_name)) {
2095 TF_RETURN_IF_ERROR(
2096 fld->ReplaceFunction(body_host_func_name, body_host_fdef));
2097 } else {
2098 TF_RETURN_IF_ERROR(fld->AddFunctionDef(body_host_fdef));
2099 }
2100 }
2101 string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name());
2102 TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode(
2103 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2104 n->name(), host_transfer_key, oc_host_graph_name, fld,
2105 cond_host_func_name, body_host_func_name));
2106 host_graphs->push_back(oc_host_graph_name);
2107
2108 return Status::OK();
2109 }
2110
ExtractOutsideCompilationForNodesWithAssociatedFunctions(Graph * g,const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const std::map<string,int> & host_compute_core,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * host_graphs,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)2111 Status ExtractOutsideCompilationForNodesWithAssociatedFunctions(
2112 Graph* g, const string& xla_cluster_attr_name,
2113 const string& outside_compilation_attr_name, const string& xla_cluster_name,
2114 const std::map<string, int>& host_compute_core, FunctionLibraryRuntime* flr,
2115 FunctionLibraryDefinition* fld, std::vector<string>* host_graphs,
2116 std::vector<string>* shape_inference_graphs,
2117 bool* has_outside_compilation) {
2118 std::vector<Node*> if_nodes, while_nodes, func_call_nodes;
2119 for (Node* n : g->nodes()) {
2120 if (n->IsIfNode()) {
2121 if_nodes.push_back(n);
2122 } else if (n->IsWhileNode()) {
2123 while_nodes.push_back(n);
2124 } else if (IsFunctionCall(*fld, *n)) {
2125 func_call_nodes.push_back(n);
2126 }
2127 }
2128
2129 for (Node* n : func_call_nodes) {
2130 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFuncCallNode(
2131 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2132 host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs,
2133 has_outside_compilation));
2134 }
2135
2136 for (Node* n : if_nodes) {
2137 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForIfNode(
2138 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2139 host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs,
2140 has_outside_compilation));
2141 }
2142
2143 for (Node* n : while_nodes) {
2144 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForWhileNode(
2145 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2146 host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs,
2147 has_outside_compilation));
2148 }
2149
2150 return Status::OK();
2151 }
2152
CopyOutsideCompilationConstNodes(Graph * g,const string & outside_compilation_attr_name)2153 Status CopyOutsideCompilationConstNodes(
2154 Graph* g, const string& outside_compilation_attr_name) {
2155 for (Node* n : g->op_nodes()) {
2156 if (!n->IsConstant() ||
2157 !HasNodeAttr(n->def(), outside_compilation_attr_name)) {
2158 continue;
2159 }
2160
2161 std::vector<const Edge*> out_edges(n->out_edges().begin(),
2162 n->out_edges().end());
2163 bool has_non_oc_output = false;
2164 for (const Edge* e : out_edges) {
2165 if (!e->IsControlEdge() &&
2166 !HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
2167 has_non_oc_output = true;
2168 break;
2169 }
2170 }
2171 if (!has_non_oc_output) {
2172 continue;
2173 }
2174
2175 NodeDef copy_def = n->def();
2176 copy_def.set_name(g->NewName(n->name()));
2177 copy_def.mutable_attr()->erase(outside_compilation_attr_name);
2178 Status s;
2179 Node* copy_node = g->AddNode(copy_def, &s);
2180 TF_RETURN_IF_ERROR(s);
2181 for (const Edge* e : n->in_edges()) {
2182 if (e->IsControlEdge()) {
2183 g->AddControlEdge(e->src(), copy_node);
2184 }
2185 }
2186 for (const Edge* e : out_edges) {
2187 if (!e->IsControlEdge() &&
2188 !HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
2189 Node* dst = e->dst();
2190 int dst_input = e->dst_input();
2191 g->RemoveEdge(e);
2192 g->AddEdge(copy_node, 0, dst, dst_input);
2193 }
2194 }
2195 }
2196
2197 return Status::OK();
2198 }
2199
2200 } // namespace
2201
operator ()(const std::vector<OutputTensor> & arg_source_tensors,std::unique_ptr<Graph> * graph,std::vector<int> * input_permutation,std::vector<int> * output_permutation,NodeDef * node_def)2202 Status RewriteOutsideCompilationSubgraphFn::operator()(
2203 const std::vector<OutputTensor>& arg_source_tensors,
2204 std::unique_ptr<Graph>* graph, std::vector<int>* input_permutation,
2205 std::vector<int>* output_permutation, NodeDef* node_def) {
2206 string old_name = node_def->op();
2207 string new_name =
2208 absl::StrCat(xla_cluster_name_, "_", new_function_name_, "_", old_name);
2209 node_def->set_op(new_name);
2210 node_def->set_name(new_name);
2211
2212 // Later we will run PruneForReverseReachability(), so make sure all original
2213 // nodes are reachable from sink node and won't be removed.
2214 FixupSourceAndSinkEdges(graph->get());
2215
2216 // Step 1: create a key placeholder node.
2217 TF_ASSIGN_OR_RETURN(
2218 Node * key_placeholder,
2219 AddHostComputeKeyPlaceholder(xla_cluster_name_, graph->get()));
2220
2221 // Step 2: build RecvAtHost node, and replace all _Arg nodes with it.
2222 std::vector<DataType> recv_at_host_dtypes;
2223 TF_ASSIGN_OR_RETURN(
2224 Node * recv_at_host_node,
2225 ReplaceArgNodesWithRecvAtHostNode(graph->get(), new_name,
2226 &recv_at_host_dtypes, key_placeholder));
2227
2228 // Step 3: build SendFromHost node, and replace all _Retval nodes with it.
2229 std::vector<DataType> send_from_host_dtypes;
2230 TF_ASSIGN_OR_RETURN(
2231 Node * send_from_host_node,
2232 ReplaceRetNodesWithSendFromHostNode(
2233 graph->get(), new_name, &send_from_host_dtypes, key_placeholder));
2234
2235 // Step 4: add XLA cluster and outside compilation attr.
2236 for (Node* n : (*graph)->nodes()) {
2237 if (IsKeyPlaceholderNode(*n)) {
2238 continue;
2239 }
2240
2241 n->AddAttr(xla_cluster_attr_name_, xla_cluster_name_);
2242 n->AddAttr(outside_compilation_attr_name_, old_name);
2243 }
2244
2245 // Check whether we have all input shapes for XlaSendFromHost. If we do, we
2246 // will set `shapes` attr for the call node; otherwise we will save the
2247 // shape inference graph and set `shape_inference_graph` for the call node.
2248 absl::optional<std::vector<PartialTensorShape>> shapes =
2249 GetInferredInputShapes(send_from_host_dtypes.size(), send_from_host_node);
2250 for (Node* n : (*graph)->nodes()) {
2251 n->ClearAttr(kXlaInferredShapesAttrName);
2252 }
2253
2254 // Step 5: add control edges for originally XLA <-> outside compilation
2255 // control edges.
2256 for (Node* n : (*graph)->nodes()) {
2257 if (HasNodeAttr(n->def(), kXlaConnectedToXlaComputationAttrName)) {
2258 (*graph)->AddControlEdge(n, send_from_host_node);
2259 n->ClearAttr(kXlaConnectedToXlaComputationAttrName);
2260 }
2261 if (HasNodeAttr(n->def(), kXlaConnectedFromXlaComputationAttrName)) {
2262 (*graph)->AddControlEdge(recv_at_host_node, n);
2263 n->ClearAttr(kXlaConnectedFromXlaComputationAttrName);
2264 }
2265 }
2266
2267 // Step 6: RecvAtHost/SendFromHost/key_placeholder might be dead nodes. Prune
2268 // them if necessary.
2269 // - RecvAtHost should be pruned iff it has no output data/control edges. If
2270 // it has any output edge, it will be reverse reachable from sink node. We
2271 // don't need to do anything special.
2272 // - SendFromHost should be pruned iff it has no input data/control edges. If
2273 // it has input edges other than key_placeholder, we connect it to sink
2274 // node so it won't be pruned.
2275 // - key_placeholder should be pruned iff RecvAtHost/SendFromHost are pruned.
2276 // We don't need to do anything special.
2277 if (send_from_host_node->in_edges().size() > 1) {
2278 (*graph)->AddControlEdge(send_from_host_node, (*graph)->sink_node());
2279 }
2280 PruneForReverseReachability(
2281 graph->get(), std::unordered_set<const Node*>{(*graph)->sink_node()});
2282
2283 // Step 7: add necessary attributes to function call node, so we can replace
2284 // it with HostCompute node later.
2285 AddNodeAttr("_outside_compilation_subgraph", old_name, node_def);
2286 if (shapes) {
2287 NameAttrList shape_inference_graph;
2288 AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def);
2289 AddNodeAttr("shapes", *shapes, node_def);
2290 } else {
2291 string shape_inference_func_name =
2292 absl::StrCat("_outside_compilation_shape_inference_", new_name);
2293 NameAttrList shape_inference_graph;
2294 shape_inference_graph.set_name(shape_inference_func_name);
2295 AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def);
2296 AddNodeAttr("shapes", std::vector<TensorShapeProto>{}, node_def);
2297 }
2298 AddNodeAttr("ancestors", std::vector<string>{}, node_def);
2299 AddNodeAttr("Tinputs", recv_at_host_dtypes, node_def);
2300 AddNodeAttr("Toutputs", send_from_host_dtypes, node_def);
2301 AddNodeAttr("key", absl::StrCat("host_compute_channel_", new_name), node_def);
2302
2303 return Status::OK();
2304 }
2305
ExtractOutsideCompilationForFunction(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const NameAttrList & func_name_attrs,const string & new_func_name,const string & host_graph_func_name,const std::map<string,int> & host_compute_core,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)2306 Status ExtractOutsideCompilationForFunction(
2307 const string& xla_cluster_attr_name,
2308 const string& outside_compilation_attr_name, const string& xla_cluster_name,
2309 const NameAttrList& func_name_attrs, const string& new_func_name,
2310 const string& host_graph_func_name,
2311 const std::map<string, int>& host_compute_core, FunctionLibraryRuntime* flr,
2312 FunctionLibraryDefinition* fld, std::vector<string>* shape_inference_graphs,
2313 bool* has_outside_compilation) {
2314 // Convert the function to graph.
2315 const string& func_name = func_name_attrs.name();
2316 FunctionLibraryRuntime::Handle handle;
2317 TF_RETURN_IF_ERROR(
2318 flr->Instantiate(func_name, AttrSlice(&func_name_attrs.attr()), &handle));
2319 Status ret_status = Status::OK();
2320 auto cleanup_handle = gtl::MakeCleanup([&]() {
2321 auto s = flr->ReleaseHandle(handle);
2322 if (!s.ok()) {
2323 ret_status.Update(s);
2324 }
2325 });
2326 const FunctionBody* fbody = flr->GetFunctionBody(handle);
2327
2328 // Check if we have outside compilation nodes.
2329 *has_outside_compilation = false;
2330 for (Node* n : fbody->graph->nodes()) {
2331 if (HasNodeAttr(n->def(), outside_compilation_attr_name)) {
2332 *has_outside_compilation = true;
2333 break;
2334 }
2335 }
2336 // We cannot early return here, because we might have outside compilation in
2337 // If/While function body.
2338
2339 if (VLOG_IS_ON(4)) {
2340 DumpGraphToFile(
2341 absl::StrCat("extract_outside_compilation_for_func_before_", func_name),
2342 *fbody->graph, fld);
2343 }
2344
2345 std::unique_ptr<Graph> graph_out;
2346 std::vector<string> outside_compilation_host_graphs;
2347 std::vector<string> shape_inference_graphs_to_rewrite;
2348 if (*has_outside_compilation) {
2349 // Copy outside compilation Const nodes with non outside compilation users.
2350 TF_RETURN_IF_ERROR(CopyOutsideCompilationConstNodes(
2351 fbody->graph, outside_compilation_attr_name));
2352
2353 // Find dependencies between outside compilation clusters.
2354 TF_ASSIGN_OR_RETURN(auto cluster_deps,
2355 OutsideCompilationClusterDependencies(
2356 fbody->graph, outside_compilation_attr_name));
2357
2358 // Preprocess edges between different outside compilations. They will be
2359 // restored in `ConstructHostGraph()`.
2360 TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations(
2361 fbody->graph, outside_compilation_attr_name));
2362
2363 // Encapsulate outside_compilation cluster into function call node.
2364 auto rewrite_fn = absl::make_unique<RewriteOutsideCompilationSubgraphFn>(
2365 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2366 new_func_name);
2367 TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
2368 outside_compilation_attr_name, *fbody->graph, *rewrite_fn,
2369 /*reuse_existing_functions=*/true, &graph_out, fld));
2370
2371 // Replace outside_compilation function nodes with HostCompute ops.
2372 std::vector<Node*> outside_compilation_nodes;
2373 for (Node* n : graph_out->nodes()) {
2374 if (HasNodeAttr(n->def(), "_outside_compilation_subgraph")) {
2375 outside_compilation_nodes.push_back(n);
2376 outside_compilation_host_graphs.push_back(n->name());
2377
2378 // If we could not infer shapes for XlaSendFromHost inputs statically,
2379 // we will set the "shape_inference_graph" attribute. In that case, copy
2380 // outside compilation subgraph as shape inference graph in `fld`.
2381 auto shape_inference_graph = absl::make_unique<NameAttrList>();
2382 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "shape_inference_graph",
2383 shape_inference_graph.get()));
2384 if (!shape_inference_graph->name().empty()) {
2385 shape_inference_graphs->push_back(shape_inference_graph->name());
2386 shape_inference_graphs_to_rewrite.push_back(
2387 shape_inference_graph->name());
2388
2389 const FunctionDef* xla_fdef = fld->Find(n->name());
2390 if (!xla_fdef) {
2391 return errors::Internal("Cannot find XLA function ", n->name());
2392 }
2393 auto shape_inference_fdef = absl::make_unique<FunctionDef>(*xla_fdef);
2394 shape_inference_fdef->mutable_signature()->set_name(
2395 shape_inference_graph->name());
2396 if (fld->Find(shape_inference_graph->name())) {
2397 TF_RETURN_IF_ERROR(fld->ReplaceFunction(
2398 shape_inference_graph->name(), *shape_inference_fdef));
2399 } else {
2400 TF_RETURN_IF_ERROR(fld->AddFunctionDef(*shape_inference_fdef));
2401 }
2402 }
2403 }
2404 }
2405 std::map<string, Node*> host_compute_nodes;
2406 for (Node* n : outside_compilation_nodes) {
2407 auto host_compute_node_or = ReplaceOutsideCompilationCallNode(
2408 graph_out.get(), n, host_compute_core, *cluster_deps);
2409 TF_RETURN_IF_ERROR(host_compute_node_or.status());
2410 Node* host_compute_node = host_compute_node_or.ValueOrDie();
2411 host_compute_nodes[host_compute_node->name()] = host_compute_node;
2412 }
2413 // For XlaHostCompute nodes with dependencies, add control edges between
2414 // them so XlaCompiler can handle them in correct order.
2415 for (const auto& iter : host_compute_nodes) {
2416 Node* host_compute_node = iter.second;
2417 std::vector<string> token_input_node_names;
2418 TF_RETURN_IF_ERROR(GetNodeAttr(host_compute_node->def(),
2419 kXlaTokenInputNodesAttrName,
2420 &token_input_node_names));
2421 for (const string& node_name : token_input_node_names) {
2422 if (node_name == kXlaTokenArgNodeName) {
2423 continue;
2424 }
2425
2426 auto iter = host_compute_nodes.find(node_name);
2427 TF_RET_CHECK(iter != host_compute_nodes.end());
2428 graph_out->AddControlEdge(iter->second, host_compute_node);
2429 }
2430 }
2431 }
2432
2433 // Handle nodes with associated functions.
2434 Graph* g = (*has_outside_compilation) ? graph_out.get() : fbody->graph;
2435 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForNodesWithAssociatedFunctions(
2436 g, xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2437 host_compute_core, flr, fld, &outside_compilation_host_graphs,
2438 shape_inference_graphs, has_outside_compilation));
2439
2440 if (*has_outside_compilation) {
2441 // Construct host graph.
2442 std::unique_ptr<Graph> host_graph;
2443 TF_RETURN_IF_ERROR(
2444 ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name,
2445 outside_compilation_host_graphs, fld, &host_graph));
2446 auto host_graph_fdef = absl::make_unique<FunctionDef>();
2447 TF_RETURN_IF_ERROR(GraphToFunctionDef(*host_graph, host_graph_func_name,
2448 HostGraphControlRetMapping,
2449 host_graph_fdef.get()));
2450 if (fld->Find(host_graph_func_name)) {
2451 TF_RETURN_IF_ERROR(
2452 fld->ReplaceFunction(host_graph_func_name, *host_graph_fdef));
2453 } else {
2454 TF_RETURN_IF_ERROR(fld->AddFunctionDef(*host_graph_fdef));
2455 }
2456
2457 // Shape inference graphs might contain Placeholder nodes for outside
2458 // compilation to outside compilation edges. Rewrite shape inference graphs
2459 // to remove such nodes.
2460 for (const string& shape_inference_graph :
2461 shape_inference_graphs_to_rewrite) {
2462 TF_RETURN_IF_ERROR(
2463 RewriteShapeInferenceGraph(shape_inference_graph, host_graph.get(),
2464 /*pivot_node=*/nullptr, fld));
2465 }
2466
2467 // Remove the outside compilation graphs from function library.
2468 for (const string& func : outside_compilation_host_graphs) {
2469 TF_RETURN_IF_ERROR(fld->RemoveFunction(func));
2470 }
2471
2472 // Replace original function.
2473 auto updated_fdef = absl::make_unique<FunctionDef>();
2474 TF_RETURN_IF_ERROR(
2475 GraphToFunctionDef(*g, new_func_name, updated_fdef.get()));
2476 updated_fdef->mutable_signature()->set_is_stateful(true);
2477 const FunctionDef* original_fdef = fld->Find(func_name);
2478 if (original_fdef) {
2479 for (const auto& attr : original_fdef->attr()) {
2480 (*updated_fdef->mutable_attr())[attr.first] = attr.second;
2481 }
2482 }
2483 if (fld->Find(new_func_name)) {
2484 TF_RETURN_IF_ERROR(fld->ReplaceFunction(new_func_name, *updated_fdef));
2485 } else {
2486 TF_RETURN_IF_ERROR(fld->AddFunctionDef(*updated_fdef));
2487 }
2488 if (VLOG_IS_ON(4)) {
2489 DumpGraphToFile(
2490 absl::StrCat("extract_outside_compilation_for_func_after_",
2491 func_name),
2492 *g, fld);
2493 }
2494 }
2495
2496 return ret_status;
2497 }
2498
ExtractOutsideCompilation(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const std::unordered_map<string,XlaClusterInfo> & clusters,Graph * g,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,bool * modified)2499 Status ExtractOutsideCompilation(
2500 const string& xla_cluster_attr_name,
2501 const string& outside_compilation_attr_name,
2502 const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g,
2503 FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
2504 bool* modified) {
2505 if (VLOG_IS_ON(4)) {
2506 DumpGraphToFile("extract_outside_compilation_before", *g, fld);
2507 }
2508
2509 *modified = false;
2510 auto node_name_index = g->BuildNodeNameIndex();
2511 for (auto& iter : clusters) {
2512 string xla_cluster_name = iter.first;
2513 Node* n = iter.second.node;
2514 auto const& func_name_attrs = iter.second.func_name_attrs;
2515 auto const& host_compute_core = iter.second.host_compute_core;
2516
2517 std::vector<string> shape_inference_graphs;
2518 bool has_outside_compilation;
2519 string host_graph_func_name =
2520 absl::StrCat("oc_host_graph_", xla_cluster_name);
2521 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
2522 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2523 func_name_attrs, func_name_attrs.name(), host_graph_func_name,
2524 host_compute_core, flr, fld, &shape_inference_graphs,
2525 &has_outside_compilation));
2526 *modified |= has_outside_compilation;
2527
2528 if (has_outside_compilation) {
2529 string pivot_name = absl::StrCat(xla_cluster_name, "/pivot");
2530 Node* pivot_node = node_name_index[pivot_name];
2531 TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph(
2532 g, fld, host_graph_func_name, n, pivot_node));
2533
2534 TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name));
2535
2536 for (const auto& shape_inference_graph_name : shape_inference_graphs) {
2537 TF_RETURN_IF_ERROR(RewriteShapeInferenceGraph(
2538 shape_inference_graph_name, g, pivot_node, fld));
2539 }
2540 }
2541 }
2542
2543 if (VLOG_IS_ON(4)) {
2544 DumpGraphToFile("extract_outside_compilation_after", *g, fld);
2545 }
2546 return Status::OK();
2547 }
2548
2549 } // namespace tensorflow
2550