1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering MHLO dialect to Standard dialect.
17 
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/StringSwitch.h"
20 #include "llvm/Support/Casting.h"
21 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // TF:llvm-project
25 #include "mlir/IR/Block.h"
26 #include "mlir/IR/BlockAndValueMapping.h"
27 #include "mlir/IR/Builders.h"
28 #include "mlir/IR/BuiltinOps.h"
29 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/IR/PatternMatch.h"
31 #include "mlir/IR/TypeUtilities.h"
32 #include "mlir/Pass/Pass.h"
33 #include "mlir/Pass/PassRegistry.h"
34 #include "mlir/Support/LogicalResult.h"
35 
36 namespace mlir {
37 namespace mhlo {
38 namespace {
39 struct LegalizeControlFlowPass
40     : public mlir::PassWrapper<LegalizeControlFlowPass, FunctionPass> {
41   // Perform the lowering to MLIR control flow.
42   void runOnFunction() override;
43 };
44 
45 // Replaces terminators for the newly created blocks from a targe region.
46 // These terminators are replaced with branch operations to a target block.
ReplaceTerminators(Region * region,Block * target_block,Location loc,const BlockAndValueMapping & mapper,OpBuilder * builder)47 LogicalResult ReplaceTerminators(Region* region, Block* target_block,
48                                  Location loc,
49                                  const BlockAndValueMapping& mapper,
50                                  OpBuilder* builder) {
51   for (auto& old_block : region->getBlocks()) {
52     Block* block = mapper.lookup(&old_block);
53     auto return_op = dyn_cast<mhlo::ReturnOp>(block->getTerminator());
54     if (!return_op) continue;
55     builder->setInsertionPointToEnd(block);
56     builder->create<mlir::BranchOp>(loc, target_block, return_op.getOperands());
57     return_op.erase();
58   }
59 
60   return success();
61 }
62 
LowerIfOp(mlir::mhlo::IfOp if_op)63 LogicalResult LowerIfOp(mlir::mhlo::IfOp if_op) {
64   Operation* op_inst = if_op.getOperation();
65   mlir::OpBuilder builder(if_op);
66   auto orig_block = op_inst->getBlock();
67   auto* tail_block = orig_block->splitBlock(op_inst);
68   auto loc = if_op.getLoc();
69 
70   // Duplicate the true and false regions in the block between the sections
71   // before and after the conditional.
72   BlockAndValueMapping mapper;
73   if_op.true_branch().cloneInto(orig_block->getParent(),
74                                 Region::iterator(tail_block), mapper);
75   if_op.false_branch().cloneInto(orig_block->getParent(),
76                                  Region::iterator(tail_block), mapper);
77 
78   // Determine the blocks for the start of the true and false regions.
79   Block* true_block = mapper.lookup(&if_op.true_branch().front());
80   Block* false_block = mapper.lookup(&if_op.false_branch().front());
81 
82   // Perform the conditional branch into the true/false cases.
83   builder.setInsertionPointToEnd(orig_block);
84 
85   // Extract the predicate for checking branching, then branch to the true and
86   // false regions appropriately.
87   auto cond_value = builder.create<mlir::tensor::ExtractOp>(loc, if_op.pred());
88   builder.create<mlir::CondBranchOp>(loc, cond_value, true_block,
89                                      if_op.true_arg(), false_block,
90                                      if_op.false_arg());
91 
92   // Replace the true case's return operations with a branch to the tail of
93   // the condition.
94   if (failed(ReplaceTerminators(&if_op.true_branch(), tail_block, loc, mapper,
95                                 &builder)))
96     return failure();
97   if (failed(ReplaceTerminators(&if_op.false_branch(), tail_block, loc, mapper,
98                                 &builder)))
99     return failure();
100 
101   tail_block->addArguments(if_op.getResult().getType());
102   if_op.getResult().replaceAllUsesWith(tail_block->getArgument(0));
103 
104   op_inst->erase();
105   return success();
106 }
107 
LowerWhileOp(mlir::mhlo::WhileOp while_op)108 LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) {
109   // Converts a MHLO while loop into control flow. This generates a set of MLIR
110   // blocks and branches, along with inlining the regions provided by the MHLO
111   // while loop. The structure should be similar to below:
112   //
113   //   <prior operations>
114   //   %0 = "mhlo.while"(%arg0) {^cond(...){...}, ^body(...){...}}
115   //   <post operations>
116   auto* op_inst = while_op.getOperation();
117   mlir::OpBuilder builder(while_op);
118   auto loc = while_op.getLoc();
119 
120   // Break the block into four sections:
121   // orig_block - operations before the while and the branch into looping check.
122   // tail_block - operations after the while loop completes.
123   // cond_block - check the looping condition, then conditionally branch into
124   //              the loop or, if condition is false, jump to the tail branch.
125   // body_block - inlined loop body, then jump back to the condition block.
126   auto* orig_block = op_inst->getBlock();
127   auto* tail_block = orig_block->splitBlock(op_inst);
128 
129   BlockAndValueMapping mapper;
130   while_op.cond().cloneInto(orig_block->getParent(),
131                             Region::iterator(tail_block), mapper);
132   while_op.body().cloneInto(orig_block->getParent(),
133                             Region::iterator(tail_block), mapper);
134 
135   // Lookup the entry blocks for both condition and body.
136   auto* cond_block = mapper.lookup(&while_op.cond().front());
137   auto* body_block = mapper.lookup(&while_op.body().front());
138 
139   // Setup the end of the original block:
140   //     <prior operations>
141   //     br ^cond(%arg0) // Jumps to the condition statement.
142   builder.setInsertionPointToEnd(orig_block);
143   builder.create<mlir::BranchOp>(loc, cond_block, while_op.getOperand());
144 
145   // Updates the inlined condition blocks by replacing the return op with an
146   // tensor.extract and conditional branch. This changes the block below:
147   //   ^cond(%0):
148   //     <inlined conditional region>
149   //    "mhlo".return(%1)
150   //
151   //  Into:
152   //   ^cond(%0):
153   //     <inlined conditional region>
154   //     %2 = tensor.extract %1[] : tensor<i1> // Extract the condition value.
155   //     cond_br %2, ^body(%0), ^tail(%0) // Branch.
156   builder.setInsertionPointToStart(cond_block);
157 
158   // Replace the mhlo::ReturnOp with a branch back to the condition block.
159   // This is required as the mhlo::ReturnOp is used to mark the end of a
160   // block for regions nested inside of a operations (MLIR ReturnOp cannot be
161   // nested within an non-function region).
162   for (auto& block : while_op.cond()) {
163     auto new_block = mapper.lookup(&block);
164 
165     auto return_op = dyn_cast<mhlo::ReturnOp>(new_block->getTerminator());
166     if (!return_op) continue;
167     builder.setInsertionPointToEnd(new_block);
168 
169     auto return_value = return_op.getOperand(0);
170     auto cond_value =
171         builder.create<mlir::tensor::ExtractOp>(loc, return_value);
172 
173     // Get the body block arguments.
174     llvm::SmallVector<Value, 4> successor_args(cond_block->args_begin(),
175                                                cond_block->args_end());
176     builder.create<mlir::CondBranchOp>(loc, cond_value, body_block,
177                                        successor_args, tail_block,
178                                        successor_args);
179     return_op.erase();
180   }
181 
182   // Updates the body blocks by replace the return op with an branch to the
183   // conditional block. This changes the block below:
184   //   ^body(%0):
185   //     <inlined body block>
186   //    "mhlo".return(%1)
187   //
188   //  Into:
189   //   ^body(%0):
190   //     <inlined body block>
191   //     br ^cond(%0) // Branch.
192   for (auto& block : while_op.body()) {
193     auto new_block = mapper.lookup(&block);
194     auto return_op = dyn_cast<mlir::mhlo::ReturnOp>(new_block->getTerminator());
195     if (!return_op) continue;
196     builder.setInsertionPointToEnd(new_block);
197     builder.create<mlir::BranchOp>(loc, cond_block, return_op.getOperands());
198     return_op.erase();
199   }
200 
201   // Erase the original while loop.
202   tail_block->addArgument(while_op.getType());
203   while_op.getResult().replaceAllUsesWith(tail_block->getArgument(0));
204   op_inst->erase();
205 
206   return success();
207 }
208 
runOnFunction()209 void LegalizeControlFlowPass::runOnFunction() {
210   auto func = getFunction();
211   llvm::SmallVector<IfOp, 4> if_ops;
212   func.walk([&](IfOp op) { if_ops.push_back(op); });
213 
214   for (auto& op : if_ops) {
215     if (failed(LowerIfOp(op))) return signalPassFailure();
216   }
217 
218   llvm::SmallVector<WhileOp, 4> while_ops;
219   func.walk([&](WhileOp op) { while_ops.push_back(op); });
220 
221   for (auto& op : while_ops) {
222     if (failed(LowerWhileOp(op))) return signalPassFailure();
223   }
224 }
225 }  // namespace
226 }  // namespace mhlo
227 }  // namespace mlir
228 
229 std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
createLegalizeControlFlowPass()230 mlir::mhlo::createLegalizeControlFlowPass() {
231   return std::make_unique<LegalizeControlFlowPass>();
232 }
233