1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_
17 #define TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_
18 
19 #include <string>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/strings/string_view.h"
23 #include "absl/types/optional.h"
24 #include "tensorflow/compiler/jit/defs.h"
25 #include "tensorflow/compiler/jit/device_util.h"
26 #include "tensorflow/compiler/jit/flags.h"
27 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
28 #include "tensorflow/compiler/tf2xla/const_analysis.h"
29 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
30 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
31 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/union_find.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/core/common_runtime/function.h"
36 #include "tensorflow/core/common_runtime/graph_constructor.h"
37 #include "tensorflow/core/framework/attr_value.pb.h"
38 #include "tensorflow/core/framework/bounds_check.h"
39 #include "tensorflow/core/framework/function.h"
40 #include "tensorflow/core/framework/graph_def_util.h"
41 #include "tensorflow/core/framework/memory_types.h"
42 #include "tensorflow/core/framework/node_def.pb.h"
43 #include "tensorflow/core/framework/op_kernel.h"
44 #include "tensorflow/core/framework/types.h"
45 #include "tensorflow/core/framework/types.pb.h"
46 #include "tensorflow/core/graph/algorithm.h"
47 #include "tensorflow/core/graph/control_flow.h"
48 #include "tensorflow/core/graph/graph.h"
49 #include "tensorflow/core/lib/gtl/cleanup.h"
50 #include "tensorflow/core/lib/strings/stringprintf.h"
51 #include "tensorflow/core/public/version.h"
52 #include "tensorflow/core/util/dump_graph.h"
53 
54 namespace tensorflow {
55 // Checks whether a TF node can be compiled or not.  "Recursive" as in for call
56 // and functional while nodes it recursively checks whether the callee functions
57 // can be compiled.
58 class RecursiveCompilabilityChecker {
59  public:
60   // Contains node name and function name. If the node is not inside a function
61   // body, function name is an empty string.
62   struct StackFrame {
63     std::string name;
64     std::string function_name;
65     std::shared_ptr<AbstractStackTrace> stack_trace;
66   };
67 
68   // Contains information about uncompilable node inside a function body.
69   struct UncompilableNodeInfo {
70     std::string name;
71     // A list representing a stacktrace from the highest level node in
72     // increasing call depth to immediate node that fails the
73     // compilability checker.
74     std::vector<StackFrame> stack_trace;
75     std::string uncompilable_reason;
76   };
77 
78   // Aggregates information about what kinds of ops are allowed.
79   struct OperationFilter {  // TODO(lzr): Add AllowEverything() helper.
80     // Whether resource variable ops are allowed are allowed in callees.  We do
81     // not allow resource variable ops in called functions (either as direct TF
82     // calls or as higher order control flow ops) because we do not yet model
83     // their memory effects in jit/resource_operation_safety_analysis.
84     bool allow_resource_ops_in_called_functions = false;
85 
86     // Whether Stack operations are allowed.  We avoid auto-clustering Stack
87     // operations in general because we do not support snapshotting them.
88     //
89     // TODO(b/112837194): This restriction can be lifted with some work.
90     bool allow_stack_ops = false;
91 
92     // Whether TensorArray operations are allowed.  We avoid auto-clustering
93     // TensorArray operations in general because we do not support snapshotting
94     // them.
95     //
96     // TODO(b/112837194): This restriction can be lifted with some work.
97     bool allow_tensor_array_ops = false;
98 
99     // Whether stateful RNG ops are allowed.  XLA's RNG does not have the same
100     // seeding behavior as TensorFlow's RNG (b/34749654).  So we avoid
101     // auto-clustering stateful RNG ops.
102     bool allow_stateful_rng_ops = false;
103 
104     // TODO(b/118970344): Whether ControlTrigger ops are allowed.  It is unsound
105     // to cluster ControlTrigger because of how we use deadness analysis.
106     bool allow_control_trigger = false;
107 
108     // Whether it is okay to "cluster" Assert and CheckNumerics by simply
109     // removing them (they're not removed during clustering, but their
110     // XlaOpKernel is a no-op kernel).  We avoid auto-clustering these ops so
111     // that the user is not surprised when XLA is implicitly enabled. If the
112     // user explicitly specifies to use XLA, it is fine to resort to a dummy
113     // implementation. Currently Assert and CheckNumerics ops have dummy XLA
114     // implementations.
115     bool allow_eliding_assert_and_checknumerics_ops = false;
116 
117     // Whether ops that produce or consume DT_VARIANT values are allowed.  We
118     // don't auto-cluster these ops because we don't yet support live-in or
119     // live-out DT_VARIANT values.
120     bool allow_ops_producing_or_consuming_variant = false;
121 
122     // Whether ops known to be slow on XLA-GPU should be considered compilable.
123     bool allow_slow_ops = false;
124 
125     // Whether ops known to have numerical accuracy issues should be considered
126     // compilable..
127     bool allow_inaccurate_ops = false;
128 
129     // Require the function to be always compilable, regardless whether some
130     // control flow branches might be dead for a given input.
131     bool require_always_compilable = false;
132 
133     // Whether string constants are compilable.
134     bool allow_string_consts = true;
135   };
136 
RecursiveCompilabilityChecker(OperationFilter op_filter,DeviceType jit_device_type)137   RecursiveCompilabilityChecker(OperationFilter op_filter,
138                                 DeviceType jit_device_type)
139       : op_filter_(std::move(op_filter)),
140         jit_device_type_(std::move(jit_device_type)) {}
141 
142   using UncompilableNodesMap =
143       std::map<std::string,
144                std::pair<NameAttrList, std::vector<UncompilableNodeInfo>>>;
145 
146   // Returns a map where the key is the function identifier(short debug
147   // string) of the function encapsulating the uncompilable nodes, and the
148   // value is a pair of NameAttrList of the function and a vector of
149   // uncompilable node info. When uncompilable node is not inside any
150   // function call nodes, then key is a ShortDebugString() of an empty
151   // NameAttrList.
152   //
153   // Also, when `node` is inside a function body, users can set
154   // `node_stack_trace` to provide an additional context for `node`'s
155   // placement within the outer most graph.
156   UncompilableNodesMap FindUncompilableNodes(
157       const Node& node, FunctionLibraryRuntime* lib_runtime,
158       const std::vector<StackFrame>* node_stack_trace = nullptr) const;
159 
160   // Returns a map where the key is the function identifier(short debug
161   // string) of the function encapsulating the uncompilable nodes, and the
162   // value is a pair of NameAttrList of the function and a vector of
163   // uncompilable node info. When uncompilable node is not inside any
164   // function call nodes, then key is a ShortDebugString() of an empty
165   // NameAttrList.
166   //
167   // Also, when `node` is inside a function body, users can set
168   // `node_stack_trace` to provide an additional context for `node`'s
169   // placement within the outer most graph.
170   UncompilableNodesMap FindUncompilableNodes(
171       const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
172       const std::vector<StackFrame>* node_stack_trace = nullptr) const;
173 
174   // Returns true if `node` can be compiled by XLA.
IsCompilableNode(const Node & node,FunctionLibraryRuntime * lib_runtime)175   bool IsCompilableNode(const Node& node,
176                         FunctionLibraryRuntime* lib_runtime) const {
177     std::vector<StackFrameView> stack_trace;
178     stack_trace.emplace_back(StackFrameView{node.name(), ""});
179     return IsCompilableNode(node, lib_runtime, &stack_trace);
180   }
181 
182   // Returns true if `call_def` can be compiled by XLA.  It is assumed that
183   // `call_def` is a call operation.
IsCompilableCall(const NodeDef & call_def,FunctionLibraryRuntime * lib_runtime)184   bool IsCompilableCall(const NodeDef& call_def,
185                         FunctionLibraryRuntime* lib_runtime) {
186     std::vector<StackFrameView> stack_trace;
187     stack_trace.emplace_back(StackFrameView{call_def.name(), ""});
188     return IsCompilableCall(call_def, lib_runtime, &stack_trace);
189   }
190 
191   // Returns true if XLA supports this Op, but we don't want to cluster it (ie:
192   // due to performance or correctness concerns).
193   bool OpIsInaccurate(const Node& node) const;
194   bool OpIsSlow(const Node& node) const;
195 
196  private:
197   struct StackFrameView {
198     absl::string_view name;
199     absl::string_view function_name;
200     std::shared_ptr<AbstractStackTrace> stack_trace;
201   };
202 
203   bool IsCompilableNode(
204       const Node& node, FunctionLibraryRuntime* lib_runtime,
205       std::vector<StackFrameView>* stack_trace,
206       NameAttrList* encapsulating_function = nullptr,
207       UncompilableNodesMap* uncompilable_nodes = nullptr) const;
208   bool IsCompilableCall(
209       const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
210       std::vector<StackFrameView>* stack_trace,
211       NameAttrList* encapsulating_function = nullptr,
212       UncompilableNodesMap* uncompilable_nodes = nullptr) const;
213   bool IsCompilableIf(const Node& if_node, FunctionLibraryRuntime* lib_runtime,
214                       std::vector<StackFrameView>* stack_trace,
215                       NameAttrList* encapsulating_function,
216                       UncompilableNodesMap* uncompilable_nodes) const;
217   bool IsCompilableWhile(const Node& while_node,
218                          FunctionLibraryRuntime* lib_runtime,
219                          std::vector<StackFrameView>* stack_trace,
220                          NameAttrList* encapsulating_function,
221                          UncompilableNodesMap* uncompilable_nodes) const;
222 
223   // Tests whether 'case_node' is compilable. Every operator in all branches
224   // must be compilable.
225   bool IsCompilableCase(const Node& case_node,
226                         FunctionLibraryRuntime* lib_runtime,
227                         std::vector<StackFrameView>* stack_trace,
228                         NameAttrList* encapsulating_function,
229                         UncompilableNodesMap* uncompilable_nodes) const;
230 
231   // Returns compilability of node def retrieved from `node`'s attribute with
232   // name `attr_name`.
233   bool ExtractNodeDefAndCheckCompilability(
234       const Node& node, const std::string& attr_name,
235       const std::string& call_name, NameAttrList* encapsulating_function,
236       FunctionLibraryRuntime* lib_runtime,
237       std::vector<StackFrameView>* stack_trace,
238       UncompilableNodesMap* uncompilable_nodes) const;
239 
IsStackOp(const Node & node)240   bool IsStackOp(const Node& node) const {
241     const XlaResourceOpInfo* op_info =
242         GetResourceOpInfoForOp(node.type_string());
243     return op_info && op_info->resource_kind() == XlaResourceKind::kStack;
244   }
245 
IsTensorArrayOp(const Node & node)246   bool IsTensorArrayOp(const Node& node) const {
247     const XlaResourceOpInfo* op_info =
248         GetResourceOpInfoForOp(node.type_string());
249     return op_info && op_info->resource_kind() == XlaResourceKind::kTensorArray;
250   }
251 
IsAssertOrCheckNumerics(absl::string_view op_name)252   bool IsAssertOrCheckNumerics(absl::string_view op_name) const {
253     return op_name == "Assert" || op_name == "CheckNumerics";
254   }
255 
IsStatefulRandomOp(absl::string_view op_name)256   bool IsStatefulRandomOp(absl::string_view op_name) const {
257     return op_name == "RandomUniform" || op_name == "RandomShuffle" ||
258            op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" ||
259            op_name == "TruncatedNormal" || op_name == "Multinomial";
260   }
261 
OpProducesOrConsumesVariant(const Node & node)262   bool OpProducesOrConsumesVariant(const Node& node) const {
263     auto is_variant = [](DataType dtype) { return dtype == DT_VARIANT; };
264     return absl::c_any_of(node.input_types(), is_variant) ||
265            absl::c_any_of(node.output_types(), is_variant);
266   }
267 
268   bool HasXLAKernel(const Node& node,
269                     string* uncompilable_reason = nullptr) const;
270 
271   static void MaybeMarkUncompilableNode(
272       const absl::string_view reason,
273       const std::vector<StackFrameView>& stack_trace,
274       NameAttrList* encapsulating_function,
275       UncompilableNodesMap* uncompilable_nodes_map);
276 
277   // Make sure we don't recurse infinitely on recursive functions.
278   const size_t kMaxRecursionDepth = 50;
279 
280   const OperationFilter op_filter_;
281   const DeviceType jit_device_type_;
282 };
283 
284 RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
285     const XlaOpRegistry::DeviceRegistration& registration);
286 
287 // Given a FunctionLibraryRuntime and a `function`, returns this function's body
288 // in `fbody` as well as the indices of its constant and resource arguments.
289 // `fbody` is owned by `flr`.
290 // `constant_arg_indices` and `resource_arg_indices` should be empty vector.
291 // They are sorted in ascending order on this function's return.
292 Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
293                                        const NameAttrList& function,
294                                        const FunctionBody** fbody,
295                                        std::vector<int>* constant_arg_indices,
296                                        std::vector<int>* resource_arg_indices);
297 
298 // Given a NodeDef `node_def` returns true iff `node_def` has kXlaCompileAttr
299 // set.
300 bool CanCreateXlaKernel(const NodeDef& node_def);
301 
302 // Returns memory types for the input.
303 // `constant_arg_indices` and `resource_arg_indices` are sorted arrays of
304 // indices corresponding to constant and resource arguments respectively.
305 //
306 // One might wonder, about the case where a compile-time constant argument
307 // (which must be in host memory) is also used as an input into an op,
308 // e.g. `Add`, that expects its inputs in device memory. Here is how it
309 // works now.
310 // First, what do we mean by "op expects an input in XYZ memory"?
311 // There are two types of "ops" here: the tf2xla kernel and the HLO
312 // computation it builds. The tf2xla kernel needs to retrieve the actual
313 // numeric value of the compile-time constant tensors, so it really expects
314 // them to be on in host memory. However, for other inputs, it refers to them
315 // using xla::ComputationDataHandle, which is just a symbolic handle that
316 // xla::ComputationBuilder assigns. How does this handle gets assigned for
317 // constant arguments? Even constant arguments get an _Arg node in the graph
318 // instantiated for Function compilation. The tf2xla kernel for constant _Arg
319 // nodes takes the constant value, converts it to XlaLiteral, and feeds it
320 // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
321 // constant XlaLiteral is included in the HLO graph, and subsequently, in
322 // the actual executable, which is copied to the device before being
323 // executed. Thus, when this executable runs, the constant is available in
324 // device memory.
325 tensorflow::MemoryTypeVector GetInputMemoryTypes(
326     const tensorflow::FunctionBody* fbody,
327     absl::Span<int const> constant_arg_indices,
328     absl::Span<int const> resource_arg_indices);
329 
330 // Returns output memory types.
331 //
332 // XlaLaunch kernel keeps all outputs (including constants, which it copies),
333 // in device memory except for resources.
334 tensorflow::MemoryTypeVector GetOutputMemoryTypes(
335     const tensorflow::FunctionBody* fbody);
336 
337 // Check whether graph can trigger XLA compilation.
338 bool CanTriggerXlaCompilation(const GraphDef& graph);
339 
340 }  // namespace tensorflow
341 
342 #endif  // TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_
343