1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <numeric>
21 #include <string>
22 #include <unordered_map>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/match.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/compiler/jit/flags.h"
30 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
31 #include "tensorflow/compiler/jit/shape_inference_helpers.h"
32 #include "tensorflow/compiler/jit/xla_cluster_util.h"
33 #include "tensorflow/compiler/tf2xla/const_analysis.h"
34 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
35 #include "tensorflow/compiler/xla/status_macros.h"
36 #include "tensorflow/core/common_runtime/device_factory.h"
37 #include "tensorflow/core/common_runtime/function.h"
38 #include "tensorflow/core/common_runtime/optimization_registry.h"
39 #include "tensorflow/core/common_runtime/shape_refiner.h"
40 #include "tensorflow/core/framework/function.h"
41 #include "tensorflow/core/framework/graph_def_util.h"
42 #include "tensorflow/core/framework/graph_to_functiondef.h"
43 #include "tensorflow/core/framework/node_def_builder.h"
44 #include "tensorflow/core/framework/node_def_util.h"
45 #include "tensorflow/core/framework/tensor.pb.h"
46 #include "tensorflow/core/graph/algorithm.h"
47 #include "tensorflow/core/graph/control_flow.h"
48 #include "tensorflow/core/graph/graph.h"
49 #include "tensorflow/core/graph/graph_def_builder.h"
50 #include "tensorflow/core/graph/tensor_id.h"
51 #include "tensorflow/core/lib/gtl/map_util.h"
52 #include "tensorflow/core/lib/hash/hash.h"
53 #include "tensorflow/core/public/session_options.h"
54 #include "tensorflow/core/public/version.h"
55 #include "tensorflow/core/util/device_name_utils.h"
56 #include "tensorflow/core/util/dump_graph.h"
57 
58 namespace tensorflow {
59 
60 const char* const kXlaCompiledKernelAttr = "_XlaCompiledKernel";
61 const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs";
62 const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs";
63 const char* const kXlaHostTransferSequencerAttr =
64     "_xla_host_transfer_sequencer";
65 const char* const kXlaHasReferenceVarsAttr = "_XlaHasReferenceVars";
66 
SortControlInputs(GraphDef * gdef)67 void SortControlInputs(GraphDef* gdef) {
68   int64 num_nodes = gdef->node_size();
69   for (int64 i = 0; i < num_nodes; ++i) {
70     NodeDef* node = gdef->mutable_node(i);
71     // Stable sort control inputs and leave the order of data inputs unchanged.
72     std::stable_sort(node->mutable_input()->begin(),
73                      node->mutable_input()->end(),
74                      [](const string& a, const string& b) {
75                        bool a_is_control = absl::StartsWith(a, "^");
76                        bool b_is_control = absl::StartsWith(b, "^");
77                        return (!a_is_control && b_is_control) ||
78                               (a_is_control && b_is_control && a < b);
79                      });
80   }
81 }
82 
83 namespace {
84 
AreAllParentsGuaranteedConst(const Node & n,const absl::flat_hash_set<const Node * > & runtime_const_nodes)85 bool AreAllParentsGuaranteedConst(
86     const Node& n,
87     const absl::flat_hash_set<const Node*>& runtime_const_nodes) {
88   if (n.type_string() == "GuaranteeConst") {
89     // If the current node is itself a cast-to-const, no need
90     // to look at the incoming edges.
91     return true;
92   }
93 
94   bool all_parents_const = true;
95   bool atleast_one_non_control_edge = false;
96   for (const Edge* in : n.in_edges()) {
97     atleast_one_non_control_edge =
98         atleast_one_non_control_edge || !in->IsControlEdge();
99     if (!in->IsControlEdge() && runtime_const_nodes.count(in->src()) == 0) {
100       all_parents_const = false;
101       break;
102     }
103   }
104   return all_parents_const && atleast_one_non_control_edge;
105 }
106 
MarkGuaranteedConstants(const Graph & graph,const std::vector<std::pair<const Node *,Node * >> & src_arg_pairs)107 void MarkGuaranteedConstants(
108     const Graph& graph,
109     const std::vector<std::pair<const Node*, Node*>>& src_arg_pairs) {
110   absl::flat_hash_set<const Node*> guaranteed_const_nodes;
111   std::vector<const Node*> srcs;
112   srcs.reserve(src_arg_pairs.size());
113   for (const auto& src_arg : src_arg_pairs) {
114     srcs.push_back(src_arg.first);
115   }
116   ReverseDFSFrom(
117       graph, srcs, /*enter=*/nullptr,
118       /*leave=*/[&guaranteed_const_nodes](const Node* n) {
119         // TODO(vinuraja): Doesn't work in the presence of loops.
120         if (AreAllParentsGuaranteedConst(*n, guaranteed_const_nodes)) {
121           guaranteed_const_nodes.insert(n);
122         }
123       });
124 
125   for (auto& src_arg : src_arg_pairs) {
126     if (guaranteed_const_nodes.count(src_arg.first) != 0) {
127       VLOG(1) << "Guaranteed const found: " << src_arg.first->DebugString();
128       src_arg.second->AddAttr("_is_guaranteed_constant", true);
129     }
130   }
131 }
132 
133 struct OutputInputTensorPairHasher {
operator ()tensorflow::__anona26b85820211::OutputInputTensorPairHasher134   uint64 operator()(std::pair<OutputTensor, InputTensor> const& s) const {
135     return Hash64Combine(OutputTensor::Hash()(s.first),
136                          InputTensor::Hash()(s.second));
137   }
138 };
139 
140 // TODO(phawkins) add a canonical copy of these operator names and refactor
141 // everything to use it.
142 static const char* const kArgOp = "_Arg";
143 static const char* const kRetValOp = "_Retval";
144 static const char* const kHostComputeOp = "XlaHostCompute";
145 static const char* const kSendFromHostOp = "_XlaSendFromHost";
146 static const char* const kRecvAtHostOp = "_XlaRecvAtHost";
147 
148 class Encapsulator {
149  public:
Encapsulator(string group_attribute,Graph const * graph_in)150   Encapsulator(string group_attribute, Graph const* graph_in)
151       : group_attribute_(std::move(group_attribute)), graph_in_(graph_in) {}
152 
153   // Find subgraphs marked with 'group_attribute', and build a new
154   // subgraph, one for each value of 'group_attribute'.
155   Status SplitIntoSubgraphs(FunctionLibraryDefinition* library);
156 
157   // Build a FunctionDef for each subgraph, and add it 'library'. The values of
158   // the 'group_attribute' annotations become the function names.
159   // If 'reuse_existing_functions' is set, use an existing function with the
160   // same name, if any.
161   // If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
162   // function conversion.
163   Status BuildFunctionDefs(const RewriteSubgraphFn& rewrite_subgraph_fn,
164                            bool reuse_existing_functions,
165                            FunctionLibraryDefinition* library);
166 
167   // Write a copy of the input graph to 'graph_out', where the subgraphs are
168   // replaced with calls to the new functions.
169   Status BuildOutputGraph(Graph* graph_out, FunctionLibraryDefinition* library);
170 
171  private:
172   // A subgraph of the input, all marked with a common 'group_attribute'
173   // value.
174   //
175   // In the following simple example, A, B, ..., E are nodes in the original
176   // graph. The group attributes g are each shown as either 0 or empty.
177   //
178   //  A  -->  B  -->  C  -->  D  -->  E
179   //  g:      g:0     g:0     g:0     g:
180   //
181   // The example is rewritten to two graphs; one on the host and one to be
182   // compiled. The host graph is as follows.
183   //
184   //  A  -->  Call  -->  E
185   //
186   // The compiled cluster is as follows.
187   //
188   //  Arg  --> B  --> C  --> D --> Retval
189   class Subgraph {
190    public:
191     // Creates a graph to build the subgraph in, if it doesn't already exist,
192     // using the same op registry and versions as graph_in.
193     Node* MakeNodeImage(const Graph* graph_in, Node* node);
194 
195     // Returns the graph the subgraph is being built in.
196     Graph* GetGraph() const;
197 
198     // Builds a FunctionDef, and adds it to 'library'. The value of the
199     // 'group_attribute' annotations becomes the function name.  If
200     // 'reuse_existing_functions' is set, use an existing function with the same
201     // name, if any.  If 'rewrite_subgraph_fn' is set, it is applied to the
202     // subgraph before function conversion.
203     Status BuildFunctionDef(const string& name_in,
204                             const RewriteSubgraphFn& rewrite_subgraph_fn,
205                             bool reuse_existing_functions,
206                             FunctionLibraryDefinition* library);
207 
208     // Adds the function call node to graph_out.
209     Status AddFunctionCallNode(
210         const std::unordered_map<const Node*, Node*>& node_images,
211         Graph* graph_out);
212 
213     // Returns the Node that the inputs and outputs of the function should be
214     // wired up to.
215     Node* GetCallNode() const;
216 
217     // Returns the index of the arg that the dst of edge should connect to.
218     int GetArgIndexForEdge(const Edge* edge) const;
219 
220     // Returns the index of the result that the src of edge should connect to.
221     int GetResultIndexForEdge(const Edge* edge) const;
222 
223     // Creates an _Arg node for the src node of edge, and add its index to
224     // args_by_src_, if none exists yet. Also adds its index to args_by_dst_,
225     // and adds the edge within the subgraph from the _Arg node to the image of
226     // the dst node.
227     Status RecordArg(const Edge* edge,
228                      const std::unordered_map<const Node*, Node*>& node_images,
229                      std::vector<std::pair<const Node*, Node*>>* src_arg_pairs);
230 
231     // Records the src of the given edge as a control result of the graph.
232     // Used during graph to function conversion to tie control results to
233     // the function signature.
234     Status RecordControlResult(
235         const Edge* edge,
236         const std::unordered_map<const Node*, Node*>& node_images);
237 
238     // Creates a _Retval node for the src node of edge, and add it to results_,
239     // if none exists yet. If a new _Retval node is created, also adds the edge
240     // within the subgraph from the src to the _Retval node.
241     Status RecordResult(
242         const Edge* edge,
243         const std::unordered_map<const Node*, Node*>& node_images);
244 
245     // Creates the sequencer node if it doesn't exist, adding it to graph_out.
246     Status MakeSequencingNode(const string& subgraph_name, Graph* graph_out);
247 
248     // If there is a sequencer node, adds a control edge from the sequencer to
249     // the call node.
250     void ConnectSequencerToCallNode(Graph* graph_out);
251 
252     Status ReplaceFunctionDef(FunctionLibraryDefinition* library);
253 
254    private:
255     // The subgraph extracted from the input graph, suitable for being turned
256     // into a FunctionDef. Inputs are fed by _Arg nodes, and outputs are
257     // returned by _Retval nodes.
258     std::unique_ptr<Graph> graph_;
259 
260     // Which device are these nodes on? Used to assign a device to the call
261     // node.
262     string device_;
263 
264     // NodeDef for the function call node.
265     NodeDef call_node_def_;
266 
267     // Name that is used for the call node. This may not be
268     // call_node_def_.name() if the client supplies a rewrite lambda.
269     string function_def_name_;
270 
271     // Placeholder node simulating the host compute key in the output graph.
272     // Not owned.
273     Node* host_compute_key_placeholder_ = nullptr;
274 
275     // Function call node in the output graph. Not owned.
276     Node* call_node_;
277 
278     // Maps from source (producer node/slot) and destination
279     // (consumer node/slot) tensors in the input graph to _Arg numbers in
280     // the subgraph. The source map is one-to-one, whereas the dest map may be
281     // many-to-one.
282     std::unordered_map<OutputTensor, int, OutputTensor::Hash> args_by_src_;
283     std::unordered_map<InputTensor, int, InputTensor::Hash> args_by_dst_;
284 
285     // The arguments to the subgraph, in order.
286     std::vector<Node*> args_;
287 
288     // Map from source tensor in the input graph to result #.
289     std::unordered_map<OutputTensor, int, OutputTensor::Hash> results_;
290 
291     // Set of node names that are the source of a control output of the
292     // subgraph. We store strings here so that we can tolerate nodes being
293     // removed from the graph.
294     absl::flat_hash_set<string> control_output_nodes_;
295 
296     // NoOp node in the output graph that is sequenced after the call node.
297     Node* sequencer_ = nullptr;
298   };
299 
300   // Returns the key attribute associated with a node in attr. Sets either
301   // result to the empty string if the respective attribute is not found.
302   Status GetFunctionNameAttr(Node const* node, string* attr) const;
303 
304   // Copies edges local to a subgraph. Adds _Arg and _Retval nodes to
305   // subgraphs for data edges that cross subgraph boundaries.
306   Status CopySubgraphEdges(
307       const std::unordered_map<const Node*, Node*>& node_images,
308       std::vector<std::pair<const Node*, Node*>>* src_arg_pairs);
309 
310   // Copies all marked nodes to a subgraph. Does nothing for unmarked nodes.
311   Status CopySubgraphNodes(std::unordered_map<const Node*, Node*>* node_images);
312 
313   // Copies all nodes that aren't in a compiled subgraph to the output graph.
314   Status CopyNodesToOutputGraph(
315       Graph* graph_out, std::unordered_map<const Node*, Node*>* node_images);
316 
317   // Adds function call nodes for each compiled subgraph.
318   Status AddFunctionCallNodes(
319       const std::unordered_map<const Node*, Node*>& node_images,
320       Graph* graph_out);
321 
322   // Finds the image of an edge source in the output graph. If the edge crosses
323   // a subgraph boundary it is the output of a call node, otherwise it is a node
324   // in the output graph.
325   Status FindOutputImageOfEdgeSrc(
326       const string& src_func_id, const string& dst_func_id,
327       const std::unordered_map<const Node*, Node*>& node_images,
328       const Node* original_src_node, Node** src_image);
329 
330   // Finds an edge source slot in the output graph. If the edge crosses a
331   // subgraph boundary it is a slot on the output of a call node, otherwise it
332   // is a slot on a node in the output graph.
333   int FindOutputSlotOfEdgeSrc(const string& src_func_id,
334                               const string& dst_func_id,
335                               const Edge* edge);
336 
337   // Finds the image of an edge destination in the output graph. If the edge
338   // crosses a subgraph boundary it is the input of a call node, otherwise it is
339   // a node in the output graph.
340   Status FindOutputImageOfEdgeDst(
341       const string& src_func_id, const string& dst_func_id,
342       const std::unordered_map<const Node*, Node*>& node_images,
343       const Node* original_dst_node, Node** dst_image);
344 
345   // Finds an edge destination slot in the output graph. If the edge crosses a
346   // subgraph boundary it is a slot on the input of a call node, otherwise it is
347   // a slot on a node in the output graph.
348   int FindOutputSlotOfEdgeDst(const string& src_func_id,
349                               const string& dst_func_id,
350                               const Edge* edge);
351 
352   // Copies a single edge to the output graph. The edge is either entirely
353   // within the output graph, or crosses into or out of a compiled subgraph.
354   Status CopyEdgeToOutputGraph(
355       const Edge* edge, const string& src_func_id, const string& dst_func_id,
356       const std::unordered_map<const Node*, Node*>& node_images,
357       Graph* graph_out,
358       std::unordered_set<std::pair<OutputTensor, InputTensor>,
359                          OutputInputTensorPairHasher>* edges_added);
360 
361   // Adds all edges to the output graph.
362   Status AddEdgesToOutputGraph(
363       const std::unordered_map<const Node*, Node*>& node_images,
364       Graph* graph_out);
365 
366   // Makes a copy of graph containing only nodes that are ancestors of at least
367   // one node in send_from_host_nodes and store it in pruned_graph. On exit
368   // nodes_images contains a mapping from nodes in graph to nodes in
369   // pruned_graph. All functions in the copied graph are inlined.
370   Status MakePrunedGraphCopyAndInline(
371       const Graph& graph, const std::vector<Node*>& sink_nodes,
372       std::unique_ptr<Graph>* pruned_graph,
373       std::unordered_map<const Node*, Node*>* node_images,
374       FunctionLibraryDefinition* library);
375 
376   const string group_attribute_;
377   const Graph* graph_in_;
378 
379   std::unordered_map<string, Subgraph> subgraphs_;
380 
381   TF_DISALLOW_COPY_AND_ASSIGN(Encapsulator);
382 };
383 
384 namespace {
385 
386 // Return in 'sorted' a topological sort of clusters according to the
387 // dependencies encoded in ancestors. clusters is the list of all clusters
388 // including clusters that are not present in the ancestors map. has_successors
389 // is the set of clusters that are ancestors of some other cluster.
TopologicalClusterSort(const std::unordered_set<string> & clusters,const std::unordered_set<string> & has_successors,const std::unordered_map<string,std::unordered_set<string>> & ancestors,std::vector<string> * sorted)390 void TopologicalClusterSort(
391     const std::unordered_set<string>& clusters,
392     const std::unordered_set<string>& has_successors,
393     const std::unordered_map<string, std::unordered_set<string>>& ancestors,
394     std::vector<string>* sorted) {
395   // The nodes are placed in 'sorted' in topological order.
396   sorted->clear();
397   // We don't use the standard DFS because we are not operating on Node*
398   // objects.
399   struct Work {
400     string cluster;
401     bool leave;
402   };
403   std::set<string> visited;
404   std::vector<Work> stack;
405   // Seed the processing list with clusters that have no successors.
406   for (const auto& cluster : clusters) {
407     if (has_successors.find(cluster) == has_successors.end()) {
408       stack.push_back({cluster, false});
409     }
410   }
411   while (!stack.empty()) {
412     const Work item = stack.back();
413     stack.pop_back();
414     if (item.leave) {
415       sorted->push_back(item.cluster);
416       continue;
417     }
418 
419     if (visited.find(item.cluster) != visited.end()) continue;
420     visited.insert(item.cluster);
421 
422     stack.push_back({item.cluster, true});
423     const auto& iter = ancestors.find(item.cluster);
424     if (iter != ancestors.end()) {
425       for (const auto& ancestor : iter->second) {
426         stack.push_back({ancestor, false});
427       }
428     }
429   }
430   CHECK(sorted->size() == clusters.size());
431 }
432 
433 }  // namespace
434 
GetCallNode() const435 Node* Encapsulator::Subgraph::GetCallNode() const { return call_node_; }
436 
GetArgIndexForEdge(const Edge * edge) const437 int Encapsulator::Subgraph::GetArgIndexForEdge(const Edge* edge) const {
438   return args_by_dst_.at(InputTensor(edge->dst(), edge->dst_input()));
439 }
440 
GetResultIndexForEdge(const Edge * edge) const441 int Encapsulator::Subgraph::GetResultIndexForEdge(const Edge* edge) const {
442   return results_.at(OutputTensor(edge->src(), edge->src_output()));
443 }
444 
MakeNodeImage(const Graph * graph_in,Node * node)445 Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) {
446   if (!graph_) {
447     graph_.reset(new Graph(graph_in->op_registry()));
448     graph_->set_versions(graph_in->versions());
449   }
450 
451   // TODO(b/116981129): Enhance how the device for the encapsulated subgraph is
452   // determined. In case of hard placement, ensure all the encapsulated nodes
453   // have the same requested device, which in turn will be the requested device
454   // for the entire encapsulated subgraph. In case of soft placement, use a
455   // deterministic approach to fill in the requested device. Handle co-location
456   // constraints similarly if they exist.
457   if (device_.empty()) {
458     device_ = node->assigned_device_name().empty()
459                   ? node->requested_device()
460                   : node->assigned_device_name();
461   }
462 
463   return graph_->CopyNode(node);
464 }
465 
GetGraph() const466 Graph* Encapsulator::Subgraph::GetGraph() const { return graph_.get(); }
467 
RecordArg(const Edge * edge,const std::unordered_map<const Node *,Node * > & node_images,std::vector<std::pair<const Node *,Node * >> * src_arg_pairs)468 Status Encapsulator::Subgraph::RecordArg(
469     const Edge* edge, const std::unordered_map<const Node*, Node*>& node_images,
470     std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
471   Node* src_node = edge->src();
472   int src_slot = edge->src_output();
473   std::unordered_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
474   bool inserted;
475   std::tie(iter, inserted) = args_by_src_.emplace(
476       OutputTensor(src_node, src_slot), args_by_src_.size());
477   int arg_index = iter->second;
478   if (inserted) {
479     NodeDef arg_def;
480     NodeDefBuilder builder(
481         absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp,
482         NodeDebugInfo(src_node->def()));
483     DataType dtype = edge->dst()->input_type(edge->dst_input());
484     builder.Attr("T", dtype);
485     builder.Attr("index", arg_index);
486     Status s = builder.Finalize(&arg_def);
487     if (!s.ok()) return s;
488 
489     Node* arg = graph_->AddNode(arg_def, &s);
490     if (!s.ok()) return s;
491 
492     src_arg_pairs->push_back({src_node, arg});
493     args_.push_back(arg);
494   }
495   Node* dst_node = edge->dst();
496   Node* dst_image = node_images.at(dst_node);
497   int dst_slot = edge->dst_input();
498   args_by_dst_[InputTensor(dst_node, dst_slot)] = arg_index;
499   graph_->AddEdge(args_[arg_index], 0, dst_image, dst_slot);
500   return Status::OK();
501 }
502 
RecordControlResult(const Edge * edge,const std::unordered_map<const Node *,Node * > & node_images)503 Status Encapsulator::Subgraph::RecordControlResult(
504     const Edge* edge,
505     const std::unordered_map<const Node*, Node*>& node_images) {
506   Node* src_node = edge->src();
507   Node* src_image = node_images.at(src_node);
508   control_output_nodes_.insert(src_image->name());
509   return Status::OK();
510 }
511 
RecordResult(const Edge * edge,const std::unordered_map<const Node *,Node * > & node_images)512 Status Encapsulator::Subgraph::RecordResult(
513     const Edge* edge,
514     const std::unordered_map<const Node*, Node*>& node_images) {
515   Node* src_node = edge->src();
516   Node* src_image = node_images.at(src_node);
517   int src_slot = edge->src_output();
518   std::unordered_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
519   bool inserted;
520   std::tie(iter, inserted) =
521       results_.emplace(OutputTensor(src_node, src_slot), results_.size());
522   int ret_index = iter->second;
523   if (inserted) {
524     NodeDef ret_def;
525     NodeDefBuilder builder(
526         absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp,
527         NodeDebugInfo(src_node->def()));
528     DataType dtype = src_node->output_type(src_slot);
529     builder.Attr("T", dtype);
530     builder.Attr("index", ret_index);
531     builder.Input(src_image->name(), src_slot, dtype);
532     Status s = builder.Finalize(&ret_def);
533     if (!s.ok()) return s;
534     Node* ret = graph_->AddNode(ret_def, &s);
535     if (!s.ok()) return s;
536 
537     graph_->AddEdge(src_image, src_slot, ret, 0);
538   }
539   return Status::OK();
540 }
541 
MakeSequencingNode(const string & subgraph_name,Graph * graph_out)542 Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name,
543                                                   Graph* graph_out) {
544   if (sequencer_ == nullptr) {
545     NodeDef seq_def;
546     // TODO(shikharagarwal): What source node should we use for errors?
547     NodeDefBuilder builder(absl::StrCat(subgraph_name, "_sequencer"), "NoOp");
548     builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name);
549     builder.Device(device_);
550     Status s = builder.Finalize(&seq_def);
551     if (!s.ok()) return s;
552 
553     sequencer_ = graph_out->AddNode(seq_def, &s);
554     if (!s.ok()) return s;
555   }
556   return Status::OK();
557 }
558 
ConnectSequencerToCallNode(Graph * graph_out)559 void Encapsulator::Subgraph::ConnectSequencerToCallNode(Graph* graph_out) {
560   if (sequencer_ != nullptr) {
561     VLOG(2) << "ConnectSequencerToCallNode";
562     graph_out->AddControlEdge(sequencer_, call_node_,
563                               /* allow_duplicates= */ true);
564   }
565 }
566 
BuildFunctionDef(const string & name_in,const RewriteSubgraphFn & rewrite_subgraph_fn,bool reuse_existing_functions,FunctionLibraryDefinition * library)567 Status Encapsulator::Subgraph::BuildFunctionDef(
568     const string& name_in, const RewriteSubgraphFn& rewrite_subgraph_fn,
569     bool reuse_existing_functions, FunctionLibraryDefinition* library) {
570   // name_in is copied here because name may be modified below if
571   // rewrite_subgraph_fn is true.
572   string name = name_in;
573   call_node_def_.set_op(name);
574   call_node_def_.set_name(name);
575   call_node_def_.set_device(device_);
576 
577   if (rewrite_subgraph_fn) {
578     std::vector<OutputTensor> arg_source_tensors(args_by_src_.size());
579     for (const auto& arg : args_by_src_) {
580       arg_source_tensors.at(arg.second) = arg.first;
581     }
582     // Initialize the input and output permutations to the identity.
583     std::vector<int> input_permutation(args_by_src_.size());
584     std::iota(input_permutation.begin(), input_permutation.end(), 0);
585     std::vector<int> output_permutation(results_.size());
586     std::iota(output_permutation.begin(), output_permutation.end(), 0);
587 
588     TF_RETURN_IF_ERROR(
589         rewrite_subgraph_fn(arg_source_tensors, &graph_, &input_permutation,
590                             &output_permutation, &call_node_def_));
591 
592     // Apply the input/output permutations to the 'args_by_...' and 'results_'
593     // mappings, so when we build edges in BuildOutputGraph() we
594     // connect them to the right input/output positions.
595     if (input_permutation.size() != args_by_src_.size()) {
596       return errors::InvalidArgument("Input permutation has incorrect size.");
597     }
598     if (output_permutation.size() != results_.size()) {
599       return errors::InvalidArgument("Output permutation has incorrect size.");
600     }
601     for (auto& arg : args_by_src_) {
602       arg.second = input_permutation[arg.second];
603     }
604     for (auto& arg : args_by_dst_) {
605       arg.second = input_permutation[arg.second];
606     }
607     for (auto& result : results_) {
608       result.second = output_permutation[result.second];
609     }
610 
611     name = call_node_def_.op();
612   }
613 
614   function_def_name_ = name;
615 
616   FunctionDef fdef;
617   auto lookup = [this](const Node* node) -> absl::optional<string> {
618     if (control_output_nodes_.contains(node->name())) {
619       return absl::make_optional(node->name());
620     }
621     return absl::nullopt;
622   };
623   // Verify that the graph has well-formed control flow structure.
624   std::vector<ControlFlowInfo> dummy;
625   TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_.get(), &dummy));
626   TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, lookup, &fdef));
627 
628   if (VLOG_IS_ON(1)) {
629     VLOG(2) << "Build function def " << name;
630     DumpGraphToFile(absl::StrCat("encapsulate_fdef_graph_", name), *graph_,
631                     library);
632     DumpFunctionDefToFile(absl::StrCat("encapsulate_fdef_", name), fdef);
633   }
634 
635   const FunctionDef* original_fdef = library->Find(name);
636   if (!reuse_existing_functions || original_fdef == nullptr) {
637     TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
638   } else if (!FunctionDefsEqual(*original_fdef, fdef)) {
639     TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef));
640   }
641   return Status::OK();
642 }
643 
ReplaceFunctionDef(FunctionLibraryDefinition * library)644 Status Encapsulator::Subgraph::ReplaceFunctionDef(
645     FunctionLibraryDefinition* library) {
646   const string& name = function_def_name_;
647 
648   FunctionDef fdef;
649   TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef));
650 
651   if (VLOG_IS_ON(1)) {
652     VLOG(2) << "Replace function def " << name;
653     DumpGraphToFile(absl::StrCat("replace_encapsulate_fdef_graph_", name),
654                     *graph_, library);
655     DumpFunctionDefToFile(absl::StrCat("replace_encapsulate_fdef_", name),
656                           fdef);
657   }
658 
659   TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef));
660   return Status::OK();
661 }
662 
AddFunctionCallNode(const std::unordered_map<const Node *,Node * > & node_images,Graph * graph_out)663 Status Encapsulator::Subgraph::AddFunctionCallNode(
664     const std::unordered_map<const Node*, Node*>& node_images,
665     Graph* graph_out) {
666   Status s;
667   call_node_ = graph_out->AddNode(call_node_def_, &s);
668   if (!s.ok()) return s;
669 
670   // Copy the assigned device and the key_annotation over.
671   call_node_->set_assigned_device_name(device_);
672 
673   return Status::OK();
674 }
675 
GetFunctionNameAttr(Node const * node,string * attr) const676 Status Encapsulator::GetFunctionNameAttr(Node const* node, string* attr) const {
677   AttrSlice attrs = node->attrs();
678   attr->clear();
679   for (const auto& node_attr : attrs) {
680     if (node_attr.first == group_attribute_) {
681       TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string"));
682       *attr = node_attr.second.s();
683       break;
684     }
685   }
686   return Status::OK();
687 }
688 
IsInSubgraph(const string & func_id)689 bool IsInSubgraph(const string& func_id) { return !func_id.empty(); }
690 
CopySubgraphNodes(std::unordered_map<const Node *,Node * > * node_images)691 Status Encapsulator::CopySubgraphNodes(
692     std::unordered_map<const Node*, Node*>* node_images) {
693   for (Node* node : graph_in_->op_nodes()) {
694     string func_id;
695     TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id));
696     if (!IsInSubgraph(func_id)) continue;
697 
698     Subgraph& subgraph = subgraphs_[func_id];
699     Node* image = subgraph.MakeNodeImage(graph_in_, node);
700     image->ClearAttr(group_attribute_);
701     (*node_images)[node] = image;
702   }
703   return Status::OK();
704 }
705 
CopySubgraphEdges(const std::unordered_map<const Node *,Node * > & node_images,std::vector<std::pair<const Node *,Node * >> * src_arg_pairs)706 Status Encapsulator::CopySubgraphEdges(
707     const std::unordered_map<const Node*, Node*>& node_images,
708     std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
709   for (const Edge* edge : graph_in_->edges()) {
710     string src_func_id;
711     TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id));
712     string dst_func_id;
713     TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id));
714     Node* src_image = gtl::FindWithDefault(node_images, edge->src(), nullptr);
715     Node* dst_image = gtl::FindWithDefault(node_images, edge->dst(), nullptr);
716 
717     // Copy edges that are local to a subgraph.
718     if (IsInSubgraph(src_func_id) && IsInSubgraph(dst_func_id) &&
719         src_func_id == dst_func_id) {
720       Graph* g = subgraphs_[src_func_id].GetGraph();
721       if (edge->IsControlEdge()) {
722         g->AddControlEdge(src_image, dst_image,
723                           /* allow_duplicates= */ true);
724       } else {
725         g->AddEdge(src_image, edge->src_output(), dst_image, edge->dst_input());
726       }
727       continue;
728     }
729 
730     // Record 'src' as an output of its subgraph, if applicable.
731     if (IsInSubgraph(src_func_id)) {
732       if (!edge->IsControlEdge()) {
733         DataType dtype = edge->src()->output_type(edge->src_output());
734         if (IsRefType(dtype)) {
735           return errors::InvalidArgument(
736               "Ref Tensors (e.g., Variables) are not supported as results: "
737               "tensor ",
738               edge->src()->name(), ":", edge->src_output());
739         }
740       }
741 
742       Subgraph& src_subgraph = subgraphs_[src_func_id];
743       if (edge->IsControlEdge()) {
744         TF_RETURN_IF_ERROR(src_subgraph.RecordControlResult(edge, node_images));
745       } else {
746         TF_RETURN_IF_ERROR(src_subgraph.RecordResult(edge, node_images));
747       }
748     }
749 
750     // Record 'dst' as an input of its subgraph, if applicable.
751     if (IsInSubgraph(dst_func_id)) {
752       // Look at the type of the destination not the source, since Ref output
753       // Tensors can be automatically cast to non-Ref Tensors at the
754       // destination.
755       if (!edge->IsControlEdge()) {
756         DataType dtype = edge->dst()->input_type(edge->dst_input());
757         if (IsRefType(dtype)) {
758           return errors::InvalidArgument(
759               "Ref Tensors (e.g., Variables) are not supported as args: "
760               "tensor ",
761               edge->src()->name(), ":", edge->src_output());
762         }
763       }
764 
765       Subgraph& dst_subgraph = subgraphs_[dst_func_id];
766       // Ignore control edges entering the subgraph. We will lift them onto
767       // the enclosing call operators in BuildOutputGraph().
768       if (!edge->IsControlEdge()) {
769         TF_RETURN_IF_ERROR(
770             dst_subgraph.RecordArg(edge, node_images, src_arg_pairs));
771       }
772     }
773   }
774   return Status::OK();
775 }
776 
SplitIntoSubgraphs(FunctionLibraryDefinition * library)777 Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) {
778   Status s;
779 
780   // Map from input graph nodes to subgraph nodes.
781   std::unordered_map<const Node*, Node*> node_images;
782 
783   // Each entry of src_arg_pairs is a pair whose first element is a node in the
784   // original graph that has an output edge in the subgraph, and whose second
785   // element is the arg node in the subgraph that it sends to. The vector will
786   // be filled in below in AddArgs.
787   std::vector<std::pair<const Node*, Node*>> src_arg_pairs;
788 
789   TF_RETURN_IF_ERROR(CopySubgraphNodes(&node_images));
790   TF_RETURN_IF_ERROR(CopySubgraphEdges(node_images, &src_arg_pairs));
791   MarkGuaranteedConstants(*graph_in_, src_arg_pairs);
792 
793   for (auto& entry : subgraphs_) {
794     Subgraph& subgraph = entry.second;
795     FixupSourceAndSinkEdges(subgraph.GetGraph());
796   }
797 
798   if (VLOG_IS_ON(1)) {
799     // Dump subgraphs.
800     for (auto& entry : subgraphs_) {
801       DumpGraphToFile(
802           absl::StrCat("encapsulate_subgraphs_subgraph_", entry.first),
803           *entry.second.GetGraph(), library);
804     }
805   }
806 
807   return s;
808 }
809 
BuildFunctionDefs(const RewriteSubgraphFn & rewrite_subgraph_fn,bool reuse_existing_functions,FunctionLibraryDefinition * library)810 Status Encapsulator::BuildFunctionDefs(
811     const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions,
812     FunctionLibraryDefinition* library) {
813   for (auto& subgraph_entry : subgraphs_) {
814     string name = subgraph_entry.first;
815     Subgraph& subgraph = subgraph_entry.second;
816     TF_RETURN_IF_ERROR(subgraph.BuildFunctionDef(
817         name, rewrite_subgraph_fn, reuse_existing_functions, library));
818   }
819   return Status::OK();
820 }
821 
CopyNodesToOutputGraph(Graph * graph_out,std::unordered_map<const Node *,Node * > * node_images)822 Status Encapsulator::CopyNodesToOutputGraph(
823     Graph* graph_out, std::unordered_map<const Node*, Node*>* node_images) {
824   for (Node* node : graph_in_->op_nodes()) {
825     string func_id;
826     TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id));
827 
828     // Don't copy nodes that are going to be encapsulated.
829     if (IsInSubgraph(func_id)) continue;
830 
831     Node* image = graph_out->CopyNode(node);
832     (*node_images)[node] = image;
833   }
834   (*node_images)[graph_in_->source_node()] = graph_out->source_node();
835   (*node_images)[graph_in_->sink_node()] = graph_out->sink_node();
836   return Status::OK();
837 }
838 
AddFunctionCallNodes(const std::unordered_map<const Node *,Node * > & node_images,Graph * graph_out)839 Status Encapsulator::AddFunctionCallNodes(
840     const std::unordered_map<const Node*, Node*>& node_images,
841     Graph* graph_out) {
842   for (auto& subgraph_entry : subgraphs_) {
843     TF_RETURN_IF_ERROR(
844         subgraph_entry.second.AddFunctionCallNode(node_images, graph_out));
845   }
846   return Status::OK();
847 }
848 
FindOutputImageOfEdgeSrc(const string & src_func_id,const string & dst_func_id,const std::unordered_map<const Node *,Node * > & node_images,const Node * original_src_node,Node ** src_image)849 Status Encapsulator::FindOutputImageOfEdgeSrc(
850     const string& src_func_id, const string& dst_func_id,
851     const std::unordered_map<const Node*, Node*>& node_images,
852     const Node* original_src_node, Node** src_image) {
853   if (IsInSubgraph(src_func_id)) {
854     // The edge is from a subgraph to a regular node in the output graph so
855     // use the subgraph's call node output.
856     *src_image = subgraphs_.at(src_func_id).GetCallNode();
857   } else {
858     // The source of the edge is in the output graph so use the node image in
859     // the output graph.
860     *src_image = node_images.at(original_src_node);
861   }
862   return Status::OK();
863 }
864 
FindOutputSlotOfEdgeSrc(const string & src_func_id,const string & dst_func_id,const Edge * edge)865 int Encapsulator::FindOutputSlotOfEdgeSrc(const string& src_func_id,
866                                           const string& dst_func_id,
867                                           const Edge* edge) {
868   if (IsInSubgraph(src_func_id)) {
869     const Subgraph& src_subgraph = subgraphs_.at(src_func_id);
870     // 'src' is in a subgraph and 'dst' is a regular node in the output
871     // graph. Use the corresponding call output instead.
872     return src_subgraph.GetResultIndexForEdge(edge);
873   } else {
874     // The source of the edge is in the output graph so use the regular edge
875     // slot.
876     return edge->src_output();
877   }
878 }
879 
FindOutputImageOfEdgeDst(const string & src_func_id,const string & dst_func_id,const std::unordered_map<const Node *,Node * > & node_images,const Node * original_dst_node,Node ** dst_image)880 Status Encapsulator::FindOutputImageOfEdgeDst(
881     const string& src_func_id, const string& dst_func_id,
882     const std::unordered_map<const Node*, Node*>& node_images,
883     const Node* original_dst_node, Node** dst_image) {
884   if (IsInSubgraph(dst_func_id)) {
885     // The edge is to a subgraph from a regular node in the output graph so
886     // use the subgraph's call node input.
887     *dst_image = subgraphs_.at(dst_func_id).GetCallNode();
888   } else {
889     // The destination of the edge is in the output graph so use the node image
890     // in the output graph.
891     *dst_image = node_images.at(original_dst_node);
892   }
893   return Status::OK();
894 }
895 
FindOutputSlotOfEdgeDst(const string & src_func_id,const string & dst_func_id,const Edge * edge)896 int Encapsulator::FindOutputSlotOfEdgeDst(const string& src_func_id,
897                                           const string& dst_func_id,
898                                           const Edge* edge) {
899   if (IsInSubgraph(dst_func_id)) {
900     const Subgraph& dst_subgraph = subgraphs_.at(dst_func_id);
901       // 'dst' is in a subgraph and 'src' is a regular node in the output
902       // graph. Use the corresponding call input instead.
903       return dst_subgraph.GetArgIndexForEdge(edge);
904   } else {
905     // The destination of the edge is in the output graph so use the regular
906     // edge slot.
907     return edge->dst_input();
908   }
909 }
910 
CopyEdgeToOutputGraph(const Edge * edge,const string & src_func_id,const string & dst_func_id,const std::unordered_map<const Node *,Node * > & node_images,Graph * graph_out,std::unordered_set<std::pair<OutputTensor,InputTensor>,OutputInputTensorPairHasher> * edges_added)911 Status Encapsulator::CopyEdgeToOutputGraph(
912     const Edge* edge, const string& src_func_id, const string& dst_func_id,
913     const std::unordered_map<const Node*, Node*>& node_images, Graph* graph_out,
914     std::unordered_set<std::pair<OutputTensor, InputTensor>,
915                        OutputInputTensorPairHasher>* edges_added) {
916   Node* src_image;
917   TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc(
918       src_func_id, dst_func_id, node_images, edge->src(), &src_image));
919   Node* dst_image;
920   TF_RETURN_IF_ERROR(FindOutputImageOfEdgeDst(
921       src_func_id, dst_func_id, node_images, edge->dst(), &dst_image));
922 
923   // If this is a control edge then copy it and return. Lift control edges onto
924   // the enclosing call operator.
925   if (edge->IsControlEdge()) {
926     // Add the control edge, if we have not already added it, using the images
927     // determined above (potentially call operators or RecvAtHost/SendFromHost).
928     if (edges_added
929             ->emplace(OutputTensor(src_image, -1), InputTensor(dst_image, -1))
930             .second) {
931       graph_out->AddControlEdge(src_image, dst_image,
932                                 /* allow_duplicates= */ true);
933     }
934 
935     return Status::OK();
936   }
937 
938   int src_output = FindOutputSlotOfEdgeSrc(src_func_id, dst_func_id, edge);
939 
940   int dst_input = FindOutputSlotOfEdgeDst(src_func_id, dst_func_id, edge);
941 
942   // Add the edge, if we have not already added it.
943   if (edges_added
944           ->emplace(OutputTensor(src_image, src_output),
945                     InputTensor(dst_image, dst_input))
946           .second) {
947     graph_out->AddEdge(src_image, src_output, dst_image, dst_input);
948   }
949   return Status::OK();
950 }
951 
AddEdgesToOutputGraph(const std::unordered_map<const Node *,Node * > & node_images,Graph * graph_out)952 Status Encapsulator::AddEdgesToOutputGraph(
953     const std::unordered_map<const Node*, Node*>& node_images,
954     Graph* graph_out) {
955   // Set of edges already added to the output graph, represented as (src, dst)
956   // pairs. We use the set to deduplicate edges; multiple edges in the input
957   // graph may map to one edge in the output graph.
958   std::unordered_set<std::pair<OutputTensor, InputTensor>,
959                      OutputInputTensorPairHasher>
960       edges_added;
961 
962   for (const Edge* edge : graph_in_->edges()) {
963     string src_func_id;
964     TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id));
965     string dst_func_id;
966     TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id));
967 
968     // Ignore edges that are strictly contained within one subgraph, unless
969     // we are constructing parallel check graphs.
970     if (IsInSubgraph(src_func_id) && IsInSubgraph(dst_func_id) &&
971         src_func_id == dst_func_id) {
972       continue;
973     }
974 
975     // We have an edge that crosses a cluster boundary or is entirely within the
976     // unclustered graph.
977     TF_RETURN_IF_ERROR(CopyEdgeToOutputGraph(
978         edge, src_func_id, dst_func_id, node_images, graph_out, &edges_added));
979   }
980 
981   for (auto& subgraph_entry : subgraphs_) {
982     Subgraph& subgraph = subgraph_entry.second;
983     subgraph.ConnectSequencerToCallNode(graph_out);
984   }
985 
986   return Status::OK();
987 }
988 
989 namespace {
990 
991 // Adds a dummy Const node to graph_out. The "constant" has the type of
992 // data_type and the shape indicated in 'shape'. The dummy node is not a valid
993 // Const node because it does not have any value defined, but this doesn't
994 // matter because it will only be used subsequently for shape inference. (It
995 // would be possible to add a switch statement over data_type to create a value
996 // for the constant, but that would entail maintaining the logic as new types
997 // are added, and is not necessary.) If the node being replaced was within a
998 // control flow frame, adds appropriate Enter nodes so that the use of the Const
999 // is well-formed.
AddDummyShapedNode(const Node * src_node,int src_port,const std::vector<ControlFlowInfo> & control_flow_info,const TensorShapeProto & shape,Graph * graph_out)1000 Node* AddDummyShapedNode(const Node* src_node, int src_port,
1001                          const std::vector<ControlFlowInfo>& control_flow_info,
1002                          const TensorShapeProto& shape, Graph* graph_out) {
1003   DataType data_type = src_node->output_type(src_port);
1004   TensorProto dummy_proto;
1005   dummy_proto.set_dtype(data_type);
1006   *dummy_proto.mutable_tensor_shape() = shape;
1007   // Don't set any value field in the proto, since it is only going to be used
1008   // for shape inference.
1009 
1010   GraphDefBuilder::Options options(graph_out, /*status=*/nullptr);
1011   NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const",
1012                            options.op_registry());
1013   node_builder.Attr("dtype", data_type).Attr("value", dummy_proto);
1014   Node* node = options.FinalizeBuilder(&node_builder);
1015   // Add any Enter nodes required to bring the constant to the correct control
1016   // flow frame.
1017   while (!control_flow_info[src_node->id()].frame_name.empty()) {
1018     NodeDebugInfo debug_info(*src_node);
1019     NodeBuilder enter_builder(options.GetNameForOp("Enter"), "Enter",
1020                               options.op_registry(), &debug_info);
1021     enter_builder.Attr("frame_name",
1022                        control_flow_info[src_node->id()].frame_name);
1023     enter_builder.Attr("is_constant", true);
1024     enter_builder.Input(node, 0);
1025     Node* enter_node = options.FinalizeBuilder(&enter_builder);
1026     // Adopt the new Enter node as the value in the current frame.
1027     node = enter_node;
1028     // Recurse to the parent frame to see if more Enter nodes need to be added.
1029     src_node = control_flow_info[src_node->id()].parent_frame;
1030   }
1031   return node;
1032 }
1033 
1034 }  // namespace
1035 
MakePrunedGraphCopyAndInline(const Graph & graph,const std::vector<Node * > & sink_nodes,std::unique_ptr<Graph> * pruned_graph,std::unordered_map<const Node *,Node * > * node_images,FunctionLibraryDefinition * library)1036 Status Encapsulator::MakePrunedGraphCopyAndInline(
1037     const Graph& graph, const std::vector<Node*>& sink_nodes,
1038     std::unique_ptr<Graph>* pruned_graph,
1039     std::unordered_map<const Node*, Node*>* node_images,
1040     FunctionLibraryDefinition* library) {
1041   // First copy all ancestor nodes of sink_nodes into a new graph.
1042   pruned_graph->reset(new Graph(library));
1043   (*pruned_graph)->set_versions(graph.versions());
1044   ReverseDFSFrom(graph, sink_nodes,
1045                  /*enter=*/nullptr,
1046                  /*leave=*/[&](Node* n) {
1047                    if (!n->IsSource()) {
1048                      Node* copied = (*pruned_graph)->CopyNode(n);
1049                      node_images->emplace(n, copied);
1050                    }
1051                  });
1052 
1053   // Add all the edges between copied nodes.
1054   for (auto entry : *node_images) {
1055     const Node* orig = entry.first;
1056     Node* image = entry.second;
1057     for (const Edge* out_edge : orig->out_edges()) {
1058       auto iter = node_images->find(out_edge->dst());
1059       if (iter != node_images->end()) {
1060         // The source and destination are both in the copied graph.
1061         (*pruned_graph)
1062             ->AddEdge(image, out_edge->src_output(), iter->second,
1063                       out_edge->dst_input());
1064       }
1065     }
1066   }
1067 
1068   // Find all the function call nodes, and inline them.
1069   std::vector<Node*> function_nodes;
1070   for (auto node : (*pruned_graph)->nodes()) {
1071     const OpRegistrationData* op_reg_data;
1072     TF_RETURN_IF_ERROR(library->LookUp(node->type_string(), &op_reg_data));
1073     if (op_reg_data->is_function_op) {
1074       function_nodes.push_back(node);
1075     }
1076   }
1077   for (auto node : function_nodes) {
1078     VLOG(2) << "Inlining function " << node->name();
1079     const FunctionDef* fdef = library->Find(node->type_string());
1080     if (fdef == nullptr) {
1081       return errors::Internal("Failed to find function ", node->type_string(),
1082                               " in function library.");
1083     }
1084     std::unique_ptr<FunctionBody> fbody;
1085     TF_RETURN_IF_ERROR(
1086         FunctionDefToBodyHelper(*fdef, node->attrs(), library, &fbody));
1087 
1088     InlineFunctionBodyOptions inline_opts;
1089     TF_RETURN_IF_ERROR(InlineFunctionBody(*library, pruned_graph->get(), node,
1090                                           fbody.get(), inline_opts));
1091   }
1092 
1093   return Status::OK();
1094 }
1095 
BuildOutputGraph(Graph * graph_out,FunctionLibraryDefinition * library)1096 Status Encapsulator::BuildOutputGraph(Graph* graph_out,
1097                                       FunctionLibraryDefinition* library) {
1098   // Map from nodes in the input graph to nodes in the output graph.
1099   std::unordered_map<const Node*, Node*> node_images;
1100 
1101   TF_RETURN_IF_ERROR(CopyNodesToOutputGraph(graph_out, &node_images));
1102   TF_RETURN_IF_ERROR(AddFunctionCallNodes(node_images, graph_out));
1103   TF_RETURN_IF_ERROR(AddEdgesToOutputGraph(node_images, graph_out));
1104 
1105   return Status::OK();
1106 }
1107 
1108 }  // anonymous namespace
1109 
EncapsulateSubgraphsInFunctions(string group_attribute,const Graph & graph_in,const RewriteSubgraphFn & rewrite_subgraph_fn,bool reuse_existing_functions,std::unique_ptr<Graph> * graph_out,FunctionLibraryDefinition * library)1110 Status EncapsulateSubgraphsInFunctions(
1111     string group_attribute, const Graph& graph_in,
1112     const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions,
1113     std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library) {
1114   Encapsulator encapsulator(std::move(group_attribute),
1115                             &graph_in);
1116   TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs(library));
1117 
1118   TF_RETURN_IF_ERROR(encapsulator.BuildFunctionDefs(
1119       rewrite_subgraph_fn, reuse_existing_functions, library));
1120 
1121   std::unique_ptr<Graph> out(new Graph(library));
1122   out->set_versions(graph_in.versions());
1123   TF_RETURN_IF_ERROR(encapsulator.BuildOutputGraph(out.get(), library));
1124 
1125   *graph_out = std::move(out);
1126   return Status::OK();
1127 }
1128 
1129 // Finds the types of the _Arg nodes, indexed by position.
GetArgTypes(const Graph & graph,DataTypeVector * types)1130 static Status GetArgTypes(const Graph& graph, DataTypeVector* types) {
1131   for (Node* n : graph.op_nodes()) {
1132     if (n->type_string() == kArgOp) {
1133       int index;
1134       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
1135       const int num_types = types->size();
1136       if (index < 0 || index >= num_types) {
1137         return errors::InvalidArgument("Invalid argument number");
1138       }
1139       (*types)[index] = n->output_type(0);
1140     }
1141   }
1142   return Status::OK();
1143 }
1144 
1145 // Renumber the indices of _Arg nodes in a graph, according to
1146 // 'permutation' that maps old indices to new indices.
RenumberArguments(Graph * graph,const std::vector<int> & permutation)1147 static Status RenumberArguments(Graph* graph,
1148                                 const std::vector<int>& permutation) {
1149   for (Node* n : graph->op_nodes()) {
1150     if (n->type_string() == kArgOp) {
1151       int index;
1152       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
1153       const int permutation_size = permutation.size();
1154       if (index < 0 || index >= permutation_size) {
1155         return errors::InvalidArgument("Invalid argument number");
1156       }
1157       n->AddAttr("index", permutation[index]);
1158     }
1159   }
1160   return Status::OK();
1161 }
1162 
Run(const GraphOptimizationPassOptions & options)1163 Status EncapsulateSubgraphsPass::Run(
1164     const GraphOptimizationPassOptions& options) {
1165   VLOG(1) << "EncapsulateSubgraphsPass::Run";
1166   if (VLOG_IS_ON(1)) {
1167     DumpGraphToFile("encapsulate_subgraphs_before", **options.graph,
1168                     options.flib_def);
1169   }
1170 
1171   std::unique_ptr<Graph> graph_out;
1172   FunctionLibraryDefinition* const library = options.flib_def;
1173 
1174   // Constant folding below might need to run part of the function to compute
1175   // constants. Create an FunctionLibraryRuntime with a single CPU device
1176   // that can run the part of the function.
1177   // NOTE: If this turns out to be slow, we can cache the FLRs keyed by
1178   // `options`.
1179   SessionOptions session_options;
1180   auto* device_count = session_options.config.mutable_device_count();
1181   device_count->insert({"CPU", 1});
1182   std::vector<std::unique_ptr<Device>> devices;
1183 
1184   DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU");
1185   if (!cpu_factory) {
1186     return errors::NotFound(
1187         "CPU Factory not registered. Can't run EncapsulateSubgraphsPass");
1188   }
1189   TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(
1190       session_options, "/job:localhost/replica:0/task:0", &devices));
1191   if (devices.empty()) {
1192     return errors::NotFound(
1193         "Failed to create a CPU device for EncapsulateSubgraphsPass");
1194   }
1195 
1196   std::unique_ptr<DeviceMgr> device_mgr =
1197       absl::make_unique<StaticDeviceMgr>(std::move(devices));
1198   const auto* config = &options.session_options->config;
1199   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
1200       new ProcessFunctionLibraryRuntime(
1201           device_mgr.get(), options.session_options->env,
1202           /*config=*/config, TF_GRAPH_DEF_VERSION, library,
1203           config->graph_options().optimizer_options()));
1204   FunctionLibraryRuntime* flr =
1205       pflr->GetFLR("/job:localhost/replica:0/task:0/device:CPU:0");
1206   if (flr == nullptr) {
1207     return errors::Internal(
1208         "Failed to create and retrieve function library runtime to run "
1209         "constant folding");
1210   }
1211 
1212   auto rewrite_subgraph =
1213       [flr](const std::vector<OutputTensor>& arg_source_tensors,
1214             std::unique_ptr<Graph>* subgraph,
1215             std::vector<int>* input_permutation,
1216             std::vector<int>* output_permutation, NodeDef* node) {
1217         // Optimize the subgraph.
1218         // Do not constant fold nodes that output DT_VARIANT type tensors.
1219         // XLA does not support Const nodes of Variant type since it needs
1220         // to know the original ops to be able to compile them to the relevant
1221         // XLA form.
1222         // TODO(srbs): This filter is a little conservative. E.g. a subgraph of
1223         // the form:
1224         //                          Const
1225         //                            |
1226         // EmptyTensorList -> TensorListPushBack -> TensorListPopBack -> Op
1227         //                                                  |
1228         //                                        (Discard popped list)
1229         //
1230         // Would have been reduced to "Const -> Op" without this filter.
1231         // However since we are only allowed to specify the filter at the "Node"
1232         // level there is no good way to allow the above behavior. So we
1233         // disallow any sort of constant folding on Variant nodes for now.
1234         bool disable_constant_folding =
1235             GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding;
1236         auto cf_consider_fn = [disable_constant_folding](const Node* n) {
1237           if (disable_constant_folding) return false;
1238           for (const auto& output_arg : n->op_def().output_arg()) {
1239             if (output_arg.type() == DT_VARIANT) {
1240               return false;
1241             }
1242           }
1243           return true;
1244         };
1245         GraphOptimizer::Options graph_optimizer_options;
1246         graph_optimizer_options.cf_consider_fn = cf_consider_fn;
1247         OptimizeGraph(flr, subgraph, graph_optimizer_options);
1248 
1249         const int num_args = input_permutation->size();
1250         std::vector<bool> const_args(num_args);
1251         TF_RETURN_IF_ERROR(
1252             BackwardsConstAnalysis(**subgraph, &const_args,
1253                                    /*compile_time_const_nodes=*/nullptr, flr));
1254 
1255         DataTypeVector arg_types(num_args);
1256         TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types));
1257 
1258         // Compute a permutation of the arguments such that the constant
1259         // arguments are first.
1260         const int num_consts =
1261             std::count(const_args.begin(), const_args.end(), true);
1262 
1263         const int num_resources =
1264             std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE);
1265         const int num_nonconsts = num_args - num_resources - num_consts;
1266         if (num_nonconsts < 0) {
1267           return errors::Internal("num_nonconsts should be >= 0, was ",
1268                                   num_nonconsts);
1269         }
1270 
1271         int const_pos = 0;
1272         int arg_pos = num_consts;
1273         int resource_pos = num_consts + num_nonconsts;
1274         for (int i = 0; i < num_args; ++i) {
1275           if (const_args[i]) {
1276             if (arg_types[i] == DT_RESOURCE) {
1277               return errors::Internal(
1278                   "Resource arguments cannot be constant (argument ", i, ")");
1279             }
1280             (*input_permutation)[i] = const_pos;
1281             ++const_pos;
1282           } else if (arg_types[i] == DT_RESOURCE) {
1283             (*input_permutation)[i] = resource_pos;
1284             ++resource_pos;
1285           } else {
1286             (*input_permutation)[i] = arg_pos;
1287             ++arg_pos;
1288           }
1289         }
1290 
1291         // Renumber argument nodes in the graph.
1292         TF_RETURN_IF_ERROR(
1293             RenumberArguments(subgraph->get(), *input_permutation));
1294 
1295         // TODO(phawkins): add a forward is-constant analysis, similarly split
1296         // outputs into host-memory constants and device-memory non-constants.
1297 
1298         AddNodeAttr(kXlaCompiledKernelAttr, true, node);
1299         AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node);
1300         AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node);
1301         return Status::OK();
1302       };
1303 
1304   TF_RETURN_WITH_CONTEXT_IF_ERROR(
1305       EncapsulateSubgraphsInFunctions(
1306           kXlaClusterAttr, **options.graph, rewrite_subgraph,
1307           /*reuse_existing_functions=*/false, &graph_out, library),
1308       "EncapsulateSubgraphsPass failed");
1309 
1310   if (VLOG_IS_ON(1)) {
1311     DumpGraphToFile("encapsulate_subgraphs_after", *graph_out,
1312                     options.flib_def);
1313   }
1314 
1315   *options.graph = std::move(graph_out);
1316   TF_ASSIGN_OR_RETURN(absl::flat_hash_set<Node*> ref_related_nodes,
1317                       GetNodesRelatedToRefVariables(**options.graph, flr));
1318   for (Node* node : (*options.graph)->nodes()) {
1319     bool has_ref_vars = ref_related_nodes.contains(node);
1320     node->AddAttr(kXlaHasReferenceVarsAttr, has_ref_vars);
1321     VLOG(3) << "Has ref vars = " << has_ref_vars
1322             << ", node: " << node->def().SerializeAsString();
1323   }
1324   return Status::OK();
1325 }
1326 
IsXlaCompiledKernel(const Node & node)1327 bool IsXlaCompiledKernel(const Node& node) {
1328   bool is_compiled = false;
1329   bool has_compilation_attr =
1330       TryGetNodeAttr(node.attrs(), kXlaCompiledKernelAttr, &is_compiled) &&
1331       is_compiled;
1332   return has_compilation_attr ? is_compiled : false;
1333 }
1334 
1335 }  // namespace tensorflow
1336