1 //===- ConvertStandardToLLVM.h - Convert to the LLVM dialect ----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Provides a dialect conversion targeting the LLVM IR dialect. By default, it 10 // converts Standard ops and types and provides hooks for dialect-specific 11 // extensions to the conversion. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H 16 #define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H 17 18 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 21 namespace llvm { 22 class IntegerType; 23 class LLVMContext; 24 class Module; 25 class Type; 26 } // namespace llvm 27 28 namespace mlir { 29 30 class BaseMemRefType; 31 class ComplexType; 32 class LLVMTypeConverter; 33 class UnrankedMemRefType; 34 35 namespace LLVM { 36 class LLVMDialect; 37 class LLVMType; 38 class LLVMPointerType; 39 } // namespace LLVM 40 41 /// Callback to convert function argument types. It converts a MemRef function 42 /// argument to a list of non-aggregate types containing descriptor 43 /// information, and an UnrankedmemRef function argument to a list containing 44 /// the rank and a pointer to a descriptor struct. 45 LogicalResult structFuncArgTypeConverter(LLVMTypeConverter &converter, 46 Type type, 47 SmallVectorImpl<Type> &result); 48 49 /// Callback to convert function argument types. It converts MemRef function 50 /// arguments to bare pointers to the MemRef element type. 51 LogicalResult barePtrFuncArgTypeConverter(LLVMTypeConverter &converter, 52 Type type, 53 SmallVectorImpl<Type> &result); 54 55 /// Conversion from types in the Standard dialect to the LLVM IR dialect. 56 class LLVMTypeConverter : public TypeConverter { 57 /// Give structFuncArgTypeConverter access to memref-specific functions. 58 friend LogicalResult 59 structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, 60 SmallVectorImpl<Type> &result); 61 62 public: 63 using TypeConverter::convertType; 64 65 /// Create an LLVMTypeConverter using the default LowerToLLVMOptions. 66 LLVMTypeConverter(MLIRContext *ctx); 67 68 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions. 69 LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options); 70 71 /// Convert a function type. The arguments and results are converted one by 72 /// one and results are packed into a wrapped LLVM IR structure type. `result` 73 /// is populated with argument mapping. 74 LLVM::LLVMType convertFunctionSignature(FunctionType funcTy, bool isVariadic, 75 SignatureConversion &result); 76 77 /// Convert a non-empty list of types to be returned from a function into a 78 /// supported LLVM IR type. In particular, if more than one value is 79 /// returned, create an LLVM IR structure type with elements that correspond 80 /// to each of the MLIR types converted with `convertType`. 81 Type packFunctionResults(ArrayRef<Type> types); 82 83 /// Convert a type in the context of the default or bare pointer calling 84 /// convention. Calling convention sensitive types, such as MemRefType and 85 /// UnrankedMemRefType, are converted following the specific rules for the 86 /// calling convention. Calling convention independent types are converted 87 /// following the default LLVM type conversions. 88 Type convertCallingConventionType(Type type); 89 90 /// Promote the bare pointers in 'values' that resulted from memrefs to 91 /// descriptors. 'stdTypes' holds the types of 'values' before the conversion 92 /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type). 93 void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter, 94 Location loc, ArrayRef<Type> stdTypes, 95 SmallVectorImpl<Value> &values); 96 97 /// Returns the MLIR context. 98 MLIRContext &getContext(); 99 100 /// Returns the LLVM dialect. getDialect()101 LLVM::LLVMDialect *getDialect() { return llvmDialect; } 102 getOptions()103 const LowerToLLVMOptions &getOptions() const { return options; } 104 105 /// Promote the LLVM representation of all operands including promoting MemRef 106 /// descriptors to stack and use pointers to struct to avoid the complexity 107 /// of the platform-specific C/C++ ABI lowering related to struct argument 108 /// passing. 109 SmallVector<Value, 4> promoteOperands(Location loc, ValueRange opOperands, 110 ValueRange operands, 111 OpBuilder &builder); 112 113 /// Promote the LLVM struct representation of one MemRef descriptor to stack 114 /// and use pointer to struct to avoid the complexity of the platform-specific 115 /// C/C++ ABI lowering related to struct argument passing. 116 Value promoteOneMemRefDescriptor(Location loc, Value operand, 117 OpBuilder &builder); 118 119 /// Converts the function type to a C-compatible format, in particular using 120 /// pointers to memref descriptors for arguments. 121 LLVM::LLVMType convertFunctionTypeCWrapper(FunctionType type); 122 123 /// Returns the data layout to use during and after conversion. getDataLayout()124 const llvm::DataLayout &getDataLayout() { return options.dataLayout; } 125 126 /// Gets the LLVM representation of the index type. The returned type is an 127 /// integer type with the size configured for this type converter. 128 LLVM::LLVMType getIndexType(); 129 130 /// Gets the bitwidth of the index type when converted to LLVM. getIndexTypeBitwidth()131 unsigned getIndexTypeBitwidth() { return options.indexBitwidth; } 132 133 /// Gets the pointer bitwidth. 134 unsigned getPointerBitwidth(unsigned addressSpace = 0); 135 136 protected: 137 /// Pointer to the LLVM dialect. 138 LLVM::LLVMDialect *llvmDialect; 139 140 private: 141 /// Convert a function type. The arguments and results are converted one by 142 /// one. Additionally, if the function returns more than one value, pack the 143 /// results into an LLVM IR structure type so that the converted function type 144 /// returns at most one result. 145 Type convertFunctionType(FunctionType type); 146 147 /// Convert the index type. Uses llvmModule data layout to create an integer 148 /// of the pointer bitwidth. 149 Type convertIndexType(IndexType type); 150 151 /// Convert an integer type `i*` to `!llvm<"i*">`. 152 Type convertIntegerType(IntegerType type); 153 154 /// Convert a floating point type: `f16` to `!llvm.half`, `f32` to 155 /// `!llvm.float` and `f64` to `!llvm.double`. `bf16` is not supported 156 /// by LLVM. 157 Type convertFloatType(FloatType type); 158 159 /// Convert complex number type: `complex<f16>` to `!llvm<"{ half, half }">`, 160 /// `complex<f32>` to `!llvm<"{ float, float }">`, and `complex<f64>` to 161 /// `!llvm<"{ double, double }">`. `complex<bf16>` is not supported. 162 Type convertComplexType(ComplexType type); 163 164 /// Convert a memref type into an LLVM type that captures the relevant data. 165 Type convertMemRefType(MemRefType type); 166 167 /// Convert a memref type into a list of LLVM IR types that will form the 168 /// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides` 169 /// arrays in the descriptors are unpacked to individual index-typed elements, 170 /// else they are are kept as rank-sized arrays of index type. In particular, 171 /// the list will contain: 172 /// - two pointers to the memref element type, followed by 173 /// - an index-typed offset, followed by 174 /// - (if unpackAggregates = true) 175 /// - one index-typed size per dimension of the memref, followed by 176 /// - one index-typed stride per dimension of the memref. 177 /// - (if unpackArrregates = false) 178 /// - one rank-sized array of index-type for the size of each dimension 179 /// - one rank-sized array of index-type for the stride of each dimension 180 /// 181 /// For example, memref<?x?xf32> is converted to the following list: 182 /// - `!llvm<"float*">` (allocated pointer), 183 /// - `!llvm<"float*">` (aligned pointer), 184 /// - `!llvm.i64` (offset), 185 /// - `!llvm.i64`, `!llvm.i64` (sizes), 186 /// - `!llvm.i64`, `!llvm.i64` (strides). 187 /// These types can be recomposed to a memref descriptor struct. 188 SmallVector<LLVM::LLVMType, 5> 189 getMemRefDescriptorFields(MemRefType type, bool unpackAggregates); 190 191 /// Convert an unranked memref type into a list of non-aggregate LLVM IR types 192 /// that will form the unranked memref descriptor. In particular, this list 193 /// contains: 194 /// - an integer rank, followed by 195 /// - a pointer to the memref descriptor struct. 196 /// For example, memref<*xf32> is converted to the following list: 197 /// !llvm.i64 (rank) 198 /// !llvm<"i8*"> (type-erased pointer). 199 /// These types can be recomposed to a unranked memref descriptor struct. 200 SmallVector<LLVM::LLVMType, 2> getUnrankedMemRefDescriptorFields(); 201 202 // Convert an unranked memref type to an LLVM type that captures the 203 // runtime rank and a pointer to the static ranked memref desc 204 Type convertUnrankedMemRefType(UnrankedMemRefType type); 205 206 /// Convert a memref type to a bare pointer to the memref element type. 207 Type convertMemRefToBarePtr(BaseMemRefType type); 208 209 // Convert a 1D vector type into an LLVM vector type. 210 Type convertVectorType(VectorType type); 211 212 /// Options for customizing the llvm lowering. 213 LowerToLLVMOptions options; 214 }; 215 216 /// Helper class to produce LLVM dialect operations extracting or inserting 217 /// values to a struct. 218 class StructBuilder { 219 public: 220 /// Construct a helper for the given value. 221 explicit StructBuilder(Value v); 222 /// Builds IR creating an `undef` value of the descriptor type. 223 static StructBuilder undef(OpBuilder &builder, Location loc, 224 Type descriptorType); 225 Value()226 /*implicit*/ operator Value() { return value; } 227 228 protected: 229 // LLVM value 230 Value value; 231 // Cached struct type. 232 Type structType; 233 234 protected: 235 /// Builds IR to extract a value from the struct at position pos 236 Value extractPtr(OpBuilder &builder, Location loc, unsigned pos); 237 /// Builds IR to set a value in the struct at position pos 238 void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr); 239 }; 240 241 class ComplexStructBuilder : public StructBuilder { 242 public: 243 /// Construct a helper for the given complex number value. 244 using StructBuilder::StructBuilder; 245 /// Build IR creating an `undef` value of the complex number type. 246 static ComplexStructBuilder undef(OpBuilder &builder, Location loc, 247 Type type); 248 249 // Build IR extracting the real value from the complex number struct. 250 Value real(OpBuilder &builder, Location loc); 251 // Build IR inserting the real value into the complex number struct. 252 void setReal(OpBuilder &builder, Location loc, Value real); 253 254 // Build IR extracting the imaginary value from the complex number struct. 255 Value imaginary(OpBuilder &builder, Location loc); 256 // Build IR inserting the imaginary value into the complex number struct. 257 void setImaginary(OpBuilder &builder, Location loc, Value imaginary); 258 }; 259 260 /// Helper class to produce LLVM dialect operations extracting or inserting 261 /// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor. 262 /// The Value may be null, in which case none of the operations are valid. 263 class MemRefDescriptor : public StructBuilder { 264 public: 265 /// Construct a helper for the given descriptor value. 266 explicit MemRefDescriptor(Value descriptor); 267 /// Builds IR creating an `undef` value of the descriptor type. 268 static MemRefDescriptor undef(OpBuilder &builder, Location loc, 269 Type descriptorType); 270 /// Builds IR creating a MemRef descriptor that represents `type` and 271 /// populates it with static shape and stride information extracted from the 272 /// type. 273 static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, 274 LLVMTypeConverter &typeConverter, 275 MemRefType type, Value memory); 276 277 /// Builds IR extracting the allocated pointer from the descriptor. 278 Value allocatedPtr(OpBuilder &builder, Location loc); 279 /// Builds IR inserting the allocated pointer into the descriptor. 280 void setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr); 281 282 /// Builds IR extracting the aligned pointer from the descriptor. 283 Value alignedPtr(OpBuilder &builder, Location loc); 284 285 /// Builds IR inserting the aligned pointer into the descriptor. 286 void setAlignedPtr(OpBuilder &builder, Location loc, Value ptr); 287 288 /// Builds IR extracting the offset from the descriptor. 289 Value offset(OpBuilder &builder, Location loc); 290 291 /// Builds IR inserting the offset into the descriptor. 292 void setOffset(OpBuilder &builder, Location loc, Value offset); 293 void setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset); 294 295 /// Builds IR extracting the pos-th size from the descriptor. 296 Value size(OpBuilder &builder, Location loc, unsigned pos); 297 Value size(OpBuilder &builder, Location loc, Value pos, int64_t rank); 298 299 /// Builds IR inserting the pos-th size into the descriptor 300 void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size); 301 void setConstantSize(OpBuilder &builder, Location loc, unsigned pos, 302 uint64_t size); 303 304 /// Builds IR extracting the pos-th size from the descriptor. 305 Value stride(OpBuilder &builder, Location loc, unsigned pos); 306 307 /// Builds IR inserting the pos-th stride into the descriptor 308 void setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride); 309 void setConstantStride(OpBuilder &builder, Location loc, unsigned pos, 310 uint64_t stride); 311 312 /// Returns the (LLVM) pointer type this descriptor contains. 313 LLVM::LLVMPointerType getElementPtrType(); 314 315 /// Builds IR populating a MemRef descriptor structure from a list of 316 /// individual values composing that descriptor, in the following order: 317 /// - allocated pointer; 318 /// - aligned pointer; 319 /// - offset; 320 /// - <rank> sizes; 321 /// - <rank> shapes; 322 /// where <rank> is the MemRef rank as provided in `type`. 323 static Value pack(OpBuilder &builder, Location loc, 324 LLVMTypeConverter &converter, MemRefType type, 325 ValueRange values); 326 327 /// Builds IR extracting individual elements of a MemRef descriptor structure 328 /// and returning them as `results` list. 329 static void unpack(OpBuilder &builder, Location loc, Value packed, 330 MemRefType type, SmallVectorImpl<Value> &results); 331 332 /// Returns the number of non-aggregate values that would be produced by 333 /// `unpack`. 334 static unsigned getNumUnpackedValues(MemRefType type); 335 336 private: 337 // Cached index type. 338 Type indexType; 339 }; 340 341 /// Helper class allowing the user to access a range of Values that correspond 342 /// to an unpacked memref descriptor using named accessors. This does not own 343 /// the values. 344 class MemRefDescriptorView { 345 public: 346 /// Constructs the view from a range of values. Infers the rank from the size 347 /// of the range. 348 explicit MemRefDescriptorView(ValueRange range); 349 350 /// Returns the allocated pointer Value. 351 Value allocatedPtr(); 352 353 /// Returns the aligned pointer Value. 354 Value alignedPtr(); 355 356 /// Returns the offset Value. 357 Value offset(); 358 359 /// Returns the pos-th size Value. 360 Value size(unsigned pos); 361 362 /// Returns the pos-th stride Value. 363 Value stride(unsigned pos); 364 365 private: 366 /// Rank of the memref the descriptor is pointing to. 367 int rank; 368 /// Underlying range of Values. 369 ValueRange elements; 370 }; 371 372 class UnrankedMemRefDescriptor : public StructBuilder { 373 public: 374 /// Construct a helper for the given descriptor value. 375 explicit UnrankedMemRefDescriptor(Value descriptor); 376 /// Builds IR creating an `undef` value of the descriptor type. 377 static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, 378 Type descriptorType); 379 380 /// Builds IR extracting the rank from the descriptor 381 Value rank(OpBuilder &builder, Location loc); 382 /// Builds IR setting the rank in the descriptor 383 void setRank(OpBuilder &builder, Location loc, Value value); 384 /// Builds IR extracting ranked memref descriptor ptr 385 Value memRefDescPtr(OpBuilder &builder, Location loc); 386 /// Builds IR setting ranked memref descriptor ptr 387 void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value); 388 389 /// Builds IR populating an unranked MemRef descriptor structure from a list 390 /// of individual constituent values in the following order: 391 /// - rank of the memref; 392 /// - pointer to the memref descriptor. 393 static Value pack(OpBuilder &builder, Location loc, 394 LLVMTypeConverter &converter, UnrankedMemRefType type, 395 ValueRange values); 396 397 /// Builds IR extracting individual elements that compose an unranked memref 398 /// descriptor and returns them as `results` list. 399 static void unpack(OpBuilder &builder, Location loc, Value packed, 400 SmallVectorImpl<Value> &results); 401 402 /// Returns the number of non-aggregate values that would be produced by 403 /// `unpack`. getNumUnpackedValues()404 static unsigned getNumUnpackedValues() { return 2; } 405 406 /// Builds IR computing the sizes in bytes (suitable for opaque allocation) 407 /// and appends the corresponding values into `sizes`. 408 static void computeSizes(OpBuilder &builder, Location loc, 409 LLVMTypeConverter &typeConverter, 410 ArrayRef<UnrankedMemRefDescriptor> values, 411 SmallVectorImpl<Value> &sizes); 412 413 /// TODO: The following accessors don't take alignment rules between elements 414 /// of the descriptor struct into account. For some architectures, it might be 415 /// necessary to extend them and to use `llvm::DataLayout` contained in 416 /// `LLVMTypeConverter`. 417 418 /// Builds IR extracting the allocated pointer from the descriptor. 419 static Value allocatedPtr(OpBuilder &builder, Location loc, 420 Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType); 421 /// Builds IR inserting the allocated pointer into the descriptor. 422 static void setAllocatedPtr(OpBuilder &builder, Location loc, 423 Value memRefDescPtr, 424 LLVM::LLVMType elemPtrPtrType, 425 Value allocatedPtr); 426 427 /// Builds IR extracting the aligned pointer from the descriptor. 428 static Value alignedPtr(OpBuilder &builder, Location loc, 429 LLVMTypeConverter &typeConverter, Value memRefDescPtr, 430 LLVM::LLVMType elemPtrPtrType); 431 /// Builds IR inserting the aligned pointer into the descriptor. 432 static void setAlignedPtr(OpBuilder &builder, Location loc, 433 LLVMTypeConverter &typeConverter, 434 Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType, 435 Value alignedPtr); 436 437 /// Builds IR extracting the offset from the descriptor. 438 static Value offset(OpBuilder &builder, Location loc, 439 LLVMTypeConverter &typeConverter, Value memRefDescPtr, 440 LLVM::LLVMType elemPtrPtrType); 441 /// Builds IR inserting the offset into the descriptor. 442 static void setOffset(OpBuilder &builder, Location loc, 443 LLVMTypeConverter &typeConverter, Value memRefDescPtr, 444 LLVM::LLVMType elemPtrPtrType, Value offset); 445 446 /// Builds IR extracting the pointer to the first element of the size array. 447 static Value sizeBasePtr(OpBuilder &builder, Location loc, 448 LLVMTypeConverter &typeConverter, 449 Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType); 450 /// Builds IR extracting the size[index] from the descriptor. 451 static Value size(OpBuilder &builder, Location loc, 452 LLVMTypeConverter typeConverter, Value sizeBasePtr, 453 Value index); 454 /// Builds IR inserting the size[index] into the descriptor. 455 static void setSize(OpBuilder &builder, Location loc, 456 LLVMTypeConverter typeConverter, Value sizeBasePtr, 457 Value index, Value size); 458 459 /// Builds IR extracting the pointer to the first element of the stride array. 460 static Value strideBasePtr(OpBuilder &builder, Location loc, 461 LLVMTypeConverter &typeConverter, 462 Value sizeBasePtr, Value rank); 463 /// Builds IR extracting the stride[index] from the descriptor. 464 static Value stride(OpBuilder &builder, Location loc, 465 LLVMTypeConverter typeConverter, Value strideBasePtr, 466 Value index, Value stride); 467 /// Builds IR inserting the stride[index] into the descriptor. 468 static void setStride(OpBuilder &builder, Location loc, 469 LLVMTypeConverter typeConverter, Value strideBasePtr, 470 Value index, Value stride); 471 }; 472 473 /// Base class for operation conversions targeting the LLVM IR dialect. It 474 /// provides the conversion patterns with access to the LLVMTypeConverter and 475 /// the LowerToLLVMOptions. The class captures the LLVMTypeConverter and the 476 /// LowerToLLVMOptions by reference meaning the references have to remain alive 477 /// during the entire pattern lifetime. 478 class ConvertToLLVMPattern : public ConversionPattern { 479 public: 480 ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, 481 LLVMTypeConverter &typeConverter, 482 PatternBenefit benefit = 1); 483 484 protected: 485 /// Returns the LLVM dialect. 486 LLVM::LLVMDialect &getDialect() const; 487 488 LLVMTypeConverter *getTypeConverter() const; 489 490 /// Gets the MLIR type wrapping the LLVM integer type whose bit width is 491 /// defined by the used type converter. 492 LLVM::LLVMType getIndexType() const; 493 494 /// Gets the MLIR type wrapping the LLVM integer type whose bit width 495 /// corresponds to that of a LLVM pointer type. 496 LLVM::LLVMType getIntPtrType(unsigned addressSpace = 0) const; 497 498 /// Gets the MLIR type wrapping the LLVM void type. 499 LLVM::LLVMType getVoidType() const; 500 501 /// Get the MLIR type wrapping the LLVM i8* type. 502 LLVM::LLVMType getVoidPtrType() const; 503 504 /// Create an LLVM dialect operation defining the given index constant. 505 Value createIndexConstant(ConversionPatternRewriter &builder, Location loc, 506 uint64_t value) const; 507 508 // This is a strided getElementPtr variant that linearizes subscripts as: 509 // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. 510 Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, 511 ValueRange indices, 512 ConversionPatternRewriter &rewriter) const; 513 514 // Forwards to getStridedElementPtr. TODO: remove. 515 Value getDataPtr(Location loc, MemRefType type, Value memRefDesc, 516 ValueRange indices, 517 ConversionPatternRewriter &rewriter) const; 518 519 /// Returns if the givem memref type is supported. 520 bool isSupportedMemRefType(MemRefType type) const; 521 522 /// Returns the type of a pointer to an element of the memref. 523 Type getElementPtrType(MemRefType type) const; 524 525 /// Computes sizes, strides and buffer size in bytes of `memRefType` with 526 /// identity layout. Emits constant ops for the static sizes of `memRefType`, 527 /// and uses `dynamicSizes` for the others. Emits instructions to compute 528 /// strides and buffer size from these sizes. 529 /// 530 /// For example, memref<4x?xf32> emits: 531 /// `sizes[0]` = llvm.mlir.constant(4 : index) : !llvm.i64 532 /// `sizes[1]` = `dynamicSizes[0]` 533 /// `strides[1]` = llvm.mlir.constant(1 : index) : !llvm.i64 534 /// `strides[0]` = `sizes[0]` 535 /// %size = llvm.mul `sizes[0]`, `sizes[1]` : !llvm.i64 536 /// %nullptr = llvm.mlir.null : !llvm.ptr<float> 537 /// %gep = llvm.getelementptr %nullptr[%size] 538 /// : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float> 539 /// `sizeBytes` = llvm.ptrtoint %gep : !llvm.ptr<float> to !llvm.i64 540 void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, 541 ArrayRef<Value> dynamicSizes, 542 ConversionPatternRewriter &rewriter, 543 SmallVectorImpl<Value> &sizes, 544 SmallVectorImpl<Value> &strides, 545 Value &sizeBytes) const; 546 547 /// Computes the size of type in bytes. 548 Value getSizeInBytes(Location loc, Type type, 549 ConversionPatternRewriter &rewriter) const; 550 551 /// Computes total number of elements for the given shape. 552 Value getNumElements(Location loc, ArrayRef<Value> shape, 553 ConversionPatternRewriter &rewriter) const; 554 555 /// Creates and populates a canonical memref descriptor struct. 556 MemRefDescriptor 557 createMemRefDescriptor(Location loc, MemRefType memRefType, 558 Value allocatedPtr, Value alignedPtr, 559 ArrayRef<Value> sizes, ArrayRef<Value> strides, 560 ConversionPatternRewriter &rewriter) const; 561 }; 562 563 /// Utility class for operation conversions targeting the LLVM dialect that 564 /// match exactly one source operation. 565 template <typename SourceOp> 566 class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { 567 public: 568 ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter, 569 PatternBenefit benefit = 1) 570 : ConvertToLLVMPattern(SourceOp::getOperationName(), 571 &typeConverter.getContext(), typeConverter, 572 benefit) {} 573 574 /// Wrappers around the RewritePattern methods that pass the derived op type. rewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)575 void rewrite(Operation *op, ArrayRef<Value> operands, 576 ConversionPatternRewriter &rewriter) const final { 577 rewrite(cast<SourceOp>(op), operands, rewriter); 578 } match(Operation * op)579 LogicalResult match(Operation *op) const final { 580 return match(cast<SourceOp>(op)); 581 } 582 LogicalResult matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)583 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 584 ConversionPatternRewriter &rewriter) const final { 585 return matchAndRewrite(cast<SourceOp>(op), operands, rewriter); 586 } 587 588 /// Rewrite and Match methods that operate on the SourceOp type. These must be 589 /// overridden by the derived pattern class. rewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)590 virtual void rewrite(SourceOp op, ArrayRef<Value> operands, 591 ConversionPatternRewriter &rewriter) const { 592 llvm_unreachable("must override rewrite or matchAndRewrite"); 593 } match(SourceOp op)594 virtual LogicalResult match(SourceOp op) const { 595 llvm_unreachable("must override match or matchAndRewrite"); 596 } 597 virtual LogicalResult matchAndRewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)598 matchAndRewrite(SourceOp op, ArrayRef<Value> operands, 599 ConversionPatternRewriter &rewriter) const { 600 if (succeeded(match(op))) { 601 rewrite(op, operands, rewriter); 602 return success(); 603 } 604 return failure(); 605 } 606 607 private: 608 using ConvertToLLVMPattern::match; 609 using ConvertToLLVMPattern::matchAndRewrite; 610 }; 611 612 namespace LLVM { 613 namespace detail { 614 /// Replaces the given operation "op" with a new operation of type "targetOp" 615 /// and given operands. 616 LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, 617 ValueRange operands, 618 LLVMTypeConverter &typeConverter, 619 ConversionPatternRewriter &rewriter); 620 621 LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, 622 ValueRange operands, 623 LLVMTypeConverter &typeConverter, 624 ConversionPatternRewriter &rewriter); 625 } // namespace detail 626 } // namespace LLVM 627 628 /// Generic implementation of one-to-one conversion from "SourceOp" to 629 /// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent. 630 /// Upholds a convention that multi-result operations get converted into an 631 /// operation returning the LLVM IR structure type, in which case individual 632 /// values must be extracted from using LLVM::ExtractValueOp before being used. 633 template <typename SourceOp, typename TargetOp> 634 class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> { 635 public: 636 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern; 637 using Super = OneToOneConvertToLLVMPattern<SourceOp, TargetOp>; 638 639 /// Converts the type of the result to an LLVM type, pass operands as is, 640 /// preserve attributes. 641 LogicalResult matchAndRewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)642 matchAndRewrite(SourceOp op, ArrayRef<Value> operands, 643 ConversionPatternRewriter &rewriter) const override { 644 return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(), 645 operands, *this->getTypeConverter(), 646 rewriter); 647 } 648 }; 649 650 /// Basic lowering implementation to rewrite Ops with just one result to the 651 /// LLVM Dialect. This supports higher-dimensional vector types. 652 template <typename SourceOp, typename TargetOp> 653 class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> { 654 public: 655 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern; 656 using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>; 657 658 LogicalResult matchAndRewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)659 matchAndRewrite(SourceOp op, ArrayRef<Value> operands, 660 ConversionPatternRewriter &rewriter) const override { 661 static_assert( 662 std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value, 663 "expected single result op"); 664 static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>, 665 SourceOp>::value, 666 "expected same operands and result type"); 667 return LLVM::detail::vectorOneToOneRewrite( 668 op, TargetOp::getOperationName(), operands, *this->getTypeConverter(), 669 rewriter); 670 } 671 }; 672 673 /// Derived class that automatically populates legalization information for 674 /// different LLVM ops. 675 class LLVMConversionTarget : public ConversionTarget { 676 public: 677 explicit LLVMConversionTarget(MLIRContext &ctx); 678 }; 679 680 } // namespace mlir 681 682 #endif // MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H 683