1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
17
18 #include <numeric>
19
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
23 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
26 #include "mlir/Dialect/Affine/IR/AffineOps.h"
27 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
28 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
29 #include "mlir/Dialect/Math/IR/Math.h"
30 #include "mlir/Dialect/SCF/SCF.h"
31 #include "mlir/Dialect/StandardOps/IR/Ops.h"
32 #include "mlir/Dialect/Tensor/IR/Tensor.h"
33 #include "mlir/IR/AffineExpr.h"
34 #include "mlir/IR/Attributes.h"
35 #include "mlir/IR/Builders.h"
36 #include "mlir/IR/BuiltinOps.h"
37 #include "mlir/IR/BuiltinTypes.h"
38 #include "mlir/IR/Location.h"
39 #include "mlir/IR/MLIRContext.h"
40 #include "mlir/IR/Matchers.h"
41 #include "mlir/IR/Operation.h"
42 #include "mlir/IR/OperationSupport.h"
43 #include "mlir/IR/PatternMatch.h"
44 #include "mlir/IR/TypeUtilities.h"
45 #include "mlir/Pass/Pass.h"
46 #include "mlir/Pass/PassManager.h"
47 #include "mlir/Transforms/DialectConversion.h"
48
49 namespace mlir {
50 namespace {
51
52 /// Returns an ArrayAttr that contains `nLoops` attributes. All the attributes
53 /// are "parallel" except the last `nReduction` elements, where are "reduction"
54 /// attributes.
GetParallelAndReductionIterators(unsigned nLoops,unsigned nReduction)55 SmallVector<StringRef, 3> GetParallelAndReductionIterators(
56 unsigned nLoops, unsigned nReduction) {
57 SmallVector<StringRef, 3> res(nLoops - nReduction,
58 getParallelIteratorTypeName());
59 res.append(nReduction, getReductionIteratorTypeName());
60 return res;
61 }
62
GetNParallelLoopsAttrs(unsigned nParallelLoops)63 SmallVector<StringRef, 3> GetNParallelLoopsAttrs(unsigned nParallelLoops) {
64 return GetParallelAndReductionIterators(nParallelLoops, 0);
65 }
66
67 template <bool isLHLO = true>
GetResultValue(Operation * op)68 Value GetResultValue(Operation* op) {
69 return isLHLO ? op->getOperand(op->getNumOperands() - 1) : op->getResult(0);
70 }
71
72 template <bool isLHLO = true>
GetHloOpResultType(Operation * op)73 ShapedType GetHloOpResultType(Operation* op) {
74 return GetResultValue<isLHLO>(op).getType().template cast<ShapedType>();
75 }
76
77 template <bool isLHLO = true>
VerifyHloOpBufferOrTensorSemantics(Operation * op)78 bool VerifyHloOpBufferOrTensorSemantics(Operation* op) {
79 auto verify_type = [&](Value val) -> bool {
80 return (isLHLO && val.getType().isa<MemRefType>()) ||
81 (!isLHLO && val.getType().isa<RankedTensorType>());
82 };
83 if (!llvm::all_of(op->getOperands(), verify_type)) return false;
84 return isLHLO ? op->getResults().empty()
85 : llvm::all_of(op->getResults(), verify_type);
86 }
87
GetInitTensor(OpBuilder & b,Location loc,ShapedType type,ArrayRef<Value> dyn_sizes)88 Value GetInitTensor(OpBuilder& b, Location loc, ShapedType type,
89 ArrayRef<Value> dyn_sizes) {
90 return b.create<linalg::InitTensorOp>(loc, dyn_sizes, type.getShape(),
91 type.getElementType());
92 }
93
ExtractDynamicSizes(OpBuilder & b,Location loc,Value tensor)94 SmallVector<Value, 2> ExtractDynamicSizes(OpBuilder& b, Location loc,
95 Value tensor) {
96 auto tensor_type = tensor.getType().dyn_cast<RankedTensorType>();
97 if (!tensor_type) return {};
98 SmallVector<Value, 2> dyn_sizes;
99 for (auto& en : llvm::enumerate(tensor_type.getShape())) {
100 if (en.value() != ShapedType::kDynamicSize) continue;
101 dyn_sizes.push_back(b.create<DimOp>(loc, tensor, en.index()));
102 }
103 return dyn_sizes;
104 }
105
Extract1DVector(DenseIntElementsAttr elements)106 SmallVector<int64_t, 4> Extract1DVector(DenseIntElementsAttr elements) {
107 SmallVector<int64_t, 4> ret;
108 for (const APInt& element : elements) {
109 ret.push_back(element.getLimitedValue());
110 }
111 return ret;
112 }
113
114 /// Returns the constant value associated with the init value if the defining
115 /// operation is a constant.
GetInitValueAsConst(Value init)116 Attribute GetInitValueAsConst(Value init) {
117 DenseElementsAttr attr;
118 if (!matchPattern(init, m_Constant(&attr))) return {};
119 auto type = attr.getType().dyn_cast<ShapedType>();
120 if (!type || type.getRank() != 0) return {};
121 return attr.getValue({});
122 }
123
124 /// Returns a permutation AffineMap that puts all reduction dimensions to the
125 /// last. The order of parallel loops and reduction loops are all sorted. E.g.,
126 /// if `rank` is 4 and `reductionDims` is {1, 3}, then
127 /// "(d0, d1, d2, d3) -> (d0, d2, d1, d3)" is used. The inverse permutation of
128 /// the AffineMap is returned.
GetTransposeMapForReduction(MLIRContext * context,int rank,ArrayRef<int64_t> reduction_dims)129 AffineMap GetTransposeMapForReduction(MLIRContext* context, int rank,
130 ArrayRef<int64_t> reduction_dims) {
131 llvm::SmallSetVector<int, 4> s;
132 for (auto dim : reduction_dims) s.insert(dim);
133
134 SmallVector<unsigned, 4> permutation;
135 for (int i = 0; i < rank; ++i)
136 if (!s.count(i)) permutation.push_back(i);
137 for (auto dim : reduction_dims) permutation.push_back(dim);
138
139 auto map = AffineMap::getPermutationMap(permutation, context);
140 return inversePermutation(map);
141 }
142
143 template <typename OpTy, bool isLHLO = true>
144 class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
145 public:
146 using OpConversionPattern<OpTy>::OpConversionPattern;
147
matchAndRewrite(OpTy op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const148 LogicalResult matchAndRewrite(
149 OpTy op, ArrayRef<Value> args,
150 ConversionPatternRewriter& rewriter) const final {
151 auto loc = op.getLoc();
152 ShapedType t0 = args[0].getType().template dyn_cast<ShapedType>();
153 if (!t0) return failure();
154
155 unsigned nloops = t0.getRank();
156 auto fail = [&](ShapedType t) {
157 return !t || !t.hasRank() || t.getRank() != nloops ||
158 !(t.getElementType().isSignlessIntOrFloat() ||
159 t.getElementType().isa<ComplexType>());
160 };
161 if (llvm::any_of(args,
162 [&](Value v) {
163 return fail(v.getType().dyn_cast<ShapedType>());
164 }) ||
165 llvm::any_of(op.getOperation()->getResultTypes(),
166 [&](Type t) { return fail(t.dyn_cast<ShapedType>()); }))
167 return emitError(loc,
168 "lhlo to linalg conversion expects ranked args of "
169 "signless int, float or complex element type with ")
170 << nloops << " parallel iterators: " << *(op.getOperation());
171
172 // Construct the indexing maps needed for linalg.generic ops.
173 SmallVector<Type, 4> body_arg_types, body_result_types, op_result_types;
174
175 // This doesnt account for implicit broadcast, but the working assumption
176 // in HLO/LHLO is that are broadcasts are made explicit.
177
178 if (isLHLO && !nloops) return failure();
179
180 int num_inputs = (isLHLO ? args.size() - 1 : args.size());
181
182 ValueRange inputs(args.take_front(num_inputs));
183 for (Value in : inputs)
184 body_arg_types.emplace_back(getElementTypeOrSelf(in.getType()));
185
186 SmallVector<Value, 4> output_buffers;
187 if (isLHLO) {
188 output_buffers.append(args.begin() + num_inputs, args.end());
189 } else {
190 Value result = op.getOperation()->getResult(0);
191 ShapedType result_type = result.getType().template cast<ShapedType>();
192 auto dyn_sizes = ExtractDynamicSizes(rewriter, loc, args[0]);
193 output_buffers.push_back(
194 GetInitTensor(rewriter, loc, result_type, dyn_sizes));
195 op_result_types.push_back(result.getType());
196 }
197 body_result_types = llvm::to_vector<4>(llvm::map_range(
198 output_buffers, [](Value v) { return getElementTypeOrSelf(v); }));
199
200 AffineMap common_indexing_map =
201 nloops ? rewriter.getMultiDimIdentityMap(nloops)
202 : AffineMap::get(nloops, 0, rewriter.getContext());
203 SmallVector<AffineMap, 2> indexing_maps(args.size() + (isLHLO ? 0 : 1),
204 common_indexing_map);
205
206 bool failed = false;
207 auto linalg_op = rewriter.create<linalg::GenericOp>(
208 loc, op_result_types, inputs, output_buffers, indexing_maps,
209 GetNParallelLoopsAttrs(nloops),
210 [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
211 // TODO(ravishankarm) : For now use the method in lmhlo namespace.
212 // That method needs to be moved out of there.
213 Value op_result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
214 op, body_result_types,
215 llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter);
216 if (op_result == nullptr) {
217 failed = true;
218 } else {
219 nested_builder.create<linalg::YieldOp>(loc, op_result);
220 }
221 });
222 if (failed) return failure();
223 rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
224 return success();
225 }
226 };
227
228 template <typename LhloOp>
229 class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
230 public:
231 using OpConversionPattern<LhloOp>::OpConversionPattern;
232
matchAndRewrite(LhloOp lhlo_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const233 LogicalResult matchAndRewrite(
234 LhloOp lhlo_op, ArrayRef<Value> args,
235 ConversionPatternRewriter& rewriter) const final {
236 auto loc = lhlo_op.getLoc();
237 auto arg_type =
238 lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>();
239 if (!arg_type || !arg_type.getElementType().isSignlessIntOrFloat() ||
240 (arg_type.getRank() != 0)) {
241 return failure();
242 }
243
244 // Create two loads from the input.
245 auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
246 auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
247 // TODO(ravishankarm) : Move this method out of lmhlo namespace.
248 Value op_result = lmhlo::HloOpToStdScalarOp::map<LhloOp>(
249 lhlo_op, arg_type.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
250 &rewriter);
251 rewriter.create<StoreOp>(loc, op_result, lhlo_op.out());
252 rewriter.eraseOp(lhlo_op);
253 return success();
254 }
255 };
256
257 //===----------------------------------------------------------------------===//
258 // lmhlo.convolution conversion pattern.
259 //===----------------------------------------------------------------------===//
260
261 /// Converts lmhlo.convolution operation to a linalg.conv op.
262 struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
263 public:
264 using OpConversionPattern<lmhlo::ConvOp>::OpConversionPattern;
265
266 // This code has been adapted from IREE's
267 // (https://github.com/google/iree/) mhlo -> linalg conversion.
matchAndRewritemlir::__anon00bff3f50111::ConvToLinalgConverter268 LogicalResult matchAndRewrite(
269 lmhlo::ConvOp op, ArrayRef<Value> args,
270 ConversionPatternRewriter& rewriter) const final {
271 // Check validity of dimension information.
272 if (const mhlo::ConvDimensionNumbers& dimension_numbers =
273 op.dimension_numbers()) {
274 const int input_spatial_rank =
275 llvm::size(dimension_numbers.input_spatial_dimensions());
276 // The dimensions for input should follow the order of
277 // batch_count, spatial_dims..., input_feature_count.
278 if (dimension_numbers.input_batch_dimension().getInt() != 0 ||
279 dimension_numbers.input_feature_dimension().getInt() !=
280 (input_spatial_rank + 1))
281 return failure();
282
283 const int kernel_spatial_rank =
284 llvm::size(dimension_numbers.kernel_spatial_dimensions());
285 // The dimensions for filter should follow the order of
286 // spatial_dims..., input_feature_count, num_output_feature_count.
287 if (dimension_numbers.kernel_input_feature_dimension().getInt() !=
288 kernel_spatial_rank ||
289 dimension_numbers.kernel_output_feature_dimension().getInt() !=
290 (kernel_spatial_rank + 1))
291 return failure();
292
293 const int output_spatial_rank =
294 llvm::size(dimension_numbers.output_spatial_dimensions());
295 // The dimensions for output should follow the order of
296 // batch_count, spatial_dims.., output_feature_count.
297 if (dimension_numbers.output_batch_dimension().getInt() != 0 ||
298 dimension_numbers.output_feature_dimension().getInt() !=
299 (output_spatial_rank + 1))
300 return failure();
301
302 if (input_spatial_rank != output_spatial_rank ||
303 input_spatial_rank != kernel_spatial_rank)
304 return failure();
305
306 auto input_spatial_dim =
307 dimension_numbers.input_spatial_dimensions().begin();
308 auto kernel_spatial_dim =
309 dimension_numbers.kernel_spatial_dimensions().begin();
310 auto output_spatial_dim =
311 dimension_numbers.output_spatial_dimensions().begin();
312 // Check if spatial dims are ordered correctly.
313 for (int i = 0; i < input_spatial_rank; ++i) {
314 const int dim = i + 1;
315 if ((*input_spatial_dim++).getZExtValue() != dim ||
316 (*output_spatial_dim++).getZExtValue() != dim ||
317 (*kernel_spatial_dim++).getZExtValue() != i)
318 return failure();
319 }
320 }
321
322 // TODO: LHS dilation for deconvolution not supported yet.
323 // TODO(jurahul): Window reversal is not supported yet.
324 if (op.lhs_dilation() || op.hasWindowReversal()) {
325 return failure();
326 }
327
328 llvm::SmallVector<Attribute, 4> strides;
329 if (auto window_strides = op.window_strides()) {
330 auto range = window_strides->getAttributeValues();
331 strides.assign(range.begin(), range.end());
332 }
333 auto strides_arg = ArrayAttr::get(op.getContext(), strides);
334
335 llvm::SmallVector<Attribute, 2> dilation;
336 if (auto rhs_dilation = op.rhs_dilation()) {
337 auto range = rhs_dilation->getAttributeValues();
338 dilation.assign(range.begin(), range.end());
339 } else {
340 // Default dilation of 1.
341 dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1));
342 }
343 auto dilation_arg = ArrayAttr::get(op.getContext(), dilation);
344
345 // Set padding only if it is non-zero.
346 DenseIntElementsAttr padding = op.paddingAttr();
347 if (!padding ||
348 !llvm::any_of(padding.getValues<APInt>(),
349 [](APInt int_val) { return !int_val.isNullValue(); })) {
350 padding = nullptr;
351 }
352
353 // The order of input and filter are switched with linalg.conv.
354 rewriter.replaceOpWithNewOp<linalg::ConvOp>(
355 op, args[1], args[0], args[2], strides_arg, dilation_arg, padding);
356 return success();
357 }
358 };
359
360 /// Base class for lowering HLO operations that have one operand and one result,
361 /// and are semantically equivalent to a copy of the input to the output (like
362 /// transpose, some reshape, etc.). The derived classes need to provide a method
363 /// `getIndexingMaps` that returns AffineMaps for the index maps of the input
364 /// and the output.
365 template <typename Derived, typename OpTy, bool isLHLO = true>
366 class DataMovementOpConverter : public OpConversionPattern<OpTy> {
367 public:
368 using OpConversionPattern<OpTy>::OpConversionPattern;
369
matchAndRewrite(OpTy op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const370 LogicalResult matchAndRewrite(
371 OpTy op, ArrayRef<Value> args,
372 ConversionPatternRewriter& rewriter) const final {
373 if (!VerifyHloOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
374 auto result_type = GetHloOpResultType<isLHLO>(op);
375
376 SmallVector<AffineMap, 2> indexing_maps =
377 Derived::getIndexingMaps(op, &rewriter);
378 if (indexing_maps.empty()) return failure();
379
380 auto nloops = result_type.getRank();
381 auto loc = op.getLoc();
382 // TODO(pifon): technically, the op itself could have size operands (e.g.
383 // broadcast into a dynamic dimension).Handle this case.
384 auto dyn_sizes = isLHLO ? SmallVector<Value, 2>()
385 : ExtractDynamicSizes(rewriter, loc, args[0]);
386 auto linalg_op = rewriter.create<linalg::GenericOp>(
387 loc,
388 /*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : result_type,
389 /*inputs=*/args.front(),
390 /*outputBuffers=*/
391 isLHLO
392 ? ValueRange{args.back()}
393 : ValueRange{GetInitTensor(rewriter, loc, result_type, dyn_sizes)},
394 indexing_maps, GetNParallelLoopsAttrs(nloops),
395 [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
396 nested_builder.create<linalg::YieldOp>(loc, *args.begin());
397 });
398 rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
399 return success();
400 }
401 };
402
403 /// Pattern to convert BroadcastOp to Linalg ops.
404 template <typename OpTy, bool isLHLO = true>
405 class BroadcastConverter
406 : public DataMovementOpConverter<BroadcastConverter<OpTy, isLHLO>, OpTy,
407 isLHLO> {
408 public:
409 using DataMovementOpConverter<BroadcastConverter, OpTy,
410 isLHLO>::DataMovementOpConverter;
411
getIndexingMaps(OpTy broadcast_op,Builder * b)412 static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcast_op,
413 Builder* b) {
414 ShapedType input_type =
415 broadcast_op.operand().getType().template cast<ShapedType>();
416 unsigned input_rank = input_type.getRank();
417 unsigned nloops = GetHloOpResultType<isLHLO>(broadcast_op).getRank();
418
419 // BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to
420 // the input's dimensions.
421 unsigned num_prepended_dims = llvm::size(broadcast_op.broadcast_sizes());
422 SmallVector<AffineExpr, 4> input_dim_exprs;
423 input_dim_exprs.reserve(input_rank);
424 for (int i = 0; i < input_rank; ++i) {
425 input_dim_exprs.push_back(b->getAffineDimExpr(num_prepended_dims + i));
426 }
427
428 AffineMap input_map;
429 MLIRContext* context = b->getContext();
430 if (input_dim_exprs.empty()) {
431 // The input is a scalar, i.e. this is a scalar broadcast op.
432 input_map = AffineMap::get(nloops, /*symbolCount=*/0, context);
433 } else {
434 input_map =
435 AffineMap::get(nloops, /*symbolCount=*/0, input_dim_exprs, context);
436 }
437 return {input_map, b->getMultiDimIdentityMap(nloops)};
438 }
439 };
440
441 class HloBroadcastInDimConverter
442 : public DataMovementOpConverter<HloBroadcastInDimConverter,
443 mhlo::BroadcastInDimOp, false> {
444 public:
445 using DataMovementOpConverter<HloBroadcastInDimConverter,
446 mhlo::BroadcastInDimOp,
447 false>::DataMovementOpConverter;
448
getIndexingMaps(mhlo::BroadcastInDimOp broadcast_op,Builder * b)449 static SmallVector<AffineMap, 2> getIndexingMaps(
450 mhlo::BroadcastInDimOp broadcast_op, Builder* b) {
451 auto result_type = GetHloOpResultType<false>(broadcast_op);
452 auto operand_type =
453 broadcast_op.operand().getType().template cast<ShapedType>();
454 unsigned nloops = result_type.getRank();
455
456 // The input is a scalar, i.e. this is a scalar broadcast op.
457 if (operand_type.getRank() == 0) {
458 return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
459 b->getMultiDimIdentityMap(nloops)};
460 }
461
462 auto operand_shape = operand_type.getShape();
463 SmallVector<AffineExpr, 4> dim_exprs;
464 dim_exprs.reserve(nloops);
465
466 if (broadcast_op.broadcast_dimensions()) {
467 for (const auto& broadcastDim :
468 enumerate(broadcast_op.broadcast_dimensions().getIntValues())) {
469 int size = broadcastDim.value().getSExtValue();
470 bool expansion_needed = operand_shape[broadcastDim.index()] == 1 &&
471 result_type.getShape()[size] != 1;
472 dim_exprs.push_back(expansion_needed ? b->getAffineConstantExpr(0)
473 : b->getAffineDimExpr(size));
474 }
475 }
476 return {
477 AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()),
478 b->getMultiDimIdentityMap(nloops)};
479 }
480 };
481
482 class HloDynamicBroadcastInDimConverter
483 : public OpConversionPattern<mhlo::DynamicBroadcastInDimOp> {
484 public:
485 using OpConversionPattern<mhlo::DynamicBroadcastInDimOp>::OpConversionPattern;
486
matchAndRewrite(mhlo::DynamicBroadcastInDimOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const487 LogicalResult matchAndRewrite(
488 mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
489 ConversionPatternRewriter& rewriter) const final {
490 // Convert only if the producer is an HLO constant. Ideally the pattern
491 // (`mhlo.constant` -> `mhlo.dynamic_broadcast_in_dim`) should be converted
492 // to an Tensor-dialect op similar to TF ConstantLikeOp.
493 if (!op.operand().getDefiningOp<mhlo::ConstOp>()) return failure();
494
495 mhlo::DynamicBroadcastInDimOp::Adaptor adaptor(op);
496 Value operand = adaptor.operand();
497 auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
498 if (!operand_type || operand_type.getRank() != 0) return failure();
499
500 Value shape = adaptor.output_dimensions();
501 auto shape_type = shape.getType().cast<RankedTensorType>();
502 int64_t result_rank = shape_type.getDimSize(0);
503
504 SmallVector<Value, 2> dyn_dims;
505 Location loc = op.getLoc();
506 for (int i = 0; i < result_rank; ++i) {
507 Value index = rewriter.create<ConstantIndexOp>(loc, i);
508 dyn_dims.push_back(rewriter.create<tensor::ExtractOp>(loc, shape, index));
509 }
510 auto result_type = op.getType().dyn_cast<RankedTensorType>();
511 if (!result_type) return failure();
512
513 int64_t nloops = result_type.getRank();
514 Value init = rewriter.create<linalg::InitTensorOp>(
515 loc, dyn_dims, result_type.getShape(), result_type.getElementType());
516 Operation* generic = rewriter.create<linalg::GenericOp>(
517 loc, TypeRange{init.getType()}, ValueRange{operand},
518 /*outputBuffers=*/ValueRange{init},
519 llvm::makeArrayRef(
520 {AffineMap::get(/*dimCount=*/nloops, /*symbolCount=*/0, {},
521 rewriter.getContext()),
522 rewriter.getMultiDimIdentityMap(nloops)}),
523 GetNParallelLoopsAttrs(nloops),
524 [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
525 nested_builder.create<linalg::YieldOp>(loc, *args.begin());
526 });
527 rewriter.replaceOp(op, generic->getResults());
528 return success();
529 }
530 };
531
532 class LhloBroadcastInDimConverter
533 : public OpConversionPattern<lmhlo::BroadcastInDimOp> {
534 public:
535 using OpConversionPattern<lmhlo::BroadcastInDimOp>::OpConversionPattern;
536
matchAndRewrite(lmhlo::BroadcastInDimOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const537 LogicalResult matchAndRewrite(
538 lmhlo::BroadcastInDimOp op, ArrayRef<Value> args,
539 ConversionPatternRewriter& rewriter) const final {
540 lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
541 auto result_type = operand_adaptor.output().getType().cast<MemRefType>();
542 auto result_shape = result_type.getShape();
543
544 auto operand_and_dims = InsertReshapeIfNecessary(op, args, rewriter);
545
546 Value operand = std::get<0>(operand_and_dims);
547 auto broadcast_dims = std::get<1>(operand_and_dims);
548
549 auto loc = op.getLoc();
550 auto nloops = result_type.getRank();
551 auto operand_type = operand.getType().cast<MemRefType>();
552
553 // For a degenerate case, i.e. broadcasting with expansion of
554 // memref<1xELEMENT_TYPE>, the operand is not passed to `linalg.generic`.
555 // Instead the value is loaded and used directly in `linalg.yield`.
556 if (operand_type.getRank() == 1 &&
557 operand_type.getDimSize(0) <
558 result_type.getDimSize(broadcast_dims.front())) {
559 Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
560 Value val =
561 rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero}));
562 rewriter.create<linalg::GenericOp>(
563 loc, /*inputs=*/ValueRange{},
564 /*outputBuffers=*/ValueRange{operand_adaptor.output()},
565 llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
566 GetNParallelLoopsAttrs(nloops),
567 [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
568 nested_builder.create<linalg::YieldOp>(loc, val);
569 });
570
571 } else {
572 auto indexing_maps = getIndexingMaps(op, broadcast_dims, result_shape,
573 operand_type, &rewriter);
574 rewriter.create<linalg::GenericOp>(
575 loc, /*inputs=*/ValueRange{operand},
576 /*outputBuffers=*/ValueRange{operand_adaptor.output()}, indexing_maps,
577 GetNParallelLoopsAttrs(nloops),
578 [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
579 nested_builder.create<linalg::YieldOp>(loc, *args.begin());
580 });
581 }
582 rewriter.replaceOp(op, llvm::None);
583 return success();
584 }
585
586 // Inserts 'linalg.reshape' if there is a size-1 dim expansion.
InsertReshapeIfNecessary(lmhlo::BroadcastInDimOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const587 std::pair<Value, SmallVector<int64_t, 2>> InsertReshapeIfNecessary(
588 lmhlo::BroadcastInDimOp op, ArrayRef<Value> args,
589 ConversionPatternRewriter& rewriter) const {
590 lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
591 Value operand = operand_adaptor.operand();
592 auto operand_type = operand_adaptor.operand().getType().cast<MemRefType>();
593 auto operand_shape = operand_type.getShape();
594
595 Value result = operand_adaptor.output();
596 auto result_type = result.getType().cast<MemRefType>();
597 auto result_shape = result_type.getShape();
598
599 SmallVector<int64_t, 2> operand_strides;
600 int64_t operand_offset;
601 if (failed(getStridesAndOffset(operand_type, operand_strides,
602 operand_offset))) {
603 op.emitOpError() << "Failed to get offset and strides.";
604 }
605
606 SmallVector<int64_t, 2> new_shape, new_strides, broadcast_dims;
607 SmallVector<linalg::ReassociationIndices, 4> collapsed_dims_list;
608 linalg::ReassociationIndices collapsed_dims;
609 for (const auto& item :
610 enumerate(op.broadcast_dimensions().getIntValues())) {
611 size_t index = item.index();
612 int dim = item.value().getSExtValue();
613
614 collapsed_dims.push_back(index);
615
616 bool expansion_needed =
617 operand_shape[index] == 1 && result_shape[dim] != 1;
618 if (expansion_needed) {
619 continue;
620 }
621 new_shape.push_back(operand_shape[index]);
622 new_strides.push_back(operand_strides[index]);
623 broadcast_dims.push_back(dim);
624
625 collapsed_dims_list.push_back(collapsed_dims);
626 collapsed_dims.clear();
627 }
628 // If `collapsed_dims_list` is empty, then the memref has shape [1, ..., 1]
629 // and all dimensions need expansion. Such memref will be reshaped to a 1D
630 // memref with a single element. New shape and strides needs to be updated
631 // accordingly.
632 if (collapsed_dims_list.empty()) {
633 collapsed_dims_list.push_back({});
634 new_shape.push_back(1);
635 new_strides.push_back(1);
636 broadcast_dims.push_back(0);
637 }
638 for (const auto& dims : collapsed_dims) {
639 collapsed_dims_list.back().push_back(dims);
640 }
641
642 // `linalg.reshape` is inserted only if necessary, i.e. when the rank can be
643 // reduced.
644 if (new_shape.size() < operand_shape.size()) {
645 auto new_memref_type = MemRefType::get(
646 new_shape, operand_type.getElementType(),
647 makeStridedLinearLayoutMap(new_strides, operand_offset,
648 rewriter.getContext()));
649 operand = rewriter.create<linalg::ReshapeOp>(op.getLoc(), new_memref_type,
650 operand_adaptor.operand(),
651 collapsed_dims_list);
652 }
653 return std::make_pair(operand, broadcast_dims);
654 }
655
getIndexingMaps(lmhlo::BroadcastInDimOp op,ArrayRef<int64_t> broadcast_dims,ArrayRef<int64_t> result_shape,MemRefType operand_type,Builder * b) const656 SmallVector<AffineMap, 2> getIndexingMaps(lmhlo::BroadcastInDimOp op,
657 ArrayRef<int64_t> broadcast_dims,
658 ArrayRef<int64_t> result_shape,
659 MemRefType operand_type,
660 Builder* b) const {
661 unsigned nloops = result_shape.size();
662
663 // The input is a scalar, i.e. this is a scalar broadcast op.
664 if (operand_type.getRank() == 0) {
665 return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
666 b->getMultiDimIdentityMap(nloops)};
667 }
668
669 auto operand_shape = operand_type.getShape();
670 SmallVector<AffineExpr, 4> dim_exprs;
671 dim_exprs.reserve(nloops);
672
673 for (const auto& broadcast_dim : llvm::enumerate(broadcast_dims)) {
674 int size = broadcast_dim.value();
675 bool expansion_needed =
676 operand_shape[broadcast_dim.index()] == 1 && result_shape[size] != 1;
677 if (expansion_needed) {
678 op.emitOpError(
679 "BroadcastInDimOp lowering to Linalg does not support size-1 "
680 "dimensions expansion.");
681 }
682 dim_exprs.push_back(b->getAffineDimExpr(size));
683 }
684 return {
685 AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()),
686 b->getMultiDimIdentityMap(nloops)};
687 }
688 };
689
690 template <typename OpTy, bool isLHLO = true>
691 class TransposeConverter
692 : public DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
693 isLHLO> {
694 public:
695 using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
696 isLHLO>::DataMovementOpConverter;
getIndexingMaps(OpTy op,Builder * b)697 static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
698 auto result_type =
699 GetHloOpResultType<isLHLO>(op).template cast<ShapedType>();
700 auto nloops = result_type.getRank();
701 SmallVector<AffineExpr, 2> input_exprs;
702 input_exprs.resize(result_type.getRank());
703 for (auto permutation : llvm::enumerate(op.permutation())) {
704 input_exprs[permutation.value().getZExtValue()] =
705 b->getAffineDimExpr(permutation.index());
706 }
707 return {
708 AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()),
709 b->getMultiDimIdentityMap(nloops)};
710 }
711 };
712
713 // Converts reshape ops that can be proven to be either a collapse of dimensions
714 // or expansion of dimensions of the operand.
715 template <typename OpTy, bool isLHLO = true>
716 class ReshapeOpConverter : public OpConversionPattern<OpTy> {
717 public:
718 using OpConversionPattern<OpTy>::OpConversionPattern;
719
matchAndRewrite(OpTy reshape_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const720 LogicalResult matchAndRewrite(
721 OpTy reshape_op, ArrayRef<Value> args,
722 ConversionPatternRewriter& rewriter) const final {
723 if (!VerifyHloOpBufferOrTensorSemantics<isLHLO>(reshape_op))
724 return failure();
725 typename OpTy::Adaptor operands(args);
726 ShapedType operand_type =
727 operands.operand().getType().template cast<ShapedType>();
728 ShapedType result_type = GetHloOpResultType<isLHLO>(reshape_op);
729
730 if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
731 return failure();
732
733 // Compute the reassociation maps for the linalg operation.
734 ArrayRef<int64_t> src_shape =
735 (operand_type.getRank() > result_type.getRank()
736 ? operand_type.getShape()
737 : result_type.getShape());
738 ArrayRef<int64_t> dst_shape =
739 (operand_type.getRank() > result_type.getRank()
740 ? result_type.getShape()
741 : operand_type.getShape());
742 unsigned curr_src_dim = 0, curr_dst_dim = 0;
743 SmallVector<linalg::ReassociationExprs, 4> reassociation_map(
744 dst_shape.size());
745
746 // First scan all dimensions in the source shapes to see whether we have a
747 // perfect case where consecutive dimensions in source are collapsed. For
748 // such case we can just generate one single linalg.reshape.
749 bool is_collapsing_source = true;
750 while (curr_src_dim < src_shape.size() && curr_dst_dim < dst_shape.size()) {
751 int64_t dst_size = dst_shape[curr_dst_dim];
752 int64_t src_size = src_shape[curr_src_dim];
753 while (src_size < dst_size && curr_src_dim < src_shape.size()) {
754 reassociation_map[curr_dst_dim].push_back(
755 rewriter.getAffineDimExpr(curr_src_dim++));
756 src_size *= src_shape[curr_src_dim];
757 }
758 if (src_size == dst_size) {
759 reassociation_map[curr_dst_dim].push_back(
760 rewriter.getAffineDimExpr(curr_src_dim++));
761 // If the next dim in dst_shape is not 1, treat subsequent dims in
762 // src_shape which are 1 to be collapsed.
763 if (curr_dst_dim == dst_shape.size() - 1 ||
764 dst_shape[curr_dst_dim + 1] != 1) {
765 while (curr_src_dim < src_shape.size() &&
766 src_shape[curr_src_dim] == 1) {
767 reassociation_map[curr_dst_dim].push_back(
768 rewriter.getAffineDimExpr(curr_src_dim++));
769 }
770 }
771 } else {
772 is_collapsing_source = false;
773 break;
774 }
775 curr_dst_dim++;
776 }
777 if (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size())
778 is_collapsing_source = false;
779
780 // Otherwise, we need to first reduce all source dimensions into one and
781 // then expand to the destination dimensions.
782 if (!is_collapsing_source) {
783 auto get_identity_exprs = [&rewriter](int n) {
784 SmallVector<AffineExpr, 4> exprs;
785 for (int i = 0; i < n; ++i)
786 exprs.push_back(rewriter.getAffineDimExpr(i));
787 return exprs;
788 };
789 Location loc = reshape_op.getLoc();
790 int64_t total_elems = std::accumulate(src_shape.begin(), src_shape.end(),
791 1, std::multiplies<int64_t>());
792 auto elem_type = operand_type.getElementType();
793 SmallVector<linalg::ReassociationExprs, 4> collapsing_map = {
794 // Use operand_type here because we need to collapse all operands
795 // dimensions.
796 get_identity_exprs(operand_type.getShape().size())};
797 SmallVector<linalg::ReassociationExprs, 4> expanding_map = {
798 // Use result_type here because we need to expand to all result
799 // dimensions.
800 get_identity_exprs(result_type.getShape().size())};
801
802 if (isLHLO) {
803 auto collapsed_type = MemRefType::get({total_elems}, elem_type);
804 Value collapsed_op = rewriter.create<linalg::ReshapeOp>(
805 loc, collapsed_type, args[0], collapsing_map);
806 Value reshape_buffer = rewriter.create<linalg::ReshapeOp>(
807 loc, result_type, collapsed_op, expanding_map);
808 rewriter.replaceOpWithNewOp<linalg::CopyOp>(reshape_op, reshape_buffer,
809 args[1]);
810 } else {
811 auto collapsed_type = RankedTensorType::get({total_elems}, elem_type);
812 Value collapsed_op = rewriter.create<linalg::TensorReshapeOp>(
813 loc, collapsed_type, args[0], collapsing_map);
814 rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
815 reshape_op, result_type, collapsed_op, expanding_map);
816 }
817 return success();
818 }
819
820 if (isLHLO) {
821 Value reshape_buffer = rewriter.create<linalg::ReshapeOp>(
822 reshape_op.getLoc(), result_type, args[0], reassociation_map);
823 rewriter.replaceOpWithNewOp<linalg::CopyOp>(reshape_op, reshape_buffer,
824 args[1]);
825 } else {
826 rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
827 reshape_op, result_type, args[0], reassociation_map);
828 }
829 return success();
830 }
831 };
832
833 template <typename OpTy, bool isLHLO = true>
834 class IotaConverter : public OpConversionPattern<OpTy> {
835 public:
836 using OpConversionPattern<OpTy>::OpConversionPattern;
837
matchAndRewrite(OpTy iota_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const838 LogicalResult matchAndRewrite(
839 OpTy iota_op, ArrayRef<Value> args,
840 ConversionPatternRewriter& rewriter) const final {
841 ShapedType result_shaped_type = GetHloOpResultType<isLHLO>(iota_op);
842 if (!result_shaped_type) return failure();
843
844 auto result_element_type = result_shaped_type.getElementType();
845 if (!result_element_type.isSignlessIntOrFloat()) return failure();
846
847 // Construct the indexing maps needed for linalg.generic ops.
848 unsigned nloops = result_shaped_type.getRank();
849
850 Location loc = iota_op.getLoc();
851 auto dyn_sizes = isLHLO
852 ? SmallVector<Value, 2>()
853 : ExtractDynamicSizes(rewriter, loc,
854 GetResultValue<isLHLO>(iota_op));
855 auto linalg_op = rewriter.create<linalg::IndexedGenericOp>(
856 loc,
857 /*resultTensorTypes=*/
858 isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{result_shaped_type},
859 /*inputs=*/ValueRange{},
860 /*outputBuffers=*/
861 isLHLO ? ValueRange{args}
862 : ValueRange{GetInitTensor(rewriter, loc, result_shaped_type,
863 dyn_sizes)},
864 llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
865 GetNParallelLoopsAttrs(nloops),
866 [&](OpBuilder& nested_builder, Location nested_loc, ValueRange ivs,
867 ValueRange args) {
868 Value cast_op = nested_builder.create<IndexCastOp>(
869 nested_loc, ivs[iota_op.iota_dimension()],
870 nested_builder.getIntegerType(
871 result_element_type.getIntOrFloatBitWidth()));
872 if (result_element_type.template isa<FloatType>()) {
873 cast_op = nested_builder.create<SIToFPOp>(nested_loc, cast_op,
874 result_element_type);
875 }
876 nested_builder.create<linalg::YieldOp>(nested_loc, cast_op);
877 });
878 if (isLHLO)
879 rewriter.replaceOp(iota_op, llvm::None);
880 else
881 rewriter.replaceOp(iota_op, linalg_op.result_tensors());
882 return success();
883 }
884 };
885
886 template <typename OpTy>
887 class ConstConverter : public OpConversionPattern<OpTy> {
888 public:
889 using OpConversionPattern<OpTy>::OpConversionPattern;
890
matchAndRewrite(OpTy const_op,ArrayRef<Value>,ConversionPatternRewriter & rewriter) const891 LogicalResult matchAndRewrite(
892 OpTy const_op, ArrayRef<Value> /*args*/,
893 ConversionPatternRewriter& rewriter) const final {
894 Location loc = const_op.getLoc();
895 auto value_attr = const_op.value().template cast<DenseElementsAttr>();
896 if (value_attr.getType().getRank() != 0) return failure();
897 ReplaceConstOp(loc, const_op, value_attr, rewriter);
898 return success();
899 }
900
901 private:
ReplaceConstOp(Location loc,mhlo::ConstOp op,DenseElementsAttr value_attr,ConversionPatternRewriter & rewriter) const902 void ReplaceConstOp(Location loc, mhlo::ConstOp op,
903 DenseElementsAttr value_attr,
904 ConversionPatternRewriter& rewriter) const {
905 Value std_tensor_const = rewriter.create<mlir::ConstantOp>(loc, value_attr);
906 rewriter.replaceOp(op, {std_tensor_const});
907 }
ReplaceConstOp(Location loc,lmhlo::ConstOp op,DenseElementsAttr value_attr,ConversionPatternRewriter & rewriter) const908 void ReplaceConstOp(Location loc, lmhlo::ConstOp op,
909 DenseElementsAttr value_attr,
910 ConversionPatternRewriter& rewriter) const {
911 Value std_scalar_const =
912 rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({}));
913 rewriter.create<mlir::AffineStoreOp>(loc, std_scalar_const, op.getOperand(),
914 llvm::None);
915 rewriter.eraseOp(op);
916 }
917 };
918
919 class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
920 public:
921 using OpConversionPattern<lmhlo::ReduceOp>::OpConversionPattern;
922
matchAndRewrite(lmhlo::ReduceOp reduce_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const923 LogicalResult matchAndRewrite(
924 lmhlo::ReduceOp reduce_op, ArrayRef<Value> args,
925 ConversionPatternRewriter& rewriter) const final {
926 auto loc = reduce_op.getLoc();
927 lmhlo::ReduceOp::Adaptor adaptor(args);
928 auto operand_shape =
929 adaptor.operands()[0].getType().template dyn_cast<ShapedType>();
930 if (!operand_shape || !operand_shape.hasRank()) {
931 emitError(loc, "lhlo to linalg conversion expects known-rank args");
932 return failure();
933 }
934
935 // First fill the output buffer with the init value.
936 Value init_value = rewriter.create<LoadOp>(loc, adaptor.init_values()[0]);
937 rewriter.create<linalg::FillOp>(loc, adaptor.out()[0], init_value);
938
939 DenseIntElementsAttr dimensions_attr = reduce_op.dimensions();
940 SmallVector<int, 4> reduction_dims;
941 for (const auto& dim : dimensions_attr.getIntValues()) {
942 reduction_dims.push_back(dim.getSExtValue());
943 }
944
945 SmallVector<AffineExpr, 2> src_exprs;
946 SmallVector<AffineExpr, 2> dst_exprs;
947 SmallVector<StringRef, 4> types;
948 for (int i = 0, rank = operand_shape.getRank(); i != rank; ++i) {
949 bool is_reduced = llvm::is_contained(reduction_dims, i);
950 types.push_back(is_reduced ? getReductionIteratorTypeName()
951 : getParallelIteratorTypeName());
952
953 src_exprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
954 if (!is_reduced) {
955 dst_exprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
956 }
957 }
958
959 auto maps = AffineMap::inferFromExprList({src_exprs, dst_exprs});
960
961 auto linalg_op = rewriter.create<linalg::GenericOp>(
962 loc, /*resultTensorTypes=*/ArrayRef<Type>{},
963 /*inputs=*/adaptor.operands(), /*outputBuffers=*/adaptor.out(), maps,
964 types);
965 rewriter.inlineRegionBefore(reduce_op.body(), linalg_op.region(),
966 linalg_op.region().end());
967 {
968 OpBuilder::InsertionGuard region_guard(rewriter);
969 Block* block = linalg_op.getBody();
970 rewriter.setInsertionPoint(&block->front());
971
972 // The incoming region is operating on buffers, while linalg.generic
973 // expects scalar SSA values. Add some allocs around the original op to
974 // make it compatible.
975 auto arg_type = block->getArgument(0).getType().cast<MemRefType>();
976 Value alloc_a = rewriter.create<AllocaOp>(loc, arg_type);
977 Value alloc_b = rewriter.create<AllocaOp>(loc, arg_type);
978 Value alloc_res = rewriter.create<AllocaOp>(loc, arg_type);
979
980 // Now turn the existing signature
981 // (memref<X>, memref<X>, memref<X>) -> ()
982 // into
983 // (X, X) -> X
984 TypeConverter::SignatureConversion signature_converter(3);
985 signature_converter.remapInput(0, alloc_a);
986 signature_converter.remapInput(1, alloc_b);
987 signature_converter.remapInput(2, alloc_res);
988 signature_converter.addInputs(
989 {arg_type.getElementType(), arg_type.getElementType()});
990 Block* entry_block = rewriter.applySignatureConversion(
991 &linalg_op.region(), signature_converter);
992
993 // Store the arguments into the newly allocated buffers.
994 rewriter.setInsertionPointAfter(alloc_res.getDefiningOp());
995 rewriter.create<StoreOp>(loc, entry_block->getArgument(0), alloc_a);
996 rewriter.create<StoreOp>(loc, entry_block->getArgument(1), alloc_b);
997 rewriter.replaceOp(entry_block->getTerminator(), {});
998
999 // Load & yield the result.
1000 rewriter.setInsertionPointToEnd(entry_block);
1001 auto load_res = rewriter.create<LoadOp>(loc, alloc_res);
1002 rewriter.create<linalg::YieldOp>(loc, ValueRange{load_res});
1003 }
1004
1005 rewriter.replaceOp(reduce_op, linalg_op.getOperation()->getResults());
1006 return success();
1007 }
1008 };
1009
1010 // TODO(b/156787842): Support the lowering for dynamic shapes.
1011 template <typename OpTy, bool isLHLO = true>
1012 class ReverseConverter
1013 : public DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
1014 isLHLO> {
1015 public:
1016 using DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
1017 isLHLO>::DataMovementOpConverter;
getIndexingMaps(OpTy op,Builder * b)1018 static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
1019 auto result_type =
1020 GetHloOpResultType<isLHLO>(op).template cast<ShapedType>();
1021 auto nloops = result_type.getRank();
1022 SmallVector<AffineExpr, 2> input_exprs;
1023 input_exprs.reserve(nloops);
1024 for (int i = 0; i < nloops; ++i)
1025 input_exprs.push_back(b->getAffineDimExpr(i));
1026 for (auto dim : op.dimensions()) {
1027 int i = dim.getZExtValue();
1028 if (result_type.isDynamicDim(i)) return {};
1029 int n = result_type.getShape()[i];
1030 input_exprs[i] = b->getAffineConstantExpr(n - 1) - input_exprs[i];
1031 }
1032 return {
1033 AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()),
1034 b->getMultiDimIdentityMap(nloops)};
1035 }
1036 };
1037
1038 class SliceConverter : public OpConversionPattern<lmhlo::SliceOp> {
1039 public:
1040 using OpConversionPattern<lmhlo::SliceOp>::OpConversionPattern;
1041
matchAndRewrite(lmhlo::SliceOp slice_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1042 LogicalResult matchAndRewrite(
1043 lmhlo::SliceOp slice_op, ArrayRef<Value> args,
1044 ConversionPatternRewriter& rewriter) const final {
1045 auto loc = slice_op.getLoc();
1046 auto arg_type =
1047 slice_op.getOperand(0).getType().template dyn_cast<ShapedType>();
1048 if (!arg_type || !arg_type.hasRank()) {
1049 emitError(loc, "lhlo to linalg conversion expects known-rank args");
1050 return failure();
1051 }
1052
1053 SmallVector<OpFoldResult, 3> offsets, sizes, strides;
1054 for (int i = 0, e = arg_type.getRank(); i < e; ++i) {
1055 offsets.push_back(rewriter.getI64IntegerAttr(
1056 slice_op.start_indices().getValue<int64_t>(i)));
1057 sizes.push_back(rewriter.getI64IntegerAttr(
1058 slice_op.limit_indices().getValue<int64_t>(i) -
1059 slice_op.start_indices().getValue<int64_t>(i)));
1060 strides.push_back(
1061 rewriter.getI64IntegerAttr(slice_op.strides().getValue<int64_t>(i)));
1062 }
1063 auto linalg_slice = rewriter.create<SubViewOp>(loc, slice_op.getOperand(0),
1064 offsets, sizes, strides);
1065 rewriter.create<linalg::CopyOp>(loc, linalg_slice, slice_op.getOperand(1));
1066 rewriter.eraseOp(slice_op);
1067 return success();
1068 }
1069 };
1070
1071 enum class DotOperationType {
1072 kVectorDot = 0,
1073 kMatrixVector = 1,
1074 kMatrixMatrix = 2,
1075 kUnsupported = 3
1076 };
1077
GetDotOperationType(mhlo::DotOp dot_op)1078 DotOperationType GetDotOperationType(mhlo::DotOp dot_op) {
1079 ArrayRef<int64_t> lhs_shape =
1080 dot_op.lhs().getType().cast<ShapedType>().getShape();
1081 ArrayRef<int64_t> rhs_shape =
1082 dot_op.rhs().getType().cast<ShapedType>().getShape();
1083 auto shape_matches = [](int64_t a, int64_t b) {
1084 return a == ShapedType::kDynamicSize || b == ShapedType::kDynamicSize ||
1085 a == b;
1086 };
1087 if (lhs_shape.size() == 1 && rhs_shape.size() == 1 &&
1088 shape_matches(lhs_shape[0], rhs_shape[0])) {
1089 return DotOperationType::kVectorDot;
1090 }
1091 if (lhs_shape.size() == 2 && rhs_shape.size() == 1 &&
1092 shape_matches(lhs_shape[1], rhs_shape[0])) {
1093 return DotOperationType::kMatrixVector;
1094 }
1095 if (rhs_shape.size() == 2 && rhs_shape.size() == 2 &&
1096 shape_matches(lhs_shape[1], rhs_shape[0])) {
1097 return DotOperationType::kMatrixMatrix;
1098 }
1099 return DotOperationType::kUnsupported;
1100 }
1101
GetDotOpInitTensorDynSizes(OpBuilder & b,Location loc,Value lhs,Value rhs,DotOperationType type)1102 SmallVector<Value, 2> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
1103 Value lhs, Value rhs,
1104 DotOperationType type) {
1105 SmallVector<Value, 2> dyn_shape;
1106 switch (type) {
1107 case DotOperationType::kMatrixMatrix: {
1108 if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
1109 dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
1110 if (rhs.getType().cast<ShapedType>().isDynamicDim(1))
1111 dyn_shape.push_back(b.create<DimOp>(loc, rhs, 1));
1112 break;
1113 }
1114 case DotOperationType::kMatrixVector: {
1115 if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
1116 dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
1117 break;
1118 }
1119 case DotOperationType::kVectorDot:
1120 case DotOperationType::kUnsupported:
1121 default: {
1122 break;
1123 }
1124 }
1125 return dyn_shape;
1126 }
1127
1128 class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
1129 public:
1130 using OpConversionPattern<mhlo::DotOp>::OpConversionPattern;
matchAndRewrite(mhlo::DotOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1131 LogicalResult matchAndRewrite(
1132 mhlo::DotOp op, ArrayRef<Value> args,
1133 ConversionPatternRewriter& rewriter) const final {
1134 if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
1135 return failure();
1136 }
1137 Location loc = op.getLoc();
1138 mhlo::DotOp::Adaptor adaptor(args);
1139 Type result_type = op.getResult().getType();
1140 auto shaped_type = result_type.cast<ShapedType>();
1141 DotOperationType op_type = GetDotOperationType(op);
1142 auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType());
1143 Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
1144 SmallVector<Value, 2> dyn_shape = GetDotOpInitTensorDynSizes(
1145 rewriter, loc, adaptor.lhs(), adaptor.rhs(), op_type);
1146 auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape);
1147 Value zero_tensor =
1148 rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
1149 linalg::LinalgOp linalg_op;
1150 switch (op_type) {
1151 case DotOperationType::kMatrixMatrix: {
1152 linalg_op = rewriter.create<linalg::MatmulOp>(
1153 loc, TypeRange{result_type},
1154 ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor});
1155 break;
1156 }
1157 case DotOperationType::kMatrixVector: {
1158 linalg_op = rewriter.create<linalg::MatvecOp>(
1159 loc, TypeRange{result_type},
1160 ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor});
1161 break;
1162 }
1163 case DotOperationType::kVectorDot: {
1164 linalg_op = rewriter.create<linalg::DotOp>(
1165 loc, TypeRange{result_type},
1166 ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor});
1167 break;
1168 }
1169 case DotOperationType::kUnsupported:
1170 default: {
1171 return op.emitError("unsupported dot operation type");
1172 }
1173 }
1174 rewriter.replaceOp(op, linalg_op->getResults());
1175 return success();
1176 }
1177 };
1178
GetDotGeneralOpInitTensorDynSizes(OpBuilder & b,Location loc,Value lhs,Value rhs,ShapedType result_type)1179 SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
1180 OpBuilder& b, Location loc, Value lhs, Value rhs, ShapedType result_type) {
1181 SmallVector<Value, 8> dyn_shape;
1182 if (result_type.isDynamicDim(0))
1183 dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
1184 if (result_type.isDynamicDim(1))
1185 dyn_shape.push_back(b.create<DimOp>(loc, lhs, 1));
1186 if (result_type.isDynamicDim(2))
1187 dyn_shape.push_back(b.create<DimOp>(loc, rhs, 2));
1188 return dyn_shape;
1189 }
1190
1191 class DotGeneralOpOnTensorsConversion
1192 : public OpConversionPattern<mhlo::DotGeneralOp> {
1193 public:
1194 using OpConversionPattern<mhlo::DotGeneralOp>::OpConversionPattern;
matchAndRewrite(mhlo::DotGeneralOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1195 LogicalResult matchAndRewrite(
1196 mhlo::DotGeneralOp op, ArrayRef<Value> args,
1197 ConversionPatternRewriter& rewriter) const final {
1198 if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
1199 return failure();
1200 }
1201 mhlo::DotDimensionNumbers dim_numbers = op.dot_dimension_numbers();
1202 auto lhs_bathcing_dims =
1203 Extract1DVector(dim_numbers.lhs_batching_dimensions());
1204 auto rhs_bathcing_dims =
1205 Extract1DVector(dim_numbers.rhs_batching_dimensions());
1206 auto lhs_contracting_dims =
1207 Extract1DVector(dim_numbers.lhs_contracting_dimensions());
1208 auto rhs_contracting_dims =
1209 Extract1DVector(dim_numbers.rhs_contracting_dimensions());
1210 if (lhs_bathcing_dims.size() != 1 || lhs_bathcing_dims[0] != 0) {
1211 return rewriter.notifyMatchFailure(
1212 op, "expected lhs batching dimensions exactly {0}");
1213 }
1214 if (rhs_bathcing_dims.size() != 1 || rhs_bathcing_dims[0] != 0) {
1215 return rewriter.notifyMatchFailure(
1216 op, "expected rhs batching dimensions exactly {0}");
1217 }
1218 if (lhs_contracting_dims.size() != 1 || lhs_contracting_dims[0] != 2) {
1219 return rewriter.notifyMatchFailure(
1220 op, "expected lhs contracting dimensions exactly {2}");
1221 }
1222 if (rhs_contracting_dims.size() != 1 || rhs_contracting_dims[0] != 1) {
1223 return rewriter.notifyMatchFailure(
1224 op, "expected rhs contracting dimensions exactly {1}");
1225 }
1226 Location loc = op.getLoc();
1227 mhlo::DotGeneralOp::Adaptor adaptor(args);
1228 Type result_type = op.getResult().getType();
1229 auto shaped_type = result_type.cast<ShapedType>();
1230 SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
1231 rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type);
1232 auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType());
1233 Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
1234 auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape);
1235 Value zero_tensor =
1236 rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
1237 auto linalg_op = rewriter.create<linalg::BatchMatmulOp>(
1238 loc, /*resultTensorTypes=*/TypeRange{result_type},
1239 /*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
1240 /*outputBuffers=*/ValueRange{zero_tensor});
1241 rewriter.replaceOp(op, linalg_op.getResults());
1242 return success();
1243 }
1244 };
1245
1246 template <typename OpTy>
1247 struct ReduceRegionXLAOpConversion : public OpConversionPattern<OpTy> {
1248 using OpConversionPattern<OpTy>::OpConversionPattern;
matchAndRewritemlir::__anon00bff3f50111::ReduceRegionXLAOpConversion1249 LogicalResult matchAndRewrite(
1250 OpTy op, ArrayRef<Value> args,
1251 ConversionPatternRewriter& rewriter) const final {
1252 // Only convert the body of reduction ops to std ops.
1253 auto parent_op = op.getOperation()->getParentRegion()->getParentOp();
1254 if (!isa<mhlo::ReduceOp, linalg::GenericOp, linalg::IndexedGenericOp>(
1255 parent_op)) {
1256 return failure();
1257 }
1258 if (!op.getResult().getType().template isa<TensorType>()) return failure();
1259 if (llvm::all_of(args, [](Value arg) {
1260 return arg.getType().template isa<TensorType>();
1261 })) {
1262 return failure();
1263 }
1264 Value result = lmhlo::HloOpToStdScalarOp::map<OpTy>(op, args[0].getType(),
1265 args, &rewriter);
1266 rewriter.replaceOp(op, result);
1267 return success();
1268 }
1269 };
1270
GetReduceOpInitTensorDynSizes(OpBuilder & b,Location loc,Value arg,ShapedType result_type,ArrayRef<int64_t> reduction_dims)1271 SmallVector<Value, 8> GetReduceOpInitTensorDynSizes(
1272 OpBuilder& b, Location loc, Value arg, ShapedType result_type,
1273 ArrayRef<int64_t> reduction_dims) {
1274 llvm::SmallSetVector<int, 4> s;
1275 for (auto dim : reduction_dims) s.insert(dim);
1276
1277 SmallVector<unsigned, 4> parallel_dims;
1278 SmallVector<Value, 8> dyn_shape;
1279 int rank = arg.getType().cast<RankedTensorType>().getRank();
1280 for (int i = 0, j = 0; i < rank; ++i) {
1281 if (s.count(i)) continue;
1282 if (!result_type.isDynamicDim(j++)) continue;
1283 dyn_shape.push_back(b.create<DimOp>(loc, arg, i));
1284 }
1285
1286 return dyn_shape;
1287 }
1288
1289 class ReduceRegionReturnOpConversion
1290 : public OpConversionPattern<mhlo::ReturnOp> {
1291 public:
1292 using OpConversionPattern<mhlo::ReturnOp>::OpConversionPattern;
matchAndRewrite(mhlo::ReturnOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1293 LogicalResult matchAndRewrite(
1294 mhlo::ReturnOp op, ArrayRef<Value> args,
1295 ConversionPatternRewriter& rewriter) const final {
1296 rewriter.replaceOpWithNewOp<linalg::YieldOp>(op, args);
1297 return success();
1298 }
1299 };
1300
1301 class ReduceOnTensorsConversion : public OpConversionPattern<mhlo::ReduceOp> {
1302 public:
1303 using OpConversionPattern<mhlo::ReduceOp>::OpConversionPattern;
matchAndRewrite(mhlo::ReduceOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1304 LogicalResult matchAndRewrite(
1305 mhlo::ReduceOp op, ArrayRef<Value> args,
1306 ConversionPatternRewriter& rewriter) const final {
1307 Location loc = op.getLoc();
1308 mhlo::ReduceOp::Adaptor adaptor(args);
1309 if (op.getNumOperands() != 2) {
1310 return op.emitError("expects exactly two operands");
1311 }
1312 Value src = adaptor.operands()[0];
1313 auto src_type = src.getType().cast<ShapedType>();
1314 int src_rank = src_type.getRank();
1315 if (!src_rank) {
1316 return rewriter.notifyMatchFailure(op, "expects known-rank args");
1317 }
1318
1319 // Check if init_value is constant. If so, inline the value into the region.
1320 Value init_value = adaptor.init_values()[0];
1321 Attribute init_const_val = GetInitValueAsConst(init_value);
1322 if (init_const_val) {
1323 init_value = rewriter.create<ConstantOp>(
1324 init_value.getDefiningOp()->getLoc(), init_const_val);
1325 } else {
1326 init_value = rewriter.create<tensor::ExtractOp>(loc, init_value);
1327 }
1328
1329 // Prepare indexing maps for linalg generic op. The elements are for src and
1330 // dst. Transpose `src` to make the reduction loops be the innermost,
1331 // because it's easier to fully utilize processors.
1332 SmallVector<AffineMap, 3> indexing_maps;
1333 SmallVector<int64_t, 4> reduction_dims = Extract1DVector(op.dimensions());
1334 indexing_maps.emplace_back(GetTransposeMapForReduction(
1335 rewriter.getContext(), src_rank, reduction_dims));
1336
1337 // The indexing map of `dst` should drop the reduction loops. Since the
1338 // reduction loops now are all in the innermost, drops
1339 // `reduction_dims.size()` dimensions. We don't need an inverse permutation
1340 // here because they are the same.
1341 SmallVector<AffineExpr, 4> exprs;
1342 for (int i = 0, e = src_rank - reduction_dims.size(); i < e; ++i)
1343 exprs.push_back(rewriter.getAffineDimExpr(i));
1344 indexing_maps.emplace_back(AffineMap::get(src_rank, /*symbolCount=*/0,
1345 exprs, rewriter.getContext()));
1346
1347 SmallVector<Value, 2> inputs = {adaptor.operands()[0]};
1348 Type result_type = op.getResult(0).getType();
1349 auto shaped_type = result_type.cast<ShapedType>();
1350 SmallVector<Value, 8> dyn_shape = GetReduceOpInitTensorDynSizes(
1351 rewriter, loc, adaptor.operands()[0], result_type.cast<ShapedType>(),
1352 reduction_dims);
1353 auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape);
1354 Value filled_tensor =
1355 rewriter.create<linalg::FillOp>(loc, init_tensor, init_value)
1356 .getResult(0);
1357
1358 auto linalg_op = rewriter.create<linalg::GenericOp>(
1359 loc, /*resultTensorTypes=*/op.getResultTypes(), inputs,
1360 /*outputBuffers=*/ValueRange{filled_tensor}, indexing_maps,
1361 GetParallelAndReductionIterators(src_rank, reduction_dims.size()));
1362
1363 // Convert the signature of the body. The reduce op region apply function
1364 // has a signature (lhs, rhs) -> output, all of the same tensor type t.
1365 // This is converted to a function with the same signature but with
1366 // element types. E.g., "(tensor<f32>, tensor<f32>) -> tensor<f32>" will
1367 // be converted to "(f32, f32, f32)".
1368 Region& region = linalg_op.region();
1369 rewriter.inlineRegionBefore(op.body(), region, region.end());
1370 TypeConverter::SignatureConversion signatureConverter(2);
1371 signatureConverter.addInputs(0, src_type.getElementType());
1372 signatureConverter.addInputs(1, src_type.getElementType());
1373 rewriter.applySignatureConversion(®ion, signatureConverter);
1374 rewriter.replaceOp(op, linalg_op.getResults());
1375 return success();
1376 }
1377 };
1378
populateLHLOToLinalgConversionPattern(MLIRContext * context,OwningRewritePatternList * patterns)1379 void populateLHLOToLinalgConversionPattern(MLIRContext* context,
1380 OwningRewritePatternList* patterns) {
1381 // clang-format off
1382 patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
1383 ConstConverter<lmhlo::ConstOp>,
1384 ConvToLinalgConverter,
1385 IotaConverter<lmhlo::IotaOp>,
1386 LhloBroadcastInDimConverter,
1387 PointwiseToLinalgConverter<lmhlo::AbsOp>,
1388 PointwiseToLinalgConverter<lmhlo::AddOp>,
1389 PointwiseToLinalgConverter<lmhlo::AndOp>,
1390 PointwiseToLinalgConverter<lmhlo::Atan2Op>,
1391 PointwiseToLinalgConverter<lmhlo::CeilOp>,
1392 PointwiseToLinalgConverter<lmhlo::ClampOp>,
1393 PointwiseToLinalgConverter<lmhlo::CompareOp>,
1394 PointwiseToLinalgConverter<lmhlo::ComplexOp>,
1395 PointwiseToLinalgConverter<lmhlo::ConvertOp>,
1396 // TODO(ataei): Remove this pattern, CopyOp is folded away.
1397 PointwiseToLinalgConverter<lmhlo::CopyOp>,
1398 PointwiseToLinalgConverter<lmhlo::CosOp>,
1399 PointwiseToLinalgConverter<lmhlo::DivOp>,
1400 PointwiseToLinalgConverter<lmhlo::ExpOp>,
1401 PointwiseToLinalgConverter<lmhlo::Expm1Op>,
1402 PointwiseToLinalgConverter<lmhlo::FloorOp>,
1403 PointwiseToLinalgConverter<lmhlo::ImagOp>,
1404 PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
1405 PointwiseToLinalgConverter<lmhlo::LogOp>,
1406 PointwiseToLinalgConverter<lmhlo::LogisticOp>,
1407 PointwiseToLinalgConverter<lmhlo::Log1pOp>,
1408 PointwiseToLinalgConverter<lmhlo::MaxOp>,
1409 PointwiseToLinalgConverter<lmhlo::MinOp>,
1410 PointwiseToLinalgConverter<lmhlo::MulOp>,
1411 PointwiseToLinalgConverter<lmhlo::NegOp>,
1412 PointwiseToLinalgConverter<lmhlo::NotOp>,
1413 PointwiseToLinalgConverter<lmhlo::OrOp>,
1414 PointwiseToLinalgConverter<lmhlo::PowOp>,
1415 PointwiseToLinalgConverter<lmhlo::RealOp>,
1416 PointwiseToLinalgConverter<lmhlo::RemOp>,
1417 PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
1418 PointwiseToLinalgConverter<lmhlo::SelectOp>,
1419 PointwiseToLinalgConverter<lmhlo::ShiftLeftOp>,
1420 PointwiseToLinalgConverter<lmhlo::ShiftRightArithmeticOp>,
1421 PointwiseToLinalgConverter<lmhlo::ShiftRightLogicalOp>,
1422 PointwiseToLinalgConverter<lmhlo::SignOp>,
1423 PointwiseToLinalgConverter<lmhlo::SinOp>,
1424 PointwiseToLinalgConverter<lmhlo::SqrtOp>,
1425 PointwiseToLinalgConverter<lmhlo::SubOp>,
1426 PointwiseToLinalgConverter<lmhlo::TanhOp>,
1427 PointwiseToLinalgConverter<lmhlo::XorOp>,
1428 ReduceConverter,
1429 ReshapeOpConverter<lmhlo::ReshapeOp>,
1430 ReverseConverter<lmhlo::ReverseOp>,
1431 ScalarPointwiseToStandardConverter<lmhlo::AddOp>,
1432 ScalarPointwiseToStandardConverter<lmhlo::MaxOp>,
1433 SliceConverter,
1434 TransposeConverter<lmhlo::TransposeOp>
1435 >(context);
1436 // clang-format on
1437 }
1438
1439 // Converts LHLO ops to Linalg generic.
1440 // Sample result for lmhlo::AddOp.
1441 //
1442 // "lmhlo.add"(%arg1, %arg2, %out) :
1443 // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
1444 //
1445 // will be converted to
1446 //
1447 // #map0 = (d0, d1) -> (d0, d1)
1448 // "linalg.generic"(%arg1, %arg2, %out) ( {
1449 // ^bb0(%arg4: f32, %arg5: f32):
1450 // %0 = addf %arg4, %arg5 : f32
1451 // "linalg.yield"(%0) : (f32) -> ()
1452 // }) {
1453 // indexing_maps = [#map0, #map0, #map0],
1454 // iterator_types = ["parallel", "parallel"],
1455 // } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
1456 struct LhloLegalizeToLinalgPass
1457 : public PassWrapper<LhloLegalizeToLinalgPass, FunctionPass> {
getDependentDialectsmlir::__anon00bff3f50111::LhloLegalizeToLinalgPass1458 void getDependentDialects(DialectRegistry& registry) const override {
1459 registry.insert<AffineDialect, linalg::LinalgDialect, math::MathDialect>();
1460 }
1461
runOnFunctionmlir::__anon00bff3f50111::LhloLegalizeToLinalgPass1462 void runOnFunction() override {
1463 OwningRewritePatternList patterns;
1464 ConversionTarget target(getContext());
1465 target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
1466 math::MathDialect, StandardOpsDialect,
1467 AffineDialect>();
1468
1469 auto func = getFunction();
1470 populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
1471 if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
1472 signalPassFailure();
1473 }
1474 }
1475 };
1476
1477 struct HloLegalizeToLinalgPass
1478 : public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
getDependentDialectsmlir::__anon00bff3f50111::HloLegalizeToLinalgPass1479 void getDependentDialects(DialectRegistry& registry) const override {
1480 registry.insert<linalg::LinalgDialect, scf::SCFDialect,
1481 complex::ComplexDialect, math::MathDialect>();
1482 }
1483
runOnFunctionmlir::__anon00bff3f50111::HloLegalizeToLinalgPass1484 void runOnFunction() override {
1485 OwningRewritePatternList patterns;
1486 ConversionTarget target(getContext());
1487 target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
1488 math::MathDialect, StandardOpsDialect,
1489 tensor::TensorDialect, scf::SCFDialect>();
1490
1491 auto func = getFunction();
1492 mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
1493 if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
1494 signalPassFailure();
1495 }
1496 }
1497 };
1498
1499 } // namespace
1500
1501 namespace lmhlo {
createLegalizeLhloToLinalgPass()1502 std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass() {
1503 return std::make_unique<LhloLegalizeToLinalgPass>();
1504 }
1505 } // namespace lmhlo
1506
1507 namespace mhlo {
1508
populateHLOToLinalgConversionPattern(MLIRContext * context,OwningRewritePatternList * patterns)1509 void populateHLOToLinalgConversionPattern(MLIRContext* context,
1510 OwningRewritePatternList* patterns) {
1511 patterns->insert<
1512 BroadcastConverter<mhlo::BroadcastOp, false>,
1513 ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter,
1514 HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
1515 PointwiseToLinalgConverter<mhlo::AbsOp, false>,
1516 PointwiseToLinalgConverter<mhlo::AddOp, false>,
1517 PointwiseToLinalgConverter<mhlo::AndOp, false>,
1518 PointwiseToLinalgConverter<mhlo::Atan2Op, false>,
1519 PointwiseToLinalgConverter<mhlo::CeilOp, false>,
1520 PointwiseToLinalgConverter<mhlo::ClampOp, false>,
1521 PointwiseToLinalgConverter<mhlo::CompareOp, false>,
1522 PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
1523 PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
1524 PointwiseToLinalgConverter<mhlo::CopyOp, false>,
1525 PointwiseToLinalgConverter<mhlo::CosOp, false>,
1526 PointwiseToLinalgConverter<mhlo::DivOp, false>,
1527 PointwiseToLinalgConverter<mhlo::ExpOp, false>,
1528 PointwiseToLinalgConverter<mhlo::Expm1Op, false>,
1529 PointwiseToLinalgConverter<mhlo::FloorOp, false>,
1530 PointwiseToLinalgConverter<mhlo::ImagOp, false>,
1531 PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
1532 PointwiseToLinalgConverter<mhlo::LogOp, false>,
1533 PointwiseToLinalgConverter<mhlo::LogisticOp, false>,
1534 PointwiseToLinalgConverter<mhlo::Log1pOp, false>,
1535 PointwiseToLinalgConverter<mhlo::MaxOp, false>,
1536 PointwiseToLinalgConverter<mhlo::MinOp, false>,
1537 PointwiseToLinalgConverter<mhlo::MulOp, false>,
1538 PointwiseToLinalgConverter<mhlo::NegOp, false>,
1539 PointwiseToLinalgConverter<mhlo::NotOp, false>,
1540 PointwiseToLinalgConverter<mhlo::OrOp, false>,
1541 PointwiseToLinalgConverter<mhlo::PowOp, false>,
1542 PointwiseToLinalgConverter<mhlo::RealOp, false>,
1543 PointwiseToLinalgConverter<mhlo::RemOp, false>,
1544 PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
1545 PointwiseToLinalgConverter<mhlo::SelectOp, false>,
1546 PointwiseToLinalgConverter<mhlo::ShiftLeftOp, false>,
1547 PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp, false>,
1548 PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp, false>,
1549 PointwiseToLinalgConverter<mhlo::SignOp, false>,
1550 PointwiseToLinalgConverter<mhlo::SinOp, false>,
1551 PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
1552 PointwiseToLinalgConverter<mhlo::SubOp, false>,
1553 PointwiseToLinalgConverter<mhlo::TanhOp, false>,
1554 PointwiseToLinalgConverter<mhlo::XorOp, false>,
1555 ReshapeOpConverter<mhlo::ReshapeOp, false>,
1556 ReverseConverter<mhlo::ReverseOp, false>,
1557 TransposeConverter<mhlo::TransposeOp, false>, DotOpOnTensorsConversion,
1558 DotGeneralOpOnTensorsConversion, ReduceOnTensorsConversion>(context);
1559 patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,
1560 ReduceRegionXLAOpConversion<mhlo::MinOp>,
1561 ReduceRegionXLAOpConversion<mhlo::MaxOp>,
1562 ReduceRegionReturnOpConversion>(context);
1563 }
1564
createLegalizeHloToLinalgPass()1565 std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
1566 return std::make_unique<HloLegalizeToLinalgPass>();
1567 }
1568 } // namespace mhlo
1569 } // namespace mlir
1570