1 /*
2  * Copyright 2011-2012, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "slang_rs_export_foreach.h"
18 
19 #include <string>
20 
21 #include "clang/AST/ASTContext.h"
22 #include "clang/AST/Attr.h"
23 #include "clang/AST/Decl.h"
24 #include "clang/AST/TypeLoc.h"
25 
26 #include "llvm/IR/DerivedTypes.h"
27 
28 #include "slang_assert.h"
29 #include "slang_rs_context.h"
30 #include "slang_rs_export_type.h"
31 #include "slang_version.h"
32 
33 namespace slang {
34 
35 // This function takes care of additional validation and construction of
36 // parameters related to forEach_* reflection.
validateAndConstructParams(RSContext * Context,const clang::FunctionDecl * FD)37 bool RSExportForEach::validateAndConstructParams(
38     RSContext *Context, const clang::FunctionDecl *FD) {
39   slangAssert(Context && FD);
40   bool valid = true;
41 
42   numParams = FD->getNumParams();
43 
44   if (Context->getTargetAPI() < SLANG_JB_TARGET_API) {
45     // Before JellyBean, we allowed only one kernel per file.  It must be called "root".
46     if (!isRootRSFunc(FD)) {
47       Context->ReportError(FD->getLocation(),
48                            "Non-root compute kernel %0() is "
49                            "not supported in SDK levels %1-%2")
50           << FD->getName() << SLANG_MINIMUM_TARGET_API
51           << (SLANG_JB_TARGET_API - 1);
52       return false;
53     }
54   }
55 
56   mResultType = FD->getReturnType().getCanonicalType();
57   // Compute kernel functions are defined differently when the
58   // "__attribute__((kernel))" is set.
59   if (FD->hasAttr<clang::KernelAttr>()) {
60     valid |= validateAndConstructKernelParams(Context, FD);
61   } else {
62     valid |= validateAndConstructOldStyleParams(Context, FD);
63   }
64 
65   valid |= setSignatureMetadata(Context, FD);
66   return valid;
67 }
68 
validateAndConstructOldStyleParams(RSContext * Context,const clang::FunctionDecl * FD)69 bool RSExportForEach::validateAndConstructOldStyleParams(
70     RSContext *Context, const clang::FunctionDecl *FD) {
71   slangAssert(Context && FD);
72   // If numParams is 0, we already marked this as a graphics root().
73   slangAssert(numParams > 0);
74 
75   bool valid = true;
76 
77   // Compute kernel functions of this style are required to return a void type.
78   clang::ASTContext &C = Context->getASTContext();
79   if (mResultType != C.VoidTy) {
80     Context->ReportError(FD->getLocation(),
81                          "Compute kernel %0() is required to return a "
82                          "void type")
83         << FD->getName();
84     valid = false;
85   }
86 
87   // Validate remaining parameter types
88   // TODO(all): Add support for LOD/face when we have them
89 
90   size_t IndexOfFirstIterator = numParams;
91   valid |= validateIterationParameters(Context, FD, &IndexOfFirstIterator);
92 
93   // Validate the non-iterator parameters, which should all be found before the
94   // first iterator.
95   for (size_t i = 0; i < IndexOfFirstIterator; i++) {
96     const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
97     clang::QualType QT = PVD->getType().getCanonicalType();
98 
99     if (!QT->isPointerType()) {
100       Context->ReportError(PVD->getLocation(),
101                            "Compute kernel %0() cannot have non-pointer "
102                            "parameters besides 'x' and 'y'. Parameter '%1' is "
103                            "of type: '%2'")
104           << FD->getName() << PVD->getName() << PVD->getType().getAsString();
105       valid = false;
106       continue;
107     }
108 
109     // The only non-const pointer should be out.
110     if (!QT->getPointeeType().isConstQualified()) {
111       if (mOut == NULL) {
112         mOut = PVD;
113       } else {
114         Context->ReportError(PVD->getLocation(),
115                              "Compute kernel %0() can only have one non-const "
116                              "pointer parameter. Parameters '%1' and '%2' are "
117                              "both non-const.")
118             << FD->getName() << mOut->getName() << PVD->getName();
119         valid = false;
120       }
121     } else {
122       if (mIns.empty() && mOut == NULL) {
123         mIns.push_back(PVD);
124       } else if (mUsrData == NULL) {
125         mUsrData = PVD;
126       } else {
127         Context->ReportError(
128             PVD->getLocation(),
129             "Unexpected parameter '%0' for compute kernel %1()")
130             << PVD->getName() << FD->getName();
131         valid = false;
132       }
133     }
134   }
135 
136   if (mIns.empty() && !mOut) {
137     Context->ReportError(FD->getLocation(),
138                          "Compute kernel %0() must have at least one "
139                          "parameter for in or out")
140         << FD->getName();
141     valid = false;
142   }
143 
144   return valid;
145 }
146 
validateAndConstructKernelParams(RSContext * Context,const clang::FunctionDecl * FD)147 bool RSExportForEach::validateAndConstructKernelParams(
148     RSContext *Context, const clang::FunctionDecl *FD) {
149   slangAssert(Context && FD);
150   bool valid = true;
151   clang::ASTContext &C = Context->getASTContext();
152 
153   if (Context->getTargetAPI() < SLANG_JB_MR1_TARGET_API) {
154     Context->ReportError(FD->getLocation(),
155                          "Compute kernel %0() targeting SDK levels "
156                          "%1-%2 may not use pass-by-value with "
157                          "__attribute__((kernel))")
158         << FD->getName() << SLANG_MINIMUM_TARGET_API
159         << (SLANG_JB_MR1_TARGET_API - 1);
160     return false;
161   }
162 
163   // Denote that we are indeed a pass-by-value kernel.
164   mIsKernelStyle = true;
165   mHasReturnType = (mResultType != C.VoidTy);
166 
167   if (mResultType->isPointerType()) {
168     Context->ReportError(
169         FD->getTypeSpecStartLoc(),
170         "Compute kernel %0() cannot return a pointer type: '%1'")
171         << FD->getName() << mResultType.getAsString();
172     valid = false;
173   }
174 
175   // Validate remaining parameter types
176   // TODO(all): Add support for LOD/face when we have them
177 
178   size_t IndexOfFirstIterator = numParams;
179   valid |= validateIterationParameters(Context, FD, &IndexOfFirstIterator);
180 
181   // Validate the non-iterator parameters, which should all be found before the
182   // first iterator.
183   for (size_t i = 0; i < IndexOfFirstIterator; i++) {
184     const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
185 
186     /*
187      * FIXME: Change this to a test against an actual API version when the
188      *        multi-input feature is officially supported.
189      */
190     if (Context->getTargetAPI() == SLANG_DEVELOPMENT_TARGET_API || i == 0) {
191       mIns.push_back(PVD);
192     } else {
193       Context->ReportError(PVD->getLocation(),
194                            "Invalid parameter '%0' for compute kernel %1(). "
195                            "Kernels targeting SDK levels %2-%3 may not use "
196                            "multiple input parameters.") << PVD->getName() <<
197                            FD->getName() << SLANG_MINIMUM_TARGET_API <<
198                            SLANG_MAXIMUM_TARGET_API;
199       valid = false;
200     }
201     clang::QualType QT = PVD->getType().getCanonicalType();
202     if (QT->isPointerType()) {
203       Context->ReportError(PVD->getLocation(),
204                            "Compute kernel %0() cannot have "
205                            "parameter '%1' of pointer type: '%2'")
206           << FD->getName() << PVD->getName() << PVD->getType().getAsString();
207       valid = false;
208     }
209   }
210 
211   // Check that we have at least one allocation to use for dimensions.
212   if (valid && mIns.empty() && !mHasReturnType) {
213     Context->ReportError(FD->getLocation(),
214                          "Compute kernel %0() must have at least one "
215                          "input parameter or a non-void return "
216                          "type")
217         << FD->getName();
218     valid = false;
219   }
220 
221   return valid;
222 }
223 
224 // Search for the optional x and y parameters.  Returns true if valid.   Also
225 // sets *IndexOfFirstIterator to the index of the first iterator parameter, or
226 // FD->getNumParams() if none are found.
validateIterationParameters(RSContext * Context,const clang::FunctionDecl * FD,size_t * IndexOfFirstIterator)227 bool RSExportForEach::validateIterationParameters(
228     RSContext *Context, const clang::FunctionDecl *FD,
229     size_t *IndexOfFirstIterator) {
230   slangAssert(IndexOfFirstIterator != NULL);
231   slangAssert(mX == NULL && mY == NULL);
232   clang::ASTContext &C = Context->getASTContext();
233 
234   // Find the x and y parameters if present.
235   size_t NumParams = FD->getNumParams();
236   *IndexOfFirstIterator = NumParams;
237   bool valid = true;
238   for (size_t i = 0; i < NumParams; i++) {
239     const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
240     llvm::StringRef ParamName = PVD->getName();
241     if (ParamName.equals("x")) {
242       slangAssert(mX == NULL);  // We won't be invoked if two 'x' are present.
243       mX = PVD;
244       if (mY != NULL) {
245         Context->ReportError(PVD->getLocation(),
246                              "In compute kernel %0(), parameter 'x' should "
247                              "be defined before parameter 'y'")
248             << FD->getName();
249         valid = false;
250       }
251     } else if (ParamName.equals("y")) {
252       slangAssert(mY == NULL);  // We won't be invoked if two 'y' are present.
253       mY = PVD;
254     } else {
255       // It's neither x nor y.
256       if (*IndexOfFirstIterator < NumParams) {
257         Context->ReportError(PVD->getLocation(),
258                              "In compute kernel %0(), parameter '%1' cannot "
259                              "appear after the 'x' and 'y' parameters")
260             << FD->getName() << ParamName;
261         valid = false;
262       }
263       continue;
264     }
265     // Validate the data type of x and y.
266     clang::QualType QT = PVD->getType().getCanonicalType();
267     clang::QualType UT = QT.getUnqualifiedType();
268     if (UT != C.UnsignedIntTy && UT != C.IntTy) {
269       Context->ReportError(PVD->getLocation(),
270                            "Parameter '%0' must be of type 'int' or "
271                            "'unsigned int'. It is of type '%1'")
272           << ParamName << PVD->getType().getAsString();
273       valid = false;
274     }
275     // If this is the first time we find an iterator, save it.
276     if (*IndexOfFirstIterator >= NumParams) {
277       *IndexOfFirstIterator = i;
278     }
279   }
280   // Check that x and y have the same type.
281   if (mX != NULL and mY != NULL) {
282     clang::QualType XType = mX->getType();
283     clang::QualType YType = mY->getType();
284 
285     if (XType != YType) {
286       Context->ReportError(mY->getLocation(),
287                            "Parameter 'x' and 'y' must be of the same type. "
288                            "'x' is of type '%0' while 'y' is of type '%1'")
289           << XType.getAsString() << YType.getAsString();
290       valid = false;
291     }
292   }
293   return valid;
294 }
295 
setSignatureMetadata(RSContext * Context,const clang::FunctionDecl * FD)296 bool RSExportForEach::setSignatureMetadata(RSContext *Context,
297                                            const clang::FunctionDecl *FD) {
298   mSignatureMetadata = 0;
299   bool valid = true;
300 
301   if (mIsKernelStyle) {
302     slangAssert(mOut == NULL);
303     slangAssert(mUsrData == NULL);
304   } else {
305     slangAssert(!mHasReturnType);
306   }
307 
308   // Set up the bitwise metadata encoding for runtime argument passing.
309   // TODO: If this bit field is re-used from C++ code, define the values in a header.
310   const bool HasOut = mOut || mHasReturnType;
311   mSignatureMetadata |= (hasIns() ?       0x01 : 0);
312   mSignatureMetadata |= (HasOut ?         0x02 : 0);
313   mSignatureMetadata |= (mUsrData ?       0x04 : 0);
314   mSignatureMetadata |= (mX ?             0x08 : 0);
315   mSignatureMetadata |= (mY ?             0x10 : 0);
316   mSignatureMetadata |= (mIsKernelStyle ? 0x20 : 0);  // pass-by-value
317 
318   if (Context->getTargetAPI() < SLANG_ICS_TARGET_API) {
319     // APIs before ICS cannot skip between parameters. It is ok, however, for
320     // them to omit further parameters (i.e. skipping X is ok if you skip Y).
321     if (mSignatureMetadata != 0x1f &&  // In, Out, UsrData, X, Y
322         mSignatureMetadata != 0x0f &&  // In, Out, UsrData, X
323         mSignatureMetadata != 0x07 &&  // In, Out, UsrData
324         mSignatureMetadata != 0x03 &&  // In, Out
325         mSignatureMetadata != 0x01) {  // In
326       Context->ReportError(FD->getLocation(),
327                            "Compute kernel %0() targeting SDK levels "
328                            "%1-%2 may not skip parameters")
329           << FD->getName() << SLANG_MINIMUM_TARGET_API
330           << (SLANG_ICS_TARGET_API - 1);
331       valid = false;
332     }
333   }
334   return valid;
335 }
336 
Create(RSContext * Context,const clang::FunctionDecl * FD)337 RSExportForEach *RSExportForEach::Create(RSContext *Context,
338                                          const clang::FunctionDecl *FD) {
339   slangAssert(Context && FD);
340   llvm::StringRef Name = FD->getName();
341   RSExportForEach *FE;
342 
343   slangAssert(!Name.empty() && "Function must have a name");
344 
345   FE = new RSExportForEach(Context, Name);
346 
347   if (!FE->validateAndConstructParams(Context, FD)) {
348     return NULL;
349   }
350 
351   clang::ASTContext &Ctx = Context->getASTContext();
352 
353   std::string Id = CreateDummyName("helper_foreach_param", FE->getName());
354 
355   // Extract the usrData parameter (if we have one)
356   if (FE->mUsrData) {
357     const clang::ParmVarDecl *PVD = FE->mUsrData;
358     clang::QualType QT = PVD->getType().getCanonicalType();
359     slangAssert(QT->isPointerType() &&
360                 QT->getPointeeType().isConstQualified());
361 
362     const clang::ASTContext &C = Context->getASTContext();
363     if (QT->getPointeeType().getCanonicalType().getUnqualifiedType() ==
364         C.VoidTy) {
365       // In the case of using const void*, we can't reflect an appopriate
366       // Java type, so we fall back to just reflecting the ain/aout parameters
367       FE->mUsrData = NULL;
368     } else {
369       clang::RecordDecl *RD =
370           clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
371                                     Ctx.getTranslationUnitDecl(),
372                                     clang::SourceLocation(),
373                                     clang::SourceLocation(),
374                                     &Ctx.Idents.get(Id));
375 
376       clang::FieldDecl *FD =
377           clang::FieldDecl::Create(Ctx,
378                                    RD,
379                                    clang::SourceLocation(),
380                                    clang::SourceLocation(),
381                                    PVD->getIdentifier(),
382                                    QT->getPointeeType(),
383                                    NULL,
384                                    /* BitWidth = */ NULL,
385                                    /* Mutable = */ false,
386                                    /* HasInit = */ clang::ICIS_NoInit);
387       RD->addDecl(FD);
388       RD->completeDefinition();
389 
390       // Create an export type iff we have a valid usrData type
391       clang::QualType T = Ctx.getTagDeclType(RD);
392       slangAssert(!T.isNull());
393 
394       RSExportType *ET = RSExportType::Create(Context, T.getTypePtr());
395 
396       if (ET == NULL) {
397         fprintf(stderr, "Failed to export the function %s. There's at least "
398                         "one parameter whose type is not supported by the "
399                         "reflection\n", FE->getName().c_str());
400         return NULL;
401       }
402 
403       slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
404                   "Parameter packet must be a record");
405 
406       FE->mParamPacketType = static_cast<RSExportRecordType *>(ET);
407     }
408   }
409 
410   if (FE->hasIns()) {
411 
412     for (InIter BI = FE->mIns.begin(), EI = FE->mIns.end(); BI != EI; BI++) {
413       const clang::Type *T = (*BI)->getType().getCanonicalType().getTypePtr();
414       RSExportType *InExportType = RSExportType::Create(Context, T);
415 
416       if (FE->mIsKernelStyle) {
417         slangAssert(InExportType != NULL);
418       }
419 
420       FE->mInTypes.push_back(InExportType);
421     }
422   }
423 
424   if (FE->mIsKernelStyle && FE->mHasReturnType) {
425     const clang::Type *T = FE->mResultType.getTypePtr();
426     FE->mOutType = RSExportType::Create(Context, T);
427     slangAssert(FE->mOutType);
428   } else if (FE->mOut) {
429     const clang::Type *T = FE->mOut->getType().getCanonicalType().getTypePtr();
430     FE->mOutType = RSExportType::Create(Context, T);
431   }
432 
433   return FE;
434 }
435 
CreateDummyRoot(RSContext * Context)436 RSExportForEach *RSExportForEach::CreateDummyRoot(RSContext *Context) {
437   slangAssert(Context);
438   llvm::StringRef Name = "root";
439   RSExportForEach *FE = new RSExportForEach(Context, Name);
440   FE->mDummyRoot = true;
441   return FE;
442 }
443 
isGraphicsRootRSFunc(unsigned int targetAPI,const clang::FunctionDecl * FD)444 bool RSExportForEach::isGraphicsRootRSFunc(unsigned int targetAPI,
445                                            const clang::FunctionDecl *FD) {
446   if (FD->hasAttr<clang::KernelAttr>()) {
447     return false;
448   }
449 
450   if (!isRootRSFunc(FD)) {
451     return false;
452   }
453 
454   if (FD->getNumParams() == 0) {
455     // Graphics root function
456     return true;
457   }
458 
459   // Check for legacy graphics root function (with single parameter).
460   if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) {
461     const clang::QualType &IntType = FD->getASTContext().IntTy;
462     if (FD->getReturnType().getCanonicalType() == IntType) {
463       return true;
464     }
465   }
466 
467   return false;
468 }
469 
isRSForEachFunc(unsigned int targetAPI,slang::RSContext * Context,const clang::FunctionDecl * FD)470 bool RSExportForEach::isRSForEachFunc(unsigned int targetAPI,
471                                       slang::RSContext* Context,
472                                       const clang::FunctionDecl *FD) {
473   slangAssert(Context && FD);
474   bool hasKernelAttr = FD->hasAttr<clang::KernelAttr>();
475 
476   if (FD->getStorageClass() == clang::SC_Static) {
477     if (hasKernelAttr) {
478       Context->ReportError(FD->getLocation(),
479                            "Invalid use of attribute kernel with "
480                            "static function declaration: %0")
481           << FD->getName();
482     }
483     return false;
484   }
485 
486   // Anything tagged as a kernel is definitely used with ForEach.
487   if (hasKernelAttr) {
488     return true;
489   }
490 
491   if (isGraphicsRootRSFunc(targetAPI, FD)) {
492     return false;
493   }
494 
495   // Check if first parameter is a pointer (which is required for ForEach).
496   unsigned int numParams = FD->getNumParams();
497 
498   if (numParams > 0) {
499     const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
500     clang::QualType QT = PVD->getType().getCanonicalType();
501 
502     if (QT->isPointerType()) {
503       return true;
504     }
505 
506     // Any non-graphics root() is automatically a ForEach candidate.
507     // At this point, however, we know that it is not going to be a valid
508     // compute root() function (due to not having a pointer parameter). We
509     // still want to return true here, so that we can issue appropriate
510     // diagnostics.
511     if (isRootRSFunc(FD)) {
512       return true;
513     }
514   }
515 
516   return false;
517 }
518 
519 bool
validateSpecialFuncDecl(unsigned int targetAPI,slang::RSContext * Context,clang::FunctionDecl const * FD)520 RSExportForEach::validateSpecialFuncDecl(unsigned int targetAPI,
521                                          slang::RSContext *Context,
522                                          clang::FunctionDecl const *FD) {
523   slangAssert(Context && FD);
524   bool valid = true;
525   const clang::ASTContext &C = FD->getASTContext();
526   const clang::QualType &IntType = FD->getASTContext().IntTy;
527 
528   if (isGraphicsRootRSFunc(targetAPI, FD)) {
529     if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) {
530       // Legacy graphics root function
531       const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
532       clang::QualType QT = PVD->getType().getCanonicalType();
533       if (QT != IntType) {
534         Context->ReportError(PVD->getLocation(),
535                              "invalid parameter type for legacy "
536                              "graphics root() function: %0")
537             << PVD->getType();
538         valid = false;
539       }
540     }
541 
542     // Graphics root function, so verify that it returns an int
543     if (FD->getReturnType().getCanonicalType() != IntType) {
544       Context->ReportError(FD->getLocation(),
545                            "root() is required to return "
546                            "an int for graphics usage");
547       valid = false;
548     }
549   } else if (isInitRSFunc(FD) || isDtorRSFunc(FD)) {
550     if (FD->getNumParams() != 0) {
551       Context->ReportError(FD->getLocation(),
552                            "%0(void) is required to have no "
553                            "parameters")
554           << FD->getName();
555       valid = false;
556     }
557 
558     if (FD->getReturnType().getCanonicalType() != C.VoidTy) {
559       Context->ReportError(FD->getLocation(),
560                            "%0(void) is required to have a void "
561                            "return type")
562           << FD->getName();
563       valid = false;
564     }
565   } else {
566     slangAssert(false && "must be called on root, init or .rs.dtor function!");
567   }
568 
569   return valid;
570 }
571 
572 }  // namespace slang
573