1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/compilability_check_util.h"
17 
18 #include <algorithm>
19 #include <atomic>
20 #include <deque>
21 #include <iterator>
22 #include <limits>
23 #include <string>
24 #include <unordered_map>
25 #include <unordered_set>
26 #include <utility>
27 
28 #include "absl/algorithm/container.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_join.h"
33 #include "absl/strings/string_view.h"
34 #include "tensorflow/compiler/jit/defs.h"
35 #include "tensorflow/compiler/jit/device_util.h"
36 #include "tensorflow/compiler/jit/flags.h"
37 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
38 #include "tensorflow/compiler/jit/xla_activity.pb.h"
39 #include "tensorflow/compiler/jit/xla_activity_listener.h"
40 #include "tensorflow/compiler/jit/xla_cluster_util.h"
41 #include "tensorflow/compiler/tf2xla/const_analysis.h"
42 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
43 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
44 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
45 #include "tensorflow/compiler/xla/statusor.h"
46 #include "tensorflow/compiler/xla/union_find.h"
47 #include "tensorflow/compiler/xla/util.h"
48 #include "tensorflow/core/common_runtime/function.h"
49 #include "tensorflow/core/common_runtime/graph_constructor.h"
50 #include "tensorflow/core/framework/attr_value.pb.h"
51 #include "tensorflow/core/framework/bounds_check.h"
52 #include "tensorflow/core/framework/graph_def_util.h"
53 #include "tensorflow/core/framework/memory_types.h"
54 #include "tensorflow/core/framework/node_def.pb.h"
55 #include "tensorflow/core/framework/op_kernel.h"
56 #include "tensorflow/core/framework/types.h"
57 #include "tensorflow/core/graph/algorithm.h"
58 #include "tensorflow/core/graph/control_flow.h"
59 #include "tensorflow/core/lib/gtl/cleanup.h"
60 #include "tensorflow/core/lib/strings/stringprintf.h"
61 #include "tensorflow/core/public/version.h"
62 #include "tensorflow/core/util/dump_graph.h"
63 
64 namespace tensorflow {
65 
66 namespace {
67 
HasResourceInput(const Node & node)68 bool HasResourceInput(const Node& node) {
69   return absl::c_count(node.input_types(), DT_RESOURCE) != 0;
70 }
71 
LogNotCompilable(const Node & node,absl::string_view reason="")72 void LogNotCompilable(const Node& node, absl::string_view reason = "") {
73   VLOG(3) << "Found uncompilable node " << node.name() << " (op "
74           << node.type_string() << ")" << (reason.empty() ? "" : ": ")
75           << reason;
76 }
77 
MakeCallNodeFromAttribute(const Node & node,const std::string & attr_name,NodeDef * node_def)78 Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name,
79                                  NodeDef* node_def) {
80   const NameAttrList* name_attr;
81   TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), attr_name, &name_attr));
82   node_def->set_op(name_attr->name());
83   *(node_def->mutable_attr()) = name_attr->attr();
84   return Status::OK();
85 }
86 
MakeCallNodesFromAttribute(const Node & node,absl::string_view attr_name,absl::string_view call_name)87 xla::StatusOr<std::vector<NodeDef>> MakeCallNodesFromAttribute(
88     const Node& node, absl::string_view attr_name,
89     absl::string_view call_name) {
90   std::vector<NameAttrList> attr_lists;
91   TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), attr_name, &attr_lists));
92 
93   std::vector<NodeDef> out;
94   for (int i = 0; i < attr_lists.size(); i++) {
95     out.emplace_back();
96     NodeDef& inserted = out.back();
97     inserted.set_name(absl::StrCat(call_name, "_", i));
98     inserted.set_op(attr_lists[i].name());
99     *inserted.mutable_attr() = attr_lists[i].attr();
100   }
101   return out;
102 }
103 
104 // Utility which searches for values in a sorted list by scanning over it once.
105 // No matter how many times ScanForValue is called, the list is scanned at most
106 // once. However, if a call to ScanForValue skips over a value, that value is
107 // not revisited in future calls to ScanForValue, so callers must take
108 // care to order their calls.
109 //
110 // Useful for merging multiple sorted lists in O(n) time.
111 class SinglePassSearch {
112  public:
113   // Creates a SinglePassSearch object that can be used to search in `values`.
114   // Does not take ownership of `values`. `values` must outlive this.
115   // `values` must be sorted.
SinglePassSearch(absl::Span<int const> values)116   explicit SinglePassSearch(absl::Span<int const> values)
117       : current_index_(0), values_(values) {}
118 
119   // Scans forward in the vector looking for "value", updating the internal
120   // position in to the vector.
121   // Returns true iff the vector contains the given value at or after current
122   // position.
123   // Not thread-safe.
ScanForValue(int value)124   bool ScanForValue(int value) {
125     while (current_index_ < values_.size() &&
126            values_[current_index_] <= value) {
127       if (values_[current_index_] == value) {
128         current_index_++;
129         return true;
130       }
131       current_index_++;
132     }
133     return false;
134   }
135 
136  private:
137   int current_index_;
138   const absl::Span<int const> values_;
139 };
140 
141 }  // anonymous namespace
142 
143 RecursiveCompilabilityChecker::UncompilableNodesMap
FindUncompilableNodes(const Node & node,FunctionLibraryRuntime * lib_runtime,const std::vector<RecursiveCompilabilityChecker::StackFrame> * node_stack_trace) const144 RecursiveCompilabilityChecker::FindUncompilableNodes(
145     const Node& node, FunctionLibraryRuntime* lib_runtime,
146     const std::vector<RecursiveCompilabilityChecker::StackFrame>*
147         node_stack_trace) const {
148   std::vector<StackFrameView> stack_trace;
149   // If `node_stack_trace` is provided, that means `node` is inside
150   // a function body, and therefore, arg nodes and retval nodes are
151   // not considered uncompilable.
152   if (node_stack_trace != nullptr) {
153     for (const auto& frame : *node_stack_trace) {
154       stack_trace.emplace_back(
155           StackFrameView{frame.name, frame.function_name, frame.stack_trace});
156     }
157   }
158   stack_trace.emplace_back(
159       StackFrameView{node.name(), "", node.GetStackTrace()});
160 
161   RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes;
162   IsCompilableNode(node, lib_runtime, &stack_trace,
163                    /*encapsulating_function=*/nullptr, &uncompilable_nodes);
164   return uncompilable_nodes;
165 }
166 
167 RecursiveCompilabilityChecker::UncompilableNodesMap
FindUncompilableNodes(const NodeDef & call_def,FunctionLibraryRuntime * lib_runtime,const std::vector<RecursiveCompilabilityChecker::StackFrame> * node_stack_trace) const168 RecursiveCompilabilityChecker::FindUncompilableNodes(
169     const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
170     const std::vector<RecursiveCompilabilityChecker::StackFrame>*
171         node_stack_trace) const {
172   // If `node_stack_trace` is provided, that means `call_def` is inside
173   // a function body, and therefore, arg nodes and retval nodes are
174   // not considered uncompilable.
175   std::vector<StackFrameView> stack_trace;
176   if (node_stack_trace != nullptr) {
177     for (const auto& frame : *node_stack_trace) {
178       stack_trace.emplace_back(
179           StackFrameView{frame.name, frame.function_name, frame.stack_trace});
180     }
181   }
182   stack_trace.emplace_back(StackFrameView{call_def.name(), "", nullptr});
183 
184   RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes;
185   IsCompilableCall(call_def, lib_runtime, &stack_trace,
186                    /*encapsulating_function=*/nullptr, &uncompilable_nodes);
187   return uncompilable_nodes;
188 }
189 
HasXLAKernel(const Node & node,string * uncompilable_reason) const190 bool RecursiveCompilabilityChecker::HasXLAKernel(
191     const Node& node, string* uncompilable_reason) const {
192   // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
193   // is really a kind of function call and will be handled by
194   // IsCompilableCall().
195   if (node.type_string() == "SymbolicGradient") {
196     *uncompilable_reason =
197         "SymbolicGradient should be handled by IsCompilableCall().";
198     return false;
199   }
200 
201   if (node.type_string() == "Const") {
202     const AttrValue* attr = node.attrs().Find("dtype");
203     if (!op_filter_.allow_string_consts && attr != nullptr &&
204         attr->type() == DT_STRING) {
205       *uncompilable_reason =
206           "Const op with type DT_STRING is not supported by XLA.";
207       return false;
208     }
209   }
210 
211   // XLA does not offer guaranteed aliasing between the input and output of the
212   // XLA cluster so it can't implement the forward-tensor-ref semantic.  Leave
213   // such nodes out of XLA clusters.
214   if (HasForwardedRefInput(node)) {
215     VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast.";
216     *uncompilable_reason = "Identity with unsafe cast.";
217     return false;
218   }
219 
220   Status s = FindKernelDef(jit_device_type_, node.def(), nullptr, nullptr);
221   if (!s.ok()) {
222     *uncompilable_reason = s.error_message();
223     return false;
224   }
225   return true;
226 }
227 
228 // Tests whether 'if_node' is compilable. Every operator in the then_branch and
229 // else_branch functions must be compilable for 'if_node' to be compilable.
IsCompilableIf(const Node & if_node,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const230 bool RecursiveCompilabilityChecker::IsCompilableIf(
231     const Node& if_node, FunctionLibraryRuntime* lib_runtime,
232     std::vector<StackFrameView>* stack_trace,
233     NameAttrList* encapsulating_function,
234     RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
235     const {
236   bool is_compilable = true;
237   is_compilable &= ExtractNodeDefAndCheckCompilability(
238       if_node, "then_branch", "if_then", encapsulating_function, lib_runtime,
239       stack_trace, uncompilable_nodes);
240   if (!uncompilable_nodes && !is_compilable) return is_compilable;
241 
242   is_compilable &= ExtractNodeDefAndCheckCompilability(
243       if_node, "else_branch", "if_else", encapsulating_function, lib_runtime,
244       stack_trace, uncompilable_nodes);
245 
246   return is_compilable;
247 }
248 
IsCompilableCase(const Node & case_node,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const249 bool RecursiveCompilabilityChecker::IsCompilableCase(
250     const Node& case_node, FunctionLibraryRuntime* lib_runtime,
251     std::vector<StackFrameView>* stack_trace,
252     NameAttrList* encapsulating_function,
253     RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
254     const {
255   xla::StatusOr<std::vector<NodeDef>> calls =
256       MakeCallNodesFromAttribute(case_node, "branches", "branch");
257   if (!calls.ok()) {
258     VLOG(2) << "Rejecting node " << case_node.name() << ": "
259             << "missing attribute 'branches'";
260     return false;
261   }
262 
263   bool is_compilable = true;
264 
265   for (const NodeDef& call : *calls) {
266     is_compilable &=
267         IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function,
268                          uncompilable_nodes);
269   }
270   return is_compilable;
271 }
272 
273 // Tests whether 'while_node' is a completely compilable loop.
274 // Every operator in the condition and body functions must be compilable for a
275 // while loop to be compilable.
IsCompilableWhile(const Node & while_node,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const276 bool RecursiveCompilabilityChecker::IsCompilableWhile(
277     const Node& while_node, FunctionLibraryRuntime* lib_runtime,
278     std::vector<StackFrameView>* stack_trace,
279     NameAttrList* encapsulating_function,
280     RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
281     const {
282   bool is_compilable = true;
283   is_compilable &= ExtractNodeDefAndCheckCompilability(
284       while_node, "cond", "while_cond", encapsulating_function, lib_runtime,
285       stack_trace, uncompilable_nodes);
286 
287   if (!uncompilable_nodes && !is_compilable) return is_compilable;
288 
289   is_compilable &= ExtractNodeDefAndCheckCompilability(
290       while_node, "body", "while_body", encapsulating_function, lib_runtime,
291       stack_trace, uncompilable_nodes);
292 
293   return is_compilable;
294 }
295 
ExtractNodeDefAndCheckCompilability(const Node & node,const std::string & attr_name,const std::string & call_name,NameAttrList * encapsulating_function,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const296 bool RecursiveCompilabilityChecker::ExtractNodeDefAndCheckCompilability(
297     const Node& node, const std::string& attr_name,
298     const std::string& call_name, NameAttrList* encapsulating_function,
299     FunctionLibraryRuntime* lib_runtime,
300     std::vector<StackFrameView>* stack_trace,
301     RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
302     const {
303   NodeDef call;
304   call.set_name(call_name);
305   if (!MakeCallNodeFromAttribute(node, attr_name, &call).ok()) {
306     const auto uncompilable_reason = absl::StrCat(
307         "missing '", attr_name, "' attribute from node", node.name());
308     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
309                               encapsulating_function, uncompilable_nodes);
310     VLOG(2) << "Rejecting node " << node.name() << ": " << uncompilable_reason
311             << ".";
312     return false;
313   }
314   if (!IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function,
315                         uncompilable_nodes)) {
316     VLOG(2) << "Rejecting node " << node.name()
317             << ": can't compile : " << call.op();
318     return false;
319   }
320   return true;
321 }
322 
323 // Tests whether 'call_def' is a call to a completely compilable function.
324 // Every operator in the function must be compilable for a function to be
325 // compilable.
IsCompilableCall(const NodeDef & call_def,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const326 bool RecursiveCompilabilityChecker::IsCompilableCall(
327     const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
328     std::vector<StackFrameView>* stack_trace,
329     NameAttrList* encapsulating_function,
330     RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
331     const {
332   if (stack_trace->size() > kMaxRecursionDepth) {
333     std::string uncompilable_reason = "function depth limit exceeded";
334     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
335                               encapsulating_function, uncompilable_nodes);
336     VLOG(2) << "Rejecting " << call_def.op() << ": " << uncompilable_reason
337             << ".";
338     return false;
339   }
340 
341   FunctionLibraryRuntime::Handle handle;
342   Status s;
343   NameAttrList function;
344   s = NameAndAttrsFromFunctionCall(call_def, &function);
345   if (s.ok()) {
346     s = lib_runtime->Instantiate(function.name(), AttrSlice(&function.attr()),
347                                  &handle);
348   }
349   if (!s.ok()) {
350     std::string uncompilable_reason =
351         absl::StrCat("could not instantiate call: '", function.name(), "'");
352     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
353                               encapsulating_function, uncompilable_nodes);
354     VLOG(2) << "Rejecting " << call_def.DebugString() << ": "
355             << uncompilable_reason << " : " << s;
356     return false;
357   }
358 
359   auto release_handle_on_return = gtl::MakeCleanup(
360       [&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
361   const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
362   bool is_compilable = true;
363   for (const Node* node : fbody->graph->op_nodes()) {
364     stack_trace->emplace_back(
365         StackFrameView{node->name(), function.name(), node->GetStackTrace()});
366     is_compilable &= IsCompilableNode(*node, lib_runtime, stack_trace,
367                                       &function, uncompilable_nodes);
368     stack_trace->pop_back();
369     if (!uncompilable_nodes && !is_compilable) return is_compilable;
370   }
371 
372   return is_compilable;
373 }
374 
OpIsInaccurate(const Node & node) const375 bool RecursiveCompilabilityChecker::OpIsInaccurate(const Node& node) const {
376   // b/127344411: SelfAdjointEigV2 and Svd precision issues.
377   return node.type_string() == "SelfAdjointEigV2" ||
378          node.type_string() == "Svd";
379 }
380 
OpIsSlow(const Node & node) const381 bool RecursiveCompilabilityChecker::OpIsSlow(const Node& node) const {
382   // b/128001705: SelfAdjointEigV2 and Svd performance issues.
383   // b/135640736: MatrixInverse performance issues.
384   // b/111271662: MatrixSolve performance issues.
385   // https://github.com/tensorflow/tensorflow/pull/31012:
386   //    ResizeNearestNeighbor, ResizeBilinear, and ResizeBilinearGrad sometimes
387   //    create convolutions too large for CuDNN to handle.
388   return node.type_string() == "SelfAdjointEigV2" ||
389          node.type_string() == "Svd" || node.type_string() == "Qr" ||
390          node.type_string() == "MatrixInverse" ||
391          node.type_string() == "MatrixSolve" ||
392          node.type_string() == "ResizeNearestNeighbor" ||
393          node.type_string() == "ResizeBilinear" ||
394          node.type_string() == "ResizeBilinearGrad";
395 }
396 
IsCompilableNode(const Node & node,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const397 bool RecursiveCompilabilityChecker::IsCompilableNode(
398     const Node& node, FunctionLibraryRuntime* lib_runtime,
399     std::vector<StackFrameView>* stack_trace,
400     NameAttrList* encapsulating_function,
401     RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
402     const {
403   auto stack_depth = stack_trace->size();
404   if (node.IsSource() || node.IsSink()) {
405     absl::string_view uncompilable_reason = "source or sink node";
406     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
407                               encapsulating_function, uncompilable_nodes);
408     LogNotCompilable(node, uncompilable_reason);
409     return false;
410   }
411 
412   // _Arg nodes in a top-level function represent feeds and _Retval nodes in a
413   // top-level function represent fetches.
414   if (stack_depth == 1 &&
415       (node.type_string() == "_Arg" || node.type_string() == "_Retval")) {
416     absl::string_view uncompilable_reason = "top level _Arg or _Retval";
417     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
418                               encapsulating_function, uncompilable_nodes);
419     LogNotCompilable(node, uncompilable_reason);
420     return false;
421   }
422 
423   if (node.attrs().Find("_scoped_allocator") ||
424       node.attrs().Find("_forward_from")) {
425     // TODO(b/128858118): XLA does not support _scoped_allocator and
426     // _forward_from.
427     absl::string_view uncompilable_reason =
428         "_scoped_allocator or _forward_from attribute";
429     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
430                               encapsulating_function, uncompilable_nodes);
431     LogNotCompilable(node, uncompilable_reason);
432     return false;
433   }
434 
435   string uncompilable_reason;
436   if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), node)) {
437     if (!IsCompilableCall(node.def(), lib_runtime, stack_trace,
438                           encapsulating_function, uncompilable_nodes)) {
439       LogNotCompilable(node, "unsupported function");
440       return false;
441     }
442   } else if (!HasXLAKernel(node, &uncompilable_reason)) {
443     MaybeMarkUncompilableNode(
444         absl::StrCat("unsupported op: ", uncompilable_reason), *stack_trace,
445         encapsulating_function, uncompilable_nodes);
446     LogNotCompilable(node, uncompilable_reason);
447     return false;
448   }
449 
450   if (node.IsWhileNode() &&
451       !IsCompilableWhile(node, lib_runtime, stack_trace, encapsulating_function,
452                          uncompilable_nodes)) {
453     LogNotCompilable(node, "unsupported while");
454     return false;
455   }
456 
457   if (node.IsIfNode() &&
458       !IsCompilableIf(node, lib_runtime, stack_trace, encapsulating_function,
459                       uncompilable_nodes)) {
460     LogNotCompilable(node, "unsupported if");
461     return false;
462   }
463 
464   if (op_filter_.require_always_compilable && node.IsCaseNode() &&
465       !IsCompilableCase(node, lib_runtime, stack_trace, encapsulating_function,
466                         uncompilable_nodes)) {
467     LogNotCompilable(node, "unsupported case");
468     return false;
469   }
470 
471   if (!op_filter_.allow_stateful_rng_ops &&
472       IsStatefulRandomOp(node.type_string())) {
473     absl::string_view uncompilable_reason = "stateful random op";
474     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
475                               encapsulating_function, uncompilable_nodes);
476     LogNotCompilable(node, uncompilable_reason);
477     return false;
478   }
479 
480   if (!op_filter_.allow_control_trigger && node.IsControlTrigger()) {
481     absl::string_view uncompilable_reason = "not allowed control trigger";
482     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
483                               encapsulating_function, uncompilable_nodes);
484     LogNotCompilable(node, uncompilable_reason);
485     return false;
486   }
487 
488   if (!op_filter_.allow_eliding_assert_and_checknumerics_ops &&
489       IsAssertOrCheckNumerics(node.type_string())) {
490     absl::string_view uncompilable_reason = "Assert or CheckNumerics";
491     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
492                               encapsulating_function, uncompilable_nodes);
493     LogNotCompilable(node, uncompilable_reason);
494     return false;
495   }
496 
497   if (!op_filter_.allow_ops_producing_or_consuming_variant &&
498       OpProducesOrConsumesVariant(node)) {
499     absl::string_view uncompilable_reason = "DT_VARIANT producer/consumer";
500     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
501                               encapsulating_function, uncompilable_nodes);
502     LogNotCompilable(node, uncompilable_reason);
503     return false;
504   }
505 
506   if (!op_filter_.allow_stack_ops && IsStackOp(node)) {
507     absl::string_view uncompilable_reason = "Stack op";
508     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
509                               encapsulating_function, uncompilable_nodes);
510     LogNotCompilable(node, uncompilable_reason);
511     return false;
512   }
513 
514   if (!op_filter_.allow_tensor_array_ops && IsTensorArrayOp(node)) {
515     absl::string_view uncompilable_reason = "TensorArray op";
516     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
517                               encapsulating_function, uncompilable_nodes);
518     LogNotCompilable(node, uncompilable_reason);
519     return false;
520   }
521 
522   if (!op_filter_.allow_resource_ops_in_called_functions && stack_depth > 1 &&
523       HasResourceInput(node)) {
524     absl::string_view uncompilable_reason =
525         "resource variable op in called function";
526     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
527                               encapsulating_function, uncompilable_nodes);
528     LogNotCompilable(node, uncompilable_reason);
529     return false;
530   }
531 
532   if (!op_filter_.allow_inaccurate_ops && OpIsInaccurate(node)) {
533     absl::string_view uncompilable_reason =
534         "operation with numerical accuracy issues";
535     BroadcastOptimizationRemark(XlaOptimizationRemark::INACCURATE_OPERATION,
536                                 node.DebugString())
537         .IgnoreError();
538     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
539                               encapsulating_function, uncompilable_nodes);
540     LogNotCompilable(node, uncompilable_reason);
541     return false;
542   }
543 
544   if (!op_filter_.allow_slow_ops && OpIsSlow(node)) {
545     absl::string_view uncompilable_reason = "slow operation";
546     BroadcastOptimizationRemark(XlaOptimizationRemark::SLOW_OPERATION,
547                                 node.DebugString())
548         .IgnoreError();
549     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
550                               encapsulating_function, uncompilable_nodes);
551     LogNotCompilable(node, uncompilable_reason);
552     return false;
553   }
554 
555   return true;
556 }
557 
CreateOperationFilter(const XlaOpRegistry::DeviceRegistration & registration)558 RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
559     const XlaOpRegistry::DeviceRegistration& registration) {
560   RecursiveCompilabilityChecker::OperationFilter op_filter;
561   op_filter.allow_resource_ops_in_called_functions =
562       registration.cluster_resource_variable_ops_unsafely;
563   op_filter.allow_stack_ops = registration.cluster_stack_ops;
564   op_filter.allow_tensor_array_ops = registration.cluster_tensor_array_ops;
565   op_filter.allow_stateful_rng_ops = registration.cluster_stateful_rng_ops;
566   op_filter.allow_control_trigger = registration.cluster_control_trigger;
567   op_filter.allow_eliding_assert_and_checknumerics_ops =
568       registration.elide_assert_and_checknumerics;
569   op_filter.allow_ops_producing_or_consuming_variant =
570       registration.cluster_variant_ops;
571   op_filter.allow_slow_ops = registration.cluster_slow_ops;
572   op_filter.allow_inaccurate_ops = registration.cluster_inaccurate_ops;
573   return op_filter;
574 }
575 
MaybeMarkUncompilableNode(const absl::string_view reason,const std::vector<StackFrameView> & stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes)576 /*static*/ void RecursiveCompilabilityChecker::MaybeMarkUncompilableNode(
577     const absl::string_view reason,
578     const std::vector<StackFrameView>& stack_trace,
579     NameAttrList* encapsulating_function,
580     RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes) {
581   if (!uncompilable_nodes) return;
582 
583   UncompilableNodeInfo node_info;
584   node_info.uncompilable_reason = std::string(reason);
585   absl::c_transform(stack_trace, std::back_inserter(node_info.stack_trace),
586                     [](const StackFrameView& stack_element) {
587                       return StackFrame{
588                           std::string(stack_element.name),
589                           std::string(stack_element.function_name),
590                           stack_element.stack_trace};
591                     });
592 
593   node_info.name = std::string(stack_trace.back().name);
594   auto function =
595       encapsulating_function ? *encapsulating_function : NameAttrList();
596   auto function_identifier = function.ShortDebugString();
597 
598   auto it = uncompilable_nodes->find(function_identifier);
599   if (it == uncompilable_nodes->end()) {
600     std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
601         uncompilable_node_info{std::move(node_info)};
602     uncompilable_nodes->emplace(
603         std::move(function_identifier),
604         std::make_pair(function, std::move(uncompilable_node_info)));
605   } else {
606     it->second.second.emplace_back(std::move(node_info));
607   }
608 }
609 
610 // Returns `true` iff node has a given `attr` set to `true`. Returns `false`
611 // both for the missing attr, and the attr set to `false`.
HasBoolAttr(const NodeDef & node,const char * attr)612 static bool HasBoolAttr(const NodeDef& node, const char* attr) {
613   const auto& it = node.attr().find(attr);
614   return it != node.attr().end() && it->second.b();
615 }
616 
CanCreateXlaKernel(const NodeDef & node_def)617 bool CanCreateXlaKernel(const NodeDef& node_def) {
618   return HasBoolAttr(node_def, kXlaMustCompileAttr);
619 }
620 
GetBodyAndConstantsAndResources(FunctionLibraryRuntime * flr,const NameAttrList & function,const FunctionBody ** fbody,std::vector<int> * constant_arg_indices,std::vector<int> * resource_arg_indices)621 Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
622                                        const NameAttrList& function,
623                                        const FunctionBody** fbody,
624                                        std::vector<int>* constant_arg_indices,
625                                        std::vector<int>* resource_arg_indices) {
626   FunctionLibraryRuntime::Handle handle;
627   TF_RETURN_IF_ERROR(
628       flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
629   *fbody = flr->GetFunctionBody(handle);
630   CHECK(*fbody);  // Can't be nullptr since we just instantiated it.
631   const DataTypeVector& arg_types = (*fbody)->arg_types;
632   std::vector<bool> const_args(arg_types.size());
633   // If we can't analyze the const args. Bail out.
634   TF_RETURN_IF_ERROR(
635       BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
636                              /*compile_time_const_nodes=*/nullptr, flr));
637 
638   for (size_t i = 0; i < const_args.size(); ++i) {
639     if (const_args[i]) {
640       constant_arg_indices->push_back(i);
641     }
642   }
643 
644   // There can be hundreds of resource variables. Reserve the space for them.
645   // We don't reserve for constants above as they are usually few.
646   resource_arg_indices->reserve(arg_types.size());
647   for (size_t i = 0; i < arg_types.size(); ++i) {
648     if (arg_types[i] == DT_RESOURCE) {
649       resource_arg_indices->push_back(i);
650     }
651   }
652 
653   return Status::OK();
654 }
655 
GetInputMemoryTypes(const tensorflow::FunctionBody * fbody,absl::Span<int const> constant_arg_indices,absl::Span<int const> resource_arg_indices)656 tensorflow::MemoryTypeVector GetInputMemoryTypes(
657     const tensorflow::FunctionBody* fbody,
658     absl::Span<int const> constant_arg_indices,
659     absl::Span<int const> resource_arg_indices) {
660   // Set input and output memory types.
661   tensorflow::MemoryTypeVector input_memory_types(fbody->arg_types.size(),
662                                                   tensorflow::DEVICE_MEMORY);
663   // These indices are used only for optimization purposes. They allow us
664   // to loop over constant_arg_indices and resource_arg_indices only once
665   // while iterating over all the function arguments checking if it is a
666   // resource or a constant.
667   // The reason we optimized this code is because functions can have a lot of
668   // captured arguments. For example, the backward pass of ResNet50 takes in all
669   // 214 variables and a similar number of activations.
670   SinglePassSearch constants_search(constant_arg_indices);
671   SinglePassSearch resources_search(resource_arg_indices);
672   for (size_t i = 0; i < fbody->arg_types.size(); ++i) {
673     if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
674       // Compile-time constants and resource handles are expected to be in
675       // host memory.
676       input_memory_types[i] = tensorflow::HOST_MEMORY;
677     }
678   }
679   return input_memory_types;
680 }
681 
GetOutputMemoryTypes(const tensorflow::FunctionBody * fbody)682 tensorflow::MemoryTypeVector GetOutputMemoryTypes(
683     const tensorflow::FunctionBody* fbody) {
684   tensorflow::MemoryTypeVector output_memory_types(fbody->ret_types.size(),
685                                                    tensorflow::DEVICE_MEMORY);
686   for (size_t i = 0; i < fbody->ret_types.size(); ++i) {
687     if (fbody->ret_types[i] == tensorflow::DT_RESOURCE) {
688       output_memory_types[i] = tensorflow::HOST_MEMORY;
689     }
690   }
691   return output_memory_types;
692 }
693 
694 static auto const ops_triggering_xla_compilation =
695     new absl::flat_hash_set<std::string>{"XlaBroadcastHelper",
696                                          "XlaConv",
697                                          "XlaDequantize",
698                                          "XlaDot",
699                                          "XlaDynamicSlice",
700                                          "XlaDynamicUpdateSlice",
701                                          "XlaEinsum",
702                                          "XlaGather",
703                                          "XlaIf",
704                                          "XlaKeyValueSort",
705                                          "XlaPad",
706                                          "XlaRecv",
707                                          "XlaReduce",
708                                          "XlaReduceWindow",
709                                          "XlaReplicaId",
710                                          "XlaScatter",
711                                          "XlaSelectAndScatter",
712                                          "XlaSelfAdjointEig",
713                                          "XlaSend",
714                                          "XlaSharding",
715                                          "XlaSort",
716                                          "XlaSpmdFullToShardShape",
717                                          "XlaSpmdShardToFullShape",
718                                          "XlaSvd",
719                                          "XlaWhile"};
720 
NodeCanTriggerXlaCompilation(const NodeDef & node)721 static bool NodeCanTriggerXlaCompilation(const NodeDef& node) {
722   return node.attr().find(kXlaClusterIdAttr) != node.attr().end() ||
723          HasBoolAttr(node, kXlaMustCompileAttr) ||
724          HasBoolAttr(node, kXlaCompileAttr) ||
725          HasBoolAttr(node, kXlaScopeAttr) ||
726          HasBoolAttr(node, kXlaInternalScopeAttr) ||
727          ops_triggering_xla_compilation->count(node.op());
728 }
729 
CanTriggerXlaCompilation(const GraphDef & graph)730 bool CanTriggerXlaCompilation(const GraphDef& graph) {
731   for (const FunctionDef& function : graph.library().function()) {
732     for (const NodeDef& node : function.node_def()) {
733       if (NodeCanTriggerXlaCompilation(node)) {
734         return true;
735       }
736     }
737   }
738 
739   for (const NodeDef& node : graph.node()) {
740     if (NodeCanTriggerXlaCompilation(node)) {
741       return true;
742     }
743   }
744 
745   return false;
746 }
747 
748 }  // namespace tensorflow
749