1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/compilability_check_util.h"
17
18 #include <algorithm>
19 #include <atomic>
20 #include <deque>
21 #include <iterator>
22 #include <limits>
23 #include <string>
24 #include <unordered_map>
25 #include <unordered_set>
26 #include <utility>
27
28 #include "absl/algorithm/container.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_join.h"
33 #include "absl/strings/string_view.h"
34 #include "tensorflow/compiler/jit/defs.h"
35 #include "tensorflow/compiler/jit/device_util.h"
36 #include "tensorflow/compiler/jit/flags.h"
37 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
38 #include "tensorflow/compiler/jit/xla_activity.pb.h"
39 #include "tensorflow/compiler/jit/xla_activity_listener.h"
40 #include "tensorflow/compiler/jit/xla_cluster_util.h"
41 #include "tensorflow/compiler/tf2xla/const_analysis.h"
42 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
43 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
44 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
45 #include "tensorflow/compiler/xla/statusor.h"
46 #include "tensorflow/compiler/xla/union_find.h"
47 #include "tensorflow/compiler/xla/util.h"
48 #include "tensorflow/core/common_runtime/function.h"
49 #include "tensorflow/core/common_runtime/graph_constructor.h"
50 #include "tensorflow/core/framework/attr_value.pb.h"
51 #include "tensorflow/core/framework/bounds_check.h"
52 #include "tensorflow/core/framework/graph_def_util.h"
53 #include "tensorflow/core/framework/memory_types.h"
54 #include "tensorflow/core/framework/node_def.pb.h"
55 #include "tensorflow/core/framework/op_kernel.h"
56 #include "tensorflow/core/framework/types.h"
57 #include "tensorflow/core/graph/algorithm.h"
58 #include "tensorflow/core/graph/control_flow.h"
59 #include "tensorflow/core/lib/gtl/cleanup.h"
60 #include "tensorflow/core/lib/strings/stringprintf.h"
61 #include "tensorflow/core/public/version.h"
62 #include "tensorflow/core/util/dump_graph.h"
63
64 namespace tensorflow {
65
66 namespace {
67
HasResourceInput(const Node & node)68 bool HasResourceInput(const Node& node) {
69 return absl::c_count(node.input_types(), DT_RESOURCE) != 0;
70 }
71
LogNotCompilable(const Node & node,absl::string_view reason="")72 void LogNotCompilable(const Node& node, absl::string_view reason = "") {
73 VLOG(3) << "Found uncompilable node " << node.name() << " (op "
74 << node.type_string() << ")" << (reason.empty() ? "" : ": ")
75 << reason;
76 }
77
MakeCallNodeFromAttribute(const Node & node,const std::string & attr_name,NodeDef * node_def)78 Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name,
79 NodeDef* node_def) {
80 const NameAttrList* name_attr;
81 TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), attr_name, &name_attr));
82 node_def->set_op(name_attr->name());
83 *(node_def->mutable_attr()) = name_attr->attr();
84 return Status::OK();
85 }
86
MakeCallNodesFromAttribute(const Node & node,absl::string_view attr_name,absl::string_view call_name)87 xla::StatusOr<std::vector<NodeDef>> MakeCallNodesFromAttribute(
88 const Node& node, absl::string_view attr_name,
89 absl::string_view call_name) {
90 std::vector<NameAttrList> attr_lists;
91 TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), attr_name, &attr_lists));
92
93 std::vector<NodeDef> out;
94 for (int i = 0; i < attr_lists.size(); i++) {
95 out.emplace_back();
96 NodeDef& inserted = out.back();
97 inserted.set_name(absl::StrCat(call_name, "_", i));
98 inserted.set_op(attr_lists[i].name());
99 *inserted.mutable_attr() = attr_lists[i].attr();
100 }
101 return out;
102 }
103
104 // Utility which searches for values in a sorted list by scanning over it once.
105 // No matter how many times ScanForValue is called, the list is scanned at most
106 // once. However, if a call to ScanForValue skips over a value, that value is
107 // not revisited in future calls to ScanForValue, so callers must take
108 // care to order their calls.
109 //
110 // Useful for merging multiple sorted lists in O(n) time.
111 class SinglePassSearch {
112 public:
113 // Creates a SinglePassSearch object that can be used to search in `values`.
114 // Does not take ownership of `values`. `values` must outlive this.
115 // `values` must be sorted.
SinglePassSearch(absl::Span<int const> values)116 explicit SinglePassSearch(absl::Span<int const> values)
117 : current_index_(0), values_(values) {}
118
119 // Scans forward in the vector looking for "value", updating the internal
120 // position in to the vector.
121 // Returns true iff the vector contains the given value at or after current
122 // position.
123 // Not thread-safe.
ScanForValue(int value)124 bool ScanForValue(int value) {
125 while (current_index_ < values_.size() &&
126 values_[current_index_] <= value) {
127 if (values_[current_index_] == value) {
128 current_index_++;
129 return true;
130 }
131 current_index_++;
132 }
133 return false;
134 }
135
136 private:
137 int current_index_;
138 const absl::Span<int const> values_;
139 };
140
141 } // anonymous namespace
142
143 RecursiveCompilabilityChecker::UncompilableNodesMap
FindUncompilableNodes(const Node & node,FunctionLibraryRuntime * lib_runtime,const std::vector<RecursiveCompilabilityChecker::StackFrame> * node_stack_trace) const144 RecursiveCompilabilityChecker::FindUncompilableNodes(
145 const Node& node, FunctionLibraryRuntime* lib_runtime,
146 const std::vector<RecursiveCompilabilityChecker::StackFrame>*
147 node_stack_trace) const {
148 std::vector<StackFrameView> stack_trace;
149 // If `node_stack_trace` is provided, that means `node` is inside
150 // a function body, and therefore, arg nodes and retval nodes are
151 // not considered uncompilable.
152 if (node_stack_trace != nullptr) {
153 for (const auto& frame : *node_stack_trace) {
154 stack_trace.emplace_back(
155 StackFrameView{frame.name, frame.function_name, frame.stack_trace});
156 }
157 }
158 stack_trace.emplace_back(
159 StackFrameView{node.name(), "", node.GetStackTrace()});
160
161 RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes;
162 IsCompilableNode(node, lib_runtime, &stack_trace,
163 /*encapsulating_function=*/nullptr, &uncompilable_nodes);
164 return uncompilable_nodes;
165 }
166
167 RecursiveCompilabilityChecker::UncompilableNodesMap
FindUncompilableNodes(const NodeDef & call_def,FunctionLibraryRuntime * lib_runtime,const std::vector<RecursiveCompilabilityChecker::StackFrame> * node_stack_trace) const168 RecursiveCompilabilityChecker::FindUncompilableNodes(
169 const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
170 const std::vector<RecursiveCompilabilityChecker::StackFrame>*
171 node_stack_trace) const {
172 // If `node_stack_trace` is provided, that means `call_def` is inside
173 // a function body, and therefore, arg nodes and retval nodes are
174 // not considered uncompilable.
175 std::vector<StackFrameView> stack_trace;
176 if (node_stack_trace != nullptr) {
177 for (const auto& frame : *node_stack_trace) {
178 stack_trace.emplace_back(
179 StackFrameView{frame.name, frame.function_name, frame.stack_trace});
180 }
181 }
182 stack_trace.emplace_back(StackFrameView{call_def.name(), "", nullptr});
183
184 RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes;
185 IsCompilableCall(call_def, lib_runtime, &stack_trace,
186 /*encapsulating_function=*/nullptr, &uncompilable_nodes);
187 return uncompilable_nodes;
188 }
189
HasXLAKernel(const Node & node,string * uncompilable_reason) const190 bool RecursiveCompilabilityChecker::HasXLAKernel(
191 const Node& node, string* uncompilable_reason) const {
192 // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
193 // is really a kind of function call and will be handled by
194 // IsCompilableCall().
195 if (node.type_string() == "SymbolicGradient") {
196 *uncompilable_reason =
197 "SymbolicGradient should be handled by IsCompilableCall().";
198 return false;
199 }
200
201 if (node.type_string() == "Const") {
202 const AttrValue* attr = node.attrs().Find("dtype");
203 if (!op_filter_.allow_string_consts && attr != nullptr &&
204 attr->type() == DT_STRING) {
205 *uncompilable_reason =
206 "Const op with type DT_STRING is not supported by XLA.";
207 return false;
208 }
209 }
210
211 // XLA does not offer guaranteed aliasing between the input and output of the
212 // XLA cluster so it can't implement the forward-tensor-ref semantic. Leave
213 // such nodes out of XLA clusters.
214 if (HasForwardedRefInput(node)) {
215 VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast.";
216 *uncompilable_reason = "Identity with unsafe cast.";
217 return false;
218 }
219
220 Status s = FindKernelDef(jit_device_type_, node.def(), nullptr, nullptr);
221 if (!s.ok()) {
222 *uncompilable_reason = s.error_message();
223 return false;
224 }
225 return true;
226 }
227
228 // Tests whether 'if_node' is compilable. Every operator in the then_branch and
229 // else_branch functions must be compilable for 'if_node' to be compilable.
IsCompilableIf(const Node & if_node,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const230 bool RecursiveCompilabilityChecker::IsCompilableIf(
231 const Node& if_node, FunctionLibraryRuntime* lib_runtime,
232 std::vector<StackFrameView>* stack_trace,
233 NameAttrList* encapsulating_function,
234 RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
235 const {
236 bool is_compilable = true;
237 is_compilable &= ExtractNodeDefAndCheckCompilability(
238 if_node, "then_branch", "if_then", encapsulating_function, lib_runtime,
239 stack_trace, uncompilable_nodes);
240 if (!uncompilable_nodes && !is_compilable) return is_compilable;
241
242 is_compilable &= ExtractNodeDefAndCheckCompilability(
243 if_node, "else_branch", "if_else", encapsulating_function, lib_runtime,
244 stack_trace, uncompilable_nodes);
245
246 return is_compilable;
247 }
248
IsCompilableCase(const Node & case_node,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const249 bool RecursiveCompilabilityChecker::IsCompilableCase(
250 const Node& case_node, FunctionLibraryRuntime* lib_runtime,
251 std::vector<StackFrameView>* stack_trace,
252 NameAttrList* encapsulating_function,
253 RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
254 const {
255 xla::StatusOr<std::vector<NodeDef>> calls =
256 MakeCallNodesFromAttribute(case_node, "branches", "branch");
257 if (!calls.ok()) {
258 VLOG(2) << "Rejecting node " << case_node.name() << ": "
259 << "missing attribute 'branches'";
260 return false;
261 }
262
263 bool is_compilable = true;
264
265 for (const NodeDef& call : *calls) {
266 is_compilable &=
267 IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function,
268 uncompilable_nodes);
269 }
270 return is_compilable;
271 }
272
273 // Tests whether 'while_node' is a completely compilable loop.
274 // Every operator in the condition and body functions must be compilable for a
275 // while loop to be compilable.
IsCompilableWhile(const Node & while_node,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const276 bool RecursiveCompilabilityChecker::IsCompilableWhile(
277 const Node& while_node, FunctionLibraryRuntime* lib_runtime,
278 std::vector<StackFrameView>* stack_trace,
279 NameAttrList* encapsulating_function,
280 RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
281 const {
282 bool is_compilable = true;
283 is_compilable &= ExtractNodeDefAndCheckCompilability(
284 while_node, "cond", "while_cond", encapsulating_function, lib_runtime,
285 stack_trace, uncompilable_nodes);
286
287 if (!uncompilable_nodes && !is_compilable) return is_compilable;
288
289 is_compilable &= ExtractNodeDefAndCheckCompilability(
290 while_node, "body", "while_body", encapsulating_function, lib_runtime,
291 stack_trace, uncompilable_nodes);
292
293 return is_compilable;
294 }
295
ExtractNodeDefAndCheckCompilability(const Node & node,const std::string & attr_name,const std::string & call_name,NameAttrList * encapsulating_function,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const296 bool RecursiveCompilabilityChecker::ExtractNodeDefAndCheckCompilability(
297 const Node& node, const std::string& attr_name,
298 const std::string& call_name, NameAttrList* encapsulating_function,
299 FunctionLibraryRuntime* lib_runtime,
300 std::vector<StackFrameView>* stack_trace,
301 RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
302 const {
303 NodeDef call;
304 call.set_name(call_name);
305 if (!MakeCallNodeFromAttribute(node, attr_name, &call).ok()) {
306 const auto uncompilable_reason = absl::StrCat(
307 "missing '", attr_name, "' attribute from node", node.name());
308 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
309 encapsulating_function, uncompilable_nodes);
310 VLOG(2) << "Rejecting node " << node.name() << ": " << uncompilable_reason
311 << ".";
312 return false;
313 }
314 if (!IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function,
315 uncompilable_nodes)) {
316 VLOG(2) << "Rejecting node " << node.name()
317 << ": can't compile : " << call.op();
318 return false;
319 }
320 return true;
321 }
322
323 // Tests whether 'call_def' is a call to a completely compilable function.
324 // Every operator in the function must be compilable for a function to be
325 // compilable.
IsCompilableCall(const NodeDef & call_def,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const326 bool RecursiveCompilabilityChecker::IsCompilableCall(
327 const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
328 std::vector<StackFrameView>* stack_trace,
329 NameAttrList* encapsulating_function,
330 RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
331 const {
332 if (stack_trace->size() > kMaxRecursionDepth) {
333 std::string uncompilable_reason = "function depth limit exceeded";
334 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
335 encapsulating_function, uncompilable_nodes);
336 VLOG(2) << "Rejecting " << call_def.op() << ": " << uncompilable_reason
337 << ".";
338 return false;
339 }
340
341 FunctionLibraryRuntime::Handle handle;
342 Status s;
343 NameAttrList function;
344 s = NameAndAttrsFromFunctionCall(call_def, &function);
345 if (s.ok()) {
346 s = lib_runtime->Instantiate(function.name(), AttrSlice(&function.attr()),
347 &handle);
348 }
349 if (!s.ok()) {
350 std::string uncompilable_reason =
351 absl::StrCat("could not instantiate call: '", function.name(), "'");
352 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
353 encapsulating_function, uncompilable_nodes);
354 VLOG(2) << "Rejecting " << call_def.DebugString() << ": "
355 << uncompilable_reason << " : " << s;
356 return false;
357 }
358
359 auto release_handle_on_return = gtl::MakeCleanup(
360 [&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
361 const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
362 bool is_compilable = true;
363 for (const Node* node : fbody->graph->op_nodes()) {
364 stack_trace->emplace_back(
365 StackFrameView{node->name(), function.name(), node->GetStackTrace()});
366 is_compilable &= IsCompilableNode(*node, lib_runtime, stack_trace,
367 &function, uncompilable_nodes);
368 stack_trace->pop_back();
369 if (!uncompilable_nodes && !is_compilable) return is_compilable;
370 }
371
372 return is_compilable;
373 }
374
OpIsInaccurate(const Node & node) const375 bool RecursiveCompilabilityChecker::OpIsInaccurate(const Node& node) const {
376 // b/127344411: SelfAdjointEigV2 and Svd precision issues.
377 return node.type_string() == "SelfAdjointEigV2" ||
378 node.type_string() == "Svd";
379 }
380
OpIsSlow(const Node & node) const381 bool RecursiveCompilabilityChecker::OpIsSlow(const Node& node) const {
382 // b/128001705: SelfAdjointEigV2 and Svd performance issues.
383 // b/135640736: MatrixInverse performance issues.
384 // b/111271662: MatrixSolve performance issues.
385 // https://github.com/tensorflow/tensorflow/pull/31012:
386 // ResizeNearestNeighbor, ResizeBilinear, and ResizeBilinearGrad sometimes
387 // create convolutions too large for CuDNN to handle.
388 return node.type_string() == "SelfAdjointEigV2" ||
389 node.type_string() == "Svd" || node.type_string() == "Qr" ||
390 node.type_string() == "MatrixInverse" ||
391 node.type_string() == "MatrixSolve" ||
392 node.type_string() == "ResizeNearestNeighbor" ||
393 node.type_string() == "ResizeBilinear" ||
394 node.type_string() == "ResizeBilinearGrad";
395 }
396
IsCompilableNode(const Node & node,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const397 bool RecursiveCompilabilityChecker::IsCompilableNode(
398 const Node& node, FunctionLibraryRuntime* lib_runtime,
399 std::vector<StackFrameView>* stack_trace,
400 NameAttrList* encapsulating_function,
401 RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
402 const {
403 auto stack_depth = stack_trace->size();
404 if (node.IsSource() || node.IsSink()) {
405 absl::string_view uncompilable_reason = "source or sink node";
406 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
407 encapsulating_function, uncompilable_nodes);
408 LogNotCompilable(node, uncompilable_reason);
409 return false;
410 }
411
412 // _Arg nodes in a top-level function represent feeds and _Retval nodes in a
413 // top-level function represent fetches.
414 if (stack_depth == 1 &&
415 (node.type_string() == "_Arg" || node.type_string() == "_Retval")) {
416 absl::string_view uncompilable_reason = "top level _Arg or _Retval";
417 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
418 encapsulating_function, uncompilable_nodes);
419 LogNotCompilable(node, uncompilable_reason);
420 return false;
421 }
422
423 if (node.attrs().Find("_scoped_allocator") ||
424 node.attrs().Find("_forward_from")) {
425 // TODO(b/128858118): XLA does not support _scoped_allocator and
426 // _forward_from.
427 absl::string_view uncompilable_reason =
428 "_scoped_allocator or _forward_from attribute";
429 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
430 encapsulating_function, uncompilable_nodes);
431 LogNotCompilable(node, uncompilable_reason);
432 return false;
433 }
434
435 string uncompilable_reason;
436 if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), node)) {
437 if (!IsCompilableCall(node.def(), lib_runtime, stack_trace,
438 encapsulating_function, uncompilable_nodes)) {
439 LogNotCompilable(node, "unsupported function");
440 return false;
441 }
442 } else if (!HasXLAKernel(node, &uncompilable_reason)) {
443 MaybeMarkUncompilableNode(
444 absl::StrCat("unsupported op: ", uncompilable_reason), *stack_trace,
445 encapsulating_function, uncompilable_nodes);
446 LogNotCompilable(node, uncompilable_reason);
447 return false;
448 }
449
450 if (node.IsWhileNode() &&
451 !IsCompilableWhile(node, lib_runtime, stack_trace, encapsulating_function,
452 uncompilable_nodes)) {
453 LogNotCompilable(node, "unsupported while");
454 return false;
455 }
456
457 if (node.IsIfNode() &&
458 !IsCompilableIf(node, lib_runtime, stack_trace, encapsulating_function,
459 uncompilable_nodes)) {
460 LogNotCompilable(node, "unsupported if");
461 return false;
462 }
463
464 if (op_filter_.require_always_compilable && node.IsCaseNode() &&
465 !IsCompilableCase(node, lib_runtime, stack_trace, encapsulating_function,
466 uncompilable_nodes)) {
467 LogNotCompilable(node, "unsupported case");
468 return false;
469 }
470
471 if (!op_filter_.allow_stateful_rng_ops &&
472 IsStatefulRandomOp(node.type_string())) {
473 absl::string_view uncompilable_reason = "stateful random op";
474 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
475 encapsulating_function, uncompilable_nodes);
476 LogNotCompilable(node, uncompilable_reason);
477 return false;
478 }
479
480 if (!op_filter_.allow_control_trigger && node.IsControlTrigger()) {
481 absl::string_view uncompilable_reason = "not allowed control trigger";
482 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
483 encapsulating_function, uncompilable_nodes);
484 LogNotCompilable(node, uncompilable_reason);
485 return false;
486 }
487
488 if (!op_filter_.allow_eliding_assert_and_checknumerics_ops &&
489 IsAssertOrCheckNumerics(node.type_string())) {
490 absl::string_view uncompilable_reason = "Assert or CheckNumerics";
491 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
492 encapsulating_function, uncompilable_nodes);
493 LogNotCompilable(node, uncompilable_reason);
494 return false;
495 }
496
497 if (!op_filter_.allow_ops_producing_or_consuming_variant &&
498 OpProducesOrConsumesVariant(node)) {
499 absl::string_view uncompilable_reason = "DT_VARIANT producer/consumer";
500 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
501 encapsulating_function, uncompilable_nodes);
502 LogNotCompilable(node, uncompilable_reason);
503 return false;
504 }
505
506 if (!op_filter_.allow_stack_ops && IsStackOp(node)) {
507 absl::string_view uncompilable_reason = "Stack op";
508 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
509 encapsulating_function, uncompilable_nodes);
510 LogNotCompilable(node, uncompilable_reason);
511 return false;
512 }
513
514 if (!op_filter_.allow_tensor_array_ops && IsTensorArrayOp(node)) {
515 absl::string_view uncompilable_reason = "TensorArray op";
516 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
517 encapsulating_function, uncompilable_nodes);
518 LogNotCompilable(node, uncompilable_reason);
519 return false;
520 }
521
522 if (!op_filter_.allow_resource_ops_in_called_functions && stack_depth > 1 &&
523 HasResourceInput(node)) {
524 absl::string_view uncompilable_reason =
525 "resource variable op in called function";
526 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
527 encapsulating_function, uncompilable_nodes);
528 LogNotCompilable(node, uncompilable_reason);
529 return false;
530 }
531
532 if (!op_filter_.allow_inaccurate_ops && OpIsInaccurate(node)) {
533 absl::string_view uncompilable_reason =
534 "operation with numerical accuracy issues";
535 BroadcastOptimizationRemark(XlaOptimizationRemark::INACCURATE_OPERATION,
536 node.DebugString())
537 .IgnoreError();
538 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
539 encapsulating_function, uncompilable_nodes);
540 LogNotCompilable(node, uncompilable_reason);
541 return false;
542 }
543
544 if (!op_filter_.allow_slow_ops && OpIsSlow(node)) {
545 absl::string_view uncompilable_reason = "slow operation";
546 BroadcastOptimizationRemark(XlaOptimizationRemark::SLOW_OPERATION,
547 node.DebugString())
548 .IgnoreError();
549 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
550 encapsulating_function, uncompilable_nodes);
551 LogNotCompilable(node, uncompilable_reason);
552 return false;
553 }
554
555 return true;
556 }
557
CreateOperationFilter(const XlaOpRegistry::DeviceRegistration & registration)558 RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
559 const XlaOpRegistry::DeviceRegistration& registration) {
560 RecursiveCompilabilityChecker::OperationFilter op_filter;
561 op_filter.allow_resource_ops_in_called_functions =
562 registration.cluster_resource_variable_ops_unsafely;
563 op_filter.allow_stack_ops = registration.cluster_stack_ops;
564 op_filter.allow_tensor_array_ops = registration.cluster_tensor_array_ops;
565 op_filter.allow_stateful_rng_ops = registration.cluster_stateful_rng_ops;
566 op_filter.allow_control_trigger = registration.cluster_control_trigger;
567 op_filter.allow_eliding_assert_and_checknumerics_ops =
568 registration.elide_assert_and_checknumerics;
569 op_filter.allow_ops_producing_or_consuming_variant =
570 registration.cluster_variant_ops;
571 op_filter.allow_slow_ops = registration.cluster_slow_ops;
572 op_filter.allow_inaccurate_ops = registration.cluster_inaccurate_ops;
573 return op_filter;
574 }
575
MaybeMarkUncompilableNode(const absl::string_view reason,const std::vector<StackFrameView> & stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes)576 /*static*/ void RecursiveCompilabilityChecker::MaybeMarkUncompilableNode(
577 const absl::string_view reason,
578 const std::vector<StackFrameView>& stack_trace,
579 NameAttrList* encapsulating_function,
580 RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes) {
581 if (!uncompilable_nodes) return;
582
583 UncompilableNodeInfo node_info;
584 node_info.uncompilable_reason = std::string(reason);
585 absl::c_transform(stack_trace, std::back_inserter(node_info.stack_trace),
586 [](const StackFrameView& stack_element) {
587 return StackFrame{
588 std::string(stack_element.name),
589 std::string(stack_element.function_name),
590 stack_element.stack_trace};
591 });
592
593 node_info.name = std::string(stack_trace.back().name);
594 auto function =
595 encapsulating_function ? *encapsulating_function : NameAttrList();
596 auto function_identifier = function.ShortDebugString();
597
598 auto it = uncompilable_nodes->find(function_identifier);
599 if (it == uncompilable_nodes->end()) {
600 std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
601 uncompilable_node_info{std::move(node_info)};
602 uncompilable_nodes->emplace(
603 std::move(function_identifier),
604 std::make_pair(function, std::move(uncompilable_node_info)));
605 } else {
606 it->second.second.emplace_back(std::move(node_info));
607 }
608 }
609
610 // Returns `true` iff node has a given `attr` set to `true`. Returns `false`
611 // both for the missing attr, and the attr set to `false`.
HasBoolAttr(const NodeDef & node,const char * attr)612 static bool HasBoolAttr(const NodeDef& node, const char* attr) {
613 const auto& it = node.attr().find(attr);
614 return it != node.attr().end() && it->second.b();
615 }
616
CanCreateXlaKernel(const NodeDef & node_def)617 bool CanCreateXlaKernel(const NodeDef& node_def) {
618 return HasBoolAttr(node_def, kXlaMustCompileAttr);
619 }
620
GetBodyAndConstantsAndResources(FunctionLibraryRuntime * flr,const NameAttrList & function,const FunctionBody ** fbody,std::vector<int> * constant_arg_indices,std::vector<int> * resource_arg_indices)621 Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
622 const NameAttrList& function,
623 const FunctionBody** fbody,
624 std::vector<int>* constant_arg_indices,
625 std::vector<int>* resource_arg_indices) {
626 FunctionLibraryRuntime::Handle handle;
627 TF_RETURN_IF_ERROR(
628 flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
629 *fbody = flr->GetFunctionBody(handle);
630 CHECK(*fbody); // Can't be nullptr since we just instantiated it.
631 const DataTypeVector& arg_types = (*fbody)->arg_types;
632 std::vector<bool> const_args(arg_types.size());
633 // If we can't analyze the const args. Bail out.
634 TF_RETURN_IF_ERROR(
635 BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
636 /*compile_time_const_nodes=*/nullptr, flr));
637
638 for (size_t i = 0; i < const_args.size(); ++i) {
639 if (const_args[i]) {
640 constant_arg_indices->push_back(i);
641 }
642 }
643
644 // There can be hundreds of resource variables. Reserve the space for them.
645 // We don't reserve for constants above as they are usually few.
646 resource_arg_indices->reserve(arg_types.size());
647 for (size_t i = 0; i < arg_types.size(); ++i) {
648 if (arg_types[i] == DT_RESOURCE) {
649 resource_arg_indices->push_back(i);
650 }
651 }
652
653 return Status::OK();
654 }
655
GetInputMemoryTypes(const tensorflow::FunctionBody * fbody,absl::Span<int const> constant_arg_indices,absl::Span<int const> resource_arg_indices)656 tensorflow::MemoryTypeVector GetInputMemoryTypes(
657 const tensorflow::FunctionBody* fbody,
658 absl::Span<int const> constant_arg_indices,
659 absl::Span<int const> resource_arg_indices) {
660 // Set input and output memory types.
661 tensorflow::MemoryTypeVector input_memory_types(fbody->arg_types.size(),
662 tensorflow::DEVICE_MEMORY);
663 // These indices are used only for optimization purposes. They allow us
664 // to loop over constant_arg_indices and resource_arg_indices only once
665 // while iterating over all the function arguments checking if it is a
666 // resource or a constant.
667 // The reason we optimized this code is because functions can have a lot of
668 // captured arguments. For example, the backward pass of ResNet50 takes in all
669 // 214 variables and a similar number of activations.
670 SinglePassSearch constants_search(constant_arg_indices);
671 SinglePassSearch resources_search(resource_arg_indices);
672 for (size_t i = 0; i < fbody->arg_types.size(); ++i) {
673 if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
674 // Compile-time constants and resource handles are expected to be in
675 // host memory.
676 input_memory_types[i] = tensorflow::HOST_MEMORY;
677 }
678 }
679 return input_memory_types;
680 }
681
GetOutputMemoryTypes(const tensorflow::FunctionBody * fbody)682 tensorflow::MemoryTypeVector GetOutputMemoryTypes(
683 const tensorflow::FunctionBody* fbody) {
684 tensorflow::MemoryTypeVector output_memory_types(fbody->ret_types.size(),
685 tensorflow::DEVICE_MEMORY);
686 for (size_t i = 0; i < fbody->ret_types.size(); ++i) {
687 if (fbody->ret_types[i] == tensorflow::DT_RESOURCE) {
688 output_memory_types[i] = tensorflow::HOST_MEMORY;
689 }
690 }
691 return output_memory_types;
692 }
693
694 static auto const ops_triggering_xla_compilation =
695 new absl::flat_hash_set<std::string>{"XlaBroadcastHelper",
696 "XlaConv",
697 "XlaDequantize",
698 "XlaDot",
699 "XlaDynamicSlice",
700 "XlaDynamicUpdateSlice",
701 "XlaEinsum",
702 "XlaGather",
703 "XlaIf",
704 "XlaKeyValueSort",
705 "XlaPad",
706 "XlaRecv",
707 "XlaReduce",
708 "XlaReduceWindow",
709 "XlaReplicaId",
710 "XlaScatter",
711 "XlaSelectAndScatter",
712 "XlaSelfAdjointEig",
713 "XlaSend",
714 "XlaSharding",
715 "XlaSort",
716 "XlaSpmdFullToShardShape",
717 "XlaSpmdShardToFullShape",
718 "XlaSvd",
719 "XlaWhile"};
720
NodeCanTriggerXlaCompilation(const NodeDef & node)721 static bool NodeCanTriggerXlaCompilation(const NodeDef& node) {
722 return node.attr().find(kXlaClusterIdAttr) != node.attr().end() ||
723 HasBoolAttr(node, kXlaMustCompileAttr) ||
724 HasBoolAttr(node, kXlaCompileAttr) ||
725 HasBoolAttr(node, kXlaScopeAttr) ||
726 HasBoolAttr(node, kXlaInternalScopeAttr) ||
727 ops_triggering_xla_compilation->count(node.op());
728 }
729
CanTriggerXlaCompilation(const GraphDef & graph)730 bool CanTriggerXlaCompilation(const GraphDef& graph) {
731 for (const FunctionDef& function : graph.library().function()) {
732 for (const NodeDef& node : function.node_def()) {
733 if (NodeCanTriggerXlaCompilation(node)) {
734 return true;
735 }
736 }
737 }
738
739 for (const NodeDef& node : graph.node()) {
740 if (NodeCanTriggerXlaCompilation(node)) {
741 return true;
742 }
743 }
744
745 return false;
746 }
747
748 } // namespace tensorflow
749