1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // An optimization pass that groups nodes marked with a common
17 // kXlaClusterAttr into functions, and replaces the original nodes by
18 // calls. The calls are annotated with kXlaCompiledKernelAttr.
19 
20 #ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_
21 #define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_
22 
23 #include "tensorflow/core/common_runtime/optimization_registry.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/graph/graph.h"
26 #include "tensorflow/core/lib/core/status.h"
27 
28 namespace tensorflow {
29 
30 // EncapsulateSubgraphs pass takes all the nodes with the same cluster ID
31 // (derived from kXlaClusterAttr=ID (kXlaClusterAttr) attribute), puts them into
32 // a TF function, and replaces the subgraph in the main graph with a call to
33 // that TF function annotated with kXlaCompiledKernelAttr (_XlaCompiledKernel).
34 class EncapsulateSubgraphsPass : public GraphOptimizationPass {
35  public:
36   Status Run(const GraphOptimizationPassOptions& options) override;
37 };
38 
39 // A rewriting function to apply to each subgraph during encapsulation.
40 // 'arg_source_tensors' are the tensors corresponding to the arguments in the
41 // original source graph (*not* 'graph').
42 //
43 // 'graph' is the subgraph. The rewriting may renumber the inputs and outputs;
44 // 'input_permutation' is a mapping from old argument numbers to new argument
45 // numbers, whereas 'output_permutation' is the same for outputs. Both
46 // 'input_permutation' and 'output_permutation' are initialized to the identity
47 // permutation. 'nodedef' is the NodeDef for the call to the function under
48 // construction, provided to allow additional attributes to be set.
49 // The rewrite may also change the NodeDef's operator name, and that
50 // name will be used as the name of the generated function.
51 typedef std::function<Status(
52     const std::vector<OutputTensor>& arg_source_tensors,
53     std::unique_ptr<Graph>* graph, std::vector<int>* input_permutation,
54     std::vector<int>* output_permutation, NodeDef* node_def)>
55     RewriteSubgraphFn;
56 
57 // Transformation that finds subgraphs whose nodes are marked with
58 // 'group_attribute', splits those subgraphs into functions, and replaces
59 // the originals with function calls.
60 //
61 // 'group_attribute' must be a string valued-attribute that names the new
62 // functions to introduce.
63 //
64 // If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
65 // function conversion.
66 //
67 // If 'reuse_existing_functions' is set, use an existing function with the
68 // same name, if any.
69 //
70 // TODO(phawkins): currently, some information in control edges
71 // is not preserved. Suppose you have A and B in the main
72 // graph, C and D in a subgraph. B and C have control deps from A, D has control
73 // dep from B. Originally D must run after C, post-transformation this
74 // dependency is lost.
75 Status EncapsulateSubgraphsInFunctions(
76     string group_attribute, const Graph& graph_in,
77     const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions,
78     std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library);
79 
80 // The attribute that marks function calls produced by the encapsulate
81 // subgraphs pass and that should in turn be compiled via XlaLaunch operators.
82 extern const char* const kXlaCompiledKernelAttr;
83 
84 // Does `node` have the kXlaCompiledKernelAttr attribute?
85 bool IsXlaCompiledKernel(const Node& node);
86 
87 // Functions produced by the EncapsulateSubgraphs pass have their arguments in
88 // the order:
89 // 1) compile-time constant arguments, in host memory,
90 // 2) other arguments, in device memory.
91 // 3) resource variable arguments, in host memory. Note that only the resource
92 //    Tensor itself is in host memory; the underlying value may be in device
93 //    memory.
94 // The functions are annotated with the following attributes that describe how
95 // many constant and resource arguments there are:
96 
97 // Name of the attribute containing the number of constant arguments.
98 extern const char* const kXlaNumConstantArgsAttr;
99 
100 // Name of the attribute containing the number of resource variable arguments.
101 extern const char* const kXlaNumResourceArgsAttr;
102 
103 // Name of the attribute defining whether the cluster has reference variables.
104 extern const char* const kXlaHasReferenceVarsAttr;
105 
106 // Sorts each node's control inputs by their names. This guarantees that for two
107 // structurally equivalent GraphDefs, we get the same traversal ordering on
108 // node's control input fields.
109 // TODO(hpucha): Move the utilities to a more appropriate place.
110 void SortControlInputs(GraphDef* gdef);
111 
112 }  // namespace tensorflow
113 
114 #endif  // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_
115