1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
20 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
21 #include "mlir/Dialect/SCF/SCF.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Transforms/DialectConversion.h"
26
27 namespace mlir {
28 namespace lmhlo {
29 namespace {
30
31 // Clones and adapts the code in `lhlo_block` that works on buffers and has a
32 // single output buffer to make it compatible with `operands` that have element
33 // types of the respective buffers. Returns the computed value.
34 //
35 // Example. For `operands` with (f32, i32) types and a block with LHLO ops and
36 // with signature:
37 // ^bb(%lhs: memref<f32>, %rhs: memref<i32>, %res: memref<i1>):
38 // <LHLO_ops>
39 //
40 // inserts necessary alloc and store ops to compute and return result that has
41 // `i1` type.
ApplySingleResultLhloCode(Location loc,ValueRange operands,Block * lhlo_block,OpBuilder * b)42 Value ApplySingleResultLhloCode(Location loc, ValueRange operands,
43 Block* lhlo_block, OpBuilder* b) {
44 SmallVector<Value, 2> arg_bufs;
45 for (auto arg_type : lhlo_block->getArgumentTypes()) {
46 arg_bufs.push_back(b->create<AllocOp>(loc, arg_type.cast<MemRefType>()));
47 }
48 for (auto operand : llvm::enumerate(operands)) {
49 b->create<StoreOp>(loc, operand.value(), arg_bufs[operand.index()]);
50 }
51 // Clone the ops from `lhlo_block`.
52 BlockAndValueMapping mapping;
53 mapping.map(lhlo_block->getArguments(), arg_bufs);
54 for (auto& nested : lhlo_block->without_terminator()) {
55 auto clone = b->clone(nested, mapping);
56 mapping.map(nested.getResults(), clone->getResults());
57 }
58 return b->create<LoadOp>(loc, arg_bufs.back());
59 }
60
61 // Converts a block with LHLO ops and with signature:
62 // ^bb(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
63 // into a reduction operator of scf.reduce by doing buffer allocation for
64 // scalar arguments and the result of `scf.reduce` to make it compatible with
65 // LHLO ops.
ConvertToReductionOperator(Location loc,scf::ReduceOp reduce_op,Block * lhlo_block,OpBuilder * b)66 void ConvertToReductionOperator(Location loc, scf::ReduceOp reduce_op,
67 Block* lhlo_block, OpBuilder* b) {
68 Block& loop_reduce_op_body = reduce_op.reductionOperator().front();
69 OpBuilder::InsertionGuard guard(*b);
70 b->setInsertionPointToStart(&loop_reduce_op_body);
71 b->create<scf::ReduceReturnOp>(
72 loc, ApplySingleResultLhloCode(loc, loop_reduce_op_body.getArguments(),
73 lhlo_block, b));
74 }
75
76 // Returns result of ConstantOp if `dim` is static, otherwise uses DimOp to
77 // extract dimension at runtime.
GetStaticOrDynamicDim(mlir::Location loc,Value shaped_value,size_t dim_index,int64_t dim,OpBuilder * b)78 Value GetStaticOrDynamicDim(mlir::Location loc, Value shaped_value,
79 size_t dim_index, int64_t dim, OpBuilder* b) {
80 return dim == ShapedType::kDynamicSize
81 ? b->create<DimOp>(loc, shaped_value, dim_index).getResult()
82 : b->create<ConstantIndexOp>(loc, dim);
83 }
84
85 struct MappedIvs {
86 // False if the mapped indices are in the padding area, true otherwise.
87 Value in_bounds;
88 // Mapped indices.
89 SmallVector<Value, 2> ivs;
90 };
91
92 template <typename OpTy>
MapWindowIvsToInput(OpTy op,ValueRange ivs,ValueRange window_ivs,OpBuilder * b)93 MappedIvs MapWindowIvsToInput(OpTy op, ValueRange ivs, ValueRange window_ivs,
94 OpBuilder* b) {
95 MappedIvs mapped_ivs;
96
97 if (!op.window_strides().hasValue()) {
98 op.emitOpError("No window strides specified.");
99 }
100 auto window_strides = op.window_strides().getValue();
101
102 if (!op.padding().hasValue()) {
103 op.emitOpError("No padding specified.");
104 }
105 auto padding = op.padding().getValue();
106
107 auto loc = op.getLoc();
108 auto operand = op.operand();
109 auto operand_shape = operand.getType().template cast<MemRefType>().getShape();
110
111 // `in_bounds` is false when the mapped indices are in the padding area.
112 mapped_ivs.in_bounds = b->create<mlir::ConstantOp>(
113 loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1));
114 for (unsigned i = 0, e = ivs.size(); i < e; ++i) {
115 auto stride = window_strides.template getValue<llvm::APInt>(i);
116 auto pad_low = padding.template getValue<llvm::APInt>({i, 0});
117
118 Value stride_val = b->create<ConstantIndexOp>(loc, stride.getSExtValue());
119 Value pad_low_val = b->create<ConstantIndexOp>(loc, pad_low.getSExtValue());
120
121 Value center = b->create<MulIOp>(loc, ivs[i], stride_val);
122 Value offset = b->create<SubIOp>(loc, window_ivs[i], pad_low_val);
123 Value index = b->create<AddIOp>(loc, center, offset);
124 Value upper_bound =
125 GetStaticOrDynamicDim(loc, operand, i, operand_shape[i], b);
126 // We must check whether 0 <= index_i < shape_i, as otherwise we are in
127 // the pad and then we have to use the neutral element for reduction.
128 // Equivalently, it can be computed as the unsigned comparison index_i <
129 // shape_i, since a negative value wraps to a large positive value.
130 mapped_ivs.in_bounds = b->create<mlir::AndOp>(
131 loc, mapped_ivs.in_bounds,
132 b->create<CmpIOp>(loc, CmpIPredicate::ult, index, upper_bound));
133 mapped_ivs.ivs.push_back(index);
134 }
135 return mapped_ivs;
136 }
137
138 // Returns scf::Parallel over a shaped value with static or dynamic shape.
MakeLoopOverShape(Location loc,Value shaped_value,OpBuilder * b)139 scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value,
140 OpBuilder* b) {
141 Value zero = b->create<ConstantIndexOp>(loc, 0);
142 Value one = b->create<ConstantIndexOp>(loc, 1);
143
144 ArrayRef<int64_t> shape =
145 shaped_value.getType().cast<ShapedType>().getShape();
146 SmallVector<Value, 2> lower, upper, step;
147 for (auto dim : llvm::enumerate(shape)) {
148 upper.push_back(
149 GetStaticOrDynamicDim(loc, shaped_value, dim.index(), dim.value(), b));
150 lower.push_back(zero);
151 step.push_back(one);
152 }
153 return b->create<scf::ParallelOp>(loc, lower, upper, step);
154 }
155
156 // Converts `lmhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp.
157 // The outper `ParallelOp` refers to the parallel loops if there are
158 // any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp`
159 // contains the reduction operator.
160 //
161 // Example:
162 //
163 // "lmhlo.reduce"(%buffer, %init_buf, %result) ( {
164 // ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
165 // <LHLO ops>
166 // } ) {dimensions = dense<[1]> : tensor<1xi64>}
167 // : (memref<100x10x5xf32>, memref<f32>, memref<100x5xf32>) -> ()
168 //
169 // is roughly converted into:
170 //
171 // %init = load %init_buf[] : memref<f32>
172 // scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) {
173 // %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) {
174 // %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32>
175 // scf.reduce(%elem_to_reduce) {
176 // ^bb0(%elem: f32, %acc: f32): // no predecessors
177 // elem_buf = alloc() : memref<f32>
178 // store %elem, elem_buf[] : memref<f32>
179 // acc_buf = alloc() : memref<f32>
180 // store %acc, acc_buf[] : memref<f32>
181 // <LHLO_ops>
182 // %acc_result = load acc_buf[] : memref<f32>
183 // scf.reduce.return %acc_result : f32
184 // } : f32
185 // scf.yield
186 // } : f32
187 // scf.yield
188 // }
189 class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
190 public:
191 using OpConversionPattern<lmhlo::ReduceOp>::OpConversionPattern;
192
matchAndRewrite(lmhlo::ReduceOp reduce_op,ArrayRef<Value>,ConversionPatternRewriter & rewriter) const193 LogicalResult matchAndRewrite(
194 lmhlo::ReduceOp reduce_op, ArrayRef<Value> /*args*/,
195 ConversionPatternRewriter& rewriter) const final {
196 // TODO(b/137624192) Implement variadic reduce.
197 if (reduce_op.out().size() != 1) return failure();
198
199 scf::ReduceOp scf_reduce_op =
200 CreateReduceOpInNestedParallelLoops(reduce_op, &rewriter);
201 ConvertToReductionOperator(reduce_op.getLoc(), scf_reduce_op,
202 &reduce_op.body().front(), &rewriter);
203 rewriter.replaceOp(reduce_op, llvm::None);
204 return success();
205 }
206
207 private:
208 // Creates nested `scf.parallel` ops with `scf.reduce`. The outer ParallelOp
209 // refers to the parallel dimensions of `reduce_op` if any and the inner
210 // ParallelOp refers to the reduction dimensions. The scf.reduce op is
211 // returned.
212 //
213 // If the reduction argument is a memref<100x10x5xf32> and the
214 // reduction is performed along dimension 1 then this method will generate
215 //
216 // %init = load %init_buf[] : memref<f32>
217 // scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) {
218 // %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) {
219 // %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32>
220 // scf.reduce(%elem_to_reduce) {
221 // <THE BLOCK PTR TO BE RETURNED>
222 // } : f32
223 // scf.yield
224 // } : f32
225 // scf.yield
226 // }
CreateReduceOpInNestedParallelLoops(lmhlo::ReduceOp reduce_op,ConversionPatternRewriter * rewriter) const227 scf::ReduceOp CreateReduceOpInNestedParallelLoops(
228 lmhlo::ReduceOp reduce_op, ConversionPatternRewriter* rewriter) const {
229 auto loc = reduce_op.getLoc();
230 DenseSet<int> reducing_dims;
231 for (const auto& rdim : reduce_op.dimensions().getIntValues()) {
232 reducing_dims.insert(rdim.getSExtValue());
233 }
234
235 Value operand = *reduce_op.operands().begin();
236 Value out = *reduce_op.out().begin();
237 SmallVector<Value, 2> parallel_lower, parallel_upper, parallel_step;
238 SmallVector<Value, 2> reduce_lower, reduce_upper, reduce_step;
239 auto operand_shape = operand.getType().cast<MemRefType>().getShape();
240 for (auto dim : llvm::enumerate(operand_shape)) {
241 const bool is_reducing_dim = reducing_dims.count(dim.index());
242
243 Value ub = GetStaticOrDynamicDim(loc, operand, dim.index(), dim.value(),
244 rewriter);
245 Value lb = rewriter->create<ConstantIndexOp>(loc, 0);
246 Value step = rewriter->create<ConstantIndexOp>(loc, 1);
247 (is_reducing_dim ? reduce_lower : parallel_lower).push_back(lb);
248 (is_reducing_dim ? reduce_upper : parallel_upper).push_back(ub);
249 (is_reducing_dim ? reduce_step : parallel_step).push_back(step);
250 }
251 // Load initial value from memref<element_type>.
252 SmallVector<Value, 1> init_value = {
253 rewriter->create<LoadOp>(loc, *reduce_op.init_values().begin())};
254 // Outer ParallelOp is not needed if it is a reduction across all dims.
255 scf::ParallelOp outer;
256 if (!parallel_lower.empty()) {
257 outer = rewriter->create<scf::ParallelOp>(loc, parallel_lower,
258 parallel_upper, parallel_step);
259 rewriter->setInsertionPointToStart(outer.getBody());
260 }
261 scf::ParallelOp inner = rewriter->create<scf::ParallelOp>(
262 loc, reduce_lower, reduce_upper, reduce_step, ValueRange(init_value));
263 Value reduction_result = *inner.getResults().begin();
264
265 SmallVector<Value, 1> out_indices;
266 if (outer != nullptr) {
267 out_indices.reserve(outer.getNumLoops());
268 for (Value iv : outer.getInductionVars()) {
269 out_indices.push_back(iv);
270 }
271 } else {
272 out_indices.push_back(rewriter->create<ConstantIndexOp>(loc, 0));
273 }
274
275 rewriter->create<StoreOp>(loc, reduction_result, out, out_indices);
276
277 // Load the element to reduce.
278 SmallVector<Value, 2> indices;
279 indices.reserve(operand_shape.size());
280
281 if (outer) {
282 auto inner_ivs_it = inner.getInductionVars().begin();
283 auto outer_ivs_it = outer.getInductionVars().begin();
284 for (unsigned i = 0, e = operand_shape.size(); i < e; ++i) {
285 indices.push_back(reducing_dims.count(i) ? *inner_ivs_it++
286 : *outer_ivs_it++);
287 }
288 } else {
289 indices = inner.getInductionVars();
290 }
291
292 rewriter->setInsertionPointToStart(inner.getBody());
293 Value elem = rewriter->create<mlir::LoadOp>(
294 loc, *reduce_op.operands().begin(), indices);
295 return rewriter->create<scf::ReduceOp>(loc, elem);
296 }
297 };
298
299 // Pseudocode:
300 // for each index O in output
301 // accumulator = neutral_value
302 // in_bounds = true
303 // for each index W in window
304 // for each dimension i from 0 to rank - 1
305 // index = O[i] * stride[i] + W[i] - pad_low[i]
306 // in_bounds = inbounds && (index `ult` shape[i])
307 // I[i] = index
308 // if (in_bounds)
309 // value = input[I]
310 // else
311 // value = neutral_value
312 // accumulator = reduction_operator(output[O], value)
313 // output[O] = accumulator
314 //
315 // Converts `lmhlo.ReduceWindowOp` into two scf::ParallelOp and a
316 // scf::ReduceOp.
317 // The outper `ParallelOp` refers to the parallel loops that traverese output
318 // buffer. The inner `ParalleOp` refers to the reduction loops that traverse
319 // reduction windows and `ReduceOp` contains the reduction operator.
320 //
321 // Example:
322 //
323 // func @reduce_window(%arg: memref<112x112xf32>,
324 // %init: memref<f32>,
325 // %result: memref<56x56xf32>) {
326 // "lmhlo.reduce_window"(%arg, %init, %result) ( {
327 // ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
328 // "lmhlo.maximum"(%lhs, %rhs, %res)
329 // : (memref<f32>, memref<f32>, memref<f32>) -> ()
330 // "lmhlo.terminator"() : () -> ()
331 // }) {
332 // padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
333 // window_dimensions = dense<[3, 3]> : tensor<2xi64>,
334 // window_strides = dense<[2, 2]> : tensor<2xi64>
335 // } : (memref<112x112xf32>, memref<f32>, memref<56x56xf32>) -> ()
336 // return
337 // }
338 //
339 // is roughly converted into:
340 //
341 // %neutral_elem = load %init_buf[] : memref<f32>
342 // scf.parallel (%i, %j) = (%c0, %c0) to (%c56, %c56) step (%c1, %c1) {
343 // %result = scf.parallel (%iw, %jw) = (%c0, %c0)
344 // to (%c3, %c3) step (%c1, %c1) neutral_elem (%0) -> f32 {
345 // %in_bounds = <COMPUTE IF INDEX IS IN OPERAND'S pad>
346 // %elem = load %operand[%computed_i, %computed_j]
347 // %elem_or_neutral = select %in_bounds, %elem, %neutral_elem : f32
348 // scf.reduce(%elem_to_reduce) : f32 {
349 // ^bb0(%arg7: f32, %arg8: f32):
350 // <LHLO ops>
351 // }
352 // scf.yield
353 // }
354 // store %result, %output_buffer[%i, %j] : memref<56x56xf32>
355 // scf.yield
356 // }
357 // return
358 // }
359 class ReduceWindowOpConverter
360 : public OpConversionPattern<lmhlo::ReduceWindowOp> {
361 public:
362 using OpConversionPattern<lmhlo::ReduceWindowOp>::OpConversionPattern;
363
matchAndRewrite(lmhlo::ReduceWindowOp reduce_window_op,ArrayRef<Value>,ConversionPatternRewriter & rewriter) const364 LogicalResult matchAndRewrite(
365 lmhlo::ReduceWindowOp reduce_window_op, ArrayRef<Value> /*args*/,
366 ConversionPatternRewriter& rewriter) const final {
367 scf::ParallelOp output_loop, window_loop;
368 std::tie(output_loop, window_loop) =
369 CreateParallelLoopsToTraverseOutputAndWindow(reduce_window_op,
370 &rewriter);
371
372 scf::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops(
373 reduce_window_op, output_loop, window_loop, &rewriter);
374
375 ConvertToReductionOperator(reduce_window_op.getLoc(), reduce_op,
376 &reduce_window_op.body().front(), &rewriter);
377 rewriter.replaceOp(reduce_window_op, llvm::None);
378 return success();
379 }
380
381 private:
382 std::pair<scf::ParallelOp, scf::ParallelOp>
CreateParallelLoopsToTraverseOutputAndWindow(lmhlo::ReduceWindowOp reduce_window_op,ConversionPatternRewriter * rewriter) const383 CreateParallelLoopsToTraverseOutputAndWindow(
384 lmhlo::ReduceWindowOp reduce_window_op,
385 ConversionPatternRewriter* rewriter) const {
386 auto loc = reduce_window_op.getLoc();
387 Value init_value =
388 rewriter->create<LoadOp>(loc, reduce_window_op.init_value());
389
390 Value zero = rewriter->create<ConstantIndexOp>(loc, 0);
391 Value one = rewriter->create<ConstantIndexOp>(loc, 1);
392
393 // Create an outer parallel loop that spans the output of ReduceWindowOp.
394 Value output = reduce_window_op.out();
395 auto output_loop = MakeLoopOverShape(loc, output, rewriter);
396
397 // Create a nested loop that traverses the window.
398 SmallVector<Value, 2> window_lower, window_upper, window_step;
399 rewriter->setInsertionPointToStart(output_loop.getBody());
400 for (const auto& window_dim : reduce_window_op.window_dimensions()) {
401 window_step.push_back(one);
402 window_lower.push_back(zero);
403 window_upper.push_back(
404 rewriter->create<ConstantIndexOp>(loc, window_dim.getSExtValue()));
405 }
406 auto window_loop = rewriter->create<scf::ParallelOp>(
407 loc, window_lower, window_upper, window_step, ValueRange(init_value));
408
409 Value reduction_result = *window_loop.getResults().begin();
410 auto output_ivs = output_loop.getInductionVars();
411 rewriter->create<StoreOp>(loc, reduction_result, output, output_ivs);
412 return std::make_pair(output_loop, window_loop);
413 }
414
CreateReduceOpInNestedParallelLoops(lmhlo::ReduceWindowOp reduce_window_op,scf::ParallelOp output_loop,scf::ParallelOp window_loop,ConversionPatternRewriter * rewriter) const415 scf::ReduceOp CreateReduceOpInNestedParallelLoops(
416 lmhlo::ReduceWindowOp reduce_window_op, scf::ParallelOp output_loop,
417 scf::ParallelOp window_loop, ConversionPatternRewriter* rewriter) const {
418 rewriter->setInsertionPointToStart(window_loop.getBody());
419 auto loc = reduce_window_op.getLoc();
420
421 if (reduce_window_op.base_dilations().hasValue() ||
422 reduce_window_op.window_dilations().hasValue()) {
423 reduce_window_op.emitRemark(
424 "Lowering to parallel loops does not support `base_dilations` or "
425 "`window_dilations` attributes yet. The attributes will be ignored.");
426 }
427
428 Value operand = reduce_window_op.operand();
429 auto operand_type = operand.getType().cast<MemRefType>();
430
431 // Compute ivs in 'arg' buffer and whether these ivs are in pad area or not.
432 MappedIvs mapped_ivs =
433 MapWindowIvsToInput(reduce_window_op, output_loop.getInductionVars(),
434 window_loop.getInductionVars(), rewriter);
435
436 auto elem_or_init = rewriter->create<scf::IfOp>(
437 loc, operand_type.getElementType(), mapped_ivs.in_bounds,
438 /*withElseRegion=*/true);
439
440 OpBuilder then_builder =
441 elem_or_init.getThenBodyBuilder(rewriter->getListener());
442 Value elem = then_builder.create<mlir::LoadOp>(
443 loc, reduce_window_op.operand(), mapped_ivs.ivs);
444 then_builder.create<scf::YieldOp>(loc, elem);
445
446 OpBuilder else_builder =
447 elem_or_init.getElseBodyBuilder(rewriter->getListener());
448 else_builder.create<scf::YieldOp>(loc, *window_loop.initVals().begin());
449
450 return rewriter->create<scf::ReduceOp>(loc,
451 *elem_or_init.results().begin());
452 }
453 };
454
455 // See the operation semantics in
456 // https://www.tensorflow.org/xla/operation_semantics#selectandscatter
457 //
458 // Pseudocode:
459 // scf.parallel(coordinates O in the output):
460 // output[O] = init
461 // scf.parallel(coordinates S in the source):
462 // selected_ivs = 0
463 // selected_val = 0
464 // initialized_flag = false
465 // scf.for (first dim W_1 in the window)
466 // iter_args (selected_ivs, selected_val, initialized_flag):
467 // ...
468 // scf.for (last dim W_N in the window):
469 // iter_args (selected_ivs, selected_val, initialized_flag):
470 // I = S * stride + W - pad_low
471 // if I within bounds of operand:
472 // if (initialized_flag):
473 // pred = select(selected_value, operand(I))):
474 // if (pred)
475 // selected_value = operand(I)
476 // selected_index = I
477 // else
478 // selected_value = operand(I)
479 // selected_index = I
480 // initialized_flag = true
481 // output(selected_index) = scatter(output(selected_index), source(S))
482 class SelectAndScatterOpConverter
483 : public OpConversionPattern<lmhlo::SelectAndScatterOp> {
484 public:
485 using OpConversionPattern<lmhlo::SelectAndScatterOp>::OpConversionPattern;
486
matchAndRewrite(lmhlo::SelectAndScatterOp s_and_s_op,ArrayRef<Value>,ConversionPatternRewriter & rewriter) const487 LogicalResult matchAndRewrite(
488 lmhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> /*args*/,
489 ConversionPatternRewriter& rewriter) const final {
490 auto loc = s_and_s_op.getLoc();
491 InitializeOutput(s_and_s_op, &rewriter);
492 scf::ParallelOp loop_over_src =
493 MakeLoopOverShape(loc, s_and_s_op.source(), &rewriter);
494 rewriter.setInsertionPointToStart(loop_over_src.getBody());
495
496 // Compute indices of the selected element in the window.
497 auto selected_ivs = SelectIvs(s_and_s_op, loop_over_src, &rewriter);
498
499 // Load `source[selected_ivs]`.
500 auto src_elem = rewriter.create<LoadOp>(loc, s_and_s_op.source(),
501 loop_over_src.getInductionVars());
502
503 // Compute `out[selected_ivs]` = scatter(out[selected_ivs], src_element)`.
504 auto rmw = rewriter.create<GenericAtomicRMWOp>(loc, s_and_s_op.out(),
505 selected_ivs);
506 OpBuilder rmw_builder = OpBuilder::atBlockEnd(rmw.getBody());
507 auto acc_result =
508 ApplySingleResultLhloCode(loc, {src_elem, rmw.getCurrentValue()},
509 &s_and_s_op.scatter().front(), &rmw_builder);
510 rmw_builder.create<AtomicYieldOp>(loc, acc_result);
511
512 rewriter.replaceOp(s_and_s_op, llvm::None);
513 return success();
514 }
515
516 private:
InitializeOutput(lmhlo::SelectAndScatterOp s_and_s_op,OpBuilder * b) const517 void InitializeOutput(lmhlo::SelectAndScatterOp s_and_s_op,
518 OpBuilder* b) const {
519 auto loc = s_and_s_op.getLoc();
520 Value init_value = b->create<LoadOp>(loc, s_and_s_op.init_value());
521
522 scf::ParallelOp loop_over_output =
523 MakeLoopOverShape(loc, s_and_s_op.out(), b);
524 OpBuilder::InsertionGuard guard(*b);
525 b->setInsertionPointToStart(loop_over_output.getBody());
526 b->create<StoreOp>(loc, init_value, s_and_s_op.out(),
527 loop_over_output.getInductionVars());
528 }
529
530 struct WindowLoops {
531 SmallVector<Value, 2> selected_ivs;
532 SmallVector<Value, 2> window_ivs;
533 scf::ForOp inner_loop;
534 };
InsertWindowLoops(lmhlo::SelectAndScatterOp s_and_s_op,scf::ParallelOp loop_over_src,OpBuilder * b) const535 WindowLoops InsertWindowLoops(lmhlo::SelectAndScatterOp s_and_s_op,
536 scf::ParallelOp loop_over_src,
537 OpBuilder* b) const {
538 auto loc = s_and_s_op.getLoc();
539 Value zero = b->create<ConstantIndexOp>(loc, 0);
540 Value one = b->create<ConstantIndexOp>(loc, 1);
541
542 auto element_type =
543 s_and_s_op.out().getType().cast<MemRefType>().getElementType();
544 auto rank = loop_over_src.getNumLoops();
545
546 // `iter_args` = [iv_1, ..., iv_N, selected_value, is_initialized]
547 SmallVector<Value, 4> iter_args(rank, zero);
548 iter_args.push_back(b->create<mlir::ConstantOp>(
549 loc, element_type, b->getFloatAttr(element_type, 0)));
550 iter_args.push_back(b->create<mlir::ConstantOp>(
551 loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 0)));
552
553 // Create a nested loop that traverses the window.
554 OpBuilder::InsertPoint ip;
555 WindowLoops result;
556 for (const auto& window_dim :
557 s_and_s_op.window_dimensions()->getIntValues()) {
558 Value upper = b->create<ConstantIndexOp>(loc, window_dim.getSExtValue());
559 result.inner_loop =
560 b->create<scf::ForOp>(loc, zero, upper, one, iter_args);
561 if (b->getInsertionBlock() == loop_over_src.getBody()) {
562 ip = b->saveInsertionPoint();
563 result.selected_ivs = result.inner_loop.getResults().take_front(rank);
564 } else {
565 b->create<scf::YieldOp>(loc, result.inner_loop.getResults());
566 }
567 b->setInsertionPointToStart(result.inner_loop.getBody());
568 iter_args = ValueRange{result.inner_loop.getRegionIterArgs()};
569 result.window_ivs.push_back(result.inner_loop.getInductionVar());
570 }
571 b->restoreInsertionPoint(ip);
572 return result;
573 }
574
575 // Adapter to store iteration arguments of sequential loops that perform
576 // select in a window.
577 class IterArgs {
578 public:
IterArgs(ValueRange ivs_val_flag)579 explicit IterArgs(ValueRange ivs_val_flag) : ivs_val_flag_(ivs_val_flag) {}
IterArgs(ValueRange ivs,Value value,Value flag)580 IterArgs(ValueRange ivs, Value value, Value flag) {
581 ivs_val_flag_ = ivs;
582 ivs_val_flag_.push_back(value);
583 ivs_val_flag_.push_back(flag);
584 }
585
to_vector() const586 ArrayRef<Value> to_vector() const { return ivs_val_flag_; }
587
588 // Indices of the currently selected value.
ivs() const589 ArrayRef<Value> ivs() const { return to_vector().drop_back(2); }
590 // Currently selected value w.r.t. select() function.
value() const591 Value value() const { return ivs_val_flag_.end()[-2]; }
592 // i1 flag if value() and ivs() were initialized.
is_init() const593 Value is_init() const { return ivs_val_flag_.back(); }
594
595 private:
596 // Vector that stores iv_1, ..., iv_N, value, init.
597 SmallVector<Value, 4> ivs_val_flag_;
598 };
599
SelectIvs(lmhlo::SelectAndScatterOp s_and_s_op,scf::ParallelOp loop_over_src,OpBuilder * b) const600 SmallVector<Value, 2> SelectIvs(lmhlo::SelectAndScatterOp s_and_s_op,
601 scf::ParallelOp loop_over_src,
602 OpBuilder* b) const {
603 auto loc = s_and_s_op.getLoc();
604
605 WindowLoops window_loops = InsertWindowLoops(s_and_s_op, loop_over_src, b);
606 auto inner_loop_b =
607 OpBuilder::atBlockEnd(window_loops.inner_loop.getBody());
608
609 // Compute ivs in 'arg' buffer and whether these ivs are in the pad area.
610 MappedIvs mapped_ivs =
611 MapWindowIvsToInput(s_and_s_op, loop_over_src.getInductionVars(),
612 window_loops.window_ivs, &inner_loop_b);
613
614 IterArgs ivs_val_flag(window_loops.inner_loop.getRegionIterArgs());
615
616 auto if_in_bounds = inner_loop_b.create<scf::IfOp>(
617 loc, window_loops.inner_loop.getResultTypes(), mapped_ivs.in_bounds,
618 /*withElseRegion=*/true);
619
620 // Case when we are inside boundaries of 'arg' and not in the pad area.
621 {
622 OpBuilder in_bounds_then_b =
623 if_in_bounds.getThenBodyBuilder(b->getListener());
624 auto select_or_init_results = SelectOrInitialize(
625 s_and_s_op, mapped_ivs.ivs, &ivs_val_flag, &in_bounds_then_b);
626 in_bounds_then_b.create<scf::YieldOp>(loc, select_or_init_results);
627 }
628
629 // Case when we are in the pad.
630 {
631 OpBuilder in_bounds_else_b =
632 if_in_bounds.getElseBodyBuilder(b->getListener());
633 in_bounds_else_b.create<scf::YieldOp>(loc, ivs_val_flag.to_vector());
634 }
635
636 inner_loop_b.create<scf::YieldOp>(loc, if_in_bounds.getResults());
637 return window_loops.selected_ivs;
638 }
639
SelectOrInitialize(lmhlo::SelectAndScatterOp s_and_s_op,ArrayRef<Value> operand_ivs,IterArgs * ivs_val_flag,OpBuilder * b) const640 SmallVector<Value, 4> SelectOrInitialize(lmhlo::SelectAndScatterOp s_and_s_op,
641 ArrayRef<Value> operand_ivs,
642 IterArgs* ivs_val_flag,
643 OpBuilder* b) const {
644 auto loc = s_and_s_op.getLoc();
645 Value true_i1 = b->create<mlir::ConstantOp>(
646 loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1));
647
648 TypeRange iter_arg_types{ivs_val_flag->to_vector()};
649 Value operand_elem =
650 b->create<LoadOp>(loc, s_and_s_op.operand(), operand_ivs);
651 auto if_init =
652 b->create<scf::IfOp>(loc, iter_arg_types, ivs_val_flag->is_init(),
653 /*withElseRegion=*/true);
654 // Init == true, i.e. iter args are already initialized with a selected
655 // element in boundaries of the operand. Select function has to be computed
656 // here.
657 {
658 OpBuilder if_init_then_b = if_init.getThenBodyBuilder(b->getListener());
659
660 auto& lhlo_select = s_and_s_op.select().front();
661 Value pred =
662 ApplySingleResultLhloCode(loc, {operand_elem, ivs_val_flag->value()},
663 &lhlo_select, &if_init_then_b);
664
665 auto if_pred = if_init_then_b.create<scf::IfOp>(loc, iter_arg_types, pred,
666 /*withElseRegion=*/true);
667
668 // Pred == true, therefore pack newly selected ivs, val and init flag back
669 // to iter_args and return.
670 {
671 OpBuilder if_pred_then_b = if_pred.getThenBodyBuilder(b->getListener());
672 if_pred_then_b.create<scf::YieldOp>(
673 loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());
674 }
675
676 // Pred == false, therefore return old iter_args.
677 {
678 OpBuilder if_pred_else_b = if_pred.getElseBodyBuilder(b->getListener());
679 if_pred_else_b.create<scf::YieldOp>(loc, ivs_val_flag->to_vector());
680 }
681
682 if_init_then_b.create<scf::YieldOp>(loc, if_pred.getResults());
683 }
684 // Init == false, i.e. only pad was visited before and this is the first
685 // element in the boundaries of the operand.
686 {
687 OpBuilder if_init_else_b = if_init.getElseBodyBuilder(b->getListener());
688
689 if_init_else_b.create<scf::YieldOp>(
690 loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());
691 }
692 return if_init.getResults();
693 }
694 };
695
696 struct LhloLegalizeToParallelLoopsPass
697 : public PassWrapper<LhloLegalizeToParallelLoopsPass, FunctionPass> {
getDependentDialectsmlir::lmhlo::__anonfcbbac650111::LhloLegalizeToParallelLoopsPass698 void getDependentDialects(DialectRegistry& registry) const override {
699 registry.insert<StandardOpsDialect, scf::SCFDialect>();
700 }
701
runOnFunctionmlir::lmhlo::__anonfcbbac650111::LhloLegalizeToParallelLoopsPass702 void runOnFunction() override {
703 auto func = getFunction();
704
705 OwningRewritePatternList patterns;
706 // clang-format off
707 patterns.insert<
708 ReduceOpConverter,
709 ReduceWindowOpConverter,
710 SelectAndScatterOpConverter
711 >(func.getContext());
712 // clang-format on
713
714 ConversionTarget target(getContext());
715 target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
716 scf::SCFDialect, LmhloDialect>();
717 target.addIllegalOp<lmhlo::ReduceOp, lmhlo::ReduceWindowOp,
718 lmhlo::SelectAndScatterOp>();
719
720 if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
721 signalPassFailure();
722 }
723 }
724 };
725 } // namespace
726
createLegalizeLhloToParallelLoopsPass()727 std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass() {
728 return std::make_unique<LhloLegalizeToParallelLoopsPass>();
729 }
730
731 } // namespace lmhlo
732 } // namespace mlir
733