1 /*
2  * Copyright 2010, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "slang_rs_object_ref_count.h"
18 
19 #include "clang/AST/DeclGroup.h"
20 #include "clang/AST/Expr.h"
21 #include "clang/AST/NestedNameSpecifier.h"
22 #include "clang/AST/OperationKinds.h"
23 #include "clang/AST/RecursiveASTVisitor.h"
24 #include "clang/AST/Stmt.h"
25 #include "clang/AST/StmtVisitor.h"
26 
27 #include "slang_assert.h"
28 #include "slang.h"
29 #include "slang_rs_ast_replace.h"
30 #include "slang_rs_export_type.h"
31 
32 namespace slang {
33 
34 /* Even though those two arrays are of size DataTypeMax, only entries that
35  * correspond to object types will be set.
36  */
37 clang::FunctionDecl *
38 RSObjectRefCount::RSSetObjectFD[DataTypeMax];
39 clang::FunctionDecl *
40 RSObjectRefCount::RSClearObjectFD[DataTypeMax];
41 
GetRSRefCountingFunctions(clang::ASTContext & C)42 void RSObjectRefCount::GetRSRefCountingFunctions(clang::ASTContext &C) {
43   for (unsigned i = 0; i < DataTypeMax; i++) {
44     RSSetObjectFD[i] = nullptr;
45     RSClearObjectFD[i] = nullptr;
46   }
47 
48   clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
49 
50   for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
51           E = TUDecl->decls_end(); I != E; I++) {
52     if ((I->getKind() >= clang::Decl::firstFunction) &&
53         (I->getKind() <= clang::Decl::lastFunction)) {
54       clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
55 
56       // points to RSSetObjectFD or RSClearObjectFD
57       clang::FunctionDecl **RSObjectFD;
58 
59       if (FD->getName() == "rsSetObject") {
60         slangAssert((FD->getNumParams() == 2) &&
61                     "Invalid rsSetObject function prototype (# params)");
62         RSObjectFD = RSSetObjectFD;
63       } else if (FD->getName() == "rsClearObject") {
64         slangAssert((FD->getNumParams() == 1) &&
65                     "Invalid rsClearObject function prototype (# params)");
66         RSObjectFD = RSClearObjectFD;
67       } else {
68         continue;
69       }
70 
71       const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
72       clang::QualType PVT = PVD->getOriginalType();
73       // The first parameter must be a pointer like rs_allocation*
74       slangAssert(PVT->isPointerType() &&
75           "Invalid rs{Set,Clear}Object function prototype (pointer param)");
76 
77       // The rs object type passed to the FD
78       clang::QualType RST = PVT->getPointeeType();
79       DataType DT = RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
80       slangAssert(RSExportPrimitiveType::IsRSObjectType(DT)
81              && "must be RS object type");
82 
83       if (DT >= 0 && DT < DataTypeMax) {
84           RSObjectFD[DT] = FD;
85       } else {
86           slangAssert(false && "incorrect type");
87       }
88     }
89   }
90 }
91 
92 namespace {
93 
94 unsigned CountRSObjectTypes(const clang::Type *T);
95 
96 clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
97                                      clang::Expr *DstExpr,
98                                      clang::Expr *SrcExpr,
99                                      clang::SourceLocation StartLoc,
100                                      clang::SourceLocation Loc);
101 
102 // This function constructs a new CompoundStmt from the input StmtList.
BuildCompoundStmt(clang::ASTContext & C,std::vector<clang::Stmt * > & StmtList,clang::SourceLocation Loc)103 clang::CompoundStmt* BuildCompoundStmt(clang::ASTContext &C,
104       std::vector<clang::Stmt*> &StmtList, clang::SourceLocation Loc) {
105   unsigned NewStmtCount = StmtList.size();
106   unsigned CompoundStmtCount = 0;
107 
108   clang::Stmt **CompoundStmtList;
109   CompoundStmtList = new clang::Stmt*[NewStmtCount];
110 
111   std::vector<clang::Stmt*>::const_iterator I = StmtList.begin();
112   std::vector<clang::Stmt*>::const_iterator E = StmtList.end();
113   for ( ; I != E; I++) {
114     CompoundStmtList[CompoundStmtCount++] = *I;
115   }
116   slangAssert(CompoundStmtCount == NewStmtCount);
117 
118   clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
119       C, llvm::makeArrayRef(CompoundStmtList, CompoundStmtCount), Loc, Loc);
120 
121   delete [] CompoundStmtList;
122 
123   return CS;
124 }
125 
AppendAfterStmt(clang::ASTContext & C,clang::CompoundStmt * CS,clang::Stmt * S,std::list<clang::Stmt * > & StmtList)126 void AppendAfterStmt(clang::ASTContext &C,
127                      clang::CompoundStmt *CS,
128                      clang::Stmt *S,
129                      std::list<clang::Stmt*> &StmtList) {
130   slangAssert(CS);
131   clang::CompoundStmt::body_iterator bI = CS->body_begin();
132   clang::CompoundStmt::body_iterator bE = CS->body_end();
133   clang::Stmt **UpdatedStmtList =
134       new clang::Stmt*[CS->size() + StmtList.size()];
135 
136   unsigned UpdatedStmtCount = 0;
137   unsigned Once = 0;
138   for ( ; bI != bE; bI++) {
139     if (!S && ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass)) {
140       // If we come across a return here, we don't have anything we can
141       // reasonably replace. We should have already inserted our destructor
142       // code in the proper spot, so we just clean up and return.
143       delete [] UpdatedStmtList;
144 
145       return;
146     }
147 
148     UpdatedStmtList[UpdatedStmtCount++] = *bI;
149 
150     if ((*bI == S) && !Once) {
151       Once++;
152       std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
153       std::list<clang::Stmt*>::const_iterator E = StmtList.end();
154       for ( ; I != E; I++) {
155         UpdatedStmtList[UpdatedStmtCount++] = *I;
156       }
157     }
158   }
159   slangAssert(Once <= 1);
160 
161   // When S is nullptr, we are appending to the end of the CompoundStmt.
162   if (!S) {
163     slangAssert(Once == 0);
164     std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
165     std::list<clang::Stmt*>::const_iterator E = StmtList.end();
166     for ( ; I != E; I++) {
167       UpdatedStmtList[UpdatedStmtCount++] = *I;
168     }
169   }
170 
171   CS->setStmts(C, llvm::makeArrayRef(UpdatedStmtList, UpdatedStmtCount));
172 
173   delete [] UpdatedStmtList;
174 }
175 
176 // This class visits a compound statement and collects a list of all the exiting
177 // statements, such as any return statement in any sub-block, and any
178 // break/continue statement that would resume outside the current scope.
179 // We do not handle the case for goto statements that leave a local scope.
180 class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
181  private:
182   // The loop depth of the currently visited node.
183   int mLoopDepth;
184 
185   // The switch statement depth of the currently visited node.
186   // Note that this is tracked separately from the loop depth because
187   // SwitchStmt-contained ContinueStmt's should have destructors for the
188   // corresponding loop scope.
189   int mSwitchDepth;
190 
191   // Output of the visitor: the statements that should be replaced by compound
192   // statements, each of which contains rsClearObject() calls followed by the
193   // original statement.
194   std::vector<clang::Stmt*> mExitingStmts;
195 
196  public:
DestructorVisitor()197   DestructorVisitor() : mLoopDepth(0), mSwitchDepth(0) {}
198 
getExitingStmts() const199   const std::vector<clang::Stmt*>& getExitingStmts() const {
200     return mExitingStmts;
201   }
202 
203   void VisitStmt(clang::Stmt *S);
204   void VisitBreakStmt(clang::BreakStmt *BS);
205   void VisitContinueStmt(clang::ContinueStmt *CS);
206   void VisitDoStmt(clang::DoStmt *DS);
207   void VisitForStmt(clang::ForStmt *FS);
208   void VisitReturnStmt(clang::ReturnStmt *RS);
209   void VisitSwitchStmt(clang::SwitchStmt *SS);
210   void VisitWhileStmt(clang::WhileStmt *WS);
211 };
212 
VisitStmt(clang::Stmt * S)213 void DestructorVisitor::VisitStmt(clang::Stmt *S) {
214   for (clang::Stmt* Child : S->children()) {
215     if (Child) {
216       Visit(Child);
217     }
218   }
219 }
220 
VisitBreakStmt(clang::BreakStmt * BS)221 void DestructorVisitor::VisitBreakStmt(clang::BreakStmt *BS) {
222   VisitStmt(BS);
223   if ((mLoopDepth == 0) && (mSwitchDepth == 0)) {
224     mExitingStmts.push_back(BS);
225   }
226 }
227 
VisitContinueStmt(clang::ContinueStmt * CS)228 void DestructorVisitor::VisitContinueStmt(clang::ContinueStmt *CS) {
229   VisitStmt(CS);
230   if (mLoopDepth == 0) {
231     // Switch statements can have nested continues.
232     mExitingStmts.push_back(CS);
233   }
234 }
235 
VisitDoStmt(clang::DoStmt * DS)236 void DestructorVisitor::VisitDoStmt(clang::DoStmt *DS) {
237   mLoopDepth++;
238   VisitStmt(DS);
239   mLoopDepth--;
240 }
241 
VisitForStmt(clang::ForStmt * FS)242 void DestructorVisitor::VisitForStmt(clang::ForStmt *FS) {
243   mLoopDepth++;
244   VisitStmt(FS);
245   mLoopDepth--;
246 }
247 
VisitReturnStmt(clang::ReturnStmt * RS)248 void DestructorVisitor::VisitReturnStmt(clang::ReturnStmt *RS) {
249   mExitingStmts.push_back(RS);
250 }
251 
VisitSwitchStmt(clang::SwitchStmt * SS)252 void DestructorVisitor::VisitSwitchStmt(clang::SwitchStmt *SS) {
253   mSwitchDepth++;
254   VisitStmt(SS);
255   mSwitchDepth--;
256 }
257 
VisitWhileStmt(clang::WhileStmt * WS)258 void DestructorVisitor::VisitWhileStmt(clang::WhileStmt *WS) {
259   mLoopDepth++;
260   VisitStmt(WS);
261   mLoopDepth--;
262 }
263 
ClearSingleRSObject(clang::ASTContext & C,clang::Expr * RefRSVar,clang::SourceLocation Loc)264 clang::Expr *ClearSingleRSObject(clang::ASTContext &C,
265                                  clang::Expr *RefRSVar,
266                                  clang::SourceLocation Loc) {
267   slangAssert(RefRSVar);
268   const clang::Type *T = RefRSVar->getType().getTypePtr();
269   slangAssert(!T->isArrayType() &&
270               "Should not be destroying arrays with this function");
271 
272   clang::FunctionDecl *ClearObjectFD = RSObjectRefCount::GetRSClearObjectFD(T);
273   slangAssert((ClearObjectFD != nullptr) &&
274               "rsClearObject doesn't cover all RS object types");
275 
276   clang::QualType ClearObjectFDType = ClearObjectFD->getType();
277   clang::QualType ClearObjectFDArgType =
278       ClearObjectFD->getParamDecl(0)->getOriginalType();
279 
280   // Example destructor for "rs_font localFont;"
281   //
282   // (CallExpr 'void'
283   //   (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
284   //     (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
285   //   (UnaryOperator 'rs_font *' prefix '&'
286   //     (DeclRefExpr 'rs_font':'rs_font' Var='localFont')))
287 
288   // Get address of targeted RS object
289   clang::Expr *AddrRefRSVar =
290       new(C) clang::UnaryOperator(RefRSVar,
291                                   clang::UO_AddrOf,
292                                   ClearObjectFDArgType,
293                                   clang::VK_RValue,
294                                   clang::OK_Ordinary,
295                                   Loc);
296 
297   clang::Expr *RefRSClearObjectFD =
298       clang::DeclRefExpr::Create(C,
299                                  clang::NestedNameSpecifierLoc(),
300                                  clang::SourceLocation(),
301                                  ClearObjectFD,
302                                  false,
303                                  ClearObjectFD->getLocation(),
304                                  ClearObjectFDType,
305                                  clang::VK_RValue,
306                                  nullptr);
307 
308   clang::Expr *RSClearObjectFP =
309       clang::ImplicitCastExpr::Create(C,
310                                       C.getPointerType(ClearObjectFDType),
311                                       clang::CK_FunctionToPointerDecay,
312                                       RefRSClearObjectFD,
313                                       nullptr,
314                                       clang::VK_RValue);
315 
316   llvm::SmallVector<clang::Expr*, 1> ArgList;
317   ArgList.push_back(AddrRefRSVar);
318 
319   clang::CallExpr *RSClearObjectCall =
320       new(C) clang::CallExpr(C,
321                              RSClearObjectFP,
322                              ArgList,
323                              ClearObjectFD->getCallResultType(),
324                              clang::VK_RValue,
325                              Loc);
326 
327   return RSClearObjectCall;
328 }
329 
ArrayDim(const clang::Type * T)330 static int ArrayDim(const clang::Type *T) {
331   if (!T || !T->isArrayType()) {
332     return 0;
333   }
334 
335   const clang::ConstantArrayType *CAT =
336     static_cast<const clang::ConstantArrayType *>(T);
337   return static_cast<int>(CAT->getSize().getSExtValue());
338 }
339 
340 clang::Stmt *ClearStructRSObject(
341     clang::ASTContext &C,
342     clang::DeclContext *DC,
343     clang::Expr *RefRSStruct,
344     clang::SourceLocation StartLoc,
345     clang::SourceLocation Loc);
346 
ClearArrayRSObject(clang::ASTContext & C,clang::DeclContext * DC,clang::Expr * RefRSArr,clang::SourceLocation StartLoc,clang::SourceLocation Loc)347 clang::Stmt *ClearArrayRSObject(
348     clang::ASTContext &C,
349     clang::DeclContext *DC,
350     clang::Expr *RefRSArr,
351     clang::SourceLocation StartLoc,
352     clang::SourceLocation Loc) {
353   const clang::Type *BaseType = RefRSArr->getType().getTypePtr();
354   slangAssert(BaseType->isArrayType());
355 
356   int NumArrayElements = ArrayDim(BaseType);
357   // Actually extract out the base RS object type for use later
358   BaseType = BaseType->getArrayElementTypeNoTypeQual();
359 
360   if (NumArrayElements <= 0) {
361     return nullptr;
362   }
363 
364   // Example destructor loop for "rs_font fontArr[10];"
365   //
366   // (ForStmt
367   //   (DeclStmt
368   //     (VarDecl used rsIntIter 'int' cinit
369   //       (IntegerLiteral 'int' 0)))
370   //   (BinaryOperator 'int' '<'
371   //     (ImplicitCastExpr int LValueToRValue
372   //       (DeclRefExpr 'int' Var='rsIntIter'))
373   //     (IntegerLiteral 'int' 10)
374   //   nullptr << CondVar >>
375   //   (UnaryOperator 'int' postfix '++'
376   //     (DeclRefExpr 'int' Var='rsIntIter'))
377   //   (CallExpr 'void'
378   //     (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
379   //       (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
380   //     (UnaryOperator 'rs_font *' prefix '&'
381   //       (ArraySubscriptExpr 'rs_font':'rs_font'
382   //         (ImplicitCastExpr 'rs_font *' <ArrayToPointerDecay>
383   //           (DeclRefExpr 'rs_font [10]' Var='fontArr'))
384   //         (DeclRefExpr 'int' Var='rsIntIter'))))))
385 
386   // Create helper variable for iterating through elements
387   static unsigned sIterCounter = 0;
388   std::stringstream UniqueIterName;
389   UniqueIterName << "rsIntIter" << sIterCounter++;
390   clang::IdentifierInfo *II = &C.Idents.get(UniqueIterName.str());
391   clang::VarDecl *IIVD =
392       clang::VarDecl::Create(C,
393                              DC,
394                              StartLoc,
395                              Loc,
396                              II,
397                              C.IntTy,
398                              C.getTrivialTypeSourceInfo(C.IntTy),
399                              clang::SC_None);
400   // Mark "rsIntIter" as used
401   IIVD->markUsed(C);
402 
403   // Form the actual destructor loop
404   // for (Init; Cond; Inc)
405   //   RSClearObjectCall;
406 
407   // Init -> "int rsIntIter = 0"
408   clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
409       llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
410   IIVD->setInit(Int0);
411 
412   clang::Decl *IID = (clang::Decl *)IIVD;
413   clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
414   clang::Stmt *Init = new(C) clang::DeclStmt(DGR, Loc, Loc);
415 
416   // Cond -> "rsIntIter < NumArrayElements"
417   clang::DeclRefExpr *RefrsIntIterLValue =
418       clang::DeclRefExpr::Create(C,
419                                  clang::NestedNameSpecifierLoc(),
420                                  clang::SourceLocation(),
421                                  IIVD,
422                                  false,
423                                  Loc,
424                                  C.IntTy,
425                                  clang::VK_LValue,
426                                  nullptr);
427 
428   clang::Expr *RefrsIntIterRValue =
429       clang::ImplicitCastExpr::Create(C,
430                                       RefrsIntIterLValue->getType(),
431                                       clang::CK_LValueToRValue,
432                                       RefrsIntIterLValue,
433                                       nullptr,
434                                       clang::VK_RValue);
435 
436   clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
437       llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
438 
439   clang::BinaryOperator *Cond =
440       new(C) clang::BinaryOperator(RefrsIntIterRValue,
441                                    NumArrayElementsExpr,
442                                    clang::BO_LT,
443                                    C.IntTy,
444                                    clang::VK_RValue,
445                                    clang::OK_Ordinary,
446                                    Loc,
447                                    false);
448 
449   // Inc -> "rsIntIter++"
450   clang::UnaryOperator *Inc =
451       new(C) clang::UnaryOperator(RefrsIntIterLValue,
452                                   clang::UO_PostInc,
453                                   C.IntTy,
454                                   clang::VK_RValue,
455                                   clang::OK_Ordinary,
456                                   Loc);
457 
458   // Body -> "rsClearObject(&VD[rsIntIter]);"
459   // Destructor loop operates on individual array elements
460 
461   clang::Expr *RefRSArrPtr =
462       clang::ImplicitCastExpr::Create(C,
463           C.getPointerType(BaseType->getCanonicalTypeInternal()),
464           clang::CK_ArrayToPointerDecay,
465           RefRSArr,
466           nullptr,
467           clang::VK_RValue);
468 
469   clang::Expr *RefRSArrPtrSubscript =
470       new(C) clang::ArraySubscriptExpr(RefRSArrPtr,
471                                        RefrsIntIterRValue,
472                                        BaseType->getCanonicalTypeInternal(),
473                                        clang::VK_RValue,
474                                        clang::OK_Ordinary,
475                                        Loc);
476 
477   DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
478 
479   clang::Stmt *RSClearObjectCall = nullptr;
480   if (BaseType->isArrayType()) {
481     RSClearObjectCall =
482         ClearArrayRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
483   } else if (DT == DataTypeUnknown) {
484     RSClearObjectCall =
485         ClearStructRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
486   } else {
487     RSClearObjectCall = ClearSingleRSObject(C, RefRSArrPtrSubscript, Loc);
488   }
489 
490   clang::ForStmt *DestructorLoop =
491       new(C) clang::ForStmt(C,
492                             Init,
493                             Cond,
494                             nullptr,  // no condVar
495                             Inc,
496                             RSClearObjectCall,
497                             Loc,
498                             Loc,
499                             Loc);
500 
501   return DestructorLoop;
502 }
503 
CountRSObjectTypes(const clang::Type * T)504 unsigned CountRSObjectTypes(const clang::Type *T) {
505   slangAssert(T);
506   unsigned RSObjectCount = 0;
507 
508   if (T->isArrayType()) {
509     return CountRSObjectTypes(T->getArrayElementTypeNoTypeQual());
510   }
511 
512   DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
513   if (DT != DataTypeUnknown) {
514     return (RSExportPrimitiveType::IsRSObjectType(DT) ? 1 : 0);
515   }
516 
517   if (T->isUnionType()) {
518     clang::RecordDecl *RD = T->getAsUnionType()->getDecl();
519     RD = RD->getDefinition();
520     for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
521            FE = RD->field_end();
522          FI != FE;
523          FI++) {
524       const clang::FieldDecl *FD = *FI;
525       const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
526       if (CountRSObjectTypes(FT)) {
527         slangAssert(false && "can't have unions with RS object types!");
528         return 0;
529       }
530     }
531   }
532 
533   if (!T->isStructureType()) {
534     return 0;
535   }
536 
537   clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
538   RD = RD->getDefinition();
539   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
540          FE = RD->field_end();
541        FI != FE;
542        FI++) {
543     const clang::FieldDecl *FD = *FI;
544     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
545     if (CountRSObjectTypes(FT)) {
546       // Sub-structs should only count once (as should arrays, etc.)
547       RSObjectCount++;
548     }
549   }
550 
551   return RSObjectCount;
552 }
553 
ClearStructRSObject(clang::ASTContext & C,clang::DeclContext * DC,clang::Expr * RefRSStruct,clang::SourceLocation StartLoc,clang::SourceLocation Loc)554 clang::Stmt *ClearStructRSObject(
555     clang::ASTContext &C,
556     clang::DeclContext *DC,
557     clang::Expr *RefRSStruct,
558     clang::SourceLocation StartLoc,
559     clang::SourceLocation Loc) {
560   const clang::Type *BaseType = RefRSStruct->getType().getTypePtr();
561 
562   slangAssert(!BaseType->isArrayType());
563 
564   // Structs should show up as unknown primitive types
565   slangAssert(RSExportPrimitiveType::GetRSSpecificType(BaseType) ==
566               DataTypeUnknown);
567 
568   unsigned FieldsToDestroy = CountRSObjectTypes(BaseType);
569   slangAssert(FieldsToDestroy != 0);
570 
571   unsigned StmtCount = 0;
572   clang::Stmt **StmtArray = new clang::Stmt*[FieldsToDestroy];
573   for (unsigned i = 0; i < FieldsToDestroy; i++) {
574     StmtArray[i] = nullptr;
575   }
576 
577   // Populate StmtArray by creating a destructor for each RS object field
578   clang::RecordDecl *RD = BaseType->getAsStructureType()->getDecl();
579   RD = RD->getDefinition();
580   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
581          FE = RD->field_end();
582        FI != FE;
583        FI++) {
584     // We just look through all field declarations to see if we find a
585     // declaration for an RS object type (or an array of one).
586     bool IsArrayType = false;
587     clang::FieldDecl *FD = *FI;
588     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
589     slangAssert(FT);
590     const clang::Type *OrigType = FT;
591     while (FT->isArrayType()) {
592       FT = FT->getArrayElementTypeNoTypeQual();
593       slangAssert(FT);
594       IsArrayType = true;
595     }
596 
597     // Pass a DeclarationNameInfo with a valid DeclName, since name equality
598     // gets asserted during CodeGen.
599     clang::DeclarationNameInfo FDDeclNameInfo(FD->getDeclName(),
600                                               FD->getLocation());
601 
602     if (RSExportPrimitiveType::IsRSObjectType(FT)) {
603       clang::DeclAccessPair FoundDecl =
604           clang::DeclAccessPair::make(FD, clang::AS_none);
605       clang::MemberExpr *RSObjectMember =
606           clang::MemberExpr::Create(C,
607                                     RefRSStruct,
608                                     false,
609                                     clang::SourceLocation(),
610                                     clang::NestedNameSpecifierLoc(),
611                                     clang::SourceLocation(),
612                                     FD,
613                                     FoundDecl,
614                                     FDDeclNameInfo,
615                                     nullptr,
616                                     OrigType->getCanonicalTypeInternal(),
617                                     clang::VK_RValue,
618                                     clang::OK_Ordinary);
619 
620       slangAssert(StmtCount < FieldsToDestroy);
621 
622       if (IsArrayType) {
623         StmtArray[StmtCount++] = ClearArrayRSObject(C,
624                                                     DC,
625                                                     RSObjectMember,
626                                                     StartLoc,
627                                                     Loc);
628       } else {
629         StmtArray[StmtCount++] = ClearSingleRSObject(C,
630                                                      RSObjectMember,
631                                                      Loc);
632       }
633     } else if (FT->isStructureType() && CountRSObjectTypes(FT)) {
634       // In this case, we have a nested struct. We may not end up filling all
635       // of the spaces in StmtArray (sub-structs should handle themselves
636       // with separate compound statements).
637       clang::DeclAccessPair FoundDecl =
638           clang::DeclAccessPair::make(FD, clang::AS_none);
639       clang::MemberExpr *RSObjectMember =
640           clang::MemberExpr::Create(C,
641                                     RefRSStruct,
642                                     false,
643                                     clang::SourceLocation(),
644                                     clang::NestedNameSpecifierLoc(),
645                                     clang::SourceLocation(),
646                                     FD,
647                                     FoundDecl,
648                                     clang::DeclarationNameInfo(),
649                                     nullptr,
650                                     OrigType->getCanonicalTypeInternal(),
651                                     clang::VK_RValue,
652                                     clang::OK_Ordinary);
653 
654       if (IsArrayType) {
655         StmtArray[StmtCount++] = ClearArrayRSObject(C,
656                                                     DC,
657                                                     RSObjectMember,
658                                                     StartLoc,
659                                                     Loc);
660       } else {
661         StmtArray[StmtCount++] = ClearStructRSObject(C,
662                                                      DC,
663                                                      RSObjectMember,
664                                                      StartLoc,
665                                                      Loc);
666       }
667     }
668   }
669 
670   slangAssert(StmtCount > 0);
671   clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
672       C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
673 
674   delete [] StmtArray;
675 
676   return CS;
677 }
678 
CreateSingleRSSetObject(clang::ASTContext & C,clang::Expr * DstExpr,clang::Expr * SrcExpr,clang::SourceLocation StartLoc,clang::SourceLocation Loc)679 clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
680                                      clang::Expr *DstExpr,
681                                      clang::Expr *SrcExpr,
682                                      clang::SourceLocation StartLoc,
683                                      clang::SourceLocation Loc) {
684   const clang::Type *T = DstExpr->getType().getTypePtr();
685   clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(T);
686   slangAssert((SetObjectFD != nullptr) &&
687               "rsSetObject doesn't cover all RS object types");
688 
689   clang::QualType SetObjectFDType = SetObjectFD->getType();
690   clang::QualType SetObjectFDArgType[2];
691   SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
692   SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
693 
694   clang::Expr *RefRSSetObjectFD =
695       clang::DeclRefExpr::Create(C,
696                                  clang::NestedNameSpecifierLoc(),
697                                  clang::SourceLocation(),
698                                  SetObjectFD,
699                                  false,
700                                  Loc,
701                                  SetObjectFDType,
702                                  clang::VK_RValue,
703                                  nullptr);
704 
705   clang::Expr *RSSetObjectFP =
706       clang::ImplicitCastExpr::Create(C,
707                                       C.getPointerType(SetObjectFDType),
708                                       clang::CK_FunctionToPointerDecay,
709                                       RefRSSetObjectFD,
710                                       nullptr,
711                                       clang::VK_RValue);
712 
713   llvm::SmallVector<clang::Expr*, 2> ArgList;
714   ArgList.push_back(new(C) clang::UnaryOperator(DstExpr,
715                                                 clang::UO_AddrOf,
716                                                 SetObjectFDArgType[0],
717                                                 clang::VK_RValue,
718                                                 clang::OK_Ordinary,
719                                                 Loc));
720   ArgList.push_back(SrcExpr);
721 
722   clang::CallExpr *RSSetObjectCall =
723       new(C) clang::CallExpr(C,
724                              RSSetObjectFP,
725                              ArgList,
726                              SetObjectFD->getCallResultType(),
727                              clang::VK_RValue,
728                              Loc);
729 
730   return RSSetObjectCall;
731 }
732 
733 clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
734                                      clang::Expr *LHS,
735                                      clang::Expr *RHS,
736                                      clang::SourceLocation StartLoc,
737                                      clang::SourceLocation Loc);
738 
739 /*static clang::Stmt *CreateArrayRSSetObject(clang::ASTContext &C,
740                                            clang::Expr *DstArr,
741                                            clang::Expr *SrcArr,
742                                            clang::SourceLocation StartLoc,
743                                            clang::SourceLocation Loc) {
744   clang::DeclContext *DC = nullptr;
745   const clang::Type *BaseType = DstArr->getType().getTypePtr();
746   slangAssert(BaseType->isArrayType());
747 
748   int NumArrayElements = ArrayDim(BaseType);
749   // Actually extract out the base RS object type for use later
750   BaseType = BaseType->getArrayElementTypeNoTypeQual();
751 
752   clang::Stmt *StmtArray[2] = {nullptr};
753   int StmtCtr = 0;
754 
755   if (NumArrayElements <= 0) {
756     return nullptr;
757   }
758 
759   // Create helper variable for iterating through elements
760   clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
761   clang::VarDecl *IIVD =
762       clang::VarDecl::Create(C,
763                              DC,
764                              StartLoc,
765                              Loc,
766                              &II,
767                              C.IntTy,
768                              C.getTrivialTypeSourceInfo(C.IntTy),
769                              clang::SC_None,
770                              clang::SC_None);
771   clang::Decl *IID = (clang::Decl *)IIVD;
772 
773   clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
774   StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
775 
776   // Form the actual loop
777   // for (Init; Cond; Inc)
778   //   RSSetObjectCall;
779 
780   // Init -> "rsIntIter = 0"
781   clang::DeclRefExpr *RefrsIntIter =
782       clang::DeclRefExpr::Create(C,
783                                  clang::NestedNameSpecifierLoc(),
784                                  IIVD,
785                                  Loc,
786                                  C.IntTy,
787                                  clang::VK_RValue,
788                                  nullptr);
789 
790   clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
791       llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
792 
793   clang::BinaryOperator *Init =
794       new(C) clang::BinaryOperator(RefrsIntIter,
795                                    Int0,
796                                    clang::BO_Assign,
797                                    C.IntTy,
798                                    clang::VK_RValue,
799                                    clang::OK_Ordinary,
800                                    Loc);
801 
802   // Cond -> "rsIntIter < NumArrayElements"
803   clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
804       llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
805 
806   clang::BinaryOperator *Cond =
807       new(C) clang::BinaryOperator(RefrsIntIter,
808                                    NumArrayElementsExpr,
809                                    clang::BO_LT,
810                                    C.IntTy,
811                                    clang::VK_RValue,
812                                    clang::OK_Ordinary,
813                                    Loc);
814 
815   // Inc -> "rsIntIter++"
816   clang::UnaryOperator *Inc =
817       new(C) clang::UnaryOperator(RefrsIntIter,
818                                   clang::UO_PostInc,
819                                   C.IntTy,
820                                   clang::VK_RValue,
821                                   clang::OK_Ordinary,
822                                   Loc);
823 
824   // Body -> "rsSetObject(&Dst[rsIntIter], Src[rsIntIter]);"
825   // Loop operates on individual array elements
826 
827   clang::Expr *DstArrPtr =
828       clang::ImplicitCastExpr::Create(C,
829           C.getPointerType(BaseType->getCanonicalTypeInternal()),
830           clang::CK_ArrayToPointerDecay,
831           DstArr,
832           nullptr,
833           clang::VK_RValue);
834 
835   clang::Expr *DstArrPtrSubscript =
836       new(C) clang::ArraySubscriptExpr(DstArrPtr,
837                                        RefrsIntIter,
838                                        BaseType->getCanonicalTypeInternal(),
839                                        clang::VK_RValue,
840                                        clang::OK_Ordinary,
841                                        Loc);
842 
843   clang::Expr *SrcArrPtr =
844       clang::ImplicitCastExpr::Create(C,
845           C.getPointerType(BaseType->getCanonicalTypeInternal()),
846           clang::CK_ArrayToPointerDecay,
847           SrcArr,
848           nullptr,
849           clang::VK_RValue);
850 
851   clang::Expr *SrcArrPtrSubscript =
852       new(C) clang::ArraySubscriptExpr(SrcArrPtr,
853                                        RefrsIntIter,
854                                        BaseType->getCanonicalTypeInternal(),
855                                        clang::VK_RValue,
856                                        clang::OK_Ordinary,
857                                        Loc);
858 
859   DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
860 
861   clang::Stmt *RSSetObjectCall = nullptr;
862   if (BaseType->isArrayType()) {
863     RSSetObjectCall = CreateArrayRSSetObject(C, DstArrPtrSubscript,
864                                              SrcArrPtrSubscript,
865                                              StartLoc, Loc);
866   } else if (DT == DataTypeUnknown) {
867     RSSetObjectCall = CreateStructRSSetObject(C, DstArrPtrSubscript,
868                                               SrcArrPtrSubscript,
869                                               StartLoc, Loc);
870   } else {
871     RSSetObjectCall = CreateSingleRSSetObject(C, DstArrPtrSubscript,
872                                               SrcArrPtrSubscript,
873                                               StartLoc, Loc);
874   }
875 
876   clang::ForStmt *DestructorLoop =
877       new(C) clang::ForStmt(C,
878                             Init,
879                             Cond,
880                             nullptr,  // no condVar
881                             Inc,
882                             RSSetObjectCall,
883                             Loc,
884                             Loc,
885                             Loc);
886 
887   StmtArray[StmtCtr++] = DestructorLoop;
888   slangAssert(StmtCtr == 2);
889 
890   clang::CompoundStmt *CS =
891       new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
892 
893   return CS;
894 } */
895 
CreateStructRSSetObject(clang::ASTContext & C,clang::Expr * LHS,clang::Expr * RHS,clang::SourceLocation StartLoc,clang::SourceLocation Loc)896 clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
897                                      clang::Expr *LHS,
898                                      clang::Expr *RHS,
899                                      clang::SourceLocation StartLoc,
900                                      clang::SourceLocation Loc) {
901   clang::QualType QT = LHS->getType();
902   const clang::Type *T = QT.getTypePtr();
903   slangAssert(T->isStructureType());
904   slangAssert(!RSExportPrimitiveType::IsRSObjectType(T));
905 
906   // Keep an extra slot for the original copy (memcpy)
907   unsigned FieldsToSet = CountRSObjectTypes(T) + 1;
908 
909   unsigned StmtCount = 0;
910   clang::Stmt **StmtArray = new clang::Stmt*[FieldsToSet];
911   for (unsigned i = 0; i < FieldsToSet; i++) {
912     StmtArray[i] = nullptr;
913   }
914 
915   clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
916   RD = RD->getDefinition();
917   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
918          FE = RD->field_end();
919        FI != FE;
920        FI++) {
921     bool IsArrayType = false;
922     clang::FieldDecl *FD = *FI;
923     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
924     const clang::Type *OrigType = FT;
925 
926     if (!CountRSObjectTypes(FT)) {
927       // Skip to next if we don't have any viable RS object types
928       continue;
929     }
930 
931     clang::DeclAccessPair FoundDecl =
932         clang::DeclAccessPair::make(FD, clang::AS_none);
933     clang::MemberExpr *DstMember =
934         clang::MemberExpr::Create(C,
935                                   LHS,
936                                   false,
937                                   clang::SourceLocation(),
938                                   clang::NestedNameSpecifierLoc(),
939                                   clang::SourceLocation(),
940                                   FD,
941                                   FoundDecl,
942                                   clang::DeclarationNameInfo(
943                                       FD->getDeclName(),
944                                       clang::SourceLocation()),
945                                   nullptr,
946                                   OrigType->getCanonicalTypeInternal(),
947                                   clang::VK_RValue,
948                                   clang::OK_Ordinary);
949 
950     clang::MemberExpr *SrcMember =
951         clang::MemberExpr::Create(C,
952                                   RHS,
953                                   false,
954                                   clang::SourceLocation(),
955                                   clang::NestedNameSpecifierLoc(),
956                                   clang::SourceLocation(),
957                                   FD,
958                                   FoundDecl,
959                                   clang::DeclarationNameInfo(
960                                       FD->getDeclName(),
961                                       clang::SourceLocation()),
962                                   nullptr,
963                                   OrigType->getCanonicalTypeInternal(),
964                                   clang::VK_RValue,
965                                   clang::OK_Ordinary);
966 
967     if (FT->isArrayType()) {
968       FT = FT->getArrayElementTypeNoTypeQual();
969       IsArrayType = true;
970     }
971 
972     DataType DT = RSExportPrimitiveType::GetRSSpecificType(FT);
973 
974     if (IsArrayType) {
975       clang::DiagnosticsEngine &DiagEngine = C.getDiagnostics();
976       DiagEngine.Report(
977         clang::FullSourceLoc(Loc, C.getSourceManager()),
978         DiagEngine.getCustomDiagID(
979           clang::DiagnosticsEngine::Error,
980           "Arrays of RS object types within structures cannot be copied"));
981       // TODO(srhines): Support setting arrays of RS objects
982       // StmtArray[StmtCount++] =
983       //    CreateArrayRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
984     } else if (DT == DataTypeUnknown) {
985       StmtArray[StmtCount++] =
986           CreateStructRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
987     } else if (RSExportPrimitiveType::IsRSObjectType(DT)) {
988       StmtArray[StmtCount++] =
989           CreateSingleRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
990     } else {
991       slangAssert(false);
992     }
993   }
994 
995   slangAssert(StmtCount < FieldsToSet);
996 
997   // We still need to actually do the overall struct copy. For simplicity,
998   // we just do a straight-up assignment (which will still preserve all
999   // the proper RS object reference counts).
1000   clang::BinaryOperator *CopyStruct =
1001       new(C) clang::BinaryOperator(LHS, RHS, clang::BO_Assign, QT,
1002                                    clang::VK_RValue, clang::OK_Ordinary, Loc,
1003                                    false);
1004   StmtArray[StmtCount++] = CopyStruct;
1005 
1006   clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
1007       C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
1008 
1009   delete [] StmtArray;
1010 
1011   return CS;
1012 }
1013 
1014 }  // namespace
1015 
InsertStmt(const clang::ASTContext & C,clang::Stmt * NewStmt)1016 void RSObjectRefCount::Scope::InsertStmt(const clang::ASTContext &C,
1017                                          clang::Stmt *NewStmt) {
1018   std::vector<clang::Stmt*> newBody;
1019   for (clang::Stmt* S1 : mCS->body()) {
1020     if (S1 == mCurrent) {
1021       newBody.push_back(NewStmt);
1022     }
1023     newBody.push_back(S1);
1024   }
1025   mCS->setStmts(C, newBody);
1026 }
1027 
ReplaceStmt(const clang::ASTContext & C,clang::Stmt * NewStmt)1028 void RSObjectRefCount::Scope::ReplaceStmt(const clang::ASTContext &C,
1029                                           clang::Stmt *NewStmt) {
1030   std::vector<clang::Stmt*> newBody;
1031   for (clang::Stmt* S1 : mCS->body()) {
1032     if (S1 == mCurrent) {
1033       newBody.push_back(NewStmt);
1034     } else {
1035       newBody.push_back(S1);
1036     }
1037   }
1038   mCS->setStmts(C, newBody);
1039 }
1040 
ReplaceExpr(const clang::ASTContext & C,clang::Expr * OldExpr,clang::Expr * NewExpr)1041 void RSObjectRefCount::Scope::ReplaceExpr(const clang::ASTContext& C,
1042                                           clang::Expr* OldExpr,
1043                                           clang::Expr* NewExpr) {
1044   RSASTReplace R(C);
1045   R.ReplaceStmt(mCurrent, OldExpr, NewExpr);
1046 }
1047 
ReplaceRSObjectAssignment(clang::BinaryOperator * AS)1048 void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
1049     clang::BinaryOperator *AS) {
1050 
1051   clang::QualType QT = AS->getType();
1052 
1053   clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1054       DataTypeRSAllocation)->getASTContext();
1055 
1056   clang::SourceLocation Loc = AS->getExprLoc();
1057   clang::SourceLocation StartLoc = AS->getLHS()->getExprLoc();
1058   clang::Stmt *UpdatedStmt = nullptr;
1059 
1060   if (!RSExportPrimitiveType::IsRSObjectType(QT.getTypePtr())) {
1061     // By definition, this is a struct assignment if we get here
1062     UpdatedStmt =
1063         CreateStructRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1064   } else {
1065     UpdatedStmt =
1066         CreateSingleRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1067   }
1068 
1069   RSASTReplace R(C);
1070   R.ReplaceStmt(mCS, AS, UpdatedStmt);
1071 }
1072 
AppendRSObjectInit(clang::VarDecl * VD,clang::DeclStmt * DS,DataType DT,clang::Expr * InitExpr)1073 void RSObjectRefCount::Scope::AppendRSObjectInit(
1074     clang::VarDecl *VD,
1075     clang::DeclStmt *DS,
1076     DataType DT,
1077     clang::Expr *InitExpr) {
1078   slangAssert(VD);
1079 
1080   if (!InitExpr) {
1081     return;
1082   }
1083 
1084   clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1085       DataTypeRSAllocation)->getASTContext();
1086   clang::SourceLocation Loc = RSObjectRefCount::GetRSSetObjectFD(
1087       DataTypeRSAllocation)->getLocation();
1088   clang::SourceLocation StartLoc = RSObjectRefCount::GetRSSetObjectFD(
1089       DataTypeRSAllocation)->getInnerLocStart();
1090 
1091   if (DT == DataTypeIsStruct) {
1092     const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1093     clang::DeclRefExpr *RefRSVar =
1094         clang::DeclRefExpr::Create(C,
1095                                    clang::NestedNameSpecifierLoc(),
1096                                    clang::SourceLocation(),
1097                                    VD,
1098                                    false,
1099                                    Loc,
1100                                    T->getCanonicalTypeInternal(),
1101                                    clang::VK_RValue,
1102                                    nullptr);
1103 
1104     clang::Stmt *RSSetObjectOps =
1105         CreateStructRSSetObject(C, RefRSVar, InitExpr, StartLoc, Loc);
1106     // Fix for b/37363420; consider:
1107     //
1108     // struct foo { rs_matrix m; };
1109     // void bar() {
1110     //   struct foo M = {...};
1111     // }
1112     //
1113     // slang modifies that declaration with initialization to a
1114     // declaration plus an assignment of the initialization values.
1115     //
1116     // void bar() {
1117     //   struct foo M = {};
1118     //   M = {...}; // by CreateStructRSSetObject() above
1119     // }
1120     //
1121     // the slang-generated statement (M = {...}) is a use of M, and we
1122     // need to mark M (clang::VarDecl *VD) as used.
1123     VD->markUsed(C);
1124 
1125     std::list<clang::Stmt*> StmtList;
1126     StmtList.push_back(RSSetObjectOps);
1127     AppendAfterStmt(C, mCS, DS, StmtList);
1128     return;
1129   }
1130 
1131   clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(DT);
1132   slangAssert((SetObjectFD != nullptr) &&
1133               "rsSetObject doesn't cover all RS object types");
1134 
1135   clang::QualType SetObjectFDType = SetObjectFD->getType();
1136   clang::QualType SetObjectFDArgType[2];
1137   SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
1138   SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
1139 
1140   clang::Expr *RefRSSetObjectFD =
1141       clang::DeclRefExpr::Create(C,
1142                                  clang::NestedNameSpecifierLoc(),
1143                                  clang::SourceLocation(),
1144                                  SetObjectFD,
1145                                  false,
1146                                  Loc,
1147                                  SetObjectFDType,
1148                                  clang::VK_RValue,
1149                                  nullptr);
1150 
1151   clang::Expr *RSSetObjectFP =
1152       clang::ImplicitCastExpr::Create(C,
1153                                       C.getPointerType(SetObjectFDType),
1154                                       clang::CK_FunctionToPointerDecay,
1155                                       RefRSSetObjectFD,
1156                                       nullptr,
1157                                       clang::VK_RValue);
1158 
1159   const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1160   clang::DeclRefExpr *RefRSVar =
1161       clang::DeclRefExpr::Create(C,
1162                                  clang::NestedNameSpecifierLoc(),
1163                                  clang::SourceLocation(),
1164                                  VD,
1165                                  false,
1166                                  Loc,
1167                                  T->getCanonicalTypeInternal(),
1168                                  clang::VK_RValue,
1169                                  nullptr);
1170 
1171   llvm::SmallVector<clang::Expr*, 2> ArgList;
1172   ArgList.push_back(new(C) clang::UnaryOperator(RefRSVar,
1173                                                 clang::UO_AddrOf,
1174                                                 SetObjectFDArgType[0],
1175                                                 clang::VK_RValue,
1176                                                 clang::OK_Ordinary,
1177                                                 Loc));
1178   ArgList.push_back(InitExpr);
1179 
1180   clang::CallExpr *RSSetObjectCall =
1181       new(C) clang::CallExpr(C,
1182                              RSSetObjectFP,
1183                              ArgList,
1184                              SetObjectFD->getCallResultType(),
1185                              clang::VK_RValue,
1186                              Loc);
1187 
1188   std::list<clang::Stmt*> StmtList;
1189   StmtList.push_back(RSSetObjectCall);
1190   AppendAfterStmt(C, mCS, DS, StmtList);
1191 }
1192 
InsertLocalVarDestructors()1193 void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
1194   if (mRSO.empty()) {
1195     return;
1196   }
1197 
1198   clang::DeclContext* DC = mRSO.front()->getDeclContext();
1199   clang::ASTContext& C = DC->getParentASTContext();
1200   clang::SourceManager& SM = C.getSourceManager();
1201 
1202   const auto& OccursBefore = [&SM] (clang::SourceLocation L1, clang::SourceLocation L2)->bool {
1203     return SM.isBeforeInTranslationUnit(L1, L2);
1204   };
1205   typedef std::map<clang::SourceLocation, clang::Stmt*, decltype(OccursBefore)> DMap;
1206 
1207   DMap dtors(OccursBefore);
1208 
1209   // Create rsClearObject calls. Note the DMap entries are sorted by the SourceLocation.
1210   for (clang::VarDecl* VD : mRSO) {
1211     clang::SourceLocation Loc = VD->getSourceRange().getBegin();
1212     clang::Stmt* RSClearObjectCall = ClearRSObject(VD, DC);
1213     dtors.insert(std::make_pair(Loc, RSClearObjectCall));
1214   }
1215 
1216   DestructorVisitor Visitor;
1217   Visitor.Visit(mCS);
1218 
1219   // Replace each exiting statement with a block that contains the original statement
1220   // and added rsClearObject() calls before it.
1221   for (clang::Stmt* S : Visitor.getExitingStmts()) {
1222 
1223     const clang::SourceLocation currentLoc = S->getLocStart();
1224 
1225     DMap::iterator firstDtorIter = dtors.begin();
1226     DMap::iterator currentDtorIter = firstDtorIter;
1227     DMap::iterator lastDtorIter = dtors.end();
1228 
1229     while (currentDtorIter != lastDtorIter &&
1230            OccursBefore(currentDtorIter->first, currentLoc)) {
1231       currentDtorIter++;
1232     }
1233 
1234     if (currentDtorIter == firstDtorIter) {
1235       continue;
1236     }
1237 
1238     std::vector<clang::Stmt*> Stmts;
1239 
1240     // Insert rsClearObject() calls for all rsObjects declared before the current statement
1241     for(DMap::iterator it = firstDtorIter; it != currentDtorIter; it++) {
1242       Stmts.push_back(it->second);
1243     }
1244     Stmts.push_back(S);
1245 
1246     RSASTReplace R(C);
1247     clang::CompoundStmt* CS = BuildCompoundStmt(C, Stmts, S->getLocEnd());
1248     R.ReplaceStmt(mCS, S, CS);
1249   }
1250 
1251   std::list<clang::Stmt*> Stmts;
1252   for(auto LocCallPair : dtors) {
1253     Stmts.push_back(LocCallPair.second);
1254   }
1255   AppendAfterStmt(C, mCS, nullptr, Stmts);
1256 }
1257 
ClearRSObject(clang::VarDecl * VD,clang::DeclContext * DC)1258 clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(
1259     clang::VarDecl *VD,
1260     clang::DeclContext *DC) {
1261   slangAssert(VD);
1262   clang::ASTContext &C = VD->getASTContext();
1263   clang::SourceLocation Loc = VD->getLocation();
1264   clang::SourceLocation StartLoc = VD->getInnerLocStart();
1265   const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1266 
1267   // Reference expr to target RS object variable
1268   clang::DeclRefExpr *RefRSVar =
1269       clang::DeclRefExpr::Create(C,
1270                                  clang::NestedNameSpecifierLoc(),
1271                                  clang::SourceLocation(),
1272                                  VD,
1273                                  false,
1274                                  Loc,
1275                                  T->getCanonicalTypeInternal(),
1276                                  clang::VK_RValue,
1277                                  nullptr);
1278 
1279   if (T->isArrayType()) {
1280     return ClearArrayRSObject(C, DC, RefRSVar, StartLoc, Loc);
1281   }
1282 
1283   DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
1284 
1285   if (DT == DataTypeUnknown ||
1286       DT == DataTypeIsStruct) {
1287     return ClearStructRSObject(C, DC, RefRSVar, StartLoc, Loc);
1288   }
1289 
1290   slangAssert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
1291               "Should be RS object");
1292 
1293   return ClearSingleRSObject(C, RefRSVar, Loc);
1294 }
1295 
InitializeRSObject(clang::VarDecl * VD,DataType * DT,clang::Expr ** InitExpr)1296 bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD,
1297                                           DataType *DT,
1298                                           clang::Expr **InitExpr) {
1299   slangAssert(VD && DT && InitExpr);
1300   const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1301 
1302   // Loop through array types to get to base type
1303   slangAssert(T);
1304   while (T->isArrayType()) {
1305     T = T->getArrayElementTypeNoTypeQual();
1306     slangAssert(T);
1307   }
1308 
1309   bool DataTypeIsStructWithRSObject = false;
1310   *DT = RSExportPrimitiveType::GetRSSpecificType(T);
1311 
1312   if (*DT == DataTypeUnknown) {
1313     if (RSExportPrimitiveType::IsStructureTypeWithRSObject(T)) {
1314       *DT = DataTypeIsStruct;
1315       DataTypeIsStructWithRSObject = true;
1316     } else {
1317       return false;
1318     }
1319   }
1320 
1321   bool DataTypeIsRSObject = false;
1322   if (DataTypeIsStructWithRSObject) {
1323     DataTypeIsRSObject = true;
1324   } else {
1325     DataTypeIsRSObject = RSExportPrimitiveType::IsRSObjectType(*DT);
1326   }
1327   *InitExpr = VD->getInit();
1328 
1329   if (!DataTypeIsRSObject && *InitExpr) {
1330     // If we already have an initializer for a matrix type, we are done.
1331     return DataTypeIsRSObject;
1332   }
1333 
1334   clang::Expr *ZeroInitializer =
1335       CreateEmptyInitListExpr(VD->getASTContext(), VD->getLocation());
1336 
1337   if (ZeroInitializer) {
1338     ZeroInitializer->setType(T->getCanonicalTypeInternal());
1339     VD->setInit(ZeroInitializer);
1340   }
1341 
1342   return DataTypeIsRSObject;
1343 }
1344 
CreateEmptyInitListExpr(clang::ASTContext & C,const clang::SourceLocation & Loc)1345 clang::Expr *RSObjectRefCount::CreateEmptyInitListExpr(
1346     clang::ASTContext &C,
1347     const clang::SourceLocation &Loc) {
1348 
1349   // We can cheaply construct a zero initializer by just creating an empty
1350   // initializer list. Clang supports this extension to C(99), and will create
1351   // any necessary constructs to zero out the entire variable.
1352   llvm::SmallVector<clang::Expr*, 1> EmptyInitList;
1353   return new(C) clang::InitListExpr(C, Loc, EmptyInitList, Loc);
1354 }
1355 
CreateGuard(clang::ASTContext & C,clang::DeclContext * DC,clang::Expr * E,const llvm::Twine & VarName,std::vector<clang::Stmt * > & NewStmts)1356 clang::DeclRefExpr *RSObjectRefCount::CreateGuard(clang::ASTContext &C,
1357                                                   clang::DeclContext *DC,
1358                                                   clang::Expr *E,
1359                                                   const llvm::Twine &VarName,
1360                                                   std::vector<clang::Stmt*> &NewStmts) {
1361   clang::SourceLocation Loc = E->getLocStart();
1362   const clang::QualType Ty = E->getType();
1363   clang::VarDecl* TmpDecl = clang::VarDecl::Create(
1364       C,                                     // AST context
1365       DC,                                    // Decl context
1366       Loc,                                   // Start location
1367       Loc,                                   // Id location
1368       &C.Idents.get(VarName.str()),          // Id
1369       Ty,                                    // Type
1370       C.getTrivialTypeSourceInfo(Ty),        // Type info
1371       clang::SC_None                         // Storage class
1372   );
1373   const clang::Type *T = Ty.getTypePtr();
1374   clang::Expr *ZeroInitializer =
1375       RSObjectRefCount::CreateEmptyInitListExpr(C, Loc);
1376   ZeroInitializer->setType(T->getCanonicalTypeInternal());
1377   TmpDecl->setInit(ZeroInitializer);
1378   TmpDecl->markUsed(C);
1379   clang::Decl* Decls[] = { TmpDecl };
1380   const clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(
1381       C, Decls, sizeof(Decls) / sizeof(*Decls));
1382   clang::DeclStmt* DS = new (C) clang::DeclStmt(DGR, Loc, Loc);
1383   NewStmts.push_back(DS);
1384 
1385   clang::DeclRefExpr* DRE = clang::DeclRefExpr::Create(
1386       C,
1387       clang::NestedNameSpecifierLoc(),       // QualifierLoc
1388       Loc,                                   // TemplateKWLoc
1389       TmpDecl,
1390       false,                                 // RefersToEnclosingVariableOrCapture
1391       Loc,                                   // NameLoc
1392       Ty,
1393       clang::VK_LValue
1394   );
1395 
1396   clang::Stmt *UpdatedStmt = nullptr;
1397   if (CountRSObjectTypes(Ty.getTypePtr()) == 0) {
1398     // The expression E is not an RS object itself. Instead of calling
1399     // rsSetObject(), create an assignment statement to set the value of the
1400     // temporary "guard" variable to the expression.
1401     // This can happen if called from RSObjectRefCount::VisitReturnStmt(),
1402     // when the return expression is not an RS object but references one.
1403     UpdatedStmt =
1404       new(C) clang::BinaryOperator(DRE, E, clang::BO_Assign, Ty,
1405                                    clang::VK_RValue, clang::OK_Ordinary, Loc,
1406                                    false);
1407 
1408   } else if (!RSExportPrimitiveType::IsRSObjectType(Ty.getTypePtr())) {
1409     // By definition, this is a struct assignment if we get here
1410     UpdatedStmt =
1411         CreateStructRSSetObject(C, DRE, E, Loc, Loc);
1412   } else {
1413     UpdatedStmt =
1414         CreateSingleRSSetObject(C, DRE, E, Loc, Loc);
1415   }
1416   NewStmts.push_back(UpdatedStmt);
1417 
1418   return DRE;
1419 }
1420 
CreateParameterGuard(clang::ASTContext & C,clang::DeclContext * DC,clang::ParmVarDecl * PD,std::vector<clang::Stmt * > & NewStmts)1421 void RSObjectRefCount::CreateParameterGuard(clang::ASTContext &C,
1422                                             clang::DeclContext *DC,
1423                                             clang::ParmVarDecl *PD,
1424                                             std::vector<clang::Stmt*> &NewStmts) {
1425   clang::SourceLocation Loc = PD->getLocStart();
1426   clang::DeclRefExpr* ParamDRE = clang::DeclRefExpr::Create(
1427       C,
1428       clang::NestedNameSpecifierLoc(),       // QualifierLoc
1429       Loc,                                   // TemplateKWLoc
1430       PD,
1431       false,                                 // RefersToEnclosingVariableOrCapture
1432       Loc,                                   // NameLoc
1433       PD->getType(),
1434       clang::VK_RValue
1435   );
1436 
1437   CreateGuard(C, DC, ParamDRE,
1438               llvm::Twine(".rs.param.") + llvm::Twine(PD->getName()), NewStmts);
1439 }
1440 
HandleParamsAndLocals(clang::FunctionDecl * FD)1441 void RSObjectRefCount::HandleParamsAndLocals(clang::FunctionDecl *FD) {
1442   std::vector<clang::Stmt*> NewStmts;
1443   std::list<clang::ParmVarDecl*> ObjParams;
1444   for (clang::ParmVarDecl *Param : FD->parameters()) {
1445     clang::QualType QT = Param->getType();
1446     if (CountRSObjectTypes(QT.getTypePtr())) {
1447       // Ignore non-object types
1448       RSObjectRefCount::CreateParameterGuard(mCtx, FD, Param, NewStmts);
1449       ObjParams.push_back(Param);
1450     }
1451   }
1452 
1453   clang::Stmt *OldBody = FD->getBody();
1454   if (ObjParams.empty()) {
1455     Visit(OldBody);
1456     return;
1457   }
1458 
1459   NewStmts.push_back(OldBody);
1460 
1461   clang::SourceLocation Loc = FD->getLocStart();
1462   clang::CompoundStmt *NewBody = BuildCompoundStmt(mCtx, NewStmts, Loc);
1463   Scope S(NewBody);
1464   for (clang::ParmVarDecl *Param : ObjParams) {
1465     S.addRSObject(Param);
1466   }
1467   mScopeStack.push_back(&S);
1468 
1469   // To avoid adding unnecessary ref counting artifacts to newly added temporary
1470   // local variables for parameters, visits only the old function body here.
1471   Visit(OldBody);
1472 
1473   FD->setBody(NewBody);
1474 
1475   S.InsertLocalVarDestructors();
1476   mScopeStack.pop_back();
1477 }
1478 
CreateRetStmtWithTempVar(clang::ASTContext & C,clang::DeclContext * DC,clang::ReturnStmt * RS,const unsigned id)1479 clang::CompoundStmt* RSObjectRefCount::CreateRetStmtWithTempVar(
1480     clang::ASTContext& C,
1481     clang::DeclContext* DC,
1482     clang::ReturnStmt* RS,
1483     const unsigned id) {
1484   std::vector<clang::Stmt*> NewStmts;
1485   // Since we insert rsClearObj() calls before the return statement, we need
1486   // to make sure none of the cleared RS objects are referenced in the
1487   // return statement.
1488   // For that, we create a new local variable named .rs.retval, assign the
1489   // original return expression to it, make all necessary rsClearObj()
1490   // calls, then return .rs.retval. Note rsClearObj() is not called on
1491   // .rs.retval.
1492   clang::SourceLocation Loc = RS->getLocStart();
1493   clang::Expr* RetVal = RS->getRetValue();
1494   const clang::QualType RetTy = RetVal->getType();
1495   clang::DeclRefExpr *DRE = CreateGuard(C, DC, RetVal,
1496                                         llvm::Twine(".rs.retval") + llvm::Twine(id),
1497                                         NewStmts);
1498 
1499   // Creates a new return statement
1500   clang::ReturnStmt* NewRet = new (C) clang::ReturnStmt(Loc);
1501   clang::Expr* CastExpr = clang::ImplicitCastExpr::Create(
1502       C,
1503       RetTy,
1504       clang::CK_LValueToRValue,
1505       DRE,
1506       nullptr,
1507       clang::VK_RValue
1508   );
1509   NewRet->setRetValue(CastExpr);
1510   NewStmts.push_back(NewRet);
1511 
1512   return BuildCompoundStmt(C, NewStmts, Loc);
1513 }
1514 
VisitDeclStmt(clang::DeclStmt * DS)1515 void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
1516   VisitStmt(DS);
1517   getCurrentScope()->setCurrentStmt(DS);
1518   for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
1519        I != E;
1520        I++) {
1521     clang::Decl *D = *I;
1522     if (D->getKind() == clang::Decl::Var) {
1523       clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
1524       DataType DT = DataTypeUnknown;
1525       clang::Expr *InitExpr = nullptr;
1526       if (InitializeRSObject(VD, &DT, &InitExpr)) {
1527         // We need to zero-init all RS object types (including matrices), ...
1528         getCurrentScope()->AppendRSObjectInit(VD, DS, DT, InitExpr);
1529         // ... but, only add to the list of RS objects if we have some
1530         // non-matrix RS object fields.
1531         if (CountRSObjectTypes(VD->getType().getTypePtr())) {
1532           getCurrentScope()->addRSObject(VD);
1533         }
1534       }
1535     }
1536   }
1537 }
1538 
VisitCallExpr(clang::CallExpr * CE)1539 void RSObjectRefCount::VisitCallExpr(clang::CallExpr* CE) {
1540   clang::QualType RetTy;
1541   const clang::FunctionDecl* FD = CE->getDirectCallee();
1542 
1543   if (FD) {
1544     // Direct calls
1545 
1546     RetTy = FD->getReturnType();
1547   } else {
1548     // Indirect calls through function pointers
1549 
1550     const clang::Expr* Callee = CE->getCallee();
1551     const clang::Type* CalleeType = Callee->getType().getTypePtr();
1552     const clang::PointerType* PtrType = CalleeType->getAs<clang::PointerType>();
1553 
1554     if (!PtrType) {
1555       return;
1556     }
1557 
1558     const clang::Type* PointeeType = PtrType->getPointeeType().getTypePtr();
1559     const clang::FunctionType* FuncType = PointeeType->getAs<clang::FunctionType>();
1560 
1561     if (!FuncType) {
1562       return;
1563     }
1564 
1565     RetTy = FuncType->getReturnType();
1566   }
1567 
1568   // The RenderScript runtime API maintains the invariant that the sysRef of a new RS object would
1569   // be 1, with the exception of rsGetAllocation() (deprecated in API 22), which leaves the sysRef
1570   // 0 for a new allocation. It is the responsibility of the callee of the API to decrement the
1571   // sysRef when a reference of the RS object goes out of scope. The compiler generates code to do
1572   // just that, by creating a temporary variable named ".rs.tmpN" with the result of
1573   // an RS-object-returning API directly assigned to it, and calling rsClearObject() on .rs.tmpN
1574   // right before it exits the current scope. Such code generation is skipped for rsGetAllocation()
1575   // to avoid decrementing its sysRef below zero.
1576 
1577   if (CountRSObjectTypes(RetTy.getTypePtr())==0 ||
1578       (FD && FD->getName() == "rsGetAllocation")) {
1579     return;
1580   }
1581 
1582   clang::SourceLocation Loc = CE->getSourceRange().getBegin();
1583   std::stringstream ss;
1584   ss << ".rs.tmp" << getNextID();
1585   clang::IdentifierInfo *II = &mCtx.Idents.get(ss.str());
1586 
1587   clang::VarDecl* TempVarDecl = clang::VarDecl::Create(
1588       mCtx,                                  // AST context
1589       GetDeclContext(),                      // Decl context
1590       Loc,                                   // Start location
1591       Loc,                                   // Id location
1592       II,                                    // Id
1593       RetTy,                                 // Type
1594       mCtx.getTrivialTypeSourceInfo(RetTy),  // Type info
1595       clang::SC_None                         // Storage class
1596   );
1597   TempVarDecl->setInit(CE);
1598   TempVarDecl->markUsed(mCtx);
1599   clang::Decl* Decls[] = { TempVarDecl };
1600   const clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(
1601       mCtx, Decls, sizeof(Decls) / sizeof(*Decls));
1602   clang::DeclStmt* DS = new (mCtx) clang::DeclStmt(DGR, Loc, Loc);
1603 
1604   getCurrentScope()->InsertStmt(mCtx, DS);
1605 
1606   clang::DeclRefExpr* DRE = clang::DeclRefExpr::Create(
1607       mCtx,                                  // AST context
1608       clang::NestedNameSpecifierLoc(),       // QualifierLoc
1609       Loc,                                   // TemplateKWLoc
1610       TempVarDecl,
1611       false,                                 // RefersToEnclosingVariableOrCapture
1612       Loc,                                   // NameLoc
1613       RetTy,
1614       clang::VK_LValue
1615   );
1616   clang::Expr* CastExpr = clang::ImplicitCastExpr::Create(
1617       mCtx,
1618       RetTy,
1619       clang::CK_LValueToRValue,
1620       DRE,
1621       nullptr,
1622       clang::VK_RValue
1623   );
1624 
1625   getCurrentScope()->ReplaceExpr(mCtx, CE, CastExpr);
1626 
1627   // Register TempVarDecl for destruction call (rsClearObj).
1628   getCurrentScope()->addRSObject(TempVarDecl);
1629 }
1630 
VisitCompoundStmt(clang::CompoundStmt * CS)1631 void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
1632   if (!emptyScope()) {
1633     getCurrentScope()->setCurrentStmt(CS);
1634   }
1635 
1636   if (!CS->body_empty()) {
1637     // Push a new scope
1638     Scope *S = new Scope(CS);
1639     mScopeStack.push_back(S);
1640 
1641     VisitStmt(CS);
1642 
1643     // Destroy the scope
1644     slangAssert((getCurrentScope() == S) && "Corrupted scope stack!");
1645     S->InsertLocalVarDestructors();
1646     mScopeStack.pop_back();
1647     delete S;
1648   }
1649 }
1650 
VisitBinAssign(clang::BinaryOperator * AS)1651 void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
1652   getCurrentScope()->setCurrentStmt(AS);
1653   clang::QualType QT = AS->getType();
1654 
1655   if (CountRSObjectTypes(QT.getTypePtr())) {
1656     getCurrentScope()->ReplaceRSObjectAssignment(AS);
1657   }
1658 }
1659 
1660 namespace {
1661 
1662 class FindRSObjRefVisitor : public clang::RecursiveASTVisitor<FindRSObjRefVisitor> {
1663 public:
FindRSObjRefVisitor()1664   explicit FindRSObjRefVisitor() : mRefRSObj(false) {}
VisitExpr(clang::Expr * Expression)1665   bool VisitExpr(clang::Expr* Expression) {
1666     if (CountRSObjectTypes(Expression->getType().getTypePtr()) > 0) {
1667       mRefRSObj = true;
1668       // Found a reference to an RS object. Stop the AST traversal.
1669       return false;
1670     }
1671     return true;
1672   }
1673 
foundRSObjRef() const1674   bool foundRSObjRef() const { return mRefRSObj; }
1675 
1676 private:
1677   bool mRefRSObj;
1678 };
1679 
1680 }  // anonymous namespace
1681 
VisitReturnStmt(clang::ReturnStmt * RS)1682 void RSObjectRefCount::VisitReturnStmt(clang::ReturnStmt *RS) {
1683   getCurrentScope()->setCurrentStmt(RS);
1684 
1685   // If there is no local rsObject declared so far, no need to transform the
1686   // return statement.
1687 
1688   bool RSObjDeclared = false;
1689 
1690   for (const Scope* S : mScopeStack) {
1691     if (S->hasRSObject()) {
1692       RSObjDeclared = true;
1693       break;
1694     }
1695   }
1696 
1697   if (!RSObjDeclared) {
1698     return;
1699   }
1700 
1701   FindRSObjRefVisitor visitor;
1702 
1703   visitor.TraverseStmt(RS);
1704 
1705   // If the return statement does not return anything, or if it does not reference
1706   // a rsObject, no need to transform it.
1707 
1708   if (!visitor.foundRSObjRef()) {
1709     return;
1710   }
1711 
1712   // Transform the return statement so that it does not potentially return or
1713   // reference a rsObject that has been cleared.
1714 
1715   clang::CompoundStmt* NewRS;
1716   NewRS = CreateRetStmtWithTempVar(mCtx, GetDeclContext(), RS, getNextID());
1717 
1718   getCurrentScope()->ReplaceStmt(mCtx, NewRS);
1719 }
1720 
VisitStmt(clang::Stmt * S)1721 void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
1722   getCurrentScope()->setCurrentStmt(S);
1723   for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
1724        I != E;
1725        I++) {
1726     if (clang::Stmt *Child = *I) {
1727       Visit(Child);
1728     }
1729   }
1730 }
1731 
1732 // This function walks the list of global variables and (potentially) creates
1733 // a single global static destructor function that properly decrements
1734 // reference counts on the contained RS object types.
CreateStaticGlobalDtor()1735 clang::FunctionDecl *RSObjectRefCount::CreateStaticGlobalDtor() {
1736   Init();
1737 
1738   clang::DeclContext *DC = mCtx.getTranslationUnitDecl();
1739   clang::SourceLocation loc;
1740 
1741   llvm::StringRef SR(".rs.dtor");
1742   clang::IdentifierInfo &II = mCtx.Idents.get(SR);
1743   clang::DeclarationName N(&II);
1744   clang::FunctionProtoType::ExtProtoInfo EPI;
1745   clang::QualType T = mCtx.getFunctionType(mCtx.VoidTy,
1746       llvm::ArrayRef<clang::QualType>(), EPI);
1747   clang::FunctionDecl *FD = nullptr;
1748 
1749   // Generate rsClearObject() call chains for every global variable
1750   // (whether static or extern).
1751   std::vector<clang::Stmt *> StmtList;
1752   for (clang::DeclContext::decl_iterator I = DC->decls_begin(),
1753           E = DC->decls_end(); I != E; I++) {
1754     clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*I);
1755     if (VD) {
1756       if (CountRSObjectTypes(VD->getType().getTypePtr())) {
1757         if (!FD) {
1758           // Only create FD if we are going to use it.
1759           FD = clang::FunctionDecl::Create(mCtx, DC, loc, loc, N, T, nullptr,
1760                                            clang::SC_None);
1761         }
1762         // Mark VD as used.  It might be unused, except for the destructor.
1763         // 'markUsed' has side-effects that are caused only if VD is not already
1764         // used.  Hence no need for an extra check here.
1765         VD->markUsed(mCtx);
1766         // Make sure to create any helpers within the function's DeclContext,
1767         // not the one associated with the global translation unit.
1768         clang::Stmt *RSClearObjectCall = Scope::ClearRSObject(VD, FD);
1769         StmtList.push_back(RSClearObjectCall);
1770       }
1771     }
1772   }
1773 
1774   // Nothing needs to be destroyed, so don't emit a dtor.
1775   if (StmtList.empty()) {
1776     return nullptr;
1777   }
1778 
1779   clang::CompoundStmt *CS = BuildCompoundStmt(mCtx, StmtList, loc);
1780 
1781   slangAssert(FD);
1782   FD->setBody(CS);
1783   // We need some way to tell if this FD is generated by slang
1784   FD->setImplicit();
1785 
1786   return FD;
1787 }
1788 
HasRSObjectType(const clang::Type * T)1789 bool HasRSObjectType(const clang::Type *T) {
1790   return CountRSObjectTypes(T) != 0;
1791 }
1792 
1793 }  // namespace slang
1794