1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_JIT_EXTRACT_OUTSIDE_COMPILATION_PASS_H_
17 #define TENSORFLOW_COMPILER_JIT_EXTRACT_OUTSIDE_COMPILATION_PASS_H_
18 
19 #include "absl/types/optional.h"
20 #include "tensorflow/compiler/jit/encapsulate_util.h"
21 #include "tensorflow/compiler/xla/status_macros.h"
22 #include "tensorflow/core/graph/graph.h"
23 
24 namespace tensorflow {
25 
26 // Rewrite function for outside compilation subgraphs. It will perform the
27 // following steps:
28 //
29 // 1. Add a XLA computation key placeholder node (it will be used as input for
30 //    XlaRecvAtHost and XlaSendFromHost);
31 // 2. Replace all _Arg nodes with one single XlaRecvAtHost node;
32 // 3. Replace all _Retval nodes with one single XlaSendFromHost node;
33 // 4. Mark all nodes except key placeholder with attr `xla_cluster_attr_name`
34 //    and `outside_compilation_attr_name`;
35 // 5. For nodes marked with attr kXlaConnectedToXlaComputationAttrName, add a
36 //    control edge from the node to XlaSendFromHost; for nodes marked with attr
37 //    kXlaConnectedFromXlaComputationAttrName, add a control edge from
38 //    XlaRecvAtHost node to the node;
39 // 6. Try pruning XlaRecvAtHost/XlaSendFromHost/key placeholder node.
40 // 7. Add necessary attributes to `node_def`, so we can replace it with a
41 //    XlaHostCompute node later. If all input shapes for XlaSendFromHost are
42 //    known, "shapes" attr will be set to the list of input shapes; otherwise
43 //    "shape_inference_graph" attr will be set to shape inference function name.
44 class RewriteOutsideCompilationSubgraphFn {
45  public:
RewriteOutsideCompilationSubgraphFn(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const string & new_function_name)46   RewriteOutsideCompilationSubgraphFn(
47       const string& xla_cluster_attr_name,
48       const string& outside_compilation_attr_name,
49       const string& xla_cluster_name, const string& new_function_name)
50       : xla_cluster_attr_name_(xla_cluster_attr_name),
51         outside_compilation_attr_name_(outside_compilation_attr_name),
52         xla_cluster_name_(xla_cluster_name),
53         new_function_name_(new_function_name) {}
54 
55   Status operator()(const std::vector<OutputTensor>&,
56                     std::unique_ptr<Graph>* graph,
57                     std::vector<int>* input_permutation,
58                     std::vector<int>* output_permutation, NodeDef* node_def);
59 
60  private:
61   string xla_cluster_attr_name_;
62   string outside_compilation_attr_name_;
63   string xla_cluster_name_;
64   string new_function_name_;
65 };
66 
67 // For an XLA computation function, replace all outside compilations with
68 // XlaHostCompute nodes. Each outside compilation subgraph will be rewritten by
69 // `RewriteOutsideCompilationSubgraphFn`, and they will be merged into one
70 // single host side graph (`host_graph`).
71 //
72 // xla_cluster_attr_name and outside_compilation_attr_name: attr name for XLA
73 //   computation and outside compilation. Required for
74 //   `RewriteOutsideCompilationSubgraphFn`.
75 // xla_cluster_name: XLA cluster name for this XLA computation. We need it
76 //   because XLA cluster name might be different from `func_name`.
77 // func_name_attrs: they will be used to instantiate the XLA computation func.
78 // new_func_name: new function name for rewritten XLA computation func.
79 // host_compute_core: mapping from outside compilation cluster name to XLA
80 //   device assignment.
81 // fld: FunctionLibraryDefinition object.
82 // host_graph: Graph object to store host side graph for all outside
83 //   compilations within this XLA computation func. If there is no outside
84 //   compilation, it will be empty.
85 // shape_inference_graphs: a list of outside compilation shape inference
86 //   function names. These functions need to be rewritten later.
87 // has_outside_compilation: a bool indicating whether this function has any
88 //   outside compilation nodes.
89 Status ExtractOutsideCompilationForFunction(
90     const string& xla_cluster_attr_name,
91     const string& outside_compilation_attr_name, const string& xla_cluster_name,
92     const NameAttrList& func_name_attrs, const string& new_func_name,
93     const string& host_graph_func_name,
94     const std::map<string, int>& host_compute_core, FunctionLibraryRuntime* flr,
95     FunctionLibraryDefinition* fld, std::vector<string>* shape_inference_graphs,
96     bool* has_outside_compilation);
97 
98 // Rewrites XLA computation in `clusters` to replace outside compilation nodes
99 // with XlaHostCompute, and moves those outside compilations into `g`. If shapes
100 // of outside compilation outputs cannot be determined now, we will store shape
101 // inference graph into `fld`.
102 Status ExtractOutsideCompilation(
103     const string& xla_cluster_attr_name,
104     const string& outside_compilation_attr_name,
105     const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g,
106     FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
107     bool* modified);
108 
109 }  // namespace tensorflow
110 
111 #endif  // TENSORFLOW_COMPILER_JIT_EXTRACT_OUTSIDE_COMPILATION_PASS_H_
112