1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_JIT_EXTRACT_OUTSIDE_COMPILATION_PASS_H_ 17 #define TENSORFLOW_COMPILER_JIT_EXTRACT_OUTSIDE_COMPILATION_PASS_H_ 18 19 #include "absl/types/optional.h" 20 #include "tensorflow/compiler/jit/encapsulate_util.h" 21 #include "tensorflow/compiler/xla/status_macros.h" 22 #include "tensorflow/core/graph/graph.h" 23 24 namespace tensorflow { 25 26 // Rewrite function for outside compilation subgraphs. It will perform the 27 // following steps: 28 // 29 // 1. Add a XLA computation key placeholder node (it will be used as input for 30 // XlaRecvAtHost and XlaSendFromHost); 31 // 2. Replace all _Arg nodes with one single XlaRecvAtHost node; 32 // 3. Replace all _Retval nodes with one single XlaSendFromHost node; 33 // 4. Mark all nodes except key placeholder with attr `xla_cluster_attr_name` 34 // and `outside_compilation_attr_name`; 35 // 5. For nodes marked with attr kXlaConnectedToXlaComputationAttrName, add a 36 // control edge from the node to XlaSendFromHost; for nodes marked with attr 37 // kXlaConnectedFromXlaComputationAttrName, add a control edge from 38 // XlaRecvAtHost node to the node; 39 // 6. Try pruning XlaRecvAtHost/XlaSendFromHost/key placeholder node. 40 // 7. Add necessary attributes to `node_def`, so we can replace it with a 41 // XlaHostCompute node later. If all input shapes for XlaSendFromHost are 42 // known, "shapes" attr will be set to the list of input shapes; otherwise 43 // "shape_inference_graph" attr will be set to shape inference function name. 44 class RewriteOutsideCompilationSubgraphFn { 45 public: RewriteOutsideCompilationSubgraphFn(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const string & new_function_name)46 RewriteOutsideCompilationSubgraphFn( 47 const string& xla_cluster_attr_name, 48 const string& outside_compilation_attr_name, 49 const string& xla_cluster_name, const string& new_function_name) 50 : xla_cluster_attr_name_(xla_cluster_attr_name), 51 outside_compilation_attr_name_(outside_compilation_attr_name), 52 xla_cluster_name_(xla_cluster_name), 53 new_function_name_(new_function_name) {} 54 55 Status operator()(const std::vector<OutputTensor>&, 56 std::unique_ptr<Graph>* graph, 57 std::vector<int>* input_permutation, 58 std::vector<int>* output_permutation, NodeDef* node_def); 59 60 private: 61 string xla_cluster_attr_name_; 62 string outside_compilation_attr_name_; 63 string xla_cluster_name_; 64 string new_function_name_; 65 }; 66 67 // For an XLA computation function, replace all outside compilations with 68 // XlaHostCompute nodes. Each outside compilation subgraph will be rewritten by 69 // `RewriteOutsideCompilationSubgraphFn`, and they will be merged into one 70 // single host side graph (`host_graph`). 71 // 72 // xla_cluster_attr_name and outside_compilation_attr_name: attr name for XLA 73 // computation and outside compilation. Required for 74 // `RewriteOutsideCompilationSubgraphFn`. 75 // xla_cluster_name: XLA cluster name for this XLA computation. We need it 76 // because XLA cluster name might be different from `func_name`. 77 // func_name_attrs: they will be used to instantiate the XLA computation func. 78 // new_func_name: new function name for rewritten XLA computation func. 79 // host_compute_core: mapping from outside compilation cluster name to XLA 80 // device assignment. 81 // fld: FunctionLibraryDefinition object. 82 // host_graph: Graph object to store host side graph for all outside 83 // compilations within this XLA computation func. If there is no outside 84 // compilation, it will be empty. 85 // shape_inference_graphs: a list of outside compilation shape inference 86 // function names. These functions need to be rewritten later. 87 // has_outside_compilation: a bool indicating whether this function has any 88 // outside compilation nodes. 89 Status ExtractOutsideCompilationForFunction( 90 const string& xla_cluster_attr_name, 91 const string& outside_compilation_attr_name, const string& xla_cluster_name, 92 const NameAttrList& func_name_attrs, const string& new_func_name, 93 const string& host_graph_func_name, 94 const std::map<string, int>& host_compute_core, FunctionLibraryRuntime* flr, 95 FunctionLibraryDefinition* fld, std::vector<string>* shape_inference_graphs, 96 bool* has_outside_compilation); 97 98 // Rewrites XLA computation in `clusters` to replace outside compilation nodes 99 // with XlaHostCompute, and moves those outside compilations into `g`. If shapes 100 // of outside compilation outputs cannot be determined now, we will store shape 101 // inference graph into `fld`. 102 Status ExtractOutsideCompilation( 103 const string& xla_cluster_attr_name, 104 const string& outside_compilation_attr_name, 105 const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g, 106 FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, 107 bool* modified); 108 109 } // namespace tensorflow 110 111 #endif // TENSORFLOW_COMPILER_JIT_EXTRACT_OUTSIDE_COMPILATION_PASS_H_ 112