1 /*
2  * Copyright 2010-2012, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "slang_rs_export_type.h"
18 
19 #include <list>
20 #include <vector>
21 
22 #include "clang/AST/ASTContext.h"
23 #include "clang/AST/Attr.h"
24 #include "clang/AST/RecordLayout.h"
25 
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/IR/DataLayout.h"
28 #include "llvm/IR/DerivedTypes.h"
29 #include "llvm/IR/Type.h"
30 
31 #include "slang_assert.h"
32 #include "slang_rs_context.h"
33 #include "slang_rs_export_element.h"
34 #include "slang_version.h"
35 
36 #define CHECK_PARENT_EQUALITY(ParentClass, E) \
37   if (!ParentClass::matchODR(E, true))        \
38     return false;
39 
40 namespace slang {
41 
42 namespace {
43 
44 // For the data types we support:
45 //  Category      - data type category
46 //  SName         - "common name" in script (C99)
47 //  RsType        - element name in RenderScript
48 //  RsShortType   - short element name in RenderScript
49 //  SizeInBits    - size in bits
50 //  CName         - reflected C name
51 //  JavaName      - reflected Java name
52 //  JavaArrayElementName - reflected name in Java arrays
53 //  CVecName      - prefix for C vector types
54 //  JavaVecName   - prefix for Java vector type
55 //  JavaPromotion - unsigned type undergoing Java promotion
56 //
57 // IMPORTANT: The data types in this table should be at the same index as
58 // specified by the corresponding DataType enum.
59 //
60 // TODO: Pull this information out into a separate file.
61 static RSReflectionType gReflectionTypes[] = {
62 #define _ nullptr
63   //      Category     SName              RsType       RsST           CName         JN      JAEN       CVN       JVN     JP
64 {PrimitiveDataType,   "half",         "FLOAT_16",     "F16", 16,     "half",   "short",  "short",   "Half",  "Short", false},
65 {PrimitiveDataType,  "float",         "FLOAT_32",     "F32", 32,    "float",   "float",  "float",  "Float",  "Float", false},
66 {PrimitiveDataType, "double",         "FLOAT_64",     "F64", 64,   "double",  "double", "double", "Double", "Double", false},
67 {PrimitiveDataType,   "char",         "SIGNED_8",      "I8",  8,   "int8_t",    "byte",   "byte",   "Byte",   "Byte", false},
68 {PrimitiveDataType,  "short",        "SIGNED_16",     "I16", 16,  "int16_t",   "short",  "short",  "Short",  "Short", false},
69 {PrimitiveDataType,    "int",        "SIGNED_32",     "I32", 32,  "int32_t",     "int",    "int",    "Int",    "Int", false},
70 {PrimitiveDataType,   "long",        "SIGNED_64",     "I64", 64,  "int64_t",    "long",   "long",   "Long",   "Long", false},
71 {PrimitiveDataType,  "uchar",       "UNSIGNED_8",      "U8",  8,  "uint8_t",   "short",   "byte",  "UByte",  "Short",  true},
72 {PrimitiveDataType, "ushort",      "UNSIGNED_16",     "U16", 16, "uint16_t",     "int",  "short", "UShort",    "Int",  true},
73 {PrimitiveDataType,   "uint",      "UNSIGNED_32",     "U32", 32, "uint32_t",    "long",    "int",   "UInt",   "Long",  true},
74 {PrimitiveDataType,  "ulong",      "UNSIGNED_64",     "U64", 64, "uint64_t",    "long",   "long",  "ULong",   "Long", false},
75 {PrimitiveDataType,   "bool",          "BOOLEAN", "BOOLEAN",  8,     "bool", "boolean",   "byte",        _,        _, false},
76 {PrimitiveDataType,        _,   "UNSIGNED_5_6_5",         _, 16,          _,         _,        _,        _,        _, false},
77 {PrimitiveDataType,        _, "UNSIGNED_5_5_5_1",         _, 16,          _,         _,        _,        _,        _, false},
78 {PrimitiveDataType,        _, "UNSIGNED_4_4_4_4",         _, 16,          _,         _,        _,        _,        _, false},
79 
80 {MatrixDataType, "rs_matrix2x2", "MATRIX_2X2", _,  4*32, "rs_matrix2x2", "Matrix2f", _, _, _, false},
81 {MatrixDataType, "rs_matrix3x3", "MATRIX_3X3", _,  9*32, "rs_matrix3x3", "Matrix3f", _, _, _, false},
82 {MatrixDataType, "rs_matrix4x4", "MATRIX_4X4", _, 16*32, "rs_matrix4x4", "Matrix4f", _, _, _, false},
83 
84 // RS object types are 32 bits in 32-bit RS, but 256 bits in 64-bit RS.
85 // This is handled specially by the GetElementSizeInBits() method.
86 {ObjectDataType,          "rs_element",          "RS_ELEMENT",          "ELEMENT", 32,         "Element",         "Element", _, _, _, false},
87 {ObjectDataType,             "rs_type",             "RS_TYPE",             "TYPE", 32,            "Type",            "Type", _, _, _, false},
88 {ObjectDataType,       "rs_allocation",       "RS_ALLOCATION",       "ALLOCATION", 32,      "Allocation",      "Allocation", _, _, _, false},
89 {ObjectDataType,          "rs_sampler",          "RS_SAMPLER",          "SAMPLER", 32,         "Sampler",         "Sampler", _, _, _, false},
90 {ObjectDataType,           "rs_script",           "RS_SCRIPT",           "SCRIPT", 32,          "Script",          "Script", _, _, _, false},
91 {ObjectDataType,             "rs_mesh",             "RS_MESH",             "MESH", 32,            "Mesh",            "Mesh", _, _, _, false},
92 {ObjectDataType,             "rs_path",             "RS_PATH",             "PATH", 32,            "Path",            "Path", _, _, _, false},
93 {ObjectDataType, "rs_program_fragment", "RS_PROGRAM_FRAGMENT", "PROGRAM_FRAGMENT", 32, "ProgramFragment", "ProgramFragment", _, _, _, false},
94 {ObjectDataType,   "rs_program_vertex",   "RS_PROGRAM_VERTEX",   "PROGRAM_VERTEX", 32,   "ProgramVertex",   "ProgramVertex", _, _, _, false},
95 {ObjectDataType,   "rs_program_raster",   "RS_PROGRAM_RASTER",   "PROGRAM_RASTER", 32,   "ProgramRaster",   "ProgramRaster", _, _, _, false},
96 {ObjectDataType,    "rs_program_store",    "RS_PROGRAM_STORE",    "PROGRAM_STORE", 32,    "ProgramStore",    "ProgramStore", _, _, _, false},
97 {ObjectDataType,             "rs_font",             "RS_FONT",             "FONT", 32,            "Font",            "Font", _, _, _, false},
98 #undef _
99 };
100 
101 const int kMaxVectorSize = 4;
102 
103 struct BuiltinInfo {
104   clang::BuiltinType::Kind builtinTypeKind;
105   DataType type;
106   /* TODO If we return std::string instead of llvm::StringRef, we could build
107    * the name instead of duplicating the entries.
108    */
109   const char *cname[kMaxVectorSize];
110 };
111 
112 
113 BuiltinInfo BuiltinInfoTable[] = {
114     {clang::BuiltinType::Bool, DataTypeBoolean,
115      {"bool", "bool2", "bool3", "bool4"}},
116     {clang::BuiltinType::Char_U, DataTypeUnsigned8,
117      {"uchar", "uchar2", "uchar3", "uchar4"}},
118     {clang::BuiltinType::UChar, DataTypeUnsigned8,
119      {"uchar", "uchar2", "uchar3", "uchar4"}},
120     {clang::BuiltinType::Char16, DataTypeSigned16,
121      {"short", "short2", "short3", "short4"}},
122     {clang::BuiltinType::Char32, DataTypeSigned32,
123      {"int", "int2", "int3", "int4"}},
124     {clang::BuiltinType::UShort, DataTypeUnsigned16,
125      {"ushort", "ushort2", "ushort3", "ushort4"}},
126     {clang::BuiltinType::UInt, DataTypeUnsigned32,
127      {"uint", "uint2", "uint3", "uint4"}},
128     {clang::BuiltinType::ULong, DataTypeUnsigned64,
129      {"ulong", "ulong2", "ulong3", "ulong4"}},
130     {clang::BuiltinType::ULongLong, DataTypeUnsigned64,
131      {"ulong", "ulong2", "ulong3", "ulong4"}},
132 
133     {clang::BuiltinType::Char_S, DataTypeSigned8,
134      {"char", "char2", "char3", "char4"}},
135     {clang::BuiltinType::SChar, DataTypeSigned8,
136      {"char", "char2", "char3", "char4"}},
137     {clang::BuiltinType::Short, DataTypeSigned16,
138      {"short", "short2", "short3", "short4"}},
139     {clang::BuiltinType::Int, DataTypeSigned32,
140      {"int", "int2", "int3", "int4"}},
141     {clang::BuiltinType::Long, DataTypeSigned64,
142      {"long", "long2", "long3", "long4"}},
143     {clang::BuiltinType::LongLong, DataTypeSigned64,
144      {"long", "long2", "long3", "long4"}},
145     {clang::BuiltinType::Half, DataTypeFloat16,
146      {"half", "half2", "half3", "half4"}},
147     {clang::BuiltinType::Float, DataTypeFloat32,
148      {"float", "float2", "float3", "float4"}},
149     {clang::BuiltinType::Double, DataTypeFloat64,
150      {"double", "double2", "double3", "double4"}},
151 };
152 const int BuiltinInfoTableCount = sizeof(BuiltinInfoTable) / sizeof(BuiltinInfoTable[0]);
153 
154 struct NameAndPrimitiveType {
155   const char *name;
156   DataType dataType;
157 };
158 
159 static NameAndPrimitiveType MatrixAndObjectDataTypes[] = {
160     {"rs_matrix2x2", DataTypeRSMatrix2x2},
161     {"rs_matrix3x3", DataTypeRSMatrix3x3},
162     {"rs_matrix4x4", DataTypeRSMatrix4x4},
163     {"rs_element", DataTypeRSElement},
164     {"rs_type", DataTypeRSType},
165     {"rs_allocation", DataTypeRSAllocation},
166     {"rs_sampler", DataTypeRSSampler},
167     {"rs_script", DataTypeRSScript},
168     {"rs_mesh", DataTypeRSMesh},
169     {"rs_path", DataTypeRSPath},
170     {"rs_program_fragment", DataTypeRSProgramFragment},
171     {"rs_program_vertex", DataTypeRSProgramVertex},
172     {"rs_program_raster", DataTypeRSProgramRaster},
173     {"rs_program_store", DataTypeRSProgramStore},
174     {"rs_font", DataTypeRSFont},
175 };
176 
177 const int MatrixAndObjectDataTypesCount =
178     sizeof(MatrixAndObjectDataTypes) / sizeof(MatrixAndObjectDataTypes[0]);
179 
180 static const clang::Type *TypeExportableHelper(
181     const clang::Type *T,
182     llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
183     slang::RSContext *Context,
184     const clang::VarDecl *VD,
185     const clang::RecordDecl *TopLevelRecord,
186     ExportKind EK);
187 
188 template <unsigned N>
ReportTypeError(slang::RSContext * Context,const clang::NamedDecl * ND,const clang::RecordDecl * TopLevelRecord,const char (& Message)[N],unsigned int TargetAPI=0)189 static void ReportTypeError(slang::RSContext *Context,
190                             const clang::NamedDecl *ND,
191                             const clang::RecordDecl *TopLevelRecord,
192                             const char (&Message)[N],
193                             unsigned int TargetAPI = 0) {
194   // Attempt to use the type declaration first (if we have one).
195   // Fall back to the variable definition, if we are looking at something
196   // like an array declaration that can't be exported.
197   if (TopLevelRecord) {
198     Context->ReportError(TopLevelRecord->getLocation(), Message)
199         << TopLevelRecord->getName() << TargetAPI;
200   } else if (ND) {
201     Context->ReportError(ND->getLocation(), Message) << ND->getName()
202                                                      << TargetAPI;
203   } else {
204     slangAssert(false && "Variables should be validated before exporting");
205   }
206 }
207 
ConstantArrayTypeExportableHelper(const clang::ConstantArrayType * CAT,llvm::SmallPtrSet<const clang::Type *,8> & SPS,slang::RSContext * Context,const clang::VarDecl * VD,const clang::RecordDecl * TopLevelRecord,ExportKind EK)208 static const clang::Type *ConstantArrayTypeExportableHelper(
209     const clang::ConstantArrayType *CAT,
210     llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
211     slang::RSContext *Context,
212     const clang::VarDecl *VD,
213     const clang::RecordDecl *TopLevelRecord,
214     ExportKind EK) {
215   // Check element type
216   const clang::Type *ElementType = GetConstantArrayElementType(CAT);
217   if (ElementType->isArrayType()) {
218     ReportTypeError(Context, VD, TopLevelRecord,
219                     "multidimensional arrays cannot be exported: '%0'");
220     return nullptr;
221   } else if (ElementType->isExtVectorType()) {
222     const clang::ExtVectorType *EVT =
223         static_cast<const clang::ExtVectorType*>(ElementType);
224     unsigned numElements = EVT->getNumElements();
225 
226     const clang::Type *BaseElementType = GetExtVectorElementType(EVT);
227     if (!RSExportPrimitiveType::IsPrimitiveType(BaseElementType)) {
228       ReportTypeError(Context, VD, TopLevelRecord,
229         "vectors of non-primitive types cannot be exported: '%0'");
230       return nullptr;
231     }
232 
233     if (numElements == 3 && CAT->getSize() != 1) {
234       ReportTypeError(Context, VD, TopLevelRecord,
235         "arrays of width 3 vector types cannot be exported: '%0'");
236       return nullptr;
237     }
238   }
239 
240   if (TypeExportableHelper(ElementType, SPS, Context, VD,
241                            TopLevelRecord, EK) == nullptr) {
242     return nullptr;
243   } else {
244     return CAT;
245   }
246 }
247 
FindBuiltinType(clang::BuiltinType::Kind builtinTypeKind)248 BuiltinInfo *FindBuiltinType(clang::BuiltinType::Kind builtinTypeKind) {
249   for (int i = 0; i < BuiltinInfoTableCount; i++) {
250     if (builtinTypeKind == BuiltinInfoTable[i].builtinTypeKind) {
251       return &BuiltinInfoTable[i];
252     }
253   }
254   return nullptr;
255 }
256 
TypeExportableHelper(clang::Type const * T,llvm::SmallPtrSet<clang::Type const *,8> & SPS,slang::RSContext * Context,clang::VarDecl const * VD,clang::RecordDecl const * TopLevelRecord,ExportKind EK)257 static const clang::Type *TypeExportableHelper(
258     clang::Type const *T,
259     llvm::SmallPtrSet<clang::Type const *, 8> &SPS,
260     slang::RSContext *Context,
261     clang::VarDecl const *VD,
262     clang::RecordDecl const *TopLevelRecord,
263     ExportKind EK) {
264   // Normalize first
265   if ((T = GetCanonicalType(T)) == nullptr)
266     return nullptr;
267 
268   if (SPS.count(T))
269     return T;
270 
271   const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
272 
273   switch (T->getTypeClass()) {
274     case clang::Type::Builtin: {
275       const clang::BuiltinType *BT = static_cast<const clang::BuiltinType*>(CTI);
276       return FindBuiltinType(BT->getKind()) == nullptr ? nullptr : T;
277     }
278     case clang::Type::Record: {
279       if (RSExportPrimitiveType::GetRSSpecificType(T) != DataTypeUnknown) {
280         return T;  // RS object type, no further checks are needed
281       }
282 
283       // Check internal struct
284       if (T->isUnionType()) {
285         ReportTypeError(Context, VD, T->getAsUnionType()->getDecl(),
286                         "unions cannot be exported: '%0'");
287         return nullptr;
288       } else if (!T->isStructureType()) {
289         slangAssert(false && "Unknown type cannot be exported");
290         return nullptr;
291       }
292 
293       clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
294       slangAssert(RD);
295       RD = RD->getDefinition();
296       if (RD == nullptr) {
297         ReportTypeError(Context, nullptr, T->getAsStructureType()->getDecl(),
298                         "struct is not defined in this module");
299         return nullptr;
300       }
301 
302       if (!TopLevelRecord) {
303         TopLevelRecord = RD;
304       }
305       if (RD->getName().empty()) {
306         ReportTypeError(Context, nullptr, RD,
307                         "anonymous structures cannot be exported");
308         return nullptr;
309       }
310 
311       // Fast check
312       if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
313         return nullptr;
314 
315       // Insert myself into checking set
316       SPS.insert(T);
317 
318       // Check all element
319       for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
320                FE = RD->field_end();
321            FI != FE;
322            FI++) {
323         const clang::FieldDecl *FD = *FI;
324         const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
325         FT = GetCanonicalType(FT);
326 
327         if (!TypeExportableHelper(FT, SPS, Context, VD, TopLevelRecord,
328                                   EK)) {
329           return nullptr;
330         }
331 
332         // We don't support bit fields yet
333         //
334         // TODO(zonr/srhines): allow bit fields of size 8, 16, 32
335         if (FD->isBitField()) {
336           // Context can be null from NormalizeType?
337           slangAssert(Context);
338           Context->ReportError(
339               FD->getLocation(),
340               "bit fields are not able to be exported: '%0.%1'")
341               << RD->getName() << FD->getName();
342           return nullptr;
343         }
344       }
345 
346       return T;
347     }
348     case clang::Type::FunctionProto:
349     case clang::Type::FunctionNoProto:
350       ReportTypeError(Context, VD, TopLevelRecord,
351                       "function types cannot be exported: '%0'");
352       return nullptr;
353     case clang::Type::Pointer: {
354       if (TopLevelRecord) {
355         ReportTypeError(Context, VD, TopLevelRecord,
356             "structures containing pointers cannot be used as the type of "
357             "an exported global variable or the parameter to an exported "
358             "function: '%0'");
359         return nullptr;
360       }
361 
362       const clang::PointerType *PT = static_cast<const clang::PointerType*>(CTI);
363       const clang::Type *PointeeType = GetPointeeType(PT);
364 
365       if (PointeeType->getTypeClass() == clang::Type::Pointer) {
366         ReportTypeError(Context, VD, TopLevelRecord,
367             "multiple levels of pointers cannot be exported: '%0'");
368         return nullptr;
369       }
370 
371       // Void pointers are forbidden for export, although we must accept
372       // void pointers that come in as arguments to a legacy kernel.
373       if (PointeeType->isVoidType() && EK != LegacyKernelArgument) {
374         ReportTypeError(Context, VD, TopLevelRecord,
375             "void pointers cannot be exported: '%0'");
376         return nullptr;
377       }
378 
379       // We don't support pointer with array-type pointee
380       if (PointeeType->isArrayType()) {
381         ReportTypeError(Context, VD, TopLevelRecord,
382             "pointers to arrays cannot be exported: '%0'");
383         return nullptr;
384       }
385 
386       // Check for unsupported pointee type
387       if (TypeExportableHelper(PointeeType, SPS, Context, VD,
388                                 TopLevelRecord, EK) == nullptr)
389         return nullptr;
390       else
391         return T;
392     }
393     case clang::Type::ExtVector: {
394       const clang::ExtVectorType *EVT =
395               static_cast<const clang::ExtVectorType*>(CTI);
396       // Only vector with size 2, 3 and 4 are supported.
397       if (EVT->getNumElements() < 2 || EVT->getNumElements() > 4)
398         return nullptr;
399 
400       // Check base element type
401       const clang::Type *ElementType = GetExtVectorElementType(EVT);
402 
403       if ((ElementType->getTypeClass() != clang::Type::Builtin) ||
404           (TypeExportableHelper(ElementType, SPS, Context, VD,
405                                 TopLevelRecord, EK) == nullptr))
406         return nullptr;
407       else
408         return T;
409     }
410     case clang::Type::ConstantArray: {
411       const clang::ConstantArrayType *CAT =
412               static_cast<const clang::ConstantArrayType*>(CTI);
413 
414       return ConstantArrayTypeExportableHelper(CAT, SPS, Context, VD,
415                                                TopLevelRecord, EK);
416     }
417     case clang::Type::Enum: {
418       // FIXME: We currently convert enums to integers, rather than reflecting
419       // a more complete (and nicer type-safe Java version).
420       // Context can be null from NormalizeType?
421       slangAssert(Context);
422       return Context->getASTContext().IntTy.getTypePtr();
423     }
424     default: {
425       slangAssert(false && "Unknown type cannot be validated");
426       return nullptr;
427     }
428   }
429 }
430 
431 // Return the type that can be used to create RSExportType, will always return
432 // the canonical type.
433 //
434 // If the Type T is not exportable, this function returns nullptr. DiagEngine is
435 // used to generate proper Clang diagnostic messages when a non-exportable type
436 // is detected. TopLevelRecord is used to capture the highest struct (in the
437 // case of a nested hierarchy) for detecting other types that cannot be exported
438 // (mostly pointers within a struct).
TypeExportable(const clang::Type * T,slang::RSContext * Context,const clang::VarDecl * VD,ExportKind EK)439 static const clang::Type *TypeExportable(const clang::Type *T,
440                                          slang::RSContext *Context,
441                                          const clang::VarDecl *VD,
442                                          ExportKind EK) {
443   llvm::SmallPtrSet<const clang::Type*, 8> SPS =
444       llvm::SmallPtrSet<const clang::Type*, 8>();
445 
446   return TypeExportableHelper(T, SPS, Context, VD, nullptr, EK);
447 }
448 
ValidateRSObjectInVarDecl(slang::RSContext * Context,const clang::VarDecl * VD,bool InCompositeType,unsigned int TargetAPI)449 static bool ValidateRSObjectInVarDecl(slang::RSContext *Context,
450                                       const clang::VarDecl *VD, bool InCompositeType,
451                                       unsigned int TargetAPI) {
452   if (TargetAPI < SLANG_JB_TARGET_API) {
453     // Only if we are already in a composite type (like an array or structure).
454     if (InCompositeType) {
455       // Only if we are actually exported (i.e. non-static).
456       if (VD->hasLinkage() &&
457           (VD->getFormalLinkage() == clang::ExternalLinkage)) {
458         // Only if we are not a pointer to an object.
459         const clang::Type *T = GetCanonicalType(VD->getType().getTypePtr());
460         if (T->getTypeClass() != clang::Type::Pointer) {
461           ReportTypeError(Context, VD, nullptr,
462                           "arrays/structures containing RS object types "
463                           "cannot be exported in target API < %1: '%0'",
464                           SLANG_JB_TARGET_API);
465           return false;
466         }
467       }
468     }
469   }
470 
471   return true;
472 }
473 
474 // Helper function for ValidateType(). We do a recursive descent on the
475 // type hierarchy to ensure that we can properly export/handle the
476 // declaration.
477 // \return true if the variable declaration is valid,
478 //         false if it is invalid (along with proper diagnostics).
479 //
480 // C - ASTContext (for diagnostics + builtin types).
481 // T - sub-type that we are validating.
482 // ND - (optional) top-level named declaration that we are validating.
483 // SPS - set of types we have already seen/validated.
484 // InCompositeType - true if we are within an outer composite type.
485 // UnionDecl - set if we are in a sub-type of a union.
486 // TargetAPI - target SDK API level.
487 // IsFilterscript - whether or not we are compiling for Filterscript
488 // IsExtern - is this type externally visible (i.e. extern global or parameter
489 //                                             to an extern function)
ValidateTypeHelper(slang::RSContext * Context,clang::ASTContext & C,const clang::Type * & T,const clang::NamedDecl * ND,clang::SourceLocation Loc,llvm::SmallPtrSet<const clang::Type *,8> & SPS,bool InCompositeType,clang::RecordDecl * UnionDecl,unsigned int TargetAPI,bool IsFilterscript,bool IsExtern)490 static bool ValidateTypeHelper(
491     slang::RSContext *Context,
492     clang::ASTContext &C,
493     const clang::Type *&T,
494     const clang::NamedDecl *ND,
495     clang::SourceLocation Loc,
496     llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
497     bool InCompositeType,
498     clang::RecordDecl *UnionDecl,
499     unsigned int TargetAPI,
500     bool IsFilterscript,
501     bool IsExtern) {
502   if ((T = GetCanonicalType(T)) == nullptr)
503     return true;
504 
505   if (SPS.count(T))
506     return true;
507 
508   const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
509 
510   switch (T->getTypeClass()) {
511     case clang::Type::Record: {
512       if (RSExportPrimitiveType::IsRSObjectType(T)) {
513         const clang::VarDecl *VD = (ND ? llvm::dyn_cast<clang::VarDecl>(ND) : nullptr);
514         if (VD && !ValidateRSObjectInVarDecl(Context, VD, InCompositeType,
515                                              TargetAPI)) {
516           return false;
517         }
518       }
519 
520       if (RSExportPrimitiveType::GetRSSpecificType(T) != DataTypeUnknown) {
521         if (!UnionDecl) {
522           return true;
523         } else if (RSExportPrimitiveType::IsRSObjectType(T)) {
524           ReportTypeError(Context, nullptr, UnionDecl,
525               "unions containing RS object types are not allowed");
526           return false;
527         }
528       }
529 
530       clang::RecordDecl *RD = nullptr;
531 
532       // Check internal struct
533       if (T->isUnionType()) {
534         RD = T->getAsUnionType()->getDecl();
535         UnionDecl = RD;
536       } else if (T->isStructureType()) {
537         RD = T->getAsStructureType()->getDecl();
538       } else {
539         slangAssert(false && "Unknown type cannot be exported");
540         return false;
541       }
542 
543       slangAssert(RD);
544       RD = RD->getDefinition();
545       if (RD == nullptr) {
546         // FIXME
547         return true;
548       }
549 
550       // Fast check
551       if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
552         return false;
553 
554       // Insert myself into checking set
555       SPS.insert(T);
556 
557       // Check all elements
558       for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
559                FE = RD->field_end();
560            FI != FE;
561            FI++) {
562         const clang::FieldDecl *FD = *FI;
563         const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
564         FT = GetCanonicalType(FT);
565 
566         if (!ValidateTypeHelper(Context, C, FT, ND, Loc, SPS, true, UnionDecl,
567                                 TargetAPI, IsFilterscript, IsExtern)) {
568           return false;
569         }
570       }
571 
572       return true;
573     }
574 
575     case clang::Type::Builtin: {
576       if (IsFilterscript) {
577         clang::QualType QT = T->getCanonicalTypeInternal();
578         if (QT == C.DoubleTy ||
579             QT == C.LongDoubleTy ||
580             QT == C.LongTy ||
581             QT == C.LongLongTy) {
582           if (ND) {
583             Context->ReportError(
584                 Loc,
585                 "Builtin types > 32 bits in size are forbidden in "
586                 "Filterscript: '%0'")
587                 << ND->getName();
588           } else {
589             Context->ReportError(
590                 Loc,
591                 "Builtin types > 32 bits in size are forbidden in "
592                 "Filterscript");
593           }
594           return false;
595         }
596       }
597       break;
598     }
599 
600     case clang::Type::Pointer: {
601       if (IsFilterscript) {
602         if (ND) {
603           Context->ReportError(Loc,
604                                "Pointers are forbidden in Filterscript: '%0'")
605               << ND->getName();
606           return false;
607         } else {
608           // TODO(srhines): Find a better way to handle expressions (i.e. no
609           // NamedDecl) involving pointers in FS that should be allowed.
610           // An example would be calls to library functions like
611           // rsMatrixMultiply() that take rs_matrixNxN * types.
612         }
613       }
614 
615       // Forbid pointers in structures that are externally visible.
616       if (InCompositeType && IsExtern) {
617         if (ND) {
618           Context->ReportError(Loc,
619               "structures containing pointers cannot be used as the type of "
620               "an exported global variable or the parameter to an exported "
621               "function: '%0'")
622             << ND->getName();
623         } else {
624           Context->ReportError(Loc,
625               "structures containing pointers cannot be used as the type of "
626               "an exported global variable or the parameter to an exported "
627               "function");
628         }
629         return false;
630       }
631 
632       const clang::PointerType *PT = static_cast<const clang::PointerType*>(CTI);
633       const clang::Type *PointeeType = GetPointeeType(PT);
634 
635       return ValidateTypeHelper(Context, C, PointeeType, ND, Loc, SPS,
636                                 InCompositeType, UnionDecl, TargetAPI,
637                                 IsFilterscript, IsExtern);
638     }
639 
640     case clang::Type::ExtVector: {
641       const clang::ExtVectorType *EVT =
642               static_cast<const clang::ExtVectorType*>(CTI);
643       const clang::Type *ElementType = GetExtVectorElementType(EVT);
644       if (TargetAPI < SLANG_ICS_TARGET_API &&
645           InCompositeType &&
646           EVT->getNumElements() == 3 &&
647           ND &&
648           ND->getFormalLinkage() == clang::ExternalLinkage) {
649         ReportTypeError(Context, ND, nullptr,
650                         "structs containing vectors of dimension 3 cannot "
651                         "be exported at this API level: '%0'");
652         return false;
653       }
654       return ValidateTypeHelper(Context, C, ElementType, ND, Loc, SPS, true,
655                                 UnionDecl, TargetAPI, IsFilterscript, IsExtern);
656     }
657 
658     case clang::Type::ConstantArray: {
659       const clang::ConstantArrayType *CAT = static_cast<const clang::ConstantArrayType*>(CTI);
660       const clang::Type *ElementType = GetConstantArrayElementType(CAT);
661       return ValidateTypeHelper(Context, C, ElementType, ND, Loc, SPS, true,
662                                 UnionDecl, TargetAPI, IsFilterscript, IsExtern);
663     }
664 
665     default: {
666       break;
667     }
668   }
669 
670   return true;
671 }
672 
673 }  // namespace
674 
CreateDummyName(const char * type,const std::string & name)675 std::string CreateDummyName(const char *type, const std::string &name) {
676   std::stringstream S;
677   S << "<" << type;
678   if (!name.empty()) {
679     S << ":" << name;
680   }
681   S << ">";
682   return S.str();
683 }
684 
685 /****************************** RSExportType ******************************/
NormalizeType(const clang::Type * & T,llvm::StringRef & TypeName,RSContext * Context,const clang::VarDecl * VD,ExportKind EK)686 bool RSExportType::NormalizeType(const clang::Type *&T,
687                                  llvm::StringRef &TypeName,
688                                  RSContext *Context,
689                                  const clang::VarDecl *VD,
690                                  ExportKind EK) {
691   if ((T = TypeExportable(T, Context, VD, EK)) == nullptr) {
692     return false;
693   }
694   // Get type name
695   TypeName = RSExportType::GetTypeName(T);
696   if (Context && TypeName.empty()) {
697     if (VD) {
698       Context->ReportError(VD->getLocation(),
699                            "anonymous types cannot be exported");
700     } else {
701       Context->ReportError("anonymous types cannot be exported");
702     }
703     return false;
704   }
705 
706   return true;
707 }
708 
ValidateType(slang::RSContext * Context,clang::ASTContext & C,clang::QualType QT,const clang::NamedDecl * ND,clang::SourceLocation Loc,unsigned int TargetAPI,bool IsFilterscript,bool IsExtern)709 bool RSExportType::ValidateType(slang::RSContext *Context, clang::ASTContext &C,
710                                 clang::QualType QT, const clang::NamedDecl *ND,
711                                 clang::SourceLocation Loc,
712                                 unsigned int TargetAPI, bool IsFilterscript,
713                                 bool IsExtern) {
714   const clang::Type *T = QT.getTypePtr();
715   llvm::SmallPtrSet<const clang::Type*, 8> SPS =
716       llvm::SmallPtrSet<const clang::Type*, 8>();
717 
718   // If this is an externally visible variable declaration, we check if the
719   // type is able to be exported first.
720   if (auto VD = llvm::dyn_cast_or_null<clang::VarDecl>(ND)) {
721     if (VD->getFormalLinkage() == clang::ExternalLinkage) {
722       if (!TypeExportable(T, Context, VD, NotLegacyKernelArgument)) {
723         return false;
724       }
725     }
726   }
727   return ValidateTypeHelper(Context, C, T, ND, Loc, SPS, false, nullptr, TargetAPI,
728                             IsFilterscript, IsExtern);
729 }
730 
ValidateVarDecl(slang::RSContext * Context,clang::VarDecl * VD,unsigned int TargetAPI,bool IsFilterscript)731 bool RSExportType::ValidateVarDecl(slang::RSContext *Context,
732                                    clang::VarDecl *VD, unsigned int TargetAPI,
733                                    bool IsFilterscript) {
734   return ValidateType(Context, VD->getASTContext(), VD->getType(), VD,
735                       VD->getLocation(), TargetAPI, IsFilterscript,
736                       (VD->getFormalLinkage() == clang::ExternalLinkage));
737 }
738 
739 const clang::Type
GetTypeOfDecl(const clang::DeclaratorDecl * DD)740 *RSExportType::GetTypeOfDecl(const clang::DeclaratorDecl *DD) {
741   if (DD) {
742     clang::QualType T = DD->getType();
743 
744     if (T.isNull())
745       return nullptr;
746     else
747       return T.getTypePtr();
748   }
749   return nullptr;
750 }
751 
GetTypeName(const clang::Type * T)752 llvm::StringRef RSExportType::GetTypeName(const clang::Type* T) {
753   T = GetCanonicalType(T);
754   if (T == nullptr)
755     return llvm::StringRef();
756 
757   const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
758 
759   switch (T->getTypeClass()) {
760     case clang::Type::Builtin: {
761       const clang::BuiltinType *BT = static_cast<const clang::BuiltinType*>(CTI);
762       BuiltinInfo *info = FindBuiltinType(BT->getKind());
763       if (info != nullptr) {
764         return info->cname[0];
765       }
766       slangAssert(false && "Unknown data type of the builtin");
767       break;
768     }
769     case clang::Type::Record: {
770       clang::RecordDecl *RD;
771       if (T->isStructureType()) {
772         RD = T->getAsStructureType()->getDecl();
773       } else {
774         break;
775       }
776 
777       llvm::StringRef Name = RD->getName();
778       if (Name.empty()) {
779         if (RD->getTypedefNameForAnonDecl() != nullptr) {
780           Name = RD->getTypedefNameForAnonDecl()->getName();
781         }
782 
783         if (Name.empty()) {
784           // Try to find a name from redeclaration (i.e. typedef)
785           for (clang::TagDecl::redecl_iterator RI = RD->redecls_begin(),
786                    RE = RD->redecls_end();
787                RI != RE;
788                RI++) {
789             slangAssert(*RI != nullptr && "cannot be NULL object");
790 
791             Name = (*RI)->getName();
792             if (!Name.empty())
793               break;
794           }
795         }
796       }
797       return Name;
798     }
799     case clang::Type::Pointer: {
800       // "*" plus pointee name
801       const clang::PointerType *P = static_cast<const clang::PointerType*>(CTI);
802       const clang::Type *PT = GetPointeeType(P);
803       llvm::StringRef PointeeName;
804       // Passing nullptr as Context to NormalizeType can cause TypeExportableHelper
805       // to dereference a null Context?
806       if (NormalizeType(PT, PointeeName, nullptr, nullptr,
807                         NotLegacyKernelArgument)) {
808         char *Name = new char[ 1 /* * */ + PointeeName.size() + 1 ];
809         Name[0] = '*';
810         memcpy(Name + 1, PointeeName.data(), PointeeName.size());
811         Name[PointeeName.size() + 1] = '\0';
812         return Name;
813       }
814       break;
815     }
816     case clang::Type::ExtVector: {
817       const clang::ExtVectorType *EVT =
818               static_cast<const clang::ExtVectorType*>(CTI);
819       return RSExportVectorType::GetTypeName(EVT);
820       break;
821     }
822     case clang::Type::ConstantArray : {
823       // Construct name for a constant array is too complicated.
824       return "<ConstantArray>";
825     }
826     default: {
827       break;
828     }
829   }
830 
831   return llvm::StringRef();
832 }
833 
834 
Create(RSContext * Context,const clang::Type * T,const llvm::StringRef & TypeName,ExportKind EK)835 RSExportType *RSExportType::Create(RSContext *Context,
836                                    const clang::Type *T,
837                                    const llvm::StringRef &TypeName,
838                                    ExportKind EK) {
839   // Lookup the context to see whether the type was processed before.
840   // Newly created RSExportType will insert into context
841   // in RSExportType::RSExportType()
842   RSContext::export_type_iterator ETI = Context->findExportType(TypeName);
843 
844   if (ETI != Context->export_types_end())
845     return ETI->second;
846 
847   const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
848 
849   RSExportType *ET = nullptr;
850   switch (T->getTypeClass()) {
851     case clang::Type::Record: {
852       DataType dt = RSExportPrimitiveType::GetRSSpecificType(TypeName);
853       switch (dt) {
854         case DataTypeUnknown: {
855           // User-defined types
856           ET = RSExportRecordType::Create(Context,
857                                           T->getAsStructureType(),
858                                           TypeName);
859           break;
860         }
861         case DataTypeRSMatrix2x2: {
862           // 2 x 2 Matrix type
863           ET = RSExportMatrixType::Create(Context,
864                                           T->getAsStructureType(),
865                                           TypeName,
866                                           2);
867           break;
868         }
869         case DataTypeRSMatrix3x3: {
870           // 3 x 3 Matrix type
871           ET = RSExportMatrixType::Create(Context,
872                                           T->getAsStructureType(),
873                                           TypeName,
874                                           3);
875           break;
876         }
877         case DataTypeRSMatrix4x4: {
878           // 4 x 4 Matrix type
879           ET = RSExportMatrixType::Create(Context,
880                                           T->getAsStructureType(),
881                                           TypeName,
882                                           4);
883           break;
884         }
885         default: {
886           // Others are primitive types
887           ET = RSExportPrimitiveType::Create(Context, T, TypeName);
888           break;
889         }
890       }
891       break;
892     }
893     case clang::Type::Builtin: {
894       ET = RSExportPrimitiveType::Create(Context, T, TypeName);
895       break;
896     }
897     case clang::Type::Pointer: {
898       ET = RSExportPointerType::Create(Context,
899                                        static_cast<const clang::PointerType*>(CTI),
900                                        TypeName);
901       // FIXME: free the name (allocated in RSExportType::GetTypeName)
902       delete [] TypeName.data();
903       break;
904     }
905     case clang::Type::ExtVector: {
906       ET = RSExportVectorType::Create(Context,
907                                       static_cast<const clang::ExtVectorType*>(CTI),
908                                       TypeName);
909       break;
910     }
911     case clang::Type::ConstantArray: {
912       ET = RSExportConstantArrayType::Create(
913               Context,
914               static_cast<const clang::ConstantArrayType*>(CTI));
915       break;
916     }
917     default: {
918       Context->ReportError("unknown type cannot be exported: '%0'")
919           << T->getTypeClassName();
920       break;
921     }
922   }
923 
924   return ET;
925 }
926 
Create(RSContext * Context,const clang::Type * T,ExportKind EK,const clang::VarDecl * VD)927 RSExportType *RSExportType::Create(RSContext *Context, const clang::Type *T,
928                                    ExportKind EK, const clang::VarDecl *VD) {
929   llvm::StringRef TypeName;
930   if (NormalizeType(T, TypeName, Context, VD, EK)) {
931     return Create(Context, T, TypeName, EK);
932   } else {
933     return nullptr;
934   }
935 }
936 
CreateFromDecl(RSContext * Context,const clang::VarDecl * VD)937 RSExportType *RSExportType::CreateFromDecl(RSContext *Context,
938                                            const clang::VarDecl *VD) {
939   return RSExportType::Create(Context, GetTypeOfDecl(VD),
940                               NotLegacyKernelArgument, VD);
941 }
942 
getStoreSize() const943 size_t RSExportType::getStoreSize() const {
944   return getRSContext()->getDataLayout().getTypeStoreSize(getLLVMType());
945 }
946 
getAllocSize() const947 size_t RSExportType::getAllocSize() const {
948     return getRSContext()->getDataLayout().getTypeAllocSize(getLLVMType());
949 }
950 
RSExportType(RSContext * Context,ExportClass Class,const llvm::StringRef & Name,clang::SourceLocation Loc)951 RSExportType::RSExportType(RSContext *Context,
952                            ExportClass Class,
953                            const llvm::StringRef &Name, clang::SourceLocation Loc)
954     : RSExportable(Context, RSExportable::EX_TYPE, Loc),
955       mClass(Class),
956       // Make a copy on Name since memory stored @Name is either allocated in
957       // ASTContext or allocated in GetTypeName which will be destroyed later.
958       mName(Name.data(), Name.size()),
959       mLLVMType(nullptr) {
960   // Don't cache the type whose name start with '<'. Those type failed to
961   // get their name since constructing their name in GetTypeName() requiring
962   // complicated work.
963   if (!IsDummyName(Name)) {
964     // TODO(zonr): Need to check whether the insertion is successful or not.
965     Context->insertExportType(llvm::StringRef(Name), this);
966   }
967 
968 }
969 
keep()970 bool RSExportType::keep() {
971   if (!RSExportable::keep())
972     return false;
973   // Invalidate converted LLVM type.
974   mLLVMType = nullptr;
975   return true;
976 }
977 
matchODR(const RSExportType * E,bool) const978 bool RSExportType::matchODR(const RSExportType *E, bool /* LookInto */) const {
979   return (E->getClass() == getClass());
980 }
981 
~RSExportType()982 RSExportType::~RSExportType() {
983 }
984 
985 /************************** RSExportPrimitiveType **************************/
986 llvm::ManagedStatic<RSExportPrimitiveType::RSSpecificTypeMapTy>
987 RSExportPrimitiveType::RSSpecificTypeMap;
988 
IsPrimitiveType(const clang::Type * T)989 bool RSExportPrimitiveType::IsPrimitiveType(const clang::Type *T) {
990   if ((T != nullptr) && (T->getTypeClass() == clang::Type::Builtin))
991     return true;
992   else
993     return false;
994 }
995 
996 DataType
GetRSSpecificType(const llvm::StringRef & TypeName)997 RSExportPrimitiveType::GetRSSpecificType(const llvm::StringRef &TypeName) {
998   if (TypeName.empty())
999     return DataTypeUnknown;
1000 
1001   if (RSSpecificTypeMap->empty()) {
1002     for (int i = 0; i < MatrixAndObjectDataTypesCount; i++) {
1003       (*RSSpecificTypeMap)[MatrixAndObjectDataTypes[i].name] =
1004           MatrixAndObjectDataTypes[i].dataType;
1005     }
1006   }
1007 
1008   RSSpecificTypeMapTy::const_iterator I = RSSpecificTypeMap->find(TypeName);
1009   if (I == RSSpecificTypeMap->end())
1010     return DataTypeUnknown;
1011   else
1012     return I->getValue();
1013 }
1014 
GetRSSpecificType(const clang::Type * T)1015 DataType RSExportPrimitiveType::GetRSSpecificType(const clang::Type *T) {
1016   T = GetCanonicalType(T);
1017   if ((T == nullptr) || (T->getTypeClass() != clang::Type::Record))
1018     return DataTypeUnknown;
1019 
1020   return GetRSSpecificType( RSExportType::GetTypeName(T) );
1021 }
1022 
IsRSMatrixType(DataType DT)1023 bool RSExportPrimitiveType::IsRSMatrixType(DataType DT) {
1024     if (DT < 0 || DT >= DataTypeMax) {
1025         return false;
1026     }
1027     return gReflectionTypes[DT].category == MatrixDataType;
1028 }
1029 
IsRSObjectType(DataType DT)1030 bool RSExportPrimitiveType::IsRSObjectType(DataType DT) {
1031     if (DT < 0 || DT >= DataTypeMax) {
1032         return false;
1033     }
1034     return gReflectionTypes[DT].category == ObjectDataType;
1035 }
1036 
IsStructureTypeWithRSObject(const clang::Type * T)1037 bool RSExportPrimitiveType::IsStructureTypeWithRSObject(const clang::Type *T) {
1038   bool RSObjectTypeSeen = false;
1039   slangAssert(T);
1040   while (T->isArrayType()) {
1041     T = T->getArrayElementTypeNoTypeQual();
1042     slangAssert(T);
1043   }
1044 
1045   const clang::RecordType *RT = T->getAsStructureType();
1046   if (!RT) {
1047     return false;
1048   }
1049 
1050   const clang::RecordDecl *RD = RT->getDecl();
1051   if (RD) {
1052     RD = RD->getDefinition();
1053   }
1054   if (!RD) {
1055     return false;
1056   }
1057 
1058   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1059          FE = RD->field_end();
1060        FI != FE;
1061        FI++) {
1062     // We just look through all field declarations to see if we find a
1063     // declaration for an RS object type (or an array of one).
1064     const clang::FieldDecl *FD = *FI;
1065     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1066     slangAssert(FT);
1067     while (FT->isArrayType()) {
1068       FT = FT->getArrayElementTypeNoTypeQual();
1069       slangAssert(FT);
1070     }
1071 
1072     DataType DT = GetRSSpecificType(FT);
1073     if (IsRSObjectType(DT)) {
1074       // RS object types definitely need to be zero-initialized
1075       RSObjectTypeSeen = true;
1076     } else {
1077       switch (DT) {
1078         case DataTypeRSMatrix2x2:
1079         case DataTypeRSMatrix3x3:
1080         case DataTypeRSMatrix4x4:
1081           // Matrix types should get zero-initialized as well
1082           RSObjectTypeSeen = true;
1083           break;
1084         default:
1085           // Ignore all other primitive types
1086           break;
1087       }
1088       if (FT->isStructureType()) {
1089         // Recursively handle structs of structs (even though these can't
1090         // be exported, it is possible for a user to have them internally).
1091         RSObjectTypeSeen |= IsStructureTypeWithRSObject(FT);
1092       }
1093     }
1094   }
1095 
1096   return RSObjectTypeSeen;
1097 }
1098 
GetElementSizeInBits(const RSExportPrimitiveType * EPT)1099 size_t RSExportPrimitiveType::GetElementSizeInBits(const RSExportPrimitiveType *EPT) {
1100   int type = EPT->getType();
1101   slangAssert((type > DataTypeUnknown && type < DataTypeMax) &&
1102               "RSExportPrimitiveType::GetElementSizeInBits : unknown data type");
1103   // All RS object types are 256 bits in 64-bit RS.
1104   if (EPT->isRSObjectType() && EPT->getRSContext()->is64Bit()) {
1105     return 256;
1106   }
1107   return gReflectionTypes[type].size_in_bits;
1108 }
1109 
1110 DataType
GetDataType(RSContext * Context,const clang::Type * T)1111 RSExportPrimitiveType::GetDataType(RSContext *Context, const clang::Type *T) {
1112   if (T == nullptr)
1113     return DataTypeUnknown;
1114 
1115   switch (T->getTypeClass()) {
1116     case clang::Type::Builtin: {
1117       const clang::BuiltinType *BT =
1118               static_cast<const clang::BuiltinType*>(T->getCanonicalTypeInternal().getTypePtr());
1119       BuiltinInfo *info = FindBuiltinType(BT->getKind());
1120       if (info != nullptr) {
1121         return info->type;
1122       }
1123       // The size of type WChar depend on platform so we abandon the support
1124       // to them.
1125       Context->ReportError("built-in type cannot be exported: '%0'")
1126           << T->getTypeClassName();
1127       break;
1128     }
1129     case clang::Type::Record: {
1130       // must be RS object type
1131       return RSExportPrimitiveType::GetRSSpecificType(T);
1132     }
1133     default: {
1134       Context->ReportError("primitive type cannot be exported: '%0'")
1135           << T->getTypeClassName();
1136       break;
1137     }
1138   }
1139 
1140   return DataTypeUnknown;
1141 }
1142 
1143 RSExportPrimitiveType
Create(RSContext * Context,const clang::Type * T,const llvm::StringRef & TypeName,bool Normalized)1144 *RSExportPrimitiveType::Create(RSContext *Context,
1145                                const clang::Type *T,
1146                                const llvm::StringRef &TypeName,
1147                                bool Normalized) {
1148   DataType DT = GetDataType(Context, T);
1149 
1150   if ((DT == DataTypeUnknown) || TypeName.empty())
1151     return nullptr;
1152   else
1153     return new RSExportPrimitiveType(Context, ExportClassPrimitive, TypeName,
1154                                      DT, Normalized);
1155 }
1156 
Create(RSContext * Context,const clang::Type * T)1157 RSExportPrimitiveType *RSExportPrimitiveType::Create(RSContext *Context,
1158                                                      const clang::Type *T) {
1159   llvm::StringRef TypeName;
1160   if (RSExportType::NormalizeType(T, TypeName, Context, nullptr,
1161                                   NotLegacyKernelArgument) &&
1162       IsPrimitiveType(T)) {
1163     return Create(Context, T, TypeName);
1164   } else {
1165     return nullptr;
1166   }
1167 }
1168 
convertToLLVMType() const1169 llvm::Type *RSExportPrimitiveType::convertToLLVMType() const {
1170   llvm::LLVMContext &C = getRSContext()->getLLVMContext();
1171 
1172   if (isRSObjectType()) {
1173     // struct {
1174     //   int *p;
1175     // } __attribute__((packed, aligned(pointer_size)))
1176     //
1177     // which is
1178     //
1179     // <{ [1 x i32] }> in LLVM
1180     //
1181     std::vector<llvm::Type *> Elements;
1182     if (getRSContext()->is64Bit()) {
1183       // 64-bit path
1184       Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt64Ty(C), 4));
1185       return llvm::StructType::get(C, Elements, true);
1186     } else {
1187       // 32-bit legacy path
1188       Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt32Ty(C), 1));
1189       return llvm::StructType::get(C, Elements, true);
1190     }
1191   }
1192 
1193   switch (mType) {
1194     case DataTypeFloat16: {
1195       return llvm::Type::getHalfTy(C);
1196       break;
1197     }
1198     case DataTypeFloat32: {
1199       return llvm::Type::getFloatTy(C);
1200       break;
1201     }
1202     case DataTypeFloat64: {
1203       return llvm::Type::getDoubleTy(C);
1204       break;
1205     }
1206     case DataTypeBoolean: {
1207       return llvm::Type::getInt1Ty(C);
1208       break;
1209     }
1210     case DataTypeSigned8:
1211     case DataTypeUnsigned8: {
1212       return llvm::Type::getInt8Ty(C);
1213       break;
1214     }
1215     case DataTypeSigned16:
1216     case DataTypeUnsigned16:
1217     case DataTypeUnsigned565:
1218     case DataTypeUnsigned5551:
1219     case DataTypeUnsigned4444: {
1220       return llvm::Type::getInt16Ty(C);
1221       break;
1222     }
1223     case DataTypeSigned32:
1224     case DataTypeUnsigned32: {
1225       return llvm::Type::getInt32Ty(C);
1226       break;
1227     }
1228     case DataTypeSigned64:
1229     case DataTypeUnsigned64: {
1230       return llvm::Type::getInt64Ty(C);
1231       break;
1232     }
1233     default: {
1234       slangAssert(false && "Unknown data type");
1235     }
1236   }
1237 
1238   return nullptr;
1239 }
1240 
matchODR(const RSExportType * E,bool) const1241 bool RSExportPrimitiveType::matchODR(const RSExportType *E,
1242                                      bool /* LookInto */) const {
1243   CHECK_PARENT_EQUALITY(RSExportType, E);
1244   return (static_cast<const RSExportPrimitiveType*>(E)->getType() == getType());
1245 }
1246 
getRSReflectionType(DataType DT)1247 RSReflectionType *RSExportPrimitiveType::getRSReflectionType(DataType DT) {
1248   if (DT > DataTypeUnknown && DT < DataTypeMax) {
1249     return &gReflectionTypes[DT];
1250   } else {
1251     return nullptr;
1252   }
1253 }
1254 
1255 /**************************** RSExportPointerType ****************************/
1256 
1257 RSExportPointerType
Create(RSContext * Context,const clang::PointerType * PT,const llvm::StringRef & TypeName)1258 *RSExportPointerType::Create(RSContext *Context,
1259                              const clang::PointerType *PT,
1260                              const llvm::StringRef &TypeName) {
1261   const clang::Type *PointeeType = GetPointeeType(PT);
1262   const RSExportType *PointeeET;
1263 
1264   if (PointeeType->getTypeClass() != clang::Type::Pointer) {
1265     PointeeET = RSExportType::Create(Context, PointeeType,
1266                                      NotLegacyKernelArgument);
1267   } else {
1268     // Double or higher dimension of pointer, export as int*
1269     PointeeET = RSExportPrimitiveType::Create(Context,
1270                     Context->getASTContext().IntTy.getTypePtr());
1271   }
1272 
1273   if (PointeeET == nullptr) {
1274     // Error diagnostic is emitted for corresponding pointee type
1275     return nullptr;
1276   }
1277 
1278   return new RSExportPointerType(Context, TypeName, PointeeET);
1279 }
1280 
convertToLLVMType() const1281 llvm::Type *RSExportPointerType::convertToLLVMType() const {
1282   llvm::Type *PointeeType = mPointeeType->getLLVMType();
1283   return llvm::PointerType::getUnqual(PointeeType);
1284 }
1285 
keep()1286 bool RSExportPointerType::keep() {
1287   if (!RSExportType::keep())
1288     return false;
1289   const_cast<RSExportType*>(mPointeeType)->keep();
1290   return true;
1291 }
1292 
matchODR(const RSExportType * E,bool) const1293 bool RSExportPointerType::matchODR(const RSExportType *E,
1294                                    bool /* LookInto */) const {
1295   // Exported types cannot contain pointers
1296   slangAssert(false && "Not supposed to perform ODR check on pointers");
1297   return false;
1298 }
1299 
1300 /***************************** RSExportVectorType *****************************/
1301 llvm::StringRef
GetTypeName(const clang::ExtVectorType * EVT)1302 RSExportVectorType::GetTypeName(const clang::ExtVectorType *EVT) {
1303   const clang::Type *ElementType = GetExtVectorElementType(EVT);
1304   llvm::StringRef name;
1305 
1306   if ((ElementType->getTypeClass() != clang::Type::Builtin))
1307     return name;
1308 
1309   const clang::BuiltinType *BT =
1310           static_cast<const clang::BuiltinType*>(
1311               ElementType->getCanonicalTypeInternal().getTypePtr());
1312 
1313   if ((EVT->getNumElements() < 1) ||
1314       (EVT->getNumElements() > 4))
1315     return name;
1316 
1317   BuiltinInfo *info = FindBuiltinType(BT->getKind());
1318   if (info != nullptr) {
1319     int I = EVT->getNumElements() - 1;
1320     if (I < kMaxVectorSize) {
1321       name = info->cname[I];
1322     } else {
1323       slangAssert(false && "Max vector is 4");
1324     }
1325   }
1326   return name;
1327 }
1328 
Create(RSContext * Context,const clang::ExtVectorType * EVT,const llvm::StringRef & TypeName,bool Normalized)1329 RSExportVectorType *RSExportVectorType::Create(RSContext *Context,
1330                                                const clang::ExtVectorType *EVT,
1331                                                const llvm::StringRef &TypeName,
1332                                                bool Normalized) {
1333   slangAssert(EVT != nullptr && EVT->getTypeClass() == clang::Type::ExtVector);
1334 
1335   const clang::Type *ElementType = GetExtVectorElementType(EVT);
1336   DataType DT = RSExportPrimitiveType::GetDataType(Context, ElementType);
1337 
1338   if (DT != DataTypeUnknown)
1339     return new RSExportVectorType(Context,
1340                                   TypeName,
1341                                   DT,
1342                                   Normalized,
1343                                   EVT->getNumElements());
1344   else
1345     return nullptr;
1346 }
1347 
convertToLLVMType() const1348 llvm::Type *RSExportVectorType::convertToLLVMType() const {
1349   llvm::Type *ElementType = RSExportPrimitiveType::convertToLLVMType();
1350   return llvm::VectorType::get(ElementType, getNumElement());
1351 }
1352 
matchODR(const RSExportType * E,bool) const1353 bool RSExportVectorType::matchODR(const RSExportType *E,
1354                                   bool /* LookInto*/) const {
1355   CHECK_PARENT_EQUALITY(RSExportPrimitiveType, E);
1356   return (static_cast<const RSExportVectorType*>(E)->getNumElement()
1357               == getNumElement());
1358 }
1359 
1360 /***************************** RSExportMatrixType *****************************/
Create(RSContext * Context,const clang::RecordType * RT,const llvm::StringRef & TypeName,unsigned Dim)1361 RSExportMatrixType *RSExportMatrixType::Create(RSContext *Context,
1362                                                const clang::RecordType *RT,
1363                                                const llvm::StringRef &TypeName,
1364                                                unsigned Dim) {
1365   slangAssert((RT != nullptr) && (RT->getTypeClass() == clang::Type::Record));
1366   slangAssert((Dim > 1) && "Invalid dimension of matrix");
1367 
1368   // Check whether the struct rs_matrix is in our expected form (but assume it's
1369   // correct if we're not sure whether it's correct or not)
1370   const clang::RecordDecl* RD = RT->getDecl();
1371   RD = RD->getDefinition();
1372   if (RD != nullptr) {
1373     // Find definition, perform further examination
1374     if (RD->field_empty()) {
1375       Context->ReportError(
1376           RD->getLocation(),
1377           "invalid matrix struct: must have 1 field for saving values: '%0'")
1378           << RD->getName();
1379       return nullptr;
1380     }
1381 
1382     clang::RecordDecl::field_iterator FIT = RD->field_begin();
1383     const clang::FieldDecl *FD = *FIT;
1384     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1385     if ((FT == nullptr) || (FT->getTypeClass() != clang::Type::ConstantArray)) {
1386       Context->ReportError(RD->getLocation(),
1387                            "invalid matrix struct: first field should"
1388                            " be an array with constant size: '%0'")
1389           << RD->getName();
1390       return nullptr;
1391     }
1392     const clang::ConstantArrayType *CAT =
1393       static_cast<const clang::ConstantArrayType *>(FT);
1394     const clang::Type *ElementType = GetConstantArrayElementType(CAT);
1395     if ((ElementType == nullptr) ||
1396         (ElementType->getTypeClass() != clang::Type::Builtin) ||
1397         (static_cast<const clang::BuiltinType *>(ElementType)->getKind() !=
1398          clang::BuiltinType::Float)) {
1399       Context->ReportError(RD->getLocation(),
1400                            "invalid matrix struct: first field "
1401                            "should be a float array: '%0'")
1402           << RD->getName();
1403       return nullptr;
1404     }
1405 
1406     if (CAT->getSize() != Dim * Dim) {
1407       Context->ReportError(RD->getLocation(),
1408                            "invalid matrix struct: first field "
1409                            "should be an array with size %0: '%1'")
1410           << (Dim * Dim) << (RD->getName());
1411       return nullptr;
1412     }
1413 
1414     FIT++;
1415     if (FIT != RD->field_end()) {
1416       Context->ReportError(RD->getLocation(),
1417                            "invalid matrix struct: must have "
1418                            "exactly 1 field: '%0'")
1419           << RD->getName();
1420       return nullptr;
1421     }
1422   }
1423 
1424   return new RSExportMatrixType(Context, TypeName, Dim);
1425 }
1426 
convertToLLVMType() const1427 llvm::Type *RSExportMatrixType::convertToLLVMType() const {
1428   // Construct LLVM type:
1429   // struct {
1430   //  float X[mDim * mDim];
1431   // }
1432 
1433   llvm::LLVMContext &C = getRSContext()->getLLVMContext();
1434   llvm::ArrayType *X = llvm::ArrayType::get(llvm::Type::getFloatTy(C),
1435                                             mDim * mDim);
1436   return llvm::StructType::get(C, X, false);
1437 }
1438 
matchODR(const RSExportType * E,bool) const1439 bool RSExportMatrixType::matchODR(const RSExportType *E,
1440                                   bool /* LookInto */) const {
1441   CHECK_PARENT_EQUALITY(RSExportType, E);
1442   return (static_cast<const RSExportMatrixType*>(E)->getDim() == getDim());
1443 }
1444 
1445 /************************* RSExportConstantArrayType *************************/
1446 RSExportConstantArrayType
Create(RSContext * Context,const clang::ConstantArrayType * CAT)1447 *RSExportConstantArrayType::Create(RSContext *Context,
1448                                    const clang::ConstantArrayType *CAT) {
1449   slangAssert(CAT != nullptr && CAT->getTypeClass() == clang::Type::ConstantArray);
1450 
1451   slangAssert((CAT->getSize().getActiveBits() < 32) && "array too large");
1452 
1453   unsigned Size = static_cast<unsigned>(CAT->getSize().getZExtValue());
1454   slangAssert((Size > 0) && "Constant array should have size greater than 0");
1455 
1456   const clang::Type *ElementType = GetConstantArrayElementType(CAT);
1457   RSExportType *ElementET = RSExportType::Create(Context, ElementType,
1458                                                  NotLegacyKernelArgument);
1459 
1460   if (ElementET == nullptr) {
1461     return nullptr;
1462   }
1463 
1464   return new RSExportConstantArrayType(Context,
1465                                        ElementET,
1466                                        Size);
1467 }
1468 
convertToLLVMType() const1469 llvm::Type *RSExportConstantArrayType::convertToLLVMType() const {
1470   return llvm::ArrayType::get(mElementType->getLLVMType(), getNumElement());
1471 }
1472 
keep()1473 bool RSExportConstantArrayType::keep() {
1474   if (!RSExportType::keep())
1475     return false;
1476   const_cast<RSExportType*>(mElementType)->keep();
1477   return true;
1478 }
1479 
matchODR(const RSExportType * E,bool LookInto) const1480 bool RSExportConstantArrayType::matchODR(const RSExportType *E,
1481                                          bool LookInto) const {
1482   CHECK_PARENT_EQUALITY(RSExportType, E);
1483   const RSExportConstantArrayType *RHS =
1484       static_cast<const RSExportConstantArrayType*>(E);
1485   return ((getNumElement() == RHS->getNumElement()) &&
1486           (getElementType()->matchODR(RHS->getElementType(), LookInto)));
1487 }
1488 
1489 /**************************** RSExportRecordType ****************************/
Create(RSContext * Context,const clang::RecordType * RT,const llvm::StringRef & TypeName,bool mIsArtificial)1490 RSExportRecordType *RSExportRecordType::Create(RSContext *Context,
1491                                                const clang::RecordType *RT,
1492                                                const llvm::StringRef &TypeName,
1493                                                bool mIsArtificial) {
1494   slangAssert(RT != nullptr && RT->getTypeClass() == clang::Type::Record);
1495 
1496   const clang::RecordDecl *RD = RT->getDecl();
1497   slangAssert(RD->isStruct());
1498 
1499   RD = RD->getDefinition();
1500   if (RD == nullptr) {
1501     slangAssert(false && "struct is not defined in this module");
1502     return nullptr;
1503   }
1504 
1505   // Struct layout construct by clang. We rely on this for obtaining the
1506   // alloc size of a struct and offset of every field in that struct.
1507   const clang::ASTRecordLayout *RL =
1508       &Context->getASTContext().getASTRecordLayout(RD);
1509   slangAssert((RL != nullptr) &&
1510       "Failed to retrieve the struct layout from Clang.");
1511 
1512   RSExportRecordType *ERT =
1513       new RSExportRecordType(Context,
1514                              TypeName,
1515                              RD->getLocation(),
1516                              RD->hasAttr<clang::PackedAttr>(),
1517                              mIsArtificial,
1518                              RL->getDataSize().getQuantity(),
1519                              RL->getSize().getQuantity());
1520   unsigned int Index = 0;
1521 
1522   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1523            FE = RD->field_end();
1524        FI != FE;
1525        FI++, Index++) {
1526 
1527     // FIXME: All fields should be primitive type
1528     slangAssert(FI->getKind() == clang::Decl::Field);
1529     clang::FieldDecl *FD = *FI;
1530 
1531     if (FD->isBitField()) {
1532       return nullptr;
1533     }
1534 
1535     if (FD->isImplicit() && (FD->getName() == RS_PADDING_FIELD_NAME))
1536       continue;
1537 
1538     // Type
1539     RSExportType *ET = RSExportElement::CreateFromDecl(Context, FD);
1540 
1541     if (ET != nullptr) {
1542       ERT->mFields.push_back(
1543           new Field(ET, FD->getName(), ERT,
1544                     static_cast<size_t>(RL->getFieldOffset(Index) >> 3)));
1545     } else {
1546       // clang static analysis complains about a potential memory leak
1547       // for the memory pointed by ERT at the end of this basic
1548       // block. This is a false warning because the compiler does not
1549       // see that the pointer to this memory is saved away in the
1550       // constructor for RSExportRecordType by calling
1551       // RSContext::newExportable(this). So, we disable this
1552       // particular instance of the warning.
1553       Context->ReportError(RD->getLocation(),
1554                            "field type cannot be exported: '%0.%1'")
1555           << RD->getName() << FD->getName(); // NOLINT
1556       return nullptr;
1557     }
1558   }
1559 
1560   return ERT;
1561 }
1562 
convertToLLVMType() const1563 llvm::Type *RSExportRecordType::convertToLLVMType() const {
1564   // Create an opaque type since struct may reference itself recursively.
1565 
1566   // TODO(sliao): LLVM took out the OpaqueType. Any other to migrate to?
1567   std::vector<llvm::Type*> FieldTypes;
1568 
1569   for (const_field_iterator FI = fields_begin(), FE = fields_end();
1570        FI != FE;
1571        FI++) {
1572     const Field *F = *FI;
1573     const RSExportType *FET = F->getType();
1574 
1575     FieldTypes.push_back(FET->getLLVMType());
1576   }
1577 
1578   llvm::StructType *ST = llvm::StructType::get(getRSContext()->getLLVMContext(),
1579                                                FieldTypes,
1580                                                mIsPacked);
1581   if (ST != nullptr) {
1582     return ST;
1583   } else {
1584     return nullptr;
1585   }
1586 }
1587 
keep()1588 bool RSExportRecordType::keep() {
1589   if (!RSExportType::keep())
1590     return false;
1591   for (std::list<const Field*>::iterator I = mFields.begin(),
1592           E = mFields.end();
1593        I != E;
1594        I++) {
1595     const_cast<RSExportType*>((*I)->getType())->keep();
1596   }
1597   return true;
1598 }
1599 
matchODR(const RSExportType * E,bool LookInto) const1600 bool RSExportRecordType::matchODR(const RSExportType *E, bool LookInto) const {
1601   CHECK_PARENT_EQUALITY(RSExportType, E);
1602   // Enforce ODR checking - the type E represents must hold
1603   // *exactly* the same "definition" as the one defined previously. We
1604   // say two record types A and B have the same definition iff:
1605   //
1606   //  struct A {              struct B {
1607   //    Type(a1) a1,            Type(b1) b1,
1608   //    Type(a2) a2,            Type(b1) b2,
1609   //    ...                     ...
1610   //    Type(aN) aN             Type(bM) bM,
1611   //  };                      }
1612   //  Cond. #0. A = B;
1613   //  Cond. #1. They have same number of fields, i.e., N = M;
1614   //  Cond. #2. for (i := 1 to N)
1615   //              Type(ai).matchODR(Type(bi)) must hold;
1616   //  Cond. #3. for (i := 1 to N)
1617   //              Name(ai) = Name(bi) must hold;
1618   //
1619   // where,
1620   //  Type(F) = the type of field F and
1621   //  Name(F) = the field name.
1622 
1623 
1624   const RSExportRecordType *ERT = static_cast<const RSExportRecordType*>(E);
1625   // Cond. #0.
1626   if (getName() != ERT->getName())
1627     return false;
1628 
1629   // Examine fields - types and names
1630   if (LookInto) {
1631     // Cond. #1
1632     if (ERT->getFields().size() != getFields().size())
1633       return false;
1634 
1635     for (RSExportRecordType::const_field_iterator AI = fields_begin(),
1636          BI = ERT->fields_begin(), AE = fields_end(); AI != AE; ++AI, ++BI) {
1637       const RSExportType *AITy = (*AI)->getType();
1638       const RSExportType *BITy = (*BI)->getType();
1639       // Cond. #3; field names must agree
1640       if ((*AI)->getName() != (*BI)->getName())
1641         return false;
1642 
1643       // Cond. #2; field types must agree recursively until we see another
1644       // next level of RSExportRecordType - such field types will be
1645       // examined and reported later when checkODR() encounters them.
1646       if (!AITy->matchODR(BITy, false))
1647         return false;
1648     }
1649   }
1650   return true;
1651 }
1652 
convertToRTD(RSReflectionTypeData * rtd) const1653 void RSExportType::convertToRTD(RSReflectionTypeData *rtd) const {
1654     memset(rtd, 0, sizeof(*rtd));
1655     rtd->vecSize = 1;
1656 
1657     switch(getClass()) {
1658     case RSExportType::ExportClassPrimitive: {
1659             const RSExportPrimitiveType *EPT = static_cast<const RSExportPrimitiveType*>(this);
1660             rtd->type = RSExportPrimitiveType::getRSReflectionType(EPT);
1661             return;
1662         }
1663     case RSExportType::ExportClassPointer: {
1664             const RSExportPointerType *EPT = static_cast<const RSExportPointerType*>(this);
1665             const RSExportType *PointeeType = EPT->getPointeeType();
1666             PointeeType->convertToRTD(rtd);
1667             rtd->isPointer = true;
1668             return;
1669         }
1670     case RSExportType::ExportClassVector: {
1671             const RSExportVectorType *EVT = static_cast<const RSExportVectorType*>(this);
1672             rtd->type = EVT->getRSReflectionType(EVT);
1673             rtd->vecSize = EVT->getNumElement();
1674             return;
1675         }
1676     case RSExportType::ExportClassMatrix: {
1677             const RSExportMatrixType *EMT = static_cast<const RSExportMatrixType*>(this);
1678             unsigned Dim = EMT->getDim();
1679             slangAssert((Dim >= 2) && (Dim <= 4));
1680             rtd->type = &gReflectionTypes[15 + Dim-2];
1681             return;
1682         }
1683     case RSExportType::ExportClassConstantArray: {
1684             const RSExportConstantArrayType* CAT =
1685               static_cast<const RSExportConstantArrayType*>(this);
1686             CAT->getElementType()->convertToRTD(rtd);
1687             rtd->arraySize = CAT->getNumElement();
1688             return;
1689         }
1690     case RSExportType::ExportClassRecord: {
1691             slangAssert(!"RSExportType::ExportClassRecord not implemented");
1692             return;// RS_TYPE_CLASS_NAME_PREFIX + ET->getName() + ".Item";
1693         }
1694     default: {
1695             slangAssert(false && "Unknown class of type");
1696         }
1697     }
1698 }
1699 
1700 
1701 }  // namespace slang
1702