1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <iterator>
18 #include <memory>
19 #include <tuple>
20 #include <utility>
21 
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/ADT/iterator_range.h"
30 #include "llvm/Support/Casting.h"
31 #include "mlir/IR/Attributes.h"  // from @llvm-project
32 #include "mlir/IR/Builders.h"  // from @llvm-project
33 #include "mlir/IR/Identifier.h"  // from @llvm-project
34 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
35 #include "mlir/IR/Operation.h"  // from @llvm-project
36 #include "mlir/IR/Types.h"  // from @llvm-project
37 #include "mlir/IR/Value.h"  // from @llvm-project
38 #include "mlir/Pass/Pass.h"  // from @llvm-project
39 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
40 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
41 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
42 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
45 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
46 
47 namespace mlir {
48 namespace TFTPU {
49 
50 namespace {
51 
52 constexpr char kTPUReplicateAttr[] = "_tpu_replicate";
53 constexpr char kDeviceAttr[] = "device";
54 constexpr char kNameAttr[] = "name";
55 constexpr char kNumCoresPerReplicaAttr[] = "num_cores_per_replica";
56 constexpr char kNumReplicasAttr[] = "num_replicas";
57 constexpr char kReplicatedInputIndicesAttr[] = "_replicated_input_indices";
58 constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices";
59 
60 constexpr char kBadTPUReplicateAttrMsg[] =
61     "requires '_tpu_replicate' string attribute";
62 
63 // Mapping for `_tpu_replicate` attribute to TPUReplicateMetadata attributes.
64 using MetadataMap = llvm::SmallDenseMap<llvm::StringRef, NamedAttrList, 8>;
65 
66 // A set of operations in a cluster.
67 using ClusterOps = llvm::SmallSetVector<Operation*, 8>;
68 
69 // Mapping for `_tpu_replicate` attribute to ops of a cluster.
70 using ClusterMap = llvm::SmallDenseMap<llvm::StringRef, ClusterOps, 8>;
71 
72 struct TPUClusterFormationPass
73     : public TF::TPUClusterFormationPassBase<TPUClusterFormationPass> {
getDependentDialectsmlir::TFTPU::__anonff832dfe0111::TPUClusterFormationPass74   void getDependentDialects(DialectRegistry& registry) const override {
75     registry.insert<tf_device::TensorFlowDeviceDialect>();
76   }
77 
78   void runOnOperation() override;
79 };
80 
81 // Creates a mapping from the TPUReplicateMetadata ops `_tpu_replicate`
82 // attribute to its attributes and removes the ops. If multiple
83 // TPUReplicateMetadata ops have the same `_tpu_replicate` attribute, an error
84 // will be returned.
CollectMetadata(Block * block,MetadataMap * metadata_map)85 LogicalResult CollectMetadata(Block* block, MetadataMap* metadata_map) {
86   // Just look at top-level operations in the block (not nested ones)
87   for (Operation& op : llvm::make_early_inc_range(*block)) {
88     auto metadata_op = dyn_cast<TF::TPUReplicateMetadataOp>(op);
89     if (!metadata_op) continue;
90 
91     NamedAttrList attrs(metadata_op->getAttrDictionary());
92 
93     // Missing or bad `_tpu_replicate` attribute.
94     auto tpu_replicate_attr = attrs.get(kTPUReplicateAttr);
95     if (!tpu_replicate_attr)
96       return metadata_op.emitError() << kBadTPUReplicateAttrMsg;
97 
98     auto tpu_replicate_attr_str = tpu_replicate_attr.dyn_cast<StringAttr>();
99     if (!tpu_replicate_attr_str || tpu_replicate_attr_str.getValue().empty())
100       return metadata_op.emitError() << kBadTPUReplicateAttrMsg;
101 
102     // Remove `name` attribute.
103     attrs.erase(Identifier::get(kNameAttr, metadata_op.getContext()));
104 
105     auto it = metadata_map->try_emplace(tpu_replicate_attr_str.getValue(),
106                                         std::move(attrs));
107 
108     // There are multiple TPUReplicateMetadata ops with the same
109     // `_tpu_replicate` attribute.
110     if (!it.second) {
111       return metadata_op.emitError()
112              << "multiple TPUReplicateMetadata ops with the same '"
113              << kTPUReplicateAttr << "' attribute '"
114              << tpu_replicate_attr_str.getValue() << "' found";
115     }
116     metadata_op.erase();
117   }
118   return success();
119 }
120 
121 // Collects and clusters ops with the same `_tpu_replicate` attribute. This will
122 // return an error if a `_tpu_replicate` attribute of an op is empty.
CollectAndGroupClusterOps(Block * block,ClusterMap * clusters)123 LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters) {
124   for (Operation& op : *block) {
125     if (auto attr = op.getAttrOfType<StringAttr>(kTPUReplicateAttr)) {
126       if (attr.getValue().empty())
127         return op.emitError()
128                << "attribute '" << kTPUReplicateAttr << "' is empty";
129 
130       auto it = clusters->try_emplace(attr.getValue());
131       it.first->getSecond().insert(&op);
132     }
133   }
134 
135   return success();
136 }
137 
138 // Collects all resource ids from an op.
CollectResourceIdsFromOp(Operation & op,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis,llvm::SmallDenseSet<int64_t> & observed_resource_ids)139 void CollectResourceIdsFromOp(
140     Operation& op,
141     const TF::ResourceAliasAnalysis::Info& resource_alias_analysis,
142     llvm::SmallDenseSet<int64_t>& observed_resource_ids) {
143   op.walk([&](Operation* inner_op) {
144     for (Value operand : TF::filter_resources(inner_op->getOperands())) {
145       if (resource_alias_analysis.IsUnknownResource(operand)) continue;
146       const auto& ids = resource_alias_analysis.GetResourceUniqueIds(operand);
147       observed_resource_ids.insert(ids.begin(), ids.end());
148     }
149     for (Value result : TF::filter_resources(inner_op->getResults())) {
150       if (resource_alias_analysis.IsUnknownResource(result)) continue;
151       const auto& ids = resource_alias_analysis.GetResourceUniqueIds(result);
152       observed_resource_ids.insert(ids.begin(), ids.end());
153     }
154   });
155 }
156 
157 // Checks if an op should be moved after a cluster. There may be users of a
158 // cluster interleaved among the cluster ops.
ShouldMoveOpAfterCluster(Block * block,Operation * op,const ClusterOps & cluster_ops,const llvm::SmallSetVector<Operation *,8> & preceding_users,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis,const llvm::SmallDenseSet<int64_t> & observed_resource_ids)159 bool ShouldMoveOpAfterCluster(
160     Block* block, Operation* op, const ClusterOps& cluster_ops,
161     const llvm::SmallSetVector<Operation*, 8>& preceding_users,
162     const TF::ResourceAliasAnalysis::Info& resource_alias_analysis,
163     const llvm::SmallDenseSet<int64_t>& observed_resource_ids) {
164   const bool is_replicate = llvm::isa<tf_device::ReplicateOp>(op);
165   auto result = op->walk([&](Operation* inner_op) {
166     for (Value operand : inner_op->getOperands()) {
167       Operation* def = operand.getDefiningOp();
168       // Operands may not have a defining op (BlockArgument) or is from a
169       // different block.
170       if (!def || def->getBlock() != block) continue;
171 
172       if (cluster_ops.count(def) != 0 || preceding_users.count(def) != 0) {
173         // Op is a user of a cluster or another op that is a user of the
174         // cluster (transitively), but is before the cluster.
175         return WalkResult::interrupt();
176       }
177     }
178 
179     // Don't visit replicate op inner op operands as new resource
180     // values/arguments may have been created but are not known in
181     // `resource_alias_analysis`.
182     if (is_replicate && inner_op != op) return WalkResult::advance();
183 
184     // Check for uses of any resource in or after cluster.
185     for (Value operand : TF::filter_resources(inner_op->getOperands())) {
186       if (resource_alias_analysis.IsUnknownResource(operand)) continue;
187       auto ids = resource_alias_analysis.GetResourceUniqueIds(operand);
188       for (const auto& id : ids)
189         if (observed_resource_ids.contains(id)) return WalkResult::interrupt();
190     }
191     return WalkResult::advance();
192   });
193 
194   return result.wasInterrupted();
195 }
196 
197 // Collects ops that are before ops in the cluster but are users of other ops
198 // in the cluster. This may happen because users of individual ops in the
199 // cluster may be interleaved with other ops in the cluster. Resource id's are
200 // also captured, to keep track of resource usage before, in, or after the
201 // cluster.
202 // TODO(b/175701589): Extend this to handle all side effecting ops while
203 // handling transitive data dependencies.
CollectClusterPrecedingUsers(Block * block,const ClusterOps & cluster_ops,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)204 llvm::SmallSetVector<Operation*, 8> CollectClusterPrecedingUsers(
205     Block* block, const ClusterOps& cluster_ops,
206     const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
207   llvm::SmallSetVector<Operation*, 8> preceding_users;
208   llvm::SmallDenseSet<int64_t> observed_resource_ids;
209 
210   auto front = Block::iterator(cluster_ops.front());
211   auto back = Block::iterator(cluster_ops.back());
212   for (Operation& op : llvm::make_range(front, back)) {
213     if (cluster_ops.contains(&op)) {
214       CollectResourceIdsFromOp(op, resource_alias_analysis,
215                                observed_resource_ids);
216     } else if (ShouldMoveOpAfterCluster(
217                    block, &op, cluster_ops, preceding_users,
218                    resource_alias_analysis, observed_resource_ids)) {
219       preceding_users.insert(&op);
220       CollectResourceIdsFromOp(op, resource_alias_analysis,
221                                observed_resource_ids);
222     }
223   }
224 
225   return preceding_users;
226 }
227 
228 // Collects results and associated types of the cluster that are used outside of
229 // the cluster. These results and types are used to create the clusters
230 // `tf_device.cluster` and associated terminator. Results that have no uses
231 // outside of the cluster (i.e. results of ops in the cluster are only consumed
232 // by other ops in the cluster) are pruned.
CollectClusterResults(Block * block,const ClusterOps & cluster_ops)233 llvm::SmallVector<Value, 8> CollectClusterResults(
234     Block* block, const ClusterOps& cluster_ops) {
235   llvm::SmallVector<Value, 8> results;
236 
237   for (Operation* op : cluster_ops) {
238     for (Value result : op->getResults()) {
239       for (Operation* user : result.getUsers()) {
240         // Check if user is not an op in the cluster.
241         if (cluster_ops.count(block->findAncestorOpInBlock(*user)) == 0) {
242           results.push_back(result);
243           break;
244         }
245       }
246     }
247   }
248 
249   return results;
250 }
251 
252 // Creates a `tf_device.cluster` to wrap cluster ops.
CreateClusterOp(Block * block,const ClusterOps & cluster_ops,llvm::ArrayRef<Value> results,llvm::ArrayRef<Operation * > preceding_users)253 tf_device::ClusterOp CreateClusterOp(
254     Block* block, const ClusterOps& cluster_ops, llvm::ArrayRef<Value> results,
255     llvm::ArrayRef<Operation*> preceding_users) {
256   // `tf_device.cluster` will be placed at where the last op of the cluster is.
257   Operation* last_cluster_op = cluster_ops.back();
258   OpBuilder builder(last_cluster_op);
259 
260   llvm::SmallVector<Type, 8> result_types;
261   for (Value result : results) result_types.push_back(result.getType());
262   auto cluster = builder.create<tf_device::ClusterOp>(last_cluster_op->getLoc(),
263                                                       result_types);
264 
265   Block* body = new Block;
266   cluster.body().push_back(body);
267 
268   // Move cluster ops to the cluster body. Also remove `_tpu_replicate` and
269   // `device` attribute from ops in the cluster as that information will be
270   // present in the `tf_device.cluster`. Do this for all ops including nested
271   // ops.
272   for (Operation* cluster_op : cluster_ops) {
273     cluster_op->moveBefore(body, body->end());
274     cluster_op->walk([&](Operation* inner_op) {
275       inner_op->removeAttr(kTPUReplicateAttr);
276       inner_op->removeAttr(kDeviceAttr);
277     });
278   }
279 
280   // Add terminator.
281   builder.setInsertionPointToEnd(body);
282   builder.create<tf_device::ReturnOp>(last_cluster_op->getLoc(), results);
283 
284   // Replaces uses of cluster ops results outside of cluster with the associated
285   // `tf_device.cluster` results.
286   for (auto ret_vals : llvm::zip(results, cluster.getResults())) {
287     Value old_ret = std::get<0>(ret_vals);
288     Value new_ret = std::get<1>(ret_vals);
289     for (auto& use : llvm::make_early_inc_range(old_ret.getUses())) {
290       Operation* user = use.getOwner();
291       if (!body->findAncestorOpInBlock(*user)) use.set(new_ret);
292     }
293   }
294 
295   // Move users of cluster that are before the cluster to after the cluster.
296   Operation* op_after_cluster = cluster.getOperation()->getNextNode();
297   for (Operation* user : preceding_users) user->moveBefore(op_after_cluster);
298   return cluster;
299 }
300 
301 // Sorts `tf.TPUReplicatedInput` ops by `index` attribute. Ops with an `index`
302 // of -1 are always after ops with a non negative `index`, and an arbitrary
303 // ordering is used as there are no dependencies on their relative ordering. If
304 // there are multiple `tf.TPUReplicatedInput` ops with the same non negative
305 // index or if indices are less than -1, an error will be returned.
SortTPUReplicatedInputsByIndex(llvm::ArrayRef<Operation * > inputs,llvm::SmallVectorImpl<Operation * > * sorted_inputs)306 LogicalResult SortTPUReplicatedInputsByIndex(
307     llvm::ArrayRef<Operation*> inputs,
308     llvm::SmallVectorImpl<Operation*>* sorted_inputs) {
309   llvm::SmallDenseSet<int64_t, 8> unique_indices;
310   for (Operation* input : inputs) {
311     int64_t index = llvm::cast<TF::TPUReplicatedInputOp>(input).index();
312     if (index < -1)
313       return input->emitOpError()
314              << "requires index to be at least -1, but got " << index;
315     if (index == -1) continue;
316     if (!unique_indices.insert(index).second)
317       return input->emitOpError()
318              << "requires indices to be unique, but found multiple '"
319              << input->getName() << "' ops with index " << index;
320   }
321 
322   // Sort all TPUReplicatedInputs by `index` attribute to have
323   // TPUReplicatedInputs with indices be added to the `tf_device.replicate` op
324   // deterministically. If `index` attribute is -1, instead move them to the
325   // end.
326   sorted_inputs->assign(inputs.begin(), inputs.end());
327   std::stable_sort(
328       sorted_inputs->begin(), sorted_inputs->end(),
329       [](Operation* l, Operation* r) {
330         int64_t l_index = llvm::cast<TF::TPUReplicatedInputOp>(l).index();
331         int64_t r_index = llvm::cast<TF::TPUReplicatedInputOp>(r).index();
332         if (l_index == -1 && r_index != -1) return false;
333         if (r_index == -1 && l_index != -1) return true;
334         return l_index < r_index;
335       });
336 
337   return success();
338 }
339 
340 // Creates a `tf_device.replicate` to represent replication for the cluster, if
341 // necessary.
ReplicateCluster(tf_device::ClusterOp cluster,int num_replicas,int num_cores_per_replica)342 LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas,
343                                int num_cores_per_replica) {
344   // No need to replicate.
345   if (num_replicas == 1) return success();
346 
347   if (num_replicas < 1)
348     return cluster.emitError() << "requires '" << kNumReplicasAttr
349                                << "' int attribute to be at least 1";
350 
351   LogicalResult status = success();
352   // Collect all used TPUReplicatedInput ops and sort by `index`.
353   llvm::SmallSetVector<Operation*, 8> unique_replicated_input_ops;
354   mlir::visitUsedValuesDefinedAbove(
355       cluster.body(), cluster.body(), [&](mlir::OpOperand* operand) {
356         Operation* def = operand->get().getDefiningOp();
357         if (llvm::isa_and_nonnull<TF::TPUReplicatedInputOp>(def))
358           unique_replicated_input_ops.insert(def);
359         // When model parallelism is used in conjunction with data parallelism
360         // for resource inputs, we need to collect the per replica resource
361         // inputs from input to `tf.TPUPartitionedInput` ops.
362         if (auto pi = llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(def)) {
363           if (pi->getNumOperands() != num_cores_per_replica)
364             status = pi.emitOpError()
365                      << "requires " << num_cores_per_replica
366                      << " operands but found " << pi->getNumOperands();
367           for (auto operand : pi.inputs()) {
368             if (llvm::isa_and_nonnull<TF::TPUReplicatedInputOp>(
369                     operand.getDefiningOp()))
370               unique_replicated_input_ops.insert(operand.getDefiningOp());
371           }
372         }
373       });
374 
375   if (failed(status)) return failure();
376   llvm::SmallVector<Operation*, 8> replicated_input_ops;
377   if (failed(SortTPUReplicatedInputsByIndex(
378           unique_replicated_input_ops.getArrayRef(), &replicated_input_ops)))
379     return failure();
380 
381   // Index attribute value stored on TPUReplicatedInput op. These will be used
382   // later for dynamic padder.
383   llvm::SmallVector<int64_t, 8> replicated_input_indices;
384   llvm::SmallVector<int64_t, 8> packed_input_indices;
385   bool has_replicated_input_index = false;
386 
387   // Indices of the replicate op's arguments that are mirrored variables.
388   llvm::SmallVector<int64_t, 8> mirrored_variable_indices;
389 
390   // Check if number of operands of each used TPUReplicatedInput op matches
391   // `num_replicas` or 1. Collect all their operands and associated type for
392   // creating the replicate op.
393   llvm::SmallVector<std::pair<ValueRange, Type>, 8> replicated_inputs;
394   llvm::SmallVector<Value, 8> packed_inputs;
395   for (auto& pos_and_input : llvm::enumerate(replicated_input_ops)) {
396     auto input = pos_and_input.value();
397     bool is_packed = llvm::cast<TF::TPUReplicatedInputOp>(input).is_packed();
398     const int num_operands = input->getNumOperands();
399     int num_inputs = is_packed ? 1 : num_replicas;
400     if (num_operands != num_inputs)
401       return input->emitOpError() << "requires " << num_inputs << " operands";
402 
403     auto tpu_replicated_input = llvm::cast<TF::TPUReplicatedInputOp>(input);
404     int64_t tpu_replicated_input_index = tpu_replicated_input.index();
405     if (is_packed) {
406       packed_inputs.push_back(input->getOperand(0));
407       packed_input_indices.push_back(tpu_replicated_input_index);
408     } else {
409       replicated_inputs.push_back(
410           {input->getOperands(), input->getOperand(0).getType()});
411       replicated_input_indices.push_back(tpu_replicated_input_index);
412     }
413     if (tpu_replicated_input_index != -1) has_replicated_input_index = true;
414 
415     if (tpu_replicated_input.is_mirrored_variable())
416       mirrored_variable_indices.push_back(pos_and_input.index());
417   }
418 
419   replicated_input_indices.append(packed_input_indices.begin(),
420                                   packed_input_indices.end());
421 
422   // Create replicate op.
423   OpBuilder builder(cluster);
424   auto replicate_op = builder.create<tf_device::ReplicateOp>(
425       cluster.getLoc(), num_replicas,
426       llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>(),
427       replicated_inputs, packed_inputs, cluster.getResultTypes());
428   if (has_replicated_input_index)
429     replicate_op->setAttr(kReplicatedInputIndicesAttr,
430                           builder.getI64ArrayAttr(replicated_input_indices));
431 
432   if (!mirrored_variable_indices.empty())
433     replicate_op->setAttr(kMirroredVariableIndicesAttr,
434                           builder.getI64ArrayAttr(mirrored_variable_indices));
435 
436   // Replace replicated cluster results with replicate op results.
437   for (auto result_and_idx : llvm::enumerate(cluster.getResults())) {
438     Value result = result_and_idx.value();
439     int idx = result_and_idx.index();
440     auto replicate_outputs = llvm::make_range(
441         std::next(replicate_op.result_begin(), idx * num_replicas),
442         std::next(replicate_op.result_begin(), (idx + 1) * num_replicas));
443 
444     for (auto& use : llvm::make_early_inc_range(result.getUses())) {
445       Operation* def = use.getOwner();
446       if (!llvm::isa<TF::TPUReplicatedOutputOp>(def)) {
447         // If user is not a `tf.TPUReplicatedOutput`, simply forward the first
448         // replica output. Certain Graphs under V1 create `tf.Identity` users of
449         // replicated ops to pin the TPU computation for execution.
450         use.set(*replicate_outputs.begin());
451         continue;
452       }
453 
454       const int def_num_results = def->getNumResults();
455       if (def_num_results != num_replicas)
456         return def->emitOpError() << "requires " << num_replicas << " results";
457 
458       def->replaceAllUsesWith(replicate_outputs);
459     }
460   }
461 
462   // Collect all `tf.TPUPartitionedInput` ops to be moved inside the
463   // `tf_device.replicate` later.
464   llvm::SmallSet<Operation*, 4> partitioned_inputs;
465   // Update replicated inputs with replicate op block arguments.
466   for (auto input_and_block_arg :
467        llvm::zip(replicated_input_ops, replicate_op.GetBody().getArguments())) {
468     Operation* input = std::get<0>(input_and_block_arg);
469     Value block_arg = std::get<1>(input_and_block_arg);
470     mlir::replaceAllUsesInRegionWith(input->getResult(0), block_arg,
471                                      cluster.body());
472     // Update replicated input use in tf.TPUPartitionedInput op.
473     for (auto& use : input->getUses()) {
474       auto pi = llvm::dyn_cast<TF::TPUPartitionedInputOp>(use.getOwner());
475       if (pi) {
476         pi.setOperand(use.getOperandNumber(), block_arg);
477         partitioned_inputs.insert(pi.getOperation());
478       }
479     }
480   }
481 
482   // Create terminator for replicate op and move `tf_device.cluster` and
483   // `tf.TPUPartitionedInput`(s) into replicate body.
484   builder.setInsertionPointToEnd(&replicate_op.GetBody());
485   auto return_op = builder.create<tf_device::ReturnOp>(replicate_op.getLoc(),
486                                                        cluster.getResults());
487   for (auto pi : partitioned_inputs) pi->moveBefore(return_op);
488 
489   cluster.getOperation()->moveBefore(return_op);
490 
491   return success();
492 }
493 
494 // Forms clusters with ops of the same `_tpu_replicate` attribute under a block.
495 //
496 // For a given block, clusters are formed via grouping ops by `_tpu_replicate`
497 // attributes.
498 // For every cluster formed:
499 //   1. Find associated TPUReplicateMetadata attributes with the same
500 //      `_tpu_replicate` attribute.
501 //   2. Find users not in cluster that are interleaved between cluster ops.
502 //   3. Find external uses of cluster ops.
503 //   4. Create `tf_device.cluster` with results consisting of the external uses
504 //      of cluster ops determined at 3.
505 //   5. Move cluster ops to `tf_device.cluster` body.
506 //   6. Replace external uses of cluster ops uses with `tf_device.cluster`
507 //      results.
508 //   7. Move users from 2 to after the `tf_device.cluster`.
509 //   8. Wrap cluster (`tf_device.cluster`) in a `tf_device.replicate` if
510 //      attribute `num_replicas` is greater than 1.
511 //   9. Copy over TPUReplicateMetadata attributes to `tf_device.cluster`.
FormClustersInBlock(Block * block,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)512 LogicalResult FormClustersInBlock(
513     Block* block,
514     const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
515   MetadataMap metadata_map;
516   LogicalResult result = CollectMetadata(block, &metadata_map);
517   if (failed(result)) return result;
518 
519   // If there is no TPUReplicateMetadata op in this block, process blocks in
520   // regions attached to the op's in the block.
521   if (metadata_map.empty()) {
522     for (Operation& op : *block) {
523       for (Region& region : op.getRegions()) {
524         if (!llvm::hasSingleElement(region))
525           return op.emitOpError("Expected single block region");
526         if (failed(
527                 FormClustersInBlock(&region.front(), resource_alias_analysis)))
528           return failure();
529       }
530     }
531     return success();
532   }
533 
534   ClusterMap clusters;
535   result = CollectAndGroupClusterOps(block, &clusters);
536   if (failed(result)) return result;
537 
538   for (const auto& cluster_metadata_and_ops : clusters) {
539     const auto& cluster_ops = cluster_metadata_and_ops.getSecond();
540 
541     auto cluster_metadata =
542         metadata_map.find(cluster_metadata_and_ops.getFirst());
543 
544     // No TPUReplicateMetadata for a `_tpu_replicate` attribute.
545     if (cluster_metadata == metadata_map.end()) {
546       cluster_ops.front()->emitWarning()
547           << "TPUReplicateMetadata for associated '" << kTPUReplicateAttr
548           << "' attribute '" << cluster_metadata_and_ops.getFirst()
549           << "' is missing";
550       continue;
551     }
552 
553     llvm::SmallSetVector<Operation*, 8> preceding_users =
554         CollectClusterPrecedingUsers(block, cluster_ops,
555                                      resource_alias_analysis);
556 
557     llvm::SmallVector<Value, 8> results =
558         CollectClusterResults(block, cluster_ops);
559 
560     tf_device::ClusterOp cluster = CreateClusterOp(
561         block, cluster_ops, results, preceding_users.getArrayRef());
562 
563     auto num_replicas = cluster_metadata->getSecond().get(kNumReplicasAttr);
564     if (!num_replicas || !num_replicas.isa<mlir::IntegerAttr>())
565       return cluster.emitError()
566              << "requires '" << kNumReplicasAttr << "' int attribute";
567 
568     int num_cores_per_replica = 1;
569     auto num_cores_per_replica_attr =
570         cluster_metadata->getSecond()
571             .get(kNumCoresPerReplicaAttr)
572             .dyn_cast_or_null<mlir::IntegerAttr>();
573     if (num_cores_per_replica_attr)
574       num_cores_per_replica = num_cores_per_replica_attr.getInt();
575 
576     if (failed(ReplicateCluster(cluster,
577                                 num_replicas.cast<mlir::IntegerAttr>().getInt(),
578                                 num_cores_per_replica)))
579       return failure();
580 
581     // Copy TPUReplicateMetadata attributes to `tf_device.cluster`.
582     cluster->setAttrs(
583         cluster_metadata->second.getDictionary(cluster.getContext()));
584     // Exclude `num_replicas` as cluster should be replicated if necessary.
585     cluster.removeAttr(kNumReplicasAttr);
586   }
587 
588   return success();
589 }
590 
FormClustersInFunction(FuncOp func,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)591 LogicalResult FormClustersInFunction(
592     FuncOp func,
593     const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
594   if (!llvm::hasSingleElement(func))
595     return func.emitOpError("Expecting a single block function");
596 
597   if (failed(FormClustersInBlock(&func.front(), resource_alias_analysis)))
598     return failure();
599 
600   // Remove TPUReplicatedInput and TPUReplicatedOutput nodes.
601   auto remove_result = func.walk([&](Operation* op) {
602     if (!llvm::isa<TF::TPUReplicatedInputOp, TF::TPUReplicatedOutputOp>(op))
603       return WalkResult::advance();
604 
605     // Forward operand to result. When `num_replicas` attribute is 1, no
606     // `tf_device.replicate` is created and replicated (1) operands/results are
607     // untouched.
608     if (op->getNumOperands() == 1 && op->getNumResults() == 1)
609       op->getResult(0).replaceAllUsesWith(op->getOperand(0));
610 
611     // Leftover TPUReplicatedInput/TPUReplicatedOutput that are not of
612     // `num_replicas` to 1.
613     if (!op->use_empty()) {
614       op->emitOpError() << "is expected to have no uses, but it is operand#"
615                         << op->use_begin()->getOperandNumber() << " of "
616                         << *op->use_begin()->getOwner();
617       return WalkResult::interrupt();
618     }
619 
620     op->erase();
621 
622     return WalkResult::advance();
623   });
624 
625   return failure(remove_result.wasInterrupted());
626 }
627 
runOnOperation()628 void TPUClusterFormationPass::runOnOperation() {
629   auto& resource_alias_analysis = getAnalysis<TF::ResourceAliasAnalysis>();
630   for (auto func : getOperation().getOps<FuncOp>())
631     if (!func.isExternal() &&
632         failed(FormClustersInFunction(
633             func, resource_alias_analysis.GetAnalysisForFunc(func))))
634       return signalPassFailure();
635 }
636 }  // anonymous namespace
637 
CreateTPUClusterFormationPass()638 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUClusterFormationPass() {
639   return std::make_unique<TPUClusterFormationPass>();
640 }
641 
642 }  // namespace TFTPU
643 }  // namespace mlir
644