1 //===- TypeTranslation.cpp - type translation between MLIR LLVM & LLVM IR -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Target/LLVMIR/TypeTranslation.h"
10 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
11 #include "mlir/IR/MLIRContext.h"
12 
13 #include "llvm/ADT/TypeSwitch.h"
14 #include "llvm/IR/DataLayout.h"
15 #include "llvm/IR/DerivedTypes.h"
16 #include "llvm/IR/Type.h"
17 
18 using namespace mlir;
19 
20 namespace mlir {
21 namespace LLVM {
22 namespace detail {
23 /// Support for translating MLIR LLVM dialect types to LLVM IR.
24 class TypeToLLVMIRTranslatorImpl {
25 public:
26   /// Constructs a class creating types in the given LLVM context.
TypeToLLVMIRTranslatorImpl(llvm::LLVMContext & context)27   TypeToLLVMIRTranslatorImpl(llvm::LLVMContext &context) : context(context) {}
28 
29   /// Translates a single type.
translateType(LLVM::LLVMType type)30   llvm::Type *translateType(LLVM::LLVMType type) {
31     // If the conversion is already known, just return it.
32     if (knownTranslations.count(type))
33       return knownTranslations.lookup(type);
34 
35     // Dispatch to an appropriate function.
36     llvm::Type *translated =
37         llvm::TypeSwitch<LLVM::LLVMType, llvm::Type *>(type)
38             .Case([this](LLVM::LLVMVoidType) {
39               return llvm::Type::getVoidTy(context);
40             })
41             .Case([this](LLVM::LLVMHalfType) {
42               return llvm::Type::getHalfTy(context);
43             })
44             .Case([this](LLVM::LLVMBFloatType) {
45               return llvm::Type::getBFloatTy(context);
46             })
47             .Case([this](LLVM::LLVMFloatType) {
48               return llvm::Type::getFloatTy(context);
49             })
50             .Case([this](LLVM::LLVMDoubleType) {
51               return llvm::Type::getDoubleTy(context);
52             })
53             .Case([this](LLVM::LLVMFP128Type) {
54               return llvm::Type::getFP128Ty(context);
55             })
56             .Case([this](LLVM::LLVMX86FP80Type) {
57               return llvm::Type::getX86_FP80Ty(context);
58             })
59             .Case([this](LLVM::LLVMPPCFP128Type) {
60               return llvm::Type::getPPC_FP128Ty(context);
61             })
62             .Case([this](LLVM::LLVMX86MMXType) {
63               return llvm::Type::getX86_MMXTy(context);
64             })
65             .Case([this](LLVM::LLVMTokenType) {
66               return llvm::Type::getTokenTy(context);
67             })
68             .Case([this](LLVM::LLVMLabelType) {
69               return llvm::Type::getLabelTy(context);
70             })
71             .Case([this](LLVM::LLVMMetadataType) {
72               return llvm::Type::getMetadataTy(context);
73             })
74             .Case<LLVM::LLVMArrayType, LLVM::LLVMIntegerType,
75                   LLVM::LLVMFunctionType, LLVM::LLVMPointerType,
76                   LLVM::LLVMStructType, LLVM::LLVMFixedVectorType,
77                   LLVM::LLVMScalableVectorType>(
78                 [this](auto type) { return this->translate(type); })
79             .Default([](LLVM::LLVMType t) -> llvm::Type * {
80               llvm_unreachable("unknown LLVM dialect type");
81             });
82 
83     // Cache the result of the conversion and return.
84     knownTranslations.try_emplace(type, translated);
85     return translated;
86   }
87 
88 private:
89   /// Translates the given array type.
translate(LLVM::LLVMArrayType type)90   llvm::Type *translate(LLVM::LLVMArrayType type) {
91     return llvm::ArrayType::get(translateType(type.getElementType()),
92                                 type.getNumElements());
93   }
94 
95   /// Translates the given function type.
translate(LLVM::LLVMFunctionType type)96   llvm::Type *translate(LLVM::LLVMFunctionType type) {
97     SmallVector<llvm::Type *, 8> paramTypes;
98     translateTypes(type.getParams(), paramTypes);
99     return llvm::FunctionType::get(translateType(type.getReturnType()),
100                                    paramTypes, type.isVarArg());
101   }
102 
103   /// Translates the given integer type.
translate(LLVM::LLVMIntegerType type)104   llvm::Type *translate(LLVM::LLVMIntegerType type) {
105     return llvm::IntegerType::get(context, type.getBitWidth());
106   }
107 
108   /// Translates the given pointer type.
translate(LLVM::LLVMPointerType type)109   llvm::Type *translate(LLVM::LLVMPointerType type) {
110     return llvm::PointerType::get(translateType(type.getElementType()),
111                                   type.getAddressSpace());
112   }
113 
114   /// Translates the given structure type, supports both identified and literal
115   /// structs. This will _create_ a new identified structure every time, use
116   /// `convertType` if a structure with the same name must be looked up instead.
translate(LLVM::LLVMStructType type)117   llvm::Type *translate(LLVM::LLVMStructType type) {
118     SmallVector<llvm::Type *, 8> subtypes;
119     if (!type.isIdentified()) {
120       translateTypes(type.getBody(), subtypes);
121       return llvm::StructType::get(context, subtypes, type.isPacked());
122     }
123 
124     llvm::StructType *structType =
125         llvm::StructType::create(context, type.getName());
126     // Mark the type we just created as known so that recursive calls can pick
127     // it up and use directly.
128     knownTranslations.try_emplace(type, structType);
129     if (type.isOpaque())
130       return structType;
131 
132     translateTypes(type.getBody(), subtypes);
133     structType->setBody(subtypes, type.isPacked());
134     return structType;
135   }
136 
137   /// Translates the given fixed-vector type.
translate(LLVM::LLVMFixedVectorType type)138   llvm::Type *translate(LLVM::LLVMFixedVectorType type) {
139     return llvm::FixedVectorType::get(translateType(type.getElementType()),
140                                       type.getNumElements());
141   }
142 
143   /// Translates the given scalable-vector type.
translate(LLVM::LLVMScalableVectorType type)144   llvm::Type *translate(LLVM::LLVMScalableVectorType type) {
145     return llvm::ScalableVectorType::get(translateType(type.getElementType()),
146                                          type.getMinNumElements());
147   }
148 
149   /// Translates a list of types.
translateTypes(ArrayRef<LLVM::LLVMType> types,SmallVectorImpl<llvm::Type * > & result)150   void translateTypes(ArrayRef<LLVM::LLVMType> types,
151                       SmallVectorImpl<llvm::Type *> &result) {
152     result.reserve(result.size() + types.size());
153     for (auto type : types)
154       result.push_back(translateType(type));
155   }
156 
157   /// Reference to the context in which the LLVM IR types are created.
158   llvm::LLVMContext &context;
159 
160   /// Map of known translation. This serves a double purpose: caches translation
161   /// results to avoid repeated recursive calls and makes sure identified
162   /// structs with the same name (that is, equal) are resolved to an existing
163   /// type instead of creating a new type.
164   llvm::DenseMap<LLVM::LLVMType, llvm::Type *> knownTranslations;
165 };
166 } // end namespace detail
167 } // end namespace LLVM
168 } // end namespace mlir
169 
TypeToLLVMIRTranslator(llvm::LLVMContext & context)170 LLVM::TypeToLLVMIRTranslator::TypeToLLVMIRTranslator(llvm::LLVMContext &context)
171     : impl(new detail::TypeToLLVMIRTranslatorImpl(context)) {}
172 
~TypeToLLVMIRTranslator()173 LLVM::TypeToLLVMIRTranslator::~TypeToLLVMIRTranslator() {}
174 
translateType(LLVM::LLVMType type)175 llvm::Type *LLVM::TypeToLLVMIRTranslator::translateType(LLVM::LLVMType type) {
176   return impl->translateType(type);
177 }
178 
getPreferredAlignment(LLVM::LLVMType type,const llvm::DataLayout & layout)179 unsigned LLVM::TypeToLLVMIRTranslator::getPreferredAlignment(
180     LLVM::LLVMType type, const llvm::DataLayout &layout) {
181   return layout.getPrefTypeAlignment(translateType(type));
182 }
183 
184 namespace mlir {
185 namespace LLVM {
186 namespace detail {
187 /// Support for translating LLVM IR types to MLIR LLVM dialect types.
188 class TypeFromLLVMIRTranslatorImpl {
189 public:
190   /// Constructs a class creating types in the given MLIR context.
TypeFromLLVMIRTranslatorImpl(MLIRContext & context)191   TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {}
192 
193   /// Translates the given type.
translateType(llvm::Type * type)194   LLVM::LLVMType translateType(llvm::Type *type) {
195     if (knownTranslations.count(type))
196       return knownTranslations.lookup(type);
197 
198     LLVM::LLVMType translated =
199         llvm::TypeSwitch<llvm::Type *, LLVM::LLVMType>(type)
200             .Case<llvm::ArrayType, llvm::FunctionType, llvm::IntegerType,
201                   llvm::PointerType, llvm::StructType, llvm::FixedVectorType,
202                   llvm::ScalableVectorType>(
203                 [this](auto *type) { return this->translate(type); })
204             .Default([this](llvm::Type *type) {
205               return translatePrimitiveType(type);
206             });
207     knownTranslations.try_emplace(type, translated);
208     return translated;
209   }
210 
211 private:
212   /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature,
213   /// type.
translatePrimitiveType(llvm::Type * type)214   LLVM::LLVMType translatePrimitiveType(llvm::Type *type) {
215     if (type->isVoidTy())
216       return LLVM::LLVMVoidType::get(&context);
217     if (type->isHalfTy())
218       return LLVM::LLVMHalfType::get(&context);
219     if (type->isBFloatTy())
220       return LLVM::LLVMBFloatType::get(&context);
221     if (type->isFloatTy())
222       return LLVM::LLVMFloatType::get(&context);
223     if (type->isDoubleTy())
224       return LLVM::LLVMDoubleType::get(&context);
225     if (type->isFP128Ty())
226       return LLVM::LLVMFP128Type::get(&context);
227     if (type->isX86_FP80Ty())
228       return LLVM::LLVMX86FP80Type::get(&context);
229     if (type->isPPC_FP128Ty())
230       return LLVM::LLVMPPCFP128Type::get(&context);
231     if (type->isX86_MMXTy())
232       return LLVM::LLVMX86MMXType::get(&context);
233     if (type->isLabelTy())
234       return LLVM::LLVMLabelType::get(&context);
235     if (type->isMetadataTy())
236       return LLVM::LLVMMetadataType::get(&context);
237     llvm_unreachable("not a primitive type");
238   }
239 
240   /// Translates the given array type.
translate(llvm::ArrayType * type)241   LLVM::LLVMType translate(llvm::ArrayType *type) {
242     return LLVM::LLVMArrayType::get(translateType(type->getElementType()),
243                                     type->getNumElements());
244   }
245 
246   /// Translates the given function type.
translate(llvm::FunctionType * type)247   LLVM::LLVMType translate(llvm::FunctionType *type) {
248     SmallVector<LLVM::LLVMType, 8> paramTypes;
249     translateTypes(type->params(), paramTypes);
250     return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()),
251                                        paramTypes, type->isVarArg());
252   }
253 
254   /// Translates the given integer type.
translate(llvm::IntegerType * type)255   LLVM::LLVMType translate(llvm::IntegerType *type) {
256     return LLVM::LLVMIntegerType::get(&context, type->getBitWidth());
257   }
258 
259   /// Translates the given pointer type.
translate(llvm::PointerType * type)260   LLVM::LLVMType translate(llvm::PointerType *type) {
261     return LLVM::LLVMPointerType::get(translateType(type->getElementType()),
262                                       type->getAddressSpace());
263   }
264 
265   /// Translates the given structure type.
translate(llvm::StructType * type)266   LLVM::LLVMType translate(llvm::StructType *type) {
267     SmallVector<LLVM::LLVMType, 8> subtypes;
268     if (type->isLiteral()) {
269       translateTypes(type->subtypes(), subtypes);
270       return LLVM::LLVMStructType::getLiteral(&context, subtypes,
271                                               type->isPacked());
272     }
273 
274     if (type->isOpaque())
275       return LLVM::LLVMStructType::getOpaque(type->getName(), &context);
276 
277     LLVM::LLVMStructType translated =
278         LLVM::LLVMStructType::getIdentified(&context, type->getName());
279     knownTranslations.try_emplace(type, translated);
280     translateTypes(type->subtypes(), subtypes);
281     LogicalResult bodySet = translated.setBody(subtypes, type->isPacked());
282     assert(succeeded(bodySet) &&
283            "could not set the body of an identified struct");
284     (void)bodySet;
285     return translated;
286   }
287 
288   /// Translates the given fixed-vector type.
translate(llvm::FixedVectorType * type)289   LLVM::LLVMType translate(llvm::FixedVectorType *type) {
290     return LLVM::LLVMFixedVectorType::get(translateType(type->getElementType()),
291                                           type->getNumElements());
292   }
293 
294   /// Translates the given scalable-vector type.
translate(llvm::ScalableVectorType * type)295   LLVM::LLVMType translate(llvm::ScalableVectorType *type) {
296     return LLVM::LLVMScalableVectorType::get(
297         translateType(type->getElementType()), type->getMinNumElements());
298   }
299 
300   /// Translates a list of types.
translateTypes(ArrayRef<llvm::Type * > types,SmallVectorImpl<LLVM::LLVMType> & result)301   void translateTypes(ArrayRef<llvm::Type *> types,
302                       SmallVectorImpl<LLVM::LLVMType> &result) {
303     result.reserve(result.size() + types.size());
304     for (llvm::Type *type : types)
305       result.push_back(translateType(type));
306   }
307 
308   /// Map of known translations. Serves as a cache and as recursion stopper for
309   /// translating recursive structs.
310   llvm::DenseMap<llvm::Type *, LLVM::LLVMType> knownTranslations;
311 
312   /// The context in which MLIR types are created.
313   MLIRContext &context;
314 };
315 } // end namespace detail
316 } // end namespace LLVM
317 } // end namespace mlir
318 
TypeFromLLVMIRTranslator(MLIRContext & context)319 LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context)
320     : impl(new detail::TypeFromLLVMIRTranslatorImpl(context)) {}
321 
~TypeFromLLVMIRTranslator()322 LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() {}
323 
translateType(llvm::Type * type)324 LLVM::LLVMType LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) {
325   return impl->translateType(type);
326 }
327