1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file contains some utility functions for encapsulating XLA computation
17 // in host graph and encapsulating outside compilation in XLA computation.
18 
19 #ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_
20 #define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/stream_executor/lib/statusor.h"
25 
26 namespace tensorflow {
27 
28 // Attribute marking output tensor shapes inferred by XLA. Attribute value is
29 // a list of PartialTensorShape objects.
30 extern const char kXlaInferredShapesAttrName[];
31 
32 // Infers output shapes for all nodes in graph `g`. The output shapes will be
33 // stored in node attribute `kXlaInferredShapesAttrName`.
34 //
35 // We have to perform shape inference before encapsulation because after
36 // encapsulation, some nodes will be encapsulated into function call, and shape
37 // inference does not handle function call at the moment.
38 Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g);
39 
40 // Attribute indicating that some ops in this node's XLA computation has control
41 // dependency on this node. Attribute value will always be "true".
42 extern const char kXlaConnectedToXlaComputationAttrName[];
43 
44 // Attribute indicating that this node has control dependency on some ops in
45 // this node's XLA computation. Attribute value will always be "true".
46 extern const char kXlaConnectedFromXlaComputationAttrName[];
47 
48 // Attribute indicating that this is an Placeholder node added to act as a
49 // temporary input node for an outside compilation node. Attribute value will be
50 // string (original input node name).
51 extern const char kOutsideCompilationOriginalNodeAttrName[];
52 
53 // Attribute indicating that this is an Placeholder node added to act as a
54 // temporary input node for an outside compilation node. Attribute value will be
55 // int (src_output for original edge).
56 extern const char kOutsideCompilationSrcOutputAttrName[];
57 
58 // Attribute indicating that this node has control dependencies on some other
59 // nodes within the same XLA cluster. Attribute value will be a list of string
60 // (node names).
61 extern const char kXlaControlDependenciesWithinXlaClusterAttrName[];
62 
63 // Attribute indicating that this node is an outside compilation node which is
64 // lifted out of If/While/function node. Attribute value will always be boolean
65 // value "true".
66 extern const char kXlaIsLiftedArgAttrName[];
67 
68 // Attribute indicating that this node is a Placeholder node for an _Arg node
69 // lifted out of If/While/function node. Attribute value will be a string, which
70 // is the outside compilation cluster name sending the lifted arg node to host.
71 extern const char kXlaLiftedArgOutsideCompilationAttrName[];
72 
73 // Attribute indicating that this is an IdentityN node receiving inputs for a
74 // outside compilation Placeholder node (the original outside compilation node
75 // is moved out of TPU computation, and we left a Placeholder node there).
76 // Attribute value will be a string, which is the outside compilation cluster
77 // name for the outside compilation Placeholder node.
78 extern const char kXlaOutsideCompilationInputsAttrName[];
79 
80 // Attribute indicating that this is a Placeholder node for an _Arg node used in
81 // outside compilation. We should not move this node out of XLA computation.
82 // Attribute value will always be boolean value "true".
83 extern const char kXlaIsPlaceholderForArg[];
84 
85 // Information for XLA computation.
86 struct XlaClusterInfo {
87   // Add an explicitly-defined default constructor for this class.
88   //
89   // The compiler may delete the default constructor here because
90   // host_compute_core is a const member whose type (std::map) doesn't
91   // necessarily have a user provided constructor -- while libc++ and
92   // libstdc++ 4.8 provide a user defined default constructor, libstdc++ at
93   // least >= 7.3 does not. See also c++11 [class.ctor] p5.
94   //
95   // TODO(klimek): In c++17 we'll be able to initialize host_compute_core
96   // without losing aggregate initialization, which allows us to get rid of
97   // the constructor definitions again.
XlaClusterInfoXlaClusterInfo98   XlaClusterInfo() {}
XlaClusterInfoXlaClusterInfo99   XlaClusterInfo(const string& cluster_name,
100                  const NameAttrList& func_name_attrs, Node* node,
101                  const std::map<string, int>& host_compute_core)
102       : cluster_name(cluster_name),
103         func_name_attrs(func_name_attrs),
104         node(node),
105         host_compute_core(host_compute_core) {}
106   // XLA cluster name. It might be different from `func_name`.
107   const string cluster_name;
108   // Name and attributes of XLA computation function.
109   const NameAttrList func_name_attrs;
110   // The XLA computation node in the graph.
111   Node* node;
112   // A mapping from outside compilation cluster name to its device assignment.
113   const std::map<string, int> host_compute_core;
114 };
115 
116 // Finds dependencies between outside compilation clusters, including both data
117 // dependencies and control dependencies. cluster_deps maps the name name of an
118 // outside compilation cluster to a set of names of outside compilation clusters
119 // that it depends on.
120 stream_executor::port::StatusOr<
121     std::unique_ptr<absl::flat_hash_map<string, std::vector<string>>>>
122 OutsideCompilationClusterDependencies(
123     const Graph* g, const string& outside_compilation_attr_name);
124 
125 // Preprocesses edges within the same XLA cluster. It will perform the following
126 // operations in order:
127 //
128 // 0.  Remove edges from source node to outside compilation nodes, and edges
129 //     from outside compilation nodes to sink node.
130 // 1a. For edges between different outside compilation clusters, remove the edge
131 //     and add attr "kXlaControlDependenciesWithinXlaClusterAttrName = src node
132 //     name" to dst node.
133 // 1b. For control edges between outside compilation and its XLA computation,
134 //     add attr "kXlaConnected{From, To}XlaComputationAttrName = true" to the
135 //     outside compilation node.
136 // 2.  For data edges between different outside compilations, remove the edge
137 //     and create a Placeholder node as dst node's input.
138 Status PreprocessEdgesBetweenOutsideCompilations(
139     Graph* g, const string& outside_compilation_attr_name);
140 
141 // Postprocesses edges within the same XLA cluster. This function reverts what
142 // `PreprocessEdgesBetweenOutsideCompilations` did. It will perform the
143 // following operations in order:
144 //
145 // 1. Remove Placeholder nodes between different outside compilations (created
146 //    in `PreprocessEdgesBetweenOutsideCompilations` step 2).
147 // 2a. Reconnect control edges between different outside compilations (marked by
148 //     `PreprocessEdgesBetweenOutsideCompilations` step 1a).
149 // Notice that control edges marked by
150 // `PreprocessEdgesBetweenOutsideCompilations` step 1b are not handled here.
151 // They are handled in `RewriteOutsideCompilationSubgraphFn`.
152 Status PostprocessEdgesBetweenOutsideCompilations(
153     Graph* g, const string& outside_compilation_attr_name);
154 }  // namespace tensorflow
155 
156 #endif  // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_
157