1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
17 
18 #include <atomic>
19 #include <deque>
20 #include <limits>
21 #include <unordered_map>
22 #include <unordered_set>
23 
24 #include "absl/base/call_once.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/compiler/jit/compilability_check_util.h"
29 #include "tensorflow/compiler/jit/deadness_analysis.h"
30 #include "tensorflow/compiler/jit/defs.h"
31 #include "tensorflow/compiler/jit/device_util.h"
32 #include "tensorflow/compiler/jit/flags.h"
33 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
34 #include "tensorflow/compiler/jit/xla_cluster_util.h"
35 #include "tensorflow/compiler/tf2xla/const_analysis.h"
36 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
37 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
38 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
39 #include "tensorflow/compiler/xla/statusor.h"
40 #include "tensorflow/compiler/xla/union_find.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/core/common_runtime/function.h"
43 #include "tensorflow/core/common_runtime/graph_constructor.h"
44 #include "tensorflow/core/framework/bounds_check.h"
45 #include "tensorflow/core/framework/graph_def_util.h"
46 #include "tensorflow/core/framework/memory_types.h"
47 #include "tensorflow/core/framework/node_def.pb.h"
48 #include "tensorflow/core/framework/op_kernel.h"
49 #include "tensorflow/core/framework/tensor.pb.h"
50 #include "tensorflow/core/framework/types.h"
51 #include "tensorflow/core/graph/algorithm.h"
52 #include "tensorflow/core/graph/control_flow.h"
53 #include "tensorflow/core/lib/gtl/cleanup.h"
54 #include "tensorflow/core/lib/strings/stringprintf.h"
55 #include "tensorflow/core/public/version.h"
56 #include "tensorflow/core/util/dump_graph.h"
57 
58 namespace tensorflow {
59 
60 namespace {
61 using DeadnessPredicate = DeadnessAnalysis::DeadnessPredicate;
62 using jit::DeviceId;
63 using jit::DeviceSet;
64 using xla::StatusOr;
65 
66 // The clusters we create here are eventually lowered into an
67 // _XlaCompile/_XlaRun pair with a TF executor "fallback" that uses the
68 // PartitionedCall op to execute the cluster in the regular graph executor if
69 // need be.  PartitionedCall, however, reruns the entire TF graph optimization
70 // pipeline over the cluster which includes this mark for compilation pass.  To
71 // avoid endlessly recursing we tag nodes that we've already visited with this
72 // attribute so that we can bail out if we see them a second time.
73 //
74 // TODO(sanjoy): This method is not robust since it is possible that the
75 // optimizations run by PartitionedCall can mutate the cluster arbitrarily,
76 // dropping the kXlaAlreadyClustered attributes from all nodes in the process.
77 // The correct fix is to use the ConfigProto to pass in some sort of flag into
78 // the PartitionedCall kernel that tells it to not rerun auto-clustering on the
79 // cluster.
80 const char* kXlaAlreadyClustered = "_XlaAlreadyClustered";
81 
82 class MarkForCompilationPassImpl {
83  public:
84   struct DebugOptions {
85     // If true, do not respect the results of deadness analysis.
86     bool ignore_deadness_checks;
87 
88     // If true, do not do safety checks to preserve TensorFlow's resource
89     // variable concurrency semantics.
90     bool ignore_resource_variable_checks;
91 
92     // If true, do not respect the _XlaCompile=false attribute.
93     bool ignore_xla_compile_attr;
94 
95     int max_cluster_size;
96     int min_cluster_size;
97 
98     // Compiler fuel for the auto-clustering algorithm.
99     //
100     // We decrement this value by one on every time we choose a compilation
101     // candidate and we stop clustering when it hits zero.  This means the
102     // initial value for this variable (via --tf_xla_clustering_fuel=N)
103     // effectively acts as a "cap" for how much we cluster and we can bisect
104     // over this initial value to discover clustering decisions that cause a
105     // miscompile or a performance regression.
106     std::atomic<int64>* fuel;
107 
108     bool dump_graphs;
109   };
110 
MarkForCompilationPassImpl(DebugOptions debug_options,Graph * graph,FunctionLibraryDefinition * flib_def,Env * env,OptimizerOptions::GlobalJitLevel global_jit_level)111   MarkForCompilationPassImpl(DebugOptions debug_options, Graph* graph,
112                              FunctionLibraryDefinition* flib_def, Env* env,
113                              OptimizerOptions::GlobalJitLevel global_jit_level)
114       : debug_options_(debug_options),
115         graph_(graph),
116         flib_def_(flib_def),
117         env_(env),
118         global_jit_level_(global_jit_level) {}
119 
120   Status Run();
121 
122  private:
123   // Represents a "cluster" or a connected subgraph of a TensorFlow graph.
124   class Cluster {
125    public:
126     // Constructs a trivial cluster representing a single TF node.
Cluster(int tf_graph_node_id,int effective_cluster_size,bool has_functional_control_flow,DeviceSet devices,absl::optional<DeviceId> resource_op_device,absl::optional<int> resource_var_operation_node_id,absl::optional<DeadnessPredicate> deadness_predicate,bool is_xla_compile_attr_true,absl::optional<string> xla_scope)127     Cluster(int tf_graph_node_id, int effective_cluster_size,
128             bool has_functional_control_flow, DeviceSet devices,
129             absl::optional<DeviceId> resource_op_device,
130             absl::optional<int> resource_var_operation_node_id,
131             absl::optional<DeadnessPredicate> deadness_predicate,
132             bool is_xla_compile_attr_true, absl::optional<string> xla_scope)
133         : cycles_graph_node_id_(tf_graph_node_id),
134           effective_cluster_size_(effective_cluster_size),
135           has_functional_control_flow_(has_functional_control_flow),
136           devices_(std::move(devices)),
137           resource_op_device_(resource_op_device),
138           deadness_predicate_(deadness_predicate),
139           is_xla_compile_attr_true_(is_xla_compile_attr_true),
140           xla_scope_(std::move(xla_scope)) {
141       if (resource_var_operation_node_id.has_value()) {
142         resource_var_operation_node_ids_.push_back(
143             *resource_var_operation_node_id);
144       }
145     }
146 
147     // Merges `other` into this cluster, and clears `other`.  This method is
148     // closely tied with the implementation of `MarkForCompilationPassImpl`.
149     void Merge(Cluster* other);
150 
151     // If this is a trivial cluster containing only one node then return the ID
152     // of that node.  May not be called otherwise.
GetIdOfOnlyNode() const153     int GetIdOfOnlyNode() const {
154       DCHECK_EQ(cluster_size(), 1);
155       return cycles_graph_node_id();
156     }
157 
158     // The number of TF nodes in this cluster.
cluster_size() const159     int cluster_size() const { return cluster_size_; }
160 
161     // The ID of the cluster as represented in `cycles_graph_`.
cycles_graph_node_id() const162     int cycles_graph_node_id() const { return cycles_graph_node_id_; }
163 
164     // Sets the ID of the cluster as represented in `cycles_graph_`.
set_cycles_graph_node_id(int cycles_graph_node_id)165     void set_cycles_graph_node_id(int cycles_graph_node_id) {
166       cycles_graph_node_id_ = cycles_graph_node_id;
167     }
168 
169     // The size of the cluster excluding constant and identity nodes.
effective_cluster_size() const170     int effective_cluster_size() const { return effective_cluster_size_; }
171 
172     // True if the cluster has functional control flow like `If` and `While`.
has_functional_control_flow() const173     bool has_functional_control_flow() const {
174       return has_functional_control_flow_;
175     }
176 
177     // The set of devices nodes in the cluster are placed on.
devices() const178     const DeviceSet& devices() const { return devices_; }
179 
180     // If the cluster has a resource operation then the device the resource
181     // operation is placed on.  A cluster may have resource ops placed only on a
182     // single device.
resource_op_device() const183     const absl::optional<DeviceId>& resource_op_device() const {
184       return resource_op_device_;
185     }
186 
187     // If not nullopt the a predicate that is true iff the cluster is alive.
188     // Otherwise the user has (unsafely) disabled deadness analysis.  If this is
189     // unset on a single Cluster instance then it is unset on all Cluster
190     // instances.
deadness_predicate() const191     const absl::optional<DeadnessPredicate>& deadness_predicate() const {
192       return deadness_predicate_;
193     }
194 
195     // If true then the cluster has a XlaCompile=true attribute on one of its
196     // nodes.
is_xla_compile_attr_true() const197     bool is_xla_compile_attr_true() const { return is_xla_compile_attr_true_; }
198 
199     // If not nullopt then the all nodes in the cluster either do not have the
200     // XlaScope attribute set or have it set to the value returned.
xla_scope() const201     const absl::optional<string>& xla_scope() const { return xla_scope_; }
202 
203     // Returns the TF graph node IDs for the resource variable operations in
204     // this cluster.
resource_var_operation_node_ids() const205     absl::Span<const int> resource_var_operation_node_ids() const {
206       return resource_var_operation_node_ids_;
207     }
208 
DebugString(const Graph & graph) const209     string DebugString(const Graph& graph) const {
210       Node* node = graph.FindNodeId(cycles_graph_node_id());
211       if (!node) {
212         // This should never happen but we try to be resilient because this is a
213         // debugging aid.
214         return absl::StrCat("NULL NODE IN #", cycles_graph_node_id());
215       }
216 
217       if (cluster_size() == 1) {
218         return absl::StrCat("<", node->name(), " #", cycles_graph_node_id(),
219                             ">");
220       }
221 
222       return absl::StrCat("<", node->name(), " + ", cluster_size() - 1,
223                           " others #", cycles_graph_node_id(), ">");
224     }
225 
226    private:
227     int cluster_size_ = 1;
228     int cycles_graph_node_id_;
229     int effective_cluster_size_;
230     bool has_functional_control_flow_;
231     DeviceSet devices_;
232     absl::optional<DeviceId> resource_op_device_;
233     absl::optional<DeadnessPredicate> deadness_predicate_;
234     bool is_xla_compile_attr_true_;
235     absl::optional<string> xla_scope_;
236     std::vector<int> resource_var_operation_node_ids_;
237 
238     TF_DISALLOW_COPY_AND_ASSIGN(Cluster);
239   };
240 
241   // If `cluster` has only a single node then returns that, otherwise returns
242   // nullptr.
243   Node* GetOnlyNodeIn(const Cluster& cluster);
244 
245   // Returns true if `cluster` is a trivial cluster containing a "sink like"
246   // node -- a NoOp node that only the Sink node control depends on.
247   bool IsSinkLike(const Cluster& cluster);
248 
249   // Returns true if `cluster` looks like an "i++" operation on an integer
250   // scalar resource variable.
251   bool IsScalarIntegerResourceOperation(const Cluster& cluster);
252 
253   // ---------------------------------------------------------------------------
254   // The pass proceeds in four steps, out of which `RunEdgeContractionLoop` and
255   // `CreateClusters` do most of the heavy lifting.
256 
257   // Initializes some internal data structures.
258   //
259   // If this returns false then Initialize exited early (either because there is
260   // nothing to do or we saw a graph that we can't handle) and not all the
261   // fields in this MarkForCompilationPassImpl instance are set up.
262   StatusOr<bool> Initialize();
263 
264   // Runs through the entire cluster graph in post-order and calls `fn(from,
265   // to)` on each edge.  `fn(from, to)` is expected to return true if it was
266   // able to contract `from`->`to`.
267   //
268   // Returns true if `fn` returned true for any edge.
269   template <typename FnTy>
270   StatusOr<bool> ForEachEdgeInPostOrder(FnTy fn);
271 
272   // Contracts as many edges as possible to create XLA clusters.  After this
273   // finishes the clustering decisions made are implicitly stored in
274   // `clusters_`.
275   Status RunEdgeContractionLoop();
276 
277   // Manifests the clustering decisions into the TF graph by tagging nodes with
278   // an `_XlaCluster` attribute.  Also some basic filter logic, like
279   // tf_xla_min_cluster_size, are applied here.
280   Status CreateClusters();
281 
282   Status DumpDebugInfo();
283 
IsCompilationCandidate(Node * n) const284   bool IsCompilationCandidate(Node* n) const {
285     return compilation_candidates_.find(n) != compilation_candidates_.end();
286   }
287 
288   // Tries to contract the edge from cluster `from` to cluster `to`.  Returns
289   // true if successful.
290   StatusOr<bool> TryToContractEdge(Cluster* from, Cluster* to);
291 
292   // Nodes that XLA can compile are put in `compilation_candidates_`.
293   Status FindCompilationCandidates();
294 
295   bool CompilationDisallowedByXlaCompileAttr(Node* node);
296 
297   // Populates `clusters_`.
298   Status BuildInitialClusterSet();
299 
300   StatusOr<bool> ShouldCompileClusterImpl(const Cluster& cluster);
301 
302   StatusOr<bool> ShouldCompileCluster(const Cluster& cluster);
303 
304   StatusOr<bool> ClusteringWillIntroduceInterDeviceDependency(
305       const Cluster& from, const Cluster& to);
306 
307   // Returns true if the devices in `cluster_a` and `cluster_b` are compatible
308   // and therefore not a hindrance for combining the two clusters into a larger
309   // cluster.
310   StatusOr<bool> AreDevicesCompatible(const Cluster& cluster_a,
311                                       const Cluster& cluster_b);
312 
313   void DumpPostClusteringGraphs();
314   void VLogClusteringSummary();
315 
MakeNewCluster(int cycles_graph_node_id,int effective_cluster_size,bool has_functional_control_flow,const DeviceSet & device_set,absl::optional<DeviceId> resource_op_device,absl::optional<int> resource_var_operation_node_id,absl::optional<DeadnessPredicate> deadness_predicate,bool is_xla_compile_attr_true,absl::optional<string> xla_scope)316   Cluster* MakeNewCluster(int cycles_graph_node_id, int effective_cluster_size,
317                           bool has_functional_control_flow,
318                           const DeviceSet& device_set,
319                           absl::optional<DeviceId> resource_op_device,
320                           absl::optional<int> resource_var_operation_node_id,
321                           absl::optional<DeadnessPredicate> deadness_predicate,
322                           bool is_xla_compile_attr_true,
323                           absl::optional<string> xla_scope) {
324     cluster_storage_.push_back(absl::make_unique<Cluster>(
325         cycles_graph_node_id, effective_cluster_size,
326         has_functional_control_flow, device_set, resource_op_device,
327         resource_var_operation_node_id, deadness_predicate,
328         is_xla_compile_attr_true, xla_scope));
329     return cluster_storage_.back().get();
330   }
331 
332   absl::optional<string> GetXlaScope(Node* n);
333 
334   // Returns the cluster for node `n`.  If two nodes, N1 and N2, are placed in
335   // the same cluster by the clustering algorithm then this function will return
336   // the same Cluster instance for N1 and N2.
337   //
338   // Returns nullptr if `n` is not a compilation candidate.
GetClusterForNode(Node * n)339   Cluster* GetClusterForNode(Node* n) {
340     return cluster_for_node_[n->id()].Get();
341   }
342 
343   // Returns the cluster for a node in `cycles_graph_`.  This uses the same
344   // underlying map because of how we set things up, but we can do an additional
345   // CHECK in this accessor.
346   //
347   // Returns nullptr if `node_id` is not a compilation candidate.
GetClusterForCyclesGraphNode(int node_id)348   Cluster* GetClusterForCyclesGraphNode(int node_id) {
349     // We have to check `graph_->FindNodeId(node) == nullptr` because we add all
350     // nodes in [0, graph_->num_node_ids()) to the cycle detection graph but the
351     // TF graph may be missing some node ids.
352     if (node_id >= graph_->num_node_ids() ||
353         graph_->FindNodeId(node_id) == nullptr) {
354       return nullptr;
355     }
356     Cluster* cluster = cluster_for_node_[node_id].Get();
357     if (cluster) {
358       DCHECK_EQ(cluster->cycles_graph_node_id(), node_id);
359     }
360     return cluster;
361   }
362 
363   bool LogNotContractableAndReturnFalse(Cluster* from, Cluster* to,
364                                         absl::string_view reason);
365 
366   // Finds a path in `cycles_graph_` from `from` to `to` that is not a direct
367   // edge from `from` to `to`.
368   //
369   // Tries to find a path that contains at least one unclusterable node.
370   std::vector<int> FindAlternatePathForDebugging(int from, int to);
371 
372   // Returns a string representing `cycles_graph_node_id`.  If the node is
373   // unclusterable (either it is a phatom "frame" node or is not a compilation
374   // candidate) then set `*found_unclustered` to true.
375   string DebugStringForCyclesGraphNode(int node_id, bool* found_unclustered);
376 
377   // We could not contract the edge from `from` to `to`.  Return a string
378   // describing an alternate path from `from` to `to` (besides the direct edge
379   // from `from` to `to`) which would have created a cycle had we contracted the
380   // edge.
381   //
382   // Tries (if possible) to find a path that contains at least one unclusterable
383   // node as it is surprising to the user if we print "A->B could not be
384   // contracted because of the path [P,Q,R]" where P, Q and R are all clusters
385   // since in that case a natural question is why we could not form a {A, P, Q,
386   // R, B} cluster.
387   string DescribePotentialCycle(int from, int to);
388 
389   // Merge the clusters `cluster_from` and `cluster_to`. After this step the
390   // larger combined cluster is represented by `cluster_from`, but can have
391   // `cycles_graph_`'s ID of either `cluster_from` or `cluster_to` depending on
392   // which way will require less operations.
MergeClusters(Cluster * cluster_from,Cluster * cluster_to)393   bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) {
394     int from = cluster_from->cycles_graph_node_id();
395     int to = cluster_to->cycles_graph_node_id();
396 
397     auto optional_merged_node = cycles_graph_.ContractEdge(from, to);
398     if (!optional_merged_node.has_value()) {
399       VLOG(3) << "Could not contract " << cluster_from->DebugString(*graph_)
400               << " -> " << cluster_to->DebugString(*graph_)
401               << " because contracting the edge would create a cycle via "
402               << DescribePotentialCycle(from, to) << ".";
403       return false;
404     }
405 
406     // Merge the clusters.
407     cluster_from->Merge(cluster_to);
408     // Update `cycle_graph_`'s ID.
409     cluster_from->set_cycles_graph_node_id(optional_merged_node.value());
410 
411     // Merge the UnionFind<Cluster*>.
412     cluster_for_node_[from].Merge(&cluster_for_node_[to]);
413 
414     return true;
415   }
416 
EdgeContractionFailureMsg(Cluster * from,Cluster * to,absl::string_view reason)417   string EdgeContractionFailureMsg(Cluster* from, Cluster* to,
418                                    absl::string_view reason) {
419     return absl::StrCat("Could not contract ", from->DebugString(*graph_),
420                         " -> ", to->DebugString(*graph_), " because ", reason,
421                         ".");
422   }
423 
424   DebugOptions debug_options_;
425   Graph* graph_;
426   FunctionLibraryDefinition* flib_def_;
427   Env* env_;
428   OptimizerOptions::GlobalJitLevel global_jit_level_;
429   absl::flat_hash_map<const Cluster*, bool> should_compile_cluster_cache_;
430   jit::DeviceInfoCache device_info_cache_;
431 
432   bool initialized_ = false;
433   bool edges_contracted_ = false;
434   bool clusters_created_ = false;
435 
436   std::vector<std::unique_ptr<Cluster>> cluster_storage_;
437   std::vector<UnionFind<Cluster*>> cluster_for_node_;
438   GraphCycles cycles_graph_;
439   OrderedNodeSet compilation_candidates_;
440   std::unique_ptr<DeadnessAnalysis> deadness_analysis_;
441   int64 iteration_count_ = 0;
442   absl::flat_hash_set<std::pair<int, int>> unsafe_resource_deps_;
443 };
444 
FindAlternatePathForDebugging(int from,int to)445 std::vector<int> MarkForCompilationPassImpl::FindAlternatePathForDebugging(
446     int from, int to) {
447   std::vector<int> rpo = cycles_graph_.AllNodesInPostOrder();
448   absl::c_reverse(rpo);
449 
450   // best_pred_for_node[n] contains a predecessor of `n` that has an
451   // unclusterable node in some path from `from` to itself.
452   // best_pred_for_node[n] is unpopulated for nodes that are not reachable from
453   // `from`.  We build this table up inductively by traversing the cycles graph
454   // in RPO.
455   absl::flat_hash_map<int, int> best_pred_for_node;
456   best_pred_for_node[from] = -1;
457 
458   int rpo_index = 0, current_rpo_node;
459   do {
460     current_rpo_node = rpo[rpo_index++];
461     absl::optional<int> some_pred, preferred_pred;
462     for (int pred : cycles_graph_.Predecessors(current_rpo_node)) {
463       if (!best_pred_for_node.contains(pred)) {
464         continue;
465       }
466 
467       // Ignore the from->to edge since we're trying to find an alternate path.
468       if (current_rpo_node == to && pred == from) {
469         continue;
470       }
471 
472       some_pred = pred;
473       if (GetClusterForCyclesGraphNode(pred) == nullptr) {
474         preferred_pred = pred;
475       }
476     }
477 
478     if (some_pred || preferred_pred) {
479       best_pred_for_node[current_rpo_node] =
480           preferred_pred.has_value() ? *preferred_pred : *some_pred;
481     }
482   } while (current_rpo_node != to);
483 
484   auto get_best_pred = [&](int n) {
485     auto it = best_pred_for_node.find(n);
486     CHECK(it != best_pred_for_node.end());
487     return it->second;
488   };
489 
490   std::vector<int> path;
491   int current_path_node = get_best_pred(to);
492   while (current_path_node != from) {
493     path.push_back(current_path_node);
494     current_path_node = get_best_pred(current_path_node);
495   }
496 
497   absl::c_reverse(path);
498   return path;
499 }
500 
DebugStringForCyclesGraphNode(int cycles_graph_node_id,bool * found_unclustered)501 string MarkForCompilationPassImpl::DebugStringForCyclesGraphNode(
502     int cycles_graph_node_id, bool* found_unclustered) {
503   Cluster* cluster = GetClusterForCyclesGraphNode(cycles_graph_node_id);
504   if (cluster) {
505     return cluster->DebugString(*graph_);
506   }
507 
508   *found_unclustered = true;
509   if (cycles_graph_node_id >= graph_->num_node_ids()) {
510     return absl::StrCat("<oob #", cycles_graph_node_id, ">");
511   }
512 
513   Node* node = graph_->FindNodeId(cycles_graph_node_id);
514   if (!node) {
515     return absl::StrCat("<bad #", cycles_graph_node_id, ">");
516   }
517 
518   return node->name();
519 }
520 
DescribePotentialCycle(int from,int to)521 string MarkForCompilationPassImpl::DescribePotentialCycle(int from, int to) {
522   std::vector<string> path_str;
523   bool found_unclustered = false;
524   absl::c_transform(FindAlternatePathForDebugging(from, to),
525                     std::back_inserter(path_str), [&](int node_id) {
526                       return DebugStringForCyclesGraphNode(node_id,
527                                                            &found_unclustered);
528                     });
529   return absl::StrCat(!found_unclustered ? "(all clusters) " : "", "[",
530                       absl::StrJoin(path_str, ","), "]");
531 }
532 
Merge(Cluster * other)533 void MarkForCompilationPassImpl::Cluster::Merge(Cluster* other) {
534   // We keep our own cycles_graph_node_id_ to mirror what GraphCycles does.
535 
536   // Clearing out data structures in `other` is just a memory saving
537   // optimization and not needed for correctness.
538 
539   cluster_size_ += other->cluster_size_;
540   effective_cluster_size_ += other->effective_cluster_size_;
541   has_functional_control_flow_ |= other->has_functional_control_flow_;
542 
543   devices_.UnionWith(other->devices_);
544 
545   DCHECK(!(resource_op_device_.has_value() &&
546            other->resource_op_device_.has_value()) ||
547          *resource_op_device_ == *other->resource_op_device_)
548       << "AreDevicesCompatible should have returned false otherwise!";
549 
550   if (!resource_op_device_.has_value()) {
551     resource_op_device_ = other->resource_op_device_;
552   }
553 
554   is_xla_compile_attr_true_ |= other->is_xla_compile_attr_true_;
555 
556   if (!xla_scope_.has_value()) {
557     xla_scope_ = std::move(other->xla_scope_);
558   }
559 
560   resource_var_operation_node_ids_.reserve(
561       resource_var_operation_node_ids_.size() +
562       other->resource_var_operation_node_ids_.size());
563   absl::c_copy(other->resource_var_operation_node_ids_,
564                std::back_inserter(resource_var_operation_node_ids_));
565   other->resource_var_operation_node_ids_.clear();
566 }
567 
IgnoreResourceOpForSafetyAnalysis(jit::DeviceInfoCache * device_info_cache,const Node & n,bool * ignore)568 Status IgnoreResourceOpForSafetyAnalysis(
569     jit::DeviceInfoCache* device_info_cache, const Node& n, bool* ignore) {
570   // If a resource operation is assigned to XLA_CPU or XLA_GPU explicitly then
571   // ignore it during resource operation safety analysis.  We need this hack
572   // because of two reasons:
573   //
574   //  1. Operations assigned to XLA_CPU and XLA_GPU have to always be compiled.
575   //  2. We don't support live-out values of type DT_RESOURCE and live-in values
576   //     of type DT_RESOURCE that are not resource variables.
577   //
578   // Together these imply we cannot let resource variable safety analysis
579   // constrain e.g. a TensorArrayV3->TensorArrayAssignV3 edge to be in different
580   // clusters: both of them will have to be clustered because of (1) and we
581   // won't be able to keep the edge between the two as neither the input to the
582   // second XLA cluster nor the output from the first XLA cluster are supported
583   // because of (2).
584   //
585   // TODO(b/113100872): This can be fixed if the TensorFlow representation for
586   // TensorArray and Stack on the XLA_{C|G}PU devices were the same in XLA; then
587   // (2) would no longer hold.
588 
589   if (n.assigned_device_name().empty()) {
590     *ignore = false;
591     return Status::OK();
592   }
593 
594   TF_ASSIGN_OR_RETURN(
595       const XlaOpRegistry::DeviceRegistration* registration,
596       device_info_cache->GetCompilationDevice(n.assigned_device_name()));
597 
598   if (!registration) {
599     *ignore = true;
600   } else {
601     *ignore = registration->cluster_resource_variable_ops_unsafely;
602   }
603   return Status::OK();
604 }
605 
Initialize()606 StatusOr<bool> MarkForCompilationPassImpl::Initialize() {
607   TF_RET_CHECK(!initialized_ && !edges_contracted_ && !clusters_created_);
608   initialized_ = true;
609 
610   TF_RETURN_IF_ERROR(FindCompilationCandidates());
611 
612   if (compilation_candidates_.empty()) {
613     VLOG(2) << "No compilable candidates";
614     return false;
615   }
616 
617   TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok,
618                       CreateCycleDetectionGraph(graph_, &cycles_graph_));
619   if (!cycle_detection_graph_ok) {
620     // TODO(sanjoy): This should be logged via the XLA activity listener.
621     VLOG(2) << "Could not form cycle detection graph";
622     return false;
623   }
624 
625   if (!debug_options_.ignore_deadness_checks) {
626     XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1);
627     TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(*graph_, &deadness_analysis_));
628   }
629 
630   // Each compilation candidate belongs to a cluster. The cluster's
631   // representative names the node in the 'cycles' graph that represents the
632   // cluster.
633   TF_RETURN_IF_ERROR(BuildInitialClusterSet());
634   return true;
635 }
636 
637 template <typename FnTy>
ForEachEdgeInPostOrder(FnTy fn)638 StatusOr<bool> MarkForCompilationPassImpl::ForEachEdgeInPostOrder(FnTy fn) {
639   bool changed = false;
640   for (int32 node : cycles_graph_.AllNodesInPostOrder()) {
641     Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
642     if (!cluster_from) {
643       continue;
644     }
645 
646     // Make a copy of the set of successors because we may modify the graph in
647     // TryToContractEdge.
648     std::vector<int32> successors_copy =
649         cycles_graph_.SuccessorsCopy(cluster_from->cycles_graph_node_id());
650 
651     for (int to : successors_copy) {
652       iteration_count_++;
653 
654       Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
655       if (!cluster_to) {
656         continue;
657       }
658 
659       TF_ASSIGN_OR_RETURN(bool contracted_edge, fn(cluster_from, cluster_to));
660       changed |= contracted_edge;
661     }
662   }
663 
664   return changed;
665 }
666 
GetOnlyNodeIn(const Cluster & cluster)667 Node* MarkForCompilationPassImpl::GetOnlyNodeIn(const Cluster& cluster) {
668   return cluster.cluster_size() == 1
669              ? graph_->FindNodeId(cluster.GetIdOfOnlyNode())
670              : nullptr;
671 }
672 
IsSinkLike(const Cluster & cluster)673 bool MarkForCompilationPassImpl::IsSinkLike(const Cluster& cluster) {
674   if (Node* n = GetOnlyNodeIn(cluster)) {
675     return n->type_string() == "NoOp" && n->out_edges().size() == 1 &&
676            (*n->out_edges().begin())->dst()->IsSink();
677   }
678 
679   return false;
680 }
681 
IsScalarIntegerResourceOperation(const Cluster & cluster)682 bool MarkForCompilationPassImpl::IsScalarIntegerResourceOperation(
683     const Cluster& cluster) {
684   Node* n = GetOnlyNodeIn(cluster);
685   if (!n) {
686     return false;
687   }
688 
689   if (n->type_string() != "AssignAddVariableOp" &&
690       n->type_string() != "AssignSubVariableOp") {
691     return false;
692   }
693 
694   DataType dtype;
695   if (!TryGetNodeAttr(n->def(), "dtype", &dtype) || !DataTypeIsInteger(dtype)) {
696     return false;
697   }
698 
699   Node* const_input = nullptr;
700   for (const Edge* e : n->in_edges()) {
701     if (!e->IsControlEdge() && e->src()->IsConstant()) {
702       const_input = e->src();
703       break;
704     }
705   }
706 
707   if (!const_input) {
708     return false;
709   }
710 
711   const TensorProto* proto = nullptr;
712   if (!TryGetNodeAttr(const_input->def(), "value", &proto)) {
713     return false;
714   }
715 
716   return TensorShapeUtils::IsScalar(proto->tensor_shape());
717 }
718 
RunEdgeContractionLoop()719 Status MarkForCompilationPassImpl::RunEdgeContractionLoop() {
720   TF_RET_CHECK(initialized_ && !edges_contracted_ && !clusters_created_);
721   edges_contracted_ = true;
722 
723   // TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for
724   // example, from the Grappler fusion pass).
725 
726   // In general there are multiple maximal clusterings, but they are not all
727   // equally performant.  Some clustering decision are likely to improve
728   // performance much more than others, and we cannot order contractions on this
729   // cost function, nor can we look at global information while deciding on
730   // individual edges to contract.  Instead, we will make decisions on these
731   // important edges then make decisions on all other edges, causing the highest
732   // chance of all most important edges to be contracted.
733   //
734   // An example of where this might occur is with a digraph:
735   // {A -> B, B -> C, A -> X, X -> C} where B is a Size operation and X is
736   // not-compilable. In this case, the valid clusterings are {A,B} or {B,C}. B
737   // should be clustered with A because it will prevent a potentially large
738   // tensor from A being computed and copied.
739   //
740   // To choose better maximal clusterings we make multiple iterations over the
741   // graph in post-order, where each such iteration is called a "phase".
742 
743   // Phase 0: contract metadata operations with their producer.
744 
745   VLOG(4) << "Running phase 0";
746   TF_RETURN_IF_ERROR(
747       ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) -> StatusOr<bool> {
748         // Shape consuming operations are desirable to cluster with their
749         // operands because they return a small set of scalar values after
750         // consuming a large amount of data.  For example, given a graph X -> Y
751         // -> Size -> Z, where the possible clustering is [{X, Y, Size}, {Z}] or
752         // [{X, Y}, {Size, Z}], the better clustering is Size with Y because the
753         // output of size will be a small tensor while Y is a potentially large
754         // tensor that must be computed and possible transposed/copied before
755         // the second cluster executes.
756         Node* n = GetOnlyNodeIn(*to);
757         bool is_shape_consumer_op = n && IsShapeConsumerOp(*n);
758         if (!is_shape_consumer_op) {
759           return false;
760         }
761 
762         return TryToContractEdge(from, to);
763       }).status());
764 
765   // Phase 1: apply a heuristic to ensure that we don't mess up clustering due
766   // to "group_deps".  After this phase most edges should have been contracted.
767 
768   VLOG(4) << "Running phase 1";
769   TF_RETURN_IF_ERROR(
770       ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) -> StatusOr<bool> {
771         // We split out this phase to get good clustering in the presence of a
772         // specific pattern seen in some graphs:
773         //
774         // digraph {
775         //   ApplyWeightUpdates_0 -> "iteration++"
776         //   ApplyWeightUpdates_1 -> "iteration++"
777         //   ApplyWeightUpdates_2 -> "iteration++"
778         //   ApplyWeightUpdates_0 -> Computation_A
779         //   ApplyWeightUpdates_1 -> Computation_B
780         //   ApplyWeightUpdates_2 -> Computation_C
781         //   Computation_A -> NoOp
782         //   Computation_B -> NoOp
783         //   Computation_C -> NoOp
784         //   "iteration++" -> NoOp
785         // }
786         //
787         // In the graph above we can't cluster iteration++ with any of the
788         // gradient update operations since that will break the TF resource
789         // variable memory model.  Given that constraint the ideal clustering
790         // would be to put all the gradient updates and all of the Computation_*
791         // nodes in one cluster, and leave iteration++ and NoOp unclustered.
792         //
793         // A naive post-order traversal would not create this good clustering,
794         // however.  Instead it will first create a cluster that puts
795         // Computation_* nodes, the NoOp and iteration++ node in a single
796         // cluster, after which it will fail to put any of the
797         // ApplyWeightUpdates_* nodes into this cluster. To avoid this fate we
798         // instead run a pass that avoids contracting edges _into_ NoOps like
799         // the above, and avoid clustering edges _from_ "iteration++" like the
800         // above.  Then we run a second pass that contracts the edges we could
801         // not contract the first time around.
802 
803         if (IsSinkLike(*to)) {
804           return false;
805         }
806 
807         if (IsScalarIntegerResourceOperation(*from)) {
808           return false;
809         }
810 
811         return TryToContractEdge(from, to);
812       }).status());
813 
814   // Phase 2: contract any remaining edges.  After this phase we should have a
815   // maximal clustering:
816   //
817   // A. We visit a cluster only after maximally clustering all its children.
818   // B. By the time we're done with a node all of its children that could have
819   //    been absorbed into the node have been absorbed.
820   // C. We have an invariant that making a cluster larger does not make edges
821   //    leaving it more contractable. That is, if we have
822   //    digraph { X->Y; Y->Z; } then collapsing X->Y does not make it possible
823   //    to contract Y->Z if Y->Z was not contractible originally.
824   VLOG(4) << "Running phase 2";
825   TF_RETURN_IF_ERROR(ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) {
826                        return TryToContractEdge(from, to);
827                      }).status());
828 
829   // Check that the conclusion made above (that iterating over the graph once in
830   // post order gives a maximal clustering) holds.  Once the linear time
831   // post-order scheme has been battle tested we can move this to happen only in
832   // debug builds.
833   VLOG(2) << "Checking idempotence";
834   TF_ASSIGN_OR_RETURN(bool changed,
835                       ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) {
836                         return TryToContractEdge(from, to);
837                       }));
838   TF_RET_CHECK(!changed);
839 
840   return Status::OK();
841 }
842 
843 std::atomic<int64> cluster_sequence_num;
844 
GetNextClusterSequenceNumber()845 int64 GetNextClusterSequenceNumber() { return cluster_sequence_num++; }
846 
CreateClusters()847 Status MarkForCompilationPassImpl::CreateClusters() {
848   TF_RET_CHECK(initialized_ && edges_contracted_ && !clusters_created_);
849   clusters_created_ = true;
850 
851   // Names for each cluster.
852   std::unordered_map<int, string> cluster_names;
853 
854   if (debug_options_.dump_graphs) {
855     DumpGraphToFile("before_mark_for_compilation", *graph_, flib_def_);
856   }
857 
858   // Mark clusters for compilation that:
859   // * are placed on a device that requires compilation (an XlaDevice),
860   // * are explicitly marked for compilation (_XlaCompile=true), or
861   // * have more than debug_options_.xla_min_cluster_size elements (applicable
862   //   only if compilation is enabled, otherwise there will be no such
863   //   candidates).
864   for (Node* n : compilation_candidates_) {
865     Cluster* cluster = GetClusterForNode(n);
866     TF_ASSIGN_OR_RETURN(bool should_compile_cluster,
867                         ShouldCompileCluster(*cluster));
868     if (!should_compile_cluster) {
869       continue;
870     }
871 
872     // We assume that functional If and While nodes have at least
873     // min_cluster_size non-trivial nodes in them.  It would be more principled
874     // to (recursively) verify this fact, but that's probably not worth the
875     // trouble.
876 
877     if (cluster->effective_cluster_size() >= debug_options_.min_cluster_size ||
878         cluster->has_functional_control_flow() ||
879         cluster->is_xla_compile_attr_true()) {
880       string& name = cluster_names[cluster->cycles_graph_node_id()];
881 
882       if (name.empty()) {
883         name = absl::StrCat("cluster_", GetNextClusterSequenceNumber());
884       }
885 
886       n->AddAttr(kXlaClusterAttr, name);
887       n->AddAttr(kXlaAlreadyClustered, true);
888       VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
889     }
890   }
891 
892   return Status::OK();
893 }
894 
DumpDebugInfo()895 Status MarkForCompilationPassImpl::DumpDebugInfo() {
896   TF_RET_CHECK(initialized_ && edges_contracted_ && clusters_created_);
897 
898   if (debug_options_.dump_graphs) {
899     DumpPostClusteringGraphs();
900   }
901 
902   VLogClusteringSummary();
903 
904   return Status::OK();
905 }
906 
907 StatusOr<bool>
ClusteringWillIntroduceInterDeviceDependency(const Cluster & cluster_from,const Cluster & cluster_to)908 MarkForCompilationPassImpl::ClusteringWillIntroduceInterDeviceDependency(
909     const Cluster& cluster_from, const Cluster& cluster_to) {
910   // If any of the consumer's producers are on a different device, do not
911   // cluster these nodes. This prevents other work on this device from being
912   // delayed by work on other devices. We consider predecessors of the entire
913   // cluster rather than just the inputs to the node to prevent the cluster
914   // still being combined in cases where the 'to' cluster has multiple
915   // dependencies on the 'from' cluster and another dependency leads to a
916   // merging of the clusters.
917   //
918   // TODO(b/117085735): We probably want to handle the reciprocal of this case
919   // where a cluster is producing data for multiple devices.
920   for (const auto& in_id :
921        cycles_graph_.Predecessors(cluster_to.cycles_graph_node_id())) {
922     const Cluster* cluster_in = GetClusterForCyclesGraphNode(in_id);
923     if (cluster_in) {
924       TF_ASSIGN_OR_RETURN(bool devices_compatible,
925                           AreDevicesCompatible(cluster_to, *cluster_in));
926       if (!devices_compatible) {
927         return true;
928       }
929       TF_ASSIGN_OR_RETURN(devices_compatible,
930                           AreDevicesCompatible(cluster_from, *cluster_in));
931       if (!devices_compatible) {
932         return true;
933       }
934     }
935   }
936 
937   return false;
938 }
939 
GetXlaScope(Node * node)940 absl::optional<string> MarkForCompilationPassImpl::GetXlaScope(Node* node) {
941   // Look for either _XlaScope or _XlaInternalScope on both nodes to guide
942   // clustering.  If both nodes have a scope and the scopes do not match, do
943   // not cluster along this edge.  If even one of the nodes lacks a scope
944   // attribute, then it is treated as a "bridge" and a cluster may be created
945   // along it.
946   //
947   // The difference between _XlaScope and _XlaInternalScope is that _XlaScope is
948   // provided by users through jit_scope APIs, while _XlaInternalScope is
949   // automatically generated by the ClusterScopingPass when auto_jit is on.  As
950   // such, we respect _XlaScope only when auto_jit is off, while respecting
951   // _XlaInternalScope only when auto_jit is on.
952   //
953   // We may want to restrict the _XlaScope behavior to require all nodes marked
954   // with _XlaCompile=true to also have a _XlaScope property set (and raise an
955   // error otherwise); but for now we don't do this.
956 
957   if (global_jit_level_ != OptimizerOptions::OFF) {
958     // If global_jit_level_ is ON, respect only _XlaInternalScope.
959     const string& scope =
960         GetNodeAttrString(node->attrs(), kXlaInternalScopeAttr);
961     if (!scope.empty()) {
962       return scope;
963     }
964   } else {
965     // If global_jit_level_ is OFF, respect only _XlaScope.
966     const string& scope = GetNodeAttrString(node->attrs(), kXlaScopeAttr);
967     if (!scope.empty()) {
968       return scope;
969     }
970   }
971 
972   return absl::nullopt;
973 }
974 
975 // Returns true iff the attribute `attr_name` is attached to either the node or
976 // to it's callee.
GetNodeOrFuncAttr(Node * node,FunctionLibraryDefinition * flib_def,const char * attr_name)977 static bool GetNodeOrFuncAttr(Node* node, FunctionLibraryDefinition* flib_def,
978                               const char* attr_name) {
979   bool out = false;
980   bool attr_value;
981   if (TryGetNodeAttr(node->attrs(), attr_name, &attr_value)) {
982     out |= attr_value;
983   }
984 
985   if (flib_def->GetAttr(*node, attr_name, &attr_value).ok()) {
986     out |= attr_value;
987   }
988   return out;
989 }
990 
BuildInitialClusterSet()991 Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
992   auto ignore_resource_ops = [&](const Node& n, bool* ignore) {
993     return IgnoreResourceOpForSafetyAnalysis(&device_info_cache_, n, ignore);
994   };
995 
996   std::vector<std::pair<int, int>> unsafe_resource_deps_vect;
997   TF_RETURN_IF_ERROR(ComputeIncompatibleResourceOperationPairs(
998       *graph_, flib_def_, ignore_resource_ops, &unsafe_resource_deps_vect));
999   absl::c_copy(
1000       unsafe_resource_deps_vect,
1001       std::inserter(unsafe_resource_deps_, unsafe_resource_deps_.begin()));
1002 
1003   cluster_for_node_.resize(graph_->num_node_ids());
1004   for (Node* node : graph_->nodes()) {
1005     if (!IsCompilationCandidate(node)) {
1006       cluster_for_node_[node->id()].Get() = nullptr;
1007       continue;
1008     }
1009 
1010     // We want clusters to be big enough that the benefit from XLA's
1011     // optimizations offsets XLA related overhead (for instance we add some
1012     // Switch/Merge nodes into the graph to implement lazy compilation).  To
1013     // this end, we don't count Identity and Constant nodes because they do not
1014     // enable interesting optimizations by themselves.
1015     int effective_cluster_size =
1016         (node->IsIdentity() || node->IsConstant()) ? 0 : 1;
1017 
1018     bool has_functional_control_flow = node->IsWhileNode() || node->IsIfNode();
1019 
1020     absl::optional<DeadnessPredicate> deadness_predicate;
1021     if (deadness_analysis_) {
1022       TF_ASSIGN_OR_RETURN(
1023           deadness_predicate,
1024           deadness_analysis_->GetPredicateFor(node, Graph::kControlSlot));
1025     }
1026 
1027     const string& device_name_str = !node->assigned_device_name().empty()
1028                                         ? node->assigned_device_name()
1029                                         : node->requested_device();
1030     TF_ASSIGN_OR_RETURN(DeviceId device,
1031                         device_info_cache_.GetIdFor(device_name_str));
1032 
1033     bool is_resource_op = HasResourceInputOrOutput(*node);
1034     absl::optional<DeviceId> resource_op_device;
1035     if (is_resource_op) {
1036       resource_op_device = device;
1037     }
1038 
1039     absl::optional<int> resource_var_operation_node_id;
1040     if (is_resource_op || MayCallFunction(*node, flib_def_)) {
1041       resource_var_operation_node_id = node->id();
1042     }
1043 
1044     bool is_xla_compile_attr_true =
1045         GetNodeOrFuncAttr(node, flib_def_, kXlaCompileAttr) ||
1046         (global_jit_level_ != OptimizerOptions::OFF &&
1047          GetNodeOrFuncAttr(node, flib_def_, kXlaMustCompileAttr));
1048 
1049     DeviceSet devices;
1050     devices.Insert(device);
1051 
1052     Cluster* new_cluster = MakeNewCluster(
1053         /*cycles_graph_node_id=*/node->id(),
1054         /*effective_cluster_size=*/effective_cluster_size,
1055         /*has_functional_control_flow=*/has_functional_control_flow, devices,
1056         resource_op_device, resource_var_operation_node_id, deadness_predicate,
1057         /*is_xla_compile_attr_true=*/is_xla_compile_attr_true,
1058         GetXlaScope(node));
1059 
1060     cluster_for_node_[node->id()].Get() = new_cluster;
1061   }
1062 
1063   return Status::OK();
1064 }
1065 
IsIdentityDrivingConstsInLoop(Node * node)1066 StatusOr<bool> IsIdentityDrivingConstsInLoop(Node* node) {
1067   if (!node->IsIdentity()) {
1068     return false;
1069   }
1070 
1071   // Check if the Identity is driven by a Switch on its true path.
1072   auto it = absl::c_find_if(node->in_edges(), [](const Edge* e) {
1073     return e->src()->IsSwitch() && e->src_output() == 1;
1074   });
1075   if (it == node->in_edges().end()) {
1076     return false;
1077   }
1078   const Node* switch_node = (*it)->src();
1079 
1080   // Check if the Switch is driven by LoopCond.
1081   const Node* maybe_loop_cond;
1082   TF_RETURN_IF_ERROR(switch_node->input_node(1, &maybe_loop_cond));
1083   if (!maybe_loop_cond->IsLoopCond()) {
1084     return false;
1085   }
1086 
1087   // Check if the Identity is driving any const nodes through a control edge.
1088   bool driving_any_consts =
1089       absl::c_any_of(node->out_edges(), [](const Edge* e) {
1090         return e->dst()->IsConstant() && e->IsControlEdge();
1091       });
1092   if (!driving_any_consts) {
1093     return false;
1094   }
1095 
1096   return true;
1097 }
1098 
GetOrCreateAllowlist()1099 absl::flat_hash_set<string> GetOrCreateAllowlist() {
1100   absl::flat_hash_map<string, std::vector<string>>* allowlist_table =
1101       tensorflow::GetAllowlistTable();
1102   MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1103   absl::flat_hash_set<string> allowlist;
1104 
1105   for (auto s : absl::StrSplit(flags->tf_xla_ops_to_cluster, ',')) {
1106     if (s == "FUSIBLE") {
1107       for (auto pair : *allowlist_table) {
1108         allowlist.insert(pair.second.begin(), pair.second.end());
1109       }
1110     } else if (allowlist_table->contains(s)) {
1111       auto v = allowlist_table->at(s);
1112       allowlist.insert(v.begin(), v.end());
1113     } else if (!s.empty()) {
1114       // Should be a user provided TF operation.
1115       allowlist.insert(string(s));
1116     }
1117   }
1118 
1119   if (VLOG_IS_ON(2) && !allowlist.empty()) {
1120     std::vector<string> vallowlist(allowlist.begin(), allowlist.end());
1121     absl::c_sort(vallowlist);
1122     VLOG(2) << "XLA clustering will only consider the following TF operations: "
1123             << absl::StrJoin(vallowlist, " ");
1124   }
1125   return allowlist;
1126 }
1127 
FindCompilationCandidates()1128 Status MarkForCompilationPassImpl::FindCompilationCandidates() {
1129   OptimizerOptions opts;
1130   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
1131       new ProcessFunctionLibraryRuntime(nullptr, env_, /*config=*/nullptr,
1132                                         TF_GRAPH_DEF_VERSION, flib_def_, opts));
1133   FunctionLibraryRuntime* lib_runtime =
1134       pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
1135   std::vector<bool> compile_time_const_nodes(graph_->num_node_ids(), false);
1136   TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
1137       *graph_, /*compile_time_const_arg_indices=*/nullptr,
1138       &compile_time_const_nodes, lib_runtime));
1139   // Iterate over nodes in sorted order so that compiler fuel is deterministic.
1140   // We can't simply pass op_nodes().begin() and op_nodes().end() to the
1141   // std::vector constructor because they're not proper iterators, with
1142   // iterator_traits defined and so on.
1143   std::vector<Node*> sorted_nodes;
1144   for (Node* node : graph_->op_nodes()) {
1145     sorted_nodes.push_back(node);
1146   }
1147   std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID());
1148 
1149   if (*debug_options_.fuel >= std::numeric_limits<int64>::max() / 2) {
1150     // The assumption is that if fuel started out as INT64_MAX, it will forever
1151     // stay greater than INT64_MAX / 2.
1152     VLOG(2) << "Starting fuel: infinity";
1153   } else {
1154     VLOG(2) << "Starting fuel: " << *debug_options_.fuel;
1155   }
1156 
1157   VLOG(2) << "sorted_nodes.size() = " << sorted_nodes.size();
1158 
1159   auto allowlist = GetOrCreateAllowlist();
1160 
1161   std::vector<string> vall_ops = XlaOpRegistry::GetAllRegisteredOps();
1162   absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end());
1163   // Check that user's provided TF operation really exists.
1164   for (const auto& s : allowlist) {
1165     if (!all_ops.contains(string(s))) {
1166       return errors::InvalidArgument(
1167           "The operation '", s,
1168           "' passed to --tf_xla_ops_to_cluster is not supported by XLA.");
1169     }
1170   }
1171 
1172   for (Node* node : sorted_nodes) {
1173     if (*debug_options_.fuel <= 0) {
1174       VLOG(1)
1175           << "Hit fuel limit; not marking any remaining ops as clusterable.";
1176       break;
1177     }
1178 
1179     TF_ASSIGN_OR_RETURN(
1180         const DeviceType& device_type,
1181         device_info_cache_.GetDeviceTypeFor(node->assigned_device_name()));
1182     VLOG(4) << "Device type for " << node->name() << ": "
1183             << device_type.type_string();
1184 
1185     if (CompilationDisallowedByXlaCompileAttr(node)) {
1186       VLOG(2) << "Not clustering " << node->name()
1187               << ": disallowed by _XlaCompile attribute";
1188       continue;
1189     }
1190 
1191     const XlaOpRegistry::DeviceRegistration* registration;
1192     if (!XlaOpRegistry::GetCompilationDevice(device_type.type(),
1193                                              &registration)) {
1194       VLOG(2) << "Rejecting " << node->name()
1195               << ": could not find JIT device for " << device_type.type();
1196       continue;
1197     }
1198 
1199     RecursiveCompilabilityChecker::OperationFilter filter =
1200         CreateOperationFilter(*registration);
1201     filter.require_always_compilable = true;
1202     filter.allow_string_consts = false;
1203 
1204     RecursiveCompilabilityChecker checker(
1205         filter, DeviceType{registration->compilation_device_name});
1206 
1207     if (!checker.IsCompilableNode(*node, lib_runtime)) {
1208       continue;
1209     }
1210 
1211     if (node->type_string() == "Const") {
1212       // Skip Const op with type DT_STRING, since XLA autoclustering doesn't
1213       // support it.
1214       const AttrValue* attr = node->attrs().Find("dtype");
1215       if (attr != nullptr && attr->type() == DT_STRING) {
1216         continue;
1217       }
1218     }
1219 
1220     if (!allowlist.empty() && !allowlist.contains(node->def().op())) {
1221       VLOG(1) << "Rejecting TF operation " << node->def().op()
1222               << " as it is not listed in --tf_xla_ops_to_cluster.";
1223       continue;
1224     }
1225 
1226     if (compile_time_const_nodes[node->id()]) {
1227       const OpDef* op_def;
1228       TF_RETURN_IF_ERROR(
1229           graph_->op_registry()->LookUpOpDef(node->type_string(), &op_def));
1230       if (op_def->is_stateful()) {
1231         // It is easiest to demonstrate the problem we're trying to solve with
1232         // an example.  Say we have this graph:
1233         //
1234         //   shape = RandomUniformInt();
1235         //   reshape = Reshape(input, shape)
1236         //
1237         // Both RandomUniformInt and Reshape are compilable by XLA so, absent
1238         // any other reason, we will try to put both shape and reshape in the
1239         // same cluster.  However, since XLA only supports statically shaped
1240         // values, it will expect to be able to constant fold `shape` to get a
1241         // static shape for `reshape`.  This is a problem because side-effecting
1242         // ops like RandomUniformInt() cannot be constant folded.  We fix this
1243         // by putting `shape` and `reshape` in different clusters, which results
1244         // in us recompiling `reshape`'s cluster for every new value of `shape`,
1245         // making `reshape` statically sized within each compilation.  We
1246         // simplify the solution even further by disallowing operations like
1247         // `shape` from being part of *any* non-trivial cluster.  They're either
1248         // not compiled by XLA altogether or, if assigned to an XLA_* device
1249         // with "must compile" semantics, compiled into a trivial single-op
1250         // cluster.  This approach leaves some room for improvement, and we can
1251         // consider implementing a more aggressive data-flow-analysis based
1252         // solution in the future if needed.
1253         //
1254         // One ugly problem we have to contend with: certain sets of ops *have*
1255         // to be in the same cluster because values flowing between them have
1256         // types that can't be live-in or live-out of a cluster.  These ops are:
1257         //
1258         //  - TensorArray ops operating on the same TensorArray instance.
1259         //  - Stack ops operating on the same Stack instance.
1260         //
1261         // To work around this we avoid isolating these specific ops.  Because
1262         // of this concession it is unsound to auto-cluster them because then
1263         // we'd create clusters we could not compile (because we can't constant
1264         // fold, say, a TensorArrayRead or a StackPopV2).  But we don't
1265         // auto-cluster these operations today so we're good for now.
1266         const XlaResourceOpInfo* op_info =
1267             GetResourceOpInfoForOp(node->type_string());
1268         bool is_tensor_array_or_stack_op =
1269             op_info && op_info->resource_kind() != XlaResourceKind::kVariable;
1270         if (!is_tensor_array_or_stack_op) {
1271           VLOG(2) << "Isolating " << node->name()
1272                   << ": must-be-constant stateful op";
1273           continue;
1274         }
1275       }
1276     }
1277 
1278     // This is a heuristic to avoid creating dependency between while loop
1279     // condition and body computations.  Dependency between them can be created
1280     // if a special Identity node in the following pattern is clustered in.
1281     // That is, an Identity node in the loop cond computation is used to drive
1282     // const nodes consumed by the loop body.  If this Identity node goes into
1283     // the same cluster with nodes from the loop body, extra dependency is
1284     // created between the loop cond and body computations and it hinders the
1285     // progression of the loop cond computation at runtime with significant
1286     // overhead.  Specifically, we look for the below pattern and do not cluster
1287     // in this Identity to avoid the described issue.  Since Identity has low
1288     // execution cost in native TF, the fact that this heuristic gives up these
1289     // special Identity nodes as candidates should not harm any performance.  If
1290     // other considerations emerge in the future, we can revisit the heuristic
1291     // and only disallow these Identities to go into the cluster with nodes from
1292     // the loop body but still consider them candidates.
1293     //
1294     // LoopCond ->
1295     // Merge    -> Switch -> Identity -> i++ -> ... -> NextIteration
1296     //                               ..> Const -> LoopBody
1297     //                            (control edge)
1298     TF_ASSIGN_OR_RETURN(bool is_identity_driving_consts_in_loop,
1299                         IsIdentityDrivingConstsInLoop(node));
1300     if (is_identity_driving_consts_in_loop) {
1301       VLOG(2) << "Rejecting " << node->name()
1302               << ": including it can create dependencies between while loop "
1303                  "condition and body computations with runtime overhead.";
1304       continue;
1305     }
1306 
1307     compilation_candidates_.insert(node);
1308     --(*debug_options_.fuel);
1309   }
1310 
1311   VLOG(2) << "compilation_candidates_.size() = "
1312           << compilation_candidates_.size();
1313   return Status::OK();
1314 }
1315 
CompilationDisallowedByXlaCompileAttr(Node * node)1316 bool MarkForCompilationPassImpl::CompilationDisallowedByXlaCompileAttr(
1317     Node* node) {
1318   if (debug_options_.ignore_xla_compile_attr) {
1319     return false;
1320   }
1321 
1322   // If there is a _XlaCompile annotation, use its value.
1323   bool compile = false;
1324   Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
1325   if (status.ok()) {
1326     if (!compile) {
1327       VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr("
1328               << kXlaCompileAttr << ") is false.";
1329     }
1330     return !compile;
1331   }
1332 
1333   status = flib_def_->GetAttr(*node, kXlaCompileAttr, &compile);
1334   if (status.ok()) {
1335     if (!compile) {
1336       VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr("
1337               << kXlaCompileAttr << ") on callee is false.";
1338     }
1339     return !compile;
1340   }
1341 
1342   return false;
1343 }
1344 
LogNotContractableAndReturnFalse(Cluster * from,Cluster * to,absl::string_view reason)1345 bool MarkForCompilationPassImpl::LogNotContractableAndReturnFalse(
1346     Cluster* from, Cluster* to, absl::string_view reason) {
1347   VLOG(3) << EdgeContractionFailureMsg(from, to, reason);
1348   return false;
1349 }
1350 
TryToContractEdge(Cluster * from,Cluster * to)1351 StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdge(Cluster* from,
1352                                                              Cluster* to) {
1353   DCHECK(from->deadness_predicate().has_value() ==
1354          to->deadness_predicate().has_value());
1355   if (from->deadness_predicate() != to->deadness_predicate()) {
1356     VLOG(3) << EdgeContractionFailureMsg(
1357         from, to,
1358         absl::StrCat(
1359             "the two nodes have mismatching deadness: ",
1360             deadness_analysis_->DebugString(*from->deadness_predicate()),
1361             " and ",
1362             deadness_analysis_->DebugString(*to->deadness_predicate())));
1363     return false;
1364   }
1365 
1366   TF_ASSIGN_OR_RETURN(bool devices_compatible,
1367                       AreDevicesCompatible(*from, *to));
1368   if (!devices_compatible) {
1369     return LogNotContractableAndReturnFalse(
1370         from, to, "the two nodes have incompatible devices");
1371   }
1372 
1373   if (from->xla_scope().has_value() && to->xla_scope().has_value() &&
1374       *from->xla_scope() != *to->xla_scope()) {
1375     return LogNotContractableAndReturnFalse(
1376         from, to, "the two nodes have mismatching XLA scopes");
1377   }
1378 
1379   // Don't exceed the maximum cluster size.
1380   if (from->cluster_size() + to->cluster_size() >
1381       debug_options_.max_cluster_size) {
1382     return LogNotContractableAndReturnFalse(
1383         from, to, "the new cluster will be larger than the max cluster size");
1384   }
1385 
1386   TF_ASSIGN_OR_RETURN(bool will_introduce_cross_device_dependency,
1387                       ClusteringWillIntroduceInterDeviceDependency(*from, *to));
1388 
1389   if (will_introduce_cross_device_dependency) {
1390     return LogNotContractableAndReturnFalse(
1391         from, to, "the new cluster will introduce a cross device dependency");
1392   }
1393 
1394   // Check if contracting this edge will break the resource variable concurrency
1395   // semantics.  In theory this is quadratic in the number of nodes, but seems
1396   // to not be a problem in practice so far.
1397   if (!debug_options_.ignore_resource_variable_checks) {
1398     for (int resource_var_from : from->resource_var_operation_node_ids()) {
1399       for (int resource_var_to : to->resource_var_operation_node_ids()) {
1400         // If unsafe_resource_deps_ contains {A, B} then
1401         //
1402         //  a. A and B are resource operations.
1403         //  b. A and B cannot be placed in the same cluster.
1404         //  c. There is no path from B to A in the cycles graph (but there may
1405         //     be a path from A to B).
1406         //
1407         // So check the legality of the edge contraction by checking if any of
1408         // the n^2 pairs of resource variable operations are forbidden.
1409         if (unsafe_resource_deps_.contains(
1410                 {resource_var_from, resource_var_to})) {
1411           return LogNotContractableAndReturnFalse(
1412               from, to,
1413               "the new cluster would break resource variable semantics");
1414         }
1415       }
1416     }
1417   }
1418 
1419   return MergeClusters(from, to);
1420 }
1421 
Run()1422 Status MarkForCompilationPassImpl::Run() {
1423   // Make sure that kernels have been registered on the JIT device.
1424   XlaOpRegistry::RegisterCompilationKernels();
1425 
1426   // Start the timer after XlaOpRegistry::RegisterCompilationKernels which does
1427   // some one-time work.
1428   XLA_SCOPED_LOGGING_TIMER_LEVEL("MarkForCompilationPassImpl::Run", 1);
1429 
1430   TF_ASSIGN_OR_RETURN(bool initialized, Initialize());
1431   if (!initialized) {
1432     // Initialization exited early which means this instance of
1433     // MarkForCompilationPassImpl is not set up to run the subsequent phases.
1434     return Status::OK();
1435   }
1436 
1437   TF_RETURN_IF_ERROR(RunEdgeContractionLoop());
1438   TF_RETURN_IF_ERROR(CreateClusters());
1439   TF_RETURN_IF_ERROR(DumpDebugInfo());
1440 
1441   return Status::OK();
1442 }
1443 
DumpPostClusteringGraphs()1444 void MarkForCompilationPassImpl::DumpPostClusteringGraphs() {
1445   DumpGraphToFile("mark_for_compilation", *graph_, flib_def_);
1446 
1447   // We also dump out an annotated version of the TF graph where the nodes
1448   // names are prefixed with the cluster names.  This can help visualizing the
1449   // clustering decisions on TensorBoard.
1450   Graph new_graph(graph_->op_registry());
1451   CopyGraph(*graph_, &new_graph);
1452 
1453   for (Node* n : new_graph.nodes()) {
1454     if (absl::optional<absl::string_view> cluster_name =
1455             GetXlaClusterForNode(*n)) {
1456       n->set_name(absl::StrCat(*cluster_name, "/", n->name()));
1457     } else if (n->type_string() == "VarHandleOp") {
1458       n->set_name(absl::StrCat("varhandle/", n->name()));
1459     } else {
1460       // There is room for improvement here.  In particular, it may help to
1461       // split these unclustered nodes into classes where every node in a
1462       // specific class has edges to and from the same set of clusters.
1463       n->set_name(absl::StrCat("unclustered/", n->name()));
1464     }
1465   }
1466 
1467   DumpGraphToFile("mark_for_compilation_annotated", new_graph, flib_def_);
1468 }
1469 
RatioToString(int numerator,int denominator)1470 string RatioToString(int numerator, int denominator) {
1471   return absl::StrFormat("%d / %d (%.2f%%)", numerator, denominator,
1472                          (100.0 * numerator) / denominator);
1473 }
1474 
VLogClusteringSummary()1475 void MarkForCompilationPassImpl::VLogClusteringSummary() {
1476   if (!VLOG_IS_ON(2)) {
1477     return;
1478   }
1479 
1480   XlaAutoClusteringSummary auto_clustering_info =
1481       GetXlaAutoClusteringSummary(*graph_);
1482 
1483   VLOG(2) << "*** Clustering info for graph of size " << graph_->num_nodes();
1484   VLOG(2) << " Built " << auto_clustering_info.clusters_size()
1485           << " clusters, size "
1486           << RatioToString(auto_clustering_info.clustered_node_count(),
1487                            graph_->num_nodes());
1488 
1489   for (const XlaAutoClusteringSummary::Cluster& cluster :
1490        auto_clustering_info.clusters()) {
1491     absl::string_view cluster_name = cluster.name();
1492     int size = cluster.size();
1493     VLOG(2) << "  " << cluster_name << " "
1494             << RatioToString(size, graph_->num_nodes());
1495     for (const XlaAutoClusteringSummary::OpAndCount& op_count :
1496          cluster.op_histogram()) {
1497       VLOG(3) << "   " << op_count.op() << ": " << op_count.count()
1498               << " instances";
1499     }
1500   }
1501 
1502   if (!auto_clustering_info.unclustered_op_histogram().empty()) {
1503     VLOG(2) << " Unclustered nodes: "
1504             << RatioToString(auto_clustering_info.unclustered_node_count(),
1505                              graph_->num_nodes());
1506     for (const XlaAutoClusteringSummary::OpAndCount& op_count :
1507          auto_clustering_info.unclustered_op_histogram()) {
1508       VLOG(3) << "  " << op_count.op() << ": " << op_count.count()
1509               << " instances";
1510     }
1511   }
1512 
1513   struct EdgeInfo {
1514     absl::string_view node_name;
1515     absl::optional<absl::string_view> cluster_name;
1516 
1517     absl::string_view GetClusterName() const {
1518       return cluster_name ? *cluster_name : "[none]";
1519     }
1520 
1521     std::pair<absl::string_view, absl::optional<absl::string_view>> AsPair()
1522         const {
1523       return {node_name, cluster_name};
1524     }
1525 
1526     bool operator<(const EdgeInfo& other) const {
1527       return AsPair() < other.AsPair();
1528     }
1529   };
1530 
1531   using EdgeInfoMap = std::map<absl::string_view, std::map<EdgeInfo, int64>>;
1532 
1533   EdgeInfoMap incoming_edge_infos;
1534   EdgeInfoMap outgoing_edge_infos;
1535 
1536   std::set<absl::string_view> cluster_names_to_print;
1537 
1538   for (const Edge* e : graph_->edges()) {
1539     const Node* from = e->src();
1540     absl::optional<absl::string_view> from_cluster_name =
1541         GetXlaClusterForNode(*from);
1542 
1543     const Node* to = e->dst();
1544     absl::optional<absl::string_view> to_cluster_name =
1545         GetXlaClusterForNode(*to);
1546 
1547     if (to_cluster_name == from_cluster_name) {
1548       continue;
1549     }
1550 
1551     if (to_cluster_name) {
1552       incoming_edge_infos[*to_cluster_name]
1553                          [EdgeInfo{from->name(), from_cluster_name}]++;
1554       cluster_names_to_print.insert(*to_cluster_name);
1555     }
1556 
1557     if (from_cluster_name) {
1558       outgoing_edge_infos[*from_cluster_name][{to->name(), to_cluster_name}]++;
1559       cluster_names_to_print.insert(*from_cluster_name);
1560     }
1561   }
1562 
1563   VLOG(4) << "*** Inter-Cluster edges:";
1564   if (cluster_names_to_print.empty()) {
1565     VLOG(4) << "   [none]";
1566   }
1567 
1568   auto print_edge_info_set_for_cluster = [&](absl::string_view cluster_name,
1569                                              const EdgeInfoMap& edge_info_map,
1570                                              absl::string_view desc) {
1571     auto it = edge_info_map.find(cluster_name);
1572     if (it != edge_info_map.end()) {
1573       VLOG(4) << "  " << it->second.size() << " " << desc << " edges";
1574       for (const auto& edge_info_count_pair : it->second) {
1575         VLOG(4) << "   " << edge_info_count_pair.first.GetClusterName() << " "
1576                 << edge_info_count_pair.first.node_name << " # "
1577                 << edge_info_count_pair.second;
1578       }
1579     } else {
1580       VLOG(4) << "  No " << desc << " edges.";
1581     }
1582   };
1583 
1584   for (absl::string_view cluster_name : cluster_names_to_print) {
1585     VLOG(4) << " ** Cluster " << cluster_name;
1586     print_edge_info_set_for_cluster(cluster_name, incoming_edge_infos,
1587                                     "incoming");
1588     print_edge_info_set_for_cluster(cluster_name, outgoing_edge_infos,
1589                                     "outgoing");
1590   }
1591 }
1592 
AreDevicesCompatible(const Cluster & cluster_a,const Cluster & cluster_b)1593 StatusOr<bool> MarkForCompilationPassImpl::AreDevicesCompatible(
1594     const Cluster& cluster_a, const Cluster& cluster_b) {
1595   DeviceSet devices = cluster_a.devices();
1596   devices.UnionWith(cluster_b.devices());
1597 
1598   TF_ASSIGN_OR_RETURN(
1599       absl::optional<jit::DeviceId> maybe_chosen_device,
1600       MaybePickDeviceForXla(device_info_cache_, devices,
1601                             /*allow_mixing_unknown_and_cpu=*/false));
1602   if (!maybe_chosen_device.has_value()) {
1603     return false;
1604   }
1605 
1606   jit::DeviceId chosen_device = *maybe_chosen_device;
1607 
1608   // If we are able to pick a device `chosen_device` for the larger cluster, the
1609   // resource operations in `cluster_a` and `cluster_b` must be placed on the
1610   // same device as `chosen_device`.  This is because the _XlaCompile and
1611   // _XlaRun kernels are going to run on and therefore try to access the
1612   // resource variables from `chosen_device`, which will be an error if the
1613   // resource variables are placed on some other device.
1614   auto resource_op_device_ok =
1615       [&](absl::optional<DeviceId> resource_op_device) {
1616         return !resource_op_device.has_value() ||
1617                *resource_op_device == chosen_device;
1618       };
1619 
1620   return resource_op_device_ok(cluster_a.resource_op_device()) &&
1621          resource_op_device_ok(cluster_b.resource_op_device());
1622 }
1623 
1624 // Returns `true` iff we should compile `cluster`.
ShouldCompileClusterImpl(const Cluster & cluster)1625 StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileClusterImpl(
1626     const Cluster& cluster) {
1627   TF_ASSIGN_OR_RETURN(DeviceId chosen_device,
1628                       PickDeviceForXla(device_info_cache_, cluster.devices(),
1629                                        /*allow_mixing_unknown_and_cpu=*/false));
1630 
1631   const DeviceType& device_type =
1632       device_info_cache_.GetDeviceTypeFor(chosen_device);
1633   const XlaOpRegistry::DeviceRegistration* registration =
1634       device_info_cache_.GetCompilationDevice(chosen_device);
1635   TF_RET_CHECK(registration)
1636       << "chosen device = " << device_info_cache_.GetNameFor(chosen_device)
1637       << "; device type = " << device_type.type() << "; devices ("
1638       << device_info_cache_.DebugString(cluster.devices());
1639 
1640   bool should_compile =
1641       cluster.is_xla_compile_attr_true() ||
1642       registration->autoclustering_policy ==
1643           XlaOpRegistry::AutoclusteringPolicy::kAlways ||
1644       (registration->autoclustering_policy ==
1645            XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally &&
1646        global_jit_level_ != OptimizerOptions::OFF);
1647 
1648   if (!should_compile && global_jit_level_ != OptimizerOptions::OFF &&
1649       device_type.type_string() == DEVICE_CPU) {
1650     static absl::once_flag once;
1651     absl::call_once(once, [] {
1652       LOG(WARNING)
1653           << "(One-time warning): Not using XLA:CPU for cluster because envvar "
1654              "TF_XLA_FLAGS=--tf_xla_cpu_global_jit was not set.  If you want "
1655              "XLA:CPU, either set that envvar, or use experimental_jit_scope "
1656              "to enable XLA:CPU.  To confirm that XLA is active, pass "
1657              "--vmodule=xla_compilation_cache=1 (as a proper command-line "
1658              "flag, not via TF_XLA_FLAGS) or set the envvar "
1659              "XLA_FLAGS=--xla_hlo_profile.";
1660       MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1661       if (flags->tf_xla_cpu_global_jit) {
1662         LOG(WARNING)
1663             << "(Although the tf_xla_cpu_global_jit flag is currently enabled, "
1664                "perhaps it wasn't enabled at process startup?)";
1665       }
1666     });
1667   }
1668 
1669   VLOG(3) << (should_compile ? "Compiling" : "Not compiling")
1670           << " cluster with device "
1671           << device_info_cache_.GetNameFor(chosen_device);
1672 
1673   return should_compile;
1674 }
1675 
ShouldCompileCluster(const Cluster & cluster)1676 StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileCluster(
1677     const Cluster& cluster) {
1678   auto it = should_compile_cluster_cache_.find(&cluster);
1679   if (it != should_compile_cluster_cache_.end()) {
1680     return it->second;
1681   }
1682 
1683   TF_ASSIGN_OR_RETURN(bool should_compile, ShouldCompileClusterImpl(cluster));
1684   should_compile_cluster_cache_.insert({&cluster, should_compile});
1685   return should_compile;
1686 }
1687 
MarkForCompilation(const GraphOptimizationPassOptions & options,const MarkForCompilationPassImpl::DebugOptions & debug_options)1688 Status MarkForCompilation(
1689     const GraphOptimizationPassOptions& options,
1690     const MarkForCompilationPassImpl::DebugOptions& debug_options) {
1691   Graph* graph = options.graph->get();
1692   FunctionLibraryDefinition* flib_def = options.flib_def;
1693 
1694   // Deadness analysis expects a graph with source and sink edges properly
1695   // connected but sometimes the incoming graph does not follow this invariant.
1696   // So fix up the source and sink edges before calling into deadness analysis.
1697   FixupSourceAndSinkEdges(graph);
1698 
1699   // See explanation on `kXlaAlreadyClustered`.
1700   for (Node* n : graph->nodes()) {
1701     if (n->attrs().Find(kXlaAlreadyClustered)) {
1702       return Status::OK();
1703     }
1704   }
1705 
1706   return MarkForCompilationPassImpl{debug_options, graph, flib_def,
1707                                     options.session_options != nullptr
1708                                         ? options.session_options->env
1709                                         : Env::Default(),
1710                                     GetGlobalJitLevelForGraph(options)}
1711       .Run();
1712 }
1713 
GetPointerToFuel(int64 initial_value)1714 std::atomic<int64>* GetPointerToFuel(int64 initial_value) {
1715   static std::atomic<int64>* fuel = [&]() {
1716     std::atomic<int64>* fuel = new std::atomic<int64>;
1717     *fuel = initial_value;
1718     return fuel;
1719   }();
1720 
1721   return fuel;
1722 }
1723 }  // anonymous namespace
1724 
Run(const GraphOptimizationPassOptions & options)1725 Status MarkForCompilationPass::Run(
1726     const GraphOptimizationPassOptions& options) {
1727   MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1728 
1729   MarkForCompilationPassImpl::DebugOptions debug_options;
1730   debug_options.ignore_deadness_checks =
1731       flags->tf_xla_disable_deadness_safety_checks_for_debugging;
1732   debug_options.ignore_resource_variable_checks =
1733       flags->tf_xla_disable_resource_variable_safety_checks_for_debugging;
1734   debug_options.ignore_xla_compile_attr = false;
1735   debug_options.max_cluster_size = flags->tf_xla_max_cluster_size;
1736   debug_options.min_cluster_size = flags->tf_xla_min_cluster_size;
1737   debug_options.fuel = GetPointerToFuel(flags->tf_xla_clustering_fuel);
1738   debug_options.dump_graphs = flags->tf_xla_clustering_debug;
1739 
1740   return MarkForCompilation(options, debug_options);
1741 }
1742 
RunForTest(const GraphOptimizationPassOptions & options,bool disable_deadness_analysis)1743 Status MarkForCompilationPass::RunForTest(
1744     const GraphOptimizationPassOptions& options,
1745     bool disable_deadness_analysis) {
1746   MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1747 
1748   MarkForCompilationPassImpl::DebugOptions debug_options;
1749   debug_options.ignore_deadness_checks = disable_deadness_analysis;
1750   debug_options.ignore_resource_variable_checks =
1751       flags->tf_xla_disable_resource_variable_safety_checks_for_debugging;
1752   debug_options.ignore_xla_compile_attr = true;
1753   debug_options.max_cluster_size = flags->tf_xla_max_cluster_size;
1754   debug_options.min_cluster_size = flags->tf_xla_min_cluster_size;
1755   debug_options.fuel = GetPointerToFuel(flags->tf_xla_clustering_fuel);
1756   debug_options.dump_graphs = flags->tf_xla_clustering_debug;
1757 
1758   return MarkForCompilation(options, debug_options);
1759 }
1760 
GetAllowlistTable()1761 absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable() {
1762   // Table format: category name: {list of TF operations in that category}
1763   static absl::flat_hash_map<string, std::vector<string>>* result =
1764       new absl::flat_hash_map<string, std::vector<string>>{
1765           // Unary
1766           {"PW",
1767            {"ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin",
1768             "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp", "Expm1",
1769             "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal", "Log",
1770             "Log1p", "Invert", "LogicalNot", "Ndtri", "Neg", "Rint", "Round",
1771             "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt",
1772             "Square", "Tan", "Tanh", "Real", "Imag", "Erf", "Erfc", "Erfinv",
1773             "Lgamma", "Digamma",
1774             // Binary
1775             "Add", "AddV2", "Sub", "Mul", "Div", "Atan2", "Complex", "DivNoNan",
1776             "MulNoNan", "FloorDiv", "Xlogy", "Xlog1py", "Xdivy", "FloorMod",
1777             "BitwiseAnd", "BitwiseOr", "BitwiseXor", "LeftShift", "RightShift",
1778             "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
1779             "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "TruncateDiv",
1780             "TruncateMod", "Equal", "NotEqual", "Greater", "GreaterEqual",
1781             "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", "SoftsignGrad",
1782             "TanhGrad", "Pow", "SquaredDifference", "ApproximateEqual",
1783             // Others
1784             "AddN", "Bitcast", "Cast", "ClipByValue", "Const", "Empty",
1785             "Identity", "IdentityN", "Relu", "Relu6", "ReluGrad", "Relu6Grad",
1786             "LeakyReluGrad", "Elu", "EluGrad", "Selu", "SeluGrad", "Select",
1787             "SelectV2", "Transpose", "ConjugateTranspose",
1788             "_UnaryOpsComposition",
1789             // The following 4 operations are converted to identity
1790             "PlaceholderWithDefault", "PreventGradient", "StopGradient",
1791             "Snapshot"}},
1792           // clang-format off
1793     {"RED",
1794      {"All", "Any", "Min", "Max", "Mean", "Prod", "Sum"}},
1795           // clang-format on
1796           {"PWRED",
1797            {"ArgMax", "ArgMin", "DiagPart", "Softmax",
1798             "SparseSoftmaxCrossEntropyWithLogits", "LogSoftmax"}},
1799           {"REDUCEWINDOW",
1800            {"ArgMax", "ArgMin", "DiagPart", "Softmax",
1801             "SparseSoftmaxCrossEntropyWithLogits", "LogSoftmax"}},
1802           {"REDUCEWINDOWPW", {"BiasAddGrad", "LRN", "LRNGrad"}},
1803           {"BN",
1804            {"FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3",
1805             "_FusedBatchNormEx", "FusedBatchNormGrad", "FusedBatchNormGradV2",
1806             "FusedBatchNormGradV3"}},
1807           {"SORT", {"TopKV2"}},  // XLA version much faster then TF version.
1808           {"MISC",
1809            // clang-format off
1810      {"BroadcastTo", "ExpandDims", "Fill", "NoOp",
1811       "Range", "Rank", "Reshape", "Shape", "ShapeN", "Size", "Squeeze",
1812       "Transpose", "ZerosLike", "OnesLike", "BiasAdd" /*PW + Broadcast*/,
1813       "BroadcastArgs", "BroadcastGradientArgs", "OneHot", "Concat", "ConcatV2",
1814       "ConcatOffset", "Const", "MirrorPad", "MirrorPadGrad", "Pack", "Pad",
1815       "PadV2", "Reverse", "ReverseV2", "ReverseSequence", "Slice", "Split",
1816       "SplitV", "StridedSlice", "StridedSliceGrad",
1817       "ResourceStridedSliceAssign", "Tile", "Transpose", "InvertPermutation",
1818       "Unpack", "DeviceIndex", "TensorStridedSliceUpdate",
1819      }}};
1820   // clang-format on
1821   return result;
1822 }
1823 
1824 namespace testing {
ResetClusterSequenceNumber()1825 void ResetClusterSequenceNumber() { cluster_sequence_num = 0; }
1826 
GetKnownXLAAllowlistOp()1827 absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
1828   absl::flat_hash_set<string> result{"AdjustContrastv2",
1829                                      "AdjustHue",
1830                                      "AdjustSaturation",
1831                                      "Asinh",
1832                                      "Assert",
1833                                      "AssignAddVariableOp",
1834                                      "AssignSubVariableOp",
1835                                      "AssignVariableOp",
1836                                      "AvgPool",
1837                                      "AvgPool3D",
1838                                      "AvgPool3DGrad",
1839                                      "AvgPoolGrad",
1840                                      "BatchMatMul",
1841                                      "BatchMatMulV2",
1842                                      "BatchToSpace",
1843                                      "BatchToSpaceND",
1844                                      "BesselI0e",
1845                                      "BesselI1e",
1846                                      "Betainc",
1847                                      "BiasAddV1",
1848                                      "Bucketize",
1849                                      "Case",
1850                                      "CheckNumerics",
1851                                      "Cholesky",
1852                                      "ControlTrigger",
1853                                      "Conv2D",
1854                                      "Conv2DBackpropFilter",
1855                                      "Conv2DBackpropInput",
1856                                      "Conv3D",
1857                                      "Conv3DBackpropFilterV2",
1858                                      "Conv3DBackpropInputV2",
1859                                      "Cross",
1860                                      "Cumprod",
1861                                      "Cumsum",
1862                                      "DataFormatDimMap",
1863                                      "DataFormatVecPermute",
1864                                      "DepthToSpace",
1865                                      "DepthwiseConv2dNative",
1866                                      "DepthwiseConv2dNativeBackpropFilter",
1867                                      "DepthwiseConv2dNativeBackpropInput",
1868                                      "Dequantize",
1869                                      "Diag",
1870                                      "DynamicStitch",
1871                                      "Einsum",
1872                                      "EmptyTensorList",
1873                                      "EnsureShape",
1874                                      "ExtractImagePatches",
1875                                      "Igamma",
1876                                      "IgammaGradA",
1877                                      "RandomGammaGrad",
1878                                      "Igammac",
1879                                      "FFT",
1880                                      "FFT2D",
1881                                      "FFT3D",
1882                                      "FakeParam",
1883                                      "FakeQuantWithMinMaxArgs",
1884                                      "FakeQuantWithMinMaxArgsGradient",
1885                                      "FakeQuantWithMinMaxVars",
1886                                      "FakeQuantWithMinMaxVarsGradient",
1887                                      "Gather",
1888                                      "GatherNd",
1889                                      "GatherV2",
1890                                      "HSVToRGB",
1891                                      "IFFT",
1892                                      "IFFT2D",
1893                                      "IFFT3D",
1894                                      "IRFFT",
1895                                      "IRFFT2D",
1896                                      "IRFFT3D",
1897                                      "If",
1898                                      "InTopKV2",
1899                                      "L2Loss",
1900                                      "LeakyRelu",
1901                                      "LinSpace",
1902                                      "ListDiff",
1903                                      "LogMatrixDeterminant",
1904                                      "LowerBound",
1905                                      "MatMul",
1906                                      "MatrixBandPart",
1907                                      "MatrixDiag",
1908                                      "MatrixDiagPart",
1909                                      "MatrixDiagPartV2",
1910                                      "MatrixDiagPartV3",
1911                                      "MatrixDiagV2",
1912                                      "MatrixDiagV3",
1913                                      "MatrixInverse",
1914                                      "MatrixSetDiag",
1915                                      "MatrixSetDiagV2",
1916                                      "MatrixSetDiagV3",
1917                                      "MatrixSolve",
1918                                      "MatrixTriangularSolve",
1919                                      "MaxPool",
1920                                      "MaxPool3D",
1921                                      "MaxPool3DGrad",
1922                                      "MaxPool3DGradGrad",
1923                                      "MaxPoolGrad",
1924                                      "MaxPoolGradGrad",
1925                                      "MaxPoolGradGradV2",
1926                                      "MaxPoolGradV2",
1927                                      "MaxPoolV2",
1928                                      "Multinomial",
1929                                      "NextAfter",
1930                                      "NonMaxSuppressionV4",
1931                                      "ParallelDynamicStitch",
1932                                      "ParameterizedTruncatedNormal",
1933                                      "PartitionedCall",
1934                                      "Polygamma",
1935                                      "PopulationCount",
1936                                      "Qr",
1937                                      "QuantizeAndDequantizeV2",
1938                                      "QuantizeAndDequantizeV3",
1939                                      "RFFT",
1940                                      "RFFT2D",
1941                                      "RFFT3D",
1942                                      "RGBToHSV",
1943                                      "RandomShuffle",
1944                                      "RandomStandardNormal",
1945                                      "RandomUniform",
1946                                      "RandomUniformInt",
1947                                      "ReadVariableOp",
1948                                      "ResizeBilinear",
1949                                      "ResizeBilinearGrad",
1950                                      "ResizeNearestNeighbor",
1951                                      "ResourceApplyAdaMax",
1952                                      "ResourceApplyAdadelta",
1953                                      "ResourceApplyAdagrad",
1954                                      "ResourceApplyAdagradDA",
1955                                      "ResourceApplyAdagradV2",
1956                                      "ResourceApplyAdam",
1957                                      "ResourceApplyAddSign",
1958                                      "ResourceApplyCenteredRMSProp",
1959                                      "ResourceApplyFtrl",
1960                                      "ResourceApplyFtrlV2",
1961                                      "ResourceApplyGradientDescent",
1962                                      "ResourceApplyKerasMomentum",
1963                                      "ResourceApplyMomentum",
1964                                      "ResourceApplyPowerSign",
1965                                      "ResourceApplyProximalAdagrad",
1966                                      "ResourceApplyProximalGradientDescent",
1967                                      "ResourceApplyRMSProp",
1968                                      "ResourceGather",
1969                                      "ResourceScatterAdd",
1970                                      "ResourceScatterDiv",
1971                                      "ResourceScatterMax",
1972                                      "ResourceScatterMin",
1973                                      "ResourceScatterMul",
1974                                      "ResourceScatterNdAdd",
1975                                      "ResourceScatterNdSub",
1976                                      "ResourceScatterNdUpdate",
1977                                      "ResourceScatterSub",
1978                                      "ResourceScatterUpdate",
1979                                      "RngReadAndSkip",
1980                                      "RngSkip",
1981                                      "Roll",
1982                                      "ScatterNd",
1983                                      "SelfAdjointEigV2",
1984                                      "SoftmaxCrossEntropyWithLogits",
1985                                      "SpaceToBatch",
1986                                      "SpaceToBatchND",
1987                                      "SpaceToDepth",
1988                                      "SparseMatMul",
1989                                      "SparseToDense",
1990                                      "StackCloseV2",
1991                                      "StackPopV2",
1992                                      "StackPushV2",
1993                                      "StackV2",
1994                                      "StatefulPartitionedCall",
1995                                      "StatefulStandardNormalV2",
1996                                      "StatefulTruncatedNormal",
1997                                      "StatefulUniform",
1998                                      "StatefulUniformFullInt",
1999                                      "StatefulUniformInt",
2000                                      "StatelessCase",
2001                                      "StatelessIf",
2002                                      "StatelessMultinomial",
2003                                      "StatelessRandomGetAlg",
2004                                      "StatelessRandomGetKeyCounter",
2005                                      "StatelessRandomGetKeyCounterAlg",
2006                                      "StatelessRandomNormal",
2007                                      "StatelessRandomNormalV2",
2008                                      "StatelessRandomUniform",
2009                                      "StatelessRandomUniformV2",
2010                                      "StatelessRandomUniformInt",
2011                                      "StatelessRandomUniformIntV2",
2012                                      "StatelessRandomUniformFullInt",
2013                                      "StatelessRandomUniformFullIntV2",
2014                                      "StatelessTruncatedNormal",
2015                                      "StatelessTruncatedNormalV2",
2016                                      "StatelessWhile",
2017                                      "Svd",
2018                                      "SymbolicGradient",
2019                                      "TensorArrayCloseV3",
2020                                      "TensorArrayConcatV3",
2021                                      "TensorArrayGatherV3",
2022                                      "TensorArrayGradV3",
2023                                      "TensorArrayReadV3",
2024                                      "TensorArrayScatterV3",
2025                                      "TensorArraySizeV3",
2026                                      "TensorArraySplitV3",
2027                                      "TensorArrayV3",
2028                                      "TensorArrayWriteV3",
2029                                      "TensorListConcatV2",
2030                                      "TensorListElementShape",
2031                                      "TensorListFromTensor",
2032                                      "TensorListGather",
2033                                      "TensorListGetItem",
2034                                      "TensorListLength",
2035                                      "TensorListPopBack",
2036                                      "TensorListPushBack",
2037                                      "TensorListReserve",
2038                                      "TensorListSetItem",
2039                                      "TensorListSplit",
2040                                      "TensorListStack",
2041                                      "TensorScatterAdd",
2042                                      "TensorScatterMax",
2043                                      "TensorScatterMin",
2044                                      "TensorScatterSub",
2045                                      "TensorScatterUpdate",
2046                                      "TridiagonalSolve",
2047                                      "TruncatedNormal",
2048                                      "Unique",
2049                                      "UpperBound",
2050                                      "UnsortedSegmentMax",
2051                                      "UnsortedSegmentMin",
2052                                      "UnsortedSegmentProd",
2053                                      "UnsortedSegmentSum",
2054                                      "VarIsInitializedOp",
2055                                      "VariableShape",
2056                                      "Where",
2057                                      "While",
2058                                      "XlaBroadcastHelper",
2059                                      "XlaConv",
2060                                      "XlaDequantize",
2061                                      "XlaDot",
2062                                      "XlaDynamicSlice",
2063                                      "XlaDynamicUpdateSlice",
2064                                      "XlaEinsum",
2065                                      "XlaGather",
2066                                      "XlaIf",
2067                                      "XlaKeyValueSort",
2068                                      "XlaPad",
2069                                      "XlaRecv",
2070                                      "XlaReduce",
2071                                      "XlaReduceWindow",
2072                                      "XlaReplicaId",
2073                                      "XlaScatter",
2074                                      "XlaSelectAndScatter",
2075                                      "XlaSelfAdjointEig",
2076                                      "XlaSend",
2077                                      "XlaSetBound",
2078                                      "XlaSetDynamicDimensionSize",
2079                                      "XlaSharding",
2080                                      "XlaSort",
2081                                      "XlaSpmdFullToShardShape",
2082                                      "XlaSpmdShardToFullShape",
2083                                      "XlaSvd",
2084                                      "XlaVariadicReduce",
2085                                      "XlaVariadicSort",
2086                                      "XlaWhile",
2087                                      "Zeta",
2088                                      "_Arg",
2089                                      "_ArrayToList",
2090                                      "_ListToArray",
2091                                      "_Retval"};
2092   return result;
2093 }
2094 
2095 }  // namespace testing
2096 }  // namespace tensorflow
2097