1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
17
18 #include <atomic>
19 #include <deque>
20 #include <limits>
21 #include <unordered_map>
22 #include <unordered_set>
23
24 #include "absl/base/call_once.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/compiler/jit/compilability_check_util.h"
29 #include "tensorflow/compiler/jit/deadness_analysis.h"
30 #include "tensorflow/compiler/jit/defs.h"
31 #include "tensorflow/compiler/jit/device_util.h"
32 #include "tensorflow/compiler/jit/flags.h"
33 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
34 #include "tensorflow/compiler/jit/xla_cluster_util.h"
35 #include "tensorflow/compiler/tf2xla/const_analysis.h"
36 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
37 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
38 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
39 #include "tensorflow/compiler/xla/statusor.h"
40 #include "tensorflow/compiler/xla/union_find.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/core/common_runtime/function.h"
43 #include "tensorflow/core/common_runtime/graph_constructor.h"
44 #include "tensorflow/core/framework/bounds_check.h"
45 #include "tensorflow/core/framework/graph_def_util.h"
46 #include "tensorflow/core/framework/memory_types.h"
47 #include "tensorflow/core/framework/node_def.pb.h"
48 #include "tensorflow/core/framework/op_kernel.h"
49 #include "tensorflow/core/framework/tensor.pb.h"
50 #include "tensorflow/core/framework/types.h"
51 #include "tensorflow/core/graph/algorithm.h"
52 #include "tensorflow/core/graph/control_flow.h"
53 #include "tensorflow/core/lib/gtl/cleanup.h"
54 #include "tensorflow/core/lib/strings/stringprintf.h"
55 #include "tensorflow/core/public/version.h"
56 #include "tensorflow/core/util/dump_graph.h"
57
58 namespace tensorflow {
59
60 namespace {
61 using DeadnessPredicate = DeadnessAnalysis::DeadnessPredicate;
62 using jit::DeviceId;
63 using jit::DeviceSet;
64 using xla::StatusOr;
65
66 // The clusters we create here are eventually lowered into an
67 // _XlaCompile/_XlaRun pair with a TF executor "fallback" that uses the
68 // PartitionedCall op to execute the cluster in the regular graph executor if
69 // need be. PartitionedCall, however, reruns the entire TF graph optimization
70 // pipeline over the cluster which includes this mark for compilation pass. To
71 // avoid endlessly recursing we tag nodes that we've already visited with this
72 // attribute so that we can bail out if we see them a second time.
73 //
74 // TODO(sanjoy): This method is not robust since it is possible that the
75 // optimizations run by PartitionedCall can mutate the cluster arbitrarily,
76 // dropping the kXlaAlreadyClustered attributes from all nodes in the process.
77 // The correct fix is to use the ConfigProto to pass in some sort of flag into
78 // the PartitionedCall kernel that tells it to not rerun auto-clustering on the
79 // cluster.
80 const char* kXlaAlreadyClustered = "_XlaAlreadyClustered";
81
82 class MarkForCompilationPassImpl {
83 public:
84 struct DebugOptions {
85 // If true, do not respect the results of deadness analysis.
86 bool ignore_deadness_checks;
87
88 // If true, do not do safety checks to preserve TensorFlow's resource
89 // variable concurrency semantics.
90 bool ignore_resource_variable_checks;
91
92 // If true, do not respect the _XlaCompile=false attribute.
93 bool ignore_xla_compile_attr;
94
95 int max_cluster_size;
96 int min_cluster_size;
97
98 // Compiler fuel for the auto-clustering algorithm.
99 //
100 // We decrement this value by one on every time we choose a compilation
101 // candidate and we stop clustering when it hits zero. This means the
102 // initial value for this variable (via --tf_xla_clustering_fuel=N)
103 // effectively acts as a "cap" for how much we cluster and we can bisect
104 // over this initial value to discover clustering decisions that cause a
105 // miscompile or a performance regression.
106 std::atomic<int64>* fuel;
107
108 bool dump_graphs;
109 };
110
MarkForCompilationPassImpl(DebugOptions debug_options,Graph * graph,FunctionLibraryDefinition * flib_def,Env * env,OptimizerOptions::GlobalJitLevel global_jit_level)111 MarkForCompilationPassImpl(DebugOptions debug_options, Graph* graph,
112 FunctionLibraryDefinition* flib_def, Env* env,
113 OptimizerOptions::GlobalJitLevel global_jit_level)
114 : debug_options_(debug_options),
115 graph_(graph),
116 flib_def_(flib_def),
117 env_(env),
118 global_jit_level_(global_jit_level) {}
119
120 Status Run();
121
122 private:
123 // Represents a "cluster" or a connected subgraph of a TensorFlow graph.
124 class Cluster {
125 public:
126 // Constructs a trivial cluster representing a single TF node.
Cluster(int tf_graph_node_id,int effective_cluster_size,bool has_functional_control_flow,DeviceSet devices,absl::optional<DeviceId> resource_op_device,absl::optional<int> resource_var_operation_node_id,absl::optional<DeadnessPredicate> deadness_predicate,bool is_xla_compile_attr_true,absl::optional<string> xla_scope)127 Cluster(int tf_graph_node_id, int effective_cluster_size,
128 bool has_functional_control_flow, DeviceSet devices,
129 absl::optional<DeviceId> resource_op_device,
130 absl::optional<int> resource_var_operation_node_id,
131 absl::optional<DeadnessPredicate> deadness_predicate,
132 bool is_xla_compile_attr_true, absl::optional<string> xla_scope)
133 : cycles_graph_node_id_(tf_graph_node_id),
134 effective_cluster_size_(effective_cluster_size),
135 has_functional_control_flow_(has_functional_control_flow),
136 devices_(std::move(devices)),
137 resource_op_device_(resource_op_device),
138 deadness_predicate_(deadness_predicate),
139 is_xla_compile_attr_true_(is_xla_compile_attr_true),
140 xla_scope_(std::move(xla_scope)) {
141 if (resource_var_operation_node_id.has_value()) {
142 resource_var_operation_node_ids_.push_back(
143 *resource_var_operation_node_id);
144 }
145 }
146
147 // Merges `other` into this cluster, and clears `other`. This method is
148 // closely tied with the implementation of `MarkForCompilationPassImpl`.
149 void Merge(Cluster* other);
150
151 // If this is a trivial cluster containing only one node then return the ID
152 // of that node. May not be called otherwise.
GetIdOfOnlyNode() const153 int GetIdOfOnlyNode() const {
154 DCHECK_EQ(cluster_size(), 1);
155 return cycles_graph_node_id();
156 }
157
158 // The number of TF nodes in this cluster.
cluster_size() const159 int cluster_size() const { return cluster_size_; }
160
161 // The ID of the cluster as represented in `cycles_graph_`.
cycles_graph_node_id() const162 int cycles_graph_node_id() const { return cycles_graph_node_id_; }
163
164 // Sets the ID of the cluster as represented in `cycles_graph_`.
set_cycles_graph_node_id(int cycles_graph_node_id)165 void set_cycles_graph_node_id(int cycles_graph_node_id) {
166 cycles_graph_node_id_ = cycles_graph_node_id;
167 }
168
169 // The size of the cluster excluding constant and identity nodes.
effective_cluster_size() const170 int effective_cluster_size() const { return effective_cluster_size_; }
171
172 // True if the cluster has functional control flow like `If` and `While`.
has_functional_control_flow() const173 bool has_functional_control_flow() const {
174 return has_functional_control_flow_;
175 }
176
177 // The set of devices nodes in the cluster are placed on.
devices() const178 const DeviceSet& devices() const { return devices_; }
179
180 // If the cluster has a resource operation then the device the resource
181 // operation is placed on. A cluster may have resource ops placed only on a
182 // single device.
resource_op_device() const183 const absl::optional<DeviceId>& resource_op_device() const {
184 return resource_op_device_;
185 }
186
187 // If not nullopt the a predicate that is true iff the cluster is alive.
188 // Otherwise the user has (unsafely) disabled deadness analysis. If this is
189 // unset on a single Cluster instance then it is unset on all Cluster
190 // instances.
deadness_predicate() const191 const absl::optional<DeadnessPredicate>& deadness_predicate() const {
192 return deadness_predicate_;
193 }
194
195 // If true then the cluster has a XlaCompile=true attribute on one of its
196 // nodes.
is_xla_compile_attr_true() const197 bool is_xla_compile_attr_true() const { return is_xla_compile_attr_true_; }
198
199 // If not nullopt then the all nodes in the cluster either do not have the
200 // XlaScope attribute set or have it set to the value returned.
xla_scope() const201 const absl::optional<string>& xla_scope() const { return xla_scope_; }
202
203 // Returns the TF graph node IDs for the resource variable operations in
204 // this cluster.
resource_var_operation_node_ids() const205 absl::Span<const int> resource_var_operation_node_ids() const {
206 return resource_var_operation_node_ids_;
207 }
208
DebugString(const Graph & graph) const209 string DebugString(const Graph& graph) const {
210 Node* node = graph.FindNodeId(cycles_graph_node_id());
211 if (!node) {
212 // This should never happen but we try to be resilient because this is a
213 // debugging aid.
214 return absl::StrCat("NULL NODE IN #", cycles_graph_node_id());
215 }
216
217 if (cluster_size() == 1) {
218 return absl::StrCat("<", node->name(), " #", cycles_graph_node_id(),
219 ">");
220 }
221
222 return absl::StrCat("<", node->name(), " + ", cluster_size() - 1,
223 " others #", cycles_graph_node_id(), ">");
224 }
225
226 private:
227 int cluster_size_ = 1;
228 int cycles_graph_node_id_;
229 int effective_cluster_size_;
230 bool has_functional_control_flow_;
231 DeviceSet devices_;
232 absl::optional<DeviceId> resource_op_device_;
233 absl::optional<DeadnessPredicate> deadness_predicate_;
234 bool is_xla_compile_attr_true_;
235 absl::optional<string> xla_scope_;
236 std::vector<int> resource_var_operation_node_ids_;
237
238 TF_DISALLOW_COPY_AND_ASSIGN(Cluster);
239 };
240
241 // If `cluster` has only a single node then returns that, otherwise returns
242 // nullptr.
243 Node* GetOnlyNodeIn(const Cluster& cluster);
244
245 // Returns true if `cluster` is a trivial cluster containing a "sink like"
246 // node -- a NoOp node that only the Sink node control depends on.
247 bool IsSinkLike(const Cluster& cluster);
248
249 // Returns true if `cluster` looks like an "i++" operation on an integer
250 // scalar resource variable.
251 bool IsScalarIntegerResourceOperation(const Cluster& cluster);
252
253 // ---------------------------------------------------------------------------
254 // The pass proceeds in four steps, out of which `RunEdgeContractionLoop` and
255 // `CreateClusters` do most of the heavy lifting.
256
257 // Initializes some internal data structures.
258 //
259 // If this returns false then Initialize exited early (either because there is
260 // nothing to do or we saw a graph that we can't handle) and not all the
261 // fields in this MarkForCompilationPassImpl instance are set up.
262 StatusOr<bool> Initialize();
263
264 // Runs through the entire cluster graph in post-order and calls `fn(from,
265 // to)` on each edge. `fn(from, to)` is expected to return true if it was
266 // able to contract `from`->`to`.
267 //
268 // Returns true if `fn` returned true for any edge.
269 template <typename FnTy>
270 StatusOr<bool> ForEachEdgeInPostOrder(FnTy fn);
271
272 // Contracts as many edges as possible to create XLA clusters. After this
273 // finishes the clustering decisions made are implicitly stored in
274 // `clusters_`.
275 Status RunEdgeContractionLoop();
276
277 // Manifests the clustering decisions into the TF graph by tagging nodes with
278 // an `_XlaCluster` attribute. Also some basic filter logic, like
279 // tf_xla_min_cluster_size, are applied here.
280 Status CreateClusters();
281
282 Status DumpDebugInfo();
283
IsCompilationCandidate(Node * n) const284 bool IsCompilationCandidate(Node* n) const {
285 return compilation_candidates_.find(n) != compilation_candidates_.end();
286 }
287
288 // Tries to contract the edge from cluster `from` to cluster `to`. Returns
289 // true if successful.
290 StatusOr<bool> TryToContractEdge(Cluster* from, Cluster* to);
291
292 // Nodes that XLA can compile are put in `compilation_candidates_`.
293 Status FindCompilationCandidates();
294
295 bool CompilationDisallowedByXlaCompileAttr(Node* node);
296
297 // Populates `clusters_`.
298 Status BuildInitialClusterSet();
299
300 StatusOr<bool> ShouldCompileClusterImpl(const Cluster& cluster);
301
302 StatusOr<bool> ShouldCompileCluster(const Cluster& cluster);
303
304 StatusOr<bool> ClusteringWillIntroduceInterDeviceDependency(
305 const Cluster& from, const Cluster& to);
306
307 // Returns true if the devices in `cluster_a` and `cluster_b` are compatible
308 // and therefore not a hindrance for combining the two clusters into a larger
309 // cluster.
310 StatusOr<bool> AreDevicesCompatible(const Cluster& cluster_a,
311 const Cluster& cluster_b);
312
313 void DumpPostClusteringGraphs();
314 void VLogClusteringSummary();
315
MakeNewCluster(int cycles_graph_node_id,int effective_cluster_size,bool has_functional_control_flow,const DeviceSet & device_set,absl::optional<DeviceId> resource_op_device,absl::optional<int> resource_var_operation_node_id,absl::optional<DeadnessPredicate> deadness_predicate,bool is_xla_compile_attr_true,absl::optional<string> xla_scope)316 Cluster* MakeNewCluster(int cycles_graph_node_id, int effective_cluster_size,
317 bool has_functional_control_flow,
318 const DeviceSet& device_set,
319 absl::optional<DeviceId> resource_op_device,
320 absl::optional<int> resource_var_operation_node_id,
321 absl::optional<DeadnessPredicate> deadness_predicate,
322 bool is_xla_compile_attr_true,
323 absl::optional<string> xla_scope) {
324 cluster_storage_.push_back(absl::make_unique<Cluster>(
325 cycles_graph_node_id, effective_cluster_size,
326 has_functional_control_flow, device_set, resource_op_device,
327 resource_var_operation_node_id, deadness_predicate,
328 is_xla_compile_attr_true, xla_scope));
329 return cluster_storage_.back().get();
330 }
331
332 absl::optional<string> GetXlaScope(Node* n);
333
334 // Returns the cluster for node `n`. If two nodes, N1 and N2, are placed in
335 // the same cluster by the clustering algorithm then this function will return
336 // the same Cluster instance for N1 and N2.
337 //
338 // Returns nullptr if `n` is not a compilation candidate.
GetClusterForNode(Node * n)339 Cluster* GetClusterForNode(Node* n) {
340 return cluster_for_node_[n->id()].Get();
341 }
342
343 // Returns the cluster for a node in `cycles_graph_`. This uses the same
344 // underlying map because of how we set things up, but we can do an additional
345 // CHECK in this accessor.
346 //
347 // Returns nullptr if `node_id` is not a compilation candidate.
GetClusterForCyclesGraphNode(int node_id)348 Cluster* GetClusterForCyclesGraphNode(int node_id) {
349 // We have to check `graph_->FindNodeId(node) == nullptr` because we add all
350 // nodes in [0, graph_->num_node_ids()) to the cycle detection graph but the
351 // TF graph may be missing some node ids.
352 if (node_id >= graph_->num_node_ids() ||
353 graph_->FindNodeId(node_id) == nullptr) {
354 return nullptr;
355 }
356 Cluster* cluster = cluster_for_node_[node_id].Get();
357 if (cluster) {
358 DCHECK_EQ(cluster->cycles_graph_node_id(), node_id);
359 }
360 return cluster;
361 }
362
363 bool LogNotContractableAndReturnFalse(Cluster* from, Cluster* to,
364 absl::string_view reason);
365
366 // Finds a path in `cycles_graph_` from `from` to `to` that is not a direct
367 // edge from `from` to `to`.
368 //
369 // Tries to find a path that contains at least one unclusterable node.
370 std::vector<int> FindAlternatePathForDebugging(int from, int to);
371
372 // Returns a string representing `cycles_graph_node_id`. If the node is
373 // unclusterable (either it is a phatom "frame" node or is not a compilation
374 // candidate) then set `*found_unclustered` to true.
375 string DebugStringForCyclesGraphNode(int node_id, bool* found_unclustered);
376
377 // We could not contract the edge from `from` to `to`. Return a string
378 // describing an alternate path from `from` to `to` (besides the direct edge
379 // from `from` to `to`) which would have created a cycle had we contracted the
380 // edge.
381 //
382 // Tries (if possible) to find a path that contains at least one unclusterable
383 // node as it is surprising to the user if we print "A->B could not be
384 // contracted because of the path [P,Q,R]" where P, Q and R are all clusters
385 // since in that case a natural question is why we could not form a {A, P, Q,
386 // R, B} cluster.
387 string DescribePotentialCycle(int from, int to);
388
389 // Merge the clusters `cluster_from` and `cluster_to`. After this step the
390 // larger combined cluster is represented by `cluster_from`, but can have
391 // `cycles_graph_`'s ID of either `cluster_from` or `cluster_to` depending on
392 // which way will require less operations.
MergeClusters(Cluster * cluster_from,Cluster * cluster_to)393 bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) {
394 int from = cluster_from->cycles_graph_node_id();
395 int to = cluster_to->cycles_graph_node_id();
396
397 auto optional_merged_node = cycles_graph_.ContractEdge(from, to);
398 if (!optional_merged_node.has_value()) {
399 VLOG(3) << "Could not contract " << cluster_from->DebugString(*graph_)
400 << " -> " << cluster_to->DebugString(*graph_)
401 << " because contracting the edge would create a cycle via "
402 << DescribePotentialCycle(from, to) << ".";
403 return false;
404 }
405
406 // Merge the clusters.
407 cluster_from->Merge(cluster_to);
408 // Update `cycle_graph_`'s ID.
409 cluster_from->set_cycles_graph_node_id(optional_merged_node.value());
410
411 // Merge the UnionFind<Cluster*>.
412 cluster_for_node_[from].Merge(&cluster_for_node_[to]);
413
414 return true;
415 }
416
EdgeContractionFailureMsg(Cluster * from,Cluster * to,absl::string_view reason)417 string EdgeContractionFailureMsg(Cluster* from, Cluster* to,
418 absl::string_view reason) {
419 return absl::StrCat("Could not contract ", from->DebugString(*graph_),
420 " -> ", to->DebugString(*graph_), " because ", reason,
421 ".");
422 }
423
424 DebugOptions debug_options_;
425 Graph* graph_;
426 FunctionLibraryDefinition* flib_def_;
427 Env* env_;
428 OptimizerOptions::GlobalJitLevel global_jit_level_;
429 absl::flat_hash_map<const Cluster*, bool> should_compile_cluster_cache_;
430 jit::DeviceInfoCache device_info_cache_;
431
432 bool initialized_ = false;
433 bool edges_contracted_ = false;
434 bool clusters_created_ = false;
435
436 std::vector<std::unique_ptr<Cluster>> cluster_storage_;
437 std::vector<UnionFind<Cluster*>> cluster_for_node_;
438 GraphCycles cycles_graph_;
439 OrderedNodeSet compilation_candidates_;
440 std::unique_ptr<DeadnessAnalysis> deadness_analysis_;
441 int64 iteration_count_ = 0;
442 absl::flat_hash_set<std::pair<int, int>> unsafe_resource_deps_;
443 };
444
FindAlternatePathForDebugging(int from,int to)445 std::vector<int> MarkForCompilationPassImpl::FindAlternatePathForDebugging(
446 int from, int to) {
447 std::vector<int> rpo = cycles_graph_.AllNodesInPostOrder();
448 absl::c_reverse(rpo);
449
450 // best_pred_for_node[n] contains a predecessor of `n` that has an
451 // unclusterable node in some path from `from` to itself.
452 // best_pred_for_node[n] is unpopulated for nodes that are not reachable from
453 // `from`. We build this table up inductively by traversing the cycles graph
454 // in RPO.
455 absl::flat_hash_map<int, int> best_pred_for_node;
456 best_pred_for_node[from] = -1;
457
458 int rpo_index = 0, current_rpo_node;
459 do {
460 current_rpo_node = rpo[rpo_index++];
461 absl::optional<int> some_pred, preferred_pred;
462 for (int pred : cycles_graph_.Predecessors(current_rpo_node)) {
463 if (!best_pred_for_node.contains(pred)) {
464 continue;
465 }
466
467 // Ignore the from->to edge since we're trying to find an alternate path.
468 if (current_rpo_node == to && pred == from) {
469 continue;
470 }
471
472 some_pred = pred;
473 if (GetClusterForCyclesGraphNode(pred) == nullptr) {
474 preferred_pred = pred;
475 }
476 }
477
478 if (some_pred || preferred_pred) {
479 best_pred_for_node[current_rpo_node] =
480 preferred_pred.has_value() ? *preferred_pred : *some_pred;
481 }
482 } while (current_rpo_node != to);
483
484 auto get_best_pred = [&](int n) {
485 auto it = best_pred_for_node.find(n);
486 CHECK(it != best_pred_for_node.end());
487 return it->second;
488 };
489
490 std::vector<int> path;
491 int current_path_node = get_best_pred(to);
492 while (current_path_node != from) {
493 path.push_back(current_path_node);
494 current_path_node = get_best_pred(current_path_node);
495 }
496
497 absl::c_reverse(path);
498 return path;
499 }
500
DebugStringForCyclesGraphNode(int cycles_graph_node_id,bool * found_unclustered)501 string MarkForCompilationPassImpl::DebugStringForCyclesGraphNode(
502 int cycles_graph_node_id, bool* found_unclustered) {
503 Cluster* cluster = GetClusterForCyclesGraphNode(cycles_graph_node_id);
504 if (cluster) {
505 return cluster->DebugString(*graph_);
506 }
507
508 *found_unclustered = true;
509 if (cycles_graph_node_id >= graph_->num_node_ids()) {
510 return absl::StrCat("<oob #", cycles_graph_node_id, ">");
511 }
512
513 Node* node = graph_->FindNodeId(cycles_graph_node_id);
514 if (!node) {
515 return absl::StrCat("<bad #", cycles_graph_node_id, ">");
516 }
517
518 return node->name();
519 }
520
DescribePotentialCycle(int from,int to)521 string MarkForCompilationPassImpl::DescribePotentialCycle(int from, int to) {
522 std::vector<string> path_str;
523 bool found_unclustered = false;
524 absl::c_transform(FindAlternatePathForDebugging(from, to),
525 std::back_inserter(path_str), [&](int node_id) {
526 return DebugStringForCyclesGraphNode(node_id,
527 &found_unclustered);
528 });
529 return absl::StrCat(!found_unclustered ? "(all clusters) " : "", "[",
530 absl::StrJoin(path_str, ","), "]");
531 }
532
Merge(Cluster * other)533 void MarkForCompilationPassImpl::Cluster::Merge(Cluster* other) {
534 // We keep our own cycles_graph_node_id_ to mirror what GraphCycles does.
535
536 // Clearing out data structures in `other` is just a memory saving
537 // optimization and not needed for correctness.
538
539 cluster_size_ += other->cluster_size_;
540 effective_cluster_size_ += other->effective_cluster_size_;
541 has_functional_control_flow_ |= other->has_functional_control_flow_;
542
543 devices_.UnionWith(other->devices_);
544
545 DCHECK(!(resource_op_device_.has_value() &&
546 other->resource_op_device_.has_value()) ||
547 *resource_op_device_ == *other->resource_op_device_)
548 << "AreDevicesCompatible should have returned false otherwise!";
549
550 if (!resource_op_device_.has_value()) {
551 resource_op_device_ = other->resource_op_device_;
552 }
553
554 is_xla_compile_attr_true_ |= other->is_xla_compile_attr_true_;
555
556 if (!xla_scope_.has_value()) {
557 xla_scope_ = std::move(other->xla_scope_);
558 }
559
560 resource_var_operation_node_ids_.reserve(
561 resource_var_operation_node_ids_.size() +
562 other->resource_var_operation_node_ids_.size());
563 absl::c_copy(other->resource_var_operation_node_ids_,
564 std::back_inserter(resource_var_operation_node_ids_));
565 other->resource_var_operation_node_ids_.clear();
566 }
567
IgnoreResourceOpForSafetyAnalysis(jit::DeviceInfoCache * device_info_cache,const Node & n,bool * ignore)568 Status IgnoreResourceOpForSafetyAnalysis(
569 jit::DeviceInfoCache* device_info_cache, const Node& n, bool* ignore) {
570 // If a resource operation is assigned to XLA_CPU or XLA_GPU explicitly then
571 // ignore it during resource operation safety analysis. We need this hack
572 // because of two reasons:
573 //
574 // 1. Operations assigned to XLA_CPU and XLA_GPU have to always be compiled.
575 // 2. We don't support live-out values of type DT_RESOURCE and live-in values
576 // of type DT_RESOURCE that are not resource variables.
577 //
578 // Together these imply we cannot let resource variable safety analysis
579 // constrain e.g. a TensorArrayV3->TensorArrayAssignV3 edge to be in different
580 // clusters: both of them will have to be clustered because of (1) and we
581 // won't be able to keep the edge between the two as neither the input to the
582 // second XLA cluster nor the output from the first XLA cluster are supported
583 // because of (2).
584 //
585 // TODO(b/113100872): This can be fixed if the TensorFlow representation for
586 // TensorArray and Stack on the XLA_{C|G}PU devices were the same in XLA; then
587 // (2) would no longer hold.
588
589 if (n.assigned_device_name().empty()) {
590 *ignore = false;
591 return Status::OK();
592 }
593
594 TF_ASSIGN_OR_RETURN(
595 const XlaOpRegistry::DeviceRegistration* registration,
596 device_info_cache->GetCompilationDevice(n.assigned_device_name()));
597
598 if (!registration) {
599 *ignore = true;
600 } else {
601 *ignore = registration->cluster_resource_variable_ops_unsafely;
602 }
603 return Status::OK();
604 }
605
Initialize()606 StatusOr<bool> MarkForCompilationPassImpl::Initialize() {
607 TF_RET_CHECK(!initialized_ && !edges_contracted_ && !clusters_created_);
608 initialized_ = true;
609
610 TF_RETURN_IF_ERROR(FindCompilationCandidates());
611
612 if (compilation_candidates_.empty()) {
613 VLOG(2) << "No compilable candidates";
614 return false;
615 }
616
617 TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok,
618 CreateCycleDetectionGraph(graph_, &cycles_graph_));
619 if (!cycle_detection_graph_ok) {
620 // TODO(sanjoy): This should be logged via the XLA activity listener.
621 VLOG(2) << "Could not form cycle detection graph";
622 return false;
623 }
624
625 if (!debug_options_.ignore_deadness_checks) {
626 XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1);
627 TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(*graph_, &deadness_analysis_));
628 }
629
630 // Each compilation candidate belongs to a cluster. The cluster's
631 // representative names the node in the 'cycles' graph that represents the
632 // cluster.
633 TF_RETURN_IF_ERROR(BuildInitialClusterSet());
634 return true;
635 }
636
637 template <typename FnTy>
ForEachEdgeInPostOrder(FnTy fn)638 StatusOr<bool> MarkForCompilationPassImpl::ForEachEdgeInPostOrder(FnTy fn) {
639 bool changed = false;
640 for (int32 node : cycles_graph_.AllNodesInPostOrder()) {
641 Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
642 if (!cluster_from) {
643 continue;
644 }
645
646 // Make a copy of the set of successors because we may modify the graph in
647 // TryToContractEdge.
648 std::vector<int32> successors_copy =
649 cycles_graph_.SuccessorsCopy(cluster_from->cycles_graph_node_id());
650
651 for (int to : successors_copy) {
652 iteration_count_++;
653
654 Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
655 if (!cluster_to) {
656 continue;
657 }
658
659 TF_ASSIGN_OR_RETURN(bool contracted_edge, fn(cluster_from, cluster_to));
660 changed |= contracted_edge;
661 }
662 }
663
664 return changed;
665 }
666
GetOnlyNodeIn(const Cluster & cluster)667 Node* MarkForCompilationPassImpl::GetOnlyNodeIn(const Cluster& cluster) {
668 return cluster.cluster_size() == 1
669 ? graph_->FindNodeId(cluster.GetIdOfOnlyNode())
670 : nullptr;
671 }
672
IsSinkLike(const Cluster & cluster)673 bool MarkForCompilationPassImpl::IsSinkLike(const Cluster& cluster) {
674 if (Node* n = GetOnlyNodeIn(cluster)) {
675 return n->type_string() == "NoOp" && n->out_edges().size() == 1 &&
676 (*n->out_edges().begin())->dst()->IsSink();
677 }
678
679 return false;
680 }
681
IsScalarIntegerResourceOperation(const Cluster & cluster)682 bool MarkForCompilationPassImpl::IsScalarIntegerResourceOperation(
683 const Cluster& cluster) {
684 Node* n = GetOnlyNodeIn(cluster);
685 if (!n) {
686 return false;
687 }
688
689 if (n->type_string() != "AssignAddVariableOp" &&
690 n->type_string() != "AssignSubVariableOp") {
691 return false;
692 }
693
694 DataType dtype;
695 if (!TryGetNodeAttr(n->def(), "dtype", &dtype) || !DataTypeIsInteger(dtype)) {
696 return false;
697 }
698
699 Node* const_input = nullptr;
700 for (const Edge* e : n->in_edges()) {
701 if (!e->IsControlEdge() && e->src()->IsConstant()) {
702 const_input = e->src();
703 break;
704 }
705 }
706
707 if (!const_input) {
708 return false;
709 }
710
711 const TensorProto* proto = nullptr;
712 if (!TryGetNodeAttr(const_input->def(), "value", &proto)) {
713 return false;
714 }
715
716 return TensorShapeUtils::IsScalar(proto->tensor_shape());
717 }
718
RunEdgeContractionLoop()719 Status MarkForCompilationPassImpl::RunEdgeContractionLoop() {
720 TF_RET_CHECK(initialized_ && !edges_contracted_ && !clusters_created_);
721 edges_contracted_ = true;
722
723 // TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for
724 // example, from the Grappler fusion pass).
725
726 // In general there are multiple maximal clusterings, but they are not all
727 // equally performant. Some clustering decision are likely to improve
728 // performance much more than others, and we cannot order contractions on this
729 // cost function, nor can we look at global information while deciding on
730 // individual edges to contract. Instead, we will make decisions on these
731 // important edges then make decisions on all other edges, causing the highest
732 // chance of all most important edges to be contracted.
733 //
734 // An example of where this might occur is with a digraph:
735 // {A -> B, B -> C, A -> X, X -> C} where B is a Size operation and X is
736 // not-compilable. In this case, the valid clusterings are {A,B} or {B,C}. B
737 // should be clustered with A because it will prevent a potentially large
738 // tensor from A being computed and copied.
739 //
740 // To choose better maximal clusterings we make multiple iterations over the
741 // graph in post-order, where each such iteration is called a "phase".
742
743 // Phase 0: contract metadata operations with their producer.
744
745 VLOG(4) << "Running phase 0";
746 TF_RETURN_IF_ERROR(
747 ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) -> StatusOr<bool> {
748 // Shape consuming operations are desirable to cluster with their
749 // operands because they return a small set of scalar values after
750 // consuming a large amount of data. For example, given a graph X -> Y
751 // -> Size -> Z, where the possible clustering is [{X, Y, Size}, {Z}] or
752 // [{X, Y}, {Size, Z}], the better clustering is Size with Y because the
753 // output of size will be a small tensor while Y is a potentially large
754 // tensor that must be computed and possible transposed/copied before
755 // the second cluster executes.
756 Node* n = GetOnlyNodeIn(*to);
757 bool is_shape_consumer_op = n && IsShapeConsumerOp(*n);
758 if (!is_shape_consumer_op) {
759 return false;
760 }
761
762 return TryToContractEdge(from, to);
763 }).status());
764
765 // Phase 1: apply a heuristic to ensure that we don't mess up clustering due
766 // to "group_deps". After this phase most edges should have been contracted.
767
768 VLOG(4) << "Running phase 1";
769 TF_RETURN_IF_ERROR(
770 ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) -> StatusOr<bool> {
771 // We split out this phase to get good clustering in the presence of a
772 // specific pattern seen in some graphs:
773 //
774 // digraph {
775 // ApplyWeightUpdates_0 -> "iteration++"
776 // ApplyWeightUpdates_1 -> "iteration++"
777 // ApplyWeightUpdates_2 -> "iteration++"
778 // ApplyWeightUpdates_0 -> Computation_A
779 // ApplyWeightUpdates_1 -> Computation_B
780 // ApplyWeightUpdates_2 -> Computation_C
781 // Computation_A -> NoOp
782 // Computation_B -> NoOp
783 // Computation_C -> NoOp
784 // "iteration++" -> NoOp
785 // }
786 //
787 // In the graph above we can't cluster iteration++ with any of the
788 // gradient update operations since that will break the TF resource
789 // variable memory model. Given that constraint the ideal clustering
790 // would be to put all the gradient updates and all of the Computation_*
791 // nodes in one cluster, and leave iteration++ and NoOp unclustered.
792 //
793 // A naive post-order traversal would not create this good clustering,
794 // however. Instead it will first create a cluster that puts
795 // Computation_* nodes, the NoOp and iteration++ node in a single
796 // cluster, after which it will fail to put any of the
797 // ApplyWeightUpdates_* nodes into this cluster. To avoid this fate we
798 // instead run a pass that avoids contracting edges _into_ NoOps like
799 // the above, and avoid clustering edges _from_ "iteration++" like the
800 // above. Then we run a second pass that contracts the edges we could
801 // not contract the first time around.
802
803 if (IsSinkLike(*to)) {
804 return false;
805 }
806
807 if (IsScalarIntegerResourceOperation(*from)) {
808 return false;
809 }
810
811 return TryToContractEdge(from, to);
812 }).status());
813
814 // Phase 2: contract any remaining edges. After this phase we should have a
815 // maximal clustering:
816 //
817 // A. We visit a cluster only after maximally clustering all its children.
818 // B. By the time we're done with a node all of its children that could have
819 // been absorbed into the node have been absorbed.
820 // C. We have an invariant that making a cluster larger does not make edges
821 // leaving it more contractable. That is, if we have
822 // digraph { X->Y; Y->Z; } then collapsing X->Y does not make it possible
823 // to contract Y->Z if Y->Z was not contractible originally.
824 VLOG(4) << "Running phase 2";
825 TF_RETURN_IF_ERROR(ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) {
826 return TryToContractEdge(from, to);
827 }).status());
828
829 // Check that the conclusion made above (that iterating over the graph once in
830 // post order gives a maximal clustering) holds. Once the linear time
831 // post-order scheme has been battle tested we can move this to happen only in
832 // debug builds.
833 VLOG(2) << "Checking idempotence";
834 TF_ASSIGN_OR_RETURN(bool changed,
835 ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) {
836 return TryToContractEdge(from, to);
837 }));
838 TF_RET_CHECK(!changed);
839
840 return Status::OK();
841 }
842
843 std::atomic<int64> cluster_sequence_num;
844
GetNextClusterSequenceNumber()845 int64 GetNextClusterSequenceNumber() { return cluster_sequence_num++; }
846
CreateClusters()847 Status MarkForCompilationPassImpl::CreateClusters() {
848 TF_RET_CHECK(initialized_ && edges_contracted_ && !clusters_created_);
849 clusters_created_ = true;
850
851 // Names for each cluster.
852 std::unordered_map<int, string> cluster_names;
853
854 if (debug_options_.dump_graphs) {
855 DumpGraphToFile("before_mark_for_compilation", *graph_, flib_def_);
856 }
857
858 // Mark clusters for compilation that:
859 // * are placed on a device that requires compilation (an XlaDevice),
860 // * are explicitly marked for compilation (_XlaCompile=true), or
861 // * have more than debug_options_.xla_min_cluster_size elements (applicable
862 // only if compilation is enabled, otherwise there will be no such
863 // candidates).
864 for (Node* n : compilation_candidates_) {
865 Cluster* cluster = GetClusterForNode(n);
866 TF_ASSIGN_OR_RETURN(bool should_compile_cluster,
867 ShouldCompileCluster(*cluster));
868 if (!should_compile_cluster) {
869 continue;
870 }
871
872 // We assume that functional If and While nodes have at least
873 // min_cluster_size non-trivial nodes in them. It would be more principled
874 // to (recursively) verify this fact, but that's probably not worth the
875 // trouble.
876
877 if (cluster->effective_cluster_size() >= debug_options_.min_cluster_size ||
878 cluster->has_functional_control_flow() ||
879 cluster->is_xla_compile_attr_true()) {
880 string& name = cluster_names[cluster->cycles_graph_node_id()];
881
882 if (name.empty()) {
883 name = absl::StrCat("cluster_", GetNextClusterSequenceNumber());
884 }
885
886 n->AddAttr(kXlaClusterAttr, name);
887 n->AddAttr(kXlaAlreadyClustered, true);
888 VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
889 }
890 }
891
892 return Status::OK();
893 }
894
DumpDebugInfo()895 Status MarkForCompilationPassImpl::DumpDebugInfo() {
896 TF_RET_CHECK(initialized_ && edges_contracted_ && clusters_created_);
897
898 if (debug_options_.dump_graphs) {
899 DumpPostClusteringGraphs();
900 }
901
902 VLogClusteringSummary();
903
904 return Status::OK();
905 }
906
907 StatusOr<bool>
ClusteringWillIntroduceInterDeviceDependency(const Cluster & cluster_from,const Cluster & cluster_to)908 MarkForCompilationPassImpl::ClusteringWillIntroduceInterDeviceDependency(
909 const Cluster& cluster_from, const Cluster& cluster_to) {
910 // If any of the consumer's producers are on a different device, do not
911 // cluster these nodes. This prevents other work on this device from being
912 // delayed by work on other devices. We consider predecessors of the entire
913 // cluster rather than just the inputs to the node to prevent the cluster
914 // still being combined in cases where the 'to' cluster has multiple
915 // dependencies on the 'from' cluster and another dependency leads to a
916 // merging of the clusters.
917 //
918 // TODO(b/117085735): We probably want to handle the reciprocal of this case
919 // where a cluster is producing data for multiple devices.
920 for (const auto& in_id :
921 cycles_graph_.Predecessors(cluster_to.cycles_graph_node_id())) {
922 const Cluster* cluster_in = GetClusterForCyclesGraphNode(in_id);
923 if (cluster_in) {
924 TF_ASSIGN_OR_RETURN(bool devices_compatible,
925 AreDevicesCompatible(cluster_to, *cluster_in));
926 if (!devices_compatible) {
927 return true;
928 }
929 TF_ASSIGN_OR_RETURN(devices_compatible,
930 AreDevicesCompatible(cluster_from, *cluster_in));
931 if (!devices_compatible) {
932 return true;
933 }
934 }
935 }
936
937 return false;
938 }
939
GetXlaScope(Node * node)940 absl::optional<string> MarkForCompilationPassImpl::GetXlaScope(Node* node) {
941 // Look for either _XlaScope or _XlaInternalScope on both nodes to guide
942 // clustering. If both nodes have a scope and the scopes do not match, do
943 // not cluster along this edge. If even one of the nodes lacks a scope
944 // attribute, then it is treated as a "bridge" and a cluster may be created
945 // along it.
946 //
947 // The difference between _XlaScope and _XlaInternalScope is that _XlaScope is
948 // provided by users through jit_scope APIs, while _XlaInternalScope is
949 // automatically generated by the ClusterScopingPass when auto_jit is on. As
950 // such, we respect _XlaScope only when auto_jit is off, while respecting
951 // _XlaInternalScope only when auto_jit is on.
952 //
953 // We may want to restrict the _XlaScope behavior to require all nodes marked
954 // with _XlaCompile=true to also have a _XlaScope property set (and raise an
955 // error otherwise); but for now we don't do this.
956
957 if (global_jit_level_ != OptimizerOptions::OFF) {
958 // If global_jit_level_ is ON, respect only _XlaInternalScope.
959 const string& scope =
960 GetNodeAttrString(node->attrs(), kXlaInternalScopeAttr);
961 if (!scope.empty()) {
962 return scope;
963 }
964 } else {
965 // If global_jit_level_ is OFF, respect only _XlaScope.
966 const string& scope = GetNodeAttrString(node->attrs(), kXlaScopeAttr);
967 if (!scope.empty()) {
968 return scope;
969 }
970 }
971
972 return absl::nullopt;
973 }
974
975 // Returns true iff the attribute `attr_name` is attached to either the node or
976 // to it's callee.
GetNodeOrFuncAttr(Node * node,FunctionLibraryDefinition * flib_def,const char * attr_name)977 static bool GetNodeOrFuncAttr(Node* node, FunctionLibraryDefinition* flib_def,
978 const char* attr_name) {
979 bool out = false;
980 bool attr_value;
981 if (TryGetNodeAttr(node->attrs(), attr_name, &attr_value)) {
982 out |= attr_value;
983 }
984
985 if (flib_def->GetAttr(*node, attr_name, &attr_value).ok()) {
986 out |= attr_value;
987 }
988 return out;
989 }
990
BuildInitialClusterSet()991 Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
992 auto ignore_resource_ops = [&](const Node& n, bool* ignore) {
993 return IgnoreResourceOpForSafetyAnalysis(&device_info_cache_, n, ignore);
994 };
995
996 std::vector<std::pair<int, int>> unsafe_resource_deps_vect;
997 TF_RETURN_IF_ERROR(ComputeIncompatibleResourceOperationPairs(
998 *graph_, flib_def_, ignore_resource_ops, &unsafe_resource_deps_vect));
999 absl::c_copy(
1000 unsafe_resource_deps_vect,
1001 std::inserter(unsafe_resource_deps_, unsafe_resource_deps_.begin()));
1002
1003 cluster_for_node_.resize(graph_->num_node_ids());
1004 for (Node* node : graph_->nodes()) {
1005 if (!IsCompilationCandidate(node)) {
1006 cluster_for_node_[node->id()].Get() = nullptr;
1007 continue;
1008 }
1009
1010 // We want clusters to be big enough that the benefit from XLA's
1011 // optimizations offsets XLA related overhead (for instance we add some
1012 // Switch/Merge nodes into the graph to implement lazy compilation). To
1013 // this end, we don't count Identity and Constant nodes because they do not
1014 // enable interesting optimizations by themselves.
1015 int effective_cluster_size =
1016 (node->IsIdentity() || node->IsConstant()) ? 0 : 1;
1017
1018 bool has_functional_control_flow = node->IsWhileNode() || node->IsIfNode();
1019
1020 absl::optional<DeadnessPredicate> deadness_predicate;
1021 if (deadness_analysis_) {
1022 TF_ASSIGN_OR_RETURN(
1023 deadness_predicate,
1024 deadness_analysis_->GetPredicateFor(node, Graph::kControlSlot));
1025 }
1026
1027 const string& device_name_str = !node->assigned_device_name().empty()
1028 ? node->assigned_device_name()
1029 : node->requested_device();
1030 TF_ASSIGN_OR_RETURN(DeviceId device,
1031 device_info_cache_.GetIdFor(device_name_str));
1032
1033 bool is_resource_op = HasResourceInputOrOutput(*node);
1034 absl::optional<DeviceId> resource_op_device;
1035 if (is_resource_op) {
1036 resource_op_device = device;
1037 }
1038
1039 absl::optional<int> resource_var_operation_node_id;
1040 if (is_resource_op || MayCallFunction(*node, flib_def_)) {
1041 resource_var_operation_node_id = node->id();
1042 }
1043
1044 bool is_xla_compile_attr_true =
1045 GetNodeOrFuncAttr(node, flib_def_, kXlaCompileAttr) ||
1046 (global_jit_level_ != OptimizerOptions::OFF &&
1047 GetNodeOrFuncAttr(node, flib_def_, kXlaMustCompileAttr));
1048
1049 DeviceSet devices;
1050 devices.Insert(device);
1051
1052 Cluster* new_cluster = MakeNewCluster(
1053 /*cycles_graph_node_id=*/node->id(),
1054 /*effective_cluster_size=*/effective_cluster_size,
1055 /*has_functional_control_flow=*/has_functional_control_flow, devices,
1056 resource_op_device, resource_var_operation_node_id, deadness_predicate,
1057 /*is_xla_compile_attr_true=*/is_xla_compile_attr_true,
1058 GetXlaScope(node));
1059
1060 cluster_for_node_[node->id()].Get() = new_cluster;
1061 }
1062
1063 return Status::OK();
1064 }
1065
IsIdentityDrivingConstsInLoop(Node * node)1066 StatusOr<bool> IsIdentityDrivingConstsInLoop(Node* node) {
1067 if (!node->IsIdentity()) {
1068 return false;
1069 }
1070
1071 // Check if the Identity is driven by a Switch on its true path.
1072 auto it = absl::c_find_if(node->in_edges(), [](const Edge* e) {
1073 return e->src()->IsSwitch() && e->src_output() == 1;
1074 });
1075 if (it == node->in_edges().end()) {
1076 return false;
1077 }
1078 const Node* switch_node = (*it)->src();
1079
1080 // Check if the Switch is driven by LoopCond.
1081 const Node* maybe_loop_cond;
1082 TF_RETURN_IF_ERROR(switch_node->input_node(1, &maybe_loop_cond));
1083 if (!maybe_loop_cond->IsLoopCond()) {
1084 return false;
1085 }
1086
1087 // Check if the Identity is driving any const nodes through a control edge.
1088 bool driving_any_consts =
1089 absl::c_any_of(node->out_edges(), [](const Edge* e) {
1090 return e->dst()->IsConstant() && e->IsControlEdge();
1091 });
1092 if (!driving_any_consts) {
1093 return false;
1094 }
1095
1096 return true;
1097 }
1098
GetOrCreateAllowlist()1099 absl::flat_hash_set<string> GetOrCreateAllowlist() {
1100 absl::flat_hash_map<string, std::vector<string>>* allowlist_table =
1101 tensorflow::GetAllowlistTable();
1102 MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1103 absl::flat_hash_set<string> allowlist;
1104
1105 for (auto s : absl::StrSplit(flags->tf_xla_ops_to_cluster, ',')) {
1106 if (s == "FUSIBLE") {
1107 for (auto pair : *allowlist_table) {
1108 allowlist.insert(pair.second.begin(), pair.second.end());
1109 }
1110 } else if (allowlist_table->contains(s)) {
1111 auto v = allowlist_table->at(s);
1112 allowlist.insert(v.begin(), v.end());
1113 } else if (!s.empty()) {
1114 // Should be a user provided TF operation.
1115 allowlist.insert(string(s));
1116 }
1117 }
1118
1119 if (VLOG_IS_ON(2) && !allowlist.empty()) {
1120 std::vector<string> vallowlist(allowlist.begin(), allowlist.end());
1121 absl::c_sort(vallowlist);
1122 VLOG(2) << "XLA clustering will only consider the following TF operations: "
1123 << absl::StrJoin(vallowlist, " ");
1124 }
1125 return allowlist;
1126 }
1127
FindCompilationCandidates()1128 Status MarkForCompilationPassImpl::FindCompilationCandidates() {
1129 OptimizerOptions opts;
1130 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
1131 new ProcessFunctionLibraryRuntime(nullptr, env_, /*config=*/nullptr,
1132 TF_GRAPH_DEF_VERSION, flib_def_, opts));
1133 FunctionLibraryRuntime* lib_runtime =
1134 pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
1135 std::vector<bool> compile_time_const_nodes(graph_->num_node_ids(), false);
1136 TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
1137 *graph_, /*compile_time_const_arg_indices=*/nullptr,
1138 &compile_time_const_nodes, lib_runtime));
1139 // Iterate over nodes in sorted order so that compiler fuel is deterministic.
1140 // We can't simply pass op_nodes().begin() and op_nodes().end() to the
1141 // std::vector constructor because they're not proper iterators, with
1142 // iterator_traits defined and so on.
1143 std::vector<Node*> sorted_nodes;
1144 for (Node* node : graph_->op_nodes()) {
1145 sorted_nodes.push_back(node);
1146 }
1147 std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID());
1148
1149 if (*debug_options_.fuel >= std::numeric_limits<int64>::max() / 2) {
1150 // The assumption is that if fuel started out as INT64_MAX, it will forever
1151 // stay greater than INT64_MAX / 2.
1152 VLOG(2) << "Starting fuel: infinity";
1153 } else {
1154 VLOG(2) << "Starting fuel: " << *debug_options_.fuel;
1155 }
1156
1157 VLOG(2) << "sorted_nodes.size() = " << sorted_nodes.size();
1158
1159 auto allowlist = GetOrCreateAllowlist();
1160
1161 std::vector<string> vall_ops = XlaOpRegistry::GetAllRegisteredOps();
1162 absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end());
1163 // Check that user's provided TF operation really exists.
1164 for (const auto& s : allowlist) {
1165 if (!all_ops.contains(string(s))) {
1166 return errors::InvalidArgument(
1167 "The operation '", s,
1168 "' passed to --tf_xla_ops_to_cluster is not supported by XLA.");
1169 }
1170 }
1171
1172 for (Node* node : sorted_nodes) {
1173 if (*debug_options_.fuel <= 0) {
1174 VLOG(1)
1175 << "Hit fuel limit; not marking any remaining ops as clusterable.";
1176 break;
1177 }
1178
1179 TF_ASSIGN_OR_RETURN(
1180 const DeviceType& device_type,
1181 device_info_cache_.GetDeviceTypeFor(node->assigned_device_name()));
1182 VLOG(4) << "Device type for " << node->name() << ": "
1183 << device_type.type_string();
1184
1185 if (CompilationDisallowedByXlaCompileAttr(node)) {
1186 VLOG(2) << "Not clustering " << node->name()
1187 << ": disallowed by _XlaCompile attribute";
1188 continue;
1189 }
1190
1191 const XlaOpRegistry::DeviceRegistration* registration;
1192 if (!XlaOpRegistry::GetCompilationDevice(device_type.type(),
1193 ®istration)) {
1194 VLOG(2) << "Rejecting " << node->name()
1195 << ": could not find JIT device for " << device_type.type();
1196 continue;
1197 }
1198
1199 RecursiveCompilabilityChecker::OperationFilter filter =
1200 CreateOperationFilter(*registration);
1201 filter.require_always_compilable = true;
1202 filter.allow_string_consts = false;
1203
1204 RecursiveCompilabilityChecker checker(
1205 filter, DeviceType{registration->compilation_device_name});
1206
1207 if (!checker.IsCompilableNode(*node, lib_runtime)) {
1208 continue;
1209 }
1210
1211 if (node->type_string() == "Const") {
1212 // Skip Const op with type DT_STRING, since XLA autoclustering doesn't
1213 // support it.
1214 const AttrValue* attr = node->attrs().Find("dtype");
1215 if (attr != nullptr && attr->type() == DT_STRING) {
1216 continue;
1217 }
1218 }
1219
1220 if (!allowlist.empty() && !allowlist.contains(node->def().op())) {
1221 VLOG(1) << "Rejecting TF operation " << node->def().op()
1222 << " as it is not listed in --tf_xla_ops_to_cluster.";
1223 continue;
1224 }
1225
1226 if (compile_time_const_nodes[node->id()]) {
1227 const OpDef* op_def;
1228 TF_RETURN_IF_ERROR(
1229 graph_->op_registry()->LookUpOpDef(node->type_string(), &op_def));
1230 if (op_def->is_stateful()) {
1231 // It is easiest to demonstrate the problem we're trying to solve with
1232 // an example. Say we have this graph:
1233 //
1234 // shape = RandomUniformInt();
1235 // reshape = Reshape(input, shape)
1236 //
1237 // Both RandomUniformInt and Reshape are compilable by XLA so, absent
1238 // any other reason, we will try to put both shape and reshape in the
1239 // same cluster. However, since XLA only supports statically shaped
1240 // values, it will expect to be able to constant fold `shape` to get a
1241 // static shape for `reshape`. This is a problem because side-effecting
1242 // ops like RandomUniformInt() cannot be constant folded. We fix this
1243 // by putting `shape` and `reshape` in different clusters, which results
1244 // in us recompiling `reshape`'s cluster for every new value of `shape`,
1245 // making `reshape` statically sized within each compilation. We
1246 // simplify the solution even further by disallowing operations like
1247 // `shape` from being part of *any* non-trivial cluster. They're either
1248 // not compiled by XLA altogether or, if assigned to an XLA_* device
1249 // with "must compile" semantics, compiled into a trivial single-op
1250 // cluster. This approach leaves some room for improvement, and we can
1251 // consider implementing a more aggressive data-flow-analysis based
1252 // solution in the future if needed.
1253 //
1254 // One ugly problem we have to contend with: certain sets of ops *have*
1255 // to be in the same cluster because values flowing between them have
1256 // types that can't be live-in or live-out of a cluster. These ops are:
1257 //
1258 // - TensorArray ops operating on the same TensorArray instance.
1259 // - Stack ops operating on the same Stack instance.
1260 //
1261 // To work around this we avoid isolating these specific ops. Because
1262 // of this concession it is unsound to auto-cluster them because then
1263 // we'd create clusters we could not compile (because we can't constant
1264 // fold, say, a TensorArrayRead or a StackPopV2). But we don't
1265 // auto-cluster these operations today so we're good for now.
1266 const XlaResourceOpInfo* op_info =
1267 GetResourceOpInfoForOp(node->type_string());
1268 bool is_tensor_array_or_stack_op =
1269 op_info && op_info->resource_kind() != XlaResourceKind::kVariable;
1270 if (!is_tensor_array_or_stack_op) {
1271 VLOG(2) << "Isolating " << node->name()
1272 << ": must-be-constant stateful op";
1273 continue;
1274 }
1275 }
1276 }
1277
1278 // This is a heuristic to avoid creating dependency between while loop
1279 // condition and body computations. Dependency between them can be created
1280 // if a special Identity node in the following pattern is clustered in.
1281 // That is, an Identity node in the loop cond computation is used to drive
1282 // const nodes consumed by the loop body. If this Identity node goes into
1283 // the same cluster with nodes from the loop body, extra dependency is
1284 // created between the loop cond and body computations and it hinders the
1285 // progression of the loop cond computation at runtime with significant
1286 // overhead. Specifically, we look for the below pattern and do not cluster
1287 // in this Identity to avoid the described issue. Since Identity has low
1288 // execution cost in native TF, the fact that this heuristic gives up these
1289 // special Identity nodes as candidates should not harm any performance. If
1290 // other considerations emerge in the future, we can revisit the heuristic
1291 // and only disallow these Identities to go into the cluster with nodes from
1292 // the loop body but still consider them candidates.
1293 //
1294 // LoopCond ->
1295 // Merge -> Switch -> Identity -> i++ -> ... -> NextIteration
1296 // ..> Const -> LoopBody
1297 // (control edge)
1298 TF_ASSIGN_OR_RETURN(bool is_identity_driving_consts_in_loop,
1299 IsIdentityDrivingConstsInLoop(node));
1300 if (is_identity_driving_consts_in_loop) {
1301 VLOG(2) << "Rejecting " << node->name()
1302 << ": including it can create dependencies between while loop "
1303 "condition and body computations with runtime overhead.";
1304 continue;
1305 }
1306
1307 compilation_candidates_.insert(node);
1308 --(*debug_options_.fuel);
1309 }
1310
1311 VLOG(2) << "compilation_candidates_.size() = "
1312 << compilation_candidates_.size();
1313 return Status::OK();
1314 }
1315
CompilationDisallowedByXlaCompileAttr(Node * node)1316 bool MarkForCompilationPassImpl::CompilationDisallowedByXlaCompileAttr(
1317 Node* node) {
1318 if (debug_options_.ignore_xla_compile_attr) {
1319 return false;
1320 }
1321
1322 // If there is a _XlaCompile annotation, use its value.
1323 bool compile = false;
1324 Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
1325 if (status.ok()) {
1326 if (!compile) {
1327 VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr("
1328 << kXlaCompileAttr << ") is false.";
1329 }
1330 return !compile;
1331 }
1332
1333 status = flib_def_->GetAttr(*node, kXlaCompileAttr, &compile);
1334 if (status.ok()) {
1335 if (!compile) {
1336 VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr("
1337 << kXlaCompileAttr << ") on callee is false.";
1338 }
1339 return !compile;
1340 }
1341
1342 return false;
1343 }
1344
LogNotContractableAndReturnFalse(Cluster * from,Cluster * to,absl::string_view reason)1345 bool MarkForCompilationPassImpl::LogNotContractableAndReturnFalse(
1346 Cluster* from, Cluster* to, absl::string_view reason) {
1347 VLOG(3) << EdgeContractionFailureMsg(from, to, reason);
1348 return false;
1349 }
1350
TryToContractEdge(Cluster * from,Cluster * to)1351 StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdge(Cluster* from,
1352 Cluster* to) {
1353 DCHECK(from->deadness_predicate().has_value() ==
1354 to->deadness_predicate().has_value());
1355 if (from->deadness_predicate() != to->deadness_predicate()) {
1356 VLOG(3) << EdgeContractionFailureMsg(
1357 from, to,
1358 absl::StrCat(
1359 "the two nodes have mismatching deadness: ",
1360 deadness_analysis_->DebugString(*from->deadness_predicate()),
1361 " and ",
1362 deadness_analysis_->DebugString(*to->deadness_predicate())));
1363 return false;
1364 }
1365
1366 TF_ASSIGN_OR_RETURN(bool devices_compatible,
1367 AreDevicesCompatible(*from, *to));
1368 if (!devices_compatible) {
1369 return LogNotContractableAndReturnFalse(
1370 from, to, "the two nodes have incompatible devices");
1371 }
1372
1373 if (from->xla_scope().has_value() && to->xla_scope().has_value() &&
1374 *from->xla_scope() != *to->xla_scope()) {
1375 return LogNotContractableAndReturnFalse(
1376 from, to, "the two nodes have mismatching XLA scopes");
1377 }
1378
1379 // Don't exceed the maximum cluster size.
1380 if (from->cluster_size() + to->cluster_size() >
1381 debug_options_.max_cluster_size) {
1382 return LogNotContractableAndReturnFalse(
1383 from, to, "the new cluster will be larger than the max cluster size");
1384 }
1385
1386 TF_ASSIGN_OR_RETURN(bool will_introduce_cross_device_dependency,
1387 ClusteringWillIntroduceInterDeviceDependency(*from, *to));
1388
1389 if (will_introduce_cross_device_dependency) {
1390 return LogNotContractableAndReturnFalse(
1391 from, to, "the new cluster will introduce a cross device dependency");
1392 }
1393
1394 // Check if contracting this edge will break the resource variable concurrency
1395 // semantics. In theory this is quadratic in the number of nodes, but seems
1396 // to not be a problem in practice so far.
1397 if (!debug_options_.ignore_resource_variable_checks) {
1398 for (int resource_var_from : from->resource_var_operation_node_ids()) {
1399 for (int resource_var_to : to->resource_var_operation_node_ids()) {
1400 // If unsafe_resource_deps_ contains {A, B} then
1401 //
1402 // a. A and B are resource operations.
1403 // b. A and B cannot be placed in the same cluster.
1404 // c. There is no path from B to A in the cycles graph (but there may
1405 // be a path from A to B).
1406 //
1407 // So check the legality of the edge contraction by checking if any of
1408 // the n^2 pairs of resource variable operations are forbidden.
1409 if (unsafe_resource_deps_.contains(
1410 {resource_var_from, resource_var_to})) {
1411 return LogNotContractableAndReturnFalse(
1412 from, to,
1413 "the new cluster would break resource variable semantics");
1414 }
1415 }
1416 }
1417 }
1418
1419 return MergeClusters(from, to);
1420 }
1421
Run()1422 Status MarkForCompilationPassImpl::Run() {
1423 // Make sure that kernels have been registered on the JIT device.
1424 XlaOpRegistry::RegisterCompilationKernels();
1425
1426 // Start the timer after XlaOpRegistry::RegisterCompilationKernels which does
1427 // some one-time work.
1428 XLA_SCOPED_LOGGING_TIMER_LEVEL("MarkForCompilationPassImpl::Run", 1);
1429
1430 TF_ASSIGN_OR_RETURN(bool initialized, Initialize());
1431 if (!initialized) {
1432 // Initialization exited early which means this instance of
1433 // MarkForCompilationPassImpl is not set up to run the subsequent phases.
1434 return Status::OK();
1435 }
1436
1437 TF_RETURN_IF_ERROR(RunEdgeContractionLoop());
1438 TF_RETURN_IF_ERROR(CreateClusters());
1439 TF_RETURN_IF_ERROR(DumpDebugInfo());
1440
1441 return Status::OK();
1442 }
1443
DumpPostClusteringGraphs()1444 void MarkForCompilationPassImpl::DumpPostClusteringGraphs() {
1445 DumpGraphToFile("mark_for_compilation", *graph_, flib_def_);
1446
1447 // We also dump out an annotated version of the TF graph where the nodes
1448 // names are prefixed with the cluster names. This can help visualizing the
1449 // clustering decisions on TensorBoard.
1450 Graph new_graph(graph_->op_registry());
1451 CopyGraph(*graph_, &new_graph);
1452
1453 for (Node* n : new_graph.nodes()) {
1454 if (absl::optional<absl::string_view> cluster_name =
1455 GetXlaClusterForNode(*n)) {
1456 n->set_name(absl::StrCat(*cluster_name, "/", n->name()));
1457 } else if (n->type_string() == "VarHandleOp") {
1458 n->set_name(absl::StrCat("varhandle/", n->name()));
1459 } else {
1460 // There is room for improvement here. In particular, it may help to
1461 // split these unclustered nodes into classes where every node in a
1462 // specific class has edges to and from the same set of clusters.
1463 n->set_name(absl::StrCat("unclustered/", n->name()));
1464 }
1465 }
1466
1467 DumpGraphToFile("mark_for_compilation_annotated", new_graph, flib_def_);
1468 }
1469
RatioToString(int numerator,int denominator)1470 string RatioToString(int numerator, int denominator) {
1471 return absl::StrFormat("%d / %d (%.2f%%)", numerator, denominator,
1472 (100.0 * numerator) / denominator);
1473 }
1474
VLogClusteringSummary()1475 void MarkForCompilationPassImpl::VLogClusteringSummary() {
1476 if (!VLOG_IS_ON(2)) {
1477 return;
1478 }
1479
1480 XlaAutoClusteringSummary auto_clustering_info =
1481 GetXlaAutoClusteringSummary(*graph_);
1482
1483 VLOG(2) << "*** Clustering info for graph of size " << graph_->num_nodes();
1484 VLOG(2) << " Built " << auto_clustering_info.clusters_size()
1485 << " clusters, size "
1486 << RatioToString(auto_clustering_info.clustered_node_count(),
1487 graph_->num_nodes());
1488
1489 for (const XlaAutoClusteringSummary::Cluster& cluster :
1490 auto_clustering_info.clusters()) {
1491 absl::string_view cluster_name = cluster.name();
1492 int size = cluster.size();
1493 VLOG(2) << " " << cluster_name << " "
1494 << RatioToString(size, graph_->num_nodes());
1495 for (const XlaAutoClusteringSummary::OpAndCount& op_count :
1496 cluster.op_histogram()) {
1497 VLOG(3) << " " << op_count.op() << ": " << op_count.count()
1498 << " instances";
1499 }
1500 }
1501
1502 if (!auto_clustering_info.unclustered_op_histogram().empty()) {
1503 VLOG(2) << " Unclustered nodes: "
1504 << RatioToString(auto_clustering_info.unclustered_node_count(),
1505 graph_->num_nodes());
1506 for (const XlaAutoClusteringSummary::OpAndCount& op_count :
1507 auto_clustering_info.unclustered_op_histogram()) {
1508 VLOG(3) << " " << op_count.op() << ": " << op_count.count()
1509 << " instances";
1510 }
1511 }
1512
1513 struct EdgeInfo {
1514 absl::string_view node_name;
1515 absl::optional<absl::string_view> cluster_name;
1516
1517 absl::string_view GetClusterName() const {
1518 return cluster_name ? *cluster_name : "[none]";
1519 }
1520
1521 std::pair<absl::string_view, absl::optional<absl::string_view>> AsPair()
1522 const {
1523 return {node_name, cluster_name};
1524 }
1525
1526 bool operator<(const EdgeInfo& other) const {
1527 return AsPair() < other.AsPair();
1528 }
1529 };
1530
1531 using EdgeInfoMap = std::map<absl::string_view, std::map<EdgeInfo, int64>>;
1532
1533 EdgeInfoMap incoming_edge_infos;
1534 EdgeInfoMap outgoing_edge_infos;
1535
1536 std::set<absl::string_view> cluster_names_to_print;
1537
1538 for (const Edge* e : graph_->edges()) {
1539 const Node* from = e->src();
1540 absl::optional<absl::string_view> from_cluster_name =
1541 GetXlaClusterForNode(*from);
1542
1543 const Node* to = e->dst();
1544 absl::optional<absl::string_view> to_cluster_name =
1545 GetXlaClusterForNode(*to);
1546
1547 if (to_cluster_name == from_cluster_name) {
1548 continue;
1549 }
1550
1551 if (to_cluster_name) {
1552 incoming_edge_infos[*to_cluster_name]
1553 [EdgeInfo{from->name(), from_cluster_name}]++;
1554 cluster_names_to_print.insert(*to_cluster_name);
1555 }
1556
1557 if (from_cluster_name) {
1558 outgoing_edge_infos[*from_cluster_name][{to->name(), to_cluster_name}]++;
1559 cluster_names_to_print.insert(*from_cluster_name);
1560 }
1561 }
1562
1563 VLOG(4) << "*** Inter-Cluster edges:";
1564 if (cluster_names_to_print.empty()) {
1565 VLOG(4) << " [none]";
1566 }
1567
1568 auto print_edge_info_set_for_cluster = [&](absl::string_view cluster_name,
1569 const EdgeInfoMap& edge_info_map,
1570 absl::string_view desc) {
1571 auto it = edge_info_map.find(cluster_name);
1572 if (it != edge_info_map.end()) {
1573 VLOG(4) << " " << it->second.size() << " " << desc << " edges";
1574 for (const auto& edge_info_count_pair : it->second) {
1575 VLOG(4) << " " << edge_info_count_pair.first.GetClusterName() << " "
1576 << edge_info_count_pair.first.node_name << " # "
1577 << edge_info_count_pair.second;
1578 }
1579 } else {
1580 VLOG(4) << " No " << desc << " edges.";
1581 }
1582 };
1583
1584 for (absl::string_view cluster_name : cluster_names_to_print) {
1585 VLOG(4) << " ** Cluster " << cluster_name;
1586 print_edge_info_set_for_cluster(cluster_name, incoming_edge_infos,
1587 "incoming");
1588 print_edge_info_set_for_cluster(cluster_name, outgoing_edge_infos,
1589 "outgoing");
1590 }
1591 }
1592
AreDevicesCompatible(const Cluster & cluster_a,const Cluster & cluster_b)1593 StatusOr<bool> MarkForCompilationPassImpl::AreDevicesCompatible(
1594 const Cluster& cluster_a, const Cluster& cluster_b) {
1595 DeviceSet devices = cluster_a.devices();
1596 devices.UnionWith(cluster_b.devices());
1597
1598 TF_ASSIGN_OR_RETURN(
1599 absl::optional<jit::DeviceId> maybe_chosen_device,
1600 MaybePickDeviceForXla(device_info_cache_, devices,
1601 /*allow_mixing_unknown_and_cpu=*/false));
1602 if (!maybe_chosen_device.has_value()) {
1603 return false;
1604 }
1605
1606 jit::DeviceId chosen_device = *maybe_chosen_device;
1607
1608 // If we are able to pick a device `chosen_device` for the larger cluster, the
1609 // resource operations in `cluster_a` and `cluster_b` must be placed on the
1610 // same device as `chosen_device`. This is because the _XlaCompile and
1611 // _XlaRun kernels are going to run on and therefore try to access the
1612 // resource variables from `chosen_device`, which will be an error if the
1613 // resource variables are placed on some other device.
1614 auto resource_op_device_ok =
1615 [&](absl::optional<DeviceId> resource_op_device) {
1616 return !resource_op_device.has_value() ||
1617 *resource_op_device == chosen_device;
1618 };
1619
1620 return resource_op_device_ok(cluster_a.resource_op_device()) &&
1621 resource_op_device_ok(cluster_b.resource_op_device());
1622 }
1623
1624 // Returns `true` iff we should compile `cluster`.
ShouldCompileClusterImpl(const Cluster & cluster)1625 StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileClusterImpl(
1626 const Cluster& cluster) {
1627 TF_ASSIGN_OR_RETURN(DeviceId chosen_device,
1628 PickDeviceForXla(device_info_cache_, cluster.devices(),
1629 /*allow_mixing_unknown_and_cpu=*/false));
1630
1631 const DeviceType& device_type =
1632 device_info_cache_.GetDeviceTypeFor(chosen_device);
1633 const XlaOpRegistry::DeviceRegistration* registration =
1634 device_info_cache_.GetCompilationDevice(chosen_device);
1635 TF_RET_CHECK(registration)
1636 << "chosen device = " << device_info_cache_.GetNameFor(chosen_device)
1637 << "; device type = " << device_type.type() << "; devices ("
1638 << device_info_cache_.DebugString(cluster.devices());
1639
1640 bool should_compile =
1641 cluster.is_xla_compile_attr_true() ||
1642 registration->autoclustering_policy ==
1643 XlaOpRegistry::AutoclusteringPolicy::kAlways ||
1644 (registration->autoclustering_policy ==
1645 XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally &&
1646 global_jit_level_ != OptimizerOptions::OFF);
1647
1648 if (!should_compile && global_jit_level_ != OptimizerOptions::OFF &&
1649 device_type.type_string() == DEVICE_CPU) {
1650 static absl::once_flag once;
1651 absl::call_once(once, [] {
1652 LOG(WARNING)
1653 << "(One-time warning): Not using XLA:CPU for cluster because envvar "
1654 "TF_XLA_FLAGS=--tf_xla_cpu_global_jit was not set. If you want "
1655 "XLA:CPU, either set that envvar, or use experimental_jit_scope "
1656 "to enable XLA:CPU. To confirm that XLA is active, pass "
1657 "--vmodule=xla_compilation_cache=1 (as a proper command-line "
1658 "flag, not via TF_XLA_FLAGS) or set the envvar "
1659 "XLA_FLAGS=--xla_hlo_profile.";
1660 MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1661 if (flags->tf_xla_cpu_global_jit) {
1662 LOG(WARNING)
1663 << "(Although the tf_xla_cpu_global_jit flag is currently enabled, "
1664 "perhaps it wasn't enabled at process startup?)";
1665 }
1666 });
1667 }
1668
1669 VLOG(3) << (should_compile ? "Compiling" : "Not compiling")
1670 << " cluster with device "
1671 << device_info_cache_.GetNameFor(chosen_device);
1672
1673 return should_compile;
1674 }
1675
ShouldCompileCluster(const Cluster & cluster)1676 StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileCluster(
1677 const Cluster& cluster) {
1678 auto it = should_compile_cluster_cache_.find(&cluster);
1679 if (it != should_compile_cluster_cache_.end()) {
1680 return it->second;
1681 }
1682
1683 TF_ASSIGN_OR_RETURN(bool should_compile, ShouldCompileClusterImpl(cluster));
1684 should_compile_cluster_cache_.insert({&cluster, should_compile});
1685 return should_compile;
1686 }
1687
MarkForCompilation(const GraphOptimizationPassOptions & options,const MarkForCompilationPassImpl::DebugOptions & debug_options)1688 Status MarkForCompilation(
1689 const GraphOptimizationPassOptions& options,
1690 const MarkForCompilationPassImpl::DebugOptions& debug_options) {
1691 Graph* graph = options.graph->get();
1692 FunctionLibraryDefinition* flib_def = options.flib_def;
1693
1694 // Deadness analysis expects a graph with source and sink edges properly
1695 // connected but sometimes the incoming graph does not follow this invariant.
1696 // So fix up the source and sink edges before calling into deadness analysis.
1697 FixupSourceAndSinkEdges(graph);
1698
1699 // See explanation on `kXlaAlreadyClustered`.
1700 for (Node* n : graph->nodes()) {
1701 if (n->attrs().Find(kXlaAlreadyClustered)) {
1702 return Status::OK();
1703 }
1704 }
1705
1706 return MarkForCompilationPassImpl{debug_options, graph, flib_def,
1707 options.session_options != nullptr
1708 ? options.session_options->env
1709 : Env::Default(),
1710 GetGlobalJitLevelForGraph(options)}
1711 .Run();
1712 }
1713
GetPointerToFuel(int64 initial_value)1714 std::atomic<int64>* GetPointerToFuel(int64 initial_value) {
1715 static std::atomic<int64>* fuel = [&]() {
1716 std::atomic<int64>* fuel = new std::atomic<int64>;
1717 *fuel = initial_value;
1718 return fuel;
1719 }();
1720
1721 return fuel;
1722 }
1723 } // anonymous namespace
1724
Run(const GraphOptimizationPassOptions & options)1725 Status MarkForCompilationPass::Run(
1726 const GraphOptimizationPassOptions& options) {
1727 MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1728
1729 MarkForCompilationPassImpl::DebugOptions debug_options;
1730 debug_options.ignore_deadness_checks =
1731 flags->tf_xla_disable_deadness_safety_checks_for_debugging;
1732 debug_options.ignore_resource_variable_checks =
1733 flags->tf_xla_disable_resource_variable_safety_checks_for_debugging;
1734 debug_options.ignore_xla_compile_attr = false;
1735 debug_options.max_cluster_size = flags->tf_xla_max_cluster_size;
1736 debug_options.min_cluster_size = flags->tf_xla_min_cluster_size;
1737 debug_options.fuel = GetPointerToFuel(flags->tf_xla_clustering_fuel);
1738 debug_options.dump_graphs = flags->tf_xla_clustering_debug;
1739
1740 return MarkForCompilation(options, debug_options);
1741 }
1742
RunForTest(const GraphOptimizationPassOptions & options,bool disable_deadness_analysis)1743 Status MarkForCompilationPass::RunForTest(
1744 const GraphOptimizationPassOptions& options,
1745 bool disable_deadness_analysis) {
1746 MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1747
1748 MarkForCompilationPassImpl::DebugOptions debug_options;
1749 debug_options.ignore_deadness_checks = disable_deadness_analysis;
1750 debug_options.ignore_resource_variable_checks =
1751 flags->tf_xla_disable_resource_variable_safety_checks_for_debugging;
1752 debug_options.ignore_xla_compile_attr = true;
1753 debug_options.max_cluster_size = flags->tf_xla_max_cluster_size;
1754 debug_options.min_cluster_size = flags->tf_xla_min_cluster_size;
1755 debug_options.fuel = GetPointerToFuel(flags->tf_xla_clustering_fuel);
1756 debug_options.dump_graphs = flags->tf_xla_clustering_debug;
1757
1758 return MarkForCompilation(options, debug_options);
1759 }
1760
GetAllowlistTable()1761 absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable() {
1762 // Table format: category name: {list of TF operations in that category}
1763 static absl::flat_hash_map<string, std::vector<string>>* result =
1764 new absl::flat_hash_map<string, std::vector<string>>{
1765 // Unary
1766 {"PW",
1767 {"ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin",
1768 "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp", "Expm1",
1769 "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal", "Log",
1770 "Log1p", "Invert", "LogicalNot", "Ndtri", "Neg", "Rint", "Round",
1771 "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt",
1772 "Square", "Tan", "Tanh", "Real", "Imag", "Erf", "Erfc", "Erfinv",
1773 "Lgamma", "Digamma",
1774 // Binary
1775 "Add", "AddV2", "Sub", "Mul", "Div", "Atan2", "Complex", "DivNoNan",
1776 "MulNoNan", "FloorDiv", "Xlogy", "Xlog1py", "Xdivy", "FloorMod",
1777 "BitwiseAnd", "BitwiseOr", "BitwiseXor", "LeftShift", "RightShift",
1778 "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
1779 "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "TruncateDiv",
1780 "TruncateMod", "Equal", "NotEqual", "Greater", "GreaterEqual",
1781 "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", "SoftsignGrad",
1782 "TanhGrad", "Pow", "SquaredDifference", "ApproximateEqual",
1783 // Others
1784 "AddN", "Bitcast", "Cast", "ClipByValue", "Const", "Empty",
1785 "Identity", "IdentityN", "Relu", "Relu6", "ReluGrad", "Relu6Grad",
1786 "LeakyReluGrad", "Elu", "EluGrad", "Selu", "SeluGrad", "Select",
1787 "SelectV2", "Transpose", "ConjugateTranspose",
1788 "_UnaryOpsComposition",
1789 // The following 4 operations are converted to identity
1790 "PlaceholderWithDefault", "PreventGradient", "StopGradient",
1791 "Snapshot"}},
1792 // clang-format off
1793 {"RED",
1794 {"All", "Any", "Min", "Max", "Mean", "Prod", "Sum"}},
1795 // clang-format on
1796 {"PWRED",
1797 {"ArgMax", "ArgMin", "DiagPart", "Softmax",
1798 "SparseSoftmaxCrossEntropyWithLogits", "LogSoftmax"}},
1799 {"REDUCEWINDOW",
1800 {"ArgMax", "ArgMin", "DiagPart", "Softmax",
1801 "SparseSoftmaxCrossEntropyWithLogits", "LogSoftmax"}},
1802 {"REDUCEWINDOWPW", {"BiasAddGrad", "LRN", "LRNGrad"}},
1803 {"BN",
1804 {"FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3",
1805 "_FusedBatchNormEx", "FusedBatchNormGrad", "FusedBatchNormGradV2",
1806 "FusedBatchNormGradV3"}},
1807 {"SORT", {"TopKV2"}}, // XLA version much faster then TF version.
1808 {"MISC",
1809 // clang-format off
1810 {"BroadcastTo", "ExpandDims", "Fill", "NoOp",
1811 "Range", "Rank", "Reshape", "Shape", "ShapeN", "Size", "Squeeze",
1812 "Transpose", "ZerosLike", "OnesLike", "BiasAdd" /*PW + Broadcast*/,
1813 "BroadcastArgs", "BroadcastGradientArgs", "OneHot", "Concat", "ConcatV2",
1814 "ConcatOffset", "Const", "MirrorPad", "MirrorPadGrad", "Pack", "Pad",
1815 "PadV2", "Reverse", "ReverseV2", "ReverseSequence", "Slice", "Split",
1816 "SplitV", "StridedSlice", "StridedSliceGrad",
1817 "ResourceStridedSliceAssign", "Tile", "Transpose", "InvertPermutation",
1818 "Unpack", "DeviceIndex", "TensorStridedSliceUpdate",
1819 }}};
1820 // clang-format on
1821 return result;
1822 }
1823
1824 namespace testing {
ResetClusterSequenceNumber()1825 void ResetClusterSequenceNumber() { cluster_sequence_num = 0; }
1826
GetKnownXLAAllowlistOp()1827 absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
1828 absl::flat_hash_set<string> result{"AdjustContrastv2",
1829 "AdjustHue",
1830 "AdjustSaturation",
1831 "Asinh",
1832 "Assert",
1833 "AssignAddVariableOp",
1834 "AssignSubVariableOp",
1835 "AssignVariableOp",
1836 "AvgPool",
1837 "AvgPool3D",
1838 "AvgPool3DGrad",
1839 "AvgPoolGrad",
1840 "BatchMatMul",
1841 "BatchMatMulV2",
1842 "BatchToSpace",
1843 "BatchToSpaceND",
1844 "BesselI0e",
1845 "BesselI1e",
1846 "Betainc",
1847 "BiasAddV1",
1848 "Bucketize",
1849 "Case",
1850 "CheckNumerics",
1851 "Cholesky",
1852 "ControlTrigger",
1853 "Conv2D",
1854 "Conv2DBackpropFilter",
1855 "Conv2DBackpropInput",
1856 "Conv3D",
1857 "Conv3DBackpropFilterV2",
1858 "Conv3DBackpropInputV2",
1859 "Cross",
1860 "Cumprod",
1861 "Cumsum",
1862 "DataFormatDimMap",
1863 "DataFormatVecPermute",
1864 "DepthToSpace",
1865 "DepthwiseConv2dNative",
1866 "DepthwiseConv2dNativeBackpropFilter",
1867 "DepthwiseConv2dNativeBackpropInput",
1868 "Dequantize",
1869 "Diag",
1870 "DynamicStitch",
1871 "Einsum",
1872 "EmptyTensorList",
1873 "EnsureShape",
1874 "ExtractImagePatches",
1875 "Igamma",
1876 "IgammaGradA",
1877 "RandomGammaGrad",
1878 "Igammac",
1879 "FFT",
1880 "FFT2D",
1881 "FFT3D",
1882 "FakeParam",
1883 "FakeQuantWithMinMaxArgs",
1884 "FakeQuantWithMinMaxArgsGradient",
1885 "FakeQuantWithMinMaxVars",
1886 "FakeQuantWithMinMaxVarsGradient",
1887 "Gather",
1888 "GatherNd",
1889 "GatherV2",
1890 "HSVToRGB",
1891 "IFFT",
1892 "IFFT2D",
1893 "IFFT3D",
1894 "IRFFT",
1895 "IRFFT2D",
1896 "IRFFT3D",
1897 "If",
1898 "InTopKV2",
1899 "L2Loss",
1900 "LeakyRelu",
1901 "LinSpace",
1902 "ListDiff",
1903 "LogMatrixDeterminant",
1904 "LowerBound",
1905 "MatMul",
1906 "MatrixBandPart",
1907 "MatrixDiag",
1908 "MatrixDiagPart",
1909 "MatrixDiagPartV2",
1910 "MatrixDiagPartV3",
1911 "MatrixDiagV2",
1912 "MatrixDiagV3",
1913 "MatrixInverse",
1914 "MatrixSetDiag",
1915 "MatrixSetDiagV2",
1916 "MatrixSetDiagV3",
1917 "MatrixSolve",
1918 "MatrixTriangularSolve",
1919 "MaxPool",
1920 "MaxPool3D",
1921 "MaxPool3DGrad",
1922 "MaxPool3DGradGrad",
1923 "MaxPoolGrad",
1924 "MaxPoolGradGrad",
1925 "MaxPoolGradGradV2",
1926 "MaxPoolGradV2",
1927 "MaxPoolV2",
1928 "Multinomial",
1929 "NextAfter",
1930 "NonMaxSuppressionV4",
1931 "ParallelDynamicStitch",
1932 "ParameterizedTruncatedNormal",
1933 "PartitionedCall",
1934 "Polygamma",
1935 "PopulationCount",
1936 "Qr",
1937 "QuantizeAndDequantizeV2",
1938 "QuantizeAndDequantizeV3",
1939 "RFFT",
1940 "RFFT2D",
1941 "RFFT3D",
1942 "RGBToHSV",
1943 "RandomShuffle",
1944 "RandomStandardNormal",
1945 "RandomUniform",
1946 "RandomUniformInt",
1947 "ReadVariableOp",
1948 "ResizeBilinear",
1949 "ResizeBilinearGrad",
1950 "ResizeNearestNeighbor",
1951 "ResourceApplyAdaMax",
1952 "ResourceApplyAdadelta",
1953 "ResourceApplyAdagrad",
1954 "ResourceApplyAdagradDA",
1955 "ResourceApplyAdagradV2",
1956 "ResourceApplyAdam",
1957 "ResourceApplyAddSign",
1958 "ResourceApplyCenteredRMSProp",
1959 "ResourceApplyFtrl",
1960 "ResourceApplyFtrlV2",
1961 "ResourceApplyGradientDescent",
1962 "ResourceApplyKerasMomentum",
1963 "ResourceApplyMomentum",
1964 "ResourceApplyPowerSign",
1965 "ResourceApplyProximalAdagrad",
1966 "ResourceApplyProximalGradientDescent",
1967 "ResourceApplyRMSProp",
1968 "ResourceGather",
1969 "ResourceScatterAdd",
1970 "ResourceScatterDiv",
1971 "ResourceScatterMax",
1972 "ResourceScatterMin",
1973 "ResourceScatterMul",
1974 "ResourceScatterNdAdd",
1975 "ResourceScatterNdSub",
1976 "ResourceScatterNdUpdate",
1977 "ResourceScatterSub",
1978 "ResourceScatterUpdate",
1979 "RngReadAndSkip",
1980 "RngSkip",
1981 "Roll",
1982 "ScatterNd",
1983 "SelfAdjointEigV2",
1984 "SoftmaxCrossEntropyWithLogits",
1985 "SpaceToBatch",
1986 "SpaceToBatchND",
1987 "SpaceToDepth",
1988 "SparseMatMul",
1989 "SparseToDense",
1990 "StackCloseV2",
1991 "StackPopV2",
1992 "StackPushV2",
1993 "StackV2",
1994 "StatefulPartitionedCall",
1995 "StatefulStandardNormalV2",
1996 "StatefulTruncatedNormal",
1997 "StatefulUniform",
1998 "StatefulUniformFullInt",
1999 "StatefulUniformInt",
2000 "StatelessCase",
2001 "StatelessIf",
2002 "StatelessMultinomial",
2003 "StatelessRandomGetAlg",
2004 "StatelessRandomGetKeyCounter",
2005 "StatelessRandomGetKeyCounterAlg",
2006 "StatelessRandomNormal",
2007 "StatelessRandomNormalV2",
2008 "StatelessRandomUniform",
2009 "StatelessRandomUniformV2",
2010 "StatelessRandomUniformInt",
2011 "StatelessRandomUniformIntV2",
2012 "StatelessRandomUniformFullInt",
2013 "StatelessRandomUniformFullIntV2",
2014 "StatelessTruncatedNormal",
2015 "StatelessTruncatedNormalV2",
2016 "StatelessWhile",
2017 "Svd",
2018 "SymbolicGradient",
2019 "TensorArrayCloseV3",
2020 "TensorArrayConcatV3",
2021 "TensorArrayGatherV3",
2022 "TensorArrayGradV3",
2023 "TensorArrayReadV3",
2024 "TensorArrayScatterV3",
2025 "TensorArraySizeV3",
2026 "TensorArraySplitV3",
2027 "TensorArrayV3",
2028 "TensorArrayWriteV3",
2029 "TensorListConcatV2",
2030 "TensorListElementShape",
2031 "TensorListFromTensor",
2032 "TensorListGather",
2033 "TensorListGetItem",
2034 "TensorListLength",
2035 "TensorListPopBack",
2036 "TensorListPushBack",
2037 "TensorListReserve",
2038 "TensorListSetItem",
2039 "TensorListSplit",
2040 "TensorListStack",
2041 "TensorScatterAdd",
2042 "TensorScatterMax",
2043 "TensorScatterMin",
2044 "TensorScatterSub",
2045 "TensorScatterUpdate",
2046 "TridiagonalSolve",
2047 "TruncatedNormal",
2048 "Unique",
2049 "UpperBound",
2050 "UnsortedSegmentMax",
2051 "UnsortedSegmentMin",
2052 "UnsortedSegmentProd",
2053 "UnsortedSegmentSum",
2054 "VarIsInitializedOp",
2055 "VariableShape",
2056 "Where",
2057 "While",
2058 "XlaBroadcastHelper",
2059 "XlaConv",
2060 "XlaDequantize",
2061 "XlaDot",
2062 "XlaDynamicSlice",
2063 "XlaDynamicUpdateSlice",
2064 "XlaEinsum",
2065 "XlaGather",
2066 "XlaIf",
2067 "XlaKeyValueSort",
2068 "XlaPad",
2069 "XlaRecv",
2070 "XlaReduce",
2071 "XlaReduceWindow",
2072 "XlaReplicaId",
2073 "XlaScatter",
2074 "XlaSelectAndScatter",
2075 "XlaSelfAdjointEig",
2076 "XlaSend",
2077 "XlaSetBound",
2078 "XlaSetDynamicDimensionSize",
2079 "XlaSharding",
2080 "XlaSort",
2081 "XlaSpmdFullToShardShape",
2082 "XlaSpmdShardToFullShape",
2083 "XlaSvd",
2084 "XlaVariadicReduce",
2085 "XlaVariadicSort",
2086 "XlaWhile",
2087 "Zeta",
2088 "_Arg",
2089 "_ArrayToList",
2090 "_ListToArray",
2091 "_Retval"};
2092 return result;
2093 }
2094
2095 } // namespace testing
2096 } // namespace tensorflow
2097