1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering MHLO dialect to Standard dialect.
17
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/StringSwitch.h"
20 #include "llvm/Support/Casting.h"
21 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project
25 #include "mlir/IR/Block.h"
26 #include "mlir/IR/BlockAndValueMapping.h"
27 #include "mlir/IR/Builders.h"
28 #include "mlir/IR/BuiltinOps.h"
29 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/IR/PatternMatch.h"
31 #include "mlir/IR/TypeUtilities.h"
32 #include "mlir/Pass/Pass.h"
33 #include "mlir/Pass/PassRegistry.h"
34 #include "mlir/Support/LogicalResult.h"
35
36 namespace mlir {
37 namespace mhlo {
38 namespace {
39 struct LegalizeControlFlowPass
40 : public mlir::PassWrapper<LegalizeControlFlowPass, FunctionPass> {
41 // Perform the lowering to MLIR control flow.
42 void runOnFunction() override;
43 };
44
45 // Replaces terminators for the newly created blocks from a targe region.
46 // These terminators are replaced with branch operations to a target block.
ReplaceTerminators(Region * region,Block * target_block,Location loc,const BlockAndValueMapping & mapper,OpBuilder * builder)47 LogicalResult ReplaceTerminators(Region* region, Block* target_block,
48 Location loc,
49 const BlockAndValueMapping& mapper,
50 OpBuilder* builder) {
51 for (auto& old_block : region->getBlocks()) {
52 Block* block = mapper.lookup(&old_block);
53 auto return_op = dyn_cast<mhlo::ReturnOp>(block->getTerminator());
54 if (!return_op) continue;
55 builder->setInsertionPointToEnd(block);
56 builder->create<mlir::BranchOp>(loc, target_block, return_op.getOperands());
57 return_op.erase();
58 }
59
60 return success();
61 }
62
LowerIfOp(mlir::mhlo::IfOp if_op)63 LogicalResult LowerIfOp(mlir::mhlo::IfOp if_op) {
64 Operation* op_inst = if_op.getOperation();
65 mlir::OpBuilder builder(if_op);
66 auto orig_block = op_inst->getBlock();
67 auto* tail_block = orig_block->splitBlock(op_inst);
68 auto loc = if_op.getLoc();
69
70 // Duplicate the true and false regions in the block between the sections
71 // before and after the conditional.
72 BlockAndValueMapping mapper;
73 if_op.true_branch().cloneInto(orig_block->getParent(),
74 Region::iterator(tail_block), mapper);
75 if_op.false_branch().cloneInto(orig_block->getParent(),
76 Region::iterator(tail_block), mapper);
77
78 // Determine the blocks for the start of the true and false regions.
79 Block* true_block = mapper.lookup(&if_op.true_branch().front());
80 Block* false_block = mapper.lookup(&if_op.false_branch().front());
81
82 // Perform the conditional branch into the true/false cases.
83 builder.setInsertionPointToEnd(orig_block);
84
85 // Extract the predicate for checking branching, then branch to the true and
86 // false regions appropriately.
87 auto cond_value = builder.create<mlir::tensor::ExtractOp>(loc, if_op.pred());
88 builder.create<mlir::CondBranchOp>(loc, cond_value, true_block,
89 if_op.true_arg(), false_block,
90 if_op.false_arg());
91
92 // Replace the true case's return operations with a branch to the tail of
93 // the condition.
94 if (failed(ReplaceTerminators(&if_op.true_branch(), tail_block, loc, mapper,
95 &builder)))
96 return failure();
97 if (failed(ReplaceTerminators(&if_op.false_branch(), tail_block, loc, mapper,
98 &builder)))
99 return failure();
100
101 tail_block->addArguments(if_op.getResult().getType());
102 if_op.getResult().replaceAllUsesWith(tail_block->getArgument(0));
103
104 op_inst->erase();
105 return success();
106 }
107
LowerWhileOp(mlir::mhlo::WhileOp while_op)108 LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) {
109 // Converts a MHLO while loop into control flow. This generates a set of MLIR
110 // blocks and branches, along with inlining the regions provided by the MHLO
111 // while loop. The structure should be similar to below:
112 //
113 // <prior operations>
114 // %0 = "mhlo.while"(%arg0) {^cond(...){...}, ^body(...){...}}
115 // <post operations>
116 auto* op_inst = while_op.getOperation();
117 mlir::OpBuilder builder(while_op);
118 auto loc = while_op.getLoc();
119
120 // Break the block into four sections:
121 // orig_block - operations before the while and the branch into looping check.
122 // tail_block - operations after the while loop completes.
123 // cond_block - check the looping condition, then conditionally branch into
124 // the loop or, if condition is false, jump to the tail branch.
125 // body_block - inlined loop body, then jump back to the condition block.
126 auto* orig_block = op_inst->getBlock();
127 auto* tail_block = orig_block->splitBlock(op_inst);
128
129 BlockAndValueMapping mapper;
130 while_op.cond().cloneInto(orig_block->getParent(),
131 Region::iterator(tail_block), mapper);
132 while_op.body().cloneInto(orig_block->getParent(),
133 Region::iterator(tail_block), mapper);
134
135 // Lookup the entry blocks for both condition and body.
136 auto* cond_block = mapper.lookup(&while_op.cond().front());
137 auto* body_block = mapper.lookup(&while_op.body().front());
138
139 // Setup the end of the original block:
140 // <prior operations>
141 // br ^cond(%arg0) // Jumps to the condition statement.
142 builder.setInsertionPointToEnd(orig_block);
143 builder.create<mlir::BranchOp>(loc, cond_block, while_op.getOperand());
144
145 // Updates the inlined condition blocks by replacing the return op with an
146 // tensor.extract and conditional branch. This changes the block below:
147 // ^cond(%0):
148 // <inlined conditional region>
149 // "mhlo".return(%1)
150 //
151 // Into:
152 // ^cond(%0):
153 // <inlined conditional region>
154 // %2 = tensor.extract %1[] : tensor<i1> // Extract the condition value.
155 // cond_br %2, ^body(%0), ^tail(%0) // Branch.
156 builder.setInsertionPointToStart(cond_block);
157
158 // Replace the mhlo::ReturnOp with a branch back to the condition block.
159 // This is required as the mhlo::ReturnOp is used to mark the end of a
160 // block for regions nested inside of a operations (MLIR ReturnOp cannot be
161 // nested within an non-function region).
162 for (auto& block : while_op.cond()) {
163 auto new_block = mapper.lookup(&block);
164
165 auto return_op = dyn_cast<mhlo::ReturnOp>(new_block->getTerminator());
166 if (!return_op) continue;
167 builder.setInsertionPointToEnd(new_block);
168
169 auto return_value = return_op.getOperand(0);
170 auto cond_value =
171 builder.create<mlir::tensor::ExtractOp>(loc, return_value);
172
173 // Get the body block arguments.
174 llvm::SmallVector<Value, 4> successor_args(cond_block->args_begin(),
175 cond_block->args_end());
176 builder.create<mlir::CondBranchOp>(loc, cond_value, body_block,
177 successor_args, tail_block,
178 successor_args);
179 return_op.erase();
180 }
181
182 // Updates the body blocks by replace the return op with an branch to the
183 // conditional block. This changes the block below:
184 // ^body(%0):
185 // <inlined body block>
186 // "mhlo".return(%1)
187 //
188 // Into:
189 // ^body(%0):
190 // <inlined body block>
191 // br ^cond(%0) // Branch.
192 for (auto& block : while_op.body()) {
193 auto new_block = mapper.lookup(&block);
194 auto return_op = dyn_cast<mlir::mhlo::ReturnOp>(new_block->getTerminator());
195 if (!return_op) continue;
196 builder.setInsertionPointToEnd(new_block);
197 builder.create<mlir::BranchOp>(loc, cond_block, return_op.getOperands());
198 return_op.erase();
199 }
200
201 // Erase the original while loop.
202 tail_block->addArgument(while_op.getType());
203 while_op.getResult().replaceAllUsesWith(tail_block->getArgument(0));
204 op_inst->erase();
205
206 return success();
207 }
208
runOnFunction()209 void LegalizeControlFlowPass::runOnFunction() {
210 auto func = getFunction();
211 llvm::SmallVector<IfOp, 4> if_ops;
212 func.walk([&](IfOp op) { if_ops.push_back(op); });
213
214 for (auto& op : if_ops) {
215 if (failed(LowerIfOp(op))) return signalPassFailure();
216 }
217
218 llvm::SmallVector<WhileOp, 4> while_ops;
219 func.walk([&](WhileOp op) { while_ops.push_back(op); });
220
221 for (auto& op : while_ops) {
222 if (failed(LowerWhileOp(op))) return signalPassFailure();
223 }
224 }
225 } // namespace
226 } // namespace mhlo
227 } // namespace mlir
228
229 std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
createLegalizeControlFlowPass()230 mlir::mhlo::createLegalizeControlFlowPass() {
231 return std::make_unique<LegalizeControlFlowPass>();
232 }
233