1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering TensorFlow dialect's communication
17 // ops (TF/XLA) to the HLO dialect.
18 
19 #include <memory>
20 #include <string>
21 
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/None.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/Support/ErrorHandling.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
29 #include "mlir/IR/Attributes.h"  // from @llvm-project
30 #include "mlir/IR/Builders.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
33 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
34 #include "mlir/IR/Value.h"  // from @llvm-project
35 #include "mlir/IR/Visitors.h"  // from @llvm-project
36 #include "mlir/Pass/Pass.h"  // from @llvm-project
37 #include "mlir/Support/LLVM.h"  // from @llvm-project
38 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
39 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
41 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
42 #include "tensorflow/compiler/xla/client/sharding_builder.h"
43 #include "tensorflow/compiler/xla/primitive_util.h"
44 
45 namespace mlir {
46 namespace mhlo {
47 
48 namespace {
49 constexpr char kShardingAttr[] = "mhlo.sharding";
50 constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes";
51 const char kXlaHostTransferRendezvousNameAttr[] =
52     "_xla_host_transfer_rendezvous";
53 const char kXlaHostTransferOriginalTypeAttr[] =
54     "_xla_host_transfer_original_type";
55 
56 // A pass that legalizes TF/XLA communication ops, propagate their respective
57 // tokens (for ordering), and rewrite their respective functions and control
58 // flow ops when necessary.
59 // Note, this currently does not handle nested modules/functions or region based
60 // ops other than certain control flow ops (`mhlo.if`, `mhlo.while`).
61 class LegalizeTFCommunication
62     : public PassWrapper<LegalizeTFCommunication, OperationPass<ModuleOp>> {
getDependentDialects(DialectRegistry & registry) const63   void getDependentDialects(DialectRegistry& registry) const override {
64     registry.insert<mhlo::MhloDialect>();
65   }
66 
67  public:
68   void runOnOperation() override;
69 };
70 
71 // Checks if an op is a TF/XLA communication op.
IsCommunicationOp(Operation * op)72 bool IsCommunicationOp(Operation* op) {
73   return isa<TF::_XlaHostComputeMlirOp, TF::XlaSendToHostOp,
74              TF::XlaRecvFromHostOp>(op);
75 }
76 
77 // Checks if an op is a supported HLO control flow op.
IsControlFlowOp(Operation * op)78 bool IsControlFlowOp(Operation* op) { return isa<IfOp, WhileOp>(op); }
79 
80 // Collects control flow op ancestors of a given op, up until FuncOp. If any
81 // ancestor is not a control flow op or a FuncOp, or of a single block region,
82 // an error will be returned.
GetControlFlowAncestors(Operation * op,llvm::SmallPtrSetImpl<Operation * > & control_flow_ops,llvm::SmallPtrSetImpl<Block * > & control_flow_blocks)83 LogicalResult GetControlFlowAncestors(
84     Operation* op, llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
85     llvm::SmallPtrSetImpl<Block*>& control_flow_blocks) {
86   Block* block = op->getBlock();
87   Operation* parent = block->getParentOp();
88   while (block && parent && !isa<FuncOp>(parent)) {
89     if (!IsControlFlowOp(parent))
90       return op->emitOpError()
91              << "expects ancestor(s) to be of ['" << IfOp::getOperationName()
92              << "', '" << FuncOp::getOperationName() << "']";
93 
94     if (!llvm::hasSingleElement(block->getParent()->getBlocks()))
95       return op->emitOpError() << "expects single block region ancestor(s)";
96 
97     control_flow_ops.insert(parent);
98     control_flow_blocks.insert(block);
99 
100     parent = block->getParentOp();
101     block = parent->getBlock();
102   }
103   return success();
104 }
105 
106 // Finds communication ops in a function. `control_flow_ops` and
107 // `control_flow_blocks` will be populated with control flow op ancestors for
108 // every communication op.
FindCommunicationOps(FuncOp func,llvm::SmallPtrSetImpl<Operation * > & control_flow_ops,llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,bool & has_communication_ops)109 LogicalResult FindCommunicationOps(
110     FuncOp func, llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
111     llvm::SmallPtrSetImpl<Block*>& control_flow_blocks,
112     bool& has_communication_ops) {
113   auto result = func.walk([&](Operation* op) {
114     if (!IsCommunicationOp(op)) return WalkResult::advance();
115     has_communication_ops = true;
116     if (failed(
117             GetControlFlowAncestors(op, control_flow_ops, control_flow_blocks)))
118       return WalkResult::interrupt();
119     return WalkResult::advance();
120   });
121   return failure(result.wasInterrupted());
122 }
123 
124 // Helper struct holding a function to be rewritten, it's control flow ops that
125 // lead to a communication op or function call with a communication op
126 // (transitively), and an optional clone of itself. If `clone` is set, function
127 // calls to `original` will be replaced with `clone`.
128 struct FuncToRewrite {
129   FuncOp original;
130   llvm::SmallPtrSet<Operation*, 4> control_flow_ops;
131   llvm::SmallPtrSet<Block*, 4> control_flow_blocks;
132   FuncOp clone;
133 };
134 
135 // Finds all functions that need to be rewritten with communication ops and
136 // and associated tokens.
GetFunctionsToRewrite(ModuleOp module,llvm::SmallDenseMap<StringRef,FuncToRewrite> & funcs_to_rewrite)137 LogicalResult GetFunctionsToRewrite(
138     ModuleOp module,
139     llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs_to_rewrite) {
140   // Find functions containing communication ops.
141   SmallVector<FuncOp, 4> funcs_to_visit;
142   for (FuncOp func : module.getOps<FuncOp>()) {
143     FuncToRewrite func_to_rewrite{/*original=*/func, /*control_flow_ops=*/{},
144                                   /*control_flow_blocks=*/{},
145                                   /*clone=*/nullptr};
146     bool has_communication_ops = false;
147     if (failed(FindCommunicationOps(func, func_to_rewrite.control_flow_ops,
148                                     func_to_rewrite.control_flow_blocks,
149                                     has_communication_ops)))
150       return failure();
151 
152     if (!has_communication_ops) continue;
153     funcs_to_rewrite.insert({func.getName(), func_to_rewrite});
154     funcs_to_visit.push_back(func);
155   }
156 
157   // Find functions that call functions with communication ops, transitively.
158   while (!funcs_to_visit.empty()) {
159     SmallVector<FuncOp, 4> new_funcs_to_visit;
160     for (FuncOp& func : funcs_to_visit) {
161       auto uses = func.getSymbolUses(module);
162       if (!uses) continue;
163       for (auto& use : *uses) {
164         // Only `mlir::CallOp` is supported as this requires knowing how to
165         // rewrite arguments and results to a function.
166         if (!isa<mlir::CallOp>(use.getUser())) continue;
167         auto caller_parent_func = use.getUser()->getParentOfType<FuncOp>();
168         if (!caller_parent_func) continue;
169 
170         FuncToRewrite func_to_rewrite{/*original=*/caller_parent_func,
171                                       /*control_flow_ops=*/{},
172                                       /*control_flow_blocks=*/{},
173                                       /*clone=*/nullptr};
174         if (failed(GetControlFlowAncestors(
175                 use.getUser(), func_to_rewrite.control_flow_ops,
176                 func_to_rewrite.control_flow_blocks)))
177           return failure();
178 
179         auto it = funcs_to_rewrite.insert(
180             {caller_parent_func.getName(), func_to_rewrite});
181         if (it.second) {
182           new_funcs_to_visit.push_back(caller_parent_func);
183         } else {
184           it.first->getSecond().control_flow_ops.insert(
185               func_to_rewrite.control_flow_ops.begin(),
186               func_to_rewrite.control_flow_ops.end());
187           it.first->getSecond().control_flow_blocks.insert(
188               func_to_rewrite.control_flow_blocks.begin(),
189               func_to_rewrite.control_flow_blocks.end());
190         }
191       }
192     }
193 
194     funcs_to_visit.swap(new_funcs_to_visit);
195   }
196 
197   // Clone public functions that need to be rewritten. Function calls to this
198   // function will be replaced with the cloned function.
199   SymbolTable symbol_table(module);
200   for (auto& func : funcs_to_rewrite) {
201     if (func.getSecond().original.isPublic() &&
202         !func.getSecond().original.symbolKnownUseEmpty(module)) {
203       auto clone = func.getSecond().original.clone();
204       clone.setPrivate();
205       symbol_table.insert(clone);
206       func.getSecond().clone = clone;
207     }
208   }
209 
210   return success();
211 }
212 
213 // Assigns op sharding to an op for a given device core.
SetOpSharding(Operation * op,int64_t tpu_core)214 void SetOpSharding(Operation* op, int64_t tpu_core) {
215   std::string sharding_serialized =
216       ::xla::sharding_builder::AssignDevice(tpu_core).SerializeAsString();
217   op->setAttr(kShardingAttr,
218               StringAttr::get(op->getContext(), sharding_serialized));
219 }
220 
221 // Assigns frontend attributes holding information about data type and
222 // TensorFlow rendezvous channel name. The TensorFlow rendezvous channel name is
223 // handled differently as individual names are used per data send and receive.
SetFrontendAttributes(Operation * op,int32_t index,StringRef key,Type type,bool device_to_host)224 void SetFrontendAttributes(Operation* op, int32_t index, StringRef key,
225                            Type type, bool device_to_host) {
226   MLIRContext* context = op->getContext();
227 
228   std::string formatted_key =
229       device_to_host ? llvm::formatv("{0}_dtoh_{1}", key, index).str()
230                      : llvm::formatv("{0}_htod_{1}", key, index).str();
231 
232   auto rendezvous_name = StringAttr::get(context, formatted_key);
233   auto rendezvous_name_attr = NamedAttribute(
234       Identifier::get(kXlaHostTransferRendezvousNameAttr, context),
235       rendezvous_name);
236 
237   auto element_type = getElementTypeOrSelf(type);
238   auto xla_element_type = ::xla::TypeToPrimitiveType(element_type);
239   const std::string& xla_element_type_str =
240       ::xla::primitive_util::LowercasePrimitiveTypeName(xla_element_type);
241   auto original_type = StringAttr::get(context, xla_element_type_str);
242   auto original_type_attr =
243       NamedAttribute(Identifier::get(kXlaHostTransferOriginalTypeAttr, context),
244                      original_type);
245 
246   auto frontend_attributes = DictionaryAttr::get(
247       context,
248       ArrayRef<NamedAttribute>{rendezvous_name_attr, original_type_attr});
249   op->setAttr(kFrontendAttributesAttr, frontend_attributes);
250 }
251 
252 // Creates a `mhlo.send` op for sending value `operand`. If `tpu_core` is set,
253 // op sharding for the respective device will be set.
CreateSendOp(OpBuilder & builder,int64_t & channel_id,Location loc,Value operand,StringRef key,size_t index,const Optional<int64_t> & tpu_core,Value token)254 Value CreateSendOp(OpBuilder& builder, int64_t& channel_id, Location loc,
255                    Value operand, StringRef key, size_t index,
256                    const Optional<int64_t>& tpu_core, Value token) {
257   // type 2 == DEVICE_TO_HOST
258   auto channel_handle = ChannelHandle::get(
259       /*handle=*/builder.getI64IntegerAttr(channel_id++),
260       /*type=*/builder.getI64IntegerAttr(2), builder.getContext());
261   auto send = builder.create<SendOp>(
262       loc, token.getType(), operand, token, channel_handle,
263       /*is_host_transfer=*/builder.getBoolAttr(true));
264 
265   SetFrontendAttributes(send, index, key, operand.getType(),
266                         /*device_to_host=*/true);
267 
268   if (tpu_core) SetOpSharding(send, *tpu_core);
269 
270   return send.getResult();
271 }
272 
273 // Creates a `mhlo.recv` op for receiving a value. If `tpu_core` is set, op
274 // sharding for the respective device will be set.
CreateRecvOp(OpBuilder & builder,int64_t & channel_id,Location loc,Value result,StringRef key,size_t index,const Optional<int64_t> & tpu_core,Value token)275 Value CreateRecvOp(OpBuilder& builder, int64_t& channel_id, Location loc,
276                    Value result, StringRef key, size_t index,
277                    const Optional<int64_t>& tpu_core, Value token) {
278   // type 3 == HOST_TO_DEVICE
279   auto channel_handle = ChannelHandle::get(
280       /*handle=*/builder.getI64IntegerAttr(channel_id++),
281       /*type=*/builder.getI64IntegerAttr(3), builder.getContext());
282   auto result_type = result.getType();
283   auto recv_result_type =
284       TupleType::get(builder.getContext(), {result_type, token.getType()});
285   auto recv =
286       builder.create<RecvOp>(loc, recv_result_type, token, channel_handle,
287                              /*is_host_transfer=*/builder.getBoolAttr(true));
288 
289   SetFrontendAttributes(recv, index, key, result_type,
290                         /*device_to_host=*/false);
291 
292   if (tpu_core) SetOpSharding(recv, *tpu_core);
293 
294   auto get_tuple_element =
295       builder.create<GetTupleElementOp>(loc, recv.getResult(), /*index=*/0);
296   if (tpu_core) SetOpSharding(get_tuple_element, *tpu_core);
297 
298   result.replaceAllUsesWith(get_tuple_element);
299 
300   auto new_token = builder.create<GetTupleElementOp>(loc, recv.getResult(),
301                                                      /*index=*/1);
302   if (tpu_core) SetOpSharding(new_token, *tpu_core);
303 
304   return new_token.getResult();
305 }
306 
307 // Creates a new token if necessary, acting as a sink to previous tokens. If
308 // there is only one token in `tokens`, the only token is returned. If `tokens`
309 // is empty, `original_token` is returned instead.
CreateSinkToken(OpBuilder & builder,Location loc,ArrayRef<Value> tokens,Value original_token)310 Value CreateSinkToken(OpBuilder& builder, Location loc, ArrayRef<Value> tokens,
311                       Value original_token) {
312   if (tokens.empty()) {
313     return original_token;
314   } else if (llvm::hasSingleElement(tokens)) {
315     return tokens[0];
316   } else {
317     return builder.create<AfterAllOp>(loc, original_token.getType(), tokens)
318         .getResult();
319   }
320 }
321 
322 // Replaces `tf._XlaHostComputeMlir` with individual `mhlo.send` and `mhlo.recv`
323 // ops per operand and result. Unique Channel Id's are assigned per transfer.
324 // Sink tokens are created across all `mhlo.send` ops first and then by
325 // all `mhlo.recv` ops.
RewriteHostComputeOp(OpBuilder & builder,int64_t & channel_id,TF::_XlaHostComputeMlirOp host_compute,Value token)326 Value RewriteHostComputeOp(OpBuilder& builder, int64_t& channel_id,
327                            TF::_XlaHostComputeMlirOp host_compute,
328                            Value token) {
329   builder.setInsertionPoint(host_compute);
330   Location loc = host_compute.getLoc();
331   int64_t tpu_core = host_compute.tpu_coreAttr().getInt();
332 
333   SmallVector<Value, 4> send_tokens;
334   for (auto operand : llvm::enumerate(host_compute.inputs())) {
335     auto send_token =
336         CreateSendOp(builder, channel_id, loc, operand.value(),
337                      host_compute.send_key(), operand.index(), tpu_core, token);
338     send_tokens.push_back(send_token);
339   }
340   token = CreateSinkToken(builder, loc, send_tokens, token);
341 
342   SmallVector<Value, 4> recv_tokens;
343   for (auto result : llvm::enumerate(host_compute.outputs())) {
344     auto recv_token =
345         CreateRecvOp(builder, channel_id, loc, result.value(),
346                      host_compute.recv_key(), result.index(), tpu_core, token);
347     recv_tokens.push_back(recv_token);
348   }
349   token = CreateSinkToken(builder, loc, recv_tokens, token);
350 
351   host_compute.erase();
352   return token;
353 }
354 
355 // Replaces `tf.XlaSendToHost` with a `mhlo.send`.
RewriteSendToHostOp(OpBuilder & builder,int64_t & channel_id,TF::XlaSendToHostOp send_to_host,Value token)356 Value RewriteSendToHostOp(OpBuilder& builder, int64_t& channel_id,
357                           TF::XlaSendToHostOp send_to_host, Value token) {
358   builder.setInsertionPoint(send_to_host);
359   token = CreateSendOp(builder, channel_id, send_to_host.getLoc(),
360                        send_to_host.input(), send_to_host.key(),
361                        /*index=*/0, /*tpu_core=*/llvm::None, token);
362 
363   send_to_host.erase();
364   return token;
365 }
366 
367 // Replaces `tf.XlaRecvFromHost` with a `mhlo.recv`.
RewriteRecvFromHostOp(OpBuilder & builder,int64_t & channel_id,TF::XlaRecvFromHostOp recv_from_host,Value token)368 Value RewriteRecvFromHostOp(OpBuilder& builder, int64_t& channel_id,
369                             TF::XlaRecvFromHostOp recv_from_host, Value token) {
370   builder.setInsertionPoint(recv_from_host);
371   token = CreateRecvOp(builder, channel_id, recv_from_host.getLoc(),
372                        recv_from_host.output(), recv_from_host.key(),
373                        /*index=*/0, /*tpu_core=*/llvm::None, token);
374 
375   recv_from_host.erase();
376   return token;
377 }
378 
379 // Replaces a `mlir::CallOp` with one that has an extra `!mhlo.token` operand
380 // and `!mhlo.token` result. If `new_symbol` is set, the new call will be
381 // updated to call the `new_symbol` instead.
RewriteCallOp(OpBuilder & builder,CallOp call,const Optional<StringRef> & new_symbol,Value token)382 Value RewriteCallOp(OpBuilder& builder, CallOp call,
383                     const Optional<StringRef>& new_symbol, Value token) {
384   builder.setInsertionPoint(call);
385   auto new_operands = llvm::to_vector<4>(call.getArgOperands());
386   new_operands.push_back(token);
387   auto new_result_types = llvm::to_vector<4>(call.getResultTypes());
388   new_result_types.push_back(token.getType());
389   auto new_call = builder.create<CallOp>(
390       call.getLoc(), new_result_types, new_symbol ? *new_symbol : call.callee(),
391       new_operands);
392 
393   for (auto results : llvm::zip(call.getResults(), new_call.getResults()))
394     std::get<0>(results).replaceAllUsesWith(std::get<1>(results));
395   call.erase();
396   return new_call.getResults().back();
397 }
398 
399 // Helper struct holding state of which op to visit to next. If `op` is in a
400 // control flow op region, `region_idx` will be set with the respective region
401 // index. `token` will be current token from the last communication op/control
402 // flow op transitive communication ops.
403 struct OpVisitorState {
404   Optional<unsigned> region_idx;
405   Value token;
406   Operation* op;
407 };
408 
409 // Creates a tuple from a sequence of values.
CreateTuple(OpBuilder & builder,Location loc,ArrayRef<Value> operands)410 Value CreateTuple(OpBuilder& builder, Location loc, ArrayRef<Value> operands) {
411   return builder.create<TupleOp>(loc, operands).getResult();
412 }
413 
414 // Replaces a value `value` with a new value but the token attached. If `value`
415 // is not a tuple, a new tuple is formed with `token`. If `value` is a tuple,
416 // `value` is extended instead. New tuple values created are cached.
GetValueWithToken(OpBuilder & builder,Value value,Value token,llvm::SmallDenseMap<Value,Value> & rewritten_values)417 Value GetValueWithToken(OpBuilder& builder, Value value, Value token,
418                         llvm::SmallDenseMap<Value, Value>& rewritten_values) {
419   // If value with token already exists, reuse it.
420   auto it = rewritten_values.find(value);
421   if (it != rewritten_values.end()) return it->getSecond();
422 
423   auto create_tuple = [&](ArrayRef<Value> operands) {
424     auto new_result = CreateTuple(builder, value.getLoc(), operands);
425     rewritten_values.insert({value, new_result});
426     return new_result;
427   };
428 
429   auto tuple_type = value.getType().dyn_cast<TupleType>();
430   // `value` is not a tuple, create a new tuple.
431   if (!tuple_type) return create_tuple({value, token});
432 
433   // Extend tuple if `value` is a tuple.
434   // If `value` is an op result and the owner is a `mhlo.tuple`, simply unpack
435   // the tuple.
436   if (auto tuple_op = value.getDefiningOp<TupleOp>()) {
437     auto tuple_operands = llvm::to_vector<4>(tuple_op.getOperands());
438     tuple_operands.push_back(token);
439     return create_tuple(tuple_operands);
440   }
441 
442   // `value` is not created via a `mhlo.tuple` directly, unpack individual
443   // elements directly with `mhlo.get_tuple_element`.
444   SmallVector<Value, 4> tuple_operands;
445   for (auto idx : llvm::seq<int32_t>(0, tuple_type.getTypes().size()))
446     tuple_operands.push_back(
447         builder.create<GetTupleElementOp>(value.getLoc(), value, idx)
448             .getResult());
449 
450   tuple_operands.push_back(token);
451   return create_tuple(tuple_operands);
452 }
453 
454 // Extends a type to include a `mhlo.token` type. If `type` is not a tuple type,
455 // a new tuple type with `type` and `mhlo.token` type is created instead.
GetTypeWithToken(OpBuilder & builder,Type type)456 TupleType GetTypeWithToken(OpBuilder& builder, Type type) {
457   auto token_type = TokenType::get(builder.getContext());
458   if (auto tuple_type = type.dyn_cast<TupleType>()) {
459     auto result_types = llvm::to_vector<4>(tuple_type.getTypes());
460     result_types.push_back(token_type);
461     return builder.getTupleType(result_types);
462   }
463 
464   return builder.getTupleType({type, token_type});
465 }
466 
467 // Creates a slice of a tuple `value` with `mhlo.get_tuple_element` from index 0
468 // to `end`, exclusive.
CreateSubTuple(OpBuilder & builder,Value value,size_t end)469 Value CreateSubTuple(OpBuilder& builder, Value value, size_t end) {
470   SmallVector<Value, 4> tuple_operands;
471   for (auto idx : llvm::seq<int32_t>(0, end))
472     tuple_operands.push_back(
473         builder.create<GetTupleElementOp>(value.getLoc(), value, idx)
474             .getResult());
475 
476   return CreateTuple(builder, value.getLoc(), tuple_operands);
477 }
478 
479 // Replaces uses of `value` with `replacement`. If `value` is not a tuple type,
480 // an explicit `mhlo.get_tuple_element` is created to unpack the tuple and
481 // return the first element. Otherwise, `mhlo.get_tuple_element` users are
482 // simply updated with `replacement`, and all other users are updated with a
483 // slice of `replacement`.
ReplaceWithTupleResult(OpBuilder & builder,Value value,Value replacement)484 void ReplaceWithTupleResult(OpBuilder& builder, Value value,
485                             Value replacement) {
486   auto tuple_type = value.getType().dyn_cast<TupleType>();
487   if (!tuple_type) {
488     if (!value.use_empty()) {
489       auto new_element = builder.create<GetTupleElementOp>(replacement.getLoc(),
490                                                            replacement, 0);
491       value.replaceAllUsesWith(new_element.getResult());
492     }
493     return;
494   }
495 
496   Value sub_tuple;
497   for (auto& use : llvm::make_early_inc_range(value.getUses())) {
498     if (isa<GetTupleElementOp>(use.getOwner())) {
499       use.set(replacement);
500       continue;
501     }
502 
503     if (!sub_tuple)
504       sub_tuple = CreateSubTuple(builder, replacement, tuple_type.size());
505 
506     use.set(sub_tuple);
507   }
508 }
509 
510 // Replaces control flow op block single block argument with new block argument
511 // of type `new_type` (tuple type). The last element of the new block argument
512 // (token) is returned.
UpdateControlFlowBlockArgWithToken(OpBuilder & builder,Block & block,Type token_type)513 Value UpdateControlFlowBlockArgWithToken(OpBuilder& builder, Block& block,
514                                          Type token_type) {
515   assert(block.getNumArguments() == 1);
516   builder.setInsertionPointToStart(&block);
517   auto new_arg = block.addArgument(token_type);
518   ReplaceWithTupleResult(builder, block.getArgument(0), new_arg);
519   block.eraseArgument(0);
520   return builder
521       .create<GetTupleElementOp>(new_arg.getLoc(), new_arg,
522                                  token_type.cast<TupleType>().size() - 1)
523       .getResult();
524 }
525 
526 // Updates control flow op terminator with an extra element `token`. If the
527 // original return value is not a tuple, a new tuple is formed. Otherwise the
528 // tuple is extended.
RewriteControlFlowTerminator(OpBuilder & builder,Operation * terminator,Value token)529 void RewriteControlFlowTerminator(OpBuilder& builder, Operation* terminator,
530                                   Value token) {
531   assert(terminator->getNumOperands() == 1);
532   assert(terminator->getBlock()->getNumArguments() == 1);
533   // `mhlo.while` cond terminator does not need to be rewritten as it always
534   // returns a tensor<i1> predicate value.
535   if (auto while_parent = dyn_cast_or_null<WhileOp>(terminator->getParentOp()))
536     if (terminator->getParentRegion() == &while_parent.cond()) return;
537 
538   builder.setInsertionPoint(terminator);
539   llvm::SmallDenseMap<Value, Value> rewritten_operands;
540   Value new_result = GetValueWithToken(builder, terminator->getOperand(0),
541                                        token, rewritten_operands);
542   terminator->setOperand(0, new_result);
543 }
544 
545 // Rewrites a `mhlo.if` op to receive and forward a `mhlo.token`. Operands to
546 // the op for all of its regions are extended to have an extra operand `token`.
RewriteRegionIfOp(OpBuilder & builder,IfOp region_if,SmallVectorImpl<OpVisitorState> & ops_to_visit,Value token)547 void RewriteRegionIfOp(OpBuilder& builder, IfOp region_if,
548                        SmallVectorImpl<OpVisitorState>& ops_to_visit,
549                        Value token) {
550   llvm::SmallDenseMap<Value, Value> rewritten_operands;
551 
552   // Rewrite all region operands to have an extra operand `token`.
553   Value new_true_operand = GetValueWithToken(builder, region_if.true_arg(),
554                                              token, rewritten_operands);
555   Value new_false_operand = GetValueWithToken(builder, region_if.false_arg(),
556                                               token, rewritten_operands);
557 
558   auto new_result_type = GetTypeWithToken(builder, region_if.getType());
559 
560   // Create new `mhlo.if` op with extra token operands and result.
561   auto new_if = builder.create<IfOp>(region_if.getLoc(), new_result_type,
562                                      region_if.pred(), new_true_operand,
563                                      new_false_operand);
564 
565   // Move all regions from the old `mhlo.if` op to its replacement.
566   new_if.true_branch().takeBody(region_if.true_branch());
567   new_if.false_branch().takeBody(region_if.false_branch());
568 
569   // Forward result from old `mhlo.if` with replacement, and unpack result when
570   // necessary.
571   ReplaceWithTupleResult(builder, region_if.getResult(), new_if.getResult());
572 
573   auto new_token = builder.create<GetTupleElementOp>(
574       new_if.getLoc(), new_if.getResult(),
575       new_if.getResult().getType().cast<TupleType>().size() - 1);
576 
577   region_if.erase();
578 
579   // Remove leftover operands to old `mhlo.if` if they have no uses.
580   for (auto& rewritten_operand : rewritten_operands)
581     if (auto tuple_op = rewritten_operand.getFirst().getDefiningOp<TupleOp>())
582       if (tuple_op.use_empty()) tuple_op.erase();
583 
584   // Next op to visit. The replacement is visited but at its first region. The
585   // token result of the new region if is propagated.
586   ops_to_visit.push_back({/*region_idx=*/0, new_token, new_if});
587 }
588 
589 // Rewrites a `mhlo.if`/`mhlo.while` region to receive and forward a
590 // `mhlo.token`. The block argument is updated to have an extra `mhlo.token`
591 // element. If the region block is to be rewritten, the next op to visit is set
592 // to the first op in the block. Otherwise the terminator is updated to forward
593 // `token`.
RewriteControlFlowOpRegion(OpBuilder & builder,Operation * region_op,unsigned region_idx,Type block_arg_type,SmallVectorImpl<OpVisitorState> & ops_to_visit,const llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,Value token)594 void RewriteControlFlowOpRegion(
595     OpBuilder& builder, Operation* region_op, unsigned region_idx,
596     Type block_arg_type, SmallVectorImpl<OpVisitorState>& ops_to_visit,
597     const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks, Value token) {
598   ops_to_visit.push_back({region_idx + 1, token, region_op});
599 
600   Region& region = region_op->getRegion(region_idx);
601   assert(llvm::hasSingleElement(region));
602 
603   auto block_token = UpdateControlFlowBlockArgWithToken(builder, region.front(),
604                                                         block_arg_type);
605 
606   if (control_flow_blocks.contains(&region.front())) {
607     ops_to_visit.push_back({/*region_idx=*/llvm::None, block_token,
608                             block_token.getDefiningOp()->getNextNode()});
609     return;
610   }
611 
612   RewriteControlFlowTerminator(builder, region.front().getTerminator(),
613                                block_token);
614 }
615 
616 // Rewrites an `mhlo.if` op or its region. If `region_idx` is not set, the op
617 // operands and results are rewritten. If `region_idx` is set, region
618 // `region_idx` is rewritten to take in and return an additional token. Returns
619 // true if the op or its region was rewritten.
ProcessRegionIfOp(OpBuilder & builder,IfOp region_if,Optional<unsigned> region_idx,SmallVectorImpl<OpVisitorState> & ops_to_visit,const llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,Value token)620 bool ProcessRegionIfOp(OpBuilder& builder, IfOp region_if,
621                        Optional<unsigned> region_idx,
622                        SmallVectorImpl<OpVisitorState>& ops_to_visit,
623                        const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks,
624                        Value token) {
625   builder.setInsertionPoint(region_if);
626 
627   if (!region_idx) {
628     RewriteRegionIfOp(builder, region_if, ops_to_visit, token);
629     return true;
630   }
631 
632   if (*region_idx < region_if.getNumRegions()) {
633     RewriteControlFlowOpRegion(builder, region_if, *region_idx,
634                                region_if.getOperand(*region_idx + 1).getType(),
635                                ops_to_visit, control_flow_blocks, token);
636     return true;
637   }
638 
639   return false;
640 }
641 
642 // Rewrites a `mhlo.while` op to receive and forward a `mhlo.token`. Operands to
643 // the op for all of its regions are extended to have an extra operand `token`.
RewriteRegionWhileOp(OpBuilder & builder,WhileOp region_while,SmallVectorImpl<OpVisitorState> & ops_to_visit,Value token)644 void RewriteRegionWhileOp(OpBuilder& builder, WhileOp region_while,
645                           SmallVectorImpl<OpVisitorState>& ops_to_visit,
646                           Value token) {
647   llvm::SmallDenseMap<Value, Value> rewritten_operands;
648 
649   // Rewrite region operand to have an extra operand `token`.
650   Value new_val_operand =
651       GetValueWithToken(builder, region_while.val(), token, rewritten_operands);
652 
653   auto new_result_type = GetTypeWithToken(builder, region_while.getType());
654 
655   // Create new `mhlo.while` op with extra token operand and result.
656   auto new_while = builder.create<WhileOp>(region_while.getLoc(),
657                                            new_result_type, new_val_operand);
658 
659   // Move all regions from the old `mhlo.while` op to its replacement.
660   new_while.cond().takeBody(region_while.cond());
661   new_while.body().takeBody(region_while.body());
662 
663   // Forward result from old `mhlo.while` with replacement, and unpack result
664   // when necessary.
665   ReplaceWithTupleResult(builder, region_while.getResult(),
666                          new_while.getResult());
667 
668   auto new_token = builder.create<GetTupleElementOp>(
669       new_while.getLoc(), new_while.getResult(),
670       new_while.getResult().getType().cast<TupleType>().size() - 1);
671 
672   region_while.erase();
673 
674   // Remove leftover operands to old `mhlo.while` if they have no uses.
675   for (auto& rewritten_operand : rewritten_operands)
676     if (auto tuple_op = rewritten_operand.getFirst().getDefiningOp<TupleOp>())
677       if (tuple_op.use_empty()) tuple_op.erase();
678 
679   // Next op to visit. The replacement is visited but at its first region. The
680   // token result of the new region if is propagated.
681   ops_to_visit.push_back({/*region_idx=*/0, new_token, new_while});
682 }
683 
684 // Rewrites an `mhlo.while` op or its region. If `region_idx` is not set, the op
685 // operands and results are rewritten. If `region_idx` is set, region
686 // `region_idx` is rewritten to take in and return an additional token. Returns
687 // true if the op or its region was rewritten.
ProcessRegionWhileOp(OpBuilder & builder,WhileOp region_while,Optional<unsigned> region_idx,SmallVectorImpl<OpVisitorState> & ops_to_visit,const llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,Value token)688 bool ProcessRegionWhileOp(
689     OpBuilder& builder, WhileOp region_while, Optional<unsigned> region_idx,
690     SmallVectorImpl<OpVisitorState>& ops_to_visit,
691     const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks, Value token) {
692   builder.setInsertionPoint(region_while);
693 
694   if (!region_idx) {
695     RewriteRegionWhileOp(builder, region_while, ops_to_visit, token);
696     return true;
697   }
698 
699   if (*region_idx < region_while.getNumRegions()) {
700     RewriteControlFlowOpRegion(builder, region_while, *region_idx,
701                                region_while.val().getType(), ops_to_visit,
702                                control_flow_blocks, token);
703     return true;
704   }
705 
706   return false;
707 }
708 
709 // Updates function type based on current function body block arguments and
710 // terminator operand types.
UpdateFunctionType(OpBuilder & builder,FuncOp func,Block & func_body)711 void UpdateFunctionType(OpBuilder& builder, FuncOp func, Block& func_body) {
712   auto new_argument_types = llvm::to_vector<4>(func_body.getArgumentTypes());
713   auto new_result_types =
714       llvm::to_vector<4>(func_body.getTerminator()->getOperandTypes());
715   func.setType(FunctionType::get(builder.getContext(), new_argument_types,
716                                  new_result_types));
717 }
718 
719 // Replaces a function terminator `return` with another `return` that has an
720 // extra `mhlo.token` operand.
RewriteFunctionTerminator(OpBuilder & builder,mlir::ReturnOp terminator,Value token)721 void RewriteFunctionTerminator(OpBuilder& builder, mlir::ReturnOp terminator,
722                                Value token) {
723   auto new_results = llvm::to_vector<4>(terminator.getOperands());
724   new_results.push_back(token);
725   builder.setInsertionPoint(terminator);
726   builder.create<mlir::ReturnOp>(terminator.getLoc(), new_results);
727   terminator.erase();
728 }
729 
730 // Rewrites a function body and communication ops inside. Region control flow
731 // are updated when necessary, to propagate tokens. The function may either be
732 // rewritten to create a token or take in and return a token, depending on its
733 // visibility and if there are any callers.
RewriteFunction(OpBuilder & builder,int64_t & channel_id,ModuleOp module,FuncOp func,const llvm::SmallDenseMap<StringRef,FuncToRewrite> & funcs,const llvm::SmallPtrSetImpl<Operation * > & control_flow_ops,const llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,bool is_clone)734 LogicalResult RewriteFunction(
735     OpBuilder& builder, int64_t& channel_id, ModuleOp module, FuncOp func,
736     const llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs,
737     const llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
738     const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks, bool is_clone) {
739   MLIRContext* context = module.getContext();
740   if (!llvm::hasSingleElement(func.getBody()))
741     return func.emitError()
742            << "'" << FuncOp::getOperationName()
743            << "' ops with more than one block are not supported";
744 
745   bool rewrite_block =
746       is_clone || (!func.isPublic() && !func.symbolKnownUseEmpty(module));
747   Block& func_body = func.front();
748 
749   builder.setInsertionPointToStart(&func_body);
750   auto token_type = TokenType::get(context);
751   // If a function is public, it's signature should not be modified, and instead
752   // a token will be created. Otherwise a token block argument is inserted.
753   Value init_token =
754       rewrite_block ? func_body.addArgument(token_type)
755                     : builder.create<CreateTokenOp>(func.getLoc(), token_type)
756                           .getResult();
757 
758   // Stack to keep track of region based control flow op nesting and current
759   // op to visit.
760   SmallVector<OpVisitorState, 4> ops_to_visit{
761       {/*region_idx=*/llvm::None, init_token, &func_body.front()}};
762 
763   while (!ops_to_visit.empty()) {
764     OpVisitorState op_to_visit = ops_to_visit.pop_back_val();
765     Operation* curr_op = op_to_visit.op;
766 
767     Value token = op_to_visit.token;
768     // Ops may be removed, so the next op is kept track of beforehand.
769     Operation* next_op = curr_op->getNextNode();
770 
771     if (auto host_compute = dyn_cast<TF::_XlaHostComputeMlirOp>(curr_op)) {
772       token = RewriteHostComputeOp(builder, channel_id, host_compute, token);
773     } else if (auto send_to_host = dyn_cast<TF::XlaSendToHostOp>(curr_op)) {
774       token = RewriteSendToHostOp(builder, channel_id, send_to_host, token);
775     } else if (auto recv_from_host = dyn_cast<TF::XlaRecvFromHostOp>(curr_op)) {
776       token = RewriteRecvFromHostOp(builder, channel_id, recv_from_host, token);
777     } else if (auto call = dyn_cast<mlir::CallOp>(curr_op)) {
778       // Only `mlir::CallOp` is supported as this requires knowing how to
779       // rewrite arguments and results to a function.
780       auto it = funcs.find(call.getCallee());
781       if (it != funcs.end()) {
782         FuncOp clone = it->getSecond().clone;
783         Optional<StringRef> symbol_name =
784             clone ? Optional<StringRef>(clone.getName()) : llvm::None;
785         // If the function being called is to be cloned, update the call to also
786         // point to the cloned function.
787         token = RewriteCallOp(builder, call, symbol_name, token);
788       }
789     } else if (auto region_if = dyn_cast<IfOp>(curr_op)) {
790       if (op_to_visit.region_idx || control_flow_ops.contains(region_if))
791         if (ProcessRegionIfOp(builder, region_if, op_to_visit.region_idx,
792                               ops_to_visit, control_flow_blocks, token))
793           continue;
794     } else if (auto region_while = dyn_cast<WhileOp>(curr_op)) {
795       if (op_to_visit.region_idx || control_flow_ops.contains(region_while))
796         if (ProcessRegionWhileOp(builder, region_while, op_to_visit.region_idx,
797                                  ops_to_visit, control_flow_blocks, token))
798           continue;
799     } else if (auto region_terminator = dyn_cast<mhlo::ReturnOp>(curr_op)) {
800       RewriteControlFlowTerminator(builder, region_terminator, token);
801       // There is no next op afer the control flow op terminator, simply let
802       // stack have one less element.
803       continue;
804     } else if (auto func_terminator = dyn_cast<mlir::ReturnOp>(curr_op)) {
805       if (rewrite_block)
806         RewriteFunctionTerminator(builder, func_terminator, token);
807 
808       // There is no next op afer the function terminator, simply let stack have
809       // one less element/be empty.
810       continue;
811     }
812 
813     // Visit next op.
814     ops_to_visit.push_back({/*region_idx=*/llvm::None, token, next_op});
815   }
816 
817   if (rewrite_block) UpdateFunctionType(builder, func, func_body);
818 
819   return success();
820 }
821 
822 // Checks if a function call is pointing to a function with communication ops.
IsFunctionCallWithCommunication(Operation * op,const llvm::SmallDenseMap<StringRef,FuncToRewrite> & funcs_to_rewrite)823 bool IsFunctionCallWithCommunication(
824     Operation* op,
825     const llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs_to_rewrite) {
826   if (auto call = dyn_cast<mlir::CallOp>(op))
827     return funcs_to_rewrite.count(call.callee());
828 
829   return false;
830 }
831 
832 // Collects all control flow op ancestors of communication ops or function calls
833 // with communication ops (transitively).
GetCommunicationControlFlowOps(FuncOp func,const llvm::SmallDenseMap<StringRef,FuncToRewrite> & funcs_to_rewrite,llvm::SmallPtrSetImpl<Operation * > & control_flow_ops,llvm::SmallPtrSetImpl<Block * > & control_flow_blocks)834 void GetCommunicationControlFlowOps(
835     FuncOp func,
836     const llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs_to_rewrite,
837     llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
838     llvm::SmallPtrSetImpl<Block*>& control_flow_blocks) {
839   func.walk([&](Operation* op) {
840     if (IsCommunicationOp(op) ||
841         IsFunctionCallWithCommunication(op, funcs_to_rewrite))
842       if (failed(GetControlFlowAncestors(op, control_flow_ops,
843                                          control_flow_blocks)))
844         llvm_unreachable(
845             "checking original function for control flow ancestors should have "
846             "errored first");
847   });
848 }
849 
runOnOperation()850 void LegalizeTFCommunication::runOnOperation() {
851   auto module = getOperation();
852   llvm::SmallDenseMap<StringRef, FuncToRewrite> funcs_to_rewrite;
853   if (failed(GetFunctionsToRewrite(module, funcs_to_rewrite)))
854     return signalPassFailure();
855 
856   // Module level counter to make sure Channel Id's are unique.
857   int64_t channel_id = 1;
858   OpBuilder builder(&getContext());
859   for (const auto& func_and_name : funcs_to_rewrite) {
860     const auto& func_to_rewrite = func_and_name.getSecond();
861     FuncOp func = func_to_rewrite.original;
862     if (failed(RewriteFunction(builder, channel_id, module, func,
863                                funcs_to_rewrite,
864                                func_to_rewrite.control_flow_ops,
865                                func_to_rewrite.control_flow_blocks,
866                                /*is_clone=*/false)))
867       return signalPassFailure();
868 
869     FuncOp clone = func_and_name.getSecond().clone;
870     if (!clone) continue;
871     llvm::SmallPtrSet<Operation*, 4> clone_control_flow_ops;
872     llvm::SmallPtrSet<Block*, 4> clone_control_flow_blocks;
873     GetCommunicationControlFlowOps(clone, funcs_to_rewrite,
874                                    clone_control_flow_ops,
875                                    clone_control_flow_blocks);
876     if (failed(RewriteFunction(builder, channel_id, module, clone,
877                                funcs_to_rewrite, clone_control_flow_ops,
878                                clone_control_flow_blocks,
879                                /*is_clone=*/true)))
880       llvm_unreachable(
881           "rewriting of original function should have errored first");
882   }
883 }
884 
885 static PassRegistration<LegalizeTFCommunication> pass(
886     "xla-legalize-tf-communication",
887     "Legalize TF/XLA communication ops (TensorFlow dialect) to the HLO "
888     "dialect");
889 }  // namespace
890 
CreateLegalizeTFCommunicationPass()891 std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTFCommunicationPass() {
892   return std::make_unique<LegalizeTFCommunication>();
893 }
894 
895 }  // namespace mhlo
896 }  // namespace mlir
897