1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
20 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
21 #include "mlir/Dialect/SCF/SCF.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 
27 namespace mlir {
28 namespace lmhlo {
29 namespace {
30 
31 // Clones and adapts the code in `lhlo_block` that works on buffers and has a
32 // single output buffer to make it compatible with `operands` that have element
33 // types of the respective buffers. Returns the computed value.
34 //
35 // Example. For `operands` with (f32, i32) types and a block with LHLO ops and
36 // with signature:
37 //   ^bb(%lhs: memref<f32>, %rhs: memref<i32>, %res: memref<i1>):
38 //     <LHLO_ops>
39 //
40 // inserts necessary alloc and store ops to compute and return result that has
41 // `i1` type.
ApplySingleResultLhloCode(Location loc,ValueRange operands,Block * lhlo_block,OpBuilder * b)42 Value ApplySingleResultLhloCode(Location loc, ValueRange operands,
43                                 Block* lhlo_block, OpBuilder* b) {
44   SmallVector<Value, 2> arg_bufs;
45   for (auto arg_type : lhlo_block->getArgumentTypes()) {
46     arg_bufs.push_back(b->create<AllocOp>(loc, arg_type.cast<MemRefType>()));
47   }
48   for (auto operand : llvm::enumerate(operands)) {
49     b->create<StoreOp>(loc, operand.value(), arg_bufs[operand.index()]);
50   }
51   // Clone the ops from `lhlo_block`.
52   BlockAndValueMapping mapping;
53   mapping.map(lhlo_block->getArguments(), arg_bufs);
54   for (auto& nested : lhlo_block->without_terminator()) {
55     auto clone = b->clone(nested, mapping);
56     mapping.map(nested.getResults(), clone->getResults());
57   }
58   return b->create<LoadOp>(loc, arg_bufs.back());
59 }
60 
61 // Converts a block with LHLO ops and with signature:
62 //   ^bb(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
63 // into a reduction operator of scf.reduce by doing buffer allocation for
64 // scalar arguments and the result of `scf.reduce` to make it compatible with
65 // LHLO ops.
ConvertToReductionOperator(Location loc,scf::ReduceOp reduce_op,Block * lhlo_block,OpBuilder * b)66 void ConvertToReductionOperator(Location loc, scf::ReduceOp reduce_op,
67                                 Block* lhlo_block, OpBuilder* b) {
68   Block& loop_reduce_op_body = reduce_op.reductionOperator().front();
69   OpBuilder::InsertionGuard guard(*b);
70   b->setInsertionPointToStart(&loop_reduce_op_body);
71   b->create<scf::ReduceReturnOp>(
72       loc, ApplySingleResultLhloCode(loc, loop_reduce_op_body.getArguments(),
73                                      lhlo_block, b));
74 }
75 
76 // Returns result of ConstantOp if `dim` is static, otherwise uses DimOp to
77 // extract dimension at runtime.
GetStaticOrDynamicDim(mlir::Location loc,Value shaped_value,size_t dim_index,int64_t dim,OpBuilder * b)78 Value GetStaticOrDynamicDim(mlir::Location loc, Value shaped_value,
79                             size_t dim_index, int64_t dim, OpBuilder* b) {
80   return dim == ShapedType::kDynamicSize
81              ? b->create<DimOp>(loc, shaped_value, dim_index).getResult()
82              : b->create<ConstantIndexOp>(loc, dim);
83 }
84 
85 struct MappedIvs {
86   // False if the mapped indices are in the padding area, true otherwise.
87   Value in_bounds;
88   // Mapped indices.
89   SmallVector<Value, 2> ivs;
90 };
91 
92 template <typename OpTy>
MapWindowIvsToInput(OpTy op,ValueRange ivs,ValueRange window_ivs,OpBuilder * b)93 MappedIvs MapWindowIvsToInput(OpTy op, ValueRange ivs, ValueRange window_ivs,
94                               OpBuilder* b) {
95   MappedIvs mapped_ivs;
96 
97   if (!op.window_strides().hasValue()) {
98     op.emitOpError("No window strides specified.");
99   }
100   auto window_strides = op.window_strides().getValue();
101 
102   if (!op.padding().hasValue()) {
103     op.emitOpError("No padding specified.");
104   }
105   auto padding = op.padding().getValue();
106 
107   auto loc = op.getLoc();
108   auto operand = op.operand();
109   auto operand_shape = operand.getType().template cast<MemRefType>().getShape();
110 
111   // `in_bounds` is false when the mapped indices are in the padding area.
112   mapped_ivs.in_bounds = b->create<mlir::ConstantOp>(
113       loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1));
114   for (unsigned i = 0, e = ivs.size(); i < e; ++i) {
115     auto stride = window_strides.template getValue<llvm::APInt>(i);
116     auto pad_low = padding.template getValue<llvm::APInt>({i, 0});
117 
118     Value stride_val = b->create<ConstantIndexOp>(loc, stride.getSExtValue());
119     Value pad_low_val = b->create<ConstantIndexOp>(loc, pad_low.getSExtValue());
120 
121     Value center = b->create<MulIOp>(loc, ivs[i], stride_val);
122     Value offset = b->create<SubIOp>(loc, window_ivs[i], pad_low_val);
123     Value index = b->create<AddIOp>(loc, center, offset);
124     Value upper_bound =
125         GetStaticOrDynamicDim(loc, operand, i, operand_shape[i], b);
126     // We must check whether 0 <= index_i < shape_i, as otherwise we are in
127     // the pad and then we have to use the neutral element for reduction.
128     // Equivalently, it can be computed as the unsigned comparison index_i <
129     // shape_i, since a negative value wraps to a large positive value.
130     mapped_ivs.in_bounds = b->create<mlir::AndOp>(
131         loc, mapped_ivs.in_bounds,
132         b->create<CmpIOp>(loc, CmpIPredicate::ult, index, upper_bound));
133     mapped_ivs.ivs.push_back(index);
134   }
135   return mapped_ivs;
136 }
137 
138 // Returns scf::Parallel over a shaped value with static or dynamic shape.
MakeLoopOverShape(Location loc,Value shaped_value,OpBuilder * b)139 scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value,
140                                   OpBuilder* b) {
141   Value zero = b->create<ConstantIndexOp>(loc, 0);
142   Value one = b->create<ConstantIndexOp>(loc, 1);
143 
144   ArrayRef<int64_t> shape =
145       shaped_value.getType().cast<ShapedType>().getShape();
146   SmallVector<Value, 2> lower, upper, step;
147   for (auto dim : llvm::enumerate(shape)) {
148     upper.push_back(
149         GetStaticOrDynamicDim(loc, shaped_value, dim.index(), dim.value(), b));
150     lower.push_back(zero);
151     step.push_back(one);
152   }
153   return b->create<scf::ParallelOp>(loc, lower, upper, step);
154 }
155 
156 // Converts `lmhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp.
157 // The outper `ParallelOp` refers to the parallel loops if there are
158 // any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp`
159 // contains the reduction operator.
160 //
161 // Example:
162 //
163 //  "lmhlo.reduce"(%buffer, %init_buf, %result) ( {
164 //    ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
165 //      <LHLO ops>
166 //    } ) {dimensions = dense<[1]> : tensor<1xi64>}
167 //      : (memref<100x10x5xf32>, memref<f32>, memref<100x5xf32>) -> ()
168 //
169 //  is roughly converted into:
170 //
171 //  %init = load %init_buf[] : memref<f32>
172 //  scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) {
173 //    %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) {
174 //      %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32>
175 //      scf.reduce(%elem_to_reduce)  {
176 //        ^bb0(%elem: f32, %acc: f32):   // no predecessors
177 //          elem_buf = alloc() : memref<f32>
178 //          store %elem, elem_buf[] : memref<f32>
179 //          acc_buf = alloc() : memref<f32>
180 //          store %acc, acc_buf[] : memref<f32>
181 //          <LHLO_ops>
182 //          %acc_result = load acc_buf[] : memref<f32>
183 //          scf.reduce.return %acc_result : f32
184 //      } : f32
185 //      scf.yield
186 //    } : f32
187 //    scf.yield
188 //  }
189 class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
190  public:
191   using OpConversionPattern<lmhlo::ReduceOp>::OpConversionPattern;
192 
matchAndRewrite(lmhlo::ReduceOp reduce_op,ArrayRef<Value>,ConversionPatternRewriter & rewriter) const193   LogicalResult matchAndRewrite(
194       lmhlo::ReduceOp reduce_op, ArrayRef<Value> /*args*/,
195       ConversionPatternRewriter& rewriter) const final {
196     // TODO(b/137624192) Implement variadic reduce.
197     if (reduce_op.out().size() != 1) return failure();
198 
199     scf::ReduceOp scf_reduce_op =
200         CreateReduceOpInNestedParallelLoops(reduce_op, &rewriter);
201     ConvertToReductionOperator(reduce_op.getLoc(), scf_reduce_op,
202                                &reduce_op.body().front(), &rewriter);
203     rewriter.replaceOp(reduce_op, llvm::None);
204     return success();
205   }
206 
207  private:
208   // Creates nested `scf.parallel` ops with `scf.reduce`. The outer ParallelOp
209   // refers to the parallel dimensions of `reduce_op` if any and the inner
210   // ParallelOp refers to the reduction dimensions. The scf.reduce op is
211   // returned.
212   //
213   // If the reduction argument is a memref<100x10x5xf32> and the
214   // reduction is performed along dimension 1 then this method will generate
215   //
216   //  %init = load %init_buf[] : memref<f32>
217   //  scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) {
218   //    %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) {
219   //      %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32>
220   //      scf.reduce(%elem_to_reduce)  {
221   //        <THE BLOCK PTR TO BE RETURNED>
222   //      } : f32
223   //      scf.yield
224   //    } : f32
225   //    scf.yield
226   //  }
CreateReduceOpInNestedParallelLoops(lmhlo::ReduceOp reduce_op,ConversionPatternRewriter * rewriter) const227   scf::ReduceOp CreateReduceOpInNestedParallelLoops(
228       lmhlo::ReduceOp reduce_op, ConversionPatternRewriter* rewriter) const {
229     auto loc = reduce_op.getLoc();
230     DenseSet<int> reducing_dims;
231     for (const auto& rdim : reduce_op.dimensions().getIntValues()) {
232       reducing_dims.insert(rdim.getSExtValue());
233     }
234 
235     Value operand = *reduce_op.operands().begin();
236     Value out = *reduce_op.out().begin();
237     SmallVector<Value, 2> parallel_lower, parallel_upper, parallel_step;
238     SmallVector<Value, 2> reduce_lower, reduce_upper, reduce_step;
239     auto operand_shape = operand.getType().cast<MemRefType>().getShape();
240     for (auto dim : llvm::enumerate(operand_shape)) {
241       const bool is_reducing_dim = reducing_dims.count(dim.index());
242 
243       Value ub = GetStaticOrDynamicDim(loc, operand, dim.index(), dim.value(),
244                                        rewriter);
245       Value lb = rewriter->create<ConstantIndexOp>(loc, 0);
246       Value step = rewriter->create<ConstantIndexOp>(loc, 1);
247       (is_reducing_dim ? reduce_lower : parallel_lower).push_back(lb);
248       (is_reducing_dim ? reduce_upper : parallel_upper).push_back(ub);
249       (is_reducing_dim ? reduce_step : parallel_step).push_back(step);
250     }
251     // Load initial value from memref<element_type>.
252     SmallVector<Value, 1> init_value = {
253         rewriter->create<LoadOp>(loc, *reduce_op.init_values().begin())};
254     // Outer ParallelOp is not needed if it is a reduction across all dims.
255     scf::ParallelOp outer;
256     if (!parallel_lower.empty()) {
257       outer = rewriter->create<scf::ParallelOp>(loc, parallel_lower,
258                                                 parallel_upper, parallel_step);
259       rewriter->setInsertionPointToStart(outer.getBody());
260     }
261     scf::ParallelOp inner = rewriter->create<scf::ParallelOp>(
262         loc, reduce_lower, reduce_upper, reduce_step, ValueRange(init_value));
263     Value reduction_result = *inner.getResults().begin();
264 
265     SmallVector<Value, 1> out_indices;
266     if (outer != nullptr) {
267       out_indices.reserve(outer.getNumLoops());
268       for (Value iv : outer.getInductionVars()) {
269         out_indices.push_back(iv);
270       }
271     } else {
272       out_indices.push_back(rewriter->create<ConstantIndexOp>(loc, 0));
273     }
274 
275     rewriter->create<StoreOp>(loc, reduction_result, out, out_indices);
276 
277     // Load the element to reduce.
278     SmallVector<Value, 2> indices;
279     indices.reserve(operand_shape.size());
280 
281     if (outer) {
282       auto inner_ivs_it = inner.getInductionVars().begin();
283       auto outer_ivs_it = outer.getInductionVars().begin();
284       for (unsigned i = 0, e = operand_shape.size(); i < e; ++i) {
285         indices.push_back(reducing_dims.count(i) ? *inner_ivs_it++
286                                                  : *outer_ivs_it++);
287       }
288     } else {
289       indices = inner.getInductionVars();
290     }
291 
292     rewriter->setInsertionPointToStart(inner.getBody());
293     Value elem = rewriter->create<mlir::LoadOp>(
294         loc, *reduce_op.operands().begin(), indices);
295     return rewriter->create<scf::ReduceOp>(loc, elem);
296   }
297 };
298 
299 // Pseudocode:
300 // for each index O in output
301 //   accumulator = neutral_value
302 //   in_bounds = true
303 //   for each index W in window
304 //     for each dimension i from 0 to rank - 1
305 //       index = O[i] * stride[i] + W[i] - pad_low[i]
306 //       in_bounds = inbounds && (index `ult` shape[i])
307 //       I[i] = index
308 //     if (in_bounds)
309 //       value = input[I]
310 //     else
311 //       value = neutral_value
312 //     accumulator = reduction_operator(output[O], value)
313 //   output[O] = accumulator
314 //
315 // Converts `lmhlo.ReduceWindowOp` into two scf::ParallelOp and a
316 // scf::ReduceOp.
317 // The outper `ParallelOp` refers to the parallel loops that traverese output
318 // buffer. The inner `ParalleOp` refers to the reduction loops that traverse
319 // reduction windows and `ReduceOp` contains the reduction operator.
320 //
321 // Example:
322 //
323 // func @reduce_window(%arg: memref<112x112xf32>,
324 //              %init: memref<f32>,
325 //              %result: memref<56x56xf32>) {
326 //   "lmhlo.reduce_window"(%arg, %init, %result) ( {
327 //     ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
328 //       "lmhlo.maximum"(%lhs, %rhs, %res)
329 //         : (memref<f32>, memref<f32>, memref<f32>) -> ()
330 //       "lmhlo.terminator"() : () -> ()
331 //     }) {
332 //       padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
333 //       window_dimensions = dense<[3, 3]> : tensor<2xi64>,
334 //       window_strides = dense<[2, 2]> : tensor<2xi64>
335 //     } : (memref<112x112xf32>, memref<f32>, memref<56x56xf32>) -> ()
336 //   return
337 // }
338 //
339 // is roughly converted into:
340 //
341 //    %neutral_elem = load %init_buf[] : memref<f32>
342 //    scf.parallel (%i, %j) = (%c0, %c0) to (%c56, %c56) step (%c1, %c1) {
343 //      %result = scf.parallel (%iw, %jw) = (%c0, %c0)
344 //                  to (%c3, %c3) step (%c1, %c1) neutral_elem (%0) -> f32 {
345 //        %in_bounds = <COMPUTE IF INDEX IS IN OPERAND'S pad>
346 //        %elem = load %operand[%computed_i, %computed_j]
347 //        %elem_or_neutral = select %in_bounds, %elem, %neutral_elem : f32
348 //        scf.reduce(%elem_to_reduce)  : f32 {
349 //          ^bb0(%arg7: f32, %arg8: f32):
350 //            <LHLO ops>
351 //        }
352 //        scf.yield
353 //      }
354 //      store %result, %output_buffer[%i, %j] : memref<56x56xf32>
355 //      scf.yield
356 //    }
357 //    return
358 //  }
359 class ReduceWindowOpConverter
360     : public OpConversionPattern<lmhlo::ReduceWindowOp> {
361  public:
362   using OpConversionPattern<lmhlo::ReduceWindowOp>::OpConversionPattern;
363 
matchAndRewrite(lmhlo::ReduceWindowOp reduce_window_op,ArrayRef<Value>,ConversionPatternRewriter & rewriter) const364   LogicalResult matchAndRewrite(
365       lmhlo::ReduceWindowOp reduce_window_op, ArrayRef<Value> /*args*/,
366       ConversionPatternRewriter& rewriter) const final {
367     scf::ParallelOp output_loop, window_loop;
368     std::tie(output_loop, window_loop) =
369         CreateParallelLoopsToTraverseOutputAndWindow(reduce_window_op,
370                                                      &rewriter);
371 
372     scf::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops(
373         reduce_window_op, output_loop, window_loop, &rewriter);
374 
375     ConvertToReductionOperator(reduce_window_op.getLoc(), reduce_op,
376                                &reduce_window_op.body().front(), &rewriter);
377     rewriter.replaceOp(reduce_window_op, llvm::None);
378     return success();
379   }
380 
381  private:
382   std::pair<scf::ParallelOp, scf::ParallelOp>
CreateParallelLoopsToTraverseOutputAndWindow(lmhlo::ReduceWindowOp reduce_window_op,ConversionPatternRewriter * rewriter) const383   CreateParallelLoopsToTraverseOutputAndWindow(
384       lmhlo::ReduceWindowOp reduce_window_op,
385       ConversionPatternRewriter* rewriter) const {
386     auto loc = reduce_window_op.getLoc();
387     Value init_value =
388         rewriter->create<LoadOp>(loc, reduce_window_op.init_value());
389 
390     Value zero = rewriter->create<ConstantIndexOp>(loc, 0);
391     Value one = rewriter->create<ConstantIndexOp>(loc, 1);
392 
393     // Create an outer parallel loop that spans the output of ReduceWindowOp.
394     Value output = reduce_window_op.out();
395     auto output_loop = MakeLoopOverShape(loc, output, rewriter);
396 
397     // Create a nested loop that traverses the window.
398     SmallVector<Value, 2> window_lower, window_upper, window_step;
399     rewriter->setInsertionPointToStart(output_loop.getBody());
400     for (const auto& window_dim : reduce_window_op.window_dimensions()) {
401       window_step.push_back(one);
402       window_lower.push_back(zero);
403       window_upper.push_back(
404           rewriter->create<ConstantIndexOp>(loc, window_dim.getSExtValue()));
405     }
406     auto window_loop = rewriter->create<scf::ParallelOp>(
407         loc, window_lower, window_upper, window_step, ValueRange(init_value));
408 
409     Value reduction_result = *window_loop.getResults().begin();
410     auto output_ivs = output_loop.getInductionVars();
411     rewriter->create<StoreOp>(loc, reduction_result, output, output_ivs);
412     return std::make_pair(output_loop, window_loop);
413   }
414 
CreateReduceOpInNestedParallelLoops(lmhlo::ReduceWindowOp reduce_window_op,scf::ParallelOp output_loop,scf::ParallelOp window_loop,ConversionPatternRewriter * rewriter) const415   scf::ReduceOp CreateReduceOpInNestedParallelLoops(
416       lmhlo::ReduceWindowOp reduce_window_op, scf::ParallelOp output_loop,
417       scf::ParallelOp window_loop, ConversionPatternRewriter* rewriter) const {
418     rewriter->setInsertionPointToStart(window_loop.getBody());
419     auto loc = reduce_window_op.getLoc();
420 
421     if (reduce_window_op.base_dilations().hasValue() ||
422         reduce_window_op.window_dilations().hasValue()) {
423       reduce_window_op.emitRemark(
424           "Lowering to parallel loops does not support `base_dilations` or "
425           "`window_dilations` attributes yet. The attributes will be ignored.");
426     }
427 
428     Value operand = reduce_window_op.operand();
429     auto operand_type = operand.getType().cast<MemRefType>();
430 
431     // Compute ivs in 'arg' buffer and whether these ivs are in pad area or not.
432     MappedIvs mapped_ivs =
433         MapWindowIvsToInput(reduce_window_op, output_loop.getInductionVars(),
434                             window_loop.getInductionVars(), rewriter);
435 
436     auto elem_or_init = rewriter->create<scf::IfOp>(
437         loc, operand_type.getElementType(), mapped_ivs.in_bounds,
438         /*withElseRegion=*/true);
439 
440     OpBuilder then_builder =
441         elem_or_init.getThenBodyBuilder(rewriter->getListener());
442     Value elem = then_builder.create<mlir::LoadOp>(
443         loc, reduce_window_op.operand(), mapped_ivs.ivs);
444     then_builder.create<scf::YieldOp>(loc, elem);
445 
446     OpBuilder else_builder =
447         elem_or_init.getElseBodyBuilder(rewriter->getListener());
448     else_builder.create<scf::YieldOp>(loc, *window_loop.initVals().begin());
449 
450     return rewriter->create<scf::ReduceOp>(loc,
451                                            *elem_or_init.results().begin());
452   }
453 };
454 
455 // See the operation semantics in
456 // https://www.tensorflow.org/xla/operation_semantics#selectandscatter
457 //
458 // Pseudocode:
459 //  scf.parallel(coordinates O in the output):
460 //    output[O] = init
461 //  scf.parallel(coordinates S in the source):
462 //    selected_ivs = 0
463 //    selected_val = 0
464 //    initialized_flag = false
465 //    scf.for (first dim W_1 in the window)
466 //         iter_args (selected_ivs, selected_val, initialized_flag):
467 //    ...
468 //      scf.for (last dim W_N in the window):
469 //           iter_args (selected_ivs, selected_val, initialized_flag):
470 //        I = S * stride + W - pad_low
471 //        if I within bounds of operand:
472 //          if (initialized_flag):
473 //            pred = select(selected_value, operand(I))):
474 //            if (pred)
475 //              selected_value = operand(I)
476 //              selected_index = I
477 //          else
478 //              selected_value = operand(I)
479 //              selected_index = I
480 //              initialized_flag = true
481 //    output(selected_index) = scatter(output(selected_index), source(S))
482 class SelectAndScatterOpConverter
483     : public OpConversionPattern<lmhlo::SelectAndScatterOp> {
484  public:
485   using OpConversionPattern<lmhlo::SelectAndScatterOp>::OpConversionPattern;
486 
matchAndRewrite(lmhlo::SelectAndScatterOp s_and_s_op,ArrayRef<Value>,ConversionPatternRewriter & rewriter) const487   LogicalResult matchAndRewrite(
488       lmhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> /*args*/,
489       ConversionPatternRewriter& rewriter) const final {
490     auto loc = s_and_s_op.getLoc();
491     InitializeOutput(s_and_s_op, &rewriter);
492     scf::ParallelOp loop_over_src =
493         MakeLoopOverShape(loc, s_and_s_op.source(), &rewriter);
494     rewriter.setInsertionPointToStart(loop_over_src.getBody());
495 
496     // Compute indices of the selected element in the window.
497     auto selected_ivs = SelectIvs(s_and_s_op, loop_over_src, &rewriter);
498 
499     // Load `source[selected_ivs]`.
500     auto src_elem = rewriter.create<LoadOp>(loc, s_and_s_op.source(),
501                                             loop_over_src.getInductionVars());
502 
503     // Compute `out[selected_ivs]` = scatter(out[selected_ivs], src_element)`.
504     auto rmw = rewriter.create<GenericAtomicRMWOp>(loc, s_and_s_op.out(),
505                                                    selected_ivs);
506     OpBuilder rmw_builder = OpBuilder::atBlockEnd(rmw.getBody());
507     auto acc_result =
508         ApplySingleResultLhloCode(loc, {src_elem, rmw.getCurrentValue()},
509                                   &s_and_s_op.scatter().front(), &rmw_builder);
510     rmw_builder.create<AtomicYieldOp>(loc, acc_result);
511 
512     rewriter.replaceOp(s_and_s_op, llvm::None);
513     return success();
514   }
515 
516  private:
InitializeOutput(lmhlo::SelectAndScatterOp s_and_s_op,OpBuilder * b) const517   void InitializeOutput(lmhlo::SelectAndScatterOp s_and_s_op,
518                         OpBuilder* b) const {
519     auto loc = s_and_s_op.getLoc();
520     Value init_value = b->create<LoadOp>(loc, s_and_s_op.init_value());
521 
522     scf::ParallelOp loop_over_output =
523         MakeLoopOverShape(loc, s_and_s_op.out(), b);
524     OpBuilder::InsertionGuard guard(*b);
525     b->setInsertionPointToStart(loop_over_output.getBody());
526     b->create<StoreOp>(loc, init_value, s_and_s_op.out(),
527                        loop_over_output.getInductionVars());
528   }
529 
530   struct WindowLoops {
531     SmallVector<Value, 2> selected_ivs;
532     SmallVector<Value, 2> window_ivs;
533     scf::ForOp inner_loop;
534   };
InsertWindowLoops(lmhlo::SelectAndScatterOp s_and_s_op,scf::ParallelOp loop_over_src,OpBuilder * b) const535   WindowLoops InsertWindowLoops(lmhlo::SelectAndScatterOp s_and_s_op,
536                                 scf::ParallelOp loop_over_src,
537                                 OpBuilder* b) const {
538     auto loc = s_and_s_op.getLoc();
539     Value zero = b->create<ConstantIndexOp>(loc, 0);
540     Value one = b->create<ConstantIndexOp>(loc, 1);
541 
542     auto element_type =
543         s_and_s_op.out().getType().cast<MemRefType>().getElementType();
544     auto rank = loop_over_src.getNumLoops();
545 
546     // `iter_args` = [iv_1, ..., iv_N, selected_value, is_initialized]
547     SmallVector<Value, 4> iter_args(rank, zero);
548     iter_args.push_back(b->create<mlir::ConstantOp>(
549         loc, element_type, b->getFloatAttr(element_type, 0)));
550     iter_args.push_back(b->create<mlir::ConstantOp>(
551         loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 0)));
552 
553     // Create a nested loop that traverses the window.
554     OpBuilder::InsertPoint ip;
555     WindowLoops result;
556     for (const auto& window_dim :
557          s_and_s_op.window_dimensions()->getIntValues()) {
558       Value upper = b->create<ConstantIndexOp>(loc, window_dim.getSExtValue());
559       result.inner_loop =
560           b->create<scf::ForOp>(loc, zero, upper, one, iter_args);
561       if (b->getInsertionBlock() == loop_over_src.getBody()) {
562         ip = b->saveInsertionPoint();
563         result.selected_ivs = result.inner_loop.getResults().take_front(rank);
564       } else {
565         b->create<scf::YieldOp>(loc, result.inner_loop.getResults());
566       }
567       b->setInsertionPointToStart(result.inner_loop.getBody());
568       iter_args = ValueRange{result.inner_loop.getRegionIterArgs()};
569       result.window_ivs.push_back(result.inner_loop.getInductionVar());
570     }
571     b->restoreInsertionPoint(ip);
572     return result;
573   }
574 
575   // Adapter to store iteration arguments of sequential loops that perform
576   // select in a window.
577   class IterArgs {
578    public:
IterArgs(ValueRange ivs_val_flag)579     explicit IterArgs(ValueRange ivs_val_flag) : ivs_val_flag_(ivs_val_flag) {}
IterArgs(ValueRange ivs,Value value,Value flag)580     IterArgs(ValueRange ivs, Value value, Value flag) {
581       ivs_val_flag_ = ivs;
582       ivs_val_flag_.push_back(value);
583       ivs_val_flag_.push_back(flag);
584     }
585 
to_vector() const586     ArrayRef<Value> to_vector() const { return ivs_val_flag_; }
587 
588     // Indices of the currently selected value.
ivs() const589     ArrayRef<Value> ivs() const { return to_vector().drop_back(2); }
590     // Currently selected value w.r.t. select() function.
value() const591     Value value() const { return ivs_val_flag_.end()[-2]; }
592     // i1 flag if value() and ivs() were initialized.
is_init() const593     Value is_init() const { return ivs_val_flag_.back(); }
594 
595    private:
596     // Vector that stores iv_1, ..., iv_N, value, init.
597     SmallVector<Value, 4> ivs_val_flag_;
598   };
599 
SelectIvs(lmhlo::SelectAndScatterOp s_and_s_op,scf::ParallelOp loop_over_src,OpBuilder * b) const600   SmallVector<Value, 2> SelectIvs(lmhlo::SelectAndScatterOp s_and_s_op,
601                                   scf::ParallelOp loop_over_src,
602                                   OpBuilder* b) const {
603     auto loc = s_and_s_op.getLoc();
604 
605     WindowLoops window_loops = InsertWindowLoops(s_and_s_op, loop_over_src, b);
606     auto inner_loop_b =
607         OpBuilder::atBlockEnd(window_loops.inner_loop.getBody());
608 
609     // Compute ivs in 'arg' buffer and whether these ivs are in the pad area.
610     MappedIvs mapped_ivs =
611         MapWindowIvsToInput(s_and_s_op, loop_over_src.getInductionVars(),
612                             window_loops.window_ivs, &inner_loop_b);
613 
614     IterArgs ivs_val_flag(window_loops.inner_loop.getRegionIterArgs());
615 
616     auto if_in_bounds = inner_loop_b.create<scf::IfOp>(
617         loc, window_loops.inner_loop.getResultTypes(), mapped_ivs.in_bounds,
618         /*withElseRegion=*/true);
619 
620     // Case when we are inside boundaries of 'arg' and not in the pad area.
621     {
622       OpBuilder in_bounds_then_b =
623           if_in_bounds.getThenBodyBuilder(b->getListener());
624       auto select_or_init_results = SelectOrInitialize(
625           s_and_s_op, mapped_ivs.ivs, &ivs_val_flag, &in_bounds_then_b);
626       in_bounds_then_b.create<scf::YieldOp>(loc, select_or_init_results);
627     }
628 
629     // Case when we are in the pad.
630     {
631       OpBuilder in_bounds_else_b =
632           if_in_bounds.getElseBodyBuilder(b->getListener());
633       in_bounds_else_b.create<scf::YieldOp>(loc, ivs_val_flag.to_vector());
634     }
635 
636     inner_loop_b.create<scf::YieldOp>(loc, if_in_bounds.getResults());
637     return window_loops.selected_ivs;
638   }
639 
SelectOrInitialize(lmhlo::SelectAndScatterOp s_and_s_op,ArrayRef<Value> operand_ivs,IterArgs * ivs_val_flag,OpBuilder * b) const640   SmallVector<Value, 4> SelectOrInitialize(lmhlo::SelectAndScatterOp s_and_s_op,
641                                            ArrayRef<Value> operand_ivs,
642                                            IterArgs* ivs_val_flag,
643                                            OpBuilder* b) const {
644     auto loc = s_and_s_op.getLoc();
645     Value true_i1 = b->create<mlir::ConstantOp>(
646         loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1));
647 
648     TypeRange iter_arg_types{ivs_val_flag->to_vector()};
649     Value operand_elem =
650         b->create<LoadOp>(loc, s_and_s_op.operand(), operand_ivs);
651     auto if_init =
652         b->create<scf::IfOp>(loc, iter_arg_types, ivs_val_flag->is_init(),
653                              /*withElseRegion=*/true);
654     // Init == true, i.e. iter args are already initialized with a selected
655     // element in boundaries of the operand. Select function has to be computed
656     // here.
657     {
658       OpBuilder if_init_then_b = if_init.getThenBodyBuilder(b->getListener());
659 
660       auto& lhlo_select = s_and_s_op.select().front();
661       Value pred =
662           ApplySingleResultLhloCode(loc, {operand_elem, ivs_val_flag->value()},
663                                     &lhlo_select, &if_init_then_b);
664 
665       auto if_pred = if_init_then_b.create<scf::IfOp>(loc, iter_arg_types, pred,
666                                                       /*withElseRegion=*/true);
667 
668       // Pred == true, therefore pack newly selected ivs, val and init flag back
669       // to iter_args and return.
670       {
671         OpBuilder if_pred_then_b = if_pred.getThenBodyBuilder(b->getListener());
672         if_pred_then_b.create<scf::YieldOp>(
673             loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());
674       }
675 
676       // Pred == false, therefore return old iter_args.
677       {
678         OpBuilder if_pred_else_b = if_pred.getElseBodyBuilder(b->getListener());
679         if_pred_else_b.create<scf::YieldOp>(loc, ivs_val_flag->to_vector());
680       }
681 
682       if_init_then_b.create<scf::YieldOp>(loc, if_pred.getResults());
683     }
684     // Init == false, i.e. only pad was visited before and this is the first
685     // element in the boundaries of the operand.
686     {
687       OpBuilder if_init_else_b = if_init.getElseBodyBuilder(b->getListener());
688 
689       if_init_else_b.create<scf::YieldOp>(
690           loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());
691     }
692     return if_init.getResults();
693   }
694 };
695 
696 struct LhloLegalizeToParallelLoopsPass
697     : public PassWrapper<LhloLegalizeToParallelLoopsPass, FunctionPass> {
getDependentDialectsmlir::lmhlo::__anonfcbbac650111::LhloLegalizeToParallelLoopsPass698   void getDependentDialects(DialectRegistry& registry) const override {
699     registry.insert<StandardOpsDialect, scf::SCFDialect>();
700   }
701 
runOnFunctionmlir::lmhlo::__anonfcbbac650111::LhloLegalizeToParallelLoopsPass702   void runOnFunction() override {
703     auto func = getFunction();
704 
705     OwningRewritePatternList patterns;
706     // clang-format off
707     patterns.insert<
708         ReduceOpConverter,
709         ReduceWindowOpConverter,
710         SelectAndScatterOpConverter
711       >(func.getContext());
712     // clang-format on
713 
714     ConversionTarget target(getContext());
715     target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
716                            scf::SCFDialect, LmhloDialect>();
717     target.addIllegalOp<lmhlo::ReduceOp, lmhlo::ReduceWindowOp,
718                         lmhlo::SelectAndScatterOp>();
719 
720     if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
721       signalPassFailure();
722     }
723   }
724 };
725 }  // namespace
726 
createLegalizeLhloToParallelLoopsPass()727 std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass() {
728   return std::make_unique<LhloLegalizeToParallelLoopsPass>();
729 }
730 
731 }  // namespace lmhlo
732 }  // namespace mlir
733