1 /*
2  * Copyright 2011-2012, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "slang_rs_export_foreach.h"
18 
19 #include <string>
20 
21 #include "clang/AST/ASTContext.h"
22 #include "clang/AST/Attr.h"
23 #include "clang/AST/Decl.h"
24 #include "clang/AST/TypeLoc.h"
25 
26 #include "llvm/IR/DerivedTypes.h"
27 
28 #include "bcinfo/MetadataExtractor.h"
29 
30 #include "slang_assert.h"
31 #include "slang_rs_context.h"
32 #include "slang_rs_export_type.h"
33 #include "slang_version.h"
34 
35 namespace {
36 
37 const size_t RS_KERNEL_INPUT_LIMIT = 8; // see frameworks/base/libs/rs/cpu_ref/rsCpuCoreRuntime.h
38 
39 enum SpecialParameterKind {
40   SPK_LOCATION, // 'int' or 'unsigned int'
41   SPK_CONTEXT,  // rs_kernel_context
42 };
43 
44 struct SpecialParameter {
45   const char *name;
46   bcinfo::MetadataSignatureBitval bitval;
47   SpecialParameterKind kind;
48   SlangTargetAPI minAPI;
49 };
50 
51 // Table entries are in the order parameters must occur in a kernel parameter list.
52 const SpecialParameter specialParameterTable[] = {
53   { "context", bcinfo::MD_SIG_Ctxt, SPK_CONTEXT, SLANG_M_TARGET_API },
54   { "x", bcinfo::MD_SIG_X, SPK_LOCATION, SLANG_MINIMUM_TARGET_API },
55   { "y", bcinfo::MD_SIG_Y, SPK_LOCATION, SLANG_MINIMUM_TARGET_API },
56   { "z", bcinfo::MD_SIG_Z, SPK_LOCATION, SLANG_M_TARGET_API },
57   { nullptr, bcinfo::MD_SIG_None, SPK_LOCATION, SLANG_MINIMUM_TARGET_API }, // marks end of table
58 };
59 
60 // If the specified name matches the name of an entry in
61 // specialParameterTable, return the corresponding table index.
62 // Return -1 if not found.
lookupSpecialParameter(const llvm::StringRef name)63 int lookupSpecialParameter(const llvm::StringRef name) {
64   for (int i = 0; specialParameterTable[i].name != nullptr; ++i) {
65     if (name.equals(specialParameterTable[i].name)) {
66       return i;
67     }
68   }
69 
70   return -1;
71 }
72 
73 // Return a comma-separated list of names in specialParameterTable
74 // that are available at the specified API level.
listSpecialParameters(unsigned int api)75 std::string listSpecialParameters(unsigned int api) {
76   std::string ret;
77   bool first = true;
78   for (int i = 0; specialParameterTable[i].name != nullptr; ++i) {
79     if (specialParameterTable[i].minAPI > api)
80       continue;
81     if (first)
82       first = false;
83     else
84       ret += ", ";
85     ret += "'";
86     ret += specialParameterTable[i].name;
87     ret += "'";
88   }
89   return ret;
90 }
91 
92 }
93 
94 namespace slang {
95 
96 // This function takes care of additional validation and construction of
97 // parameters related to forEach_* reflection.
validateAndConstructParams(RSContext * Context,const clang::FunctionDecl * FD)98 bool RSExportForEach::validateAndConstructParams(
99     RSContext *Context, const clang::FunctionDecl *FD) {
100   slangAssert(Context && FD);
101   bool valid = true;
102 
103   numParams = FD->getNumParams();
104 
105   if (Context->getTargetAPI() < SLANG_JB_TARGET_API) {
106     // Before JellyBean, we allowed only one kernel per file.  It must be called "root".
107     if (!isRootRSFunc(FD)) {
108       Context->ReportError(FD->getLocation(),
109                            "Non-root compute kernel %0() is "
110                            "not supported in SDK levels %1-%2")
111           << FD->getName() << SLANG_MINIMUM_TARGET_API
112           << (SLANG_JB_TARGET_API - 1);
113       return false;
114     }
115   }
116 
117   mResultType = FD->getReturnType().getCanonicalType();
118   // Compute kernel functions are defined differently when the
119   // "__attribute__((kernel))" is set.
120   if (FD->hasAttr<clang::KernelAttr>()) {
121     valid |= validateAndConstructKernelParams(Context, FD);
122   } else {
123     valid |= validateAndConstructOldStyleParams(Context, FD);
124   }
125 
126   valid |= setSignatureMetadata(Context, FD);
127   return valid;
128 }
129 
validateAndConstructOldStyleParams(RSContext * Context,const clang::FunctionDecl * FD)130 bool RSExportForEach::validateAndConstructOldStyleParams(
131     RSContext *Context, const clang::FunctionDecl *FD) {
132   slangAssert(Context && FD);
133   // If numParams is 0, we already marked this as a graphics root().
134   slangAssert(numParams > 0);
135 
136   bool valid = true;
137 
138   // Compute kernel functions of this style are required to return a void type.
139   clang::ASTContext &C = Context->getASTContext();
140   if (mResultType != C.VoidTy) {
141     Context->ReportError(FD->getLocation(),
142                          "Compute kernel %0() is required to return a "
143                          "void type")
144         << FD->getName();
145     valid = false;
146   }
147 
148   // Validate remaining parameter types
149 
150   size_t IndexOfFirstSpecialParameter = numParams;
151   valid |= processSpecialParameters(Context, FD, &IndexOfFirstSpecialParameter);
152 
153   // Validate the non-special parameters, which should all be found before the
154   // first special parameter.
155   for (size_t i = 0; i < IndexOfFirstSpecialParameter; i++) {
156     const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
157     clang::QualType QT = PVD->getType().getCanonicalType();
158 
159     if (!QT->isPointerType()) {
160       Context->ReportError(PVD->getLocation(),
161                            "Compute kernel %0() cannot have non-pointer "
162                            "parameters besides special parameters (%1). Parameter '%2' is "
163                            "of type: '%3'")
164           << FD->getName() << listSpecialParameters(Context->getTargetAPI())
165           << PVD->getName() << PVD->getType().getAsString();
166       valid = false;
167       continue;
168     }
169 
170     // The only non-const pointer should be out.
171     if (!QT->getPointeeType().isConstQualified()) {
172       if (mOut == nullptr) {
173         mOut = PVD;
174       } else {
175         Context->ReportError(PVD->getLocation(),
176                              "Compute kernel %0() can only have one non-const "
177                              "pointer parameter. Parameters '%1' and '%2' are "
178                              "both non-const.")
179             << FD->getName() << mOut->getName() << PVD->getName();
180         valid = false;
181       }
182     } else {
183       if (mIns.empty() && mOut == nullptr) {
184         mIns.push_back(PVD);
185       } else if (mUsrData == nullptr) {
186         mUsrData = PVD;
187       } else {
188         Context->ReportError(
189             PVD->getLocation(),
190             "Unexpected parameter '%0' for compute kernel %1()")
191             << PVD->getName() << FD->getName();
192         valid = false;
193       }
194     }
195   }
196 
197   if (mIns.empty() && !mOut) {
198     Context->ReportError(FD->getLocation(),
199                          "Compute kernel %0() must have at least one "
200                          "parameter for in or out")
201         << FD->getName();
202     valid = false;
203   }
204 
205   return valid;
206 }
207 
validateAndConstructKernelParams(RSContext * Context,const clang::FunctionDecl * FD)208 bool RSExportForEach::validateAndConstructKernelParams(
209     RSContext *Context, const clang::FunctionDecl *FD) {
210   slangAssert(Context && FD);
211   bool valid = true;
212   clang::ASTContext &C = Context->getASTContext();
213 
214   if (Context->getTargetAPI() < SLANG_JB_MR1_TARGET_API) {
215     Context->ReportError(FD->getLocation(),
216                          "Compute kernel %0() targeting SDK levels "
217                          "%1-%2 may not use pass-by-value with "
218                          "__attribute__((kernel))")
219         << FD->getName() << SLANG_MINIMUM_TARGET_API
220         << (SLANG_JB_MR1_TARGET_API - 1);
221     return false;
222   }
223 
224   // Denote that we are indeed a pass-by-value kernel.
225   mIsKernelStyle = true;
226   mHasReturnType = (mResultType != C.VoidTy);
227 
228   if (mResultType->isPointerType()) {
229     Context->ReportError(
230         FD->getTypeSpecStartLoc(),
231         "Compute kernel %0() cannot return a pointer type: '%1'")
232         << FD->getName() << mResultType.getAsString();
233     valid = false;
234   }
235 
236   // Validate remaining parameter types
237 
238   size_t IndexOfFirstSpecialParameter = numParams;
239   valid |= processSpecialParameters(Context, FD, &IndexOfFirstSpecialParameter);
240 
241   // Validate the non-special parameters, which should all be found before the
242   // first special.
243   for (size_t i = 0; i < IndexOfFirstSpecialParameter; i++) {
244     const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
245 
246     if (Context->getTargetAPI() >= SLANG_M_TARGET_API || i == 0) {
247       if (i >= RS_KERNEL_INPUT_LIMIT) {
248         Context->ReportError(PVD->getLocation(),
249                              "Invalid parameter '%0' for compute kernel %1(). "
250                              "Kernels targeting SDK levels %2+ may not use "
251                              "more than %3 input parameters.") << PVD->getName() <<
252                              FD->getName() << SLANG_M_TARGET_API <<
253                              int(RS_KERNEL_INPUT_LIMIT);
254 
255       } else {
256         mIns.push_back(PVD);
257       }
258     } else {
259       Context->ReportError(PVD->getLocation(),
260                            "Invalid parameter '%0' for compute kernel %1(). "
261                            "Kernels targeting SDK levels %2-%3 may not use "
262                            "multiple input parameters.") << PVD->getName() <<
263                            FD->getName() << SLANG_MINIMUM_TARGET_API <<
264                            (SLANG_M_TARGET_API - 1);
265       valid = false;
266     }
267     clang::QualType QT = PVD->getType().getCanonicalType();
268     if (QT->isPointerType()) {
269       Context->ReportError(PVD->getLocation(),
270                            "Compute kernel %0() cannot have "
271                            "parameter '%1' of pointer type: '%2'")
272           << FD->getName() << PVD->getName() << PVD->getType().getAsString();
273       valid = false;
274     }
275   }
276 
277   // Check that we have at least one allocation to use for dimensions.
278   if (valid && mIns.empty() && !mHasReturnType && Context->getTargetAPI() < SLANG_M_TARGET_API) {
279     Context->ReportError(FD->getLocation(),
280                          "Compute kernel %0() targeting SDK levels "
281                          "%1-%2 must have at least one "
282                          "input parameter or a non-void return "
283                          "type")
284         << FD->getName() << SLANG_MINIMUM_TARGET_API
285         << (SLANG_M_TARGET_API - 1);
286     valid = false;
287   }
288 
289   return valid;
290 }
291 
292 // Process the optional special parameters:
293 // - Sets *IndexOfFirstSpecialParameter to the index of the first special parameter, or
294 //     FD->getNumParams() if none are found.
295 // - Sets mSpecialParameterSignatureMetadata for the found special parameters.
296 // Returns true if no errors.
processSpecialParameters(RSContext * Context,const clang::FunctionDecl * FD,size_t * IndexOfFirstSpecialParameter)297 bool RSExportForEach::processSpecialParameters(
298     RSContext *Context, const clang::FunctionDecl *FD,
299     size_t *IndexOfFirstSpecialParameter) {
300   slangAssert(IndexOfFirstSpecialParameter != nullptr);
301   slangAssert(mSpecialParameterSignatureMetadata == 0);
302   clang::ASTContext &C = Context->getASTContext();
303 
304   // Find all special parameters if present.
305   int LastSpecialParameterIdx = -1;     // index into specialParameterTable
306   int FirstLocationSpecialParameterIdx = -1; // index into specialParameterTable
307   clang::QualType FirstLocationSpecialParameterType;
308   size_t NumParams = FD->getNumParams();
309   *IndexOfFirstSpecialParameter = NumParams;
310   bool valid = true;
311   for (size_t i = 0; i < NumParams; i++) {
312     const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
313     const llvm::StringRef ParamName = PVD->getName();
314     const clang::QualType Type = PVD->getType();
315     const clang::QualType QT = Type.getCanonicalType();
316     const clang::QualType UT = QT.getUnqualifiedType();
317     int SpecialParameterIdx = lookupSpecialParameter(ParamName);
318 
319     static const char KernelContextUnqualifiedTypeName[] =
320         "const struct rs_kernel_context_t *";
321     static const char KernelContextTypeName[] = "rs_kernel_context";
322 
323     // If the type is rs_context, it should have been named "context" and classified
324     // as a special parameter.
325     if (SpecialParameterIdx < 0 && UT.getAsString() == KernelContextUnqualifiedTypeName) {
326       Context->ReportError(
327           PVD->getLocation(),
328           "The special parameter of type '%0' must be called "
329           "'context' instead of '%1'.")
330           << KernelContextTypeName << ParamName;
331       SpecialParameterIdx = lookupSpecialParameter("context");
332     }
333 
334     // If it's not a special parameter, check that it appears before any special
335     // parameter.
336     if (SpecialParameterIdx < 0) {
337       if (*IndexOfFirstSpecialParameter < NumParams) {
338         Context->ReportError(PVD->getLocation(),
339                              "In compute kernel %0(), parameter '%1' cannot "
340                              "appear after any of the special parameters (%2).")
341             << FD->getName() << ParamName << listSpecialParameters(Context->getTargetAPI());
342         valid = false;
343       }
344       continue;
345     }
346 
347     const SpecialParameter &SP = specialParameterTable[SpecialParameterIdx];
348 
349     // Verify that this special parameter is OK for the current API level.
350     if (Context->getTargetAPI() < SP.minAPI) {
351       Context->ReportError(PVD->getLocation(),
352                            "Compute kernel %0() targeting SDK levels "
353                            "%1-%2 may not use special parameter '%3'.")
354           << FD->getName() << SLANG_MINIMUM_TARGET_API << (SP.minAPI - 1)
355           << SP.name;
356       valid = false;
357     }
358 
359     // Check that the order of the special parameters is correct.
360     if (SpecialParameterIdx < LastSpecialParameterIdx) {
361       Context->ReportError(
362           PVD->getLocation(),
363           "In compute kernel %0(), special parameter '%1' must "
364           "be defined before special parameter '%2'.")
365           << FD->getName() << SP.name
366           << specialParameterTable[LastSpecialParameterIdx].name;
367       valid = false;
368     }
369 
370     // Validate the data type of the special parameter.
371     switch (SP.kind) {
372     case SPK_LOCATION: {
373       // Location special parameters can only be int or uint.
374       if (UT != C.UnsignedIntTy && UT != C.IntTy) {
375         Context->ReportError(PVD->getLocation(),
376                              "Special parameter '%0' must be of type 'int' or "
377                              "'unsigned int'. It is of type '%1'.")
378             << ParamName << Type.getAsString();
379         valid = false;
380       }
381 
382       // Ensure that all location special parameters have the same type.
383       if (FirstLocationSpecialParameterIdx >= 0) {
384         if (Type != FirstLocationSpecialParameterType) {
385           Context->ReportError(
386               PVD->getLocation(),
387               "Special parameters '%0' and '%1' must be of the same type. "
388               "'%0' is of type '%2' while '%1' is of type '%3'.")
389               << specialParameterTable[FirstLocationSpecialParameterIdx].name
390               << SP.name << FirstLocationSpecialParameterType.getAsString()
391               << Type.getAsString();
392           valid = false;
393         }
394       } else {
395         FirstLocationSpecialParameterIdx = SpecialParameterIdx;
396         FirstLocationSpecialParameterType = Type;
397       }
398     } break;
399     case SPK_CONTEXT: {
400       // Check that variables named "context" are of type rs_context.
401       if (UT.getAsString() != KernelContextUnqualifiedTypeName) {
402         Context->ReportError(PVD->getLocation(),
403                              "Special parameter '%0' must be of type '%1'. "
404                              "It is of type '%2'.")
405             << ParamName << KernelContextTypeName
406             << Type.getAsString();
407         valid = false;
408       }
409     } break;
410     default:
411       slangAssert(!"Unexpected special parameter type");
412     }
413 
414     // We should not be invoked if two parameters of the same name are present.
415     slangAssert(!(mSpecialParameterSignatureMetadata & SP.bitval));
416     mSpecialParameterSignatureMetadata |= SP.bitval;
417 
418     LastSpecialParameterIdx = SpecialParameterIdx;
419     // If this is the first time we find a special parameter, save it.
420     if (*IndexOfFirstSpecialParameter >= NumParams) {
421       *IndexOfFirstSpecialParameter = i;
422     }
423   }
424   return valid;
425 }
426 
setSignatureMetadata(RSContext * Context,const clang::FunctionDecl * FD)427 bool RSExportForEach::setSignatureMetadata(RSContext *Context,
428                                            const clang::FunctionDecl *FD) {
429   mSignatureMetadata = 0;
430   bool valid = true;
431 
432   if (mIsKernelStyle) {
433     slangAssert(mOut == nullptr);
434     slangAssert(mUsrData == nullptr);
435   } else {
436     slangAssert(!mHasReturnType);
437   }
438 
439   // Set up the bitwise metadata encoding for runtime argument passing.
440   const bool HasOut = mOut || mHasReturnType;
441   mSignatureMetadata |= (hasIns() ?       bcinfo::MD_SIG_In     : 0);
442   mSignatureMetadata |= (HasOut ?         bcinfo::MD_SIG_Out    : 0);
443   mSignatureMetadata |= (mUsrData ?       bcinfo::MD_SIG_Usr    : 0);
444   mSignatureMetadata |= (mIsKernelStyle ? bcinfo::MD_SIG_Kernel : 0);  // pass-by-value
445   mSignatureMetadata |= mSpecialParameterSignatureMetadata;
446 
447   if (Context->getTargetAPI() < SLANG_ICS_TARGET_API) {
448     // APIs before ICS cannot skip between parameters. It is ok, however, for
449     // them to omit further parameters (i.e. skipping X is ok if you skip Y).
450     if (mSignatureMetadata != (bcinfo::MD_SIG_In | bcinfo::MD_SIG_Out | bcinfo::MD_SIG_Usr |
451                                bcinfo::MD_SIG_X | bcinfo::MD_SIG_Y) &&
452         mSignatureMetadata != (bcinfo::MD_SIG_In | bcinfo::MD_SIG_Out | bcinfo::MD_SIG_Usr |
453                                bcinfo::MD_SIG_X) &&
454         mSignatureMetadata != (bcinfo::MD_SIG_In | bcinfo::MD_SIG_Out | bcinfo::MD_SIG_Usr) &&
455         mSignatureMetadata != (bcinfo::MD_SIG_In | bcinfo::MD_SIG_Out) &&
456         mSignatureMetadata != (bcinfo::MD_SIG_In)) {
457       Context->ReportError(FD->getLocation(),
458                            "Compute kernel %0() targeting SDK levels "
459                            "%1-%2 may not skip parameters")
460           << FD->getName() << SLANG_MINIMUM_TARGET_API
461           << (SLANG_ICS_TARGET_API - 1);
462       valid = false;
463     }
464   }
465   return valid;
466 }
467 
Create(RSContext * Context,const clang::FunctionDecl * FD)468 RSExportForEach *RSExportForEach::Create(RSContext *Context,
469                                          const clang::FunctionDecl *FD) {
470   slangAssert(Context && FD);
471   llvm::StringRef Name = FD->getName();
472   RSExportForEach *FE;
473 
474   slangAssert(!Name.empty() && "Function must have a name");
475 
476   FE = new RSExportForEach(Context, Name);
477 
478   if (!FE->validateAndConstructParams(Context, FD)) {
479     return nullptr;
480   }
481 
482   clang::ASTContext &Ctx = Context->getASTContext();
483 
484   std::string Id = CreateDummyName("helper_foreach_param", FE->getName());
485 
486   // Extract the usrData parameter (if we have one)
487   if (FE->mUsrData) {
488     const clang::ParmVarDecl *PVD = FE->mUsrData;
489     clang::QualType QT = PVD->getType().getCanonicalType();
490     slangAssert(QT->isPointerType() &&
491                 QT->getPointeeType().isConstQualified());
492 
493     const clang::ASTContext &C = Context->getASTContext();
494     if (QT->getPointeeType().getCanonicalType().getUnqualifiedType() ==
495         C.VoidTy) {
496       // In the case of using const void*, we can't reflect an appopriate
497       // Java type, so we fall back to just reflecting the ain/aout parameters
498       FE->mUsrData = nullptr;
499     } else {
500       clang::RecordDecl *RD =
501           clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
502                                     Ctx.getTranslationUnitDecl(),
503                                     clang::SourceLocation(),
504                                     clang::SourceLocation(),
505                                     &Ctx.Idents.get(Id));
506 
507       clang::FieldDecl *FD =
508           clang::FieldDecl::Create(Ctx,
509                                    RD,
510                                    clang::SourceLocation(),
511                                    clang::SourceLocation(),
512                                    PVD->getIdentifier(),
513                                    QT->getPointeeType(),
514                                    nullptr,
515                                    /* BitWidth = */ nullptr,
516                                    /* Mutable = */ false,
517                                    /* HasInit = */ clang::ICIS_NoInit);
518       RD->addDecl(FD);
519       RD->completeDefinition();
520 
521       // Create an export type iff we have a valid usrData type
522       clang::QualType T = Ctx.getTagDeclType(RD);
523       slangAssert(!T.isNull());
524 
525       RSExportType *ET = RSExportType::Create(Context, T.getTypePtr());
526 
527       slangAssert(ET && "Failed to export a kernel");
528 
529       slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
530                   "Parameter packet must be a record");
531 
532       FE->mParamPacketType = static_cast<RSExportRecordType *>(ET);
533     }
534   }
535 
536   if (FE->hasIns()) {
537 
538     for (InIter BI = FE->mIns.begin(), EI = FE->mIns.end(); BI != EI; BI++) {
539       const clang::Type *T = (*BI)->getType().getCanonicalType().getTypePtr();
540       RSExportType *InExportType = RSExportType::Create(Context, T);
541 
542       if (FE->mIsKernelStyle) {
543         slangAssert(InExportType != nullptr);
544       }
545 
546       FE->mInTypes.push_back(InExportType);
547     }
548   }
549 
550   if (FE->mIsKernelStyle && FE->mHasReturnType) {
551     const clang::Type *T = FE->mResultType.getTypePtr();
552     FE->mOutType = RSExportType::Create(Context, T);
553     slangAssert(FE->mOutType);
554   } else if (FE->mOut) {
555     const clang::Type *T = FE->mOut->getType().getCanonicalType().getTypePtr();
556     FE->mOutType = RSExportType::Create(Context, T);
557   }
558 
559   return FE;
560 }
561 
CreateDummyRoot(RSContext * Context)562 RSExportForEach *RSExportForEach::CreateDummyRoot(RSContext *Context) {
563   slangAssert(Context);
564   llvm::StringRef Name = "root";
565   RSExportForEach *FE = new RSExportForEach(Context, Name);
566   FE->mDummyRoot = true;
567   return FE;
568 }
569 
isGraphicsRootRSFunc(unsigned int targetAPI,const clang::FunctionDecl * FD)570 bool RSExportForEach::isGraphicsRootRSFunc(unsigned int targetAPI,
571                                            const clang::FunctionDecl *FD) {
572   if (FD->hasAttr<clang::KernelAttr>()) {
573     return false;
574   }
575 
576   if (!isRootRSFunc(FD)) {
577     return false;
578   }
579 
580   if (FD->getNumParams() == 0) {
581     // Graphics root function
582     return true;
583   }
584 
585   // Check for legacy graphics root function (with single parameter).
586   if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) {
587     const clang::QualType &IntType = FD->getASTContext().IntTy;
588     if (FD->getReturnType().getCanonicalType() == IntType) {
589       return true;
590     }
591   }
592 
593   return false;
594 }
595 
isRSForEachFunc(unsigned int targetAPI,slang::RSContext * Context,const clang::FunctionDecl * FD)596 bool RSExportForEach::isRSForEachFunc(unsigned int targetAPI,
597                                       slang::RSContext* Context,
598                                       const clang::FunctionDecl *FD) {
599   slangAssert(Context && FD);
600   bool hasKernelAttr = FD->hasAttr<clang::KernelAttr>();
601 
602   if (FD->getStorageClass() == clang::SC_Static) {
603     if (hasKernelAttr) {
604       Context->ReportError(FD->getLocation(),
605                            "Invalid use of attribute kernel with "
606                            "static function declaration: %0")
607           << FD->getName();
608     }
609     return false;
610   }
611 
612   // Anything tagged as a kernel is definitely used with ForEach.
613   if (hasKernelAttr) {
614     return true;
615   }
616 
617   if (isGraphicsRootRSFunc(targetAPI, FD)) {
618     return false;
619   }
620 
621   // Check if first parameter is a pointer (which is required for ForEach).
622   unsigned int numParams = FD->getNumParams();
623 
624   if (numParams > 0) {
625     const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
626     clang::QualType QT = PVD->getType().getCanonicalType();
627 
628     if (QT->isPointerType()) {
629       return true;
630     }
631 
632     // Any non-graphics root() is automatically a ForEach candidate.
633     // At this point, however, we know that it is not going to be a valid
634     // compute root() function (due to not having a pointer parameter). We
635     // still want to return true here, so that we can issue appropriate
636     // diagnostics.
637     if (isRootRSFunc(FD)) {
638       return true;
639     }
640   }
641 
642   return false;
643 }
644 
645 bool
validateSpecialFuncDecl(unsigned int targetAPI,slang::RSContext * Context,clang::FunctionDecl const * FD)646 RSExportForEach::validateSpecialFuncDecl(unsigned int targetAPI,
647                                          slang::RSContext *Context,
648                                          clang::FunctionDecl const *FD) {
649   slangAssert(Context && FD);
650   bool valid = true;
651   const clang::ASTContext &C = FD->getASTContext();
652   const clang::QualType &IntType = FD->getASTContext().IntTy;
653 
654   if (isGraphicsRootRSFunc(targetAPI, FD)) {
655     if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) {
656       // Legacy graphics root function
657       const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
658       clang::QualType QT = PVD->getType().getCanonicalType();
659       if (QT != IntType) {
660         Context->ReportError(PVD->getLocation(),
661                              "invalid parameter type for legacy "
662                              "graphics root() function: %0")
663             << PVD->getType();
664         valid = false;
665       }
666     }
667 
668     // Graphics root function, so verify that it returns an int
669     if (FD->getReturnType().getCanonicalType() != IntType) {
670       Context->ReportError(FD->getLocation(),
671                            "root() is required to return "
672                            "an int for graphics usage");
673       valid = false;
674     }
675   } else if (isInitRSFunc(FD) || isDtorRSFunc(FD)) {
676     if (FD->getNumParams() != 0) {
677       Context->ReportError(FD->getLocation(),
678                            "%0(void) is required to have no "
679                            "parameters")
680           << FD->getName();
681       valid = false;
682     }
683 
684     if (FD->getReturnType().getCanonicalType() != C.VoidTy) {
685       Context->ReportError(FD->getLocation(),
686                            "%0(void) is required to have a void "
687                            "return type")
688           << FD->getName();
689       valid = false;
690     }
691   } else {
692     slangAssert(false && "must be called on root, init or .rs.dtor function!");
693   }
694 
695   return valid;
696 }
697 
698 }  // namespace slang
699