1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_cluster_util.h"
17 
18 #include <unordered_map>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/container/inlined_vector.h"
22 #include "absl/strings/match.h"
23 #include "absl/strings/numbers.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_join.h"
26 #include "tensorflow/compiler/jit/flags.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/core/common_runtime/function.h"
29 #include "tensorflow/core/framework/bounds_check.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/graph/control_flow.h"
32 #include "tensorflow/core/lib/gtl/cleanup.h"
33 #include "tensorflow/core/public/session_options.h"
34 #include "tensorflow/core/util/device_name_utils.h"
35 #include "tensorflow/core/util/xla_config_registry.h"
36 
37 namespace tensorflow {
38 
39 const char* const kXlaClusterAttr = "_XlaCluster";
40 const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation";
41 const char* const kXlaCompileTimeConstantInputsAttr =
42     "_XlaCompileTimeConstantInputs";
43 
44 namespace {
45 // Returns a string describing how an edge from src to dst would
46 // create a cycle.
DescribeCycle(const GraphCycles * cycles,const Graph & graph,int src,int dst)47 string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src,
48                      int dst) {
49   int32 max_path_size = graph.num_node_ids() + 1;
50   std::vector<int32> path(max_path_size);
51   int32 path_size = cycles->FindPath(dst, src, max_path_size, path.data());
52   if (path_size == 0) {
53     return "";
54   }
55 
56   auto node_name = [&graph](int node_id) {
57     if (!FastBoundsCheck(node_id, graph.num_node_ids())) {
58       return string("(null)");
59     }
60     auto* node = graph.FindNodeId(node_id);
61     if (node == nullptr) {
62       return string("(null)");
63     }
64     return node->name();
65   };
66 
67   string description;
68   absl::StrAppend(&description, "Edge from ", node_name(src), " to ",
69                   node_name(dst), " would create a cycle.\n");
70   path.resize(path_size);
71   for (int32 node_id : path) {
72     string ascii_art;
73     if (node_id == dst) {
74       ascii_art = "+-> ";
75     } else if (node_id != src) {
76       ascii_art = "|   ";
77     } else {
78       ascii_art = "+-- ";
79     }
80     absl::StrAppend(&description, ascii_art, node_name(node_id), "\n");
81   }
82   return description;
83 }
84 
AlwaysForwardsRefInput(const Node & node)85 bool AlwaysForwardsRefInput(const Node& node) { return node.IsIdentity(); }
86 
87 }  // namespace
88 
HasForwardedRefInput(const Node & node)89 bool HasForwardedRefInput(const Node& node) {
90   if (AlwaysForwardsRefInput(node)) {
91     for (const Edge* incoming_edge : node.in_edges()) {
92       if (incoming_edge->IsControlEdge()) {
93         continue;
94       }
95 
96       Node* incoming_node = incoming_edge->src();
97       if (IsRefType(incoming_node->output_type(incoming_edge->src_output()))) {
98         VLOG(2) << "Node " << node.def().ShortDebugString() << " has ref input "
99                 << incoming_node->name() << " " << incoming_node->type_string();
100         return true;
101       }
102     }
103   }
104   return false;
105 }
106 
CreateCycleDetectionGraph(const Graph * graph,GraphCycles * cycles)107 xla::StatusOr<bool> CreateCycleDetectionGraph(const Graph* graph,
108                                               GraphCycles* cycles) {
109   for (int i = 0; i < graph->num_node_ids(); ++i) {
110     // We rely on the node IDs in the cycle detection graph being consecutive
111     // integers starting from 0.
112     CHECK_EQ(i, cycles->NewNode());
113   }
114 
115   // Compute the loop structure of the graph.
116   std::vector<ControlFlowInfo> control_flow_info;
117   TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info));
118 
119   // The clustering code must avoid adding cycles to the graph to prevent
120   // deadlock. However, the graph may contain loops, which would trigger the
121   // cycle detection code. To handle loops, we alter the structure of the cycle
122   // detection graph, disconnecting each loop from the enclosing graph.
123   // Specifically, we:
124   // * add a new "frame" node for each loop.
125   // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges
126   //   to/from the corresponding frame node. In essence, we collapse the loop
127   //   into a single node for the purpose of cycle detection in the enclosing
128   //   graph.
129   // * the body of the loop should now be disconnected from the rest of the
130   //   graph; we make it acyclic by breaking loop backedges (edges outgoing from
131   //   "NextIteration" nodes.
132 
133   // Map from frame name strings to node IDs in the cycle detection graph.
134   std::unordered_map<string, int> frame_nodes;
135 
136   // Get the cycle graph node ID for frame 'frame_name', or add one if none
137   // exists.
138   auto GetOrAddFrameNodeId = [&frame_nodes, cycles](const string& frame_name) {
139     int& frame_id = frame_nodes.emplace(frame_name, -1).first->second;
140     if (frame_id < 0) {
141       // The emplace succeeded; we have not allocated a frame node yet.
142       frame_id = cycles->NewNode();
143     }
144     return frame_id;
145   };
146 
147   for (Edge const* edge : graph->edges()) {
148     if (edge->dst()->IsEnter() || edge->src()->IsExit()) {
149       const char* src_type = "pre-enter";
150       const char* dst_type = "post-exit";
151       int src = edge->src()->id();
152       int dst = edge->dst()->id();
153 
154       if (edge->dst()->IsEnter()) {
155         // Lift edges to an "Enter" node to the corresponding frame node.
156         const string& frame_name =
157             control_flow_info[edge->dst()->id()].frame_name;
158         dst = GetOrAddFrameNodeId(frame_name);
159         dst_type = "frame";
160       }
161 
162       if (edge->src()->IsExit()) {
163         // Lift edges from an "Exit" node to the corresponding frame node.
164         const string& frame_name =
165             control_flow_info[edge->src()->id()].frame_name;
166         src = GetOrAddFrameNodeId(frame_name);
167         src_type = "frame";
168       }
169 
170       if (!cycles->InsertEdge(src, dst)) {
171         // TODO(b/127521408): We can probably handle this situation with a more
172         // sophisticated SCC based algorithm, but for now we bail out.
173         VLOG(1) << "Cycle detected when adding " << src_type << "->" << dst_type
174                 << " edge: " << DescribeCycle(cycles, *graph, src, dst);
175         return false;
176       }
177       // Drop the original edge.
178       continue;
179     }
180     if (edge->src()->IsNextIteration()) {
181       // Break loop back-edges.
182       continue;
183     }
184     if (!cycles->InsertEdge(edge->src()->id(), edge->dst()->id())) {
185       // This should never happen. All cycles in the graph should contain
186       // a control flow operator.
187       return errors::Internal(
188           "Found cycle in graph without control flow operator during XLA "
189           "compilation: ",
190           DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id()));
191     }
192   }
193 
194   return true;
195 }
196 
GetXlaClusterForNode(const Node & node)197 absl::optional<absl::string_view> GetXlaClusterForNode(const Node& node) {
198   const AttrValue* attr_value = node.attrs().Find(kXlaClusterAttr);
199   if (attr_value == nullptr) {
200     return absl::nullopt;
201   }
202   Status s = AttrValueHasType(*attr_value, "string");
203   if (!s.ok()) {
204     return absl::nullopt;
205   }
206   return attr_value->s();
207 }
208 
HasResourceInputOrOutput(const Node & node)209 bool HasResourceInputOrOutput(const Node& node) {
210   return std::find(node.input_types().begin(), node.input_types().end(),
211                    DT_RESOURCE) != node.input_types().end() ||
212          std::find(node.output_types().begin(), node.output_types().end(),
213                    DT_RESOURCE) != node.output_types().end();
214 }
215 
RemoveFromXlaCluster(NodeDef * node_def)216 void RemoveFromXlaCluster(NodeDef* node_def) {
217   node_def->mutable_attr()->erase(kXlaClusterAttr);
218 }
219 
RemoveFromXlaCluster(Node * node)220 void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); }
221 
222 namespace {
223 typedef xla_config_registry::XlaGlobalJitLevel XlaGlobalJitLevel;
224 
GetXlaGlobalJitLevel(const OptimizerOptions::GlobalJitLevel & jit_level_in_session_opts)225 XlaGlobalJitLevel GetXlaGlobalJitLevel(
226     const OptimizerOptions::GlobalJitLevel& jit_level_in_session_opts) {
227   XlaGlobalJitLevel result;
228 
229   if (jit_level_in_session_opts == OptimizerOptions::DEFAULT) {
230     // To set compilation to be on by default, change the following line.
231     result.single_gpu = result.general = OptimizerOptions::OFF;
232   } else {
233     result.single_gpu = result.general = jit_level_in_session_opts;
234   }
235 
236   // If the flag tf_xla_auto_jit is a valid, non-DEFAULT setting, it overrides
237   // the setting in ConfigProto.
238   MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
239   if (flags->xla_auto_jit_flag.optimization_level_single_gpu !=
240       OptimizerOptions::DEFAULT) {
241     result.single_gpu = static_cast<OptimizerOptions::GlobalJitLevel>(
242         flags->xla_auto_jit_flag.optimization_level_single_gpu);
243   }
244   if (flags->xla_auto_jit_flag.optimization_level_general !=
245       OptimizerOptions::DEFAULT) {
246     result.general = static_cast<OptimizerOptions::GlobalJitLevel>(
247         flags->xla_auto_jit_flag.optimization_level_general);
248   }
249 
250   return result;
251 }
252 
GetGpuNumber(const string & device_name)253 int GetGpuNumber(const string& device_name) {
254   DeviceNameUtils::ParsedName parsed_name;
255   if (!DeviceNameUtils::ParseFullName(device_name, &parsed_name)) {
256     return -1;
257   }
258 
259   return parsed_name.type == DEVICE_GPU ? parsed_name.id : -1;
260 }
261 }  // namespace
262 
IsSingleGpuGraph(const Graph & g)263 bool IsSingleGpuGraph(const Graph& g) {
264   int gpus_seen = 0;
265   absl::flat_hash_set<string> devices_seen;
266 
267   for (Node* n : g.op_nodes()) {
268     if (devices_seen.contains(n->assigned_device_name())) {
269       continue;
270     }
271 
272     int gpu_number = GetGpuNumber(n->assigned_device_name());
273     if (gpu_number != -1) {
274       if (++gpus_seen > 1) {
275         return false;
276       }
277     }
278 
279     devices_seen.insert(n->assigned_device_name());
280   }
281 
282   return gpus_seen == 1;
283 }
284 
GetGlobalJitLevelForGraph(const GraphOptimizationPassOptions & options)285 OptimizerOptions::GlobalJitLevel GetGlobalJitLevelForGraph(
286     const GraphOptimizationPassOptions& options) {
287   OptimizerOptions::GlobalJitLevel jit_level_in_session_opts =
288       options.session_options->config.graph_options()
289           .optimizer_options()
290           .global_jit_level();
291   XlaGlobalJitLevel xla_global_jit_level =
292       GetXlaGlobalJitLevel(jit_level_in_session_opts);
293   if (xla_global_jit_level.single_gpu == xla_global_jit_level.general) {
294     VLOG(4) << "GetGlobalJitLevelForGraph returning "
295             << xla_global_jit_level.single_gpu;
296     return xla_global_jit_level.single_gpu;
297   }
298   OptimizerOptions::GlobalJitLevel result =
299       IsSingleGpuGraph(**options.graph) ? xla_global_jit_level.single_gpu
300                                         : xla_global_jit_level.general;
301   VLOG(4) << "GetGlobalJitLevelForGraph returning " << result;
302   return result;
303 }
304 
MayCallFunction(const Node & n,const FunctionLibraryDefinition * flib_def)305 bool MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def) {
306   if (flib_def->Contains(n.type_string())) {
307     return true;
308   }
309 
310   // This is a conservative check: there may be nodes with a `func`
311   // attribute that do not make function calls.
312   return absl::c_any_of(n.def().attr(),
313                         [](const std::pair<string, AttrValue>& name_attr_pair) {
314                           return name_attr_pair.second.has_func();
315                         });
316 }
IsShapeConsumerOp(const Node & node)317 bool IsShapeConsumerOp(const Node& node) {
318   return node.type_string() == "Shape" || node.type_string() == "Rank" ||
319          node.type_string() == "Size";
320 }
321 
322 namespace {
323 struct ClusterInfo {
324   int size;
325 
326   // Maps op names to the number of times they appear in the cluster.
327   absl::flat_hash_map<absl::string_view, int> op_histogram;
328 };
329 
HistogramMapToRepeatedOpAndCount(protobuf::RepeatedPtrField<XlaAutoClusteringSummary::OpAndCount> * result,const absl::flat_hash_map<absl::string_view,int> & histogram)330 void HistogramMapToRepeatedOpAndCount(
331     protobuf::RepeatedPtrField<XlaAutoClusteringSummary::OpAndCount>* result,
332     const absl::flat_hash_map<absl::string_view, int>& histogram) {
333   for (const auto& pair : histogram) {
334     XlaAutoClusteringSummary::OpAndCount* new_entry = result->Add();
335     new_entry->set_op(std::string(pair.first));
336     new_entry->set_count(pair.second);
337   }
338 
339   absl::c_sort(*result, [](const XlaAutoClusteringSummary::OpAndCount& a,
340                            const XlaAutoClusteringSummary::OpAndCount& b) {
341     return a.op() < b.op();
342   });
343 }
344 
ClusterInfoToProtobuf(XlaAutoClusteringSummary::Cluster * result,absl::string_view name,const ClusterInfo & info)345 void ClusterInfoToProtobuf(XlaAutoClusteringSummary::Cluster* result,
346                            absl::string_view name, const ClusterInfo& info) {
347   result->set_name(std::string(name));
348   result->set_size(info.size);
349   HistogramMapToRepeatedOpAndCount(result->mutable_op_histogram(),
350                                    info.op_histogram);
351 }
352 }  // namespace
353 
GetXlaAutoClusteringSummary(const Graph & graph)354 XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph) {
355   absl::flat_hash_map<absl::string_view, ClusterInfo> cluster_name_to_info;
356   XlaAutoClusteringSummary result;
357 
358   absl::flat_hash_map<absl::string_view, int> unclustered_op_histogram;
359 
360   for (Node* n : graph.nodes()) {
361     absl::optional<absl::string_view> cluster_name = GetXlaClusterForNode(*n);
362     if (cluster_name) {
363       result.set_clustered_node_count(result.clustered_node_count() + 1);
364       ClusterInfo* info = &cluster_name_to_info[*cluster_name];
365       info->size++;
366       info->op_histogram[n->type_string()]++;
367     } else {
368       result.set_unclustered_node_count(result.unclustered_node_count() + 1);
369       unclustered_op_histogram[n->type_string()]++;
370     }
371   }
372 
373   for (const auto& pair : cluster_name_to_info) {
374     XlaAutoClusteringSummary::Cluster* new_cluster = result.add_clusters();
375     ClusterInfoToProtobuf(new_cluster, pair.first, pair.second);
376   }
377 
378   absl::c_sort(*result.mutable_clusters(),
379                [&](const XlaAutoClusteringSummary::Cluster& a,
380                    const XlaAutoClusteringSummary::Cluster& b) {
381                  return a.name() < b.name();
382                });
383 
384   HistogramMapToRepeatedOpAndCount(result.mutable_unclustered_op_histogram(),
385                                    unclustered_op_histogram);
386 
387   return result;
388 }
389 
390 namespace {
391 using CallTargetListTy = absl::InlinedVector<NameAttrList, 2>;
392 
GetCallTargetListFromNode(const Node & n,FunctionLibraryRuntime * lib_runtime)393 CallTargetListTy GetCallTargetListFromNode(
394     const Node& n, FunctionLibraryRuntime* lib_runtime) {
395   const FunctionLibraryDefinition& flib_def =
396       *lib_runtime->GetFunctionLibraryDefinition();
397   if (flib_def.Find(n.type_string())) {
398     NameAttrList callee;
399     callee.set_name(n.type_string());
400     *callee.mutable_attr() = n.def().attr();
401     return {callee};
402   }
403 
404   CallTargetListTy result;
405   for (const auto& name_attr_pair : n.attrs()) {
406     const AttrValue& attr_value = name_attr_pair.second;
407     if (attr_value.value_case() == AttrValue::kFunc) {
408       result.push_back(attr_value.func());
409     } else if (attr_value.value_case() == AttrValue::kList) {
410       result.insert(result.end(), attr_value.list().func().begin(),
411                     attr_value.list().func().end());
412     }
413   }
414 
415   return result;
416 }
417 
418 enum class Direction { kForward, kBackward };
419 
420 Status GetNodesRelatedToRefVariablesInDirection(
421     const Graph& graph, FunctionLibraryRuntime* lib_runtime,
422     Direction direction, int depth, absl::flat_hash_set<Node*>* result);
423 
DoesAnyCalleeHaveRefNodes(const CallTargetListTy & call_target_list,FunctionLibraryRuntime * lib_runtime,Direction direction,int depth)424 xla::StatusOr<bool> DoesAnyCalleeHaveRefNodes(
425     const CallTargetListTy& call_target_list,
426     FunctionLibraryRuntime* lib_runtime, Direction direction, int depth) {
427   const int kMaxDepth = 10;
428 
429   if (depth == kMaxDepth && !call_target_list.empty()) {
430     // Conservative answer to avoid recursing too much.
431     return true;
432   }
433 
434   absl::flat_hash_set<Node*> callee_ref_nodes;
435   for (const NameAttrList& call_target : call_target_list) {
436     const OpRegistrationData* op_reg;
437     if (OpRegistry::Global()->LookUp(call_target.name(), &op_reg).ok()) {
438       const OpDef& op = op_reg->op_def;
439       if (absl::c_any_of(op.output_arg(), [](const OpDef::ArgDef arg) {
440             return arg.is_ref();
441           })) {
442         return true;
443       }
444       continue;
445     }
446 
447     callee_ref_nodes.clear();
448     FunctionLibraryRuntime::Handle handle;
449     if (!lib_runtime
450              ->Instantiate(call_target.name(), AttrSlice(&call_target.attr()),
451                            &handle)
452              .ok()) {
453       VLOG(2) << "Could not find " << call_target.name()
454               << " in the function library.";
455       // Since we don't know the semantic of `n` we don't know if this is an
456       // error.  We return true to signal a conservative answer.
457       return true;
458     }
459 
460     auto release_handle_on_return = gtl::MakeCleanup(
461         [&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
462 
463     const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
464     TF_RETURN_IF_ERROR(GetNodesRelatedToRefVariablesInDirection(
465         *fbody->graph, lib_runtime, direction, depth + 1, &callee_ref_nodes));
466 
467     // We could possibly use something cheaper than
468     // GetNodesRelatedToRefVariablesInDirection since we only care about the
469     // size of `callee_ref_nodes` but for now we don't ceare.
470     if (!callee_ref_nodes.empty()) {
471       return true;
472     }
473   }
474 
475   return false;
476 }
477 
478 // Helper for GetNodesRelatedToRefVariables that traverses the graph in one
479 // direction.
GetNodesRelatedToRefVariablesInDirection(const Graph & graph,FunctionLibraryRuntime * lib_runtime,Direction direction,int depth,absl::flat_hash_set<Node * > * result)480 Status GetNodesRelatedToRefVariablesInDirection(
481     const Graph& graph, FunctionLibraryRuntime* lib_runtime,
482     Direction direction, int depth, absl::flat_hash_set<Node*>* result) {
483   std::vector<Node*> nodes_in_order;
484   if (direction == Direction::kForward) {
485     GetReversePostOrder(graph, &nodes_in_order,
486                         /*stable_comparator=*/NodeComparatorName());
487   } else {
488     GetPostOrder(graph, &nodes_in_order,
489                  /*stable_comparator=*/NodeComparatorName());
490   }
491 
492   size_t old_result_size;
493   int iterations = 0;
494 
495   const int kMaxIterations = 10 * 1000;
496 
497   std::vector<bool> callee_has_ref_nodes_cache;
498   callee_has_ref_nodes_cache.resize(graph.num_node_ids());
499 
500   auto does_callee_have_ref_nodes = [&](Node* n) -> xla::StatusOr<bool> {
501     if (iterations == 1) {
502       TF_ASSIGN_OR_RETURN(
503           bool callee_has_ref_nodes,
504           DoesAnyCalleeHaveRefNodes(GetCallTargetListFromNode(*n, lib_runtime),
505                                     lib_runtime, direction, depth));
506       callee_has_ref_nodes_cache[n->id()] = callee_has_ref_nodes;
507       return callee_has_ref_nodes;
508     } else {
509       return {callee_has_ref_nodes_cache[n->id()]};
510     }
511   };
512 
513   do {
514     TF_RET_CHECK(iterations++ < kMaxIterations) << "infinite loop?";
515 
516     old_result_size = result->size();
517     for (Node* n : nodes_in_order) {
518       if (n->IsSource() || n->IsSink()) {
519         continue;
520       }
521 
522       bool inserted_n = false;
523       const EdgeSet& edges =
524           direction == Direction::kForward ? n->in_edges() : n->out_edges();
525       for (const Edge* e : edges) {
526         if (result->contains(direction == Direction::kForward ? e->src()
527                                                               : e->dst())) {
528           result->insert(n);
529           inserted_n = true;
530           break;
531         }
532       }
533 
534       if (inserted_n) {
535         continue;
536       }
537 
538       if (direction == Direction::kForward &&
539           absl::c_any_of(n->output_types(), IsRefType)) {
540         result->insert(n);
541         continue;
542       }
543 
544       TF_ASSIGN_OR_RETURN(bool callee_has_ref_nodes,
545                           does_callee_have_ref_nodes(n));
546       if (callee_has_ref_nodes) {
547         result->insert(n);
548         continue;
549       }
550     }
551 
552     // Loop until convergence.
553   } while (result->size() != old_result_size);
554 
555   VLOG(2) << "# iterations = " << iterations;
556 
557   return Status::OK();
558 }
559 }  // namespace
560 
GetNodesRelatedToRefVariables(const Graph & graph,FunctionLibraryRuntime * lib_runtime)561 xla::StatusOr<absl::flat_hash_set<Node*>> GetNodesRelatedToRefVariables(
562     const Graph& graph, FunctionLibraryRuntime* lib_runtime) {
563   absl::flat_hash_set<Node*> result;
564   TF_RETURN_IF_ERROR(GetNodesRelatedToRefVariablesInDirection(
565       graph, lib_runtime, Direction::kForward, 0, &result));
566   TF_RETURN_IF_ERROR(GetNodesRelatedToRefVariablesInDirection(
567       graph, lib_runtime, Direction::kBackward, 0, &result));
568 
569   VLOG(1) << "GetNodesRelatedToRefVariables() found " << result.size()
570           << " nodes";
571   return result;
572 }
573 
574 // Register a callback for querying XlaGlobalJitLevel.
575 REGISTER_XLA_CONFIG_GETTER(GetXlaGlobalJitLevel);
576 
577 }  // namespace tensorflow
578