1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
17
18 #include <functional>
19 #include <memory>
20 #include <numeric>
21 #include <string>
22 #include <unordered_map>
23 #include <vector>
24
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/match.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/compiler/jit/flags.h"
30 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
31 #include "tensorflow/compiler/jit/shape_inference_helpers.h"
32 #include "tensorflow/compiler/jit/xla_cluster_util.h"
33 #include "tensorflow/compiler/tf2xla/const_analysis.h"
34 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
35 #include "tensorflow/compiler/xla/status_macros.h"
36 #include "tensorflow/core/common_runtime/device_factory.h"
37 #include "tensorflow/core/common_runtime/function.h"
38 #include "tensorflow/core/common_runtime/optimization_registry.h"
39 #include "tensorflow/core/common_runtime/shape_refiner.h"
40 #include "tensorflow/core/framework/function.h"
41 #include "tensorflow/core/framework/graph_def_util.h"
42 #include "tensorflow/core/framework/graph_to_functiondef.h"
43 #include "tensorflow/core/framework/node_def_builder.h"
44 #include "tensorflow/core/framework/node_def_util.h"
45 #include "tensorflow/core/framework/tensor.pb.h"
46 #include "tensorflow/core/graph/algorithm.h"
47 #include "tensorflow/core/graph/control_flow.h"
48 #include "tensorflow/core/graph/graph.h"
49 #include "tensorflow/core/graph/graph_def_builder.h"
50 #include "tensorflow/core/graph/tensor_id.h"
51 #include "tensorflow/core/lib/gtl/map_util.h"
52 #include "tensorflow/core/lib/hash/hash.h"
53 #include "tensorflow/core/public/session_options.h"
54 #include "tensorflow/core/public/version.h"
55 #include "tensorflow/core/util/device_name_utils.h"
56 #include "tensorflow/core/util/dump_graph.h"
57
58 namespace tensorflow {
59
60 const char* const kXlaCompiledKernelAttr = "_XlaCompiledKernel";
61 const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs";
62 const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs";
63 const char* const kXlaHostTransferSequencerAttr =
64 "_xla_host_transfer_sequencer";
65 const char* const kXlaHasReferenceVarsAttr = "_XlaHasReferenceVars";
66
SortControlInputs(GraphDef * gdef)67 void SortControlInputs(GraphDef* gdef) {
68 int64 num_nodes = gdef->node_size();
69 for (int64 i = 0; i < num_nodes; ++i) {
70 NodeDef* node = gdef->mutable_node(i);
71 // Stable sort control inputs and leave the order of data inputs unchanged.
72 std::stable_sort(node->mutable_input()->begin(),
73 node->mutable_input()->end(),
74 [](const string& a, const string& b) {
75 bool a_is_control = absl::StartsWith(a, "^");
76 bool b_is_control = absl::StartsWith(b, "^");
77 return (!a_is_control && b_is_control) ||
78 (a_is_control && b_is_control && a < b);
79 });
80 }
81 }
82
83 namespace {
84
AreAllParentsGuaranteedConst(const Node & n,const absl::flat_hash_set<const Node * > & runtime_const_nodes)85 bool AreAllParentsGuaranteedConst(
86 const Node& n,
87 const absl::flat_hash_set<const Node*>& runtime_const_nodes) {
88 if (n.type_string() == "GuaranteeConst") {
89 // If the current node is itself a cast-to-const, no need
90 // to look at the incoming edges.
91 return true;
92 }
93
94 bool all_parents_const = true;
95 bool atleast_one_non_control_edge = false;
96 for (const Edge* in : n.in_edges()) {
97 atleast_one_non_control_edge =
98 atleast_one_non_control_edge || !in->IsControlEdge();
99 if (!in->IsControlEdge() && runtime_const_nodes.count(in->src()) == 0) {
100 all_parents_const = false;
101 break;
102 }
103 }
104 return all_parents_const && atleast_one_non_control_edge;
105 }
106
MarkGuaranteedConstants(const Graph & graph,const std::vector<std::pair<const Node *,Node * >> & src_arg_pairs)107 void MarkGuaranteedConstants(
108 const Graph& graph,
109 const std::vector<std::pair<const Node*, Node*>>& src_arg_pairs) {
110 absl::flat_hash_set<const Node*> guaranteed_const_nodes;
111 std::vector<const Node*> srcs;
112 srcs.reserve(src_arg_pairs.size());
113 for (const auto& src_arg : src_arg_pairs) {
114 srcs.push_back(src_arg.first);
115 }
116 ReverseDFSFrom(
117 graph, srcs, /*enter=*/nullptr,
118 /*leave=*/[&guaranteed_const_nodes](const Node* n) {
119 // TODO(vinuraja): Doesn't work in the presence of loops.
120 if (AreAllParentsGuaranteedConst(*n, guaranteed_const_nodes)) {
121 guaranteed_const_nodes.insert(n);
122 }
123 });
124
125 for (auto& src_arg : src_arg_pairs) {
126 if (guaranteed_const_nodes.count(src_arg.first) != 0) {
127 VLOG(1) << "Guaranteed const found: " << src_arg.first->DebugString();
128 src_arg.second->AddAttr("_is_guaranteed_constant", true);
129 }
130 }
131 }
132
133 struct OutputInputTensorPairHasher {
operator ()tensorflow::__anona26b85820211::OutputInputTensorPairHasher134 uint64 operator()(std::pair<OutputTensor, InputTensor> const& s) const {
135 return Hash64Combine(OutputTensor::Hash()(s.first),
136 InputTensor::Hash()(s.second));
137 }
138 };
139
140 // TODO(phawkins) add a canonical copy of these operator names and refactor
141 // everything to use it.
142 static const char* const kArgOp = "_Arg";
143 static const char* const kRetValOp = "_Retval";
144 static const char* const kHostComputeOp = "XlaHostCompute";
145 static const char* const kSendFromHostOp = "_XlaSendFromHost";
146 static const char* const kRecvAtHostOp = "_XlaRecvAtHost";
147
148 class Encapsulator {
149 public:
Encapsulator(string group_attribute,Graph const * graph_in)150 Encapsulator(string group_attribute, Graph const* graph_in)
151 : group_attribute_(std::move(group_attribute)), graph_in_(graph_in) {}
152
153 // Find subgraphs marked with 'group_attribute', and build a new
154 // subgraph, one for each value of 'group_attribute'.
155 Status SplitIntoSubgraphs(FunctionLibraryDefinition* library);
156
157 // Build a FunctionDef for each subgraph, and add it 'library'. The values of
158 // the 'group_attribute' annotations become the function names.
159 // If 'reuse_existing_functions' is set, use an existing function with the
160 // same name, if any.
161 // If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
162 // function conversion.
163 Status BuildFunctionDefs(const RewriteSubgraphFn& rewrite_subgraph_fn,
164 bool reuse_existing_functions,
165 FunctionLibraryDefinition* library);
166
167 // Write a copy of the input graph to 'graph_out', where the subgraphs are
168 // replaced with calls to the new functions.
169 Status BuildOutputGraph(Graph* graph_out, FunctionLibraryDefinition* library);
170
171 private:
172 // A subgraph of the input, all marked with a common 'group_attribute'
173 // value.
174 //
175 // In the following simple example, A, B, ..., E are nodes in the original
176 // graph. The group attributes g are each shown as either 0 or empty.
177 //
178 // A --> B --> C --> D --> E
179 // g: g:0 g:0 g:0 g:
180 //
181 // The example is rewritten to two graphs; one on the host and one to be
182 // compiled. The host graph is as follows.
183 //
184 // A --> Call --> E
185 //
186 // The compiled cluster is as follows.
187 //
188 // Arg --> B --> C --> D --> Retval
189 class Subgraph {
190 public:
191 // Creates a graph to build the subgraph in, if it doesn't already exist,
192 // using the same op registry and versions as graph_in.
193 Node* MakeNodeImage(const Graph* graph_in, Node* node);
194
195 // Returns the graph the subgraph is being built in.
196 Graph* GetGraph() const;
197
198 // Builds a FunctionDef, and adds it to 'library'. The value of the
199 // 'group_attribute' annotations becomes the function name. If
200 // 'reuse_existing_functions' is set, use an existing function with the same
201 // name, if any. If 'rewrite_subgraph_fn' is set, it is applied to the
202 // subgraph before function conversion.
203 Status BuildFunctionDef(const string& name_in,
204 const RewriteSubgraphFn& rewrite_subgraph_fn,
205 bool reuse_existing_functions,
206 FunctionLibraryDefinition* library);
207
208 // Adds the function call node to graph_out.
209 Status AddFunctionCallNode(
210 const std::unordered_map<const Node*, Node*>& node_images,
211 Graph* graph_out);
212
213 // Returns the Node that the inputs and outputs of the function should be
214 // wired up to.
215 Node* GetCallNode() const;
216
217 // Returns the index of the arg that the dst of edge should connect to.
218 int GetArgIndexForEdge(const Edge* edge) const;
219
220 // Returns the index of the result that the src of edge should connect to.
221 int GetResultIndexForEdge(const Edge* edge) const;
222
223 // Creates an _Arg node for the src node of edge, and add its index to
224 // args_by_src_, if none exists yet. Also adds its index to args_by_dst_,
225 // and adds the edge within the subgraph from the _Arg node to the image of
226 // the dst node.
227 Status RecordArg(const Edge* edge,
228 const std::unordered_map<const Node*, Node*>& node_images,
229 std::vector<std::pair<const Node*, Node*>>* src_arg_pairs);
230
231 // Records the src of the given edge as a control result of the graph.
232 // Used during graph to function conversion to tie control results to
233 // the function signature.
234 Status RecordControlResult(
235 const Edge* edge,
236 const std::unordered_map<const Node*, Node*>& node_images);
237
238 // Creates a _Retval node for the src node of edge, and add it to results_,
239 // if none exists yet. If a new _Retval node is created, also adds the edge
240 // within the subgraph from the src to the _Retval node.
241 Status RecordResult(
242 const Edge* edge,
243 const std::unordered_map<const Node*, Node*>& node_images);
244
245 // Creates the sequencer node if it doesn't exist, adding it to graph_out.
246 Status MakeSequencingNode(const string& subgraph_name, Graph* graph_out);
247
248 // If there is a sequencer node, adds a control edge from the sequencer to
249 // the call node.
250 void ConnectSequencerToCallNode(Graph* graph_out);
251
252 Status ReplaceFunctionDef(FunctionLibraryDefinition* library);
253
254 private:
255 // The subgraph extracted from the input graph, suitable for being turned
256 // into a FunctionDef. Inputs are fed by _Arg nodes, and outputs are
257 // returned by _Retval nodes.
258 std::unique_ptr<Graph> graph_;
259
260 // Which device are these nodes on? Used to assign a device to the call
261 // node.
262 string device_;
263
264 // NodeDef for the function call node.
265 NodeDef call_node_def_;
266
267 // Name that is used for the call node. This may not be
268 // call_node_def_.name() if the client supplies a rewrite lambda.
269 string function_def_name_;
270
271 // Placeholder node simulating the host compute key in the output graph.
272 // Not owned.
273 Node* host_compute_key_placeholder_ = nullptr;
274
275 // Function call node in the output graph. Not owned.
276 Node* call_node_;
277
278 // Maps from source (producer node/slot) and destination
279 // (consumer node/slot) tensors in the input graph to _Arg numbers in
280 // the subgraph. The source map is one-to-one, whereas the dest map may be
281 // many-to-one.
282 std::unordered_map<OutputTensor, int, OutputTensor::Hash> args_by_src_;
283 std::unordered_map<InputTensor, int, InputTensor::Hash> args_by_dst_;
284
285 // The arguments to the subgraph, in order.
286 std::vector<Node*> args_;
287
288 // Map from source tensor in the input graph to result #.
289 std::unordered_map<OutputTensor, int, OutputTensor::Hash> results_;
290
291 // Set of node names that are the source of a control output of the
292 // subgraph. We store strings here so that we can tolerate nodes being
293 // removed from the graph.
294 absl::flat_hash_set<string> control_output_nodes_;
295
296 // NoOp node in the output graph that is sequenced after the call node.
297 Node* sequencer_ = nullptr;
298 };
299
300 // Returns the key attribute associated with a node in attr. Sets either
301 // result to the empty string if the respective attribute is not found.
302 Status GetFunctionNameAttr(Node const* node, string* attr) const;
303
304 // Copies edges local to a subgraph. Adds _Arg and _Retval nodes to
305 // subgraphs for data edges that cross subgraph boundaries.
306 Status CopySubgraphEdges(
307 const std::unordered_map<const Node*, Node*>& node_images,
308 std::vector<std::pair<const Node*, Node*>>* src_arg_pairs);
309
310 // Copies all marked nodes to a subgraph. Does nothing for unmarked nodes.
311 Status CopySubgraphNodes(std::unordered_map<const Node*, Node*>* node_images);
312
313 // Copies all nodes that aren't in a compiled subgraph to the output graph.
314 Status CopyNodesToOutputGraph(
315 Graph* graph_out, std::unordered_map<const Node*, Node*>* node_images);
316
317 // Adds function call nodes for each compiled subgraph.
318 Status AddFunctionCallNodes(
319 const std::unordered_map<const Node*, Node*>& node_images,
320 Graph* graph_out);
321
322 // Finds the image of an edge source in the output graph. If the edge crosses
323 // a subgraph boundary it is the output of a call node, otherwise it is a node
324 // in the output graph.
325 Status FindOutputImageOfEdgeSrc(
326 const string& src_func_id, const string& dst_func_id,
327 const std::unordered_map<const Node*, Node*>& node_images,
328 const Node* original_src_node, Node** src_image);
329
330 // Finds an edge source slot in the output graph. If the edge crosses a
331 // subgraph boundary it is a slot on the output of a call node, otherwise it
332 // is a slot on a node in the output graph.
333 int FindOutputSlotOfEdgeSrc(const string& src_func_id,
334 const string& dst_func_id,
335 const Edge* edge);
336
337 // Finds the image of an edge destination in the output graph. If the edge
338 // crosses a subgraph boundary it is the input of a call node, otherwise it is
339 // a node in the output graph.
340 Status FindOutputImageOfEdgeDst(
341 const string& src_func_id, const string& dst_func_id,
342 const std::unordered_map<const Node*, Node*>& node_images,
343 const Node* original_dst_node, Node** dst_image);
344
345 // Finds an edge destination slot in the output graph. If the edge crosses a
346 // subgraph boundary it is a slot on the input of a call node, otherwise it is
347 // a slot on a node in the output graph.
348 int FindOutputSlotOfEdgeDst(const string& src_func_id,
349 const string& dst_func_id,
350 const Edge* edge);
351
352 // Copies a single edge to the output graph. The edge is either entirely
353 // within the output graph, or crosses into or out of a compiled subgraph.
354 Status CopyEdgeToOutputGraph(
355 const Edge* edge, const string& src_func_id, const string& dst_func_id,
356 const std::unordered_map<const Node*, Node*>& node_images,
357 Graph* graph_out,
358 std::unordered_set<std::pair<OutputTensor, InputTensor>,
359 OutputInputTensorPairHasher>* edges_added);
360
361 // Adds all edges to the output graph.
362 Status AddEdgesToOutputGraph(
363 const std::unordered_map<const Node*, Node*>& node_images,
364 Graph* graph_out);
365
366 // Makes a copy of graph containing only nodes that are ancestors of at least
367 // one node in send_from_host_nodes and store it in pruned_graph. On exit
368 // nodes_images contains a mapping from nodes in graph to nodes in
369 // pruned_graph. All functions in the copied graph are inlined.
370 Status MakePrunedGraphCopyAndInline(
371 const Graph& graph, const std::vector<Node*>& sink_nodes,
372 std::unique_ptr<Graph>* pruned_graph,
373 std::unordered_map<const Node*, Node*>* node_images,
374 FunctionLibraryDefinition* library);
375
376 const string group_attribute_;
377 const Graph* graph_in_;
378
379 std::unordered_map<string, Subgraph> subgraphs_;
380
381 TF_DISALLOW_COPY_AND_ASSIGN(Encapsulator);
382 };
383
384 namespace {
385
386 // Return in 'sorted' a topological sort of clusters according to the
387 // dependencies encoded in ancestors. clusters is the list of all clusters
388 // including clusters that are not present in the ancestors map. has_successors
389 // is the set of clusters that are ancestors of some other cluster.
TopologicalClusterSort(const std::unordered_set<string> & clusters,const std::unordered_set<string> & has_successors,const std::unordered_map<string,std::unordered_set<string>> & ancestors,std::vector<string> * sorted)390 void TopologicalClusterSort(
391 const std::unordered_set<string>& clusters,
392 const std::unordered_set<string>& has_successors,
393 const std::unordered_map<string, std::unordered_set<string>>& ancestors,
394 std::vector<string>* sorted) {
395 // The nodes are placed in 'sorted' in topological order.
396 sorted->clear();
397 // We don't use the standard DFS because we are not operating on Node*
398 // objects.
399 struct Work {
400 string cluster;
401 bool leave;
402 };
403 std::set<string> visited;
404 std::vector<Work> stack;
405 // Seed the processing list with clusters that have no successors.
406 for (const auto& cluster : clusters) {
407 if (has_successors.find(cluster) == has_successors.end()) {
408 stack.push_back({cluster, false});
409 }
410 }
411 while (!stack.empty()) {
412 const Work item = stack.back();
413 stack.pop_back();
414 if (item.leave) {
415 sorted->push_back(item.cluster);
416 continue;
417 }
418
419 if (visited.find(item.cluster) != visited.end()) continue;
420 visited.insert(item.cluster);
421
422 stack.push_back({item.cluster, true});
423 const auto& iter = ancestors.find(item.cluster);
424 if (iter != ancestors.end()) {
425 for (const auto& ancestor : iter->second) {
426 stack.push_back({ancestor, false});
427 }
428 }
429 }
430 CHECK(sorted->size() == clusters.size());
431 }
432
433 } // namespace
434
GetCallNode() const435 Node* Encapsulator::Subgraph::GetCallNode() const { return call_node_; }
436
GetArgIndexForEdge(const Edge * edge) const437 int Encapsulator::Subgraph::GetArgIndexForEdge(const Edge* edge) const {
438 return args_by_dst_.at(InputTensor(edge->dst(), edge->dst_input()));
439 }
440
GetResultIndexForEdge(const Edge * edge) const441 int Encapsulator::Subgraph::GetResultIndexForEdge(const Edge* edge) const {
442 return results_.at(OutputTensor(edge->src(), edge->src_output()));
443 }
444
MakeNodeImage(const Graph * graph_in,Node * node)445 Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) {
446 if (!graph_) {
447 graph_.reset(new Graph(graph_in->op_registry()));
448 graph_->set_versions(graph_in->versions());
449 }
450
451 // TODO(b/116981129): Enhance how the device for the encapsulated subgraph is
452 // determined. In case of hard placement, ensure all the encapsulated nodes
453 // have the same requested device, which in turn will be the requested device
454 // for the entire encapsulated subgraph. In case of soft placement, use a
455 // deterministic approach to fill in the requested device. Handle co-location
456 // constraints similarly if they exist.
457 if (device_.empty()) {
458 device_ = node->assigned_device_name().empty()
459 ? node->requested_device()
460 : node->assigned_device_name();
461 }
462
463 return graph_->CopyNode(node);
464 }
465
GetGraph() const466 Graph* Encapsulator::Subgraph::GetGraph() const { return graph_.get(); }
467
RecordArg(const Edge * edge,const std::unordered_map<const Node *,Node * > & node_images,std::vector<std::pair<const Node *,Node * >> * src_arg_pairs)468 Status Encapsulator::Subgraph::RecordArg(
469 const Edge* edge, const std::unordered_map<const Node*, Node*>& node_images,
470 std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
471 Node* src_node = edge->src();
472 int src_slot = edge->src_output();
473 std::unordered_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
474 bool inserted;
475 std::tie(iter, inserted) = args_by_src_.emplace(
476 OutputTensor(src_node, src_slot), args_by_src_.size());
477 int arg_index = iter->second;
478 if (inserted) {
479 NodeDef arg_def;
480 NodeDefBuilder builder(
481 absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp,
482 NodeDebugInfo(src_node->def()));
483 DataType dtype = edge->dst()->input_type(edge->dst_input());
484 builder.Attr("T", dtype);
485 builder.Attr("index", arg_index);
486 Status s = builder.Finalize(&arg_def);
487 if (!s.ok()) return s;
488
489 Node* arg = graph_->AddNode(arg_def, &s);
490 if (!s.ok()) return s;
491
492 src_arg_pairs->push_back({src_node, arg});
493 args_.push_back(arg);
494 }
495 Node* dst_node = edge->dst();
496 Node* dst_image = node_images.at(dst_node);
497 int dst_slot = edge->dst_input();
498 args_by_dst_[InputTensor(dst_node, dst_slot)] = arg_index;
499 graph_->AddEdge(args_[arg_index], 0, dst_image, dst_slot);
500 return Status::OK();
501 }
502
RecordControlResult(const Edge * edge,const std::unordered_map<const Node *,Node * > & node_images)503 Status Encapsulator::Subgraph::RecordControlResult(
504 const Edge* edge,
505 const std::unordered_map<const Node*, Node*>& node_images) {
506 Node* src_node = edge->src();
507 Node* src_image = node_images.at(src_node);
508 control_output_nodes_.insert(src_image->name());
509 return Status::OK();
510 }
511
RecordResult(const Edge * edge,const std::unordered_map<const Node *,Node * > & node_images)512 Status Encapsulator::Subgraph::RecordResult(
513 const Edge* edge,
514 const std::unordered_map<const Node*, Node*>& node_images) {
515 Node* src_node = edge->src();
516 Node* src_image = node_images.at(src_node);
517 int src_slot = edge->src_output();
518 std::unordered_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
519 bool inserted;
520 std::tie(iter, inserted) =
521 results_.emplace(OutputTensor(src_node, src_slot), results_.size());
522 int ret_index = iter->second;
523 if (inserted) {
524 NodeDef ret_def;
525 NodeDefBuilder builder(
526 absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp,
527 NodeDebugInfo(src_node->def()));
528 DataType dtype = src_node->output_type(src_slot);
529 builder.Attr("T", dtype);
530 builder.Attr("index", ret_index);
531 builder.Input(src_image->name(), src_slot, dtype);
532 Status s = builder.Finalize(&ret_def);
533 if (!s.ok()) return s;
534 Node* ret = graph_->AddNode(ret_def, &s);
535 if (!s.ok()) return s;
536
537 graph_->AddEdge(src_image, src_slot, ret, 0);
538 }
539 return Status::OK();
540 }
541
MakeSequencingNode(const string & subgraph_name,Graph * graph_out)542 Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name,
543 Graph* graph_out) {
544 if (sequencer_ == nullptr) {
545 NodeDef seq_def;
546 // TODO(shikharagarwal): What source node should we use for errors?
547 NodeDefBuilder builder(absl::StrCat(subgraph_name, "_sequencer"), "NoOp");
548 builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name);
549 builder.Device(device_);
550 Status s = builder.Finalize(&seq_def);
551 if (!s.ok()) return s;
552
553 sequencer_ = graph_out->AddNode(seq_def, &s);
554 if (!s.ok()) return s;
555 }
556 return Status::OK();
557 }
558
ConnectSequencerToCallNode(Graph * graph_out)559 void Encapsulator::Subgraph::ConnectSequencerToCallNode(Graph* graph_out) {
560 if (sequencer_ != nullptr) {
561 VLOG(2) << "ConnectSequencerToCallNode";
562 graph_out->AddControlEdge(sequencer_, call_node_,
563 /* allow_duplicates= */ true);
564 }
565 }
566
BuildFunctionDef(const string & name_in,const RewriteSubgraphFn & rewrite_subgraph_fn,bool reuse_existing_functions,FunctionLibraryDefinition * library)567 Status Encapsulator::Subgraph::BuildFunctionDef(
568 const string& name_in, const RewriteSubgraphFn& rewrite_subgraph_fn,
569 bool reuse_existing_functions, FunctionLibraryDefinition* library) {
570 // name_in is copied here because name may be modified below if
571 // rewrite_subgraph_fn is true.
572 string name = name_in;
573 call_node_def_.set_op(name);
574 call_node_def_.set_name(name);
575 call_node_def_.set_device(device_);
576
577 if (rewrite_subgraph_fn) {
578 std::vector<OutputTensor> arg_source_tensors(args_by_src_.size());
579 for (const auto& arg : args_by_src_) {
580 arg_source_tensors.at(arg.second) = arg.first;
581 }
582 // Initialize the input and output permutations to the identity.
583 std::vector<int> input_permutation(args_by_src_.size());
584 std::iota(input_permutation.begin(), input_permutation.end(), 0);
585 std::vector<int> output_permutation(results_.size());
586 std::iota(output_permutation.begin(), output_permutation.end(), 0);
587
588 TF_RETURN_IF_ERROR(
589 rewrite_subgraph_fn(arg_source_tensors, &graph_, &input_permutation,
590 &output_permutation, &call_node_def_));
591
592 // Apply the input/output permutations to the 'args_by_...' and 'results_'
593 // mappings, so when we build edges in BuildOutputGraph() we
594 // connect them to the right input/output positions.
595 if (input_permutation.size() != args_by_src_.size()) {
596 return errors::InvalidArgument("Input permutation has incorrect size.");
597 }
598 if (output_permutation.size() != results_.size()) {
599 return errors::InvalidArgument("Output permutation has incorrect size.");
600 }
601 for (auto& arg : args_by_src_) {
602 arg.second = input_permutation[arg.second];
603 }
604 for (auto& arg : args_by_dst_) {
605 arg.second = input_permutation[arg.second];
606 }
607 for (auto& result : results_) {
608 result.second = output_permutation[result.second];
609 }
610
611 name = call_node_def_.op();
612 }
613
614 function_def_name_ = name;
615
616 FunctionDef fdef;
617 auto lookup = [this](const Node* node) -> absl::optional<string> {
618 if (control_output_nodes_.contains(node->name())) {
619 return absl::make_optional(node->name());
620 }
621 return absl::nullopt;
622 };
623 // Verify that the graph has well-formed control flow structure.
624 std::vector<ControlFlowInfo> dummy;
625 TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_.get(), &dummy));
626 TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, lookup, &fdef));
627
628 if (VLOG_IS_ON(1)) {
629 VLOG(2) << "Build function def " << name;
630 DumpGraphToFile(absl::StrCat("encapsulate_fdef_graph_", name), *graph_,
631 library);
632 DumpFunctionDefToFile(absl::StrCat("encapsulate_fdef_", name), fdef);
633 }
634
635 const FunctionDef* original_fdef = library->Find(name);
636 if (!reuse_existing_functions || original_fdef == nullptr) {
637 TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
638 } else if (!FunctionDefsEqual(*original_fdef, fdef)) {
639 TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef));
640 }
641 return Status::OK();
642 }
643
ReplaceFunctionDef(FunctionLibraryDefinition * library)644 Status Encapsulator::Subgraph::ReplaceFunctionDef(
645 FunctionLibraryDefinition* library) {
646 const string& name = function_def_name_;
647
648 FunctionDef fdef;
649 TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef));
650
651 if (VLOG_IS_ON(1)) {
652 VLOG(2) << "Replace function def " << name;
653 DumpGraphToFile(absl::StrCat("replace_encapsulate_fdef_graph_", name),
654 *graph_, library);
655 DumpFunctionDefToFile(absl::StrCat("replace_encapsulate_fdef_", name),
656 fdef);
657 }
658
659 TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef));
660 return Status::OK();
661 }
662
AddFunctionCallNode(const std::unordered_map<const Node *,Node * > & node_images,Graph * graph_out)663 Status Encapsulator::Subgraph::AddFunctionCallNode(
664 const std::unordered_map<const Node*, Node*>& node_images,
665 Graph* graph_out) {
666 Status s;
667 call_node_ = graph_out->AddNode(call_node_def_, &s);
668 if (!s.ok()) return s;
669
670 // Copy the assigned device and the key_annotation over.
671 call_node_->set_assigned_device_name(device_);
672
673 return Status::OK();
674 }
675
GetFunctionNameAttr(Node const * node,string * attr) const676 Status Encapsulator::GetFunctionNameAttr(Node const* node, string* attr) const {
677 AttrSlice attrs = node->attrs();
678 attr->clear();
679 for (const auto& node_attr : attrs) {
680 if (node_attr.first == group_attribute_) {
681 TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string"));
682 *attr = node_attr.second.s();
683 break;
684 }
685 }
686 return Status::OK();
687 }
688
IsInSubgraph(const string & func_id)689 bool IsInSubgraph(const string& func_id) { return !func_id.empty(); }
690
CopySubgraphNodes(std::unordered_map<const Node *,Node * > * node_images)691 Status Encapsulator::CopySubgraphNodes(
692 std::unordered_map<const Node*, Node*>* node_images) {
693 for (Node* node : graph_in_->op_nodes()) {
694 string func_id;
695 TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id));
696 if (!IsInSubgraph(func_id)) continue;
697
698 Subgraph& subgraph = subgraphs_[func_id];
699 Node* image = subgraph.MakeNodeImage(graph_in_, node);
700 image->ClearAttr(group_attribute_);
701 (*node_images)[node] = image;
702 }
703 return Status::OK();
704 }
705
CopySubgraphEdges(const std::unordered_map<const Node *,Node * > & node_images,std::vector<std::pair<const Node *,Node * >> * src_arg_pairs)706 Status Encapsulator::CopySubgraphEdges(
707 const std::unordered_map<const Node*, Node*>& node_images,
708 std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
709 for (const Edge* edge : graph_in_->edges()) {
710 string src_func_id;
711 TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id));
712 string dst_func_id;
713 TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id));
714 Node* src_image = gtl::FindWithDefault(node_images, edge->src(), nullptr);
715 Node* dst_image = gtl::FindWithDefault(node_images, edge->dst(), nullptr);
716
717 // Copy edges that are local to a subgraph.
718 if (IsInSubgraph(src_func_id) && IsInSubgraph(dst_func_id) &&
719 src_func_id == dst_func_id) {
720 Graph* g = subgraphs_[src_func_id].GetGraph();
721 if (edge->IsControlEdge()) {
722 g->AddControlEdge(src_image, dst_image,
723 /* allow_duplicates= */ true);
724 } else {
725 g->AddEdge(src_image, edge->src_output(), dst_image, edge->dst_input());
726 }
727 continue;
728 }
729
730 // Record 'src' as an output of its subgraph, if applicable.
731 if (IsInSubgraph(src_func_id)) {
732 if (!edge->IsControlEdge()) {
733 DataType dtype = edge->src()->output_type(edge->src_output());
734 if (IsRefType(dtype)) {
735 return errors::InvalidArgument(
736 "Ref Tensors (e.g., Variables) are not supported as results: "
737 "tensor ",
738 edge->src()->name(), ":", edge->src_output());
739 }
740 }
741
742 Subgraph& src_subgraph = subgraphs_[src_func_id];
743 if (edge->IsControlEdge()) {
744 TF_RETURN_IF_ERROR(src_subgraph.RecordControlResult(edge, node_images));
745 } else {
746 TF_RETURN_IF_ERROR(src_subgraph.RecordResult(edge, node_images));
747 }
748 }
749
750 // Record 'dst' as an input of its subgraph, if applicable.
751 if (IsInSubgraph(dst_func_id)) {
752 // Look at the type of the destination not the source, since Ref output
753 // Tensors can be automatically cast to non-Ref Tensors at the
754 // destination.
755 if (!edge->IsControlEdge()) {
756 DataType dtype = edge->dst()->input_type(edge->dst_input());
757 if (IsRefType(dtype)) {
758 return errors::InvalidArgument(
759 "Ref Tensors (e.g., Variables) are not supported as args: "
760 "tensor ",
761 edge->src()->name(), ":", edge->src_output());
762 }
763 }
764
765 Subgraph& dst_subgraph = subgraphs_[dst_func_id];
766 // Ignore control edges entering the subgraph. We will lift them onto
767 // the enclosing call operators in BuildOutputGraph().
768 if (!edge->IsControlEdge()) {
769 TF_RETURN_IF_ERROR(
770 dst_subgraph.RecordArg(edge, node_images, src_arg_pairs));
771 }
772 }
773 }
774 return Status::OK();
775 }
776
SplitIntoSubgraphs(FunctionLibraryDefinition * library)777 Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) {
778 Status s;
779
780 // Map from input graph nodes to subgraph nodes.
781 std::unordered_map<const Node*, Node*> node_images;
782
783 // Each entry of src_arg_pairs is a pair whose first element is a node in the
784 // original graph that has an output edge in the subgraph, and whose second
785 // element is the arg node in the subgraph that it sends to. The vector will
786 // be filled in below in AddArgs.
787 std::vector<std::pair<const Node*, Node*>> src_arg_pairs;
788
789 TF_RETURN_IF_ERROR(CopySubgraphNodes(&node_images));
790 TF_RETURN_IF_ERROR(CopySubgraphEdges(node_images, &src_arg_pairs));
791 MarkGuaranteedConstants(*graph_in_, src_arg_pairs);
792
793 for (auto& entry : subgraphs_) {
794 Subgraph& subgraph = entry.second;
795 FixupSourceAndSinkEdges(subgraph.GetGraph());
796 }
797
798 if (VLOG_IS_ON(1)) {
799 // Dump subgraphs.
800 for (auto& entry : subgraphs_) {
801 DumpGraphToFile(
802 absl::StrCat("encapsulate_subgraphs_subgraph_", entry.first),
803 *entry.second.GetGraph(), library);
804 }
805 }
806
807 return s;
808 }
809
BuildFunctionDefs(const RewriteSubgraphFn & rewrite_subgraph_fn,bool reuse_existing_functions,FunctionLibraryDefinition * library)810 Status Encapsulator::BuildFunctionDefs(
811 const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions,
812 FunctionLibraryDefinition* library) {
813 for (auto& subgraph_entry : subgraphs_) {
814 string name = subgraph_entry.first;
815 Subgraph& subgraph = subgraph_entry.second;
816 TF_RETURN_IF_ERROR(subgraph.BuildFunctionDef(
817 name, rewrite_subgraph_fn, reuse_existing_functions, library));
818 }
819 return Status::OK();
820 }
821
CopyNodesToOutputGraph(Graph * graph_out,std::unordered_map<const Node *,Node * > * node_images)822 Status Encapsulator::CopyNodesToOutputGraph(
823 Graph* graph_out, std::unordered_map<const Node*, Node*>* node_images) {
824 for (Node* node : graph_in_->op_nodes()) {
825 string func_id;
826 TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id));
827
828 // Don't copy nodes that are going to be encapsulated.
829 if (IsInSubgraph(func_id)) continue;
830
831 Node* image = graph_out->CopyNode(node);
832 (*node_images)[node] = image;
833 }
834 (*node_images)[graph_in_->source_node()] = graph_out->source_node();
835 (*node_images)[graph_in_->sink_node()] = graph_out->sink_node();
836 return Status::OK();
837 }
838
AddFunctionCallNodes(const std::unordered_map<const Node *,Node * > & node_images,Graph * graph_out)839 Status Encapsulator::AddFunctionCallNodes(
840 const std::unordered_map<const Node*, Node*>& node_images,
841 Graph* graph_out) {
842 for (auto& subgraph_entry : subgraphs_) {
843 TF_RETURN_IF_ERROR(
844 subgraph_entry.second.AddFunctionCallNode(node_images, graph_out));
845 }
846 return Status::OK();
847 }
848
FindOutputImageOfEdgeSrc(const string & src_func_id,const string & dst_func_id,const std::unordered_map<const Node *,Node * > & node_images,const Node * original_src_node,Node ** src_image)849 Status Encapsulator::FindOutputImageOfEdgeSrc(
850 const string& src_func_id, const string& dst_func_id,
851 const std::unordered_map<const Node*, Node*>& node_images,
852 const Node* original_src_node, Node** src_image) {
853 if (IsInSubgraph(src_func_id)) {
854 // The edge is from a subgraph to a regular node in the output graph so
855 // use the subgraph's call node output.
856 *src_image = subgraphs_.at(src_func_id).GetCallNode();
857 } else {
858 // The source of the edge is in the output graph so use the node image in
859 // the output graph.
860 *src_image = node_images.at(original_src_node);
861 }
862 return Status::OK();
863 }
864
FindOutputSlotOfEdgeSrc(const string & src_func_id,const string & dst_func_id,const Edge * edge)865 int Encapsulator::FindOutputSlotOfEdgeSrc(const string& src_func_id,
866 const string& dst_func_id,
867 const Edge* edge) {
868 if (IsInSubgraph(src_func_id)) {
869 const Subgraph& src_subgraph = subgraphs_.at(src_func_id);
870 // 'src' is in a subgraph and 'dst' is a regular node in the output
871 // graph. Use the corresponding call output instead.
872 return src_subgraph.GetResultIndexForEdge(edge);
873 } else {
874 // The source of the edge is in the output graph so use the regular edge
875 // slot.
876 return edge->src_output();
877 }
878 }
879
FindOutputImageOfEdgeDst(const string & src_func_id,const string & dst_func_id,const std::unordered_map<const Node *,Node * > & node_images,const Node * original_dst_node,Node ** dst_image)880 Status Encapsulator::FindOutputImageOfEdgeDst(
881 const string& src_func_id, const string& dst_func_id,
882 const std::unordered_map<const Node*, Node*>& node_images,
883 const Node* original_dst_node, Node** dst_image) {
884 if (IsInSubgraph(dst_func_id)) {
885 // The edge is to a subgraph from a regular node in the output graph so
886 // use the subgraph's call node input.
887 *dst_image = subgraphs_.at(dst_func_id).GetCallNode();
888 } else {
889 // The destination of the edge is in the output graph so use the node image
890 // in the output graph.
891 *dst_image = node_images.at(original_dst_node);
892 }
893 return Status::OK();
894 }
895
FindOutputSlotOfEdgeDst(const string & src_func_id,const string & dst_func_id,const Edge * edge)896 int Encapsulator::FindOutputSlotOfEdgeDst(const string& src_func_id,
897 const string& dst_func_id,
898 const Edge* edge) {
899 if (IsInSubgraph(dst_func_id)) {
900 const Subgraph& dst_subgraph = subgraphs_.at(dst_func_id);
901 // 'dst' is in a subgraph and 'src' is a regular node in the output
902 // graph. Use the corresponding call input instead.
903 return dst_subgraph.GetArgIndexForEdge(edge);
904 } else {
905 // The destination of the edge is in the output graph so use the regular
906 // edge slot.
907 return edge->dst_input();
908 }
909 }
910
CopyEdgeToOutputGraph(const Edge * edge,const string & src_func_id,const string & dst_func_id,const std::unordered_map<const Node *,Node * > & node_images,Graph * graph_out,std::unordered_set<std::pair<OutputTensor,InputTensor>,OutputInputTensorPairHasher> * edges_added)911 Status Encapsulator::CopyEdgeToOutputGraph(
912 const Edge* edge, const string& src_func_id, const string& dst_func_id,
913 const std::unordered_map<const Node*, Node*>& node_images, Graph* graph_out,
914 std::unordered_set<std::pair<OutputTensor, InputTensor>,
915 OutputInputTensorPairHasher>* edges_added) {
916 Node* src_image;
917 TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc(
918 src_func_id, dst_func_id, node_images, edge->src(), &src_image));
919 Node* dst_image;
920 TF_RETURN_IF_ERROR(FindOutputImageOfEdgeDst(
921 src_func_id, dst_func_id, node_images, edge->dst(), &dst_image));
922
923 // If this is a control edge then copy it and return. Lift control edges onto
924 // the enclosing call operator.
925 if (edge->IsControlEdge()) {
926 // Add the control edge, if we have not already added it, using the images
927 // determined above (potentially call operators or RecvAtHost/SendFromHost).
928 if (edges_added
929 ->emplace(OutputTensor(src_image, -1), InputTensor(dst_image, -1))
930 .second) {
931 graph_out->AddControlEdge(src_image, dst_image,
932 /* allow_duplicates= */ true);
933 }
934
935 return Status::OK();
936 }
937
938 int src_output = FindOutputSlotOfEdgeSrc(src_func_id, dst_func_id, edge);
939
940 int dst_input = FindOutputSlotOfEdgeDst(src_func_id, dst_func_id, edge);
941
942 // Add the edge, if we have not already added it.
943 if (edges_added
944 ->emplace(OutputTensor(src_image, src_output),
945 InputTensor(dst_image, dst_input))
946 .second) {
947 graph_out->AddEdge(src_image, src_output, dst_image, dst_input);
948 }
949 return Status::OK();
950 }
951
AddEdgesToOutputGraph(const std::unordered_map<const Node *,Node * > & node_images,Graph * graph_out)952 Status Encapsulator::AddEdgesToOutputGraph(
953 const std::unordered_map<const Node*, Node*>& node_images,
954 Graph* graph_out) {
955 // Set of edges already added to the output graph, represented as (src, dst)
956 // pairs. We use the set to deduplicate edges; multiple edges in the input
957 // graph may map to one edge in the output graph.
958 std::unordered_set<std::pair<OutputTensor, InputTensor>,
959 OutputInputTensorPairHasher>
960 edges_added;
961
962 for (const Edge* edge : graph_in_->edges()) {
963 string src_func_id;
964 TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id));
965 string dst_func_id;
966 TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id));
967
968 // Ignore edges that are strictly contained within one subgraph, unless
969 // we are constructing parallel check graphs.
970 if (IsInSubgraph(src_func_id) && IsInSubgraph(dst_func_id) &&
971 src_func_id == dst_func_id) {
972 continue;
973 }
974
975 // We have an edge that crosses a cluster boundary or is entirely within the
976 // unclustered graph.
977 TF_RETURN_IF_ERROR(CopyEdgeToOutputGraph(
978 edge, src_func_id, dst_func_id, node_images, graph_out, &edges_added));
979 }
980
981 for (auto& subgraph_entry : subgraphs_) {
982 Subgraph& subgraph = subgraph_entry.second;
983 subgraph.ConnectSequencerToCallNode(graph_out);
984 }
985
986 return Status::OK();
987 }
988
989 namespace {
990
991 // Adds a dummy Const node to graph_out. The "constant" has the type of
992 // data_type and the shape indicated in 'shape'. The dummy node is not a valid
993 // Const node because it does not have any value defined, but this doesn't
994 // matter because it will only be used subsequently for shape inference. (It
995 // would be possible to add a switch statement over data_type to create a value
996 // for the constant, but that would entail maintaining the logic as new types
997 // are added, and is not necessary.) If the node being replaced was within a
998 // control flow frame, adds appropriate Enter nodes so that the use of the Const
999 // is well-formed.
AddDummyShapedNode(const Node * src_node,int src_port,const std::vector<ControlFlowInfo> & control_flow_info,const TensorShapeProto & shape,Graph * graph_out)1000 Node* AddDummyShapedNode(const Node* src_node, int src_port,
1001 const std::vector<ControlFlowInfo>& control_flow_info,
1002 const TensorShapeProto& shape, Graph* graph_out) {
1003 DataType data_type = src_node->output_type(src_port);
1004 TensorProto dummy_proto;
1005 dummy_proto.set_dtype(data_type);
1006 *dummy_proto.mutable_tensor_shape() = shape;
1007 // Don't set any value field in the proto, since it is only going to be used
1008 // for shape inference.
1009
1010 GraphDefBuilder::Options options(graph_out, /*status=*/nullptr);
1011 NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const",
1012 options.op_registry());
1013 node_builder.Attr("dtype", data_type).Attr("value", dummy_proto);
1014 Node* node = options.FinalizeBuilder(&node_builder);
1015 // Add any Enter nodes required to bring the constant to the correct control
1016 // flow frame.
1017 while (!control_flow_info[src_node->id()].frame_name.empty()) {
1018 NodeDebugInfo debug_info(*src_node);
1019 NodeBuilder enter_builder(options.GetNameForOp("Enter"), "Enter",
1020 options.op_registry(), &debug_info);
1021 enter_builder.Attr("frame_name",
1022 control_flow_info[src_node->id()].frame_name);
1023 enter_builder.Attr("is_constant", true);
1024 enter_builder.Input(node, 0);
1025 Node* enter_node = options.FinalizeBuilder(&enter_builder);
1026 // Adopt the new Enter node as the value in the current frame.
1027 node = enter_node;
1028 // Recurse to the parent frame to see if more Enter nodes need to be added.
1029 src_node = control_flow_info[src_node->id()].parent_frame;
1030 }
1031 return node;
1032 }
1033
1034 } // namespace
1035
MakePrunedGraphCopyAndInline(const Graph & graph,const std::vector<Node * > & sink_nodes,std::unique_ptr<Graph> * pruned_graph,std::unordered_map<const Node *,Node * > * node_images,FunctionLibraryDefinition * library)1036 Status Encapsulator::MakePrunedGraphCopyAndInline(
1037 const Graph& graph, const std::vector<Node*>& sink_nodes,
1038 std::unique_ptr<Graph>* pruned_graph,
1039 std::unordered_map<const Node*, Node*>* node_images,
1040 FunctionLibraryDefinition* library) {
1041 // First copy all ancestor nodes of sink_nodes into a new graph.
1042 pruned_graph->reset(new Graph(library));
1043 (*pruned_graph)->set_versions(graph.versions());
1044 ReverseDFSFrom(graph, sink_nodes,
1045 /*enter=*/nullptr,
1046 /*leave=*/[&](Node* n) {
1047 if (!n->IsSource()) {
1048 Node* copied = (*pruned_graph)->CopyNode(n);
1049 node_images->emplace(n, copied);
1050 }
1051 });
1052
1053 // Add all the edges between copied nodes.
1054 for (auto entry : *node_images) {
1055 const Node* orig = entry.first;
1056 Node* image = entry.second;
1057 for (const Edge* out_edge : orig->out_edges()) {
1058 auto iter = node_images->find(out_edge->dst());
1059 if (iter != node_images->end()) {
1060 // The source and destination are both in the copied graph.
1061 (*pruned_graph)
1062 ->AddEdge(image, out_edge->src_output(), iter->second,
1063 out_edge->dst_input());
1064 }
1065 }
1066 }
1067
1068 // Find all the function call nodes, and inline them.
1069 std::vector<Node*> function_nodes;
1070 for (auto node : (*pruned_graph)->nodes()) {
1071 const OpRegistrationData* op_reg_data;
1072 TF_RETURN_IF_ERROR(library->LookUp(node->type_string(), &op_reg_data));
1073 if (op_reg_data->is_function_op) {
1074 function_nodes.push_back(node);
1075 }
1076 }
1077 for (auto node : function_nodes) {
1078 VLOG(2) << "Inlining function " << node->name();
1079 const FunctionDef* fdef = library->Find(node->type_string());
1080 if (fdef == nullptr) {
1081 return errors::Internal("Failed to find function ", node->type_string(),
1082 " in function library.");
1083 }
1084 std::unique_ptr<FunctionBody> fbody;
1085 TF_RETURN_IF_ERROR(
1086 FunctionDefToBodyHelper(*fdef, node->attrs(), library, &fbody));
1087
1088 InlineFunctionBodyOptions inline_opts;
1089 TF_RETURN_IF_ERROR(InlineFunctionBody(*library, pruned_graph->get(), node,
1090 fbody.get(), inline_opts));
1091 }
1092
1093 return Status::OK();
1094 }
1095
BuildOutputGraph(Graph * graph_out,FunctionLibraryDefinition * library)1096 Status Encapsulator::BuildOutputGraph(Graph* graph_out,
1097 FunctionLibraryDefinition* library) {
1098 // Map from nodes in the input graph to nodes in the output graph.
1099 std::unordered_map<const Node*, Node*> node_images;
1100
1101 TF_RETURN_IF_ERROR(CopyNodesToOutputGraph(graph_out, &node_images));
1102 TF_RETURN_IF_ERROR(AddFunctionCallNodes(node_images, graph_out));
1103 TF_RETURN_IF_ERROR(AddEdgesToOutputGraph(node_images, graph_out));
1104
1105 return Status::OK();
1106 }
1107
1108 } // anonymous namespace
1109
EncapsulateSubgraphsInFunctions(string group_attribute,const Graph & graph_in,const RewriteSubgraphFn & rewrite_subgraph_fn,bool reuse_existing_functions,std::unique_ptr<Graph> * graph_out,FunctionLibraryDefinition * library)1110 Status EncapsulateSubgraphsInFunctions(
1111 string group_attribute, const Graph& graph_in,
1112 const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions,
1113 std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library) {
1114 Encapsulator encapsulator(std::move(group_attribute),
1115 &graph_in);
1116 TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs(library));
1117
1118 TF_RETURN_IF_ERROR(encapsulator.BuildFunctionDefs(
1119 rewrite_subgraph_fn, reuse_existing_functions, library));
1120
1121 std::unique_ptr<Graph> out(new Graph(library));
1122 out->set_versions(graph_in.versions());
1123 TF_RETURN_IF_ERROR(encapsulator.BuildOutputGraph(out.get(), library));
1124
1125 *graph_out = std::move(out);
1126 return Status::OK();
1127 }
1128
1129 // Finds the types of the _Arg nodes, indexed by position.
GetArgTypes(const Graph & graph,DataTypeVector * types)1130 static Status GetArgTypes(const Graph& graph, DataTypeVector* types) {
1131 for (Node* n : graph.op_nodes()) {
1132 if (n->type_string() == kArgOp) {
1133 int index;
1134 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
1135 const int num_types = types->size();
1136 if (index < 0 || index >= num_types) {
1137 return errors::InvalidArgument("Invalid argument number");
1138 }
1139 (*types)[index] = n->output_type(0);
1140 }
1141 }
1142 return Status::OK();
1143 }
1144
1145 // Renumber the indices of _Arg nodes in a graph, according to
1146 // 'permutation' that maps old indices to new indices.
RenumberArguments(Graph * graph,const std::vector<int> & permutation)1147 static Status RenumberArguments(Graph* graph,
1148 const std::vector<int>& permutation) {
1149 for (Node* n : graph->op_nodes()) {
1150 if (n->type_string() == kArgOp) {
1151 int index;
1152 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
1153 const int permutation_size = permutation.size();
1154 if (index < 0 || index >= permutation_size) {
1155 return errors::InvalidArgument("Invalid argument number");
1156 }
1157 n->AddAttr("index", permutation[index]);
1158 }
1159 }
1160 return Status::OK();
1161 }
1162
Run(const GraphOptimizationPassOptions & options)1163 Status EncapsulateSubgraphsPass::Run(
1164 const GraphOptimizationPassOptions& options) {
1165 VLOG(1) << "EncapsulateSubgraphsPass::Run";
1166 if (VLOG_IS_ON(1)) {
1167 DumpGraphToFile("encapsulate_subgraphs_before", **options.graph,
1168 options.flib_def);
1169 }
1170
1171 std::unique_ptr<Graph> graph_out;
1172 FunctionLibraryDefinition* const library = options.flib_def;
1173
1174 // Constant folding below might need to run part of the function to compute
1175 // constants. Create an FunctionLibraryRuntime with a single CPU device
1176 // that can run the part of the function.
1177 // NOTE: If this turns out to be slow, we can cache the FLRs keyed by
1178 // `options`.
1179 SessionOptions session_options;
1180 auto* device_count = session_options.config.mutable_device_count();
1181 device_count->insert({"CPU", 1});
1182 std::vector<std::unique_ptr<Device>> devices;
1183
1184 DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU");
1185 if (!cpu_factory) {
1186 return errors::NotFound(
1187 "CPU Factory not registered. Can't run EncapsulateSubgraphsPass");
1188 }
1189 TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(
1190 session_options, "/job:localhost/replica:0/task:0", &devices));
1191 if (devices.empty()) {
1192 return errors::NotFound(
1193 "Failed to create a CPU device for EncapsulateSubgraphsPass");
1194 }
1195
1196 std::unique_ptr<DeviceMgr> device_mgr =
1197 absl::make_unique<StaticDeviceMgr>(std::move(devices));
1198 const auto* config = &options.session_options->config;
1199 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
1200 new ProcessFunctionLibraryRuntime(
1201 device_mgr.get(), options.session_options->env,
1202 /*config=*/config, TF_GRAPH_DEF_VERSION, library,
1203 config->graph_options().optimizer_options()));
1204 FunctionLibraryRuntime* flr =
1205 pflr->GetFLR("/job:localhost/replica:0/task:0/device:CPU:0");
1206 if (flr == nullptr) {
1207 return errors::Internal(
1208 "Failed to create and retrieve function library runtime to run "
1209 "constant folding");
1210 }
1211
1212 auto rewrite_subgraph =
1213 [flr](const std::vector<OutputTensor>& arg_source_tensors,
1214 std::unique_ptr<Graph>* subgraph,
1215 std::vector<int>* input_permutation,
1216 std::vector<int>* output_permutation, NodeDef* node) {
1217 // Optimize the subgraph.
1218 // Do not constant fold nodes that output DT_VARIANT type tensors.
1219 // XLA does not support Const nodes of Variant type since it needs
1220 // to know the original ops to be able to compile them to the relevant
1221 // XLA form.
1222 // TODO(srbs): This filter is a little conservative. E.g. a subgraph of
1223 // the form:
1224 // Const
1225 // |
1226 // EmptyTensorList -> TensorListPushBack -> TensorListPopBack -> Op
1227 // |
1228 // (Discard popped list)
1229 //
1230 // Would have been reduced to "Const -> Op" without this filter.
1231 // However since we are only allowed to specify the filter at the "Node"
1232 // level there is no good way to allow the above behavior. So we
1233 // disallow any sort of constant folding on Variant nodes for now.
1234 bool disable_constant_folding =
1235 GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding;
1236 auto cf_consider_fn = [disable_constant_folding](const Node* n) {
1237 if (disable_constant_folding) return false;
1238 for (const auto& output_arg : n->op_def().output_arg()) {
1239 if (output_arg.type() == DT_VARIANT) {
1240 return false;
1241 }
1242 }
1243 return true;
1244 };
1245 GraphOptimizer::Options graph_optimizer_options;
1246 graph_optimizer_options.cf_consider_fn = cf_consider_fn;
1247 OptimizeGraph(flr, subgraph, graph_optimizer_options);
1248
1249 const int num_args = input_permutation->size();
1250 std::vector<bool> const_args(num_args);
1251 TF_RETURN_IF_ERROR(
1252 BackwardsConstAnalysis(**subgraph, &const_args,
1253 /*compile_time_const_nodes=*/nullptr, flr));
1254
1255 DataTypeVector arg_types(num_args);
1256 TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types));
1257
1258 // Compute a permutation of the arguments such that the constant
1259 // arguments are first.
1260 const int num_consts =
1261 std::count(const_args.begin(), const_args.end(), true);
1262
1263 const int num_resources =
1264 std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE);
1265 const int num_nonconsts = num_args - num_resources - num_consts;
1266 if (num_nonconsts < 0) {
1267 return errors::Internal("num_nonconsts should be >= 0, was ",
1268 num_nonconsts);
1269 }
1270
1271 int const_pos = 0;
1272 int arg_pos = num_consts;
1273 int resource_pos = num_consts + num_nonconsts;
1274 for (int i = 0; i < num_args; ++i) {
1275 if (const_args[i]) {
1276 if (arg_types[i] == DT_RESOURCE) {
1277 return errors::Internal(
1278 "Resource arguments cannot be constant (argument ", i, ")");
1279 }
1280 (*input_permutation)[i] = const_pos;
1281 ++const_pos;
1282 } else if (arg_types[i] == DT_RESOURCE) {
1283 (*input_permutation)[i] = resource_pos;
1284 ++resource_pos;
1285 } else {
1286 (*input_permutation)[i] = arg_pos;
1287 ++arg_pos;
1288 }
1289 }
1290
1291 // Renumber argument nodes in the graph.
1292 TF_RETURN_IF_ERROR(
1293 RenumberArguments(subgraph->get(), *input_permutation));
1294
1295 // TODO(phawkins): add a forward is-constant analysis, similarly split
1296 // outputs into host-memory constants and device-memory non-constants.
1297
1298 AddNodeAttr(kXlaCompiledKernelAttr, true, node);
1299 AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node);
1300 AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node);
1301 return Status::OK();
1302 };
1303
1304 TF_RETURN_WITH_CONTEXT_IF_ERROR(
1305 EncapsulateSubgraphsInFunctions(
1306 kXlaClusterAttr, **options.graph, rewrite_subgraph,
1307 /*reuse_existing_functions=*/false, &graph_out, library),
1308 "EncapsulateSubgraphsPass failed");
1309
1310 if (VLOG_IS_ON(1)) {
1311 DumpGraphToFile("encapsulate_subgraphs_after", *graph_out,
1312 options.flib_def);
1313 }
1314
1315 *options.graph = std::move(graph_out);
1316 TF_ASSIGN_OR_RETURN(absl::flat_hash_set<Node*> ref_related_nodes,
1317 GetNodesRelatedToRefVariables(**options.graph, flr));
1318 for (Node* node : (*options.graph)->nodes()) {
1319 bool has_ref_vars = ref_related_nodes.contains(node);
1320 node->AddAttr(kXlaHasReferenceVarsAttr, has_ref_vars);
1321 VLOG(3) << "Has ref vars = " << has_ref_vars
1322 << ", node: " << node->def().SerializeAsString();
1323 }
1324 return Status::OK();
1325 }
1326
IsXlaCompiledKernel(const Node & node)1327 bool IsXlaCompiledKernel(const Node& node) {
1328 bool is_compiled = false;
1329 bool has_compilation_attr =
1330 TryGetNodeAttr(node.attrs(), kXlaCompiledKernelAttr, &is_compiled) &&
1331 is_compiled;
1332 return has_compilation_attr ? is_compiled : false;
1333 }
1334
1335 } // namespace tensorflow
1336