1 /*
2  * Copyright 2010-2012, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "slang_rs_export_type.h"
18 
19 #include <list>
20 #include <vector>
21 
22 #include "clang/AST/ASTContext.h"
23 #include "clang/AST/Attr.h"
24 #include "clang/AST/RecordLayout.h"
25 
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/IR/DataLayout.h"
28 #include "llvm/IR/DerivedTypes.h"
29 #include "llvm/IR/Type.h"
30 
31 #include "slang_assert.h"
32 #include "slang_rs_context.h"
33 #include "slang_rs_export_element.h"
34 #include "slang_version.h"
35 
36 #define CHECK_PARENT_EQUALITY(ParentClass, E) \
37   if (!ParentClass::equals(E))                \
38     return false;
39 
40 namespace slang {
41 
42 namespace {
43 
44 /* For the data types we support, their category, names, and size (in bits).
45  *
46  * IMPORTANT: The data types in this table should be at the same index
47  * as specified by the corresponding DataType enum.
48  */
49 static RSReflectionType gReflectionTypes[] = {
50     {PrimitiveDataType, "FLOAT_16", "F16", 16, "half", "short", "Half", "Half", false},
51     {PrimitiveDataType, "FLOAT_32", "F32", 32, "float", "float", "Float", "Float", false},
52     {PrimitiveDataType, "FLOAT_64", "F64", 64, "double", "double", "Double", "Double",false},
53     {PrimitiveDataType, "SIGNED_8", "I8", 8, "int8_t", "byte", "Byte", "Byte", false},
54     {PrimitiveDataType, "SIGNED_16", "I16", 16, "int16_t", "short", "Short", "Short", false},
55     {PrimitiveDataType, "SIGNED_32", "I32", 32, "int32_t", "int", "Int", "Int", false},
56     {PrimitiveDataType, "SIGNED_64", "I64", 64, "int64_t", "long", "Long", "Long", false},
57     {PrimitiveDataType, "UNSIGNED_8", "U8", 8, "uint8_t", "short", "UByte", "Short", true},
58     {PrimitiveDataType, "UNSIGNED_16", "U16", 16, "uint16_t", "int", "UShort", "Int", true},
59     {PrimitiveDataType, "UNSIGNED_32", "U32", 32, "uint32_t", "long", "UInt", "Long", true},
60     {PrimitiveDataType, "UNSIGNED_64", "U64", 64, "uint64_t", "long", "ULong", "Long", false},
61 
62     {PrimitiveDataType, "BOOLEAN", "BOOLEAN", 8, "bool", "boolean", nullptr, nullptr, false},
63 
64     {PrimitiveDataType, "UNSIGNED_5_6_5", nullptr, 16, nullptr, nullptr, nullptr, nullptr, false},
65     {PrimitiveDataType, "UNSIGNED_5_5_5_1", nullptr, 16, nullptr, nullptr, nullptr, nullptr, false},
66     {PrimitiveDataType, "UNSIGNED_4_4_4_4", nullptr, 16, nullptr, nullptr, nullptr, nullptr, false},
67 
68     {MatrixDataType, "MATRIX_2X2", nullptr, 4*32, "rsMatrix_2x2", "Matrix2f", nullptr, nullptr, false},
69     {MatrixDataType, "MATRIX_3X3", nullptr, 9*32, "rsMatrix_3x3", "Matrix3f", nullptr, nullptr, false},
70     {MatrixDataType, "MATRIX_4X4", nullptr, 16*32, "rsMatrix_4x4", "Matrix4f", nullptr, nullptr, false},
71 
72     // RS object types are 32 bits in 32-bit RS, but 256 bits in 64-bit RS.
73     // This is handled specially by the GetSizeInBits() method.
74     {ObjectDataType, "RS_ELEMENT", "ELEMENT", 32, "Element", "Element", nullptr, nullptr, false},
75     {ObjectDataType, "RS_TYPE", "TYPE", 32, "Type", "Type", nullptr, nullptr, false},
76     {ObjectDataType, "RS_ALLOCATION", "ALLOCATION", 32, "Allocation", "Allocation", nullptr, nullptr, false},
77     {ObjectDataType, "RS_SAMPLER", "SAMPLER", 32, "Sampler", "Sampler", nullptr, nullptr, false},
78     {ObjectDataType, "RS_SCRIPT", "SCRIPT", 32, "Script", "Script", nullptr, nullptr, false},
79     {ObjectDataType, "RS_MESH", "MESH", 32, "Mesh", "Mesh", nullptr, nullptr, false},
80     {ObjectDataType, "RS_PATH", "PATH", 32, "Path", "Path", nullptr, nullptr, false},
81 
82     {ObjectDataType, "RS_PROGRAM_FRAGMENT", "PROGRAM_FRAGMENT", 32, "ProgramFragment", "ProgramFragment", nullptr, nullptr, false},
83     {ObjectDataType, "RS_PROGRAM_VERTEX", "PROGRAM_VERTEX", 32, "ProgramVertex", "ProgramVertex", nullptr, nullptr, false},
84     {ObjectDataType, "RS_PROGRAM_RASTER", "PROGRAM_RASTER", 32, "ProgramRaster", "ProgramRaster", nullptr, nullptr, false},
85     {ObjectDataType, "RS_PROGRAM_STORE", "PROGRAM_STORE", 32, "ProgramStore", "ProgramStore", nullptr, nullptr, false},
86     {ObjectDataType, "RS_FONT", "FONT", 32, "Font", "Font", nullptr, nullptr, false}
87 };
88 
89 const int kMaxVectorSize = 4;
90 
91 struct BuiltinInfo {
92   clang::BuiltinType::Kind builtinTypeKind;
93   DataType type;
94   /* TODO If we return std::string instead of llvm::StringRef, we could build
95    * the name instead of duplicating the entries.
96    */
97   const char *cname[kMaxVectorSize];
98 };
99 
100 
101 BuiltinInfo BuiltinInfoTable[] = {
102     {clang::BuiltinType::Bool, DataTypeBoolean,
103      {"bool", "bool2", "bool3", "bool4"}},
104     {clang::BuiltinType::Char_U, DataTypeUnsigned8,
105      {"uchar", "uchar2", "uchar3", "uchar4"}},
106     {clang::BuiltinType::UChar, DataTypeUnsigned8,
107      {"uchar", "uchar2", "uchar3", "uchar4"}},
108     {clang::BuiltinType::Char16, DataTypeSigned16,
109      {"short", "short2", "short3", "short4"}},
110     {clang::BuiltinType::Char32, DataTypeSigned32,
111      {"int", "int2", "int3", "int4"}},
112     {clang::BuiltinType::UShort, DataTypeUnsigned16,
113      {"ushort", "ushort2", "ushort3", "ushort4"}},
114     {clang::BuiltinType::UInt, DataTypeUnsigned32,
115      {"uint", "uint2", "uint3", "uint4"}},
116     {clang::BuiltinType::ULong, DataTypeUnsigned64,
117      {"ulong", "ulong2", "ulong3", "ulong4"}},
118     {clang::BuiltinType::ULongLong, DataTypeUnsigned64,
119      {"ulong", "ulong2", "ulong3", "ulong4"}},
120 
121     {clang::BuiltinType::Char_S, DataTypeSigned8,
122      {"char", "char2", "char3", "char4"}},
123     {clang::BuiltinType::SChar, DataTypeSigned8,
124      {"char", "char2", "char3", "char4"}},
125     {clang::BuiltinType::Short, DataTypeSigned16,
126      {"short", "short2", "short3", "short4"}},
127     {clang::BuiltinType::Int, DataTypeSigned32,
128      {"int", "int2", "int3", "int4"}},
129     {clang::BuiltinType::Long, DataTypeSigned64,
130      {"long", "long2", "long3", "long4"}},
131     {clang::BuiltinType::LongLong, DataTypeSigned64,
132      {"long", "long2", "long3", "long4"}},
133     {clang::BuiltinType::Half, DataTypeFloat16,
134      {"half", "half2", "half3", "half4"}},
135     {clang::BuiltinType::Float, DataTypeFloat32,
136      {"float", "float2", "float3", "float4"}},
137     {clang::BuiltinType::Double, DataTypeFloat64,
138      {"double", "double2", "double3", "double4"}},
139 };
140 const int BuiltinInfoTableCount = sizeof(BuiltinInfoTable) / sizeof(BuiltinInfoTable[0]);
141 
142 struct NameAndPrimitiveType {
143   const char *name;
144   DataType dataType;
145 };
146 
147 static NameAndPrimitiveType MatrixAndObjectDataTypes[] = {
148     {"rs_matrix2x2", DataTypeRSMatrix2x2},
149     {"rs_matrix3x3", DataTypeRSMatrix3x3},
150     {"rs_matrix4x4", DataTypeRSMatrix4x4},
151     {"rs_element", DataTypeRSElement},
152     {"rs_type", DataTypeRSType},
153     {"rs_allocation", DataTypeRSAllocation},
154     {"rs_sampler", DataTypeRSSampler},
155     {"rs_script", DataTypeRSScript},
156     {"rs_mesh", DataTypeRSMesh},
157     {"rs_path", DataTypeRSPath},
158     {"rs_program_fragment", DataTypeRSProgramFragment},
159     {"rs_program_vertex", DataTypeRSProgramVertex},
160     {"rs_program_raster", DataTypeRSProgramRaster},
161     {"rs_program_store", DataTypeRSProgramStore},
162     {"rs_font", DataTypeRSFont},
163 };
164 
165 const int MatrixAndObjectDataTypesCount =
166     sizeof(MatrixAndObjectDataTypes) / sizeof(MatrixAndObjectDataTypes[0]);
167 
168 static const clang::Type *TypeExportableHelper(
169     const clang::Type *T,
170     llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
171     slang::RSContext *Context,
172     const clang::VarDecl *VD,
173     const clang::RecordDecl *TopLevelRecord);
174 
175 template <unsigned N>
ReportTypeError(slang::RSContext * Context,const clang::NamedDecl * ND,const clang::RecordDecl * TopLevelRecord,const char (& Message)[N],unsigned int TargetAPI=0)176 static void ReportTypeError(slang::RSContext *Context,
177                             const clang::NamedDecl *ND,
178                             const clang::RecordDecl *TopLevelRecord,
179                             const char (&Message)[N],
180                             unsigned int TargetAPI = 0) {
181   // Attempt to use the type declaration first (if we have one).
182   // Fall back to the variable definition, if we are looking at something
183   // like an array declaration that can't be exported.
184   if (TopLevelRecord) {
185     Context->ReportError(TopLevelRecord->getLocation(), Message)
186         << TopLevelRecord->getName() << TargetAPI;
187   } else if (ND) {
188     Context->ReportError(ND->getLocation(), Message) << ND->getName()
189                                                      << TargetAPI;
190   } else {
191     slangAssert(false && "Variables should be validated before exporting");
192   }
193 }
194 
ConstantArrayTypeExportableHelper(const clang::ConstantArrayType * CAT,llvm::SmallPtrSet<const clang::Type *,8> & SPS,slang::RSContext * Context,const clang::VarDecl * VD,const clang::RecordDecl * TopLevelRecord)195 static const clang::Type *ConstantArrayTypeExportableHelper(
196     const clang::ConstantArrayType *CAT,
197     llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
198     slang::RSContext *Context,
199     const clang::VarDecl *VD,
200     const clang::RecordDecl *TopLevelRecord) {
201   // Check element type
202   const clang::Type *ElementType = GetConstantArrayElementType(CAT);
203   if (ElementType->isArrayType()) {
204     ReportTypeError(Context, VD, TopLevelRecord,
205                     "multidimensional arrays cannot be exported: '%0'");
206     return nullptr;
207   } else if (ElementType->isExtVectorType()) {
208     const clang::ExtVectorType *EVT =
209         static_cast<const clang::ExtVectorType*>(ElementType);
210     unsigned numElements = EVT->getNumElements();
211 
212     const clang::Type *BaseElementType = GetExtVectorElementType(EVT);
213     if (!RSExportPrimitiveType::IsPrimitiveType(BaseElementType)) {
214       ReportTypeError(Context, VD, TopLevelRecord,
215         "vectors of non-primitive types cannot be exported: '%0'");
216       return nullptr;
217     }
218 
219     if (numElements == 3 && CAT->getSize() != 1) {
220       ReportTypeError(Context, VD, TopLevelRecord,
221         "arrays of width 3 vector types cannot be exported: '%0'");
222       return nullptr;
223     }
224   }
225 
226   if (TypeExportableHelper(ElementType, SPS, Context, VD,
227                            TopLevelRecord) == nullptr) {
228     return nullptr;
229   } else {
230     return CAT;
231   }
232 }
233 
FindBuiltinType(clang::BuiltinType::Kind builtinTypeKind)234 BuiltinInfo *FindBuiltinType(clang::BuiltinType::Kind builtinTypeKind) {
235   for (int i = 0; i < BuiltinInfoTableCount; i++) {
236     if (builtinTypeKind == BuiltinInfoTable[i].builtinTypeKind) {
237       return &BuiltinInfoTable[i];
238     }
239   }
240   return nullptr;
241 }
242 
TypeExportableHelper(clang::Type const * T,llvm::SmallPtrSet<clang::Type const *,8> & SPS,slang::RSContext * Context,clang::VarDecl const * VD,clang::RecordDecl const * TopLevelRecord)243 static const clang::Type *TypeExportableHelper(
244     clang::Type const *T,
245     llvm::SmallPtrSet<clang::Type const *, 8> &SPS,
246     slang::RSContext *Context,
247     clang::VarDecl const *VD,
248     clang::RecordDecl const *TopLevelRecord) {
249   // Normalize first
250   if ((T = GetCanonicalType(T)) == nullptr)
251     return nullptr;
252 
253   if (SPS.count(T))
254     return T;
255 
256   const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
257 
258   switch (T->getTypeClass()) {
259     case clang::Type::Builtin: {
260       const clang::BuiltinType *BT = static_cast<const clang::BuiltinType*>(CTI);
261       return FindBuiltinType(BT->getKind()) == nullptr ? nullptr : T;
262     }
263     case clang::Type::Record: {
264       if (RSExportPrimitiveType::GetRSSpecificType(T) != DataTypeUnknown) {
265         return T;  // RS object type, no further checks are needed
266       }
267 
268       // Check internal struct
269       if (T->isUnionType()) {
270         ReportTypeError(Context, VD, T->getAsUnionType()->getDecl(),
271                         "unions cannot be exported: '%0'");
272         return nullptr;
273       } else if (!T->isStructureType()) {
274         slangAssert(false && "Unknown type cannot be exported");
275         return nullptr;
276       }
277 
278       clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
279       if (RD != nullptr) {
280         RD = RD->getDefinition();
281         if (RD == nullptr) {
282           ReportTypeError(Context, nullptr, T->getAsStructureType()->getDecl(),
283                           "struct is not defined in this module");
284           return nullptr;
285         }
286       }
287 
288       if (!TopLevelRecord) {
289         TopLevelRecord = RD;
290       }
291       if (RD->getName().empty()) {
292         ReportTypeError(Context, nullptr, RD,
293                         "anonymous structures cannot be exported");
294         return nullptr;
295       }
296 
297       // Fast check
298       if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
299         return nullptr;
300 
301       // Insert myself into checking set
302       SPS.insert(T);
303 
304       // Check all element
305       for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
306                FE = RD->field_end();
307            FI != FE;
308            FI++) {
309         const clang::FieldDecl *FD = *FI;
310         const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
311         FT = GetCanonicalType(FT);
312 
313         if (!TypeExportableHelper(FT, SPS, Context, VD, TopLevelRecord)) {
314           return nullptr;
315         }
316 
317         // We don't support bit fields yet
318         //
319         // TODO(zonr/srhines): allow bit fields of size 8, 16, 32
320         if (FD->isBitField()) {
321           Context->ReportError(
322               FD->getLocation(),
323               "bit fields are not able to be exported: '%0.%1'")
324               << RD->getName() << FD->getName();
325           return nullptr;
326         }
327       }
328 
329       return T;
330     }
331     case clang::Type::Pointer: {
332       if (TopLevelRecord) {
333         ReportTypeError(Context, VD, TopLevelRecord,
334             "structures containing pointers cannot be used as the type of "
335             "an exported global variable or the parameter to an exported "
336             "function: '%0'");
337         return nullptr;
338       }
339 
340       const clang::PointerType *PT = static_cast<const clang::PointerType*>(CTI);
341       const clang::Type *PointeeType = GetPointeeType(PT);
342 
343       if (PointeeType->getTypeClass() == clang::Type::Pointer) {
344         ReportTypeError(Context, VD, TopLevelRecord,
345             "multiple levels of pointers cannot be exported: '%0'");
346         return nullptr;
347       }
348       // We don't support pointer with array-type pointee or unsupported pointee
349       // type
350       if (PointeeType->isArrayType() ||
351           (TypeExportableHelper(PointeeType, SPS, Context, VD,
352                                 TopLevelRecord) == nullptr))
353         return nullptr;
354       else
355         return T;
356     }
357     case clang::Type::ExtVector: {
358       const clang::ExtVectorType *EVT =
359               static_cast<const clang::ExtVectorType*>(CTI);
360       // Only vector with size 2, 3 and 4 are supported.
361       if (EVT->getNumElements() < 2 || EVT->getNumElements() > 4)
362         return nullptr;
363 
364       // Check base element type
365       const clang::Type *ElementType = GetExtVectorElementType(EVT);
366 
367       if ((ElementType->getTypeClass() != clang::Type::Builtin) ||
368           (TypeExportableHelper(ElementType, SPS, Context, VD,
369                                 TopLevelRecord) == nullptr))
370         return nullptr;
371       else
372         return T;
373     }
374     case clang::Type::ConstantArray: {
375       const clang::ConstantArrayType *CAT =
376               static_cast<const clang::ConstantArrayType*>(CTI);
377 
378       return ConstantArrayTypeExportableHelper(CAT, SPS, Context, VD,
379                                                TopLevelRecord);
380     }
381     case clang::Type::Enum: {
382       // FIXME: We currently convert enums to integers, rather than reflecting
383       // a more complete (and nicer type-safe Java version).
384       return Context->getASTContext().IntTy.getTypePtr();
385     }
386     default: {
387       slangAssert(false && "Unknown type cannot be validated");
388       return nullptr;
389     }
390   }
391 }
392 
393 // Return the type that can be used to create RSExportType, will always return
394 // the canonical type.
395 //
396 // If the Type T is not exportable, this function returns nullptr. DiagEngine is
397 // used to generate proper Clang diagnostic messages when a non-exportable type
398 // is detected. TopLevelRecord is used to capture the highest struct (in the
399 // case of a nested hierarchy) for detecting other types that cannot be exported
400 // (mostly pointers within a struct).
TypeExportable(const clang::Type * T,slang::RSContext * Context,const clang::VarDecl * VD)401 static const clang::Type *TypeExportable(const clang::Type *T,
402                                          slang::RSContext *Context,
403                                          const clang::VarDecl *VD) {
404   llvm::SmallPtrSet<const clang::Type*, 8> SPS =
405       llvm::SmallPtrSet<const clang::Type*, 8>();
406 
407   return TypeExportableHelper(T, SPS, Context, VD, nullptr);
408 }
409 
ValidateRSObjectInVarDecl(slang::RSContext * Context,clang::VarDecl * VD,bool InCompositeType,unsigned int TargetAPI)410 static bool ValidateRSObjectInVarDecl(slang::RSContext *Context,
411                                       clang::VarDecl *VD, bool InCompositeType,
412                                       unsigned int TargetAPI) {
413   if (TargetAPI < SLANG_JB_TARGET_API) {
414     // Only if we are already in a composite type (like an array or structure).
415     if (InCompositeType) {
416       // Only if we are actually exported (i.e. non-static).
417       if (VD->hasLinkage() &&
418           (VD->getFormalLinkage() == clang::ExternalLinkage)) {
419         // Only if we are not a pointer to an object.
420         const clang::Type *T = GetCanonicalType(VD->getType().getTypePtr());
421         if (T->getTypeClass() != clang::Type::Pointer) {
422           ReportTypeError(Context, VD, nullptr,
423                           "arrays/structures containing RS object types "
424                           "cannot be exported in target API < %1: '%0'",
425                           SLANG_JB_TARGET_API);
426           return false;
427         }
428       }
429     }
430   }
431 
432   return true;
433 }
434 
435 // Helper function for ValidateType(). We do a recursive descent on the
436 // type hierarchy to ensure that we can properly export/handle the
437 // declaration.
438 // \return true if the variable declaration is valid,
439 //         false if it is invalid (along with proper diagnostics).
440 //
441 // C - ASTContext (for diagnostics + builtin types).
442 // T - sub-type that we are validating.
443 // ND - (optional) top-level named declaration that we are validating.
444 // SPS - set of types we have already seen/validated.
445 // InCompositeType - true if we are within an outer composite type.
446 // UnionDecl - set if we are in a sub-type of a union.
447 // TargetAPI - target SDK API level.
448 // IsFilterscript - whether or not we are compiling for Filterscript
449 // IsExtern - is this type externally visible (i.e. extern global or parameter
450 //                                             to an extern function)
ValidateTypeHelper(slang::RSContext * Context,clang::ASTContext & C,const clang::Type * & T,clang::NamedDecl * ND,clang::SourceLocation Loc,llvm::SmallPtrSet<const clang::Type *,8> & SPS,bool InCompositeType,clang::RecordDecl * UnionDecl,unsigned int TargetAPI,bool IsFilterscript,bool IsExtern)451 static bool ValidateTypeHelper(
452     slang::RSContext *Context,
453     clang::ASTContext &C,
454     const clang::Type *&T,
455     clang::NamedDecl *ND,
456     clang::SourceLocation Loc,
457     llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
458     bool InCompositeType,
459     clang::RecordDecl *UnionDecl,
460     unsigned int TargetAPI,
461     bool IsFilterscript,
462     bool IsExtern) {
463   if ((T = GetCanonicalType(T)) == nullptr)
464     return true;
465 
466   if (SPS.count(T))
467     return true;
468 
469   const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
470 
471   switch (T->getTypeClass()) {
472     case clang::Type::Record: {
473       if (RSExportPrimitiveType::IsRSObjectType(T)) {
474         clang::VarDecl *VD = (ND ? llvm::dyn_cast<clang::VarDecl>(ND) : nullptr);
475         if (VD && !ValidateRSObjectInVarDecl(Context, VD, InCompositeType,
476                                              TargetAPI)) {
477           return false;
478         }
479       }
480 
481       if (RSExportPrimitiveType::GetRSSpecificType(T) != DataTypeUnknown) {
482         if (!UnionDecl) {
483           return true;
484         } else if (RSExportPrimitiveType::IsRSObjectType(T)) {
485           ReportTypeError(Context, nullptr, UnionDecl,
486               "unions containing RS object types are not allowed");
487           return false;
488         }
489       }
490 
491       clang::RecordDecl *RD = nullptr;
492 
493       // Check internal struct
494       if (T->isUnionType()) {
495         RD = T->getAsUnionType()->getDecl();
496         UnionDecl = RD;
497       } else if (T->isStructureType()) {
498         RD = T->getAsStructureType()->getDecl();
499       } else {
500         slangAssert(false && "Unknown type cannot be exported");
501         return false;
502       }
503 
504       if (RD != nullptr) {
505         RD = RD->getDefinition();
506         if (RD == nullptr) {
507           // FIXME
508           return true;
509         }
510       }
511 
512       // Fast check
513       if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
514         return false;
515 
516       // Insert myself into checking set
517       SPS.insert(T);
518 
519       // Check all elements
520       for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
521                FE = RD->field_end();
522            FI != FE;
523            FI++) {
524         const clang::FieldDecl *FD = *FI;
525         const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
526         FT = GetCanonicalType(FT);
527 
528         if (!ValidateTypeHelper(Context, C, FT, ND, Loc, SPS, true, UnionDecl,
529                                 TargetAPI, IsFilterscript, IsExtern)) {
530           return false;
531         }
532       }
533 
534       return true;
535     }
536 
537     case clang::Type::Builtin: {
538       if (IsFilterscript) {
539         clang::QualType QT = T->getCanonicalTypeInternal();
540         if (QT == C.DoubleTy ||
541             QT == C.LongDoubleTy ||
542             QT == C.LongTy ||
543             QT == C.LongLongTy) {
544           if (ND) {
545             Context->ReportError(
546                 Loc,
547                 "Builtin types > 32 bits in size are forbidden in "
548                 "Filterscript: '%0'")
549                 << ND->getName();
550           } else {
551             Context->ReportError(
552                 Loc,
553                 "Builtin types > 32 bits in size are forbidden in "
554                 "Filterscript");
555           }
556           return false;
557         }
558       }
559       break;
560     }
561 
562     case clang::Type::Pointer: {
563       if (IsFilterscript) {
564         if (ND) {
565           Context->ReportError(Loc,
566                                "Pointers are forbidden in Filterscript: '%0'")
567               << ND->getName();
568           return false;
569         } else {
570           // TODO(srhines): Find a better way to handle expressions (i.e. no
571           // NamedDecl) involving pointers in FS that should be allowed.
572           // An example would be calls to library functions like
573           // rsMatrixMultiply() that take rs_matrixNxN * types.
574         }
575       }
576 
577       // Forbid pointers in structures that are externally visible.
578       if (InCompositeType && IsExtern) {
579         if (ND) {
580           Context->ReportError(Loc,
581               "structures containing pointers cannot be used as the type of "
582               "an exported global variable or the parameter to an exported "
583               "function: '%0'")
584             << ND->getName();
585         } else {
586           Context->ReportError(Loc,
587               "structures containing pointers cannot be used as the type of "
588               "an exported global variable or the parameter to an exported "
589               "function");
590         }
591         return false;
592       }
593 
594       const clang::PointerType *PT = static_cast<const clang::PointerType*>(CTI);
595       const clang::Type *PointeeType = GetPointeeType(PT);
596 
597       return ValidateTypeHelper(Context, C, PointeeType, ND, Loc, SPS,
598                                 InCompositeType, UnionDecl, TargetAPI,
599                                 IsFilterscript, IsExtern);
600     }
601 
602     case clang::Type::ExtVector: {
603       const clang::ExtVectorType *EVT =
604               static_cast<const clang::ExtVectorType*>(CTI);
605       const clang::Type *ElementType = GetExtVectorElementType(EVT);
606       if (TargetAPI < SLANG_ICS_TARGET_API &&
607           InCompositeType &&
608           EVT->getNumElements() == 3 &&
609           ND &&
610           ND->getFormalLinkage() == clang::ExternalLinkage) {
611         ReportTypeError(Context, ND, nullptr,
612                         "structs containing vectors of dimension 3 cannot "
613                         "be exported at this API level: '%0'");
614         return false;
615       }
616       return ValidateTypeHelper(Context, C, ElementType, ND, Loc, SPS, true,
617                                 UnionDecl, TargetAPI, IsFilterscript, IsExtern);
618     }
619 
620     case clang::Type::ConstantArray: {
621       const clang::ConstantArrayType *CAT = static_cast<const clang::ConstantArrayType*>(CTI);
622       const clang::Type *ElementType = GetConstantArrayElementType(CAT);
623       return ValidateTypeHelper(Context, C, ElementType, ND, Loc, SPS, true,
624                                 UnionDecl, TargetAPI, IsFilterscript, IsExtern);
625     }
626 
627     default: {
628       break;
629     }
630   }
631 
632   return true;
633 }
634 
635 }  // namespace
636 
CreateDummyName(const char * type,const std::string & name)637 std::string CreateDummyName(const char *type, const std::string &name) {
638   std::stringstream S;
639   S << "<" << type;
640   if (!name.empty()) {
641     S << ":" << name;
642   }
643   S << ">";
644   return S.str();
645 }
646 
647 /****************************** RSExportType ******************************/
NormalizeType(const clang::Type * & T,llvm::StringRef & TypeName,RSContext * Context,const clang::VarDecl * VD)648 bool RSExportType::NormalizeType(const clang::Type *&T,
649                                  llvm::StringRef &TypeName,
650                                  RSContext *Context,
651                                  const clang::VarDecl *VD) {
652   if ((T = TypeExportable(T, Context, VD)) == nullptr) {
653     return false;
654   }
655   // Get type name
656   TypeName = RSExportType::GetTypeName(T);
657   if (Context && TypeName.empty()) {
658     if (VD) {
659       Context->ReportError(VD->getLocation(),
660                            "anonymous types cannot be exported");
661     } else {
662       Context->ReportError("anonymous types cannot be exported");
663     }
664     return false;
665   }
666 
667   return true;
668 }
669 
ValidateType(slang::RSContext * Context,clang::ASTContext & C,clang::QualType QT,clang::NamedDecl * ND,clang::SourceLocation Loc,unsigned int TargetAPI,bool IsFilterscript,bool IsExtern)670 bool RSExportType::ValidateType(slang::RSContext *Context, clang::ASTContext &C,
671                                 clang::QualType QT, clang::NamedDecl *ND,
672                                 clang::SourceLocation Loc,
673                                 unsigned int TargetAPI, bool IsFilterscript,
674                                 bool IsExtern) {
675   const clang::Type *T = QT.getTypePtr();
676   llvm::SmallPtrSet<const clang::Type*, 8> SPS =
677       llvm::SmallPtrSet<const clang::Type*, 8>();
678 
679   // If this is an externally visible variable declaration, we check if the
680   // type is able to be exported first.
681   if (auto VD = llvm::dyn_cast_or_null<clang::VarDecl>(ND)) {
682     if (VD->getFormalLinkage() == clang::ExternalLinkage) {
683       if (!TypeExportable(T, Context, VD)) {
684         return false;
685       }
686     }
687   }
688   return ValidateTypeHelper(Context, C, T, ND, Loc, SPS, false, nullptr, TargetAPI,
689                             IsFilterscript, IsExtern);
690 }
691 
ValidateVarDecl(slang::RSContext * Context,clang::VarDecl * VD,unsigned int TargetAPI,bool IsFilterscript)692 bool RSExportType::ValidateVarDecl(slang::RSContext *Context,
693                                    clang::VarDecl *VD, unsigned int TargetAPI,
694                                    bool IsFilterscript) {
695   return ValidateType(Context, VD->getASTContext(), VD->getType(), VD,
696                       VD->getLocation(), TargetAPI, IsFilterscript,
697                       (VD->getFormalLinkage() == clang::ExternalLinkage));
698 }
699 
700 const clang::Type
GetTypeOfDecl(const clang::DeclaratorDecl * DD)701 *RSExportType::GetTypeOfDecl(const clang::DeclaratorDecl *DD) {
702   if (DD) {
703     clang::QualType T = DD->getType();
704 
705     if (T.isNull())
706       return nullptr;
707     else
708       return T.getTypePtr();
709   }
710   return nullptr;
711 }
712 
GetTypeName(const clang::Type * T)713 llvm::StringRef RSExportType::GetTypeName(const clang::Type* T) {
714   T = GetCanonicalType(T);
715   if (T == nullptr)
716     return llvm::StringRef();
717 
718   const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
719 
720   switch (T->getTypeClass()) {
721     case clang::Type::Builtin: {
722       const clang::BuiltinType *BT = static_cast<const clang::BuiltinType*>(CTI);
723       BuiltinInfo *info = FindBuiltinType(BT->getKind());
724       if (info != nullptr) {
725         return info->cname[0];
726       }
727       slangAssert(false && "Unknown data type of the builtin");
728       break;
729     }
730     case clang::Type::Record: {
731       clang::RecordDecl *RD;
732       if (T->isStructureType()) {
733         RD = T->getAsStructureType()->getDecl();
734       } else {
735         break;
736       }
737 
738       llvm::StringRef Name = RD->getName();
739       if (Name.empty()) {
740         if (RD->getTypedefNameForAnonDecl() != nullptr) {
741           Name = RD->getTypedefNameForAnonDecl()->getName();
742         }
743 
744         if (Name.empty()) {
745           // Try to find a name from redeclaration (i.e. typedef)
746           for (clang::TagDecl::redecl_iterator RI = RD->redecls_begin(),
747                    RE = RD->redecls_end();
748                RI != RE;
749                RI++) {
750             slangAssert(*RI != nullptr && "cannot be NULL object");
751 
752             Name = (*RI)->getName();
753             if (!Name.empty())
754               break;
755           }
756         }
757       }
758       return Name;
759     }
760     case clang::Type::Pointer: {
761       // "*" plus pointee name
762       const clang::PointerType *P = static_cast<const clang::PointerType*>(CTI);
763       const clang::Type *PT = GetPointeeType(P);
764       llvm::StringRef PointeeName;
765       if (NormalizeType(PT, PointeeName, nullptr, nullptr)) {
766         char *Name = new char[ 1 /* * */ + PointeeName.size() + 1 ];
767         Name[0] = '*';
768         memcpy(Name + 1, PointeeName.data(), PointeeName.size());
769         Name[PointeeName.size() + 1] = '\0';
770         return Name;
771       }
772       break;
773     }
774     case clang::Type::ExtVector: {
775       const clang::ExtVectorType *EVT =
776               static_cast<const clang::ExtVectorType*>(CTI);
777       return RSExportVectorType::GetTypeName(EVT);
778       break;
779     }
780     case clang::Type::ConstantArray : {
781       // Construct name for a constant array is too complicated.
782       return "<ConstantArray>";
783     }
784     default: {
785       break;
786     }
787   }
788 
789   return llvm::StringRef();
790 }
791 
792 
Create(RSContext * Context,const clang::Type * T,const llvm::StringRef & TypeName)793 RSExportType *RSExportType::Create(RSContext *Context,
794                                    const clang::Type *T,
795                                    const llvm::StringRef &TypeName) {
796   // Lookup the context to see whether the type was processed before.
797   // Newly created RSExportType will insert into context
798   // in RSExportType::RSExportType()
799   RSContext::export_type_iterator ETI = Context->findExportType(TypeName);
800 
801   if (ETI != Context->export_types_end())
802     return ETI->second;
803 
804   const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
805 
806   RSExportType *ET = nullptr;
807   switch (T->getTypeClass()) {
808     case clang::Type::Record: {
809       DataType dt = RSExportPrimitiveType::GetRSSpecificType(TypeName);
810       switch (dt) {
811         case DataTypeUnknown: {
812           // User-defined types
813           ET = RSExportRecordType::Create(Context,
814                                           T->getAsStructureType(),
815                                           TypeName);
816           break;
817         }
818         case DataTypeRSMatrix2x2: {
819           // 2 x 2 Matrix type
820           ET = RSExportMatrixType::Create(Context,
821                                           T->getAsStructureType(),
822                                           TypeName,
823                                           2);
824           break;
825         }
826         case DataTypeRSMatrix3x3: {
827           // 3 x 3 Matrix type
828           ET = RSExportMatrixType::Create(Context,
829                                           T->getAsStructureType(),
830                                           TypeName,
831                                           3);
832           break;
833         }
834         case DataTypeRSMatrix4x4: {
835           // 4 x 4 Matrix type
836           ET = RSExportMatrixType::Create(Context,
837                                           T->getAsStructureType(),
838                                           TypeName,
839                                           4);
840           break;
841         }
842         default: {
843           // Others are primitive types
844           ET = RSExportPrimitiveType::Create(Context, T, TypeName);
845           break;
846         }
847       }
848       break;
849     }
850     case clang::Type::Builtin: {
851       ET = RSExportPrimitiveType::Create(Context, T, TypeName);
852       break;
853     }
854     case clang::Type::Pointer: {
855       ET = RSExportPointerType::Create(Context,
856                                        static_cast<const clang::PointerType*>(CTI),
857                                        TypeName);
858       // FIXME: free the name (allocated in RSExportType::GetTypeName)
859       delete [] TypeName.data();
860       break;
861     }
862     case clang::Type::ExtVector: {
863       ET = RSExportVectorType::Create(Context,
864                                       static_cast<const clang::ExtVectorType*>(CTI),
865                                       TypeName);
866       break;
867     }
868     case clang::Type::ConstantArray: {
869       ET = RSExportConstantArrayType::Create(
870               Context,
871               static_cast<const clang::ConstantArrayType*>(CTI));
872       break;
873     }
874     default: {
875       Context->ReportError("unknown type cannot be exported: '%0'")
876           << T->getTypeClassName();
877       break;
878     }
879   }
880 
881   return ET;
882 }
883 
Create(RSContext * Context,const clang::Type * T)884 RSExportType *RSExportType::Create(RSContext *Context, const clang::Type *T) {
885   llvm::StringRef TypeName;
886   if (NormalizeType(T, TypeName, Context, nullptr)) {
887     return Create(Context, T, TypeName);
888   } else {
889     return nullptr;
890   }
891 }
892 
CreateFromDecl(RSContext * Context,const clang::VarDecl * VD)893 RSExportType *RSExportType::CreateFromDecl(RSContext *Context,
894                                            const clang::VarDecl *VD) {
895   return RSExportType::Create(Context, GetTypeOfDecl(VD));
896 }
897 
getStoreSize() const898 size_t RSExportType::getStoreSize() const {
899   return getRSContext()->getDataLayout()->getTypeStoreSize(getLLVMType());
900 }
901 
getAllocSize() const902 size_t RSExportType::getAllocSize() const {
903     return getRSContext()->getDataLayout()->getTypeAllocSize(getLLVMType());
904 }
905 
RSExportType(RSContext * Context,ExportClass Class,const llvm::StringRef & Name)906 RSExportType::RSExportType(RSContext *Context,
907                            ExportClass Class,
908                            const llvm::StringRef &Name)
909     : RSExportable(Context, RSExportable::EX_TYPE),
910       mClass(Class),
911       // Make a copy on Name since memory stored @Name is either allocated in
912       // ASTContext or allocated in GetTypeName which will be destroyed later.
913       mName(Name.data(), Name.size()),
914       mLLVMType(nullptr) {
915   // Don't cache the type whose name start with '<'. Those type failed to
916   // get their name since constructing their name in GetTypeName() requiring
917   // complicated work.
918   if (!IsDummyName(Name)) {
919     // TODO(zonr): Need to check whether the insertion is successful or not.
920     Context->insertExportType(llvm::StringRef(Name), this);
921   }
922 
923 }
924 
keep()925 bool RSExportType::keep() {
926   if (!RSExportable::keep())
927     return false;
928   // Invalidate converted LLVM type.
929   mLLVMType = nullptr;
930   return true;
931 }
932 
equals(const RSExportable * E) const933 bool RSExportType::equals(const RSExportable *E) const {
934   CHECK_PARENT_EQUALITY(RSExportable, E);
935   return (static_cast<const RSExportType*>(E)->getClass() == getClass());
936 }
937 
~RSExportType()938 RSExportType::~RSExportType() {
939 }
940 
941 /************************** RSExportPrimitiveType **************************/
942 llvm::ManagedStatic<RSExportPrimitiveType::RSSpecificTypeMapTy>
943 RSExportPrimitiveType::RSSpecificTypeMap;
944 
IsPrimitiveType(const clang::Type * T)945 bool RSExportPrimitiveType::IsPrimitiveType(const clang::Type *T) {
946   if ((T != nullptr) && (T->getTypeClass() == clang::Type::Builtin))
947     return true;
948   else
949     return false;
950 }
951 
952 DataType
GetRSSpecificType(const llvm::StringRef & TypeName)953 RSExportPrimitiveType::GetRSSpecificType(const llvm::StringRef &TypeName) {
954   if (TypeName.empty())
955     return DataTypeUnknown;
956 
957   if (RSSpecificTypeMap->empty()) {
958     for (int i = 0; i < MatrixAndObjectDataTypesCount; i++) {
959       (*RSSpecificTypeMap)[MatrixAndObjectDataTypes[i].name] =
960           MatrixAndObjectDataTypes[i].dataType;
961     }
962   }
963 
964   RSSpecificTypeMapTy::const_iterator I = RSSpecificTypeMap->find(TypeName);
965   if (I == RSSpecificTypeMap->end())
966     return DataTypeUnknown;
967   else
968     return I->getValue();
969 }
970 
GetRSSpecificType(const clang::Type * T)971 DataType RSExportPrimitiveType::GetRSSpecificType(const clang::Type *T) {
972   T = GetCanonicalType(T);
973   if ((T == nullptr) || (T->getTypeClass() != clang::Type::Record))
974     return DataTypeUnknown;
975 
976   return GetRSSpecificType( RSExportType::GetTypeName(T) );
977 }
978 
IsRSMatrixType(DataType DT)979 bool RSExportPrimitiveType::IsRSMatrixType(DataType DT) {
980     if (DT < 0 || DT >= DataTypeMax) {
981         return false;
982     }
983     return gReflectionTypes[DT].category == MatrixDataType;
984 }
985 
IsRSObjectType(DataType DT)986 bool RSExportPrimitiveType::IsRSObjectType(DataType DT) {
987     if (DT < 0 || DT >= DataTypeMax) {
988         return false;
989     }
990     return gReflectionTypes[DT].category == ObjectDataType;
991 }
992 
IsStructureTypeWithRSObject(const clang::Type * T)993 bool RSExportPrimitiveType::IsStructureTypeWithRSObject(const clang::Type *T) {
994   bool RSObjectTypeSeen = false;
995   while (T && T->isArrayType()) {
996     T = T->getArrayElementTypeNoTypeQual();
997   }
998 
999   const clang::RecordType *RT = T->getAsStructureType();
1000   if (!RT) {
1001     return false;
1002   }
1003 
1004   const clang::RecordDecl *RD = RT->getDecl();
1005   if (RD) {
1006     RD = RD->getDefinition();
1007   }
1008   if (!RD) {
1009     return false;
1010   }
1011 
1012   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1013          FE = RD->field_end();
1014        FI != FE;
1015        FI++) {
1016     // We just look through all field declarations to see if we find a
1017     // declaration for an RS object type (or an array of one).
1018     const clang::FieldDecl *FD = *FI;
1019     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1020     while (FT && FT->isArrayType()) {
1021       FT = FT->getArrayElementTypeNoTypeQual();
1022     }
1023 
1024     DataType DT = GetRSSpecificType(FT);
1025     if (IsRSObjectType(DT)) {
1026       // RS object types definitely need to be zero-initialized
1027       RSObjectTypeSeen = true;
1028     } else {
1029       switch (DT) {
1030         case DataTypeRSMatrix2x2:
1031         case DataTypeRSMatrix3x3:
1032         case DataTypeRSMatrix4x4:
1033           // Matrix types should get zero-initialized as well
1034           RSObjectTypeSeen = true;
1035           break;
1036         default:
1037           // Ignore all other primitive types
1038           break;
1039       }
1040       while (FT && FT->isArrayType()) {
1041         FT = FT->getArrayElementTypeNoTypeQual();
1042       }
1043       if (FT->isStructureType()) {
1044         // Recursively handle structs of structs (even though these can't
1045         // be exported, it is possible for a user to have them internally).
1046         RSObjectTypeSeen |= IsStructureTypeWithRSObject(FT);
1047       }
1048     }
1049   }
1050 
1051   return RSObjectTypeSeen;
1052 }
1053 
GetSizeInBits(const RSExportPrimitiveType * EPT)1054 size_t RSExportPrimitiveType::GetSizeInBits(const RSExportPrimitiveType *EPT) {
1055   int type = EPT->getType();
1056   slangAssert((type > DataTypeUnknown && type < DataTypeMax) &&
1057               "RSExportPrimitiveType::GetSizeInBits : unknown data type");
1058   // All RS object types are 256 bits in 64-bit RS.
1059   if (EPT->isRSObjectType() && EPT->getRSContext()->is64Bit()) {
1060     return 256;
1061   }
1062   return gReflectionTypes[type].size_in_bits;
1063 }
1064 
1065 DataType
GetDataType(RSContext * Context,const clang::Type * T)1066 RSExportPrimitiveType::GetDataType(RSContext *Context, const clang::Type *T) {
1067   if (T == nullptr)
1068     return DataTypeUnknown;
1069 
1070   switch (T->getTypeClass()) {
1071     case clang::Type::Builtin: {
1072       const clang::BuiltinType *BT =
1073               static_cast<const clang::BuiltinType*>(T->getCanonicalTypeInternal().getTypePtr());
1074       BuiltinInfo *info = FindBuiltinType(BT->getKind());
1075       if (info != nullptr) {
1076         return info->type;
1077       }
1078       // The size of type WChar depend on platform so we abandon the support
1079       // to them.
1080       Context->ReportError("built-in type cannot be exported: '%0'")
1081           << T->getTypeClassName();
1082       break;
1083     }
1084     case clang::Type::Record: {
1085       // must be RS object type
1086       return RSExportPrimitiveType::GetRSSpecificType(T);
1087     }
1088     default: {
1089       Context->ReportError("primitive type cannot be exported: '%0'")
1090           << T->getTypeClassName();
1091       break;
1092     }
1093   }
1094 
1095   return DataTypeUnknown;
1096 }
1097 
1098 RSExportPrimitiveType
Create(RSContext * Context,const clang::Type * T,const llvm::StringRef & TypeName,bool Normalized)1099 *RSExportPrimitiveType::Create(RSContext *Context,
1100                                const clang::Type *T,
1101                                const llvm::StringRef &TypeName,
1102                                bool Normalized) {
1103   DataType DT = GetDataType(Context, T);
1104 
1105   if ((DT == DataTypeUnknown) || TypeName.empty())
1106     return nullptr;
1107   else
1108     return new RSExportPrimitiveType(Context, ExportClassPrimitive, TypeName,
1109                                      DT, Normalized);
1110 }
1111 
Create(RSContext * Context,const clang::Type * T)1112 RSExportPrimitiveType *RSExportPrimitiveType::Create(RSContext *Context,
1113                                                      const clang::Type *T) {
1114   llvm::StringRef TypeName;
1115   if (RSExportType::NormalizeType(T, TypeName, Context, nullptr)
1116       && IsPrimitiveType(T)) {
1117     return Create(Context, T, TypeName);
1118   } else {
1119     return nullptr;
1120   }
1121 }
1122 
convertToLLVMType() const1123 llvm::Type *RSExportPrimitiveType::convertToLLVMType() const {
1124   llvm::LLVMContext &C = getRSContext()->getLLVMContext();
1125 
1126   if (isRSObjectType()) {
1127     // struct {
1128     //   int *p;
1129     // } __attribute__((packed, aligned(pointer_size)))
1130     //
1131     // which is
1132     //
1133     // <{ [1 x i32] }> in LLVM
1134     //
1135     std::vector<llvm::Type *> Elements;
1136     if (getRSContext()->is64Bit()) {
1137       // 64-bit path
1138       Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt64Ty(C), 4));
1139       return llvm::StructType::get(C, Elements, true);
1140     } else {
1141       // 32-bit legacy path
1142       Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt32Ty(C), 1));
1143       return llvm::StructType::get(C, Elements, true);
1144     }
1145   }
1146 
1147   switch (mType) {
1148     case DataTypeFloat16: {
1149       return llvm::Type::getHalfTy(C);
1150       break;
1151     }
1152     case DataTypeFloat32: {
1153       return llvm::Type::getFloatTy(C);
1154       break;
1155     }
1156     case DataTypeFloat64: {
1157       return llvm::Type::getDoubleTy(C);
1158       break;
1159     }
1160     case DataTypeBoolean: {
1161       return llvm::Type::getInt1Ty(C);
1162       break;
1163     }
1164     case DataTypeSigned8:
1165     case DataTypeUnsigned8: {
1166       return llvm::Type::getInt8Ty(C);
1167       break;
1168     }
1169     case DataTypeSigned16:
1170     case DataTypeUnsigned16:
1171     case DataTypeUnsigned565:
1172     case DataTypeUnsigned5551:
1173     case DataTypeUnsigned4444: {
1174       return llvm::Type::getInt16Ty(C);
1175       break;
1176     }
1177     case DataTypeSigned32:
1178     case DataTypeUnsigned32: {
1179       return llvm::Type::getInt32Ty(C);
1180       break;
1181     }
1182     case DataTypeSigned64:
1183     case DataTypeUnsigned64: {
1184       return llvm::Type::getInt64Ty(C);
1185       break;
1186     }
1187     default: {
1188       slangAssert(false && "Unknown data type");
1189     }
1190   }
1191 
1192   return nullptr;
1193 }
1194 
equals(const RSExportable * E) const1195 bool RSExportPrimitiveType::equals(const RSExportable *E) const {
1196   CHECK_PARENT_EQUALITY(RSExportType, E);
1197   return (static_cast<const RSExportPrimitiveType*>(E)->getType() == getType());
1198 }
1199 
getRSReflectionType(DataType DT)1200 RSReflectionType *RSExportPrimitiveType::getRSReflectionType(DataType DT) {
1201   if (DT > DataTypeUnknown && DT < DataTypeMax) {
1202     return &gReflectionTypes[DT];
1203   } else {
1204     return nullptr;
1205   }
1206 }
1207 
1208 /**************************** RSExportPointerType ****************************/
1209 
1210 RSExportPointerType
Create(RSContext * Context,const clang::PointerType * PT,const llvm::StringRef & TypeName)1211 *RSExportPointerType::Create(RSContext *Context,
1212                              const clang::PointerType *PT,
1213                              const llvm::StringRef &TypeName) {
1214   const clang::Type *PointeeType = GetPointeeType(PT);
1215   const RSExportType *PointeeET;
1216 
1217   if (PointeeType->getTypeClass() != clang::Type::Pointer) {
1218     PointeeET = RSExportType::Create(Context, PointeeType);
1219   } else {
1220     // Double or higher dimension of pointer, export as int*
1221     PointeeET = RSExportPrimitiveType::Create(Context,
1222                     Context->getASTContext().IntTy.getTypePtr());
1223   }
1224 
1225   if (PointeeET == nullptr) {
1226     // Error diagnostic is emitted for corresponding pointee type
1227     return nullptr;
1228   }
1229 
1230   return new RSExportPointerType(Context, TypeName, PointeeET);
1231 }
1232 
convertToLLVMType() const1233 llvm::Type *RSExportPointerType::convertToLLVMType() const {
1234   llvm::Type *PointeeType = mPointeeType->getLLVMType();
1235   return llvm::PointerType::getUnqual(PointeeType);
1236 }
1237 
keep()1238 bool RSExportPointerType::keep() {
1239   if (!RSExportType::keep())
1240     return false;
1241   const_cast<RSExportType*>(mPointeeType)->keep();
1242   return true;
1243 }
1244 
equals(const RSExportable * E) const1245 bool RSExportPointerType::equals(const RSExportable *E) const {
1246   CHECK_PARENT_EQUALITY(RSExportType, E);
1247   return (static_cast<const RSExportPointerType*>(E)
1248               ->getPointeeType()->equals(getPointeeType()));
1249 }
1250 
1251 /***************************** RSExportVectorType *****************************/
1252 llvm::StringRef
GetTypeName(const clang::ExtVectorType * EVT)1253 RSExportVectorType::GetTypeName(const clang::ExtVectorType *EVT) {
1254   const clang::Type *ElementType = GetExtVectorElementType(EVT);
1255   llvm::StringRef name;
1256 
1257   if ((ElementType->getTypeClass() != clang::Type::Builtin))
1258     return name;
1259 
1260   const clang::BuiltinType *BT =
1261           static_cast<const clang::BuiltinType*>(
1262               ElementType->getCanonicalTypeInternal().getTypePtr());
1263 
1264   if ((EVT->getNumElements() < 1) ||
1265       (EVT->getNumElements() > 4))
1266     return name;
1267 
1268   BuiltinInfo *info = FindBuiltinType(BT->getKind());
1269   if (info != nullptr) {
1270     int I = EVT->getNumElements() - 1;
1271     if (I < kMaxVectorSize) {
1272       name = info->cname[I];
1273     } else {
1274       slangAssert(false && "Max vector is 4");
1275     }
1276   }
1277   return name;
1278 }
1279 
Create(RSContext * Context,const clang::ExtVectorType * EVT,const llvm::StringRef & TypeName,bool Normalized)1280 RSExportVectorType *RSExportVectorType::Create(RSContext *Context,
1281                                                const clang::ExtVectorType *EVT,
1282                                                const llvm::StringRef &TypeName,
1283                                                bool Normalized) {
1284   slangAssert(EVT != nullptr && EVT->getTypeClass() == clang::Type::ExtVector);
1285 
1286   const clang::Type *ElementType = GetExtVectorElementType(EVT);
1287   DataType DT = RSExportPrimitiveType::GetDataType(Context, ElementType);
1288 
1289   if (DT != DataTypeUnknown)
1290     return new RSExportVectorType(Context,
1291                                   TypeName,
1292                                   DT,
1293                                   Normalized,
1294                                   EVT->getNumElements());
1295   else
1296     return nullptr;
1297 }
1298 
convertToLLVMType() const1299 llvm::Type *RSExportVectorType::convertToLLVMType() const {
1300   llvm::Type *ElementType = RSExportPrimitiveType::convertToLLVMType();
1301   return llvm::VectorType::get(ElementType, getNumElement());
1302 }
1303 
equals(const RSExportable * E) const1304 bool RSExportVectorType::equals(const RSExportable *E) const {
1305   CHECK_PARENT_EQUALITY(RSExportPrimitiveType, E);
1306   return (static_cast<const RSExportVectorType*>(E)->getNumElement()
1307               == getNumElement());
1308 }
1309 
1310 /***************************** RSExportMatrixType *****************************/
Create(RSContext * Context,const clang::RecordType * RT,const llvm::StringRef & TypeName,unsigned Dim)1311 RSExportMatrixType *RSExportMatrixType::Create(RSContext *Context,
1312                                                const clang::RecordType *RT,
1313                                                const llvm::StringRef &TypeName,
1314                                                unsigned Dim) {
1315   slangAssert((RT != nullptr) && (RT->getTypeClass() == clang::Type::Record));
1316   slangAssert((Dim > 1) && "Invalid dimension of matrix");
1317 
1318   // Check whether the struct rs_matrix is in our expected form (but assume it's
1319   // correct if we're not sure whether it's correct or not)
1320   const clang::RecordDecl* RD = RT->getDecl();
1321   RD = RD->getDefinition();
1322   if (RD != nullptr) {
1323     // Find definition, perform further examination
1324     if (RD->field_empty()) {
1325       Context->ReportError(
1326           RD->getLocation(),
1327           "invalid matrix struct: must have 1 field for saving values: '%0'")
1328           << RD->getName();
1329       return nullptr;
1330     }
1331 
1332     clang::RecordDecl::field_iterator FIT = RD->field_begin();
1333     const clang::FieldDecl *FD = *FIT;
1334     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1335     if ((FT == nullptr) || (FT->getTypeClass() != clang::Type::ConstantArray)) {
1336       Context->ReportError(RD->getLocation(),
1337                            "invalid matrix struct: first field should"
1338                            " be an array with constant size: '%0'")
1339           << RD->getName();
1340       return nullptr;
1341     }
1342     const clang::ConstantArrayType *CAT =
1343       static_cast<const clang::ConstantArrayType *>(FT);
1344     const clang::Type *ElementType = GetConstantArrayElementType(CAT);
1345     if ((ElementType == nullptr) ||
1346         (ElementType->getTypeClass() != clang::Type::Builtin) ||
1347         (static_cast<const clang::BuiltinType *>(ElementType)->getKind() !=
1348          clang::BuiltinType::Float)) {
1349       Context->ReportError(RD->getLocation(),
1350                            "invalid matrix struct: first field "
1351                            "should be a float array: '%0'")
1352           << RD->getName();
1353       return nullptr;
1354     }
1355 
1356     if (CAT->getSize() != Dim * Dim) {
1357       Context->ReportError(RD->getLocation(),
1358                            "invalid matrix struct: first field "
1359                            "should be an array with size %0: '%1'")
1360           << (Dim * Dim) << (RD->getName());
1361       return nullptr;
1362     }
1363 
1364     FIT++;
1365     if (FIT != RD->field_end()) {
1366       Context->ReportError(RD->getLocation(),
1367                            "invalid matrix struct: must have "
1368                            "exactly 1 field: '%0'")
1369           << RD->getName();
1370       return nullptr;
1371     }
1372   }
1373 
1374   return new RSExportMatrixType(Context, TypeName, Dim);
1375 }
1376 
convertToLLVMType() const1377 llvm::Type *RSExportMatrixType::convertToLLVMType() const {
1378   // Construct LLVM type:
1379   // struct {
1380   //  float X[mDim * mDim];
1381   // }
1382 
1383   llvm::LLVMContext &C = getRSContext()->getLLVMContext();
1384   llvm::ArrayType *X = llvm::ArrayType::get(llvm::Type::getFloatTy(C),
1385                                             mDim * mDim);
1386   return llvm::StructType::get(C, X, false);
1387 }
1388 
equals(const RSExportable * E) const1389 bool RSExportMatrixType::equals(const RSExportable *E) const {
1390   CHECK_PARENT_EQUALITY(RSExportType, E);
1391   return (static_cast<const RSExportMatrixType*>(E)->getDim() == getDim());
1392 }
1393 
1394 /************************* RSExportConstantArrayType *************************/
1395 RSExportConstantArrayType
Create(RSContext * Context,const clang::ConstantArrayType * CAT)1396 *RSExportConstantArrayType::Create(RSContext *Context,
1397                                    const clang::ConstantArrayType *CAT) {
1398   slangAssert(CAT != nullptr && CAT->getTypeClass() == clang::Type::ConstantArray);
1399 
1400   slangAssert((CAT->getSize().getActiveBits() < 32) && "array too large");
1401 
1402   unsigned Size = static_cast<unsigned>(CAT->getSize().getZExtValue());
1403   slangAssert((Size > 0) && "Constant array should have size greater than 0");
1404 
1405   const clang::Type *ElementType = GetConstantArrayElementType(CAT);
1406   RSExportType *ElementET = RSExportType::Create(Context, ElementType);
1407 
1408   if (ElementET == nullptr) {
1409     return nullptr;
1410   }
1411 
1412   return new RSExportConstantArrayType(Context,
1413                                        ElementET,
1414                                        Size);
1415 }
1416 
convertToLLVMType() const1417 llvm::Type *RSExportConstantArrayType::convertToLLVMType() const {
1418   return llvm::ArrayType::get(mElementType->getLLVMType(), getSize());
1419 }
1420 
keep()1421 bool RSExportConstantArrayType::keep() {
1422   if (!RSExportType::keep())
1423     return false;
1424   const_cast<RSExportType*>(mElementType)->keep();
1425   return true;
1426 }
1427 
equals(const RSExportable * E) const1428 bool RSExportConstantArrayType::equals(const RSExportable *E) const {
1429   CHECK_PARENT_EQUALITY(RSExportType, E);
1430   const RSExportConstantArrayType *RHS =
1431       static_cast<const RSExportConstantArrayType*>(E);
1432   return ((getSize() == RHS->getSize()) &&
1433           (getElementType()->equals(RHS->getElementType())));
1434 }
1435 
1436 /**************************** RSExportRecordType ****************************/
Create(RSContext * Context,const clang::RecordType * RT,const llvm::StringRef & TypeName,bool mIsArtificial)1437 RSExportRecordType *RSExportRecordType::Create(RSContext *Context,
1438                                                const clang::RecordType *RT,
1439                                                const llvm::StringRef &TypeName,
1440                                                bool mIsArtificial) {
1441   slangAssert(RT != nullptr && RT->getTypeClass() == clang::Type::Record);
1442 
1443   const clang::RecordDecl *RD = RT->getDecl();
1444   slangAssert(RD->isStruct());
1445 
1446   RD = RD->getDefinition();
1447   if (RD == nullptr) {
1448     slangAssert(false && "struct is not defined in this module");
1449     return nullptr;
1450   }
1451 
1452   // Struct layout construct by clang. We rely on this for obtaining the
1453   // alloc size of a struct and offset of every field in that struct.
1454   const clang::ASTRecordLayout *RL =
1455       &Context->getASTContext().getASTRecordLayout(RD);
1456   slangAssert((RL != nullptr) &&
1457       "Failed to retrieve the struct layout from Clang.");
1458 
1459   RSExportRecordType *ERT =
1460       new RSExportRecordType(Context,
1461                              TypeName,
1462                              RD->hasAttr<clang::PackedAttr>(),
1463                              mIsArtificial,
1464                              RL->getDataSize().getQuantity(),
1465                              RL->getSize().getQuantity());
1466   unsigned int Index = 0;
1467 
1468   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1469            FE = RD->field_end();
1470        FI != FE;
1471        FI++, Index++) {
1472 
1473     // FIXME: All fields should be primitive type
1474     slangAssert(FI->getKind() == clang::Decl::Field);
1475     clang::FieldDecl *FD = *FI;
1476 
1477     if (FD->isBitField()) {
1478       return nullptr;
1479     }
1480 
1481     // Type
1482     RSExportType *ET = RSExportElement::CreateFromDecl(Context, FD);
1483 
1484     if (ET != nullptr) {
1485       ERT->mFields.push_back(
1486           new Field(ET, FD->getName(), ERT,
1487                     static_cast<size_t>(RL->getFieldOffset(Index) >> 3)));
1488     } else {
1489       Context->ReportError(RD->getLocation(),
1490                            "field type cannot be exported: '%0.%1'")
1491           << RD->getName() << FD->getName();
1492       return nullptr;
1493     }
1494   }
1495 
1496   return ERT;
1497 }
1498 
convertToLLVMType() const1499 llvm::Type *RSExportRecordType::convertToLLVMType() const {
1500   // Create an opaque type since struct may reference itself recursively.
1501 
1502   // TODO(sliao): LLVM took out the OpaqueType. Any other to migrate to?
1503   std::vector<llvm::Type*> FieldTypes;
1504 
1505   for (const_field_iterator FI = fields_begin(), FE = fields_end();
1506        FI != FE;
1507        FI++) {
1508     const Field *F = *FI;
1509     const RSExportType *FET = F->getType();
1510 
1511     FieldTypes.push_back(FET->getLLVMType());
1512   }
1513 
1514   llvm::StructType *ST = llvm::StructType::get(getRSContext()->getLLVMContext(),
1515                                                FieldTypes,
1516                                                mIsPacked);
1517   if (ST != nullptr) {
1518     return ST;
1519   } else {
1520     return nullptr;
1521   }
1522 }
1523 
keep()1524 bool RSExportRecordType::keep() {
1525   if (!RSExportType::keep())
1526     return false;
1527   for (std::list<const Field*>::iterator I = mFields.begin(),
1528           E = mFields.end();
1529        I != E;
1530        I++) {
1531     const_cast<RSExportType*>((*I)->getType())->keep();
1532   }
1533   return true;
1534 }
1535 
equals(const RSExportable * E) const1536 bool RSExportRecordType::equals(const RSExportable *E) const {
1537   CHECK_PARENT_EQUALITY(RSExportType, E);
1538 
1539   const RSExportRecordType *ERT = static_cast<const RSExportRecordType*>(E);
1540 
1541   if (ERT->getFields().size() != getFields().size())
1542     return false;
1543 
1544   const_field_iterator AI = fields_begin(), BI = ERT->fields_begin();
1545 
1546   for (unsigned i = 0, e = getFields().size(); i != e; i++) {
1547     if (!(*AI)->getType()->equals((*BI)->getType()))
1548       return false;
1549     AI++;
1550     BI++;
1551   }
1552 
1553   return true;
1554 }
1555 
convertToRTD(RSReflectionTypeData * rtd) const1556 void RSExportType::convertToRTD(RSReflectionTypeData *rtd) const {
1557     memset(rtd, 0, sizeof(*rtd));
1558     rtd->vecSize = 1;
1559 
1560     switch(getClass()) {
1561     case RSExportType::ExportClassPrimitive: {
1562             const RSExportPrimitiveType *EPT = static_cast<const RSExportPrimitiveType*>(this);
1563             rtd->type = RSExportPrimitiveType::getRSReflectionType(EPT);
1564             return;
1565         }
1566     case RSExportType::ExportClassPointer: {
1567             const RSExportPointerType *EPT = static_cast<const RSExportPointerType*>(this);
1568             const RSExportType *PointeeType = EPT->getPointeeType();
1569             PointeeType->convertToRTD(rtd);
1570             rtd->isPointer = true;
1571             return;
1572         }
1573     case RSExportType::ExportClassVector: {
1574             const RSExportVectorType *EVT = static_cast<const RSExportVectorType*>(this);
1575             rtd->type = EVT->getRSReflectionType(EVT);
1576             rtd->vecSize = EVT->getNumElement();
1577             return;
1578         }
1579     case RSExportType::ExportClassMatrix: {
1580             const RSExportMatrixType *EMT = static_cast<const RSExportMatrixType*>(this);
1581             unsigned Dim = EMT->getDim();
1582             slangAssert((Dim >= 2) && (Dim <= 4));
1583             rtd->type = &gReflectionTypes[15 + Dim-2];
1584             return;
1585         }
1586     case RSExportType::ExportClassConstantArray: {
1587             const RSExportConstantArrayType* CAT =
1588               static_cast<const RSExportConstantArrayType*>(this);
1589             CAT->getElementType()->convertToRTD(rtd);
1590             rtd->arraySize = CAT->getSize();
1591             return;
1592         }
1593     case RSExportType::ExportClassRecord: {
1594             slangAssert(!"RSExportType::ExportClassRecord not implemented");
1595             return;// RS_TYPE_CLASS_NAME_PREFIX + ET->getName() + ".Item";
1596         }
1597     default: {
1598             slangAssert(false && "Unknown class of type");
1599         }
1600     }
1601 }
1602 
1603 
1604 }  // namespace slang
1605