1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_ 17 #define TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_ 18 19 #include <string> 20 21 #include "absl/algorithm/container.h" 22 #include "absl/strings/string_view.h" 23 #include "absl/types/optional.h" 24 #include "tensorflow/compiler/jit/defs.h" 25 #include "tensorflow/compiler/jit/device_util.h" 26 #include "tensorflow/compiler/jit/flags.h" 27 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" 28 #include "tensorflow/compiler/tf2xla/const_analysis.h" 29 #include "tensorflow/compiler/tf2xla/resource_operation_table.h" 30 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 31 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" 32 #include "tensorflow/compiler/xla/statusor.h" 33 #include "tensorflow/compiler/xla/union_find.h" 34 #include "tensorflow/compiler/xla/util.h" 35 #include "tensorflow/core/common_runtime/function.h" 36 #include "tensorflow/core/common_runtime/graph_constructor.h" 37 #include "tensorflow/core/framework/attr_value.pb.h" 38 #include "tensorflow/core/framework/bounds_check.h" 39 #include "tensorflow/core/framework/function.h" 40 #include "tensorflow/core/framework/graph_def_util.h" 41 #include "tensorflow/core/framework/memory_types.h" 42 #include "tensorflow/core/framework/node_def.pb.h" 43 #include "tensorflow/core/framework/op_kernel.h" 44 #include "tensorflow/core/framework/types.h" 45 #include "tensorflow/core/framework/types.pb.h" 46 #include "tensorflow/core/graph/algorithm.h" 47 #include "tensorflow/core/graph/control_flow.h" 48 #include "tensorflow/core/graph/graph.h" 49 #include "tensorflow/core/lib/gtl/cleanup.h" 50 #include "tensorflow/core/lib/strings/stringprintf.h" 51 #include "tensorflow/core/public/version.h" 52 #include "tensorflow/core/util/dump_graph.h" 53 54 namespace tensorflow { 55 // Checks whether a TF node can be compiled or not. "Recursive" as in for call 56 // and functional while nodes it recursively checks whether the callee functions 57 // can be compiled. 58 class RecursiveCompilabilityChecker { 59 public: 60 // Contains node name and function name. If the node is not inside a function 61 // body, function name is an empty string. 62 struct StackFrame { 63 std::string name; 64 std::string function_name; 65 std::shared_ptr<AbstractStackTrace> stack_trace; 66 }; 67 68 // Contains information about uncompilable node inside a function body. 69 struct UncompilableNodeInfo { 70 std::string name; 71 // A list representing a stacktrace from the highest level node in 72 // increasing call depth to immediate node that fails the 73 // compilability checker. 74 std::vector<StackFrame> stack_trace; 75 std::string uncompilable_reason; 76 }; 77 78 // Aggregates information about what kinds of ops are allowed. 79 struct OperationFilter { // TODO(lzr): Add AllowEverything() helper. 80 // Whether resource variable ops are allowed are allowed in callees. We do 81 // not allow resource variable ops in called functions (either as direct TF 82 // calls or as higher order control flow ops) because we do not yet model 83 // their memory effects in jit/resource_operation_safety_analysis. 84 bool allow_resource_ops_in_called_functions = false; 85 86 // Whether Stack operations are allowed. We avoid auto-clustering Stack 87 // operations in general because we do not support snapshotting them. 88 // 89 // TODO(b/112837194): This restriction can be lifted with some work. 90 bool allow_stack_ops = false; 91 92 // Whether TensorArray operations are allowed. We avoid auto-clustering 93 // TensorArray operations in general because we do not support snapshotting 94 // them. 95 // 96 // TODO(b/112837194): This restriction can be lifted with some work. 97 bool allow_tensor_array_ops = false; 98 99 // Whether stateful RNG ops are allowed. XLA's RNG does not have the same 100 // seeding behavior as TensorFlow's RNG (b/34749654). So we avoid 101 // auto-clustering stateful RNG ops. 102 bool allow_stateful_rng_ops = false; 103 104 // TODO(b/118970344): Whether ControlTrigger ops are allowed. It is unsound 105 // to cluster ControlTrigger because of how we use deadness analysis. 106 bool allow_control_trigger = false; 107 108 // Whether it is okay to "cluster" Assert and CheckNumerics by simply 109 // removing them (they're not removed during clustering, but their 110 // XlaOpKernel is a no-op kernel). We avoid auto-clustering these ops so 111 // that the user is not surprised when XLA is implicitly enabled. If the 112 // user explicitly specifies to use XLA, it is fine to resort to a dummy 113 // implementation. Currently Assert and CheckNumerics ops have dummy XLA 114 // implementations. 115 bool allow_eliding_assert_and_checknumerics_ops = false; 116 117 // Whether ops that produce or consume DT_VARIANT values are allowed. We 118 // don't auto-cluster these ops because we don't yet support live-in or 119 // live-out DT_VARIANT values. 120 bool allow_ops_producing_or_consuming_variant = false; 121 122 // Whether ops known to be slow on XLA-GPU should be considered compilable. 123 bool allow_slow_ops = false; 124 125 // Whether ops known to have numerical accuracy issues should be considered 126 // compilable.. 127 bool allow_inaccurate_ops = false; 128 129 // Require the function to be always compilable, regardless whether some 130 // control flow branches might be dead for a given input. 131 bool require_always_compilable = false; 132 133 // Whether string constants are compilable. 134 bool allow_string_consts = true; 135 }; 136 RecursiveCompilabilityChecker(OperationFilter op_filter,DeviceType jit_device_type)137 RecursiveCompilabilityChecker(OperationFilter op_filter, 138 DeviceType jit_device_type) 139 : op_filter_(std::move(op_filter)), 140 jit_device_type_(std::move(jit_device_type)) {} 141 142 using UncompilableNodesMap = 143 std::map<std::string, 144 std::pair<NameAttrList, std::vector<UncompilableNodeInfo>>>; 145 146 // Returns a map where the key is the function identifier(short debug 147 // string) of the function encapsulating the uncompilable nodes, and the 148 // value is a pair of NameAttrList of the function and a vector of 149 // uncompilable node info. When uncompilable node is not inside any 150 // function call nodes, then key is a ShortDebugString() of an empty 151 // NameAttrList. 152 // 153 // Also, when `node` is inside a function body, users can set 154 // `node_stack_trace` to provide an additional context for `node`'s 155 // placement within the outer most graph. 156 UncompilableNodesMap FindUncompilableNodes( 157 const Node& node, FunctionLibraryRuntime* lib_runtime, 158 const std::vector<StackFrame>* node_stack_trace = nullptr) const; 159 160 // Returns a map where the key is the function identifier(short debug 161 // string) of the function encapsulating the uncompilable nodes, and the 162 // value is a pair of NameAttrList of the function and a vector of 163 // uncompilable node info. When uncompilable node is not inside any 164 // function call nodes, then key is a ShortDebugString() of an empty 165 // NameAttrList. 166 // 167 // Also, when `node` is inside a function body, users can set 168 // `node_stack_trace` to provide an additional context for `node`'s 169 // placement within the outer most graph. 170 UncompilableNodesMap FindUncompilableNodes( 171 const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime, 172 const std::vector<StackFrame>* node_stack_trace = nullptr) const; 173 174 // Returns true if `node` can be compiled by XLA. IsCompilableNode(const Node & node,FunctionLibraryRuntime * lib_runtime)175 bool IsCompilableNode(const Node& node, 176 FunctionLibraryRuntime* lib_runtime) const { 177 std::vector<StackFrameView> stack_trace; 178 stack_trace.emplace_back(StackFrameView{node.name(), ""}); 179 return IsCompilableNode(node, lib_runtime, &stack_trace); 180 } 181 182 // Returns true if `call_def` can be compiled by XLA. It is assumed that 183 // `call_def` is a call operation. IsCompilableCall(const NodeDef & call_def,FunctionLibraryRuntime * lib_runtime)184 bool IsCompilableCall(const NodeDef& call_def, 185 FunctionLibraryRuntime* lib_runtime) { 186 std::vector<StackFrameView> stack_trace; 187 stack_trace.emplace_back(StackFrameView{call_def.name(), ""}); 188 return IsCompilableCall(call_def, lib_runtime, &stack_trace); 189 } 190 191 // Returns true if XLA supports this Op, but we don't want to cluster it (ie: 192 // due to performance or correctness concerns). 193 bool OpIsInaccurate(const Node& node) const; 194 bool OpIsSlow(const Node& node) const; 195 196 private: 197 struct StackFrameView { 198 absl::string_view name; 199 absl::string_view function_name; 200 std::shared_ptr<AbstractStackTrace> stack_trace; 201 }; 202 203 bool IsCompilableNode( 204 const Node& node, FunctionLibraryRuntime* lib_runtime, 205 std::vector<StackFrameView>* stack_trace, 206 NameAttrList* encapsulating_function = nullptr, 207 UncompilableNodesMap* uncompilable_nodes = nullptr) const; 208 bool IsCompilableCall( 209 const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime, 210 std::vector<StackFrameView>* stack_trace, 211 NameAttrList* encapsulating_function = nullptr, 212 UncompilableNodesMap* uncompilable_nodes = nullptr) const; 213 bool IsCompilableIf(const Node& if_node, FunctionLibraryRuntime* lib_runtime, 214 std::vector<StackFrameView>* stack_trace, 215 NameAttrList* encapsulating_function, 216 UncompilableNodesMap* uncompilable_nodes) const; 217 bool IsCompilableWhile(const Node& while_node, 218 FunctionLibraryRuntime* lib_runtime, 219 std::vector<StackFrameView>* stack_trace, 220 NameAttrList* encapsulating_function, 221 UncompilableNodesMap* uncompilable_nodes) const; 222 223 // Tests whether 'case_node' is compilable. Every operator in all branches 224 // must be compilable. 225 bool IsCompilableCase(const Node& case_node, 226 FunctionLibraryRuntime* lib_runtime, 227 std::vector<StackFrameView>* stack_trace, 228 NameAttrList* encapsulating_function, 229 UncompilableNodesMap* uncompilable_nodes) const; 230 231 // Returns compilability of node def retrieved from `node`'s attribute with 232 // name `attr_name`. 233 bool ExtractNodeDefAndCheckCompilability( 234 const Node& node, const std::string& attr_name, 235 const std::string& call_name, NameAttrList* encapsulating_function, 236 FunctionLibraryRuntime* lib_runtime, 237 std::vector<StackFrameView>* stack_trace, 238 UncompilableNodesMap* uncompilable_nodes) const; 239 IsStackOp(const Node & node)240 bool IsStackOp(const Node& node) const { 241 const XlaResourceOpInfo* op_info = 242 GetResourceOpInfoForOp(node.type_string()); 243 return op_info && op_info->resource_kind() == XlaResourceKind::kStack; 244 } 245 IsTensorArrayOp(const Node & node)246 bool IsTensorArrayOp(const Node& node) const { 247 const XlaResourceOpInfo* op_info = 248 GetResourceOpInfoForOp(node.type_string()); 249 return op_info && op_info->resource_kind() == XlaResourceKind::kTensorArray; 250 } 251 IsAssertOrCheckNumerics(absl::string_view op_name)252 bool IsAssertOrCheckNumerics(absl::string_view op_name) const { 253 return op_name == "Assert" || op_name == "CheckNumerics"; 254 } 255 IsStatefulRandomOp(absl::string_view op_name)256 bool IsStatefulRandomOp(absl::string_view op_name) const { 257 return op_name == "RandomUniform" || op_name == "RandomShuffle" || 258 op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" || 259 op_name == "TruncatedNormal" || op_name == "Multinomial"; 260 } 261 OpProducesOrConsumesVariant(const Node & node)262 bool OpProducesOrConsumesVariant(const Node& node) const { 263 auto is_variant = [](DataType dtype) { return dtype == DT_VARIANT; }; 264 return absl::c_any_of(node.input_types(), is_variant) || 265 absl::c_any_of(node.output_types(), is_variant); 266 } 267 268 bool HasXLAKernel(const Node& node, 269 string* uncompilable_reason = nullptr) const; 270 271 static void MaybeMarkUncompilableNode( 272 const absl::string_view reason, 273 const std::vector<StackFrameView>& stack_trace, 274 NameAttrList* encapsulating_function, 275 UncompilableNodesMap* uncompilable_nodes_map); 276 277 // Make sure we don't recurse infinitely on recursive functions. 278 const size_t kMaxRecursionDepth = 50; 279 280 const OperationFilter op_filter_; 281 const DeviceType jit_device_type_; 282 }; 283 284 RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter( 285 const XlaOpRegistry::DeviceRegistration& registration); 286 287 // Given a FunctionLibraryRuntime and a `function`, returns this function's body 288 // in `fbody` as well as the indices of its constant and resource arguments. 289 // `fbody` is owned by `flr`. 290 // `constant_arg_indices` and `resource_arg_indices` should be empty vector. 291 // They are sorted in ascending order on this function's return. 292 Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, 293 const NameAttrList& function, 294 const FunctionBody** fbody, 295 std::vector<int>* constant_arg_indices, 296 std::vector<int>* resource_arg_indices); 297 298 // Given a NodeDef `node_def` returns true iff `node_def` has kXlaCompileAttr 299 // set. 300 bool CanCreateXlaKernel(const NodeDef& node_def); 301 302 // Returns memory types for the input. 303 // `constant_arg_indices` and `resource_arg_indices` are sorted arrays of 304 // indices corresponding to constant and resource arguments respectively. 305 // 306 // One might wonder, about the case where a compile-time constant argument 307 // (which must be in host memory) is also used as an input into an op, 308 // e.g. `Add`, that expects its inputs in device memory. Here is how it 309 // works now. 310 // First, what do we mean by "op expects an input in XYZ memory"? 311 // There are two types of "ops" here: the tf2xla kernel and the HLO 312 // computation it builds. The tf2xla kernel needs to retrieve the actual 313 // numeric value of the compile-time constant tensors, so it really expects 314 // them to be on in host memory. However, for other inputs, it refers to them 315 // using xla::ComputationDataHandle, which is just a symbolic handle that 316 // xla::ComputationBuilder assigns. How does this handle gets assigned for 317 // constant arguments? Even constant arguments get an _Arg node in the graph 318 // instantiated for Function compilation. The tf2xla kernel for constant _Arg 319 // nodes takes the constant value, converts it to XlaLiteral, and feeds it 320 // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This 321 // constant XlaLiteral is included in the HLO graph, and subsequently, in 322 // the actual executable, which is copied to the device before being 323 // executed. Thus, when this executable runs, the constant is available in 324 // device memory. 325 tensorflow::MemoryTypeVector GetInputMemoryTypes( 326 const tensorflow::FunctionBody* fbody, 327 absl::Span<int const> constant_arg_indices, 328 absl::Span<int const> resource_arg_indices); 329 330 // Returns output memory types. 331 // 332 // XlaLaunch kernel keeps all outputs (including constants, which it copies), 333 // in device memory except for resources. 334 tensorflow::MemoryTypeVector GetOutputMemoryTypes( 335 const tensorflow::FunctionBody* fbody); 336 337 // Check whether graph can trigger XLA compilation. 338 bool CanTriggerXlaCompilation(const GraphDef& graph); 339 340 } // namespace tensorflow 341 342 #endif // TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_ 343