1 //===- StandardToLLVM.cpp - Standard to LLVM dialect conversion -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR standard and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "../PassDetail.h"
15 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
16 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/Attributes.h"
20 #include "mlir/IR/BlockAndValueMapping.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/MLIRContext.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/TypeUtilities.h"
26 #include "mlir/Support/LogicalResult.h"
27 #include "mlir/Support/MathExtras.h"
28 #include "mlir/Transforms/DialectConversion.h"
29 #include "mlir/Transforms/Passes.h"
30 #include "mlir/Transforms/Utils.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/IR/DerivedTypes.h"
33 #include "llvm/IR/IRBuilder.h"
34 #include "llvm/IR/Type.h"
35 #include "llvm/Support/CommandLine.h"
36 #include "llvm/Support/FormatVariadic.h"
37 #include <functional>
38 
39 using namespace mlir;
40 
41 #define PASS_NAME "convert-std-to-llvm"
42 
43 // Extract an LLVM IR type from the LLVM IR dialect type.
unwrap(Type type)44 static LLVM::LLVMType unwrap(Type type) {
45   if (!type)
46     return nullptr;
47   auto *mlirContext = type.getContext();
48   auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>();
49   if (!wrappedLLVMType)
50     emitError(UnknownLoc::get(mlirContext),
51               "conversion resulted in a non-LLVM type");
52   return wrappedLLVMType;
53 }
54 
55 /// Callback to convert function argument types. It converts a MemRef function
56 /// argument to a list of non-aggregate types containing descriptor
57 /// information, and an UnrankedmemRef function argument to a list containing
58 /// the rank and a pointer to a descriptor struct.
structFuncArgTypeConverter(LLVMTypeConverter & converter,Type type,SmallVectorImpl<Type> & result)59 LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter,
60                                                Type type,
61                                                SmallVectorImpl<Type> &result) {
62   if (auto memref = type.dyn_cast<MemRefType>()) {
63     // In signatures, Memref descriptors are expanded into lists of
64     // non-aggregate values.
65     auto converted =
66         converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true);
67     if (converted.empty())
68       return failure();
69     result.append(converted.begin(), converted.end());
70     return success();
71   }
72   if (type.isa<UnrankedMemRefType>()) {
73     auto converted = converter.getUnrankedMemRefDescriptorFields();
74     if (converted.empty())
75       return failure();
76     result.append(converted.begin(), converted.end());
77     return success();
78   }
79   auto converted = converter.convertType(type);
80   if (!converted)
81     return failure();
82   result.push_back(converted);
83   return success();
84 }
85 
86 /// Callback to convert function argument types. It converts MemRef function
87 /// arguments to bare pointers to the MemRef element type.
barePtrFuncArgTypeConverter(LLVMTypeConverter & converter,Type type,SmallVectorImpl<Type> & result)88 LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
89                                                 Type type,
90                                                 SmallVectorImpl<Type> &result) {
91   auto llvmTy = converter.convertCallingConventionType(type);
92   if (!llvmTy)
93     return failure();
94 
95   result.push_back(llvmTy);
96   return success();
97 }
98 
99 /// Create an LLVMTypeConverter using default LowerToLLVMOptions.
LLVMTypeConverter(MLIRContext * ctx)100 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx)
101     : LLVMTypeConverter(ctx, LowerToLLVMOptions::getDefaultOptions()) {}
102 
103 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
LLVMTypeConverter(MLIRContext * ctx,const LowerToLLVMOptions & options)104 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
105                                      const LowerToLLVMOptions &options)
106     : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()),
107       options(options) {
108   assert(llvmDialect && "LLVM IR dialect is not registered");
109   if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout)
110     this->options.indexBitwidth = options.dataLayout.getPointerSizeInBits();
111 
112   // Register conversions for the builtin types.
113   addConversion([&](ComplexType type) { return convertComplexType(type); });
114   addConversion([&](FloatType type) { return convertFloatType(type); });
115   addConversion([&](FunctionType type) { return convertFunctionType(type); });
116   addConversion([&](IndexType type) { return convertIndexType(type); });
117   addConversion([&](IntegerType type) { return convertIntegerType(type); });
118   addConversion([&](MemRefType type) { return convertMemRefType(type); });
119   addConversion(
120       [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
121   addConversion([&](VectorType type) { return convertVectorType(type); });
122 
123   // LLVMType is legal, so add a pass-through conversion.
124   addConversion([](LLVM::LLVMType type) { return type; });
125 
126   // Materialization for memrefs creates descriptor structs from individual
127   // values constituting them, when descriptors are used, i.e. more than one
128   // value represents a memref.
129   addArgumentMaterialization(
130       [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
131           Location loc) -> Optional<Value> {
132         if (inputs.size() == 1)
133           return llvm::None;
134         return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
135                                               inputs);
136       });
137   addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
138                                  ValueRange inputs,
139                                  Location loc) -> Optional<Value> {
140     if (inputs.size() == 1)
141       return llvm::None;
142     return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
143   });
144   // Add generic source and target materializations to handle cases where
145   // non-LLVM types persist after an LLVM conversion.
146   addSourceMaterialization([&](OpBuilder &builder, Type resultType,
147                                ValueRange inputs,
148                                Location loc) -> Optional<Value> {
149     if (inputs.size() != 1)
150       return llvm::None;
151     // FIXME: These should check LLVM::DialectCastOp can actually be constructed
152     // from the input and result.
153     return builder.create<LLVM::DialectCastOp>(loc, resultType, inputs[0])
154         .getResult();
155   });
156   addTargetMaterialization([&](OpBuilder &builder, Type resultType,
157                                ValueRange inputs,
158                                Location loc) -> Optional<Value> {
159     if (inputs.size() != 1)
160       return llvm::None;
161     // FIXME: These should check LLVM::DialectCastOp can actually be constructed
162     // from the input and result.
163     return builder.create<LLVM::DialectCastOp>(loc, resultType, inputs[0])
164         .getResult();
165   });
166 }
167 
168 /// Returns the MLIR context.
getContext()169 MLIRContext &LLVMTypeConverter::getContext() {
170   return *getDialect()->getContext();
171 }
172 
getIndexType()173 LLVM::LLVMType LLVMTypeConverter::getIndexType() {
174   return LLVM::LLVMType::getIntNTy(&getContext(), getIndexTypeBitwidth());
175 }
176 
getPointerBitwidth(unsigned addressSpace)177 unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) {
178   return options.dataLayout.getPointerSizeInBits(addressSpace);
179 }
180 
convertIndexType(IndexType type)181 Type LLVMTypeConverter::convertIndexType(IndexType type) {
182   return getIndexType();
183 }
184 
convertIntegerType(IntegerType type)185 Type LLVMTypeConverter::convertIntegerType(IntegerType type) {
186   return LLVM::LLVMType::getIntNTy(&getContext(), type.getWidth());
187 }
188 
convertFloatType(FloatType type)189 Type LLVMTypeConverter::convertFloatType(FloatType type) {
190   if (type.isa<Float32Type>())
191     return LLVM::LLVMType::getFloatTy(&getContext());
192   if (type.isa<Float64Type>())
193     return LLVM::LLVMType::getDoubleTy(&getContext());
194   if (type.isa<Float16Type>())
195     return LLVM::LLVMType::getHalfTy(&getContext());
196   if (type.isa<BFloat16Type>())
197     return LLVM::LLVMType::getBFloatTy(&getContext());
198   llvm_unreachable("non-float type in convertFloatType");
199 }
200 
201 // Convert a `ComplexType` to an LLVM type. The result is a complex number
202 // struct with entries for the
203 //   1. real part and for the
204 //   2. imaginary part.
205 static constexpr unsigned kRealPosInComplexNumberStruct = 0;
206 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
convertComplexType(ComplexType type)207 Type LLVMTypeConverter::convertComplexType(ComplexType type) {
208   auto elementType = convertType(type.getElementType()).cast<LLVM::LLVMType>();
209   return LLVM::LLVMType::getStructTy(&getContext(), {elementType, elementType});
210 }
211 
212 // Except for signatures, MLIR function types are converted into LLVM
213 // pointer-to-function types.
convertFunctionType(FunctionType type)214 Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
215   SignatureConversion conversion(type.getNumInputs());
216   LLVM::LLVMType converted =
217       convertFunctionSignature(type, /*isVariadic=*/false, conversion);
218   return converted.getPointerTo();
219 }
220 
221 
222 // Function types are converted to LLVM Function types by recursively converting
223 // argument and result types.  If MLIR Function has zero results, the LLVM
224 // Function has one VoidType result.  If MLIR Function has more than one result,
225 // they are into an LLVM StructType in their order of appearance.
convertFunctionSignature(FunctionType funcTy,bool isVariadic,LLVMTypeConverter::SignatureConversion & result)226 LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
227     FunctionType funcTy, bool isVariadic,
228     LLVMTypeConverter::SignatureConversion &result) {
229   // Select the argument converter depending on the calling convention.
230   auto funcArgConverter = options.useBarePtrCallConv
231                               ? barePtrFuncArgTypeConverter
232                               : structFuncArgTypeConverter;
233   // Convert argument types one by one and check for errors.
234   for (auto &en : llvm::enumerate(funcTy.getInputs())) {
235     Type type = en.value();
236     SmallVector<Type, 8> converted;
237     if (failed(funcArgConverter(*this, type, converted)))
238       return {};
239     result.addInputs(en.index(), converted);
240   }
241 
242   SmallVector<LLVM::LLVMType, 8> argTypes;
243   argTypes.reserve(llvm::size(result.getConvertedTypes()));
244   for (Type type : result.getConvertedTypes())
245     argTypes.push_back(unwrap(type));
246 
247   // If function does not return anything, create the void result type,
248   // if it returns on element, convert it, otherwise pack the result types into
249   // a struct.
250   LLVM::LLVMType resultType =
251       funcTy.getNumResults() == 0
252           ? LLVM::LLVMType::getVoidTy(&getContext())
253           : unwrap(packFunctionResults(funcTy.getResults()));
254   if (!resultType)
255     return {};
256   return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic);
257 }
258 
259 /// Converts the function type to a C-compatible format, in particular using
260 /// pointers to memref descriptors for arguments.
261 LLVM::LLVMType
convertFunctionTypeCWrapper(FunctionType type)262 LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
263   SmallVector<LLVM::LLVMType, 4> inputs;
264 
265   for (Type t : type.getInputs()) {
266     auto converted = convertType(t).dyn_cast_or_null<LLVM::LLVMType>();
267     if (!converted)
268       return {};
269     if (t.isa<MemRefType, UnrankedMemRefType>())
270       converted = converted.getPointerTo();
271     inputs.push_back(converted);
272   }
273 
274   LLVM::LLVMType resultType =
275       type.getNumResults() == 0
276           ? LLVM::LLVMType::getVoidTy(&getContext())
277           : unwrap(packFunctionResults(type.getResults()));
278   if (!resultType)
279     return {};
280 
281   return LLVM::LLVMType::getFunctionTy(resultType, inputs, false);
282 }
283 
284 static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0;
285 static constexpr unsigned kAlignedPtrPosInMemRefDescriptor = 1;
286 static constexpr unsigned kOffsetPosInMemRefDescriptor = 2;
287 static constexpr unsigned kSizePosInMemRefDescriptor = 3;
288 static constexpr unsigned kStridePosInMemRefDescriptor = 4;
289 
290 /// Convert a memref type into a list of LLVM IR types that will form the
291 /// memref descriptor. The result contains the following types:
292 ///  1. The pointer to the allocated data buffer, followed by
293 ///  2. The pointer to the aligned data buffer, followed by
294 ///  3. A lowered `index`-type integer containing the distance between the
295 ///  beginning of the buffer and the first element to be accessed through the
296 ///  view, followed by
297 ///  4. An array containing as many `index`-type integers as the rank of the
298 ///  MemRef: the array represents the size, in number of elements, of the memref
299 ///  along the given dimension. For constant MemRef dimensions, the
300 ///  corresponding size entry is a constant whose runtime value must match the
301 ///  static value, followed by
302 ///  5. A second array containing as many `index`-type integers as the rank of
303 ///  the MemRef: the second array represents the "stride" (in tensor abstraction
304 ///  sense), i.e. the number of consecutive elements of the underlying buffer.
305 ///  TODO: add assertions for the static cases.
306 ///
307 ///  If `unpackAggregates` is set to true, the arrays described in (4) and (5)
308 ///  are expanded into individual index-type elements.
309 ///
310 ///  template <typename Elem, typename Index, size_t Rank>
311 ///  struct {
312 ///    Elem *allocatedPtr;
313 ///    Elem *alignedPtr;
314 ///    Index offset;
315 ///    Index sizes[Rank]; // omitted when rank == 0
316 ///    Index strides[Rank]; // omitted when rank == 0
317 ///  };
318 SmallVector<LLVM::LLVMType, 5>
getMemRefDescriptorFields(MemRefType type,bool unpackAggregates)319 LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
320                                              bool unpackAggregates) {
321   assert(isStrided(type) &&
322          "Non-strided layout maps must have been normalized away");
323 
324   LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
325   if (!elementType)
326     return {};
327   auto ptrTy = elementType.getPointerTo(type.getMemorySpace());
328   auto indexTy = getIndexType();
329 
330   SmallVector<LLVM::LLVMType, 5> results = {ptrTy, ptrTy, indexTy};
331   auto rank = type.getRank();
332   if (rank == 0)
333     return results;
334 
335   if (unpackAggregates)
336     results.insert(results.end(), 2 * rank, indexTy);
337   else
338     results.insert(results.end(), 2, LLVM::LLVMType::getArrayTy(indexTy, rank));
339   return results;
340 }
341 
342 /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
343 /// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
convertMemRefType(MemRefType type)344 Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
345   // When converting a MemRefType to a struct with descriptor fields, do not
346   // unpack the `sizes` and `strides` arrays.
347   SmallVector<LLVM::LLVMType, 5> types =
348       getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
349   return LLVM::LLVMType::getStructTy(&getContext(), types);
350 }
351 
352 static constexpr unsigned kRankInUnrankedMemRefDescriptor = 0;
353 static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1;
354 
355 /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
356 /// that will form the unranked memref descriptor. In particular, the fields
357 /// for an unranked memref descriptor are:
358 /// 1. index-typed rank, the dynamic rank of this MemRef
359 /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
360 ///    stack allocated (alloca) copy of a MemRef descriptor that got casted to
361 ///    be unranked.
362 SmallVector<LLVM::LLVMType, 2>
getUnrankedMemRefDescriptorFields()363 LLVMTypeConverter::getUnrankedMemRefDescriptorFields() {
364   return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(&getContext())};
365 }
366 
convertUnrankedMemRefType(UnrankedMemRefType type)367 Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
368   return LLVM::LLVMType::getStructTy(&getContext(),
369                                      getUnrankedMemRefDescriptorFields());
370 }
371 
372 /// Convert a memref type to a bare pointer to the memref element type.
convertMemRefToBarePtr(BaseMemRefType type)373 Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) {
374   if (type.isa<UnrankedMemRefType>())
375     // Unranked memref is not supported in the bare pointer calling convention.
376     return {};
377 
378   // Check that the memref has static shape, strides and offset. Otherwise, it
379   // cannot be lowered to a bare pointer.
380   auto memrefTy = type.cast<MemRefType>();
381   if (!memrefTy.hasStaticShape())
382     return {};
383 
384   int64_t offset = 0;
385   SmallVector<int64_t, 4> strides;
386   if (failed(getStridesAndOffset(memrefTy, strides, offset)))
387     return {};
388 
389   for (int64_t stride : strides)
390     if (ShapedType::isDynamicStrideOrOffset(stride))
391       return {};
392 
393   if (ShapedType::isDynamicStrideOrOffset(offset))
394     return {};
395 
396   LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
397   if (!elementType)
398     return {};
399   return elementType.getPointerTo(type.getMemorySpace());
400 }
401 
402 // Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when
403 // n > 1.
404 // For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and
405 // `vector<4 x 8 x 16 f32>` converts to `!llvm<"[4 x [8 x <16 x float>]]">`.
convertVectorType(VectorType type)406 Type LLVMTypeConverter::convertVectorType(VectorType type) {
407   auto elementType = unwrap(convertType(type.getElementType()));
408   if (!elementType)
409     return {};
410   auto vectorType =
411       LLVM::LLVMType::getVectorTy(elementType, type.getShape().back());
412   auto shape = type.getShape();
413   for (int i = shape.size() - 2; i >= 0; --i)
414     vectorType = LLVM::LLVMType::getArrayTy(vectorType, shape[i]);
415   return vectorType;
416 }
417 
418 /// Convert a type in the context of the default or bare pointer calling
419 /// convention. Calling convention sensitive types, such as MemRefType and
420 /// UnrankedMemRefType, are converted following the specific rules for the
421 /// calling convention. Calling convention independent types are converted
422 /// following the default LLVM type conversions.
convertCallingConventionType(Type type)423 Type LLVMTypeConverter::convertCallingConventionType(Type type) {
424   if (options.useBarePtrCallConv)
425     if (auto memrefTy = type.dyn_cast<BaseMemRefType>())
426       return convertMemRefToBarePtr(memrefTy);
427 
428   return convertType(type);
429 }
430 
431 /// Promote the bare pointers in 'values' that resulted from memrefs to
432 /// descriptors. 'stdTypes' holds they types of 'values' before the conversion
433 /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
promoteBarePtrsToDescriptors(ConversionPatternRewriter & rewriter,Location loc,ArrayRef<Type> stdTypes,SmallVectorImpl<Value> & values)434 void LLVMTypeConverter::promoteBarePtrsToDescriptors(
435     ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
436     SmallVectorImpl<Value> &values) {
437   assert(stdTypes.size() == values.size() &&
438          "The number of types and values doesn't match");
439   for (unsigned i = 0, end = values.size(); i < end; ++i)
440     if (auto memrefTy = stdTypes[i].dyn_cast<MemRefType>())
441       values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
442                                                     memrefTy, values[i]);
443 }
444 
ConvertToLLVMPattern(StringRef rootOpName,MLIRContext * context,LLVMTypeConverter & typeConverter,PatternBenefit benefit)445 ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
446                                            MLIRContext *context,
447                                            LLVMTypeConverter &typeConverter,
448                                            PatternBenefit benefit)
449     : ConversionPattern(rootOpName, benefit, typeConverter, context) {}
450 
451 //===----------------------------------------------------------------------===//
452 // StructBuilder implementation
453 //===----------------------------------------------------------------------===//
454 
StructBuilder(Value v)455 StructBuilder::StructBuilder(Value v) : value(v) {
456   assert(value != nullptr && "value cannot be null");
457   structType = value.getType().dyn_cast<LLVM::LLVMType>();
458   assert(structType && "expected llvm type");
459 }
460 
extractPtr(OpBuilder & builder,Location loc,unsigned pos)461 Value StructBuilder::extractPtr(OpBuilder &builder, Location loc,
462                                 unsigned pos) {
463   Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos);
464   return builder.create<LLVM::ExtractValueOp>(loc, type, value,
465                                               builder.getI64ArrayAttr(pos));
466 }
467 
setPtr(OpBuilder & builder,Location loc,unsigned pos,Value ptr)468 void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos,
469                            Value ptr) {
470   value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
471                                               builder.getI64ArrayAttr(pos));
472 }
473 
474 //===----------------------------------------------------------------------===//
475 // ComplexStructBuilder implementation
476 //===----------------------------------------------------------------------===//
477 
undef(OpBuilder & builder,Location loc,Type type)478 ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder,
479                                                  Location loc, Type type) {
480   Value val = builder.create<LLVM::UndefOp>(loc, type.cast<LLVM::LLVMType>());
481   return ComplexStructBuilder(val);
482 }
483 
setReal(OpBuilder & builder,Location loc,Value real)484 void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc,
485                                    Value real) {
486   setPtr(builder, loc, kRealPosInComplexNumberStruct, real);
487 }
488 
real(OpBuilder & builder,Location loc)489 Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) {
490   return extractPtr(builder, loc, kRealPosInComplexNumberStruct);
491 }
492 
setImaginary(OpBuilder & builder,Location loc,Value imaginary)493 void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc,
494                                         Value imaginary) {
495   setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary);
496 }
497 
imaginary(OpBuilder & builder,Location loc)498 Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) {
499   return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct);
500 }
501 
502 //===----------------------------------------------------------------------===//
503 // MemRefDescriptor implementation
504 //===----------------------------------------------------------------------===//
505 
506 /// Construct a helper for the given descriptor value.
MemRefDescriptor(Value descriptor)507 MemRefDescriptor::MemRefDescriptor(Value descriptor)
508     : StructBuilder(descriptor) {
509   assert(value != nullptr && "value cannot be null");
510   indexType = value.getType().cast<LLVM::LLVMType>().getStructElementType(
511       kOffsetPosInMemRefDescriptor);
512 }
513 
514 /// Builds IR creating an `undef` value of the descriptor type.
undef(OpBuilder & builder,Location loc,Type descriptorType)515 MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
516                                          Type descriptorType) {
517 
518   Value descriptor =
519       builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
520   return MemRefDescriptor(descriptor);
521 }
522 
523 /// Builds IR creating a MemRef descriptor that represents `type` and
524 /// populates it with static shape and stride information extracted from the
525 /// type.
526 MemRefDescriptor
fromStaticShape(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,MemRefType type,Value memory)527 MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
528                                   LLVMTypeConverter &typeConverter,
529                                   MemRefType type, Value memory) {
530   assert(type.hasStaticShape() && "unexpected dynamic shape");
531 
532   // Extract all strides and offsets and verify they are static.
533   int64_t offset;
534   SmallVector<int64_t, 4> strides;
535   auto result = getStridesAndOffset(type, strides, offset);
536   (void)result;
537   assert(succeeded(result) && "unexpected failure in stride computation");
538   assert(offset != MemRefType::getDynamicStrideOrOffset() &&
539          "expected static offset");
540   assert(!llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) &&
541          "expected static strides");
542 
543   auto convertedType = typeConverter.convertType(type);
544   assert(convertedType && "unexpected failure in memref type conversion");
545 
546   auto descr = MemRefDescriptor::undef(builder, loc, convertedType);
547   descr.setAllocatedPtr(builder, loc, memory);
548   descr.setAlignedPtr(builder, loc, memory);
549   descr.setConstantOffset(builder, loc, offset);
550 
551   // Fill in sizes and strides
552   for (unsigned i = 0, e = type.getRank(); i != e; ++i) {
553     descr.setConstantSize(builder, loc, i, type.getDimSize(i));
554     descr.setConstantStride(builder, loc, i, strides[i]);
555   }
556   return descr;
557 }
558 
559 /// Builds IR extracting the allocated pointer from the descriptor.
allocatedPtr(OpBuilder & builder,Location loc)560 Value MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) {
561   return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor);
562 }
563 
564 /// Builds IR inserting the allocated pointer into the descriptor.
setAllocatedPtr(OpBuilder & builder,Location loc,Value ptr)565 void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
566                                        Value ptr) {
567   setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr);
568 }
569 
570 /// Builds IR extracting the aligned pointer from the descriptor.
alignedPtr(OpBuilder & builder,Location loc)571 Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) {
572   return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor);
573 }
574 
575 /// Builds IR inserting the aligned pointer into the descriptor.
setAlignedPtr(OpBuilder & builder,Location loc,Value ptr)576 void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
577                                      Value ptr) {
578   setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr);
579 }
580 
581 // Creates a constant Op producing a value of `resultType` from an index-typed
582 // integer attribute.
createIndexAttrConstant(OpBuilder & builder,Location loc,Type resultType,int64_t value)583 static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
584                                      Type resultType, int64_t value) {
585   return builder.create<LLVM::ConstantOp>(
586       loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
587 }
588 
589 /// Builds IR extracting the offset from the descriptor.
offset(OpBuilder & builder,Location loc)590 Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) {
591   return builder.create<LLVM::ExtractValueOp>(
592       loc, indexType, value,
593       builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
594 }
595 
596 /// Builds IR inserting the offset into the descriptor.
setOffset(OpBuilder & builder,Location loc,Value offset)597 void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
598                                  Value offset) {
599   value = builder.create<LLVM::InsertValueOp>(
600       loc, structType, value, offset,
601       builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
602 }
603 
604 /// Builds IR inserting the offset into the descriptor.
setConstantOffset(OpBuilder & builder,Location loc,uint64_t offset)605 void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc,
606                                          uint64_t offset) {
607   setOffset(builder, loc,
608             createIndexAttrConstant(builder, loc, indexType, offset));
609 }
610 
611 /// Builds IR extracting the pos-th size from the descriptor.
size(OpBuilder & builder,Location loc,unsigned pos)612 Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
613   return builder.create<LLVM::ExtractValueOp>(
614       loc, indexType, value,
615       builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
616 }
617 
size(OpBuilder & builder,Location loc,Value pos,int64_t rank)618 Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
619                              int64_t rank) {
620   auto indexTy = indexType.cast<LLVM::LLVMType>();
621   auto indexPtrTy = indexTy.getPointerTo();
622   auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, rank);
623   auto arrayPtrTy = arrayTy.getPointerTo();
624 
625   // Copy size values to stack-allocated memory.
626   auto zero = createIndexAttrConstant(builder, loc, indexType, 0);
627   auto one = createIndexAttrConstant(builder, loc, indexType, 1);
628   auto sizes = builder.create<LLVM::ExtractValueOp>(
629       loc, arrayTy, value,
630       builder.getI64ArrayAttr({kSizePosInMemRefDescriptor}));
631   auto sizesPtr =
632       builder.create<LLVM::AllocaOp>(loc, arrayPtrTy, one, /*alignment=*/0);
633   builder.create<LLVM::StoreOp>(loc, sizes, sizesPtr);
634 
635   // Load an return size value of interest.
636   auto resultPtr = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizesPtr,
637                                                ValueRange({zero, pos}));
638   return builder.create<LLVM::LoadOp>(loc, resultPtr);
639 }
640 
641 /// Builds IR inserting the pos-th size into the descriptor
setSize(OpBuilder & builder,Location loc,unsigned pos,Value size)642 void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
643                                Value size) {
644   value = builder.create<LLVM::InsertValueOp>(
645       loc, structType, value, size,
646       builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
647 }
648 
setConstantSize(OpBuilder & builder,Location loc,unsigned pos,uint64_t size)649 void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc,
650                                        unsigned pos, uint64_t size) {
651   setSize(builder, loc, pos,
652           createIndexAttrConstant(builder, loc, indexType, size));
653 }
654 
655 /// Builds IR extracting the pos-th stride from the descriptor.
stride(OpBuilder & builder,Location loc,unsigned pos)656 Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) {
657   return builder.create<LLVM::ExtractValueOp>(
658       loc, indexType, value,
659       builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
660 }
661 
662 /// Builds IR inserting the pos-th stride into the descriptor
setStride(OpBuilder & builder,Location loc,unsigned pos,Value stride)663 void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
664                                  Value stride) {
665   value = builder.create<LLVM::InsertValueOp>(
666       loc, structType, value, stride,
667       builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
668 }
669 
setConstantStride(OpBuilder & builder,Location loc,unsigned pos,uint64_t stride)670 void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc,
671                                          unsigned pos, uint64_t stride) {
672   setStride(builder, loc, pos,
673             createIndexAttrConstant(builder, loc, indexType, stride));
674 }
675 
getElementPtrType()676 LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() {
677   return value.getType()
678       .cast<LLVM::LLVMType>()
679       .getStructElementType(kAlignedPtrPosInMemRefDescriptor)
680       .cast<LLVM::LLVMPointerType>();
681 }
682 
683 /// Creates a MemRef descriptor structure from a list of individual values
684 /// composing that descriptor, in the following order:
685 /// - allocated pointer;
686 /// - aligned pointer;
687 /// - offset;
688 /// - <rank> sizes;
689 /// - <rank> shapes;
690 /// where <rank> is the MemRef rank as provided in `type`.
pack(OpBuilder & builder,Location loc,LLVMTypeConverter & converter,MemRefType type,ValueRange values)691 Value MemRefDescriptor::pack(OpBuilder &builder, Location loc,
692                              LLVMTypeConverter &converter, MemRefType type,
693                              ValueRange values) {
694   Type llvmType = converter.convertType(type);
695   auto d = MemRefDescriptor::undef(builder, loc, llvmType);
696 
697   d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]);
698   d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]);
699   d.setOffset(builder, loc, values[kOffsetPosInMemRefDescriptor]);
700 
701   int64_t rank = type.getRank();
702   for (unsigned i = 0; i < rank; ++i) {
703     d.setSize(builder, loc, i, values[kSizePosInMemRefDescriptor + i]);
704     d.setStride(builder, loc, i, values[kSizePosInMemRefDescriptor + rank + i]);
705   }
706 
707   return d;
708 }
709 
710 /// Builds IR extracting individual elements of a MemRef descriptor structure
711 /// and returning them as `results` list.
unpack(OpBuilder & builder,Location loc,Value packed,MemRefType type,SmallVectorImpl<Value> & results)712 void MemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed,
713                               MemRefType type,
714                               SmallVectorImpl<Value> &results) {
715   int64_t rank = type.getRank();
716   results.reserve(results.size() + getNumUnpackedValues(type));
717 
718   MemRefDescriptor d(packed);
719   results.push_back(d.allocatedPtr(builder, loc));
720   results.push_back(d.alignedPtr(builder, loc));
721   results.push_back(d.offset(builder, loc));
722   for (int64_t i = 0; i < rank; ++i)
723     results.push_back(d.size(builder, loc, i));
724   for (int64_t i = 0; i < rank; ++i)
725     results.push_back(d.stride(builder, loc, i));
726 }
727 
728 /// Returns the number of non-aggregate values that would be produced by
729 /// `unpack`.
getNumUnpackedValues(MemRefType type)730 unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) {
731   // Two pointers, offset, <rank> sizes, <rank> shapes.
732   return 3 + 2 * type.getRank();
733 }
734 
735 //===----------------------------------------------------------------------===//
736 // MemRefDescriptorView implementation.
737 //===----------------------------------------------------------------------===//
738 
MemRefDescriptorView(ValueRange range)739 MemRefDescriptorView::MemRefDescriptorView(ValueRange range)
740     : rank((range.size() - kSizePosInMemRefDescriptor) / 2), elements(range) {}
741 
allocatedPtr()742 Value MemRefDescriptorView::allocatedPtr() {
743   return elements[kAllocatedPtrPosInMemRefDescriptor];
744 }
745 
alignedPtr()746 Value MemRefDescriptorView::alignedPtr() {
747   return elements[kAlignedPtrPosInMemRefDescriptor];
748 }
749 
offset()750 Value MemRefDescriptorView::offset() {
751   return elements[kOffsetPosInMemRefDescriptor];
752 }
753 
size(unsigned pos)754 Value MemRefDescriptorView::size(unsigned pos) {
755   return elements[kSizePosInMemRefDescriptor + pos];
756 }
757 
stride(unsigned pos)758 Value MemRefDescriptorView::stride(unsigned pos) {
759   return elements[kSizePosInMemRefDescriptor + rank + pos];
760 }
761 
762 //===----------------------------------------------------------------------===//
763 // UnrankedMemRefDescriptor implementation
764 //===----------------------------------------------------------------------===//
765 
766 /// Construct a helper for the given descriptor value.
UnrankedMemRefDescriptor(Value descriptor)767 UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor)
768     : StructBuilder(descriptor) {}
769 
770 /// Builds IR creating an `undef` value of the descriptor type.
undef(OpBuilder & builder,Location loc,Type descriptorType)771 UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder,
772                                                          Location loc,
773                                                          Type descriptorType) {
774   Value descriptor =
775       builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
776   return UnrankedMemRefDescriptor(descriptor);
777 }
rank(OpBuilder & builder,Location loc)778 Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) {
779   return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor);
780 }
setRank(OpBuilder & builder,Location loc,Value v)781 void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc,
782                                        Value v) {
783   setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v);
784 }
memRefDescPtr(OpBuilder & builder,Location loc)785 Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder,
786                                               Location loc) {
787   return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor);
788 }
setMemRefDescPtr(OpBuilder & builder,Location loc,Value v)789 void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder,
790                                                 Location loc, Value v) {
791   setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v);
792 }
793 
794 /// Builds IR populating an unranked MemRef descriptor structure from a list
795 /// of individual constituent values in the following order:
796 /// - rank of the memref;
797 /// - pointer to the memref descriptor.
pack(OpBuilder & builder,Location loc,LLVMTypeConverter & converter,UnrankedMemRefType type,ValueRange values)798 Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc,
799                                      LLVMTypeConverter &converter,
800                                      UnrankedMemRefType type,
801                                      ValueRange values) {
802   Type llvmType = converter.convertType(type);
803   auto d = UnrankedMemRefDescriptor::undef(builder, loc, llvmType);
804 
805   d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]);
806   d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]);
807   return d;
808 }
809 
810 /// Builds IR extracting individual elements that compose an unranked memref
811 /// descriptor and returns them as `results` list.
unpack(OpBuilder & builder,Location loc,Value packed,SmallVectorImpl<Value> & results)812 void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
813                                       Value packed,
814                                       SmallVectorImpl<Value> &results) {
815   UnrankedMemRefDescriptor d(packed);
816   results.reserve(results.size() + 2);
817   results.push_back(d.rank(builder, loc));
818   results.push_back(d.memRefDescPtr(builder, loc));
819 }
820 
computeSizes(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,ArrayRef<UnrankedMemRefDescriptor> values,SmallVectorImpl<Value> & sizes)821 void UnrankedMemRefDescriptor::computeSizes(
822     OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
823     ArrayRef<UnrankedMemRefDescriptor> values, SmallVectorImpl<Value> &sizes) {
824   if (values.empty())
825     return;
826 
827   // Cache the index type.
828   LLVM::LLVMType indexType = typeConverter.getIndexType();
829 
830   // Initialize shared constants.
831   Value one = createIndexAttrConstant(builder, loc, indexType, 1);
832   Value two = createIndexAttrConstant(builder, loc, indexType, 2);
833   Value pointerSize = createIndexAttrConstant(
834       builder, loc, indexType, ceilDiv(typeConverter.getPointerBitwidth(), 8));
835   Value indexSize =
836       createIndexAttrConstant(builder, loc, indexType,
837                               ceilDiv(typeConverter.getIndexTypeBitwidth(), 8));
838 
839   sizes.reserve(sizes.size() + values.size());
840   for (UnrankedMemRefDescriptor desc : values) {
841     // Emit IR computing the memory necessary to store the descriptor. This
842     // assumes the descriptor to be
843     //   { type*, type*, index, index[rank], index[rank] }
844     // and densely packed, so the total size is
845     //   2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
846     // TODO: consider including the actual size (including eventual padding due
847     // to data layout) into the unranked descriptor.
848     Value doublePointerSize =
849         builder.create<LLVM::MulOp>(loc, indexType, two, pointerSize);
850 
851     // (1 + 2 * rank) * sizeof(index)
852     Value rank = desc.rank(builder, loc);
853     Value doubleRank = builder.create<LLVM::MulOp>(loc, indexType, two, rank);
854     Value doubleRankIncremented =
855         builder.create<LLVM::AddOp>(loc, indexType, doubleRank, one);
856     Value rankIndexSize = builder.create<LLVM::MulOp>(
857         loc, indexType, doubleRankIncremented, indexSize);
858 
859     // Total allocation size.
860     Value allocationSize = builder.create<LLVM::AddOp>(
861         loc, indexType, doublePointerSize, rankIndexSize);
862     sizes.push_back(allocationSize);
863   }
864 }
865 
allocatedPtr(OpBuilder & builder,Location loc,Value memRefDescPtr,LLVM::LLVMType elemPtrPtrType)866 Value UnrankedMemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc,
867                                              Value memRefDescPtr,
868                                              LLVM::LLVMType elemPtrPtrType) {
869 
870   Value elementPtrPtr =
871       builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
872   return builder.create<LLVM::LoadOp>(loc, elementPtrPtr);
873 }
874 
setAllocatedPtr(OpBuilder & builder,Location loc,Value memRefDescPtr,LLVM::LLVMType elemPtrPtrType,Value allocatedPtr)875 void UnrankedMemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
876                                                Value memRefDescPtr,
877                                                LLVM::LLVMType elemPtrPtrType,
878                                                Value allocatedPtr) {
879   Value elementPtrPtr =
880       builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
881   builder.create<LLVM::StoreOp>(loc, allocatedPtr, elementPtrPtr);
882 }
883 
alignedPtr(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,LLVM::LLVMType elemPtrPtrType)884 Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc,
885                                            LLVMTypeConverter &typeConverter,
886                                            Value memRefDescPtr,
887                                            LLVM::LLVMType elemPtrPtrType) {
888   Value elementPtrPtr =
889       builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
890 
891   Value one =
892       createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1);
893   Value alignedGep = builder.create<LLVM::GEPOp>(
894       loc, elemPtrPtrType, elementPtrPtr, ValueRange({one}));
895   return builder.create<LLVM::LoadOp>(loc, alignedGep);
896 }
897 
setAlignedPtr(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,LLVM::LLVMType elemPtrPtrType,Value alignedPtr)898 void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
899                                              LLVMTypeConverter &typeConverter,
900                                              Value memRefDescPtr,
901                                              LLVM::LLVMType elemPtrPtrType,
902                                              Value alignedPtr) {
903   Value elementPtrPtr =
904       builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
905 
906   Value one =
907       createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1);
908   Value alignedGep = builder.create<LLVM::GEPOp>(
909       loc, elemPtrPtrType, elementPtrPtr, ValueRange({one}));
910   builder.create<LLVM::StoreOp>(loc, alignedPtr, alignedGep);
911 }
912 
offset(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,LLVM::LLVMType elemPtrPtrType)913 Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
914                                        LLVMTypeConverter &typeConverter,
915                                        Value memRefDescPtr,
916                                        LLVM::LLVMType elemPtrPtrType) {
917   Value elementPtrPtr =
918       builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
919 
920   Value two =
921       createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2);
922   Value offsetGep = builder.create<LLVM::GEPOp>(
923       loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
924   offsetGep = builder.create<LLVM::BitcastOp>(
925       loc, typeConverter.getIndexType().getPointerTo(), offsetGep);
926   return builder.create<LLVM::LoadOp>(loc, offsetGep);
927 }
928 
setOffset(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,LLVM::LLVMType elemPtrPtrType,Value offset)929 void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
930                                          LLVMTypeConverter &typeConverter,
931                                          Value memRefDescPtr,
932                                          LLVM::LLVMType elemPtrPtrType,
933                                          Value offset) {
934   Value elementPtrPtr =
935       builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
936 
937   Value two =
938       createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2);
939   Value offsetGep = builder.create<LLVM::GEPOp>(
940       loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
941   offsetGep = builder.create<LLVM::BitcastOp>(
942       loc, typeConverter.getIndexType().getPointerTo(), offsetGep);
943   builder.create<LLVM::StoreOp>(loc, offset, offsetGep);
944 }
945 
sizeBasePtr(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,LLVM::LLVMType elemPtrPtrType)946 Value UnrankedMemRefDescriptor::sizeBasePtr(OpBuilder &builder, Location loc,
947                                             LLVMTypeConverter &typeConverter,
948                                             Value memRefDescPtr,
949                                             LLVM::LLVMType elemPtrPtrType) {
950   LLVM::LLVMType elemPtrTy = elemPtrPtrType.getPointerElementTy();
951   LLVM::LLVMType indexTy = typeConverter.getIndexType();
952   LLVM::LLVMType structPtrTy =
953       LLVM::LLVMType::getStructTy(elemPtrTy, elemPtrTy, indexTy, indexTy)
954           .getPointerTo();
955   Value structPtr =
956       builder.create<LLVM::BitcastOp>(loc, structPtrTy, memRefDescPtr);
957 
958   LLVM::LLVMType int32_type =
959       unwrap(typeConverter.convertType(builder.getI32Type()));
960   Value zero =
961       createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0);
962   Value three = builder.create<LLVM::ConstantOp>(loc, int32_type,
963                                                  builder.getI32IntegerAttr(3));
964   return builder.create<LLVM::GEPOp>(loc, indexTy.getPointerTo(), structPtr,
965                                      ValueRange({zero, three}));
966 }
967 
size(OpBuilder & builder,Location loc,LLVMTypeConverter typeConverter,Value sizeBasePtr,Value index)968 Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc,
969                                      LLVMTypeConverter typeConverter,
970                                      Value sizeBasePtr, Value index) {
971   LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
972   Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
973                                                    ValueRange({index}));
974   return builder.create<LLVM::LoadOp>(loc, sizeStoreGep);
975 }
976 
setSize(OpBuilder & builder,Location loc,LLVMTypeConverter typeConverter,Value sizeBasePtr,Value index,Value size)977 void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
978                                        LLVMTypeConverter typeConverter,
979                                        Value sizeBasePtr, Value index,
980                                        Value size) {
981   LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
982   Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
983                                                    ValueRange({index}));
984   builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep);
985 }
986 
strideBasePtr(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value sizeBasePtr,Value rank)987 Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc,
988                                               LLVMTypeConverter &typeConverter,
989                                               Value sizeBasePtr, Value rank) {
990   LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
991   return builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
992                                      ValueRange({rank}));
993 }
994 
stride(OpBuilder & builder,Location loc,LLVMTypeConverter typeConverter,Value strideBasePtr,Value index,Value stride)995 Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc,
996                                        LLVMTypeConverter typeConverter,
997                                        Value strideBasePtr, Value index,
998                                        Value stride) {
999   LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
1000   Value strideStoreGep = builder.create<LLVM::GEPOp>(
1001       loc, indexPtrTy, strideBasePtr, ValueRange({index}));
1002   return builder.create<LLVM::LoadOp>(loc, strideStoreGep);
1003 }
1004 
setStride(OpBuilder & builder,Location loc,LLVMTypeConverter typeConverter,Value strideBasePtr,Value index,Value stride)1005 void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc,
1006                                          LLVMTypeConverter typeConverter,
1007                                          Value strideBasePtr, Value index,
1008                                          Value stride) {
1009   LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo();
1010   Value strideStoreGep = builder.create<LLVM::GEPOp>(
1011       loc, indexPtrTy, strideBasePtr, ValueRange({index}));
1012   builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep);
1013 }
1014 
getTypeConverter() const1015 LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const {
1016   return static_cast<LLVMTypeConverter *>(
1017       ConversionPattern::getTypeConverter());
1018 }
1019 
getDialect() const1020 LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
1021   return *getTypeConverter()->getDialect();
1022 }
1023 
getIndexType() const1024 LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const {
1025   return getTypeConverter()->getIndexType();
1026 }
1027 
1028 LLVM::LLVMType
getIntPtrType(unsigned addressSpace) const1029 ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
1030   return LLVM::LLVMType::getIntNTy(
1031       &getTypeConverter()->getContext(),
1032       getTypeConverter()->getPointerBitwidth(addressSpace));
1033 }
1034 
getVoidType() const1035 LLVM::LLVMType ConvertToLLVMPattern::getVoidType() const {
1036   return LLVM::LLVMType::getVoidTy(&getTypeConverter()->getContext());
1037 }
1038 
getVoidPtrType() const1039 LLVM::LLVMType ConvertToLLVMPattern::getVoidPtrType() const {
1040   return LLVM::LLVMType::getInt8PtrTy(&getTypeConverter()->getContext());
1041 }
1042 
createIndexConstant(ConversionPatternRewriter & builder,Location loc,uint64_t value) const1043 Value ConvertToLLVMPattern::createIndexConstant(
1044     ConversionPatternRewriter &builder, Location loc, uint64_t value) const {
1045   return createIndexAttrConstant(builder, loc, getIndexType(), value);
1046 }
1047 
getStridedElementPtr(Location loc,MemRefType type,Value memRefDesc,ValueRange indices,ConversionPatternRewriter & rewriter) const1048 Value ConvertToLLVMPattern::getStridedElementPtr(
1049     Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
1050     ConversionPatternRewriter &rewriter) const {
1051 
1052   int64_t offset;
1053   SmallVector<int64_t, 4> strides;
1054   auto successStrides = getStridesAndOffset(type, strides, offset);
1055   assert(succeeded(successStrides) && "unexpected non-strided memref");
1056   (void)successStrides;
1057 
1058   MemRefDescriptor memRefDescriptor(memRefDesc);
1059   Value base = memRefDescriptor.alignedPtr(rewriter, loc);
1060 
1061   Value index;
1062   if (offset != 0) // Skip if offset is zero.
1063     index = offset == MemRefType::getDynamicStrideOrOffset()
1064                 ? memRefDescriptor.offset(rewriter, loc)
1065                 : createIndexConstant(rewriter, loc, offset);
1066 
1067   for (int i = 0, e = indices.size(); i < e; ++i) {
1068     Value increment = indices[i];
1069     if (strides[i] != 1) { // Skip if stride is 1.
1070       Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset()
1071                          ? memRefDescriptor.stride(rewriter, loc, i)
1072                          : createIndexConstant(rewriter, loc, strides[i]);
1073       increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
1074     }
1075     index =
1076         index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
1077   }
1078 
1079   LLVM::LLVMType elementPtrType = memRefDescriptor.getElementPtrType();
1080   return index ? rewriter.create<LLVM::GEPOp>(loc, elementPtrType, base, index)
1081                : base;
1082 }
1083 
getDataPtr(Location loc,MemRefType type,Value memRefDesc,ValueRange indices,ConversionPatternRewriter & rewriter) const1084 Value ConvertToLLVMPattern::getDataPtr(
1085     Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
1086     ConversionPatternRewriter &rewriter) const {
1087   return getStridedElementPtr(loc, type, memRefDesc, indices, rewriter);
1088 }
1089 
1090 // Check if the MemRefType `type` is supported by the lowering. We currently
1091 // only support memrefs with identity maps.
isSupportedMemRefType(MemRefType type) const1092 bool ConvertToLLVMPattern::isSupportedMemRefType(MemRefType type) const {
1093   if (!typeConverter->convertType(type.getElementType()))
1094     return false;
1095   return type.getAffineMaps().empty() ||
1096          llvm::all_of(type.getAffineMaps(),
1097                       [](AffineMap map) { return map.isIdentity(); });
1098 }
1099 
getElementPtrType(MemRefType type) const1100 Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
1101   auto elementType = type.getElementType();
1102   auto structElementType = unwrap(typeConverter->convertType(elementType));
1103   return structElementType.getPointerTo(type.getMemorySpace());
1104 }
1105 
getMemRefDescriptorSizes(Location loc,MemRefType memRefType,ArrayRef<Value> dynamicSizes,ConversionPatternRewriter & rewriter,SmallVectorImpl<Value> & sizes,SmallVectorImpl<Value> & strides,Value & sizeBytes) const1106 void ConvertToLLVMPattern::getMemRefDescriptorSizes(
1107     Location loc, MemRefType memRefType, ArrayRef<Value> dynamicSizes,
1108     ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes,
1109     SmallVectorImpl<Value> &strides, Value &sizeBytes) const {
1110   assert(isSupportedMemRefType(memRefType) &&
1111          "layout maps must have been normalized away");
1112 
1113   sizes.reserve(memRefType.getRank());
1114   unsigned dynamicIndex = 0;
1115   for (int64_t size : memRefType.getShape()) {
1116     sizes.push_back(size == ShapedType::kDynamicSize
1117                         ? dynamicSizes[dynamicIndex++]
1118                         : createIndexConstant(rewriter, loc, size));
1119   }
1120 
1121   // Strides: iterate sizes in reverse order and multiply.
1122   int64_t stride = 1;
1123   Value runningStride = createIndexConstant(rewriter, loc, 1);
1124   strides.resize(memRefType.getRank());
1125   for (auto i = memRefType.getRank(); i-- > 0;) {
1126     strides[i] = runningStride;
1127 
1128     int64_t size = memRefType.getShape()[i];
1129     if (size == 0)
1130       continue;
1131     bool useSizeAsStride = stride == 1;
1132     if (size == ShapedType::kDynamicSize)
1133       stride = ShapedType::kDynamicSize;
1134     if (stride != ShapedType::kDynamicSize)
1135       stride *= size;
1136 
1137     if (useSizeAsStride)
1138       runningStride = sizes[i];
1139     else if (stride == ShapedType::kDynamicSize)
1140       runningStride =
1141           rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]);
1142     else
1143       runningStride = createIndexConstant(rewriter, loc, stride);
1144   }
1145 
1146   // Buffer size in bytes.
1147   Type elementPtrType = getElementPtrType(memRefType);
1148   Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
1149   Value gepPtr = rewriter.create<LLVM::GEPOp>(
1150       loc, elementPtrType, ArrayRef<Value>{nullPtr, runningStride});
1151   sizeBytes = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
1152 }
1153 
getSizeInBytes(Location loc,Type type,ConversionPatternRewriter & rewriter) const1154 Value ConvertToLLVMPattern::getSizeInBytes(
1155     Location loc, Type type, ConversionPatternRewriter &rewriter) const {
1156   // Compute the size of an individual element. This emits the MLIR equivalent
1157   // of the following sizeof(...) implementation in LLVM IR:
1158   //   %0 = getelementptr %elementType* null, %indexType 1
1159   //   %1 = ptrtoint %elementType* %0 to %indexType
1160   // which is a common pattern of getting the size of a type in bytes.
1161   auto convertedPtrType =
1162       typeConverter->convertType(type).cast<LLVM::LLVMType>().getPointerTo();
1163   auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
1164   auto gep = rewriter.create<LLVM::GEPOp>(
1165       loc, convertedPtrType,
1166       ArrayRef<Value>{nullPtr, createIndexConstant(rewriter, loc, 1)});
1167   return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
1168 }
1169 
getNumElements(Location loc,ArrayRef<Value> shape,ConversionPatternRewriter & rewriter) const1170 Value ConvertToLLVMPattern::getNumElements(
1171     Location loc, ArrayRef<Value> shape,
1172     ConversionPatternRewriter &rewriter) const {
1173   // Compute the total number of memref elements.
1174   Value numElements =
1175       shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front();
1176   for (unsigned i = 1, e = shape.size(); i < e; ++i)
1177     numElements = rewriter.create<LLVM::MulOp>(loc, numElements, shape[i]);
1178   return numElements;
1179 }
1180 
1181 /// Creates and populates the memref descriptor struct given all its fields.
createMemRefDescriptor(Location loc,MemRefType memRefType,Value allocatedPtr,Value alignedPtr,ArrayRef<Value> sizes,ArrayRef<Value> strides,ConversionPatternRewriter & rewriter) const1182 MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
1183     Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
1184     ArrayRef<Value> sizes, ArrayRef<Value> strides,
1185     ConversionPatternRewriter &rewriter) const {
1186   auto structType = typeConverter->convertType(memRefType);
1187   auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
1188 
1189   // Field 1: Allocated pointer, used for malloc/free.
1190   memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
1191 
1192   // Field 2: Actual aligned pointer to payload.
1193   memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
1194 
1195   // Field 3: Offset in aligned pointer.
1196   memRefDescriptor.setOffset(rewriter, loc,
1197                              createIndexConstant(rewriter, loc, 0));
1198 
1199   // Fields 4: Sizes.
1200   for (auto en : llvm::enumerate(sizes))
1201     memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
1202 
1203   // Field 5: Strides.
1204   for (auto en : llvm::enumerate(strides))
1205     memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
1206 
1207   return memRefDescriptor;
1208 }
1209 
1210 /// Only retain those attributes that are not constructed by
1211 /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
1212 /// attributes.
filterFuncAttributes(ArrayRef<NamedAttribute> attrs,bool filterArgAttrs,SmallVectorImpl<NamedAttribute> & result)1213 static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
1214                                  bool filterArgAttrs,
1215                                  SmallVectorImpl<NamedAttribute> &result) {
1216   for (const auto &attr : attrs) {
1217     if (attr.first == SymbolTable::getSymbolAttrName() ||
1218         attr.first == impl::getTypeAttrName() || attr.first == "std.varargs" ||
1219         (filterArgAttrs && impl::isArgAttrName(attr.first.strref())))
1220       continue;
1221     result.push_back(attr);
1222   }
1223 }
1224 
1225 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
1226 /// arguments instead of unpacked arguments. This function can be called from C
1227 /// by passing a pointer to a C struct corresponding to a memref descriptor.
1228 /// Internally, the auxiliary function unpacks the descriptor into individual
1229 /// components and forwards them to `newFuncOp`.
wrapForExternalCallers(OpBuilder & rewriter,Location loc,LLVMTypeConverter & typeConverter,FuncOp funcOp,LLVM::LLVMFuncOp newFuncOp)1230 static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
1231                                    LLVMTypeConverter &typeConverter,
1232                                    FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
1233   auto type = funcOp.getType();
1234   SmallVector<NamedAttribute, 4> attributes;
1235   filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/false, attributes);
1236   auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
1237       loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
1238       typeConverter.convertFunctionTypeCWrapper(type), LLVM::Linkage::External,
1239       attributes);
1240 
1241   OpBuilder::InsertionGuard guard(rewriter);
1242   rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock());
1243 
1244   SmallVector<Value, 8> args;
1245   for (auto &en : llvm::enumerate(type.getInputs())) {
1246     Value arg = wrapperFuncOp.getArgument(en.index());
1247     if (auto memrefType = en.value().dyn_cast<MemRefType>()) {
1248       Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
1249       MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
1250       continue;
1251     }
1252     if (en.value().isa<UnrankedMemRefType>()) {
1253       Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
1254       UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
1255       continue;
1256     }
1257 
1258     args.push_back(wrapperFuncOp.getArgument(en.index()));
1259   }
1260   auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
1261   rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
1262 }
1263 
1264 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
1265 /// arguments instead of unpacked arguments. Creates a body for the (external)
1266 /// `newFuncOp` that allocates a memref descriptor on stack, packs the
1267 /// individual arguments into this descriptor and passes a pointer to it into
1268 /// the auxiliary function. This auxiliary external function is now compatible
1269 /// with functions defined in C using pointers to C structs corresponding to a
1270 /// memref descriptor.
wrapExternalFunction(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,FuncOp funcOp,LLVM::LLVMFuncOp newFuncOp)1271 static void wrapExternalFunction(OpBuilder &builder, Location loc,
1272                                  LLVMTypeConverter &typeConverter,
1273                                  FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
1274   OpBuilder::InsertionGuard guard(builder);
1275 
1276   LLVM::LLVMType wrapperType =
1277       typeConverter.convertFunctionTypeCWrapper(funcOp.getType());
1278   // This conversion can only fail if it could not convert one of the argument
1279   // types. But since it has been applies to a non-wrapper function before, it
1280   // should have failed earlier and not reach this point at all.
1281   assert(wrapperType && "unexpected type conversion failure");
1282 
1283   SmallVector<NamedAttribute, 4> attributes;
1284   filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/false, attributes);
1285 
1286   // Create the auxiliary function.
1287   auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
1288       loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
1289       wrapperType, LLVM::Linkage::External, attributes);
1290 
1291   builder.setInsertionPointToStart(newFuncOp.addEntryBlock());
1292 
1293   // Get a ValueRange containing arguments.
1294   FunctionType type = funcOp.getType();
1295   SmallVector<Value, 8> args;
1296   args.reserve(type.getNumInputs());
1297   ValueRange wrapperArgsRange(newFuncOp.getArguments());
1298 
1299   // Iterate over the inputs of the original function and pack values into
1300   // memref descriptors if the original type is a memref.
1301   for (auto &en : llvm::enumerate(type.getInputs())) {
1302     Value arg;
1303     int numToDrop = 1;
1304     auto memRefType = en.value().dyn_cast<MemRefType>();
1305     auto unrankedMemRefType = en.value().dyn_cast<UnrankedMemRefType>();
1306     if (memRefType || unrankedMemRefType) {
1307       numToDrop = memRefType
1308                       ? MemRefDescriptor::getNumUnpackedValues(memRefType)
1309                       : UnrankedMemRefDescriptor::getNumUnpackedValues();
1310       Value packed =
1311           memRefType
1312               ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
1313                                        wrapperArgsRange.take_front(numToDrop))
1314               : UnrankedMemRefDescriptor::pack(
1315                     builder, loc, typeConverter, unrankedMemRefType,
1316                     wrapperArgsRange.take_front(numToDrop));
1317 
1318       auto ptrTy = packed.getType().cast<LLVM::LLVMType>().getPointerTo();
1319       Value one = builder.create<LLVM::ConstantOp>(
1320           loc, typeConverter.convertType(builder.getIndexType()),
1321           builder.getIntegerAttr(builder.getIndexType(), 1));
1322       Value allocated =
1323           builder.create<LLVM::AllocaOp>(loc, ptrTy, one, /*alignment=*/0);
1324       builder.create<LLVM::StoreOp>(loc, packed, allocated);
1325       arg = allocated;
1326     } else {
1327       arg = wrapperArgsRange[0];
1328     }
1329 
1330     args.push_back(arg);
1331     wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
1332   }
1333   assert(wrapperArgsRange.empty() && "did not map some of the arguments");
1334 
1335   auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
1336   builder.create<LLVM::ReturnOp>(loc, call.getResults());
1337 }
1338 
1339 namespace {
1340 
1341 struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
1342 protected:
1343   using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
1344 
1345   // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
1346   // to this legalization pattern.
1347   LLVM::LLVMFuncOp
convertFuncOpToLLVMFuncOp__anone5172fdb0f11::FuncOpConversionBase1348   convertFuncOpToLLVMFuncOp(FuncOp funcOp,
1349                             ConversionPatternRewriter &rewriter) const {
1350     // Convert the original function arguments. They are converted using the
1351     // LLVMTypeConverter provided to this legalization pattern.
1352     auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("std.varargs");
1353     TypeConverter::SignatureConversion result(funcOp.getNumArguments());
1354     auto llvmType = getTypeConverter()->convertFunctionSignature(
1355         funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
1356     if (!llvmType)
1357       return nullptr;
1358 
1359     // Propagate argument attributes to all converted arguments obtained after
1360     // converting a given original argument.
1361     SmallVector<NamedAttribute, 4> attributes;
1362     filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/true,
1363                          attributes);
1364     for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
1365       auto attr = impl::getArgAttrDict(funcOp, i);
1366       if (!attr)
1367         continue;
1368 
1369       auto mapping = result.getInputMapping(i);
1370       assert(mapping.hasValue() && "unexpected deletion of function argument");
1371 
1372       SmallString<8> name;
1373       for (size_t j = 0; j < mapping->size; ++j) {
1374         impl::getArgAttrName(mapping->inputNo + j, name);
1375         attributes.push_back(rewriter.getNamedAttr(name, attr));
1376       }
1377     }
1378 
1379     // Create an LLVM function, use external linkage by default until MLIR
1380     // functions have linkage.
1381     auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
1382         funcOp.getLoc(), funcOp.getName(), llvmType, LLVM::Linkage::External,
1383         attributes);
1384     rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1385                                 newFuncOp.end());
1386     if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
1387                                            &result)))
1388       return nullptr;
1389 
1390     return newFuncOp;
1391   }
1392 };
1393 
1394 /// FuncOp legalization pattern that converts MemRef arguments to pointers to
1395 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type
1396 /// information.
1397 static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface";
1398 struct FuncOpConversion : public FuncOpConversionBase {
FuncOpConversion__anone5172fdb0f11::FuncOpConversion1399   FuncOpConversion(LLVMTypeConverter &converter)
1400       : FuncOpConversionBase(converter) {}
1401 
1402   LogicalResult
matchAndRewrite__anone5172fdb0f11::FuncOpConversion1403   matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
1404                   ConversionPatternRewriter &rewriter) const override {
1405     auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
1406     if (!newFuncOp)
1407       return failure();
1408 
1409     if (getTypeConverter()->getOptions().emitCWrappers ||
1410         funcOp->getAttrOfType<UnitAttr>(kEmitIfaceAttrName)) {
1411       if (newFuncOp.isExternal())
1412         wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(),
1413                              funcOp, newFuncOp);
1414       else
1415         wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(),
1416                                funcOp, newFuncOp);
1417     }
1418 
1419     rewriter.eraseOp(funcOp);
1420     return success();
1421   }
1422 };
1423 
1424 /// FuncOp legalization pattern that converts MemRef arguments to bare pointers
1425 /// to the MemRef element type. This will impact the calling convention and ABI.
1426 struct BarePtrFuncOpConversion : public FuncOpConversionBase {
1427   using FuncOpConversionBase::FuncOpConversionBase;
1428 
1429   LogicalResult
matchAndRewrite__anone5172fdb0f11::BarePtrFuncOpConversion1430   matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
1431                   ConversionPatternRewriter &rewriter) const override {
1432     // Store the type of memref-typed arguments before the conversion so that we
1433     // can promote them to MemRef descriptor at the beginning of the function.
1434     SmallVector<Type, 8> oldArgTypes =
1435         llvm::to_vector<8>(funcOp.getType().getInputs());
1436 
1437     auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
1438     if (!newFuncOp)
1439       return failure();
1440     if (newFuncOp.getBody().empty()) {
1441       rewriter.eraseOp(funcOp);
1442       return success();
1443     }
1444 
1445     // Promote bare pointers from memref arguments to memref descriptors at the
1446     // beginning of the function so that all the memrefs in the function have a
1447     // uniform representation.
1448     Block *entryBlock = &newFuncOp.getBody().front();
1449     auto blockArgs = entryBlock->getArguments();
1450     assert(blockArgs.size() == oldArgTypes.size() &&
1451            "The number of arguments and types doesn't match");
1452 
1453     OpBuilder::InsertionGuard guard(rewriter);
1454     rewriter.setInsertionPointToStart(entryBlock);
1455     for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
1456       BlockArgument arg = std::get<0>(it);
1457       Type argTy = std::get<1>(it);
1458 
1459       // Unranked memrefs are not supported in the bare pointer calling
1460       // convention. We should have bailed out before in the presence of
1461       // unranked memrefs.
1462       assert(!argTy.isa<UnrankedMemRefType>() &&
1463              "Unranked memref is not supported");
1464       auto memrefTy = argTy.dyn_cast<MemRefType>();
1465       if (!memrefTy)
1466         continue;
1467 
1468       // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
1469       // or unranked memref descriptor and replace placeholder with the last
1470       // instruction of the memref descriptor.
1471       // TODO: The placeholder is needed to avoid replacing barePtr uses in the
1472       // MemRef descriptor instructions. We may want to have a utility in the
1473       // rewriter to properly handle this use case.
1474       Location loc = funcOp.getLoc();
1475       auto placeholder = rewriter.create<LLVM::UndefOp>(loc, memrefTy);
1476       rewriter.replaceUsesOfBlockArgument(arg, placeholder);
1477 
1478       Value desc = MemRefDescriptor::fromStaticShape(
1479           rewriter, loc, *getTypeConverter(), memrefTy, arg);
1480       rewriter.replaceOp(placeholder, {desc});
1481     }
1482 
1483     rewriter.eraseOp(funcOp);
1484     return success();
1485   }
1486 };
1487 
1488 //////////////// Support for Lowering operations on n-D vectors ////////////////
1489 // Helper struct to "unroll" operations on n-D vectors in terms of operations on
1490 // 1-D LLVM vectors.
1491 struct NDVectorTypeInfo {
1492   // LLVM array struct which encodes n-D vectors.
1493   LLVM::LLVMType llvmArrayTy;
1494   // LLVM vector type which encodes the inner 1-D vector type.
1495   LLVM::LLVMType llvmVectorTy;
1496   // Multiplicity of llvmArrayTy to llvmVectorTy.
1497   SmallVector<int64_t, 4> arraySizes;
1498 };
1499 } // namespace
1500 
1501 // For >1-D vector types, extracts the necessary information to iterate over all
1502 // 1-D subvectors in the underlying llrepresentation of the n-D vector
1503 // Iterates on the llvm array type until we hit a non-array type (which is
1504 // asserted to be an llvm vector type).
extractNDVectorTypeInfo(VectorType vectorType,LLVMTypeConverter & converter)1505 static NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType,
1506                                                 LLVMTypeConverter &converter) {
1507   assert(vectorType.getRank() > 1 && "expected >1D vector type");
1508   NDVectorTypeInfo info;
1509   info.llvmArrayTy =
1510       converter.convertType(vectorType).dyn_cast<LLVM::LLVMType>();
1511   if (!info.llvmArrayTy)
1512     return info;
1513   info.arraySizes.reserve(vectorType.getRank() - 1);
1514   auto llvmTy = info.llvmArrayTy;
1515   while (llvmTy.isArrayTy()) {
1516     info.arraySizes.push_back(llvmTy.getArrayNumElements());
1517     llvmTy = llvmTy.getArrayElementType();
1518   }
1519   if (!llvmTy.isVectorTy())
1520     return info;
1521   info.llvmVectorTy = llvmTy;
1522   return info;
1523 }
1524 
1525 // Express `linearIndex` in terms of coordinates of `basis`.
1526 // Returns the empty vector when linearIndex is out of the range [0, P] where
1527 // P is the product of all the basis coordinates.
1528 //
1529 // Prerequisites:
1530 //   Basis is an array of nonnegative integers (signed type inherited from
1531 //   vector shape type).
getCoordinates(ArrayRef<int64_t> basis,unsigned linearIndex)1532 static SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
1533                                               unsigned linearIndex) {
1534   SmallVector<int64_t, 4> res;
1535   res.reserve(basis.size());
1536   for (unsigned basisElement : llvm::reverse(basis)) {
1537     res.push_back(linearIndex % basisElement);
1538     linearIndex = linearIndex / basisElement;
1539   }
1540   if (linearIndex > 0)
1541     return {};
1542   std::reverse(res.begin(), res.end());
1543   return res;
1544 }
1545 
1546 // Iterate of linear index, convert to coords space and insert splatted 1-D
1547 // vector in each position.
1548 template <typename Lambda>
nDVectorIterate(const NDVectorTypeInfo & info,OpBuilder & builder,Lambda fun)1549 void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder,
1550                      Lambda fun) {
1551   unsigned ub = 1;
1552   for (auto s : info.arraySizes)
1553     ub *= s;
1554   for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
1555     auto coords = getCoordinates(info.arraySizes, linearIndex);
1556     // Linear index is out of bounds, we are done.
1557     if (coords.empty())
1558       break;
1559     assert(coords.size() == info.arraySizes.size());
1560     auto position = builder.getI64ArrayAttr(coords);
1561     fun(position);
1562   }
1563 }
1564 ////////////// End Support for Lowering operations on n-D vectors //////////////
1565 
1566 /// Replaces the given operation "op" with a new operation of type "targetOp"
1567 /// and given operands.
oneToOneRewrite(Operation * op,StringRef targetOp,ValueRange operands,LLVMTypeConverter & typeConverter,ConversionPatternRewriter & rewriter)1568 LogicalResult LLVM::detail::oneToOneRewrite(
1569     Operation *op, StringRef targetOp, ValueRange operands,
1570     LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
1571   unsigned numResults = op->getNumResults();
1572 
1573   Type packedType;
1574   if (numResults != 0) {
1575     packedType = typeConverter.packFunctionResults(op->getResultTypes());
1576     if (!packedType)
1577       return failure();
1578   }
1579 
1580   // Create the operation through state since we don't know its C++ type.
1581   OperationState state(op->getLoc(), targetOp);
1582   state.addTypes(packedType);
1583   state.addOperands(operands);
1584   state.addAttributes(op->getAttrs());
1585   Operation *newOp = rewriter.createOperation(state);
1586 
1587   // If the operation produced 0 or 1 result, return them immediately.
1588   if (numResults == 0)
1589     return rewriter.eraseOp(op), success();
1590   if (numResults == 1)
1591     return rewriter.replaceOp(op, newOp->getResult(0)), success();
1592 
1593   // Otherwise, it had been converted to an operation producing a structure.
1594   // Extract individual results from the structure and return them as list.
1595   SmallVector<Value, 4> results;
1596   results.reserve(numResults);
1597   for (unsigned i = 0; i < numResults; ++i) {
1598     auto type = typeConverter.convertType(op->getResult(i).getType());
1599     results.push_back(rewriter.create<LLVM::ExtractValueOp>(
1600         op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i)));
1601   }
1602   rewriter.replaceOp(op, results);
1603   return success();
1604 }
1605 
handleMultidimensionalVectors(Operation * op,ValueRange operands,LLVMTypeConverter & typeConverter,std::function<Value (LLVM::LLVMType,ValueRange)> createOperand,ConversionPatternRewriter & rewriter)1606 static LogicalResult handleMultidimensionalVectors(
1607     Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
1608     std::function<Value(LLVM::LLVMType, ValueRange)> createOperand,
1609     ConversionPatternRewriter &rewriter) {
1610   auto vectorType = op->getResult(0).getType().dyn_cast<VectorType>();
1611   if (!vectorType)
1612     return failure();
1613   auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, typeConverter);
1614   auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
1615   auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
1616   if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy)
1617     return failure();
1618 
1619   auto loc = op->getLoc();
1620   Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
1621   nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
1622     // For this unrolled `position` corresponding to the `linearIndex`^th
1623     // element, extract operand vectors
1624     SmallVector<Value, 4> extractedOperands;
1625     for (auto operand : operands)
1626       extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
1627           loc, llvmVectorTy, operand, position));
1628     Value newVal = createOperand(llvmVectorTy, extractedOperands);
1629     desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, newVal,
1630                                                 position);
1631   });
1632   rewriter.replaceOp(op, desc);
1633   return success();
1634 }
1635 
vectorOneToOneRewrite(Operation * op,StringRef targetOp,ValueRange operands,LLVMTypeConverter & typeConverter,ConversionPatternRewriter & rewriter)1636 LogicalResult LLVM::detail::vectorOneToOneRewrite(
1637     Operation *op, StringRef targetOp, ValueRange operands,
1638     LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
1639   assert(!operands.empty());
1640 
1641   // Cannot convert ops if their operands are not of LLVM type.
1642   if (!llvm::all_of(operands.getTypes(),
1643                     [](Type t) { return t.isa<LLVM::LLVMType>(); }))
1644     return failure();
1645 
1646   auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
1647   if (!llvmArrayTy.isArrayTy())
1648     return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);
1649 
1650   auto callback = [op, targetOp, &rewriter](LLVM::LLVMType llvmVectorTy,
1651                                             ValueRange operands) {
1652     OperationState state(op->getLoc(), targetOp);
1653     state.addTypes(llvmVectorTy);
1654     state.addOperands(operands);
1655     state.addAttributes(op->getAttrs());
1656     return rewriter.createOperation(state)->getResult(0);
1657   };
1658 
1659   return handleMultidimensionalVectors(op, operands, typeConverter, callback,
1660                                        rewriter);
1661 }
1662 
1663 namespace {
1664 // Straightforward lowerings.
1665 using AbsFOpLowering = VectorConvertToLLVMPattern<AbsFOp, LLVM::FAbsOp>;
1666 using AddFOpLowering = VectorConvertToLLVMPattern<AddFOp, LLVM::FAddOp>;
1667 using AddIOpLowering = VectorConvertToLLVMPattern<AddIOp, LLVM::AddOp>;
1668 using AndOpLowering = VectorConvertToLLVMPattern<AndOp, LLVM::AndOp>;
1669 using CeilFOpLowering = VectorConvertToLLVMPattern<CeilFOp, LLVM::FCeilOp>;
1670 using CopySignOpLowering =
1671     VectorConvertToLLVMPattern<CopySignOp, LLVM::CopySignOp>;
1672 using CosOpLowering = VectorConvertToLLVMPattern<CosOp, LLVM::CosOp>;
1673 using DivFOpLowering = VectorConvertToLLVMPattern<DivFOp, LLVM::FDivOp>;
1674 using ExpOpLowering = VectorConvertToLLVMPattern<ExpOp, LLVM::ExpOp>;
1675 using Exp2OpLowering = VectorConvertToLLVMPattern<Exp2Op, LLVM::Exp2Op>;
1676 using FloorFOpLowering = VectorConvertToLLVMPattern<FloorFOp, LLVM::FFloorOp>;
1677 using Log10OpLowering = VectorConvertToLLVMPattern<Log10Op, LLVM::Log10Op>;
1678 using Log2OpLowering = VectorConvertToLLVMPattern<Log2Op, LLVM::Log2Op>;
1679 using LogOpLowering = VectorConvertToLLVMPattern<LogOp, LLVM::LogOp>;
1680 using MulFOpLowering = VectorConvertToLLVMPattern<MulFOp, LLVM::FMulOp>;
1681 using MulIOpLowering = VectorConvertToLLVMPattern<MulIOp, LLVM::MulOp>;
1682 using NegFOpLowering = VectorConvertToLLVMPattern<NegFOp, LLVM::FNegOp>;
1683 using OrOpLowering = VectorConvertToLLVMPattern<OrOp, LLVM::OrOp>;
1684 using RemFOpLowering = VectorConvertToLLVMPattern<RemFOp, LLVM::FRemOp>;
1685 using SelectOpLowering = OneToOneConvertToLLVMPattern<SelectOp, LLVM::SelectOp>;
1686 using ShiftLeftOpLowering =
1687     OneToOneConvertToLLVMPattern<ShiftLeftOp, LLVM::ShlOp>;
1688 using SignedDivIOpLowering =
1689     VectorConvertToLLVMPattern<SignedDivIOp, LLVM::SDivOp>;
1690 using SignedRemIOpLowering =
1691     VectorConvertToLLVMPattern<SignedRemIOp, LLVM::SRemOp>;
1692 using SignedShiftRightOpLowering =
1693     OneToOneConvertToLLVMPattern<SignedShiftRightOp, LLVM::AShrOp>;
1694 using SinOpLowering = VectorConvertToLLVMPattern<SinOp, LLVM::SinOp>;
1695 using SqrtOpLowering = VectorConvertToLLVMPattern<SqrtOp, LLVM::SqrtOp>;
1696 using SubFOpLowering = VectorConvertToLLVMPattern<SubFOp, LLVM::FSubOp>;
1697 using SubIOpLowering = VectorConvertToLLVMPattern<SubIOp, LLVM::SubOp>;
1698 using UnsignedDivIOpLowering =
1699     VectorConvertToLLVMPattern<UnsignedDivIOp, LLVM::UDivOp>;
1700 using UnsignedRemIOpLowering =
1701     VectorConvertToLLVMPattern<UnsignedRemIOp, LLVM::URemOp>;
1702 using UnsignedShiftRightOpLowering =
1703     OneToOneConvertToLLVMPattern<UnsignedShiftRightOp, LLVM::LShrOp>;
1704 using XOrOpLowering = VectorConvertToLLVMPattern<XOrOp, LLVM::XOrOp>;
1705 
1706 /// Lower `std.assert`. The default lowering calls the `abort` function if the
1707 /// assertion is violated and has no effect otherwise. The failure message is
1708 /// ignored by the default lowering but should be propagated by any custom
1709 /// lowering.
1710 struct AssertOpLowering : public ConvertOpToLLVMPattern<AssertOp> {
1711   using ConvertOpToLLVMPattern<AssertOp>::ConvertOpToLLVMPattern;
1712 
1713   LogicalResult
matchAndRewrite__anone5172fdb1311::AssertOpLowering1714   matchAndRewrite(AssertOp op, ArrayRef<Value> operands,
1715                   ConversionPatternRewriter &rewriter) const override {
1716     auto loc = op.getLoc();
1717     AssertOp::Adaptor transformed(operands);
1718 
1719     // Insert the `abort` declaration if necessary.
1720     auto module = op->getParentOfType<ModuleOp>();
1721     auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
1722     if (!abortFunc) {
1723       OpBuilder::InsertionGuard guard(rewriter);
1724       rewriter.setInsertionPointToStart(module.getBody());
1725       auto abortFuncTy =
1726           LLVM::LLVMType::getFunctionTy(getVoidType(), {}, /*isVarArg=*/false);
1727       abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
1728                                                     "abort", abortFuncTy);
1729     }
1730 
1731     // Split block at `assert` operation.
1732     Block *opBlock = rewriter.getInsertionBlock();
1733     auto opPosition = rewriter.getInsertionPoint();
1734     Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
1735 
1736     // Generate IR to call `abort`.
1737     Block *failureBlock = rewriter.createBlock(opBlock->getParent());
1738     rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None);
1739     rewriter.create<LLVM::UnreachableOp>(loc);
1740 
1741     // Generate assertion test.
1742     rewriter.setInsertionPointToEnd(opBlock);
1743     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
1744         op, transformed.arg(), continuationBlock, failureBlock);
1745 
1746     return success();
1747   }
1748 };
1749 
1750 // Lowerings for operations on complex numbers.
1751 
1752 struct CreateComplexOpLowering
1753     : public ConvertOpToLLVMPattern<CreateComplexOp> {
1754   using ConvertOpToLLVMPattern<CreateComplexOp>::ConvertOpToLLVMPattern;
1755 
1756   LogicalResult
matchAndRewrite__anone5172fdb1311::CreateComplexOpLowering1757   matchAndRewrite(CreateComplexOp op, ArrayRef<Value> operands,
1758                   ConversionPatternRewriter &rewriter) const override {
1759     auto complexOp = cast<CreateComplexOp>(op);
1760     CreateComplexOp::Adaptor transformed(operands);
1761 
1762     // Pack real and imaginary part in a complex number struct.
1763     auto loc = op.getLoc();
1764     auto structType = typeConverter->convertType(complexOp.getType());
1765     auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
1766     complexStruct.setReal(rewriter, loc, transformed.real());
1767     complexStruct.setImaginary(rewriter, loc, transformed.imaginary());
1768 
1769     rewriter.replaceOp(op, {complexStruct});
1770     return success();
1771   }
1772 };
1773 
1774 struct ReOpLowering : public ConvertOpToLLVMPattern<ReOp> {
1775   using ConvertOpToLLVMPattern<ReOp>::ConvertOpToLLVMPattern;
1776 
1777   LogicalResult
matchAndRewrite__anone5172fdb1311::ReOpLowering1778   matchAndRewrite(ReOp op, ArrayRef<Value> operands,
1779                   ConversionPatternRewriter &rewriter) const override {
1780     ReOp::Adaptor transformed(operands);
1781 
1782     // Extract real part from the complex number struct.
1783     ComplexStructBuilder complexStruct(transformed.complex());
1784     Value real = complexStruct.real(rewriter, op.getLoc());
1785     rewriter.replaceOp(op, real);
1786 
1787     return success();
1788   }
1789 };
1790 
1791 struct ImOpLowering : public ConvertOpToLLVMPattern<ImOp> {
1792   using ConvertOpToLLVMPattern<ImOp>::ConvertOpToLLVMPattern;
1793 
1794   LogicalResult
matchAndRewrite__anone5172fdb1311::ImOpLowering1795   matchAndRewrite(ImOp op, ArrayRef<Value> operands,
1796                   ConversionPatternRewriter &rewriter) const override {
1797     ImOp::Adaptor transformed(operands);
1798 
1799     // Extract imaginary part from the complex number struct.
1800     ComplexStructBuilder complexStruct(transformed.complex());
1801     Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
1802     rewriter.replaceOp(op, imaginary);
1803 
1804     return success();
1805   }
1806 };
1807 
1808 struct BinaryComplexOperands {
1809   std::complex<Value> lhs, rhs;
1810 };
1811 
1812 template <typename OpTy>
1813 BinaryComplexOperands
unpackBinaryComplexOperands(OpTy op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)1814 unpackBinaryComplexOperands(OpTy op, ArrayRef<Value> operands,
1815                             ConversionPatternRewriter &rewriter) {
1816   auto bop = cast<OpTy>(op);
1817   auto loc = bop.getLoc();
1818   typename OpTy::Adaptor transformed(operands);
1819 
1820   // Extract real and imaginary values from operands.
1821   BinaryComplexOperands unpacked;
1822   ComplexStructBuilder lhs(transformed.lhs());
1823   unpacked.lhs.real(lhs.real(rewriter, loc));
1824   unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
1825   ComplexStructBuilder rhs(transformed.rhs());
1826   unpacked.rhs.real(rhs.real(rewriter, loc));
1827   unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
1828 
1829   return unpacked;
1830 }
1831 
1832 struct AddCFOpLowering : public ConvertOpToLLVMPattern<AddCFOp> {
1833   using ConvertOpToLLVMPattern<AddCFOp>::ConvertOpToLLVMPattern;
1834 
1835   LogicalResult
matchAndRewrite__anone5172fdb1311::AddCFOpLowering1836   matchAndRewrite(AddCFOp op, ArrayRef<Value> operands,
1837                   ConversionPatternRewriter &rewriter) const override {
1838     auto loc = op.getLoc();
1839     BinaryComplexOperands arg =
1840         unpackBinaryComplexOperands<AddCFOp>(op, operands, rewriter);
1841 
1842     // Initialize complex number struct for result.
1843     auto structType = typeConverter->convertType(op.getType());
1844     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
1845 
1846     // Emit IR to add complex numbers.
1847     Value real =
1848         rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real());
1849     Value imag =
1850         rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag());
1851     result.setReal(rewriter, loc, real);
1852     result.setImaginary(rewriter, loc, imag);
1853 
1854     rewriter.replaceOp(op, {result});
1855     return success();
1856   }
1857 };
1858 
1859 struct SubCFOpLowering : public ConvertOpToLLVMPattern<SubCFOp> {
1860   using ConvertOpToLLVMPattern<SubCFOp>::ConvertOpToLLVMPattern;
1861 
1862   LogicalResult
matchAndRewrite__anone5172fdb1311::SubCFOpLowering1863   matchAndRewrite(SubCFOp op, ArrayRef<Value> operands,
1864                   ConversionPatternRewriter &rewriter) const override {
1865     auto loc = op.getLoc();
1866     BinaryComplexOperands arg =
1867         unpackBinaryComplexOperands<SubCFOp>(op, operands, rewriter);
1868 
1869     // Initialize complex number struct for result.
1870     auto structType = typeConverter->convertType(op.getType());
1871     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
1872 
1873     // Emit IR to substract complex numbers.
1874     Value real =
1875         rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real());
1876     Value imag =
1877         rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag());
1878     result.setReal(rewriter, loc, real);
1879     result.setImaginary(rewriter, loc, imag);
1880 
1881     rewriter.replaceOp(op, {result});
1882     return success();
1883   }
1884 };
1885 
1886 struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
1887   using ConvertOpToLLVMPattern<ConstantOp>::ConvertOpToLLVMPattern;
1888 
1889   LogicalResult
matchAndRewrite__anone5172fdb1311::ConstantOpLowering1890   matchAndRewrite(ConstantOp op, ArrayRef<Value> operands,
1891                   ConversionPatternRewriter &rewriter) const override {
1892     // If constant refers to a function, convert it to "addressof".
1893     if (auto symbolRef = op.getValue().dyn_cast<FlatSymbolRefAttr>()) {
1894       auto type = typeConverter->convertType(op.getResult().getType())
1895                       .dyn_cast_or_null<LLVM::LLVMType>();
1896       if (!type)
1897         return rewriter.notifyMatchFailure(op, "failed to convert result type");
1898 
1899       MutableDictionaryAttr attrs(op.getAttrs());
1900       attrs.remove(rewriter.getIdentifier("value"));
1901       rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(
1902           op, type.cast<LLVM::LLVMType>(), symbolRef.getValue(),
1903           attrs.getAttrs());
1904       return success();
1905     }
1906 
1907     // Calling into other scopes (non-flat reference) is not supported in LLVM.
1908     if (op.getValue().isa<SymbolRefAttr>())
1909       return rewriter.notifyMatchFailure(
1910           op, "referring to a symbol outside of the current module");
1911 
1912     return LLVM::detail::oneToOneRewrite(
1913         op, LLVM::ConstantOp::getOperationName(), operands, *getTypeConverter(),
1914         rewriter);
1915   }
1916 };
1917 
1918 /// Lowering for AllocOp and AllocaOp.
1919 struct AllocLikeOpLowering : public ConvertToLLVMPattern {
1920   using ConvertToLLVMPattern::createIndexConstant;
1921   using ConvertToLLVMPattern::getIndexType;
1922   using ConvertToLLVMPattern::getVoidPtrType;
1923 
AllocLikeOpLowering__anone5172fdb1311::AllocLikeOpLowering1924   explicit AllocLikeOpLowering(StringRef opName, LLVMTypeConverter &converter)
1925       : ConvertToLLVMPattern(opName, &converter.getContext(), converter) {}
1926 
1927 protected:
1928   // Returns 'input' aligned up to 'alignment'. Computes
1929   // bumped = input + alignement - 1
1930   // aligned = bumped - bumped % alignment
createAligned__anone5172fdb1311::AllocLikeOpLowering1931   static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
1932                              Value input, Value alignment) {
1933     Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
1934     Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
1935     Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
1936     Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
1937     return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
1938   }
1939 
1940   // Creates a call to an allocation function with params and casts the
1941   // resulting void pointer to ptrType.
createAllocCall__anone5172fdb1311::AllocLikeOpLowering1942   Value createAllocCall(Location loc, StringRef name, Type ptrType,
1943                         ArrayRef<Value> params, ModuleOp module,
1944                         ConversionPatternRewriter &rewriter) const {
1945     SmallVector<LLVM::LLVMType, 2> paramTypes;
1946     auto allocFuncOp = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1947     if (!allocFuncOp) {
1948       for (Value param : params)
1949         paramTypes.push_back(param.getType().cast<LLVM::LLVMType>());
1950       auto allocFuncType =
1951           LLVM::LLVMType::getFunctionTy(getVoidPtrType(), paramTypes,
1952                                         /*isVarArg=*/false);
1953       OpBuilder::InsertionGuard guard(rewriter);
1954       rewriter.setInsertionPointToStart(module.getBody());
1955       allocFuncOp = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
1956                                                       name, allocFuncType);
1957     }
1958     auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFuncOp);
1959     auto allocatedPtr = rewriter
1960                             .create<LLVM::CallOp>(loc, getVoidPtrType(),
1961                                                   allocFuncSymbol, params)
1962                             .getResult(0);
1963     return rewriter.create<LLVM::BitcastOp>(loc, ptrType, allocatedPtr);
1964   }
1965 
1966   /// Allocates the underlying buffer. Returns the allocated pointer and the
1967   /// aligned pointer.
1968   virtual std::tuple<Value, Value>
1969   allocateBuffer(ConversionPatternRewriter &rewriter, Location loc,
1970                  Value sizeBytes, Operation *op) const = 0;
1971 
1972 private:
getMemRefResultType__anone5172fdb1311::AllocLikeOpLowering1973   static MemRefType getMemRefResultType(Operation *op) {
1974     return op->getResult(0).getType().cast<MemRefType>();
1975   }
1976 
match__anone5172fdb1311::AllocLikeOpLowering1977   LogicalResult match(Operation *op) const override {
1978     MemRefType memRefType = getMemRefResultType(op);
1979     return success(isSupportedMemRefType(memRefType));
1980   }
1981 
1982   // An `alloc` is converted into a definition of a memref descriptor value and
1983   // a call to `malloc` to allocate the underlying data buffer.  The memref
1984   // descriptor is of the LLVM structure type where:
1985   //   1. the first element is a pointer to the allocated (typed) data buffer,
1986   //   2. the second element is a pointer to the (typed) payload, aligned to the
1987   //      specified alignment,
1988   //   3. the remaining elements serve to store all the sizes and strides of the
1989   //      memref using LLVM-converted `index` type.
1990   //
1991   // Alignment is performed by allocating `alignment` more bytes than
1992   // requested and shifting the aligned pointer relative to the allocated
1993   // memory. Note: `alignment - <minimum malloc alignment>` would actually be
1994   // sufficient. If alignment is unspecified, the two pointers are equal.
1995 
1996   // An `alloca` is converted into a definition of a memref descriptor value and
1997   // an llvm.alloca to allocate the underlying data buffer.
rewrite__anone5172fdb1311::AllocLikeOpLowering1998   void rewrite(Operation *op, ArrayRef<Value> operands,
1999                ConversionPatternRewriter &rewriter) const override {
2000     MemRefType memRefType = getMemRefResultType(op);
2001     auto loc = op->getLoc();
2002 
2003     // Get actual sizes of the memref as values: static sizes are constant
2004     // values and dynamic sizes are passed to 'alloc' as operands.  In case of
2005     // zero-dimensional memref, assume a scalar (size 1).
2006     SmallVector<Value, 4> sizes;
2007     SmallVector<Value, 4> strides;
2008     Value sizeBytes;
2009     this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
2010                                    strides, sizeBytes);
2011 
2012     // Allocate the underlying buffer.
2013     Value allocatedPtr;
2014     Value alignedPtr;
2015     std::tie(allocatedPtr, alignedPtr) =
2016         this->allocateBuffer(rewriter, loc, sizeBytes, op);
2017 
2018     // Create the MemRef descriptor.
2019     auto memRefDescriptor = this->createMemRefDescriptor(
2020         loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
2021 
2022     // Return the final value of the descriptor.
2023     rewriter.replaceOp(op, {memRefDescriptor});
2024   }
2025 };
2026 
2027 struct AllocOpLowering : public AllocLikeOpLowering {
AllocOpLowering__anone5172fdb1311::AllocOpLowering2028   AllocOpLowering(LLVMTypeConverter &converter)
2029       : AllocLikeOpLowering(AllocOp::getOperationName(), converter) {}
2030 
allocateBuffer__anone5172fdb1311::AllocOpLowering2031   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
2032                                           Location loc, Value sizeBytes,
2033                                           Operation *op) const override {
2034     // Heap allocations.
2035     AllocOp allocOp = cast<AllocOp>(op);
2036     MemRefType memRefType = allocOp.getType();
2037 
2038     Value alignment;
2039     if (auto alignmentAttr = allocOp.alignment()) {
2040       alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
2041     } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
2042       // In the case where no alignment is specified, we may want to override
2043       // `malloc's` behavior. `malloc` typically aligns at the size of the
2044       // biggest scalar on a target HW. For non-scalars, use the natural
2045       // alignment of the LLVM type given by the LLVM DataLayout.
2046       alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
2047     }
2048 
2049     if (alignment) {
2050       // Adjust the allocation size to consider alignment.
2051       sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
2052     }
2053 
2054     // Allocate the underlying buffer and store a pointer to it in the MemRef
2055     // descriptor.
2056     Type elementPtrType = this->getElementPtrType(memRefType);
2057     Value allocatedPtr =
2058         createAllocCall(loc, "malloc", elementPtrType, {sizeBytes},
2059                         allocOp->getParentOfType<ModuleOp>(), rewriter);
2060 
2061     Value alignedPtr = allocatedPtr;
2062     if (alignment) {
2063       auto intPtrType = getIntPtrType(memRefType.getMemorySpace());
2064       // Compute the aligned type pointer.
2065       Value allocatedInt =
2066           rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, allocatedPtr);
2067       Value alignmentInt =
2068           createAligned(rewriter, loc, allocatedInt, alignment);
2069       alignedPtr =
2070           rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
2071     }
2072 
2073     return std::make_tuple(allocatedPtr, alignedPtr);
2074   }
2075 };
2076 
2077 struct AlignedAllocOpLowering : public AllocLikeOpLowering {
AlignedAllocOpLowering__anone5172fdb1311::AlignedAllocOpLowering2078   AlignedAllocOpLowering(LLVMTypeConverter &converter)
2079       : AllocLikeOpLowering(AllocOp::getOperationName(), converter) {}
2080 
2081   /// Returns the memref's element size in bytes.
2082   // TODO: there are other places where this is used. Expose publicly?
getMemRefEltSizeInBytes__anone5172fdb1311::AlignedAllocOpLowering2083   static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
2084     auto elementType = memRefType.getElementType();
2085 
2086     unsigned sizeInBits;
2087     if (elementType.isIntOrFloat()) {
2088       sizeInBits = elementType.getIntOrFloatBitWidth();
2089     } else {
2090       auto vectorType = elementType.cast<VectorType>();
2091       sizeInBits =
2092           vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
2093     }
2094     return llvm::divideCeil(sizeInBits, 8);
2095   }
2096 
2097   /// Returns true if the memref size in bytes is known to be a multiple of
2098   /// factor.
isMemRefSizeMultipleOf__anone5172fdb1311::AlignedAllocOpLowering2099   static bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor) {
2100     uint64_t sizeDivisor = getMemRefEltSizeInBytes(type);
2101     for (unsigned i = 0, e = type.getRank(); i < e; i++) {
2102       if (type.isDynamic(type.getDimSize(i)))
2103         continue;
2104       sizeDivisor = sizeDivisor * type.getDimSize(i);
2105     }
2106     return sizeDivisor % factor == 0;
2107   }
2108 
2109   /// Returns the alignment to be used for the allocation call itself.
2110   /// aligned_alloc requires the allocation size to be a power of two, and the
2111   /// allocation size to be a multiple of alignment,
getAllocationAlignment__anone5172fdb1311::AlignedAllocOpLowering2112   int64_t getAllocationAlignment(AllocOp allocOp) const {
2113     if (Optional<uint64_t> alignment = allocOp.alignment())
2114       return *alignment;
2115 
2116     // Whenever we don't have alignment set, we will use an alignment
2117     // consistent with the element type; since the allocation size has to be a
2118     // power of two, we will bump to the next power of two if it already isn't.
2119     auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType());
2120     return std::max(kMinAlignedAllocAlignment,
2121                     llvm::PowerOf2Ceil(eltSizeBytes));
2122   }
2123 
allocateBuffer__anone5172fdb1311::AlignedAllocOpLowering2124   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
2125                                           Location loc, Value sizeBytes,
2126                                           Operation *op) const override {
2127     // Heap allocations.
2128     AllocOp allocOp = cast<AllocOp>(op);
2129     MemRefType memRefType = allocOp.getType();
2130     int64_t alignment = getAllocationAlignment(allocOp);
2131     Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
2132 
2133     // aligned_alloc requires size to be a multiple of alignment; we will pad
2134     // the size to the next multiple if necessary.
2135     if (!isMemRefSizeMultipleOf(memRefType, alignment))
2136       sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
2137 
2138     Type elementPtrType = this->getElementPtrType(memRefType);
2139     Value allocatedPtr = createAllocCall(
2140         loc, "aligned_alloc", elementPtrType, {allocAlignment, sizeBytes},
2141         allocOp->getParentOfType<ModuleOp>(), rewriter);
2142 
2143     return std::make_tuple(allocatedPtr, allocatedPtr);
2144   }
2145 
2146   /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
2147   static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
2148 };
2149 
2150 // Out of line definition, required till C++17.
2151 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
2152 
2153 struct AllocaOpLowering : public AllocLikeOpLowering {
AllocaOpLowering__anone5172fdb1311::AllocaOpLowering2154   AllocaOpLowering(LLVMTypeConverter &converter)
2155       : AllocLikeOpLowering(AllocaOp::getOperationName(), converter) {}
2156 
2157   /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
2158   /// is set to null for stack allocations. `accessAlignment` is set if
2159   /// alignment is needed post allocation (for eg. in conjunction with malloc).
allocateBuffer__anone5172fdb1311::AllocaOpLowering2160   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
2161                                           Location loc, Value sizeBytes,
2162                                           Operation *op) const override {
2163 
2164     // With alloca, one gets a pointer to the element type right away.
2165     // For stack allocations.
2166     auto allocaOp = cast<AllocaOp>(op);
2167     auto elementPtrType = this->getElementPtrType(allocaOp.getType());
2168 
2169     auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
2170         loc, elementPtrType, sizeBytes,
2171         allocaOp.alignment() ? *allocaOp.alignment() : 0);
2172 
2173     return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
2174   }
2175 };
2176 
2177 /// Copies the shaped descriptor part to (if `toDynamic` is set) or from
2178 /// (otherwise) the dynamically allocated memory for any operands that were
2179 /// unranked descriptors originally.
copyUnrankedDescriptors(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,TypeRange origTypes,SmallVectorImpl<Value> & operands,bool toDynamic)2180 static LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
2181                                              LLVMTypeConverter &typeConverter,
2182                                              TypeRange origTypes,
2183                                              SmallVectorImpl<Value> &operands,
2184                                              bool toDynamic) {
2185   assert(origTypes.size() == operands.size() &&
2186          "expected as may original types as operands");
2187 
2188   // Find operands of unranked memref type and store them.
2189   SmallVector<UnrankedMemRefDescriptor, 4> unrankedMemrefs;
2190   for (unsigned i = 0, e = operands.size(); i < e; ++i)
2191     if (origTypes[i].isa<UnrankedMemRefType>())
2192       unrankedMemrefs.emplace_back(operands[i]);
2193 
2194   if (unrankedMemrefs.empty())
2195     return success();
2196 
2197   // Compute allocation sizes.
2198   SmallVector<Value, 4> sizes;
2199   UnrankedMemRefDescriptor::computeSizes(builder, loc, typeConverter,
2200                                          unrankedMemrefs, sizes);
2201 
2202   // Get frequently used types.
2203   MLIRContext *context = builder.getContext();
2204   auto voidType = LLVM::LLVMType::getVoidTy(context);
2205   auto voidPtrType = LLVM::LLVMType::getInt8PtrTy(context);
2206   auto i1Type = LLVM::LLVMType::getInt1Ty(context);
2207   LLVM::LLVMType indexType = typeConverter.getIndexType();
2208 
2209   // Find the malloc and free, or declare them if necessary.
2210   auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
2211   auto mallocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc");
2212   if (!mallocFunc && toDynamic) {
2213     OpBuilder::InsertionGuard guard(builder);
2214     builder.setInsertionPointToStart(module.getBody());
2215     mallocFunc = builder.create<LLVM::LLVMFuncOp>(
2216         builder.getUnknownLoc(), "malloc",
2217         LLVM::LLVMType::getFunctionTy(
2218             voidPtrType, llvm::makeArrayRef(indexType), /*isVarArg=*/false));
2219   }
2220   auto freeFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("free");
2221   if (!freeFunc && !toDynamic) {
2222     OpBuilder::InsertionGuard guard(builder);
2223     builder.setInsertionPointToStart(module.getBody());
2224     freeFunc = builder.create<LLVM::LLVMFuncOp>(
2225         builder.getUnknownLoc(), "free",
2226         LLVM::LLVMType::getFunctionTy(voidType, llvm::makeArrayRef(voidPtrType),
2227                                       /*isVarArg=*/false));
2228   }
2229 
2230   // Initialize shared constants.
2231   Value zero =
2232       builder.create<LLVM::ConstantOp>(loc, i1Type, builder.getBoolAttr(false));
2233 
2234   unsigned unrankedMemrefPos = 0;
2235   for (unsigned i = 0, e = operands.size(); i < e; ++i) {
2236     Type type = origTypes[i];
2237     if (!type.isa<UnrankedMemRefType>())
2238       continue;
2239     Value allocationSize = sizes[unrankedMemrefPos++];
2240     UnrankedMemRefDescriptor desc(operands[i]);
2241 
2242     // Allocate memory, copy, and free the source if necessary.
2243     Value memory =
2244         toDynamic
2245             ? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
2246                   .getResult(0)
2247             : builder.create<LLVM::AllocaOp>(loc, voidPtrType, allocationSize,
2248                                              /*alignment=*/0);
2249 
2250     Value source = desc.memRefDescPtr(builder, loc);
2251     builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, zero);
2252     if (!toDynamic)
2253       builder.create<LLVM::CallOp>(loc, freeFunc, source);
2254 
2255     // Create a new descriptor. The same descriptor can be returned multiple
2256     // times, attempting to modify its pointer can lead to memory leaks
2257     // (allocated twice and overwritten) or double frees (the caller does not
2258     // know if the descriptor points to the same memory).
2259     Type descriptorType = typeConverter.convertType(type);
2260     if (!descriptorType)
2261       return failure();
2262     auto updatedDesc =
2263         UnrankedMemRefDescriptor::undef(builder, loc, descriptorType);
2264     Value rank = desc.rank(builder, loc);
2265     updatedDesc.setRank(builder, loc, rank);
2266     updatedDesc.setMemRefDescPtr(builder, loc, memory);
2267 
2268     operands[i] = updatedDesc;
2269   }
2270 
2271   return success();
2272 }
2273 
2274 // A CallOp automatically promotes MemRefType to a sequence of alloca/store and
2275 // passes the pointer to the MemRef across function boundaries.
2276 template <typename CallOpType>
2277 struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
2278   using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
2279   using Super = CallOpInterfaceLowering<CallOpType>;
2280   using Base = ConvertOpToLLVMPattern<CallOpType>;
2281 
2282   LogicalResult
matchAndRewrite__anone5172fdb1311::CallOpInterfaceLowering2283   matchAndRewrite(CallOpType callOp, ArrayRef<Value> operands,
2284                   ConversionPatternRewriter &rewriter) const override {
2285     typename CallOpType::Adaptor transformed(operands);
2286 
2287     // Pack the result types into a struct.
2288     Type packedResult = nullptr;
2289     unsigned numResults = callOp.getNumResults();
2290     auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
2291 
2292     if (numResults != 0) {
2293       if (!(packedResult =
2294                 this->getTypeConverter()->packFunctionResults(resultTypes)))
2295         return failure();
2296     }
2297 
2298     auto promoted = this->getTypeConverter()->promoteOperands(
2299         callOp.getLoc(), /*opOperands=*/callOp->getOperands(), operands,
2300         rewriter);
2301     auto newOp = rewriter.create<LLVM::CallOp>(
2302         callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
2303         promoted, callOp.getAttrs());
2304 
2305     SmallVector<Value, 4> results;
2306     if (numResults < 2) {
2307       // If < 2 results, packing did not do anything and we can just return.
2308       results.append(newOp.result_begin(), newOp.result_end());
2309     } else {
2310       // Otherwise, it had been converted to an operation producing a structure.
2311       // Extract individual results from the structure and return them as list.
2312       results.reserve(numResults);
2313       for (unsigned i = 0; i < numResults; ++i) {
2314         auto type =
2315             this->typeConverter->convertType(callOp.getResult(i).getType());
2316         results.push_back(rewriter.create<LLVM::ExtractValueOp>(
2317             callOp.getLoc(), type, newOp->getResult(0),
2318             rewriter.getI64ArrayAttr(i)));
2319       }
2320     }
2321 
2322     if (this->getTypeConverter()->getOptions().useBarePtrCallConv) {
2323       // For the bare-ptr calling convention, promote memref results to
2324       // descriptors.
2325       assert(results.size() == resultTypes.size() &&
2326              "The number of arguments and types doesn't match");
2327       this->getTypeConverter()->promoteBarePtrsToDescriptors(
2328           rewriter, callOp.getLoc(), resultTypes, results);
2329     } else if (failed(copyUnrankedDescriptors(rewriter, callOp.getLoc(),
2330                                               *this->getTypeConverter(),
2331                                               resultTypes, results,
2332                                               /*toDynamic=*/false))) {
2333       return failure();
2334     }
2335 
2336     rewriter.replaceOp(callOp, results);
2337     return success();
2338   }
2339 };
2340 
2341 struct CallOpLowering : public CallOpInterfaceLowering<CallOp> {
2342   using Super::Super;
2343 };
2344 
2345 struct CallIndirectOpLowering : public CallOpInterfaceLowering<CallIndirectOp> {
2346   using Super::Super;
2347 };
2348 
2349 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
2350 // The memref descriptor being an SSA value, there is no need to clean it up
2351 // in any way.
2352 struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
2353   using ConvertOpToLLVMPattern<DeallocOp>::ConvertOpToLLVMPattern;
2354 
DeallocOpLowering__anone5172fdb1311::DeallocOpLowering2355   explicit DeallocOpLowering(LLVMTypeConverter &converter)
2356       : ConvertOpToLLVMPattern<DeallocOp>(converter) {}
2357 
2358   LogicalResult
matchAndRewrite__anone5172fdb1311::DeallocOpLowering2359   matchAndRewrite(DeallocOp op, ArrayRef<Value> operands,
2360                   ConversionPatternRewriter &rewriter) const override {
2361     assert(operands.size() == 1 && "dealloc takes one operand");
2362     DeallocOp::Adaptor transformed(operands);
2363 
2364     // Insert the `free` declaration if it is not already present.
2365     auto freeFunc =
2366         op->getParentOfType<ModuleOp>().lookupSymbol<LLVM::LLVMFuncOp>("free");
2367     if (!freeFunc) {
2368       OpBuilder::InsertionGuard guard(rewriter);
2369       rewriter.setInsertionPointToStart(
2370           op->getParentOfType<ModuleOp>().getBody());
2371       freeFunc = rewriter.create<LLVM::LLVMFuncOp>(
2372           rewriter.getUnknownLoc(), "free",
2373           LLVM::LLVMType::getFunctionTy(getVoidType(), getVoidPtrType(),
2374                                         /*isVarArg=*/false));
2375     }
2376 
2377     MemRefDescriptor memref(transformed.memref());
2378     Value casted = rewriter.create<LLVM::BitcastOp>(
2379         op.getLoc(), getVoidPtrType(),
2380         memref.allocatedPtr(rewriter, op.getLoc()));
2381     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
2382         op, TypeRange(), rewriter.getSymbolRefAttr(freeFunc), casted);
2383     return success();
2384   }
2385 };
2386 
2387 /// Returns the LLVM type of the global variable given the memref type `type`.
2388 static LLVM::LLVMType
convertGlobalMemrefTypeToLLVM(MemRefType type,LLVMTypeConverter & typeConverter)2389 convertGlobalMemrefTypeToLLVM(MemRefType type,
2390                               LLVMTypeConverter &typeConverter) {
2391   // LLVM type for a global memref will be a multi-dimension array. For
2392   // declarations or uninitialized global memrefs, we can potentially flatten
2393   // this to a 1D array. However, for global_memref's with an initial value,
2394   // we do not intend to flatten the ElementsAttribute when going from std ->
2395   // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
2396   LLVM::LLVMType elementType =
2397       unwrap(typeConverter.convertType(type.getElementType()));
2398   LLVM::LLVMType arrayTy = elementType;
2399   // Shape has the outermost dim at index 0, so need to walk it backwards
2400   for (int64_t dim : llvm::reverse(type.getShape()))
2401     arrayTy = LLVM::LLVMType::getArrayTy(arrayTy, dim);
2402   return arrayTy;
2403 }
2404 
2405 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
2406 struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<GlobalMemrefOp> {
2407   using ConvertOpToLLVMPattern<GlobalMemrefOp>::ConvertOpToLLVMPattern;
2408 
2409   LogicalResult
matchAndRewrite__anone5172fdb1311::GlobalMemrefOpLowering2410   matchAndRewrite(GlobalMemrefOp global, ArrayRef<Value> operands,
2411                   ConversionPatternRewriter &rewriter) const override {
2412     MemRefType type = global.type().cast<MemRefType>();
2413     if (!isSupportedMemRefType(type))
2414       return failure();
2415 
2416     LLVM::LLVMType arrayTy =
2417         convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
2418 
2419     LLVM::Linkage linkage =
2420         global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
2421 
2422     Attribute initialValue = nullptr;
2423     if (!global.isExternal() && !global.isUninitialized()) {
2424       auto elementsAttr = global.initial_value()->cast<ElementsAttr>();
2425       initialValue = elementsAttr;
2426 
2427       // For scalar memrefs, the global variable created is of the element type,
2428       // so unpack the elements attribute to extract the value.
2429       if (type.getRank() == 0)
2430         initialValue = elementsAttr.getValue({});
2431     }
2432 
2433     rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
2434         global, arrayTy, global.constant(), linkage, global.sym_name(),
2435         initialValue, type.getMemorySpace());
2436     return success();
2437   }
2438 };
2439 
2440 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
2441 /// the first element stashed into the descriptor. This reuses
2442 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
2443 struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering {
GetGlobalMemrefOpLowering__anone5172fdb1311::GetGlobalMemrefOpLowering2444   GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
2445       : AllocLikeOpLowering(GetGlobalMemrefOp::getOperationName(), converter) {}
2446 
2447   /// Buffer "allocation" for get_global_memref op is getting the address of
2448   /// the global variable referenced.
allocateBuffer__anone5172fdb1311::GetGlobalMemrefOpLowering2449   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
2450                                           Location loc, Value sizeBytes,
2451                                           Operation *op) const override {
2452     auto getGlobalOp = cast<GetGlobalMemrefOp>(op);
2453     MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
2454     unsigned memSpace = type.getMemorySpace();
2455 
2456     LLVM::LLVMType arrayTy =
2457         convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
2458     auto addressOf = rewriter.create<LLVM::AddressOfOp>(
2459         loc, arrayTy.getPointerTo(memSpace), getGlobalOp.name());
2460 
2461     // Get the address of the first element in the array by creating a GEP with
2462     // the address of the GV as the base, and (rank + 1) number of 0 indices.
2463     LLVM::LLVMType elementType =
2464         unwrap(typeConverter->convertType(type.getElementType()));
2465     LLVM::LLVMType elementPtrType = elementType.getPointerTo(memSpace);
2466 
2467     SmallVector<Value, 4> operands = {addressOf};
2468     operands.insert(operands.end(), type.getRank() + 1,
2469                     createIndexConstant(rewriter, loc, 0));
2470     auto gep = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, operands);
2471 
2472     // We do not expect the memref obtained using `get_global_memref` to be
2473     // ever deallocated. Set the allocated pointer to be known bad value to
2474     // help debug if that ever happens.
2475     auto intPtrType = getIntPtrType(memSpace);
2476     Value deadBeefConst =
2477         createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
2478     auto deadBeefPtr =
2479         rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
2480 
2481     // Both allocated and aligned pointers are same. We could potentially stash
2482     // a nullptr for the allocated pointer since we do not expect any dealloc.
2483     return std::make_tuple(deadBeefPtr, gep);
2484   }
2485 };
2486 
2487 // A `rsqrt` is converted into `1 / sqrt`.
2488 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
2489   using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
2490 
2491   LogicalResult
matchAndRewrite__anone5172fdb1311::RsqrtOpLowering2492   matchAndRewrite(RsqrtOp op, ArrayRef<Value> operands,
2493                   ConversionPatternRewriter &rewriter) const override {
2494     RsqrtOp::Adaptor transformed(operands);
2495     auto operandType =
2496         transformed.operand().getType().dyn_cast<LLVM::LLVMType>();
2497 
2498     if (!operandType)
2499       return failure();
2500 
2501     auto loc = op.getLoc();
2502     auto resultType = op.getResult().getType();
2503     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
2504     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
2505 
2506     if (!operandType.isArrayTy()) {
2507       LLVM::ConstantOp one;
2508       if (operandType.isVectorTy()) {
2509         one = rewriter.create<LLVM::ConstantOp>(
2510             loc, operandType,
2511             SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
2512       } else {
2513         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
2514       }
2515       auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, transformed.operand());
2516       rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
2517       return success();
2518     }
2519 
2520     auto vectorType = resultType.dyn_cast<VectorType>();
2521     if (!vectorType)
2522       return failure();
2523 
2524     return handleMultidimensionalVectors(
2525         op.getOperation(), operands, *getTypeConverter(),
2526         [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
2527           auto splatAttr = SplatElementsAttr::get(
2528               mlir::VectorType::get({llvmVectorTy.getVectorNumElements()},
2529                                     floatType),
2530               floatOne);
2531           auto one =
2532               rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy, splatAttr);
2533           auto sqrt =
2534               rewriter.create<LLVM::SqrtOp>(loc, llvmVectorTy, operands[0]);
2535           return rewriter.create<LLVM::FDivOp>(loc, llvmVectorTy, one, sqrt);
2536         },
2537         rewriter);
2538   }
2539 };
2540 
2541 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
2542   using ConvertOpToLLVMPattern<MemRefCastOp>::ConvertOpToLLVMPattern;
2543 
match__anone5172fdb1311::MemRefCastOpLowering2544   LogicalResult match(MemRefCastOp memRefCastOp) const override {
2545     Type srcType = memRefCastOp.getOperand().getType();
2546     Type dstType = memRefCastOp.getType();
2547 
2548     // MemRefCastOp reduce to bitcast in the ranked MemRef case and can be used
2549     // for type erasure. For now they must preserve underlying element type and
2550     // require source and result type to have the same rank. Therefore, perform
2551     // a sanity check that the underlying structs are the same. Once op
2552     // semantics are relaxed we can revisit.
2553     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
2554       return success(typeConverter->convertType(srcType) ==
2555                      typeConverter->convertType(dstType));
2556 
2557     // At least one of the operands is unranked type
2558     assert(srcType.isa<UnrankedMemRefType>() ||
2559            dstType.isa<UnrankedMemRefType>());
2560 
2561     // Unranked to unranked cast is disallowed
2562     return !(srcType.isa<UnrankedMemRefType>() &&
2563              dstType.isa<UnrankedMemRefType>())
2564                ? success()
2565                : failure();
2566   }
2567 
rewrite__anone5172fdb1311::MemRefCastOpLowering2568   void rewrite(MemRefCastOp memRefCastOp, ArrayRef<Value> operands,
2569                ConversionPatternRewriter &rewriter) const override {
2570     MemRefCastOp::Adaptor transformed(operands);
2571 
2572     auto srcType = memRefCastOp.getOperand().getType();
2573     auto dstType = memRefCastOp.getType();
2574     auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
2575     auto loc = memRefCastOp.getLoc();
2576 
2577     // For ranked/ranked case, just keep the original descriptor.
2578     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
2579       return rewriter.replaceOp(memRefCastOp, {transformed.source()});
2580 
2581     if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
2582       // Casting ranked to unranked memref type
2583       // Set the rank in the destination from the memref type
2584       // Allocate space on the stack and copy the src memref descriptor
2585       // Set the ptr in the destination to the stack space
2586       auto srcMemRefType = srcType.cast<MemRefType>();
2587       int64_t rank = srcMemRefType.getRank();
2588       // ptr = AllocaOp sizeof(MemRefDescriptor)
2589       auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
2590           loc, transformed.source(), rewriter);
2591       // voidptr = BitCastOp srcType* to void*
2592       auto voidPtr =
2593           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
2594               .getResult();
2595       // rank = ConstantOp srcRank
2596       auto rankVal = rewriter.create<LLVM::ConstantOp>(
2597           loc, typeConverter->convertType(rewriter.getIntegerType(64)),
2598           rewriter.getI64IntegerAttr(rank));
2599       // undef = UndefOp
2600       UnrankedMemRefDescriptor memRefDesc =
2601           UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
2602       // d1 = InsertValueOp undef, rank, 0
2603       memRefDesc.setRank(rewriter, loc, rankVal);
2604       // d2 = InsertValueOp d1, voidptr, 1
2605       memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
2606       rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
2607 
2608     } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
2609       // Casting from unranked type to ranked.
2610       // The operation is assumed to be doing a correct cast. If the destination
2611       // type mismatches the unranked the type, it is undefined behavior.
2612       UnrankedMemRefDescriptor memRefDesc(transformed.source());
2613       // ptr = ExtractValueOp src, 1
2614       auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
2615       // castPtr = BitCastOp i8* to structTy*
2616       auto castPtr =
2617           rewriter
2618               .create<LLVM::BitcastOp>(
2619                   loc, targetStructType.cast<LLVM::LLVMType>().getPointerTo(),
2620                   ptr)
2621               .getResult();
2622       // struct = LoadOp castPtr
2623       auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
2624       rewriter.replaceOp(memRefCastOp, loadOp.getResult());
2625     } else {
2626       llvm_unreachable("Unsupported unranked memref to unranked memref cast");
2627     }
2628   }
2629 };
2630 
2631 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
2632 /// memref type. In unranked case, the fields are extracted from the underlying
2633 /// ranked descriptor.
extractPointersAndOffset(Location loc,ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,Value originalOperand,Value convertedOperand,Value * allocatedPtr,Value * alignedPtr,Value * offset=nullptr)2634 static void extractPointersAndOffset(Location loc,
2635                                      ConversionPatternRewriter &rewriter,
2636                                      LLVMTypeConverter &typeConverter,
2637                                      Value originalOperand,
2638                                      Value convertedOperand,
2639                                      Value *allocatedPtr, Value *alignedPtr,
2640                                      Value *offset = nullptr) {
2641   Type operandType = originalOperand.getType();
2642   if (operandType.isa<MemRefType>()) {
2643     MemRefDescriptor desc(convertedOperand);
2644     *allocatedPtr = desc.allocatedPtr(rewriter, loc);
2645     *alignedPtr = desc.alignedPtr(rewriter, loc);
2646     if (offset != nullptr)
2647       *offset = desc.offset(rewriter, loc);
2648     return;
2649   }
2650 
2651   unsigned memorySpace =
2652       operandType.cast<UnrankedMemRefType>().getMemorySpace();
2653   Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
2654   LLVM::LLVMType llvmElementType =
2655       unwrap(typeConverter.convertType(elementType));
2656   LLVM::LLVMType elementPtrPtrType =
2657       llvmElementType.getPointerTo(memorySpace).getPointerTo();
2658 
2659   // Extract pointer to the underlying ranked memref descriptor and cast it to
2660   // ElemType**.
2661   UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
2662   Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
2663 
2664   *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
2665       rewriter, loc, underlyingDescPtr, elementPtrPtrType);
2666   *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
2667       rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
2668   if (offset != nullptr) {
2669     *offset = UnrankedMemRefDescriptor::offset(
2670         rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
2671   }
2672 }
2673 
2674 struct MemRefReinterpretCastOpLowering
2675     : public ConvertOpToLLVMPattern<MemRefReinterpretCastOp> {
2676   using ConvertOpToLLVMPattern<MemRefReinterpretCastOp>::ConvertOpToLLVMPattern;
2677 
2678   LogicalResult
matchAndRewrite__anone5172fdb1311::MemRefReinterpretCastOpLowering2679   matchAndRewrite(MemRefReinterpretCastOp castOp, ArrayRef<Value> operands,
2680                   ConversionPatternRewriter &rewriter) const override {
2681     MemRefReinterpretCastOp::Adaptor adaptor(operands,
2682                                              castOp->getAttrDictionary());
2683     Type srcType = castOp.source().getType();
2684 
2685     Value descriptor;
2686     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
2687                                                adaptor, &descriptor)))
2688       return failure();
2689     rewriter.replaceOp(castOp, {descriptor});
2690     return success();
2691   }
2692 
2693 private:
2694   LogicalResult
convertSourceMemRefToDescriptor__anone5172fdb1311::MemRefReinterpretCastOpLowering2695   convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
2696                                   Type srcType, MemRefReinterpretCastOp castOp,
2697                                   MemRefReinterpretCastOp::Adaptor adaptor,
2698                                   Value *descriptor) const {
2699     MemRefType targetMemRefType =
2700         castOp.getResult().getType().cast<MemRefType>();
2701     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
2702                                       .dyn_cast_or_null<LLVM::LLVMType>();
2703     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
2704       return failure();
2705 
2706     // Create descriptor.
2707     Location loc = castOp.getLoc();
2708     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
2709 
2710     // Set allocated and aligned pointers.
2711     Value allocatedPtr, alignedPtr;
2712     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
2713                              castOp.source(), adaptor.source(), &allocatedPtr,
2714                              &alignedPtr);
2715     desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
2716     desc.setAlignedPtr(rewriter, loc, alignedPtr);
2717 
2718     // Set offset.
2719     if (castOp.isDynamicOffset(0))
2720       desc.setOffset(rewriter, loc, adaptor.offsets()[0]);
2721     else
2722       desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
2723 
2724     // Set sizes and strides.
2725     unsigned dynSizeId = 0;
2726     unsigned dynStrideId = 0;
2727     for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
2728       if (castOp.isDynamicSize(i))
2729         desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]);
2730       else
2731         desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
2732 
2733       if (castOp.isDynamicStride(i))
2734         desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]);
2735       else
2736         desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
2737     }
2738     *descriptor = desc;
2739     return success();
2740   }
2741 };
2742 
2743 struct MemRefReshapeOpLowering
2744     : public ConvertOpToLLVMPattern<MemRefReshapeOp> {
2745   using ConvertOpToLLVMPattern<MemRefReshapeOp>::ConvertOpToLLVMPattern;
2746 
2747   LogicalResult
matchAndRewrite__anone5172fdb1311::MemRefReshapeOpLowering2748   matchAndRewrite(MemRefReshapeOp reshapeOp, ArrayRef<Value> operands,
2749                   ConversionPatternRewriter &rewriter) const override {
2750     auto *op = reshapeOp.getOperation();
2751     MemRefReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary());
2752     Type srcType = reshapeOp.source().getType();
2753 
2754     Value descriptor;
2755     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
2756                                                adaptor, &descriptor)))
2757       return failure();
2758     rewriter.replaceOp(op, {descriptor});
2759     return success();
2760   }
2761 
2762 private:
2763   LogicalResult
convertSourceMemRefToDescriptor__anone5172fdb1311::MemRefReshapeOpLowering2764   convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
2765                                   Type srcType, MemRefReshapeOp reshapeOp,
2766                                   MemRefReshapeOp::Adaptor adaptor,
2767                                   Value *descriptor) const {
2768     // Conversion for statically-known shape args is performed via
2769     // `memref_reinterpret_cast`.
2770     auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>();
2771     if (shapeMemRefType.hasStaticShape())
2772       return failure();
2773 
2774     // The shape is a rank-1 tensor with unknown length.
2775     Location loc = reshapeOp.getLoc();
2776     MemRefDescriptor shapeDesc(adaptor.shape());
2777     Value resultRank = shapeDesc.size(rewriter, loc, 0);
2778 
2779     // Extract address space and element type.
2780     auto targetType =
2781         reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
2782     unsigned addressSpace = targetType.getMemorySpace();
2783     Type elementType = targetType.getElementType();
2784 
2785     // Create the unranked memref descriptor that holds the ranked one. The
2786     // inner descriptor is allocated on stack.
2787     auto targetDesc = UnrankedMemRefDescriptor::undef(
2788         rewriter, loc, unwrap(typeConverter->convertType(targetType)));
2789     targetDesc.setRank(rewriter, loc, resultRank);
2790     SmallVector<Value, 4> sizes;
2791     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
2792                                            targetDesc, sizes);
2793     Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
2794         loc, getVoidPtrType(), sizes.front(), llvm::None);
2795     targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
2796 
2797     // Extract pointers and offset from the source memref.
2798     Value allocatedPtr, alignedPtr, offset;
2799     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
2800                              reshapeOp.source(), adaptor.source(),
2801                              &allocatedPtr, &alignedPtr, &offset);
2802 
2803     // Set pointers and offset.
2804     LLVM::LLVMType llvmElementType =
2805         unwrap(typeConverter->convertType(elementType));
2806     LLVM::LLVMType elementPtrPtrType =
2807         llvmElementType.getPointerTo(addressSpace).getPointerTo();
2808     UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
2809                                               elementPtrPtrType, allocatedPtr);
2810     UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
2811                                             underlyingDescPtr,
2812                                             elementPtrPtrType, alignedPtr);
2813     UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
2814                                         underlyingDescPtr, elementPtrPtrType,
2815                                         offset);
2816 
2817     // Use the offset pointer as base for further addressing. Copy over the new
2818     // shape and compute strides. For this, we create a loop from rank-1 to 0.
2819     Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
2820         rewriter, loc, *getTypeConverter(), underlyingDescPtr,
2821         elementPtrPtrType);
2822     Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
2823         rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
2824     Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
2825     Value oneIndex = createIndexConstant(rewriter, loc, 1);
2826     Value resultRankMinusOne =
2827         rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
2828 
2829     Block *initBlock = rewriter.getInsertionBlock();
2830     LLVM::LLVMType indexType = getTypeConverter()->getIndexType();
2831     Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
2832 
2833     Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
2834                                             {indexType, indexType});
2835 
2836     // Iterate over the remaining ops in initBlock and move them to condBlock.
2837     BlockAndValueMapping map;
2838     for (auto it = remainingOpsIt, e = initBlock->end(); it != e; ++it) {
2839       rewriter.clone(*it, map);
2840       rewriter.eraseOp(&*it);
2841     }
2842 
2843     rewriter.setInsertionPointToEnd(initBlock);
2844     rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
2845                                 condBlock);
2846     rewriter.setInsertionPointToStart(condBlock);
2847     Value indexArg = condBlock->getArgument(0);
2848     Value strideArg = condBlock->getArgument(1);
2849 
2850     Value zeroIndex = createIndexConstant(rewriter, loc, 0);
2851     Value pred = rewriter.create<LLVM::ICmpOp>(
2852         loc, LLVM::LLVMType::getInt1Ty(rewriter.getContext()),
2853         LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
2854 
2855     Block *bodyBlock =
2856         rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
2857     rewriter.setInsertionPointToStart(bodyBlock);
2858 
2859     // Copy size from shape to descriptor.
2860     LLVM::LLVMType llvmIndexPtrType = indexType.getPointerTo();
2861     Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
2862         loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
2863     Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
2864     UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
2865                                       targetSizesBase, indexArg, size);
2866 
2867     // Write stride value and compute next one.
2868     UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
2869                                         targetStridesBase, indexArg, strideArg);
2870     Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
2871 
2872     // Decrement loop counter and branch back.
2873     Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
2874     rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
2875                                 condBlock);
2876 
2877     Block *remainder =
2878         rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
2879 
2880     // Hook up the cond exit to the remainder.
2881     rewriter.setInsertionPointToEnd(condBlock);
2882     rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
2883                                     llvm::None);
2884 
2885     // Reset position to beginning of new remainder block.
2886     rewriter.setInsertionPointToStart(remainder);
2887 
2888     *descriptor = targetDesc;
2889     return success();
2890   }
2891 };
2892 
2893 struct DialectCastOpLowering
2894     : public ConvertOpToLLVMPattern<LLVM::DialectCastOp> {
2895   using ConvertOpToLLVMPattern<LLVM::DialectCastOp>::ConvertOpToLLVMPattern;
2896 
2897   LogicalResult
matchAndRewrite__anone5172fdb1311::DialectCastOpLowering2898   matchAndRewrite(LLVM::DialectCastOp castOp, ArrayRef<Value> operands,
2899                   ConversionPatternRewriter &rewriter) const override {
2900     LLVM::DialectCastOp::Adaptor transformed(operands);
2901     if (transformed.in().getType() !=
2902         typeConverter->convertType(castOp.getType())) {
2903       return failure();
2904     }
2905     rewriter.replaceOp(castOp, transformed.in());
2906     return success();
2907   }
2908 };
2909 
2910 // A `dim` is converted to a constant for static sizes and to an access to the
2911 // size stored in the memref descriptor for dynamic sizes.
2912 struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
2913   using ConvertOpToLLVMPattern<DimOp>::ConvertOpToLLVMPattern;
2914 
2915   LogicalResult
matchAndRewrite__anone5172fdb1311::DimOpLowering2916   matchAndRewrite(DimOp dimOp, ArrayRef<Value> operands,
2917                   ConversionPatternRewriter &rewriter) const override {
2918     Type operandType = dimOp.memrefOrTensor().getType();
2919     if (operandType.isa<UnrankedMemRefType>()) {
2920       rewriter.replaceOp(dimOp, {extractSizeOfUnrankedMemRef(
2921                                     operandType, dimOp, operands, rewriter)});
2922 
2923       return success();
2924     }
2925     if (operandType.isa<MemRefType>()) {
2926       rewriter.replaceOp(dimOp, {extractSizeOfRankedMemRef(
2927                                     operandType, dimOp, operands, rewriter)});
2928       return success();
2929     }
2930     return failure();
2931   }
2932 
2933 private:
extractSizeOfUnrankedMemRef__anone5172fdb1311::DimOpLowering2934   Value extractSizeOfUnrankedMemRef(Type operandType, DimOp dimOp,
2935                                     ArrayRef<Value> operands,
2936                                     ConversionPatternRewriter &rewriter) const {
2937     Location loc = dimOp.getLoc();
2938     DimOp::Adaptor transformed(operands);
2939 
2940     auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
2941     auto scalarMemRefType =
2942         MemRefType::get({}, unrankedMemRefType.getElementType());
2943     unsigned addressSpace = unrankedMemRefType.getMemorySpace();
2944 
2945     // Extract pointer to the underlying ranked descriptor and bitcast it to a
2946     // memref<element_type> descriptor pointer to minimize the number of GEP
2947     // operations.
2948     UnrankedMemRefDescriptor unrankedDesc(transformed.memrefOrTensor());
2949     Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
2950     Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
2951         loc,
2952         typeConverter->convertType(scalarMemRefType)
2953             .cast<LLVM::LLVMType>()
2954             .getPointerTo(addressSpace),
2955         underlyingRankedDesc);
2956 
2957     // Get pointer to offset field of memref<element_type> descriptor.
2958     Type indexPtrTy =
2959         getTypeConverter()->getIndexType().getPointerTo(addressSpace);
2960     Value two = rewriter.create<LLVM::ConstantOp>(
2961         loc, typeConverter->convertType(rewriter.getI32Type()),
2962         rewriter.getI32IntegerAttr(2));
2963     Value offsetPtr = rewriter.create<LLVM::GEPOp>(
2964         loc, indexPtrTy, scalarMemRefDescPtr,
2965         ValueRange({createIndexConstant(rewriter, loc, 0), two}));
2966 
2967     // The size value that we have to extract can be obtained using GEPop with
2968     // `dimOp.index() + 1` index argument.
2969     Value idxPlusOne = rewriter.create<LLVM::AddOp>(
2970         loc, createIndexConstant(rewriter, loc, 1), transformed.index());
2971     Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
2972                                                  ValueRange({idxPlusOne}));
2973     return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
2974   }
2975 
extractSizeOfRankedMemRef__anone5172fdb1311::DimOpLowering2976   Value extractSizeOfRankedMemRef(Type operandType, DimOp dimOp,
2977                                   ArrayRef<Value> operands,
2978                                   ConversionPatternRewriter &rewriter) const {
2979     Location loc = dimOp.getLoc();
2980     DimOp::Adaptor transformed(operands);
2981     // Take advantage if index is constant.
2982     MemRefType memRefType = operandType.cast<MemRefType>();
2983     if (Optional<int64_t> index = dimOp.getConstantIndex()) {
2984       int64_t i = index.getValue();
2985       if (memRefType.isDynamicDim(i)) {
2986         // extract dynamic size from the memref descriptor.
2987         MemRefDescriptor descriptor(transformed.memrefOrTensor());
2988         return descriptor.size(rewriter, loc, i);
2989       }
2990       // Use constant for static size.
2991       int64_t dimSize = memRefType.getDimSize(i);
2992       return createIndexConstant(rewriter, loc, dimSize);
2993     }
2994     Value index = dimOp.index();
2995     int64_t rank = memRefType.getRank();
2996     MemRefDescriptor memrefDescriptor(transformed.memrefOrTensor());
2997     return memrefDescriptor.size(rewriter, loc, index, rank);
2998   }
2999 };
3000 
3001 struct RankOpLowering : public ConvertOpToLLVMPattern<RankOp> {
3002   using ConvertOpToLLVMPattern<RankOp>::ConvertOpToLLVMPattern;
3003 
3004   LogicalResult
matchAndRewrite__anone5172fdb1311::RankOpLowering3005   matchAndRewrite(RankOp op, ArrayRef<Value> operands,
3006                   ConversionPatternRewriter &rewriter) const override {
3007     Location loc = op.getLoc();
3008     Type operandType = op.memrefOrTensor().getType();
3009     if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
3010       UnrankedMemRefDescriptor desc(RankOp::Adaptor(operands).memrefOrTensor());
3011       rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
3012       return success();
3013     }
3014     if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
3015       rewriter.replaceOp(
3016           op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
3017       return success();
3018     }
3019     return failure();
3020   }
3021 };
3022 
3023 // Common base for load and store operations on MemRefs.  Restricts the match
3024 // to supported MemRef types.  Provides functionality to emit code accessing a
3025 // specific element of the underlying data buffer.
3026 template <typename Derived>
3027 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
3028   using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
3029   using ConvertOpToLLVMPattern<Derived>::isSupportedMemRefType;
3030   using Base = LoadStoreOpLowering<Derived>;
3031 
match__anone5172fdb1311::LoadStoreOpLowering3032   LogicalResult match(Derived op) const override {
3033     MemRefType type = op.getMemRefType();
3034     return isSupportedMemRefType(type) ? success() : failure();
3035   }
3036 };
3037 
3038 // Load operation is lowered to obtaining a pointer to the indexed element
3039 // and loading it.
3040 struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
3041   using Base::Base;
3042 
3043   LogicalResult
matchAndRewrite__anone5172fdb1311::LoadOpLowering3044   matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
3045                   ConversionPatternRewriter &rewriter) const override {
3046     LoadOp::Adaptor transformed(operands);
3047     auto type = loadOp.getMemRefType();
3048 
3049     Value dataPtr =
3050         getStridedElementPtr(loadOp.getLoc(), type, transformed.memref(),
3051                              transformed.indices(), rewriter);
3052     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
3053     return success();
3054   }
3055 };
3056 
3057 // Store operation is lowered to obtaining a pointer to the indexed element,
3058 // and storing the given value to it.
3059 struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
3060   using Base::Base;
3061 
3062   LogicalResult
matchAndRewrite__anone5172fdb1311::StoreOpLowering3063   matchAndRewrite(StoreOp op, ArrayRef<Value> operands,
3064                   ConversionPatternRewriter &rewriter) const override {
3065     auto type = op.getMemRefType();
3066     StoreOp::Adaptor transformed(operands);
3067 
3068     Value dataPtr =
3069         getStridedElementPtr(op.getLoc(), type, transformed.memref(),
3070                              transformed.indices(), rewriter);
3071     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
3072                                                dataPtr);
3073     return success();
3074   }
3075 };
3076 
3077 // The prefetch operation is lowered in a way similar to the load operation
3078 // except that the llvm.prefetch operation is used for replacement.
3079 struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
3080   using Base::Base;
3081 
3082   LogicalResult
matchAndRewrite__anone5172fdb1311::PrefetchOpLowering3083   matchAndRewrite(PrefetchOp prefetchOp, ArrayRef<Value> operands,
3084                   ConversionPatternRewriter &rewriter) const override {
3085     PrefetchOp::Adaptor transformed(operands);
3086     auto type = prefetchOp.getMemRefType();
3087     auto loc = prefetchOp.getLoc();
3088 
3089     Value dataPtr = getStridedElementPtr(loc, type, transformed.memref(),
3090                                          transformed.indices(), rewriter);
3091 
3092     // Replace with llvm.prefetch.
3093     auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
3094     auto isWrite = rewriter.create<LLVM::ConstantOp>(
3095         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
3096     auto localityHint = rewriter.create<LLVM::ConstantOp>(
3097         loc, llvmI32Type,
3098         rewriter.getI32IntegerAttr(prefetchOp.localityHint()));
3099     auto isData = rewriter.create<LLVM::ConstantOp>(
3100         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
3101 
3102     rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
3103                                                 localityHint, isData);
3104     return success();
3105   }
3106 };
3107 
3108 // The lowering of index_cast becomes an integer conversion since index becomes
3109 // an integer.  If the bit width of the source and target integer types is the
3110 // same, just erase the cast.  If the target type is wider, sign-extend the
3111 // value, otherwise truncate it.
3112 struct IndexCastOpLowering : public ConvertOpToLLVMPattern<IndexCastOp> {
3113   using ConvertOpToLLVMPattern<IndexCastOp>::ConvertOpToLLVMPattern;
3114 
3115   LogicalResult
matchAndRewrite__anone5172fdb1311::IndexCastOpLowering3116   matchAndRewrite(IndexCastOp indexCastOp, ArrayRef<Value> operands,
3117                   ConversionPatternRewriter &rewriter) const override {
3118     IndexCastOpAdaptor transformed(operands);
3119 
3120     auto targetType =
3121         typeConverter->convertType(indexCastOp.getResult().getType())
3122             .cast<LLVM::LLVMType>();
3123     auto sourceType = transformed.in().getType().cast<LLVM::LLVMType>();
3124     unsigned targetBits = targetType.getIntegerBitWidth();
3125     unsigned sourceBits = sourceType.getIntegerBitWidth();
3126 
3127     if (targetBits == sourceBits)
3128       rewriter.replaceOp(indexCastOp, transformed.in());
3129     else if (targetBits < sourceBits)
3130       rewriter.replaceOpWithNewOp<LLVM::TruncOp>(indexCastOp, targetType,
3131                                                  transformed.in());
3132     else
3133       rewriter.replaceOpWithNewOp<LLVM::SExtOp>(indexCastOp, targetType,
3134                                                 transformed.in());
3135     return success();
3136   }
3137 };
3138 
3139 // Convert std.cmp predicate into the LLVM dialect CmpPredicate.  The two
3140 // enums share the numerical values so just cast.
3141 template <typename LLVMPredType, typename StdPredType>
convertCmpPredicate(StdPredType pred)3142 static LLVMPredType convertCmpPredicate(StdPredType pred) {
3143   return static_cast<LLVMPredType>(pred);
3144 }
3145 
3146 struct CmpIOpLowering : public ConvertOpToLLVMPattern<CmpIOp> {
3147   using ConvertOpToLLVMPattern<CmpIOp>::ConvertOpToLLVMPattern;
3148 
3149   LogicalResult
matchAndRewrite__anone5172fdb1311::CmpIOpLowering3150   matchAndRewrite(CmpIOp cmpiOp, ArrayRef<Value> operands,
3151                   ConversionPatternRewriter &rewriter) const override {
3152     CmpIOpAdaptor transformed(operands);
3153 
3154     rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
3155         cmpiOp, typeConverter->convertType(cmpiOp.getResult().getType()),
3156         rewriter.getI64IntegerAttr(static_cast<int64_t>(
3157             convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()))),
3158         transformed.lhs(), transformed.rhs());
3159 
3160     return success();
3161   }
3162 };
3163 
3164 struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
3165   using ConvertOpToLLVMPattern<CmpFOp>::ConvertOpToLLVMPattern;
3166 
3167   LogicalResult
matchAndRewrite__anone5172fdb1311::CmpFOpLowering3168   matchAndRewrite(CmpFOp cmpfOp, ArrayRef<Value> operands,
3169                   ConversionPatternRewriter &rewriter) const override {
3170     CmpFOpAdaptor transformed(operands);
3171 
3172     rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
3173         cmpfOp, typeConverter->convertType(cmpfOp.getResult().getType()),
3174         rewriter.getI64IntegerAttr(static_cast<int64_t>(
3175             convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))),
3176         transformed.lhs(), transformed.rhs());
3177 
3178     return success();
3179   }
3180 };
3181 
3182 struct SIToFPLowering
3183     : public OneToOneConvertToLLVMPattern<SIToFPOp, LLVM::SIToFPOp> {
3184   using Super::Super;
3185 };
3186 
3187 struct UIToFPLowering
3188     : public OneToOneConvertToLLVMPattern<UIToFPOp, LLVM::UIToFPOp> {
3189   using Super::Super;
3190 };
3191 
3192 struct FPExtLowering
3193     : public OneToOneConvertToLLVMPattern<FPExtOp, LLVM::FPExtOp> {
3194   using Super::Super;
3195 };
3196 
3197 struct FPToSILowering
3198     : public OneToOneConvertToLLVMPattern<FPToSIOp, LLVM::FPToSIOp> {
3199   using Super::Super;
3200 };
3201 
3202 struct FPToUILowering
3203     : public OneToOneConvertToLLVMPattern<FPToUIOp, LLVM::FPToUIOp> {
3204   using Super::Super;
3205 };
3206 
3207 struct FPTruncLowering
3208     : public OneToOneConvertToLLVMPattern<FPTruncOp, LLVM::FPTruncOp> {
3209   using Super::Super;
3210 };
3211 
3212 struct SignExtendIOpLowering
3213     : public OneToOneConvertToLLVMPattern<SignExtendIOp, LLVM::SExtOp> {
3214   using Super::Super;
3215 };
3216 
3217 struct TruncateIOpLowering
3218     : public OneToOneConvertToLLVMPattern<TruncateIOp, LLVM::TruncOp> {
3219   using Super::Super;
3220 };
3221 
3222 struct ZeroExtendIOpLowering
3223     : public OneToOneConvertToLLVMPattern<ZeroExtendIOp, LLVM::ZExtOp> {
3224   using Super::Super;
3225 };
3226 
3227 // Base class for LLVM IR lowering terminator operations with successors.
3228 template <typename SourceOp, typename TargetOp>
3229 struct OneToOneLLVMTerminatorLowering
3230     : public ConvertOpToLLVMPattern<SourceOp> {
3231   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
3232   using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
3233 
3234   LogicalResult
matchAndRewrite__anone5172fdb1311::OneToOneLLVMTerminatorLowering3235   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
3236                   ConversionPatternRewriter &rewriter) const override {
3237     rewriter.replaceOpWithNewOp<TargetOp>(op, operands, op->getSuccessors(),
3238                                           op.getAttrs());
3239     return success();
3240   }
3241 };
3242 
3243 // Special lowering pattern for `ReturnOps`.  Unlike all other operations,
3244 // `ReturnOp` interacts with the function signature and must have as many
3245 // operands as the function has return values.  Because in LLVM IR, functions
3246 // can only return 0 or 1 value, we pack multiple values into a structure type.
3247 // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
3248 // necessary before returning it
3249 struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
3250   using ConvertOpToLLVMPattern<ReturnOp>::ConvertOpToLLVMPattern;
3251 
3252   LogicalResult
matchAndRewrite__anone5172fdb1311::ReturnOpLowering3253   matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
3254                   ConversionPatternRewriter &rewriter) const override {
3255     Location loc = op.getLoc();
3256     unsigned numArguments = op.getNumOperands();
3257     SmallVector<Value, 4> updatedOperands;
3258 
3259     if (getTypeConverter()->getOptions().useBarePtrCallConv) {
3260       // For the bare-ptr calling convention, extract the aligned pointer to
3261       // be returned from the memref descriptor.
3262       for (auto it : llvm::zip(op->getOperands(), operands)) {
3263         Type oldTy = std::get<0>(it).getType();
3264         Value newOperand = std::get<1>(it);
3265         if (oldTy.isa<MemRefType>()) {
3266           MemRefDescriptor memrefDesc(newOperand);
3267           newOperand = memrefDesc.alignedPtr(rewriter, loc);
3268         } else if (oldTy.isa<UnrankedMemRefType>()) {
3269           // Unranked memref is not supported in the bare pointer calling
3270           // convention.
3271           return failure();
3272         }
3273         updatedOperands.push_back(newOperand);
3274       }
3275     } else {
3276       updatedOperands = llvm::to_vector<4>(operands);
3277       copyUnrankedDescriptors(rewriter, loc, *getTypeConverter(),
3278                               op.getOperands().getTypes(), updatedOperands,
3279                               /*toDynamic=*/true);
3280     }
3281 
3282     // If ReturnOp has 0 or 1 operand, create it and return immediately.
3283     if (numArguments == 0) {
3284       rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
3285                                                   op.getAttrs());
3286       return success();
3287     }
3288     if (numArguments == 1) {
3289       rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
3290           op, TypeRange(), updatedOperands, op.getAttrs());
3291       return success();
3292     }
3293 
3294     // Otherwise, we need to pack the arguments into an LLVM struct type before
3295     // returning.
3296     auto packedType = getTypeConverter()->packFunctionResults(
3297         llvm::to_vector<4>(op.getOperandTypes()));
3298 
3299     Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
3300     for (unsigned i = 0; i < numArguments; ++i) {
3301       packed = rewriter.create<LLVM::InsertValueOp>(
3302           loc, packedType, packed, updatedOperands[i],
3303           rewriter.getI64ArrayAttr(i));
3304     }
3305     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
3306                                                 op.getAttrs());
3307     return success();
3308   }
3309 };
3310 
3311 // FIXME: this should be tablegen'ed as well.
3312 struct BranchOpLowering
3313     : public OneToOneLLVMTerminatorLowering<BranchOp, LLVM::BrOp> {
3314   using Super::Super;
3315 };
3316 struct CondBranchOpLowering
3317     : public OneToOneLLVMTerminatorLowering<CondBranchOp, LLVM::CondBrOp> {
3318   using Super::Super;
3319 };
3320 
3321 // The Splat operation is lowered to an insertelement + a shufflevector
3322 // operation. Splat to only 1-d vector result types are lowered.
3323 struct SplatOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
3324   using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
3325 
3326   LogicalResult
matchAndRewrite__anone5172fdb1311::SplatOpLowering3327   matchAndRewrite(SplatOp splatOp, ArrayRef<Value> operands,
3328                   ConversionPatternRewriter &rewriter) const override {
3329     VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
3330     if (!resultType || resultType.getRank() != 1)
3331       return failure();
3332 
3333     // First insert it into an undef vector so we can shuffle it.
3334     auto vectorType = typeConverter->convertType(splatOp.getType());
3335     Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
3336     auto zero = rewriter.create<LLVM::ConstantOp>(
3337         splatOp.getLoc(),
3338         typeConverter->convertType(rewriter.getIntegerType(32)),
3339         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
3340 
3341     auto v = rewriter.create<LLVM::InsertElementOp>(
3342         splatOp.getLoc(), vectorType, undef, splatOp.getOperand(), zero);
3343 
3344     int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
3345     SmallVector<int32_t, 4> zeroValues(width, 0);
3346 
3347     // Shuffle the value across the desired number of elements.
3348     ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
3349     rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
3350                                                        zeroAttrs);
3351     return success();
3352   }
3353 };
3354 
3355 // The Splat operation is lowered to an insertelement + a shufflevector
3356 // operation. Splat to only 2+-d vector result types are lowered by the
3357 // SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
3358 struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
3359   using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
3360 
3361   LogicalResult
matchAndRewrite__anone5172fdb1311::SplatNdOpLowering3362   matchAndRewrite(SplatOp splatOp, ArrayRef<Value> operands,
3363                   ConversionPatternRewriter &rewriter) const override {
3364     SplatOp::Adaptor adaptor(operands);
3365     VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
3366     if (!resultType || resultType.getRank() == 1)
3367       return failure();
3368 
3369     // First insert it into an undef vector so we can shuffle it.
3370     auto loc = splatOp.getLoc();
3371     auto vectorTypeInfo =
3372         extractNDVectorTypeInfo(resultType, *getTypeConverter());
3373     auto llvmArrayTy = vectorTypeInfo.llvmArrayTy;
3374     auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
3375     if (!llvmArrayTy || !llvmVectorTy)
3376       return failure();
3377 
3378     // Construct returned value.
3379     Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
3380 
3381     // Construct a 1-D vector with the splatted value that we insert in all the
3382     // places within the returned descriptor.
3383     Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvmVectorTy);
3384     auto zero = rewriter.create<LLVM::ConstantOp>(
3385         loc, typeConverter->convertType(rewriter.getIntegerType(32)),
3386         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
3387     Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvmVectorTy, vdesc,
3388                                                      adaptor.input(), zero);
3389 
3390     // Shuffle the value across the desired number of elements.
3391     int64_t width = resultType.getDimSize(resultType.getRank() - 1);
3392     SmallVector<int32_t, 4> zeroValues(width, 0);
3393     ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
3394     v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs);
3395 
3396     // Iterate of linear index, convert to coords space and insert splatted 1-D
3397     // vector in each position.
3398     nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
3399       desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, v,
3400                                                   position);
3401     });
3402     rewriter.replaceOp(splatOp, desc);
3403     return success();
3404   }
3405 };
3406 
3407 /// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr.
extractFromI64ArrayAttr(Attribute attr)3408 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
3409   return llvm::to_vector<4>(
3410       llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
3411         return a.cast<IntegerAttr>().getInt();
3412       }));
3413 }
3414 
3415 /// Conversion pattern that transforms a subview op into:
3416 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
3417 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
3418 ///      and stride.
3419 /// The subview op is replaced by the descriptor.
3420 struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
3421   using ConvertOpToLLVMPattern<SubViewOp>::ConvertOpToLLVMPattern;
3422 
3423   LogicalResult
matchAndRewrite__anone5172fdb1311::SubViewOpLowering3424   matchAndRewrite(SubViewOp subViewOp, ArrayRef<Value> operands,
3425                   ConversionPatternRewriter &rewriter) const override {
3426     auto loc = subViewOp.getLoc();
3427 
3428     auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
3429     auto sourceElementTy =
3430         typeConverter->convertType(sourceMemRefType.getElementType())
3431             .dyn_cast_or_null<LLVM::LLVMType>();
3432 
3433     auto viewMemRefType = subViewOp.getType();
3434     auto inferredType = SubViewOp::inferResultType(
3435                             subViewOp.getSourceType(),
3436                             extractFromI64ArrayAttr(subViewOp.static_offsets()),
3437                             extractFromI64ArrayAttr(subViewOp.static_sizes()),
3438                             extractFromI64ArrayAttr(subViewOp.static_strides()))
3439                             .cast<MemRefType>();
3440     auto targetElementTy =
3441         typeConverter->convertType(viewMemRefType.getElementType())
3442             .dyn_cast<LLVM::LLVMType>();
3443     auto targetDescTy = typeConverter->convertType(viewMemRefType)
3444                             .dyn_cast_or_null<LLVM::LLVMType>();
3445     if (!sourceElementTy || !targetDescTy)
3446       return failure();
3447 
3448     // Extract the offset and strides from the type.
3449     int64_t offset;
3450     SmallVector<int64_t, 4> strides;
3451     auto successStrides = getStridesAndOffset(inferredType, strides, offset);
3452     if (failed(successStrides))
3453       return failure();
3454 
3455     // Create the descriptor.
3456     if (!operands.front().getType().isa<LLVM::LLVMType>())
3457       return failure();
3458     MemRefDescriptor sourceMemRef(operands.front());
3459     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
3460 
3461     // Copy the buffer pointer from the old descriptor to the new one.
3462     Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
3463     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
3464         loc, targetElementTy.getPointerTo(viewMemRefType.getMemorySpace()),
3465         extracted);
3466     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
3467 
3468     // Copy the buffer pointer from the old descriptor to the new one.
3469     extracted = sourceMemRef.alignedPtr(rewriter, loc);
3470     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
3471         loc, targetElementTy.getPointerTo(viewMemRefType.getMemorySpace()),
3472         extracted);
3473     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
3474 
3475     auto shape = viewMemRefType.getShape();
3476     auto inferredShape = inferredType.getShape();
3477     size_t inferredShapeRank = inferredShape.size();
3478     size_t resultShapeRank = shape.size();
3479     SmallVector<bool, 4> mask =
3480         computeRankReductionMask(inferredShape, shape).getValue();
3481 
3482     // Extract strides needed to compute offset.
3483     SmallVector<Value, 4> strideValues;
3484     strideValues.reserve(inferredShapeRank);
3485     for (unsigned i = 0; i < inferredShapeRank; ++i)
3486       strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
3487 
3488     // Offset.
3489     auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
3490     if (!ShapedType::isDynamicStrideOrOffset(offset)) {
3491       targetMemRef.setConstantOffset(rewriter, loc, offset);
3492     } else {
3493       Value baseOffset = sourceMemRef.offset(rewriter, loc);
3494       for (unsigned i = 0; i < inferredShapeRank; ++i) {
3495         Value offset =
3496             subViewOp.isDynamicOffset(i)
3497                 ? operands[subViewOp.getIndexOfDynamicOffset(i)]
3498                 : rewriter.create<LLVM::ConstantOp>(
3499                       loc, llvmIndexType,
3500                       rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
3501         Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]);
3502         baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
3503       }
3504       targetMemRef.setOffset(rewriter, loc, baseOffset);
3505     }
3506 
3507     // Update sizes and strides.
3508     for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
3509          i >= 0 && j >= 0; --i) {
3510       if (!mask[i])
3511         continue;
3512 
3513       Value size =
3514           subViewOp.isDynamicSize(i)
3515               ? operands[subViewOp.getIndexOfDynamicSize(i)]
3516               : rewriter.create<LLVM::ConstantOp>(
3517                     loc, llvmIndexType,
3518                     rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
3519       targetMemRef.setSize(rewriter, loc, j, size);
3520       Value stride;
3521       if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
3522         stride = rewriter.create<LLVM::ConstantOp>(
3523             loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
3524       } else {
3525         stride =
3526             subViewOp.isDynamicStride(i)
3527                 ? operands[subViewOp.getIndexOfDynamicStride(i)]
3528                 : rewriter.create<LLVM::ConstantOp>(
3529                       loc, llvmIndexType,
3530                       rewriter.getI64IntegerAttr(subViewOp.getStaticStride(i)));
3531         stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
3532       }
3533       targetMemRef.setStride(rewriter, loc, j, stride);
3534       j--;
3535     }
3536 
3537     rewriter.replaceOp(subViewOp, {targetMemRef});
3538     return success();
3539   }
3540 };
3541 
3542 /// Conversion pattern that transforms a transpose op into:
3543 ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
3544 ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
3545 ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
3546 ///      and stride. Size and stride are permutations of the original values.
3547 ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
3548 /// The transpose op is replaced by the alloca'ed pointer.
3549 class TransposeOpLowering : public ConvertOpToLLVMPattern<TransposeOp> {
3550 public:
3551   using ConvertOpToLLVMPattern<TransposeOp>::ConvertOpToLLVMPattern;
3552 
3553   LogicalResult
matchAndRewrite(TransposeOp transposeOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const3554   matchAndRewrite(TransposeOp transposeOp, ArrayRef<Value> operands,
3555                   ConversionPatternRewriter &rewriter) const override {
3556     auto loc = transposeOp.getLoc();
3557     TransposeOpAdaptor adaptor(operands);
3558     MemRefDescriptor viewMemRef(adaptor.in());
3559 
3560     // No permutation, early exit.
3561     if (transposeOp.permutation().isIdentity())
3562       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
3563 
3564     auto targetMemRef = MemRefDescriptor::undef(
3565         rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
3566 
3567     // Copy the base and aligned pointers from the old descriptor to the new
3568     // one.
3569     targetMemRef.setAllocatedPtr(rewriter, loc,
3570                                  viewMemRef.allocatedPtr(rewriter, loc));
3571     targetMemRef.setAlignedPtr(rewriter, loc,
3572                                viewMemRef.alignedPtr(rewriter, loc));
3573 
3574     // Copy the offset pointer from the old descriptor to the new one.
3575     targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
3576 
3577     // Iterate over the dimensions and apply size/stride permutation.
3578     for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) {
3579       int sourcePos = en.index();
3580       int targetPos = en.value().cast<AffineDimExpr>().getPosition();
3581       targetMemRef.setSize(rewriter, loc, targetPos,
3582                            viewMemRef.size(rewriter, loc, sourcePos));
3583       targetMemRef.setStride(rewriter, loc, targetPos,
3584                              viewMemRef.stride(rewriter, loc, sourcePos));
3585     }
3586 
3587     rewriter.replaceOp(transposeOp, {targetMemRef});
3588     return success();
3589   }
3590 };
3591 
3592 /// Conversion pattern that transforms an op into:
3593 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
3594 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
3595 ///      and stride.
3596 /// The view op is replaced by the descriptor.
3597 struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
3598   using ConvertOpToLLVMPattern<ViewOp>::ConvertOpToLLVMPattern;
3599 
3600   // Build and return the value for the idx^th shape dimension, either by
3601   // returning the constant shape dimension or counting the proper dynamic size.
getSize__anone5172fdb1311::ViewOpLowering3602   Value getSize(ConversionPatternRewriter &rewriter, Location loc,
3603                 ArrayRef<int64_t> shape, ValueRange dynamicSizes,
3604                 unsigned idx) const {
3605     assert(idx < shape.size());
3606     if (!ShapedType::isDynamic(shape[idx]))
3607       return createIndexConstant(rewriter, loc, shape[idx]);
3608     // Count the number of dynamic dims in range [0, idx]
3609     unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) {
3610       return ShapedType::isDynamic(v);
3611     });
3612     return dynamicSizes[nDynamic];
3613   }
3614 
3615   // Build and return the idx^th stride, either by returning the constant stride
3616   // or by computing the dynamic stride from the current `runningStride` and
3617   // `nextSize`. The caller should keep a running stride and update it with the
3618   // result returned by this function.
getStride__anone5172fdb1311::ViewOpLowering3619   Value getStride(ConversionPatternRewriter &rewriter, Location loc,
3620                   ArrayRef<int64_t> strides, Value nextSize,
3621                   Value runningStride, unsigned idx) const {
3622     assert(idx < strides.size());
3623     if (strides[idx] != MemRefType::getDynamicStrideOrOffset())
3624       return createIndexConstant(rewriter, loc, strides[idx]);
3625     if (nextSize)
3626       return runningStride
3627                  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
3628                  : nextSize;
3629     assert(!runningStride);
3630     return createIndexConstant(rewriter, loc, 1);
3631   }
3632 
3633   LogicalResult
matchAndRewrite__anone5172fdb1311::ViewOpLowering3634   matchAndRewrite(ViewOp viewOp, ArrayRef<Value> operands,
3635                   ConversionPatternRewriter &rewriter) const override {
3636     auto loc = viewOp.getLoc();
3637     ViewOpAdaptor adaptor(operands);
3638 
3639     auto viewMemRefType = viewOp.getType();
3640     auto targetElementTy =
3641         typeConverter->convertType(viewMemRefType.getElementType())
3642             .dyn_cast<LLVM::LLVMType>();
3643     auto targetDescTy =
3644         typeConverter->convertType(viewMemRefType).dyn_cast<LLVM::LLVMType>();
3645     if (!targetDescTy)
3646       return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
3647              failure();
3648 
3649     int64_t offset;
3650     SmallVector<int64_t, 4> strides;
3651     auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
3652     if (failed(successStrides))
3653       return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
3654     assert(offset == 0 && "expected offset to be 0");
3655 
3656     // Create the descriptor.
3657     MemRefDescriptor sourceMemRef(adaptor.source());
3658     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
3659 
3660     // Field 1: Copy the allocated pointer, used for malloc/free.
3661     Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
3662     auto srcMemRefType = viewOp.source().getType().cast<MemRefType>();
3663     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
3664         loc, targetElementTy.getPointerTo(srcMemRefType.getMemorySpace()),
3665         allocatedPtr);
3666     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
3667 
3668     // Field 2: Copy the actual aligned pointer to payload.
3669     Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
3670     alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(),
3671                                               alignedPtr, adaptor.byte_shift());
3672     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
3673         loc, targetElementTy.getPointerTo(srcMemRefType.getMemorySpace()),
3674         alignedPtr);
3675     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
3676 
3677     // Field 3: The offset in the resulting type must be 0. This is because of
3678     // the type change: an offset on srcType* may not be expressible as an
3679     // offset on dstType*.
3680     targetMemRef.setOffset(rewriter, loc,
3681                            createIndexConstant(rewriter, loc, offset));
3682 
3683     // Early exit for 0-D corner case.
3684     if (viewMemRefType.getRank() == 0)
3685       return rewriter.replaceOp(viewOp, {targetMemRef}), success();
3686 
3687     // Fields 4 and 5: Update sizes and strides.
3688     if (strides.back() != 1)
3689       return viewOp.emitWarning("cannot cast to non-contiguous shape"),
3690              failure();
3691     Value stride = nullptr, nextSize = nullptr;
3692     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
3693       // Update size.
3694       Value size =
3695           getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i);
3696       targetMemRef.setSize(rewriter, loc, i, size);
3697       // Update stride.
3698       stride = getStride(rewriter, loc, strides, nextSize, stride, i);
3699       targetMemRef.setStride(rewriter, loc, i, stride);
3700       nextSize = size;
3701     }
3702 
3703     rewriter.replaceOp(viewOp, {targetMemRef});
3704     return success();
3705   }
3706 };
3707 
3708 struct AssumeAlignmentOpLowering
3709     : public ConvertOpToLLVMPattern<AssumeAlignmentOp> {
3710   using ConvertOpToLLVMPattern<AssumeAlignmentOp>::ConvertOpToLLVMPattern;
3711 
3712   LogicalResult
matchAndRewrite__anone5172fdb1311::AssumeAlignmentOpLowering3713   matchAndRewrite(AssumeAlignmentOp op, ArrayRef<Value> operands,
3714                   ConversionPatternRewriter &rewriter) const override {
3715     AssumeAlignmentOp::Adaptor transformed(operands);
3716     Value memref = transformed.memref();
3717     unsigned alignment = op.alignment();
3718     auto loc = op.getLoc();
3719 
3720     MemRefDescriptor memRefDescriptor(memref);
3721     Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
3722 
3723     // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
3724     // the asserted memref.alignedPtr isn't used anywhere else, as the real
3725     // users like load/store/views always re-extract memref.alignedPtr as they
3726     // get lowered.
3727     //
3728     // This relies on LLVM's CSE optimization (potentially after SROA), since
3729     // after CSE all memref.alignedPtr instances get de-duplicated into the same
3730     // pointer SSA value.
3731     auto intPtrType =
3732         getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
3733     Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
3734     Value mask =
3735         createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
3736     Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
3737     rewriter.create<LLVM::AssumeOp>(
3738         loc, rewriter.create<LLVM::ICmpOp>(
3739                  loc, LLVM::ICmpPredicate::eq,
3740                  rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
3741 
3742     rewriter.eraseOp(op);
3743     return success();
3744   }
3745 };
3746 
3747 } // namespace
3748 
3749 /// Try to match the kind of a std.atomic_rmw to determine whether to use a
3750 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
matchSimpleAtomicOp(AtomicRMWOp atomicOp)3751 static Optional<LLVM::AtomicBinOp> matchSimpleAtomicOp(AtomicRMWOp atomicOp) {
3752   switch (atomicOp.kind()) {
3753   case AtomicRMWKind::addf:
3754     return LLVM::AtomicBinOp::fadd;
3755   case AtomicRMWKind::addi:
3756     return LLVM::AtomicBinOp::add;
3757   case AtomicRMWKind::assign:
3758     return LLVM::AtomicBinOp::xchg;
3759   case AtomicRMWKind::maxs:
3760     return LLVM::AtomicBinOp::max;
3761   case AtomicRMWKind::maxu:
3762     return LLVM::AtomicBinOp::umax;
3763   case AtomicRMWKind::mins:
3764     return LLVM::AtomicBinOp::min;
3765   case AtomicRMWKind::minu:
3766     return LLVM::AtomicBinOp::umin;
3767   default:
3768     return llvm::None;
3769   }
3770   llvm_unreachable("Invalid AtomicRMWKind");
3771 }
3772 
3773 namespace {
3774 
3775 struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
3776   using Base::Base;
3777 
3778   LogicalResult
matchAndRewrite__anone5172fdb1811::AtomicRMWOpLowering3779   matchAndRewrite(AtomicRMWOp atomicOp, ArrayRef<Value> operands,
3780                   ConversionPatternRewriter &rewriter) const override {
3781     if (failed(match(atomicOp)))
3782       return failure();
3783     auto maybeKind = matchSimpleAtomicOp(atomicOp);
3784     if (!maybeKind)
3785       return failure();
3786     AtomicRMWOp::Adaptor adaptor(operands);
3787     auto resultType = adaptor.value().getType();
3788     auto memRefType = atomicOp.getMemRefType();
3789     auto dataPtr =
3790         getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(),
3791                              adaptor.indices(), rewriter);
3792     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
3793         atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(),
3794         LLVM::AtomicOrdering::acq_rel);
3795     return success();
3796   }
3797 };
3798 
3799 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
3800 /// retried until it succeeds in atomically storing a new value into memory.
3801 ///
3802 ///      +---------------------------------+
3803 ///      |   <code before the AtomicRMWOp> |
3804 ///      |   <compute initial %loaded>     |
3805 ///      |   br loop(%loaded)              |
3806 ///      +---------------------------------+
3807 ///             |
3808 ///  -------|   |
3809 ///  |      v   v
3810 ///  |   +--------------------------------+
3811 ///  |   | loop(%loaded):                 |
3812 ///  |   |   <body contents>              |
3813 ///  |   |   %pair = cmpxchg              |
3814 ///  |   |   %ok = %pair[0]               |
3815 ///  |   |   %new = %pair[1]              |
3816 ///  |   |   cond_br %ok, end, loop(%new) |
3817 ///  |   +--------------------------------+
3818 ///  |          |        |
3819 ///  |-----------        |
3820 ///                      v
3821 ///      +--------------------------------+
3822 ///      | end:                           |
3823 ///      |   <code after the AtomicRMWOp> |
3824 ///      +--------------------------------+
3825 ///
3826 struct GenericAtomicRMWOpLowering
3827     : public LoadStoreOpLowering<GenericAtomicRMWOp> {
3828   using Base::Base;
3829 
3830   LogicalResult
matchAndRewrite__anone5172fdb1811::GenericAtomicRMWOpLowering3831   matchAndRewrite(GenericAtomicRMWOp atomicOp, ArrayRef<Value> operands,
3832                   ConversionPatternRewriter &rewriter) const override {
3833 
3834     auto loc = atomicOp.getLoc();
3835     GenericAtomicRMWOp::Adaptor adaptor(operands);
3836     LLVM::LLVMType valueType =
3837         typeConverter->convertType(atomicOp.getResult().getType())
3838             .cast<LLVM::LLVMType>();
3839 
3840     // Split the block into initial, loop, and ending parts.
3841     auto *initBlock = rewriter.getInsertionBlock();
3842     auto *loopBlock =
3843         rewriter.createBlock(initBlock->getParent(),
3844                              std::next(Region::iterator(initBlock)), valueType);
3845     auto *endBlock = rewriter.createBlock(
3846         loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
3847 
3848     // Operations range to be moved to `endBlock`.
3849     auto opsToMoveStart = atomicOp->getIterator();
3850     auto opsToMoveEnd = initBlock->back().getIterator();
3851 
3852     // Compute the loaded value and branch to the loop block.
3853     rewriter.setInsertionPointToEnd(initBlock);
3854     auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
3855     auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
3856                                         adaptor.indices(), rewriter);
3857     Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
3858     rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
3859 
3860     // Prepare the body of the loop block.
3861     rewriter.setInsertionPointToStart(loopBlock);
3862 
3863     // Clone the GenericAtomicRMWOp region and extract the result.
3864     auto loopArgument = loopBlock->getArgument(0);
3865     BlockAndValueMapping mapping;
3866     mapping.map(atomicOp.getCurrentValue(), loopArgument);
3867     Block &entryBlock = atomicOp.body().front();
3868     for (auto &nestedOp : entryBlock.without_terminator()) {
3869       Operation *clone = rewriter.clone(nestedOp, mapping);
3870       mapping.map(nestedOp.getResults(), clone->getResults());
3871     }
3872     Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
3873 
3874     // Prepare the epilog of the loop block.
3875     // Append the cmpxchg op to the end of the loop block.
3876     auto successOrdering = LLVM::AtomicOrdering::acq_rel;
3877     auto failureOrdering = LLVM::AtomicOrdering::monotonic;
3878     auto boolType = LLVM::LLVMType::getInt1Ty(rewriter.getContext());
3879     auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType);
3880     auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
3881         loc, pairType, dataPtr, loopArgument, result, successOrdering,
3882         failureOrdering);
3883     // Extract the %new_loaded and %ok values from the pair.
3884     Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
3885         loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
3886     Value ok = rewriter.create<LLVM::ExtractValueOp>(
3887         loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
3888 
3889     // Conditionally branch to the end or back to the loop depending on %ok.
3890     rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
3891                                     loopBlock, newLoaded);
3892 
3893     rewriter.setInsertionPointToEnd(endBlock);
3894     moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
3895                  std::next(opsToMoveEnd), rewriter);
3896 
3897     // The 'result' of the atomic_rmw op is the newly loaded value.
3898     rewriter.replaceOp(atomicOp, {newLoaded});
3899 
3900     return success();
3901   }
3902 
3903 private:
3904   // Clones a segment of ops [start, end) and erases the original.
moveOpsRange__anone5172fdb1811::GenericAtomicRMWOpLowering3905   void moveOpsRange(ValueRange oldResult, ValueRange newResult,
3906                     Block::iterator start, Block::iterator end,
3907                     ConversionPatternRewriter &rewriter) const {
3908     BlockAndValueMapping mapping;
3909     mapping.map(oldResult, newResult);
3910     SmallVector<Operation *, 2> opsToErase;
3911     for (auto it = start; it != end; ++it) {
3912       rewriter.clone(*it, mapping);
3913       opsToErase.push_back(&*it);
3914     }
3915     for (auto *it : opsToErase)
3916       rewriter.eraseOp(it);
3917   }
3918 };
3919 
3920 } // namespace
3921 
3922 /// Collect a set of patterns to convert from the Standard dialect to LLVM.
populateStdToLLVMNonMemoryConversionPatterns(LLVMTypeConverter & converter,OwningRewritePatternList & patterns)3923 void mlir::populateStdToLLVMNonMemoryConversionPatterns(
3924     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
3925   // FIXME: this should be tablegen'ed
3926   // clang-format off
3927   patterns.insert<
3928       AbsFOpLowering,
3929       AddCFOpLowering,
3930       AddFOpLowering,
3931       AddIOpLowering,
3932       AllocaOpLowering,
3933       AndOpLowering,
3934       AssertOpLowering,
3935       AtomicRMWOpLowering,
3936       BranchOpLowering,
3937       CallIndirectOpLowering,
3938       CallOpLowering,
3939       CeilFOpLowering,
3940       CmpFOpLowering,
3941       CmpIOpLowering,
3942       CondBranchOpLowering,
3943       CopySignOpLowering,
3944       CosOpLowering,
3945       ConstantOpLowering,
3946       CreateComplexOpLowering,
3947       DialectCastOpLowering,
3948       DivFOpLowering,
3949       ExpOpLowering,
3950       Exp2OpLowering,
3951       FloorFOpLowering,
3952       GenericAtomicRMWOpLowering,
3953       LogOpLowering,
3954       Log10OpLowering,
3955       Log2OpLowering,
3956       FPExtLowering,
3957       FPToSILowering,
3958       FPToUILowering,
3959       FPTruncLowering,
3960       ImOpLowering,
3961       IndexCastOpLowering,
3962       MulFOpLowering,
3963       MulIOpLowering,
3964       NegFOpLowering,
3965       OrOpLowering,
3966       PrefetchOpLowering,
3967       ReOpLowering,
3968       RemFOpLowering,
3969       ReturnOpLowering,
3970       RsqrtOpLowering,
3971       SIToFPLowering,
3972       SelectOpLowering,
3973       ShiftLeftOpLowering,
3974       SignExtendIOpLowering,
3975       SignedDivIOpLowering,
3976       SignedRemIOpLowering,
3977       SignedShiftRightOpLowering,
3978       SinOpLowering,
3979       SplatOpLowering,
3980       SplatNdOpLowering,
3981       SqrtOpLowering,
3982       SubCFOpLowering,
3983       SubFOpLowering,
3984       SubIOpLowering,
3985       TruncateIOpLowering,
3986       UIToFPLowering,
3987       UnsignedDivIOpLowering,
3988       UnsignedRemIOpLowering,
3989       UnsignedShiftRightOpLowering,
3990       XOrOpLowering,
3991       ZeroExtendIOpLowering>(converter);
3992   // clang-format on
3993 }
3994 
populateStdToLLVMMemoryConversionPatterns(LLVMTypeConverter & converter,OwningRewritePatternList & patterns)3995 void mlir::populateStdToLLVMMemoryConversionPatterns(
3996     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
3997   // clang-format off
3998   patterns.insert<
3999       AssumeAlignmentOpLowering,
4000       DeallocOpLowering,
4001       DimOpLowering,
4002       GlobalMemrefOpLowering,
4003       GetGlobalMemrefOpLowering,
4004       LoadOpLowering,
4005       MemRefCastOpLowering,
4006       MemRefReinterpretCastOpLowering,
4007       MemRefReshapeOpLowering,
4008       RankOpLowering,
4009       StoreOpLowering,
4010       SubViewOpLowering,
4011       TransposeOpLowering,
4012       ViewOpLowering>(converter);
4013   // clang-format on
4014   if (converter.getOptions().useAlignedAlloc)
4015     patterns.insert<AlignedAllocOpLowering>(converter);
4016   else
4017     patterns.insert<AllocOpLowering>(converter);
4018 }
4019 
populateStdToLLVMFuncOpConversionPattern(LLVMTypeConverter & converter,OwningRewritePatternList & patterns)4020 void mlir::populateStdToLLVMFuncOpConversionPattern(
4021     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
4022   if (converter.getOptions().useBarePtrCallConv)
4023     patterns.insert<BarePtrFuncOpConversion>(converter);
4024   else
4025     patterns.insert<FuncOpConversion>(converter);
4026 }
4027 
populateStdToLLVMConversionPatterns(LLVMTypeConverter & converter,OwningRewritePatternList & patterns)4028 void mlir::populateStdToLLVMConversionPatterns(
4029     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
4030   populateStdToLLVMFuncOpConversionPattern(converter, patterns);
4031   populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);
4032   populateStdToLLVMMemoryConversionPatterns(converter, patterns);
4033 }
4034 
4035 /// Convert a non-empty list of types to be returned from a function into a
4036 /// supported LLVM IR type.  In particular, if more than one value is returned,
4037 /// create an LLVM IR structure type with elements that correspond to each of
4038 /// the MLIR types converted with `convertType`.
packFunctionResults(ArrayRef<Type> types)4039 Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) {
4040   assert(!types.empty() && "expected non-empty list of type");
4041 
4042   if (types.size() == 1)
4043     return convertCallingConventionType(types.front());
4044 
4045   SmallVector<LLVM::LLVMType, 8> resultTypes;
4046   resultTypes.reserve(types.size());
4047   for (auto t : types) {
4048     auto converted =
4049         convertCallingConventionType(t).dyn_cast_or_null<LLVM::LLVMType>();
4050     if (!converted)
4051       return {};
4052     resultTypes.push_back(converted);
4053   }
4054 
4055   return LLVM::LLVMType::getStructTy(&getContext(), resultTypes);
4056 }
4057 
promoteOneMemRefDescriptor(Location loc,Value operand,OpBuilder & builder)4058 Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
4059                                                     OpBuilder &builder) {
4060   auto *context = builder.getContext();
4061   auto int64Ty = LLVM::LLVMType::getInt64Ty(builder.getContext());
4062   auto indexType = IndexType::get(context);
4063   // Alloca with proper alignment. We do not expect optimizations of this
4064   // alloca op and so we omit allocating at the entry block.
4065   auto ptrType = operand.getType().cast<LLVM::LLVMType>().getPointerTo();
4066   Value one = builder.create<LLVM::ConstantOp>(loc, int64Ty,
4067                                                IntegerAttr::get(indexType, 1));
4068   Value allocated =
4069       builder.create<LLVM::AllocaOp>(loc, ptrType, one, /*alignment=*/0);
4070   // Store into the alloca'ed descriptor.
4071   builder.create<LLVM::StoreOp>(loc, operand, allocated);
4072   return allocated;
4073 }
4074 
promoteOperands(Location loc,ValueRange opOperands,ValueRange operands,OpBuilder & builder)4075 SmallVector<Value, 4> LLVMTypeConverter::promoteOperands(Location loc,
4076                                                          ValueRange opOperands,
4077                                                          ValueRange operands,
4078                                                          OpBuilder &builder) {
4079   SmallVector<Value, 4> promotedOperands;
4080   promotedOperands.reserve(operands.size());
4081   for (auto it : llvm::zip(opOperands, operands)) {
4082     auto operand = std::get<0>(it);
4083     auto llvmOperand = std::get<1>(it);
4084 
4085     if (options.useBarePtrCallConv) {
4086       // For the bare-ptr calling convention, we only have to extract the
4087       // aligned pointer of a memref.
4088       if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
4089         MemRefDescriptor desc(llvmOperand);
4090         llvmOperand = desc.alignedPtr(builder, loc);
4091       } else if (operand.getType().isa<UnrankedMemRefType>()) {
4092         llvm_unreachable("Unranked memrefs are not supported");
4093       }
4094     } else {
4095       if (operand.getType().isa<UnrankedMemRefType>()) {
4096         UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
4097                                          promotedOperands);
4098         continue;
4099       }
4100       if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
4101         MemRefDescriptor::unpack(builder, loc, llvmOperand,
4102                                  operand.getType().cast<MemRefType>(),
4103                                  promotedOperands);
4104         continue;
4105       }
4106     }
4107 
4108     promotedOperands.push_back(llvmOperand);
4109   }
4110   return promotedOperands;
4111 }
4112 
4113 namespace {
4114 /// A pass converting MLIR operations into the LLVM IR dialect.
4115 struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
4116   LLVMLoweringPass() = default;
LLVMLoweringPass__anone5172fdb1911::LLVMLoweringPass4117   LLVMLoweringPass(bool useBarePtrCallConv, bool emitCWrappers,
4118                    unsigned indexBitwidth, bool useAlignedAlloc,
4119                    const llvm::DataLayout &dataLayout) {
4120     this->useBarePtrCallConv = useBarePtrCallConv;
4121     this->emitCWrappers = emitCWrappers;
4122     this->indexBitwidth = indexBitwidth;
4123     this->useAlignedAlloc = useAlignedAlloc;
4124     this->dataLayout = dataLayout.getStringRepresentation();
4125   }
4126 
4127   /// Run the dialect converter on the module.
runOnOperation__anone5172fdb1911::LLVMLoweringPass4128   void runOnOperation() override {
4129     if (useBarePtrCallConv && emitCWrappers) {
4130       getOperation().emitError()
4131           << "incompatible conversion options: bare-pointer calling convention "
4132              "and C wrapper emission";
4133       signalPassFailure();
4134       return;
4135     }
4136     if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
4137             this->dataLayout, [this](const Twine &message) {
4138               getOperation().emitError() << message.str();
4139             }))) {
4140       signalPassFailure();
4141       return;
4142     }
4143 
4144     ModuleOp m = getOperation();
4145 
4146     LowerToLLVMOptions options = {useBarePtrCallConv, emitCWrappers,
4147                                   indexBitwidth, useAlignedAlloc,
4148                                   llvm::DataLayout(this->dataLayout)};
4149     LLVMTypeConverter typeConverter(&getContext(), options);
4150 
4151     OwningRewritePatternList patterns;
4152     populateStdToLLVMConversionPatterns(typeConverter, patterns);
4153 
4154     LLVMConversionTarget target(getContext());
4155     if (failed(applyPartialConversion(m, target, std::move(patterns))))
4156       signalPassFailure();
4157     m.setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
4158               StringAttr::get(this->dataLayout, m.getContext()));
4159   }
4160 };
4161 } // end namespace
4162 
LLVMConversionTarget(MLIRContext & ctx)4163 mlir::LLVMConversionTarget::LLVMConversionTarget(MLIRContext &ctx)
4164     : ConversionTarget(ctx) {
4165   this->addLegalDialect<LLVM::LLVMDialect>();
4166   this->addIllegalOp<LLVM::DialectCastOp>();
4167   this->addIllegalOp<TanhOp>();
4168 }
4169 
4170 std::unique_ptr<OperationPass<ModuleOp>>
createLowerToLLVMPass(const LowerToLLVMOptions & options)4171 mlir::createLowerToLLVMPass(const LowerToLLVMOptions &options) {
4172   return std::make_unique<LLVMLoweringPass>(
4173       options.useBarePtrCallConv, options.emitCWrappers, options.indexBitwidth,
4174       options.useAlignedAlloc, options.dataLayout);
4175 }
4176