1 //===- ConvertStandardToLLVM.h - Convert to the LLVM dialect ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Provides a dialect conversion targeting the LLVM IR dialect.  By default, it
10 // converts Standard ops and types and provides hooks for dialect-specific
11 // extensions to the conversion.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
16 #define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
17 
18 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 
21 namespace llvm {
22 class IntegerType;
23 class LLVMContext;
24 class Module;
25 class Type;
26 } // namespace llvm
27 
28 namespace mlir {
29 
30 class BaseMemRefType;
31 class ComplexType;
32 class LLVMTypeConverter;
33 class UnrankedMemRefType;
34 
35 namespace LLVM {
36 class LLVMDialect;
37 class LLVMType;
38 class LLVMPointerType;
39 } // namespace LLVM
40 
41 /// Callback to convert function argument types. It converts a MemRef function
42 /// argument to a list of non-aggregate types containing descriptor
43 /// information, and an UnrankedmemRef function argument to a list containing
44 /// the rank and a pointer to a descriptor struct.
45 LogicalResult structFuncArgTypeConverter(LLVMTypeConverter &converter,
46                                          Type type,
47                                          SmallVectorImpl<Type> &result);
48 
49 /// Callback to convert function argument types. It converts MemRef function
50 /// arguments to bare pointers to the MemRef element type.
51 LogicalResult barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
52                                           Type type,
53                                           SmallVectorImpl<Type> &result);
54 
55 /// Conversion from types in the Standard dialect to the LLVM IR dialect.
56 class LLVMTypeConverter : public TypeConverter {
57   /// Give structFuncArgTypeConverter access to memref-specific functions.
58   friend LogicalResult
59   structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type,
60                              SmallVectorImpl<Type> &result);
61 
62 public:
63   using TypeConverter::convertType;
64 
65   /// Create an LLVMTypeConverter using the default LowerToLLVMOptions.
66   LLVMTypeConverter(MLIRContext *ctx);
67 
68   /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
69   LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options);
70 
71   /// Convert a function type.  The arguments and results are converted one by
72   /// one and results are packed into a wrapped LLVM IR structure type. `result`
73   /// is populated with argument mapping.
74   LLVM::LLVMType convertFunctionSignature(FunctionType funcTy, bool isVariadic,
75                                           SignatureConversion &result);
76 
77   /// Convert a non-empty list of types to be returned from a function into a
78   /// supported LLVM IR type.  In particular, if more than one value is
79   /// returned, create an LLVM IR structure type with elements that correspond
80   /// to each of the MLIR types converted with `convertType`.
81   Type packFunctionResults(ArrayRef<Type> types);
82 
83   /// Convert a type in the context of the default or bare pointer calling
84   /// convention. Calling convention sensitive types, such as MemRefType and
85   /// UnrankedMemRefType, are converted following the specific rules for the
86   /// calling convention. Calling convention independent types are converted
87   /// following the default LLVM type conversions.
88   Type convertCallingConventionType(Type type);
89 
90   /// Promote the bare pointers in 'values' that resulted from memrefs to
91   /// descriptors. 'stdTypes' holds the types of 'values' before the conversion
92   /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
93   void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter,
94                                     Location loc, ArrayRef<Type> stdTypes,
95                                     SmallVectorImpl<Value> &values);
96 
97   /// Returns the MLIR context.
98   MLIRContext &getContext();
99 
100   /// Returns the LLVM dialect.
getDialect()101   LLVM::LLVMDialect *getDialect() { return llvmDialect; }
102 
getOptions()103   const LowerToLLVMOptions &getOptions() const { return options; }
104 
105   /// Promote the LLVM representation of all operands including promoting MemRef
106   /// descriptors to stack and use pointers to struct to avoid the complexity
107   /// of the platform-specific C/C++ ABI lowering related to struct argument
108   /// passing.
109   SmallVector<Value, 4> promoteOperands(Location loc, ValueRange opOperands,
110                                         ValueRange operands,
111                                         OpBuilder &builder);
112 
113   /// Promote the LLVM struct representation of one MemRef descriptor to stack
114   /// and use pointer to struct to avoid the complexity of the platform-specific
115   /// C/C++ ABI lowering related to struct argument passing.
116   Value promoteOneMemRefDescriptor(Location loc, Value operand,
117                                    OpBuilder &builder);
118 
119   /// Converts the function type to a C-compatible format, in particular using
120   /// pointers to memref descriptors for arguments.
121   LLVM::LLVMType convertFunctionTypeCWrapper(FunctionType type);
122 
123   /// Returns the data layout to use during and after conversion.
getDataLayout()124   const llvm::DataLayout &getDataLayout() { return options.dataLayout; }
125 
126   /// Gets the LLVM representation of the index type. The returned type is an
127   /// integer type with the size configured for this type converter.
128   LLVM::LLVMType getIndexType();
129 
130   /// Gets the bitwidth of the index type when converted to LLVM.
getIndexTypeBitwidth()131   unsigned getIndexTypeBitwidth() { return options.indexBitwidth; }
132 
133   /// Gets the pointer bitwidth.
134   unsigned getPointerBitwidth(unsigned addressSpace = 0);
135 
136 protected:
137   /// Pointer to the LLVM dialect.
138   LLVM::LLVMDialect *llvmDialect;
139 
140 private:
141   /// Convert a function type.  The arguments and results are converted one by
142   /// one.  Additionally, if the function returns more than one value, pack the
143   /// results into an LLVM IR structure type so that the converted function type
144   /// returns at most one result.
145   Type convertFunctionType(FunctionType type);
146 
147   /// Convert the index type.  Uses llvmModule data layout to create an integer
148   /// of the pointer bitwidth.
149   Type convertIndexType(IndexType type);
150 
151   /// Convert an integer type `i*` to `!llvm<"i*">`.
152   Type convertIntegerType(IntegerType type);
153 
154   /// Convert a floating point type: `f16` to `!llvm.half`, `f32` to
155   /// `!llvm.float` and `f64` to `!llvm.double`.  `bf16` is not supported
156   /// by LLVM.
157   Type convertFloatType(FloatType type);
158 
159   /// Convert complex number type: `complex<f16>` to `!llvm<"{ half, half }">`,
160   /// `complex<f32>` to `!llvm<"{ float, float }">`, and `complex<f64>` to
161   /// `!llvm<"{ double, double }">`. `complex<bf16>` is not supported.
162   Type convertComplexType(ComplexType type);
163 
164   /// Convert a memref type into an LLVM type that captures the relevant data.
165   Type convertMemRefType(MemRefType type);
166 
167   /// Convert a memref type into a list of LLVM IR types that will form the
168   /// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides`
169   /// arrays in the descriptors are unpacked to individual index-typed elements,
170   /// else they are are kept as rank-sized arrays of index type. In particular,
171   /// the list will contain:
172   /// - two pointers to the memref element type, followed by
173   /// - an index-typed offset, followed by
174   /// - (if unpackAggregates = true)
175   ///    - one index-typed size per dimension of the memref, followed by
176   ///    - one index-typed stride per dimension of the memref.
177   /// - (if unpackArrregates = false)
178   ///   - one rank-sized array of index-type for the size of each dimension
179   ///   - one rank-sized array of index-type for the stride of each dimension
180   ///
181   /// For example, memref<?x?xf32> is converted to the following list:
182   /// - `!llvm<"float*">` (allocated pointer),
183   /// - `!llvm<"float*">` (aligned pointer),
184   /// - `!llvm.i64` (offset),
185   /// - `!llvm.i64`, `!llvm.i64` (sizes),
186   /// - `!llvm.i64`, `!llvm.i64` (strides).
187   /// These types can be recomposed to a memref descriptor struct.
188   SmallVector<LLVM::LLVMType, 5>
189   getMemRefDescriptorFields(MemRefType type, bool unpackAggregates);
190 
191   /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
192   /// that will form the unranked memref descriptor. In particular, this list
193   /// contains:
194   /// - an integer rank, followed by
195   /// - a pointer to the memref descriptor struct.
196   /// For example, memref<*xf32> is converted to the following list:
197   /// !llvm.i64 (rank)
198   /// !llvm<"i8*"> (type-erased pointer).
199   /// These types can be recomposed to a unranked memref descriptor struct.
200   SmallVector<LLVM::LLVMType, 2> getUnrankedMemRefDescriptorFields();
201 
202   // Convert an unranked memref type to an LLVM type that captures the
203   // runtime rank and a pointer to the static ranked memref desc
204   Type convertUnrankedMemRefType(UnrankedMemRefType type);
205 
206   /// Convert a memref type to a bare pointer to the memref element type.
207   Type convertMemRefToBarePtr(BaseMemRefType type);
208 
209   // Convert a 1D vector type into an LLVM vector type.
210   Type convertVectorType(VectorType type);
211 
212   /// Options for customizing the llvm lowering.
213   LowerToLLVMOptions options;
214 };
215 
216 /// Helper class to produce LLVM dialect operations extracting or inserting
217 /// values to a struct.
218 class StructBuilder {
219 public:
220   /// Construct a helper for the given value.
221   explicit StructBuilder(Value v);
222   /// Builds IR creating an `undef` value of the descriptor type.
223   static StructBuilder undef(OpBuilder &builder, Location loc,
224                              Type descriptorType);
225 
Value()226   /*implicit*/ operator Value() { return value; }
227 
228 protected:
229   // LLVM value
230   Value value;
231   // Cached struct type.
232   Type structType;
233 
234 protected:
235   /// Builds IR to extract a value from the struct at position pos
236   Value extractPtr(OpBuilder &builder, Location loc, unsigned pos);
237   /// Builds IR to set a value in the struct at position pos
238   void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr);
239 };
240 
241 class ComplexStructBuilder : public StructBuilder {
242 public:
243   /// Construct a helper for the given complex number value.
244   using StructBuilder::StructBuilder;
245   /// Build IR creating an `undef` value of the complex number type.
246   static ComplexStructBuilder undef(OpBuilder &builder, Location loc,
247                                     Type type);
248 
249   // Build IR extracting the real value from the complex number struct.
250   Value real(OpBuilder &builder, Location loc);
251   // Build IR inserting the real value into the complex number struct.
252   void setReal(OpBuilder &builder, Location loc, Value real);
253 
254   // Build IR extracting the imaginary value from the complex number struct.
255   Value imaginary(OpBuilder &builder, Location loc);
256   // Build IR inserting the imaginary value into the complex number struct.
257   void setImaginary(OpBuilder &builder, Location loc, Value imaginary);
258 };
259 
260 /// Helper class to produce LLVM dialect operations extracting or inserting
261 /// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
262 /// The Value may be null, in which case none of the operations are valid.
263 class MemRefDescriptor : public StructBuilder {
264 public:
265   /// Construct a helper for the given descriptor value.
266   explicit MemRefDescriptor(Value descriptor);
267   /// Builds IR creating an `undef` value of the descriptor type.
268   static MemRefDescriptor undef(OpBuilder &builder, Location loc,
269                                 Type descriptorType);
270   /// Builds IR creating a MemRef descriptor that represents `type` and
271   /// populates it with static shape and stride information extracted from the
272   /// type.
273   static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc,
274                                           LLVMTypeConverter &typeConverter,
275                                           MemRefType type, Value memory);
276 
277   /// Builds IR extracting the allocated pointer from the descriptor.
278   Value allocatedPtr(OpBuilder &builder, Location loc);
279   /// Builds IR inserting the allocated pointer into the descriptor.
280   void setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr);
281 
282   /// Builds IR extracting the aligned pointer from the descriptor.
283   Value alignedPtr(OpBuilder &builder, Location loc);
284 
285   /// Builds IR inserting the aligned pointer into the descriptor.
286   void setAlignedPtr(OpBuilder &builder, Location loc, Value ptr);
287 
288   /// Builds IR extracting the offset from the descriptor.
289   Value offset(OpBuilder &builder, Location loc);
290 
291   /// Builds IR inserting the offset into the descriptor.
292   void setOffset(OpBuilder &builder, Location loc, Value offset);
293   void setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset);
294 
295   /// Builds IR extracting the pos-th size from the descriptor.
296   Value size(OpBuilder &builder, Location loc, unsigned pos);
297   Value size(OpBuilder &builder, Location loc, Value pos, int64_t rank);
298 
299   /// Builds IR inserting the pos-th size into the descriptor
300   void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size);
301   void setConstantSize(OpBuilder &builder, Location loc, unsigned pos,
302                        uint64_t size);
303 
304   /// Builds IR extracting the pos-th size from the descriptor.
305   Value stride(OpBuilder &builder, Location loc, unsigned pos);
306 
307   /// Builds IR inserting the pos-th stride into the descriptor
308   void setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride);
309   void setConstantStride(OpBuilder &builder, Location loc, unsigned pos,
310                          uint64_t stride);
311 
312   /// Returns the (LLVM) pointer type this descriptor contains.
313   LLVM::LLVMPointerType getElementPtrType();
314 
315   /// Builds IR populating a MemRef descriptor structure from a list of
316   /// individual values composing that descriptor, in the following order:
317   /// - allocated pointer;
318   /// - aligned pointer;
319   /// - offset;
320   /// - <rank> sizes;
321   /// - <rank> shapes;
322   /// where <rank> is the MemRef rank as provided in `type`.
323   static Value pack(OpBuilder &builder, Location loc,
324                     LLVMTypeConverter &converter, MemRefType type,
325                     ValueRange values);
326 
327   /// Builds IR extracting individual elements of a MemRef descriptor structure
328   /// and returning them as `results` list.
329   static void unpack(OpBuilder &builder, Location loc, Value packed,
330                      MemRefType type, SmallVectorImpl<Value> &results);
331 
332   /// Returns the number of non-aggregate values that would be produced by
333   /// `unpack`.
334   static unsigned getNumUnpackedValues(MemRefType type);
335 
336 private:
337   // Cached index type.
338   Type indexType;
339 };
340 
341 /// Helper class allowing the user to access a range of Values that correspond
342 /// to an unpacked memref descriptor using named accessors. This does not own
343 /// the values.
344 class MemRefDescriptorView {
345 public:
346   /// Constructs the view from a range of values. Infers the rank from the size
347   /// of the range.
348   explicit MemRefDescriptorView(ValueRange range);
349 
350   /// Returns the allocated pointer Value.
351   Value allocatedPtr();
352 
353   /// Returns the aligned pointer Value.
354   Value alignedPtr();
355 
356   /// Returns the offset Value.
357   Value offset();
358 
359   /// Returns the pos-th size Value.
360   Value size(unsigned pos);
361 
362   /// Returns the pos-th stride Value.
363   Value stride(unsigned pos);
364 
365 private:
366   /// Rank of the memref the descriptor is pointing to.
367   int rank;
368   /// Underlying range of Values.
369   ValueRange elements;
370 };
371 
372 class UnrankedMemRefDescriptor : public StructBuilder {
373 public:
374   /// Construct a helper for the given descriptor value.
375   explicit UnrankedMemRefDescriptor(Value descriptor);
376   /// Builds IR creating an `undef` value of the descriptor type.
377   static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc,
378                                         Type descriptorType);
379 
380   /// Builds IR extracting the rank from the descriptor
381   Value rank(OpBuilder &builder, Location loc);
382   /// Builds IR setting the rank in the descriptor
383   void setRank(OpBuilder &builder, Location loc, Value value);
384   /// Builds IR extracting ranked memref descriptor ptr
385   Value memRefDescPtr(OpBuilder &builder, Location loc);
386   /// Builds IR setting ranked memref descriptor ptr
387   void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value);
388 
389   /// Builds IR populating an unranked MemRef descriptor structure from a list
390   /// of individual constituent values in the following order:
391   /// - rank of the memref;
392   /// - pointer to the memref descriptor.
393   static Value pack(OpBuilder &builder, Location loc,
394                     LLVMTypeConverter &converter, UnrankedMemRefType type,
395                     ValueRange values);
396 
397   /// Builds IR extracting individual elements that compose an unranked memref
398   /// descriptor and returns them as `results` list.
399   static void unpack(OpBuilder &builder, Location loc, Value packed,
400                      SmallVectorImpl<Value> &results);
401 
402   /// Returns the number of non-aggregate values that would be produced by
403   /// `unpack`.
getNumUnpackedValues()404   static unsigned getNumUnpackedValues() { return 2; }
405 
406   /// Builds IR computing the sizes in bytes (suitable for opaque allocation)
407   /// and appends the corresponding values into `sizes`.
408   static void computeSizes(OpBuilder &builder, Location loc,
409                            LLVMTypeConverter &typeConverter,
410                            ArrayRef<UnrankedMemRefDescriptor> values,
411                            SmallVectorImpl<Value> &sizes);
412 
413   /// TODO: The following accessors don't take alignment rules between elements
414   /// of the descriptor struct into account. For some architectures, it might be
415   /// necessary to extend them and to use `llvm::DataLayout` contained in
416   /// `LLVMTypeConverter`.
417 
418   /// Builds IR extracting the allocated pointer from the descriptor.
419   static Value allocatedPtr(OpBuilder &builder, Location loc,
420                             Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType);
421   /// Builds IR inserting the allocated pointer into the descriptor.
422   static void setAllocatedPtr(OpBuilder &builder, Location loc,
423                               Value memRefDescPtr,
424                               LLVM::LLVMType elemPtrPtrType,
425                               Value allocatedPtr);
426 
427   /// Builds IR extracting the aligned pointer from the descriptor.
428   static Value alignedPtr(OpBuilder &builder, Location loc,
429                           LLVMTypeConverter &typeConverter, Value memRefDescPtr,
430                           LLVM::LLVMType elemPtrPtrType);
431   /// Builds IR inserting the aligned pointer into the descriptor.
432   static void setAlignedPtr(OpBuilder &builder, Location loc,
433                             LLVMTypeConverter &typeConverter,
434                             Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType,
435                             Value alignedPtr);
436 
437   /// Builds IR extracting the offset from the descriptor.
438   static Value offset(OpBuilder &builder, Location loc,
439                       LLVMTypeConverter &typeConverter, Value memRefDescPtr,
440                       LLVM::LLVMType elemPtrPtrType);
441   /// Builds IR inserting the offset into the descriptor.
442   static void setOffset(OpBuilder &builder, Location loc,
443                         LLVMTypeConverter &typeConverter, Value memRefDescPtr,
444                         LLVM::LLVMType elemPtrPtrType, Value offset);
445 
446   /// Builds IR extracting the pointer to the first element of the size array.
447   static Value sizeBasePtr(OpBuilder &builder, Location loc,
448                            LLVMTypeConverter &typeConverter,
449                            Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType);
450   /// Builds IR extracting the size[index] from the descriptor.
451   static Value size(OpBuilder &builder, Location loc,
452                     LLVMTypeConverter typeConverter, Value sizeBasePtr,
453                     Value index);
454   /// Builds IR inserting the size[index] into the descriptor.
455   static void setSize(OpBuilder &builder, Location loc,
456                       LLVMTypeConverter typeConverter, Value sizeBasePtr,
457                       Value index, Value size);
458 
459   /// Builds IR extracting the pointer to the first element of the stride array.
460   static Value strideBasePtr(OpBuilder &builder, Location loc,
461                              LLVMTypeConverter &typeConverter,
462                              Value sizeBasePtr, Value rank);
463   /// Builds IR extracting the stride[index] from the descriptor.
464   static Value stride(OpBuilder &builder, Location loc,
465                       LLVMTypeConverter typeConverter, Value strideBasePtr,
466                       Value index, Value stride);
467   /// Builds IR inserting the stride[index] into the descriptor.
468   static void setStride(OpBuilder &builder, Location loc,
469                         LLVMTypeConverter typeConverter, Value strideBasePtr,
470                         Value index, Value stride);
471 };
472 
473 /// Base class for operation conversions targeting the LLVM IR dialect. It
474 /// provides the conversion patterns with access to the LLVMTypeConverter and
475 /// the LowerToLLVMOptions. The class captures the LLVMTypeConverter and the
476 /// LowerToLLVMOptions by reference meaning the references have to remain alive
477 /// during the entire pattern lifetime.
478 class ConvertToLLVMPattern : public ConversionPattern {
479 public:
480   ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
481                        LLVMTypeConverter &typeConverter,
482                        PatternBenefit benefit = 1);
483 
484 protected:
485   /// Returns the LLVM dialect.
486   LLVM::LLVMDialect &getDialect() const;
487 
488   LLVMTypeConverter *getTypeConverter() const;
489 
490   /// Gets the MLIR type wrapping the LLVM integer type whose bit width is
491   /// defined by the used type converter.
492   LLVM::LLVMType getIndexType() const;
493 
494   /// Gets the MLIR type wrapping the LLVM integer type whose bit width
495   /// corresponds to that of a LLVM pointer type.
496   LLVM::LLVMType getIntPtrType(unsigned addressSpace = 0) const;
497 
498   /// Gets the MLIR type wrapping the LLVM void type.
499   LLVM::LLVMType getVoidType() const;
500 
501   /// Get the MLIR type wrapping the LLVM i8* type.
502   LLVM::LLVMType getVoidPtrType() const;
503 
504   /// Create an LLVM dialect operation defining the given index constant.
505   Value createIndexConstant(ConversionPatternRewriter &builder, Location loc,
506                             uint64_t value) const;
507 
508   // This is a strided getElementPtr variant that linearizes subscripts as:
509   //   `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
510   Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
511                              ValueRange indices,
512                              ConversionPatternRewriter &rewriter) const;
513 
514   // Forwards to getStridedElementPtr. TODO: remove.
515   Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
516                    ValueRange indices,
517                    ConversionPatternRewriter &rewriter) const;
518 
519   /// Returns if the givem memref type is supported.
520   bool isSupportedMemRefType(MemRefType type) const;
521 
522   /// Returns the type of a pointer to an element of the memref.
523   Type getElementPtrType(MemRefType type) const;
524 
525   /// Computes sizes, strides and buffer size in bytes of `memRefType` with
526   /// identity layout. Emits constant ops for the static sizes of `memRefType`,
527   /// and uses `dynamicSizes` for the others. Emits instructions to compute
528   /// strides and buffer size from these sizes.
529   ///
530   /// For example, memref<4x?xf32> emits:
531   /// `sizes[0]`   = llvm.mlir.constant(4 : index) : !llvm.i64
532   /// `sizes[1]`   = `dynamicSizes[0]`
533   /// `strides[1]` = llvm.mlir.constant(1 : index) : !llvm.i64
534   /// `strides[0]` = `sizes[0]`
535   /// %size        = llvm.mul `sizes[0]`, `sizes[1]` : !llvm.i64
536   /// %nullptr     = llvm.mlir.null : !llvm.ptr<float>
537   /// %gep         = llvm.getelementptr %nullptr[%size]
538   ///                  : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
539   /// `sizeBytes`  = llvm.ptrtoint %gep : !llvm.ptr<float> to !llvm.i64
540   void getMemRefDescriptorSizes(Location loc, MemRefType memRefType,
541                                 ArrayRef<Value> dynamicSizes,
542                                 ConversionPatternRewriter &rewriter,
543                                 SmallVectorImpl<Value> &sizes,
544                                 SmallVectorImpl<Value> &strides,
545                                 Value &sizeBytes) const;
546 
547   /// Computes the size of type in bytes.
548   Value getSizeInBytes(Location loc, Type type,
549                        ConversionPatternRewriter &rewriter) const;
550 
551   /// Computes total number of elements for the given shape.
552   Value getNumElements(Location loc, ArrayRef<Value> shape,
553                        ConversionPatternRewriter &rewriter) const;
554 
555   /// Creates and populates a canonical memref descriptor struct.
556   MemRefDescriptor
557   createMemRefDescriptor(Location loc, MemRefType memRefType,
558                          Value allocatedPtr, Value alignedPtr,
559                          ArrayRef<Value> sizes, ArrayRef<Value> strides,
560                          ConversionPatternRewriter &rewriter) const;
561 };
562 
563 /// Utility class for operation conversions targeting the LLVM dialect that
564 /// match exactly one source operation.
565 template <typename SourceOp>
566 class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
567 public:
568   ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
569                          PatternBenefit benefit = 1)
570       : ConvertToLLVMPattern(SourceOp::getOperationName(),
571                              &typeConverter.getContext(), typeConverter,
572                              benefit) {}
573 
574   /// Wrappers around the RewritePattern methods that pass the derived op type.
rewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)575   void rewrite(Operation *op, ArrayRef<Value> operands,
576                ConversionPatternRewriter &rewriter) const final {
577     rewrite(cast<SourceOp>(op), operands, rewriter);
578   }
match(Operation * op)579   LogicalResult match(Operation *op) const final {
580     return match(cast<SourceOp>(op));
581   }
582   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)583   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
584                   ConversionPatternRewriter &rewriter) const final {
585     return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
586   }
587 
588   /// Rewrite and Match methods that operate on the SourceOp type. These must be
589   /// overridden by the derived pattern class.
rewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)590   virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
591                        ConversionPatternRewriter &rewriter) const {
592     llvm_unreachable("must override rewrite or matchAndRewrite");
593   }
match(SourceOp op)594   virtual LogicalResult match(SourceOp op) const {
595     llvm_unreachable("must override match or matchAndRewrite");
596   }
597   virtual LogicalResult
matchAndRewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)598   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
599                   ConversionPatternRewriter &rewriter) const {
600     if (succeeded(match(op))) {
601       rewrite(op, operands, rewriter);
602       return success();
603     }
604     return failure();
605   }
606 
607 private:
608   using ConvertToLLVMPattern::match;
609   using ConvertToLLVMPattern::matchAndRewrite;
610 };
611 
612 namespace LLVM {
613 namespace detail {
614 /// Replaces the given operation "op" with a new operation of type "targetOp"
615 /// and given operands.
616 LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
617                               ValueRange operands,
618                               LLVMTypeConverter &typeConverter,
619                               ConversionPatternRewriter &rewriter);
620 
621 LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
622                                     ValueRange operands,
623                                     LLVMTypeConverter &typeConverter,
624                                     ConversionPatternRewriter &rewriter);
625 } // namespace detail
626 } // namespace LLVM
627 
628 /// Generic implementation of one-to-one conversion from "SourceOp" to
629 /// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent.
630 /// Upholds a convention that multi-result operations get converted into an
631 /// operation returning the LLVM IR structure type, in which case individual
632 /// values must be extracted from using LLVM::ExtractValueOp before being used.
633 template <typename SourceOp, typename TargetOp>
634 class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
635 public:
636   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
637   using Super = OneToOneConvertToLLVMPattern<SourceOp, TargetOp>;
638 
639   /// Converts the type of the result to an LLVM type, pass operands as is,
640   /// preserve attributes.
641   LogicalResult
matchAndRewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)642   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
643                   ConversionPatternRewriter &rewriter) const override {
644     return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
645                                          operands, *this->getTypeConverter(),
646                                          rewriter);
647   }
648 };
649 
650 /// Basic lowering implementation to rewrite Ops with just one result to the
651 /// LLVM Dialect. This supports higher-dimensional vector types.
652 template <typename SourceOp, typename TargetOp>
653 class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
654 public:
655   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
656   using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
657 
658   LogicalResult
matchAndRewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)659   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
660                   ConversionPatternRewriter &rewriter) const override {
661     static_assert(
662         std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
663         "expected single result op");
664     static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
665                                   SourceOp>::value,
666                   "expected same operands and result type");
667     return LLVM::detail::vectorOneToOneRewrite(
668         op, TargetOp::getOperationName(), operands, *this->getTypeConverter(),
669         rewriter);
670   }
671 };
672 
673 /// Derived class that automatically populates legalization information for
674 /// different LLVM ops.
675 class LLVMConversionTarget : public ConversionTarget {
676 public:
677   explicit LLVMConversionTarget(MLIRContext &ctx);
678 };
679 
680 } // namespace mlir
681 
682 #endif // MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
683