1 /*
2  * Copyright 2010, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "slang_rs_object_ref_count.h"
18 
19 #include "clang/AST/DeclGroup.h"
20 #include "clang/AST/Expr.h"
21 #include "clang/AST/NestedNameSpecifier.h"
22 #include "clang/AST/OperationKinds.h"
23 #include "clang/AST/RecursiveASTVisitor.h"
24 #include "clang/AST/Stmt.h"
25 #include "clang/AST/StmtVisitor.h"
26 
27 #include "slang_assert.h"
28 #include "slang.h"
29 #include "slang_rs_ast_replace.h"
30 #include "slang_rs_export_type.h"
31 
32 namespace slang {
33 
34 /* Even though those two arrays are of size DataTypeMax, only entries that
35  * correspond to object types will be set.
36  */
37 clang::FunctionDecl *
38 RSObjectRefCount::RSSetObjectFD[DataTypeMax];
39 clang::FunctionDecl *
40 RSObjectRefCount::RSClearObjectFD[DataTypeMax];
41 
GetRSRefCountingFunctions(clang::ASTContext & C)42 void RSObjectRefCount::GetRSRefCountingFunctions(clang::ASTContext &C) {
43   for (unsigned i = 0; i < DataTypeMax; i++) {
44     RSSetObjectFD[i] = nullptr;
45     RSClearObjectFD[i] = nullptr;
46   }
47 
48   clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
49 
50   for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
51           E = TUDecl->decls_end(); I != E; I++) {
52     if ((I->getKind() >= clang::Decl::firstFunction) &&
53         (I->getKind() <= clang::Decl::lastFunction)) {
54       clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
55 
56       // points to RSSetObjectFD or RSClearObjectFD
57       clang::FunctionDecl **RSObjectFD;
58 
59       if (FD->getName() == "rsSetObject") {
60         slangAssert((FD->getNumParams() == 2) &&
61                     "Invalid rsSetObject function prototype (# params)");
62         RSObjectFD = RSSetObjectFD;
63       } else if (FD->getName() == "rsClearObject") {
64         slangAssert((FD->getNumParams() == 1) &&
65                     "Invalid rsClearObject function prototype (# params)");
66         RSObjectFD = RSClearObjectFD;
67       } else {
68         continue;
69       }
70 
71       const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
72       clang::QualType PVT = PVD->getOriginalType();
73       // The first parameter must be a pointer like rs_allocation*
74       slangAssert(PVT->isPointerType() &&
75           "Invalid rs{Set,Clear}Object function prototype (pointer param)");
76 
77       // The rs object type passed to the FD
78       clang::QualType RST = PVT->getPointeeType();
79       DataType DT = RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
80       slangAssert(RSExportPrimitiveType::IsRSObjectType(DT)
81              && "must be RS object type");
82 
83       if (DT >= 0 && DT < DataTypeMax) {
84           RSObjectFD[DT] = FD;
85       } else {
86           slangAssert(false && "incorrect type");
87       }
88     }
89   }
90 }
91 
92 namespace {
93 
94 unsigned CountRSObjectTypes(const clang::Type *T);
95 
96 clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
97                                      clang::Expr *DstExpr,
98                                      clang::Expr *SrcExpr,
99                                      clang::SourceLocation StartLoc,
100                                      clang::SourceLocation Loc);
101 
102 // This function constructs a new CompoundStmt from the input StmtList.
BuildCompoundStmt(clang::ASTContext & C,std::vector<clang::Stmt * > & StmtList,clang::SourceLocation Loc)103 clang::CompoundStmt* BuildCompoundStmt(clang::ASTContext &C,
104       std::vector<clang::Stmt*> &StmtList, clang::SourceLocation Loc) {
105   unsigned NewStmtCount = StmtList.size();
106   unsigned CompoundStmtCount = 0;
107 
108   clang::Stmt **CompoundStmtList;
109   CompoundStmtList = new clang::Stmt*[NewStmtCount];
110 
111   std::vector<clang::Stmt*>::const_iterator I = StmtList.begin();
112   std::vector<clang::Stmt*>::const_iterator E = StmtList.end();
113   for ( ; I != E; I++) {
114     CompoundStmtList[CompoundStmtCount++] = *I;
115   }
116   slangAssert(CompoundStmtCount == NewStmtCount);
117 
118   clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
119       C, llvm::makeArrayRef(CompoundStmtList, CompoundStmtCount), Loc, Loc);
120 
121   delete [] CompoundStmtList;
122 
123   return CS;
124 }
125 
AppendAfterStmt(clang::ASTContext & C,clang::CompoundStmt * CS,clang::Stmt * S,std::list<clang::Stmt * > & StmtList)126 void AppendAfterStmt(clang::ASTContext &C,
127                      clang::CompoundStmt *CS,
128                      clang::Stmt *S,
129                      std::list<clang::Stmt*> &StmtList) {
130   slangAssert(CS);
131   clang::CompoundStmt::body_iterator bI = CS->body_begin();
132   clang::CompoundStmt::body_iterator bE = CS->body_end();
133   clang::Stmt **UpdatedStmtList =
134       new clang::Stmt*[CS->size() + StmtList.size()];
135 
136   unsigned UpdatedStmtCount = 0;
137   unsigned Once = 0;
138   for ( ; bI != bE; bI++) {
139     if (!S && ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass)) {
140       // If we come across a return here, we don't have anything we can
141       // reasonably replace. We should have already inserted our destructor
142       // code in the proper spot, so we just clean up and return.
143       delete [] UpdatedStmtList;
144 
145       return;
146     }
147 
148     UpdatedStmtList[UpdatedStmtCount++] = *bI;
149 
150     if ((*bI == S) && !Once) {
151       Once++;
152       std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
153       std::list<clang::Stmt*>::const_iterator E = StmtList.end();
154       for ( ; I != E; I++) {
155         UpdatedStmtList[UpdatedStmtCount++] = *I;
156       }
157     }
158   }
159   slangAssert(Once <= 1);
160 
161   // When S is nullptr, we are appending to the end of the CompoundStmt.
162   if (!S) {
163     slangAssert(Once == 0);
164     std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
165     std::list<clang::Stmt*>::const_iterator E = StmtList.end();
166     for ( ; I != E; I++) {
167       UpdatedStmtList[UpdatedStmtCount++] = *I;
168     }
169   }
170 
171   CS->setStmts(C, llvm::makeArrayRef(UpdatedStmtList, UpdatedStmtCount));
172 
173   delete [] UpdatedStmtList;
174 }
175 
176 // This class visits a compound statement and collects a list of all the exiting
177 // statements, such as any return statement in any sub-block, and any
178 // break/continue statement that would resume outside the current scope.
179 // We do not handle the case for goto statements that leave a local scope.
180 class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
181  private:
182   // The loop depth of the currently visited node.
183   int mLoopDepth;
184 
185   // The switch statement depth of the currently visited node.
186   // Note that this is tracked separately from the loop depth because
187   // SwitchStmt-contained ContinueStmt's should have destructors for the
188   // corresponding loop scope.
189   int mSwitchDepth;
190 
191   // Output of the visitor: the statements that should be replaced by compound
192   // statements, each of which contains rsClearObject() calls followed by the
193   // original statement.
194   std::vector<clang::Stmt*> mExitingStmts;
195 
196  public:
DestructorVisitor()197   DestructorVisitor() : mLoopDepth(0), mSwitchDepth(0) {}
198 
getExitingStmts() const199   const std::vector<clang::Stmt*>& getExitingStmts() const {
200     return mExitingStmts;
201   }
202 
203   void VisitStmt(clang::Stmt *S);
204   void VisitBreakStmt(clang::BreakStmt *BS);
205   void VisitContinueStmt(clang::ContinueStmt *CS);
206   void VisitDoStmt(clang::DoStmt *DS);
207   void VisitForStmt(clang::ForStmt *FS);
208   void VisitReturnStmt(clang::ReturnStmt *RS);
209   void VisitSwitchStmt(clang::SwitchStmt *SS);
210   void VisitWhileStmt(clang::WhileStmt *WS);
211 };
212 
VisitStmt(clang::Stmt * S)213 void DestructorVisitor::VisitStmt(clang::Stmt *S) {
214   for (clang::Stmt* Child : S->children()) {
215     if (Child) {
216       Visit(Child);
217     }
218   }
219 }
220 
VisitBreakStmt(clang::BreakStmt * BS)221 void DestructorVisitor::VisitBreakStmt(clang::BreakStmt *BS) {
222   VisitStmt(BS);
223   if ((mLoopDepth == 0) && (mSwitchDepth == 0)) {
224     mExitingStmts.push_back(BS);
225   }
226 }
227 
VisitContinueStmt(clang::ContinueStmt * CS)228 void DestructorVisitor::VisitContinueStmt(clang::ContinueStmt *CS) {
229   VisitStmt(CS);
230   if (mLoopDepth == 0) {
231     // Switch statements can have nested continues.
232     mExitingStmts.push_back(CS);
233   }
234 }
235 
VisitDoStmt(clang::DoStmt * DS)236 void DestructorVisitor::VisitDoStmt(clang::DoStmt *DS) {
237   mLoopDepth++;
238   VisitStmt(DS);
239   mLoopDepth--;
240 }
241 
VisitForStmt(clang::ForStmt * FS)242 void DestructorVisitor::VisitForStmt(clang::ForStmt *FS) {
243   mLoopDepth++;
244   VisitStmt(FS);
245   mLoopDepth--;
246 }
247 
VisitReturnStmt(clang::ReturnStmt * RS)248 void DestructorVisitor::VisitReturnStmt(clang::ReturnStmt *RS) {
249   mExitingStmts.push_back(RS);
250 }
251 
VisitSwitchStmt(clang::SwitchStmt * SS)252 void DestructorVisitor::VisitSwitchStmt(clang::SwitchStmt *SS) {
253   mSwitchDepth++;
254   VisitStmt(SS);
255   mSwitchDepth--;
256 }
257 
VisitWhileStmt(clang::WhileStmt * WS)258 void DestructorVisitor::VisitWhileStmt(clang::WhileStmt *WS) {
259   mLoopDepth++;
260   VisitStmt(WS);
261   mLoopDepth--;
262 }
263 
ClearSingleRSObject(clang::ASTContext & C,clang::Expr * RefRSVar,clang::SourceLocation Loc)264 clang::Expr *ClearSingleRSObject(clang::ASTContext &C,
265                                  clang::Expr *RefRSVar,
266                                  clang::SourceLocation Loc) {
267   slangAssert(RefRSVar);
268   const clang::Type *T = RefRSVar->getType().getTypePtr();
269   slangAssert(!T->isArrayType() &&
270               "Should not be destroying arrays with this function");
271 
272   clang::FunctionDecl *ClearObjectFD = RSObjectRefCount::GetRSClearObjectFD(T);
273   slangAssert((ClearObjectFD != nullptr) &&
274               "rsClearObject doesn't cover all RS object types");
275 
276   clang::QualType ClearObjectFDType = ClearObjectFD->getType();
277   clang::QualType ClearObjectFDArgType =
278       ClearObjectFD->getParamDecl(0)->getOriginalType();
279 
280   // Example destructor for "rs_font localFont;"
281   //
282   // (CallExpr 'void'
283   //   (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
284   //     (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
285   //   (UnaryOperator 'rs_font *' prefix '&'
286   //     (DeclRefExpr 'rs_font':'rs_font' Var='localFont')))
287 
288   // Get address of targeted RS object
289   clang::Expr *AddrRefRSVar =
290       new(C) clang::UnaryOperator(RefRSVar,
291                                   clang::UO_AddrOf,
292                                   ClearObjectFDArgType,
293                                   clang::VK_RValue,
294                                   clang::OK_Ordinary,
295                                   Loc);
296 
297   clang::Expr *RefRSClearObjectFD =
298       clang::DeclRefExpr::Create(C,
299                                  clang::NestedNameSpecifierLoc(),
300                                  clang::SourceLocation(),
301                                  ClearObjectFD,
302                                  false,
303                                  ClearObjectFD->getLocation(),
304                                  ClearObjectFDType,
305                                  clang::VK_RValue,
306                                  nullptr);
307 
308   clang::Expr *RSClearObjectFP =
309       clang::ImplicitCastExpr::Create(C,
310                                       C.getPointerType(ClearObjectFDType),
311                                       clang::CK_FunctionToPointerDecay,
312                                       RefRSClearObjectFD,
313                                       nullptr,
314                                       clang::VK_RValue);
315 
316   llvm::SmallVector<clang::Expr*, 1> ArgList;
317   ArgList.push_back(AddrRefRSVar);
318 
319   clang::CallExpr *RSClearObjectCall =
320       new(C) clang::CallExpr(C,
321                              RSClearObjectFP,
322                              ArgList,
323                              ClearObjectFD->getCallResultType(),
324                              clang::VK_RValue,
325                              Loc);
326 
327   return RSClearObjectCall;
328 }
329 
ArrayDim(const clang::Type * T)330 static int ArrayDim(const clang::Type *T) {
331   if (!T || !T->isArrayType()) {
332     return 0;
333   }
334 
335   const clang::ConstantArrayType *CAT =
336     static_cast<const clang::ConstantArrayType *>(T);
337   return static_cast<int>(CAT->getSize().getSExtValue());
338 }
339 
340 clang::Stmt *ClearStructRSObject(
341     clang::ASTContext &C,
342     clang::DeclContext *DC,
343     clang::Expr *RefRSStruct,
344     clang::SourceLocation StartLoc,
345     clang::SourceLocation Loc);
346 
ClearArrayRSObject(clang::ASTContext & C,clang::DeclContext * DC,clang::Expr * RefRSArr,clang::SourceLocation StartLoc,clang::SourceLocation Loc)347 clang::Stmt *ClearArrayRSObject(
348     clang::ASTContext &C,
349     clang::DeclContext *DC,
350     clang::Expr *RefRSArr,
351     clang::SourceLocation StartLoc,
352     clang::SourceLocation Loc) {
353   const clang::Type *BaseType = RefRSArr->getType().getTypePtr();
354   slangAssert(BaseType->isArrayType());
355 
356   int NumArrayElements = ArrayDim(BaseType);
357   // Actually extract out the base RS object type for use later
358   BaseType = BaseType->getArrayElementTypeNoTypeQual();
359 
360   if (NumArrayElements <= 0) {
361     return nullptr;
362   }
363 
364   // Example destructor loop for "rs_font fontArr[10];"
365   //
366   // (ForStmt
367   //   (DeclStmt
368   //     (VarDecl used rsIntIter 'int' cinit
369   //       (IntegerLiteral 'int' 0)))
370   //   (BinaryOperator 'int' '<'
371   //     (ImplicitCastExpr int LValueToRValue
372   //       (DeclRefExpr 'int' Var='rsIntIter'))
373   //     (IntegerLiteral 'int' 10)
374   //   nullptr << CondVar >>
375   //   (UnaryOperator 'int' postfix '++'
376   //     (DeclRefExpr 'int' Var='rsIntIter'))
377   //   (CallExpr 'void'
378   //     (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
379   //       (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
380   //     (UnaryOperator 'rs_font *' prefix '&'
381   //       (ArraySubscriptExpr 'rs_font':'rs_font'
382   //         (ImplicitCastExpr 'rs_font *' <ArrayToPointerDecay>
383   //           (DeclRefExpr 'rs_font [10]' Var='fontArr'))
384   //         (DeclRefExpr 'int' Var='rsIntIter'))))))
385 
386   // Create helper variable for iterating through elements
387   static unsigned sIterCounter = 0;
388   std::stringstream UniqueIterName;
389   UniqueIterName << "rsIntIter" << sIterCounter++;
390   clang::IdentifierInfo *II = &C.Idents.get(UniqueIterName.str());
391   clang::VarDecl *IIVD =
392       clang::VarDecl::Create(C,
393                              DC,
394                              StartLoc,
395                              Loc,
396                              II,
397                              C.IntTy,
398                              C.getTrivialTypeSourceInfo(C.IntTy),
399                              clang::SC_None);
400   // Mark "rsIntIter" as used
401   IIVD->markUsed(C);
402 
403   // Form the actual destructor loop
404   // for (Init; Cond; Inc)
405   //   RSClearObjectCall;
406 
407   // Init -> "int rsIntIter = 0"
408   clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
409       llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
410   IIVD->setInit(Int0);
411 
412   clang::Decl *IID = (clang::Decl *)IIVD;
413   clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
414   clang::Stmt *Init = new(C) clang::DeclStmt(DGR, Loc, Loc);
415 
416   // Cond -> "rsIntIter < NumArrayElements"
417   clang::DeclRefExpr *RefrsIntIterLValue =
418       clang::DeclRefExpr::Create(C,
419                                  clang::NestedNameSpecifierLoc(),
420                                  clang::SourceLocation(),
421                                  IIVD,
422                                  false,
423                                  Loc,
424                                  C.IntTy,
425                                  clang::VK_LValue,
426                                  nullptr);
427 
428   clang::Expr *RefrsIntIterRValue =
429       clang::ImplicitCastExpr::Create(C,
430                                       RefrsIntIterLValue->getType(),
431                                       clang::CK_LValueToRValue,
432                                       RefrsIntIterLValue,
433                                       nullptr,
434                                       clang::VK_RValue);
435 
436   clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
437       llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
438 
439   clang::BinaryOperator *Cond =
440       new(C) clang::BinaryOperator(RefrsIntIterRValue,
441                                    NumArrayElementsExpr,
442                                    clang::BO_LT,
443                                    C.IntTy,
444                                    clang::VK_RValue,
445                                    clang::OK_Ordinary,
446                                    Loc,
447                                    false);
448 
449   // Inc -> "rsIntIter++"
450   clang::UnaryOperator *Inc =
451       new(C) clang::UnaryOperator(RefrsIntIterLValue,
452                                   clang::UO_PostInc,
453                                   C.IntTy,
454                                   clang::VK_RValue,
455                                   clang::OK_Ordinary,
456                                   Loc);
457 
458   // Body -> "rsClearObject(&VD[rsIntIter]);"
459   // Destructor loop operates on individual array elements
460 
461   clang::Expr *RefRSArrPtr =
462       clang::ImplicitCastExpr::Create(C,
463           C.getPointerType(BaseType->getCanonicalTypeInternal()),
464           clang::CK_ArrayToPointerDecay,
465           RefRSArr,
466           nullptr,
467           clang::VK_RValue);
468 
469   clang::Expr *RefRSArrPtrSubscript =
470       new(C) clang::ArraySubscriptExpr(RefRSArrPtr,
471                                        RefrsIntIterRValue,
472                                        BaseType->getCanonicalTypeInternal(),
473                                        clang::VK_RValue,
474                                        clang::OK_Ordinary,
475                                        Loc);
476 
477   DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
478 
479   clang::Stmt *RSClearObjectCall = nullptr;
480   if (BaseType->isArrayType()) {
481     RSClearObjectCall =
482         ClearArrayRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
483   } else if (DT == DataTypeUnknown) {
484     RSClearObjectCall =
485         ClearStructRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
486   } else {
487     RSClearObjectCall = ClearSingleRSObject(C, RefRSArrPtrSubscript, Loc);
488   }
489 
490   clang::ForStmt *DestructorLoop =
491       new(C) clang::ForStmt(C,
492                             Init,
493                             Cond,
494                             nullptr,  // no condVar
495                             Inc,
496                             RSClearObjectCall,
497                             Loc,
498                             Loc,
499                             Loc);
500 
501   return DestructorLoop;
502 }
503 
CountRSObjectTypes(const clang::Type * T)504 unsigned CountRSObjectTypes(const clang::Type *T) {
505   slangAssert(T);
506   unsigned RSObjectCount = 0;
507 
508   if (T->isArrayType()) {
509     return CountRSObjectTypes(T->getArrayElementTypeNoTypeQual());
510   }
511 
512   DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
513   if (DT != DataTypeUnknown) {
514     return (RSExportPrimitiveType::IsRSObjectType(DT) ? 1 : 0);
515   }
516 
517   if (T->isUnionType()) {
518     clang::RecordDecl *RD = T->getAsUnionType()->getDecl();
519     RD = RD->getDefinition();
520     for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
521            FE = RD->field_end();
522          FI != FE;
523          FI++) {
524       const clang::FieldDecl *FD = *FI;
525       const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
526       if (CountRSObjectTypes(FT)) {
527         slangAssert(false && "can't have unions with RS object types!");
528         return 0;
529       }
530     }
531   }
532 
533   if (!T->isStructureType()) {
534     return 0;
535   }
536 
537   clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
538   RD = RD->getDefinition();
539   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
540          FE = RD->field_end();
541        FI != FE;
542        FI++) {
543     const clang::FieldDecl *FD = *FI;
544     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
545     if (CountRSObjectTypes(FT)) {
546       // Sub-structs should only count once (as should arrays, etc.)
547       RSObjectCount++;
548     }
549   }
550 
551   return RSObjectCount;
552 }
553 
ClearStructRSObject(clang::ASTContext & C,clang::DeclContext * DC,clang::Expr * RefRSStruct,clang::SourceLocation StartLoc,clang::SourceLocation Loc)554 clang::Stmt *ClearStructRSObject(
555     clang::ASTContext &C,
556     clang::DeclContext *DC,
557     clang::Expr *RefRSStruct,
558     clang::SourceLocation StartLoc,
559     clang::SourceLocation Loc) {
560   const clang::Type *BaseType = RefRSStruct->getType().getTypePtr();
561 
562   slangAssert(!BaseType->isArrayType());
563 
564   // Structs should show up as unknown primitive types
565   slangAssert(RSExportPrimitiveType::GetRSSpecificType(BaseType) ==
566               DataTypeUnknown);
567 
568   unsigned FieldsToDestroy = CountRSObjectTypes(BaseType);
569   slangAssert(FieldsToDestroy != 0);
570 
571   unsigned StmtCount = 0;
572   clang::Stmt **StmtArray = new clang::Stmt*[FieldsToDestroy];
573   for (unsigned i = 0; i < FieldsToDestroy; i++) {
574     StmtArray[i] = nullptr;
575   }
576 
577   // Populate StmtArray by creating a destructor for each RS object field
578   clang::RecordDecl *RD = BaseType->getAsStructureType()->getDecl();
579   RD = RD->getDefinition();
580   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
581          FE = RD->field_end();
582        FI != FE;
583        FI++) {
584     // We just look through all field declarations to see if we find a
585     // declaration for an RS object type (or an array of one).
586     bool IsArrayType = false;
587     clang::FieldDecl *FD = *FI;
588     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
589     const clang::Type *OrigType = FT;
590     while (FT && FT->isArrayType()) {
591       FT = FT->getArrayElementTypeNoTypeQual();
592       IsArrayType = true;
593     }
594 
595     // Pass a DeclarationNameInfo with a valid DeclName, since name equality
596     // gets asserted during CodeGen.
597     clang::DeclarationNameInfo FDDeclNameInfo(FD->getDeclName(),
598                                               FD->getLocation());
599 
600     if (RSExportPrimitiveType::IsRSObjectType(FT)) {
601       clang::DeclAccessPair FoundDecl =
602           clang::DeclAccessPair::make(FD, clang::AS_none);
603       clang::MemberExpr *RSObjectMember =
604           clang::MemberExpr::Create(C,
605                                     RefRSStruct,
606                                     false,
607                                     clang::SourceLocation(),
608                                     clang::NestedNameSpecifierLoc(),
609                                     clang::SourceLocation(),
610                                     FD,
611                                     FoundDecl,
612                                     FDDeclNameInfo,
613                                     nullptr,
614                                     OrigType->getCanonicalTypeInternal(),
615                                     clang::VK_RValue,
616                                     clang::OK_Ordinary);
617 
618       slangAssert(StmtCount < FieldsToDestroy);
619 
620       if (IsArrayType) {
621         StmtArray[StmtCount++] = ClearArrayRSObject(C,
622                                                     DC,
623                                                     RSObjectMember,
624                                                     StartLoc,
625                                                     Loc);
626       } else {
627         StmtArray[StmtCount++] = ClearSingleRSObject(C,
628                                                      RSObjectMember,
629                                                      Loc);
630       }
631     } else if (FT->isStructureType() && CountRSObjectTypes(FT)) {
632       // In this case, we have a nested struct. We may not end up filling all
633       // of the spaces in StmtArray (sub-structs should handle themselves
634       // with separate compound statements).
635       clang::DeclAccessPair FoundDecl =
636           clang::DeclAccessPair::make(FD, clang::AS_none);
637       clang::MemberExpr *RSObjectMember =
638           clang::MemberExpr::Create(C,
639                                     RefRSStruct,
640                                     false,
641                                     clang::SourceLocation(),
642                                     clang::NestedNameSpecifierLoc(),
643                                     clang::SourceLocation(),
644                                     FD,
645                                     FoundDecl,
646                                     clang::DeclarationNameInfo(),
647                                     nullptr,
648                                     OrigType->getCanonicalTypeInternal(),
649                                     clang::VK_RValue,
650                                     clang::OK_Ordinary);
651 
652       if (IsArrayType) {
653         StmtArray[StmtCount++] = ClearArrayRSObject(C,
654                                                     DC,
655                                                     RSObjectMember,
656                                                     StartLoc,
657                                                     Loc);
658       } else {
659         StmtArray[StmtCount++] = ClearStructRSObject(C,
660                                                      DC,
661                                                      RSObjectMember,
662                                                      StartLoc,
663                                                      Loc);
664       }
665     }
666   }
667 
668   slangAssert(StmtCount > 0);
669   clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
670       C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
671 
672   delete [] StmtArray;
673 
674   return CS;
675 }
676 
CreateSingleRSSetObject(clang::ASTContext & C,clang::Expr * DstExpr,clang::Expr * SrcExpr,clang::SourceLocation StartLoc,clang::SourceLocation Loc)677 clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
678                                      clang::Expr *DstExpr,
679                                      clang::Expr *SrcExpr,
680                                      clang::SourceLocation StartLoc,
681                                      clang::SourceLocation Loc) {
682   const clang::Type *T = DstExpr->getType().getTypePtr();
683   clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(T);
684   slangAssert((SetObjectFD != nullptr) &&
685               "rsSetObject doesn't cover all RS object types");
686 
687   clang::QualType SetObjectFDType = SetObjectFD->getType();
688   clang::QualType SetObjectFDArgType[2];
689   SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
690   SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
691 
692   clang::Expr *RefRSSetObjectFD =
693       clang::DeclRefExpr::Create(C,
694                                  clang::NestedNameSpecifierLoc(),
695                                  clang::SourceLocation(),
696                                  SetObjectFD,
697                                  false,
698                                  Loc,
699                                  SetObjectFDType,
700                                  clang::VK_RValue,
701                                  nullptr);
702 
703   clang::Expr *RSSetObjectFP =
704       clang::ImplicitCastExpr::Create(C,
705                                       C.getPointerType(SetObjectFDType),
706                                       clang::CK_FunctionToPointerDecay,
707                                       RefRSSetObjectFD,
708                                       nullptr,
709                                       clang::VK_RValue);
710 
711   llvm::SmallVector<clang::Expr*, 2> ArgList;
712   ArgList.push_back(new(C) clang::UnaryOperator(DstExpr,
713                                                 clang::UO_AddrOf,
714                                                 SetObjectFDArgType[0],
715                                                 clang::VK_RValue,
716                                                 clang::OK_Ordinary,
717                                                 Loc));
718   ArgList.push_back(SrcExpr);
719 
720   clang::CallExpr *RSSetObjectCall =
721       new(C) clang::CallExpr(C,
722                              RSSetObjectFP,
723                              ArgList,
724                              SetObjectFD->getCallResultType(),
725                              clang::VK_RValue,
726                              Loc);
727 
728   return RSSetObjectCall;
729 }
730 
731 clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
732                                      clang::Expr *LHS,
733                                      clang::Expr *RHS,
734                                      clang::SourceLocation StartLoc,
735                                      clang::SourceLocation Loc);
736 
737 /*static clang::Stmt *CreateArrayRSSetObject(clang::ASTContext &C,
738                                            clang::Expr *DstArr,
739                                            clang::Expr *SrcArr,
740                                            clang::SourceLocation StartLoc,
741                                            clang::SourceLocation Loc) {
742   clang::DeclContext *DC = nullptr;
743   const clang::Type *BaseType = DstArr->getType().getTypePtr();
744   slangAssert(BaseType->isArrayType());
745 
746   int NumArrayElements = ArrayDim(BaseType);
747   // Actually extract out the base RS object type for use later
748   BaseType = BaseType->getArrayElementTypeNoTypeQual();
749 
750   clang::Stmt *StmtArray[2] = {nullptr};
751   int StmtCtr = 0;
752 
753   if (NumArrayElements <= 0) {
754     return nullptr;
755   }
756 
757   // Create helper variable for iterating through elements
758   clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
759   clang::VarDecl *IIVD =
760       clang::VarDecl::Create(C,
761                              DC,
762                              StartLoc,
763                              Loc,
764                              &II,
765                              C.IntTy,
766                              C.getTrivialTypeSourceInfo(C.IntTy),
767                              clang::SC_None,
768                              clang::SC_None);
769   clang::Decl *IID = (clang::Decl *)IIVD;
770 
771   clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
772   StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
773 
774   // Form the actual loop
775   // for (Init; Cond; Inc)
776   //   RSSetObjectCall;
777 
778   // Init -> "rsIntIter = 0"
779   clang::DeclRefExpr *RefrsIntIter =
780       clang::DeclRefExpr::Create(C,
781                                  clang::NestedNameSpecifierLoc(),
782                                  IIVD,
783                                  Loc,
784                                  C.IntTy,
785                                  clang::VK_RValue,
786                                  nullptr);
787 
788   clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
789       llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
790 
791   clang::BinaryOperator *Init =
792       new(C) clang::BinaryOperator(RefrsIntIter,
793                                    Int0,
794                                    clang::BO_Assign,
795                                    C.IntTy,
796                                    clang::VK_RValue,
797                                    clang::OK_Ordinary,
798                                    Loc);
799 
800   // Cond -> "rsIntIter < NumArrayElements"
801   clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
802       llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
803 
804   clang::BinaryOperator *Cond =
805       new(C) clang::BinaryOperator(RefrsIntIter,
806                                    NumArrayElementsExpr,
807                                    clang::BO_LT,
808                                    C.IntTy,
809                                    clang::VK_RValue,
810                                    clang::OK_Ordinary,
811                                    Loc);
812 
813   // Inc -> "rsIntIter++"
814   clang::UnaryOperator *Inc =
815       new(C) clang::UnaryOperator(RefrsIntIter,
816                                   clang::UO_PostInc,
817                                   C.IntTy,
818                                   clang::VK_RValue,
819                                   clang::OK_Ordinary,
820                                   Loc);
821 
822   // Body -> "rsSetObject(&Dst[rsIntIter], Src[rsIntIter]);"
823   // Loop operates on individual array elements
824 
825   clang::Expr *DstArrPtr =
826       clang::ImplicitCastExpr::Create(C,
827           C.getPointerType(BaseType->getCanonicalTypeInternal()),
828           clang::CK_ArrayToPointerDecay,
829           DstArr,
830           nullptr,
831           clang::VK_RValue);
832 
833   clang::Expr *DstArrPtrSubscript =
834       new(C) clang::ArraySubscriptExpr(DstArrPtr,
835                                        RefrsIntIter,
836                                        BaseType->getCanonicalTypeInternal(),
837                                        clang::VK_RValue,
838                                        clang::OK_Ordinary,
839                                        Loc);
840 
841   clang::Expr *SrcArrPtr =
842       clang::ImplicitCastExpr::Create(C,
843           C.getPointerType(BaseType->getCanonicalTypeInternal()),
844           clang::CK_ArrayToPointerDecay,
845           SrcArr,
846           nullptr,
847           clang::VK_RValue);
848 
849   clang::Expr *SrcArrPtrSubscript =
850       new(C) clang::ArraySubscriptExpr(SrcArrPtr,
851                                        RefrsIntIter,
852                                        BaseType->getCanonicalTypeInternal(),
853                                        clang::VK_RValue,
854                                        clang::OK_Ordinary,
855                                        Loc);
856 
857   DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
858 
859   clang::Stmt *RSSetObjectCall = nullptr;
860   if (BaseType->isArrayType()) {
861     RSSetObjectCall = CreateArrayRSSetObject(C, DstArrPtrSubscript,
862                                              SrcArrPtrSubscript,
863                                              StartLoc, Loc);
864   } else if (DT == DataTypeUnknown) {
865     RSSetObjectCall = CreateStructRSSetObject(C, DstArrPtrSubscript,
866                                               SrcArrPtrSubscript,
867                                               StartLoc, Loc);
868   } else {
869     RSSetObjectCall = CreateSingleRSSetObject(C, DstArrPtrSubscript,
870                                               SrcArrPtrSubscript,
871                                               StartLoc, Loc);
872   }
873 
874   clang::ForStmt *DestructorLoop =
875       new(C) clang::ForStmt(C,
876                             Init,
877                             Cond,
878                             nullptr,  // no condVar
879                             Inc,
880                             RSSetObjectCall,
881                             Loc,
882                             Loc,
883                             Loc);
884 
885   StmtArray[StmtCtr++] = DestructorLoop;
886   slangAssert(StmtCtr == 2);
887 
888   clang::CompoundStmt *CS =
889       new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
890 
891   return CS;
892 } */
893 
CreateStructRSSetObject(clang::ASTContext & C,clang::Expr * LHS,clang::Expr * RHS,clang::SourceLocation StartLoc,clang::SourceLocation Loc)894 clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
895                                      clang::Expr *LHS,
896                                      clang::Expr *RHS,
897                                      clang::SourceLocation StartLoc,
898                                      clang::SourceLocation Loc) {
899   clang::QualType QT = LHS->getType();
900   const clang::Type *T = QT.getTypePtr();
901   slangAssert(T->isStructureType());
902   slangAssert(!RSExportPrimitiveType::IsRSObjectType(T));
903 
904   // Keep an extra slot for the original copy (memcpy)
905   unsigned FieldsToSet = CountRSObjectTypes(T) + 1;
906 
907   unsigned StmtCount = 0;
908   clang::Stmt **StmtArray = new clang::Stmt*[FieldsToSet];
909   for (unsigned i = 0; i < FieldsToSet; i++) {
910     StmtArray[i] = nullptr;
911   }
912 
913   clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
914   RD = RD->getDefinition();
915   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
916          FE = RD->field_end();
917        FI != FE;
918        FI++) {
919     bool IsArrayType = false;
920     clang::FieldDecl *FD = *FI;
921     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
922     const clang::Type *OrigType = FT;
923 
924     if (!CountRSObjectTypes(FT)) {
925       // Skip to next if we don't have any viable RS object types
926       continue;
927     }
928 
929     clang::DeclAccessPair FoundDecl =
930         clang::DeclAccessPair::make(FD, clang::AS_none);
931     clang::MemberExpr *DstMember =
932         clang::MemberExpr::Create(C,
933                                   LHS,
934                                   false,
935                                   clang::SourceLocation(),
936                                   clang::NestedNameSpecifierLoc(),
937                                   clang::SourceLocation(),
938                                   FD,
939                                   FoundDecl,
940                                   clang::DeclarationNameInfo(
941                                       FD->getDeclName(),
942                                       clang::SourceLocation()),
943                                   nullptr,
944                                   OrigType->getCanonicalTypeInternal(),
945                                   clang::VK_RValue,
946                                   clang::OK_Ordinary);
947 
948     clang::MemberExpr *SrcMember =
949         clang::MemberExpr::Create(C,
950                                   RHS,
951                                   false,
952                                   clang::SourceLocation(),
953                                   clang::NestedNameSpecifierLoc(),
954                                   clang::SourceLocation(),
955                                   FD,
956                                   FoundDecl,
957                                   clang::DeclarationNameInfo(
958                                       FD->getDeclName(),
959                                       clang::SourceLocation()),
960                                   nullptr,
961                                   OrigType->getCanonicalTypeInternal(),
962                                   clang::VK_RValue,
963                                   clang::OK_Ordinary);
964 
965     if (FT->isArrayType()) {
966       FT = FT->getArrayElementTypeNoTypeQual();
967       IsArrayType = true;
968     }
969 
970     DataType DT = RSExportPrimitiveType::GetRSSpecificType(FT);
971 
972     if (IsArrayType) {
973       clang::DiagnosticsEngine &DiagEngine = C.getDiagnostics();
974       DiagEngine.Report(
975         clang::FullSourceLoc(Loc, C.getSourceManager()),
976         DiagEngine.getCustomDiagID(
977           clang::DiagnosticsEngine::Error,
978           "Arrays of RS object types within structures cannot be copied"));
979       // TODO(srhines): Support setting arrays of RS objects
980       // StmtArray[StmtCount++] =
981       //    CreateArrayRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
982     } else if (DT == DataTypeUnknown) {
983       StmtArray[StmtCount++] =
984           CreateStructRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
985     } else if (RSExportPrimitiveType::IsRSObjectType(DT)) {
986       StmtArray[StmtCount++] =
987           CreateSingleRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
988     } else {
989       slangAssert(false);
990     }
991   }
992 
993   slangAssert(StmtCount < FieldsToSet);
994 
995   // We still need to actually do the overall struct copy. For simplicity,
996   // we just do a straight-up assignment (which will still preserve all
997   // the proper RS object reference counts).
998   clang::BinaryOperator *CopyStruct =
999       new(C) clang::BinaryOperator(LHS, RHS, clang::BO_Assign, QT,
1000                                    clang::VK_RValue, clang::OK_Ordinary, Loc,
1001                                    false);
1002   StmtArray[StmtCount++] = CopyStruct;
1003 
1004   clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
1005       C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
1006 
1007   delete [] StmtArray;
1008 
1009   return CS;
1010 }
1011 
1012 }  // namespace
1013 
InsertStmt(const clang::ASTContext & C,clang::Stmt * NewStmt)1014 void RSObjectRefCount::Scope::InsertStmt(const clang::ASTContext &C,
1015                                          clang::Stmt *NewStmt) {
1016   std::vector<clang::Stmt*> newBody;
1017   for (clang::Stmt* S1 : mCS->body()) {
1018     if (S1 == mCurrent) {
1019       newBody.push_back(NewStmt);
1020     }
1021     newBody.push_back(S1);
1022   }
1023   mCS->setStmts(C, newBody);
1024 }
1025 
ReplaceStmt(const clang::ASTContext & C,clang::Stmt * NewStmt)1026 void RSObjectRefCount::Scope::ReplaceStmt(const clang::ASTContext &C,
1027                                           clang::Stmt *NewStmt) {
1028   std::vector<clang::Stmt*> newBody;
1029   for (clang::Stmt* S1 : mCS->body()) {
1030     if (S1 == mCurrent) {
1031       newBody.push_back(NewStmt);
1032     } else {
1033       newBody.push_back(S1);
1034     }
1035   }
1036   mCS->setStmts(C, newBody);
1037 }
1038 
ReplaceExpr(const clang::ASTContext & C,clang::Expr * OldExpr,clang::Expr * NewExpr)1039 void RSObjectRefCount::Scope::ReplaceExpr(const clang::ASTContext& C,
1040                                           clang::Expr* OldExpr,
1041                                           clang::Expr* NewExpr) {
1042   RSASTReplace R(C);
1043   R.ReplaceStmt(mCurrent, OldExpr, NewExpr);
1044 }
1045 
ReplaceRSObjectAssignment(clang::BinaryOperator * AS)1046 void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
1047     clang::BinaryOperator *AS) {
1048 
1049   clang::QualType QT = AS->getType();
1050 
1051   clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1052       DataTypeRSAllocation)->getASTContext();
1053 
1054   clang::SourceLocation Loc = AS->getExprLoc();
1055   clang::SourceLocation StartLoc = AS->getLHS()->getExprLoc();
1056   clang::Stmt *UpdatedStmt = nullptr;
1057 
1058   if (!RSExportPrimitiveType::IsRSObjectType(QT.getTypePtr())) {
1059     // By definition, this is a struct assignment if we get here
1060     UpdatedStmt =
1061         CreateStructRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1062   } else {
1063     UpdatedStmt =
1064         CreateSingleRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1065   }
1066 
1067   RSASTReplace R(C);
1068   R.ReplaceStmt(mCS, AS, UpdatedStmt);
1069 }
1070 
AppendRSObjectInit(clang::VarDecl * VD,clang::DeclStmt * DS,DataType DT,clang::Expr * InitExpr)1071 void RSObjectRefCount::Scope::AppendRSObjectInit(
1072     clang::VarDecl *VD,
1073     clang::DeclStmt *DS,
1074     DataType DT,
1075     clang::Expr *InitExpr) {
1076   slangAssert(VD);
1077 
1078   if (!InitExpr) {
1079     return;
1080   }
1081 
1082   clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1083       DataTypeRSAllocation)->getASTContext();
1084   clang::SourceLocation Loc = RSObjectRefCount::GetRSSetObjectFD(
1085       DataTypeRSAllocation)->getLocation();
1086   clang::SourceLocation StartLoc = RSObjectRefCount::GetRSSetObjectFD(
1087       DataTypeRSAllocation)->getInnerLocStart();
1088 
1089   if (DT == DataTypeIsStruct) {
1090     const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1091     clang::DeclRefExpr *RefRSVar =
1092         clang::DeclRefExpr::Create(C,
1093                                    clang::NestedNameSpecifierLoc(),
1094                                    clang::SourceLocation(),
1095                                    VD,
1096                                    false,
1097                                    Loc,
1098                                    T->getCanonicalTypeInternal(),
1099                                    clang::VK_RValue,
1100                                    nullptr);
1101 
1102     clang::Stmt *RSSetObjectOps =
1103         CreateStructRSSetObject(C, RefRSVar, InitExpr, StartLoc, Loc);
1104     // Fix for b/37363420; consider:
1105     //
1106     // struct foo { rs_matrix m; };
1107     // void bar() {
1108     //   struct foo M = {...};
1109     // }
1110     //
1111     // slang modifies that declaration with initialization to a
1112     // declaration plus an assignment of the initialization values.
1113     //
1114     // void bar() {
1115     //   struct foo M = {};
1116     //   M = {...}; // by CreateStructRSSetObject() above
1117     // }
1118     //
1119     // the slang-generated statement (M = {...}) is a use of M, and we
1120     // need to mark M (clang::VarDecl *VD) as used.
1121     VD->markUsed(C);
1122 
1123     std::list<clang::Stmt*> StmtList;
1124     StmtList.push_back(RSSetObjectOps);
1125     AppendAfterStmt(C, mCS, DS, StmtList);
1126     return;
1127   }
1128 
1129   clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(DT);
1130   slangAssert((SetObjectFD != nullptr) &&
1131               "rsSetObject doesn't cover all RS object types");
1132 
1133   clang::QualType SetObjectFDType = SetObjectFD->getType();
1134   clang::QualType SetObjectFDArgType[2];
1135   SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
1136   SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
1137 
1138   clang::Expr *RefRSSetObjectFD =
1139       clang::DeclRefExpr::Create(C,
1140                                  clang::NestedNameSpecifierLoc(),
1141                                  clang::SourceLocation(),
1142                                  SetObjectFD,
1143                                  false,
1144                                  Loc,
1145                                  SetObjectFDType,
1146                                  clang::VK_RValue,
1147                                  nullptr);
1148 
1149   clang::Expr *RSSetObjectFP =
1150       clang::ImplicitCastExpr::Create(C,
1151                                       C.getPointerType(SetObjectFDType),
1152                                       clang::CK_FunctionToPointerDecay,
1153                                       RefRSSetObjectFD,
1154                                       nullptr,
1155                                       clang::VK_RValue);
1156 
1157   const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1158   clang::DeclRefExpr *RefRSVar =
1159       clang::DeclRefExpr::Create(C,
1160                                  clang::NestedNameSpecifierLoc(),
1161                                  clang::SourceLocation(),
1162                                  VD,
1163                                  false,
1164                                  Loc,
1165                                  T->getCanonicalTypeInternal(),
1166                                  clang::VK_RValue,
1167                                  nullptr);
1168 
1169   llvm::SmallVector<clang::Expr*, 2> ArgList;
1170   ArgList.push_back(new(C) clang::UnaryOperator(RefRSVar,
1171                                                 clang::UO_AddrOf,
1172                                                 SetObjectFDArgType[0],
1173                                                 clang::VK_RValue,
1174                                                 clang::OK_Ordinary,
1175                                                 Loc));
1176   ArgList.push_back(InitExpr);
1177 
1178   clang::CallExpr *RSSetObjectCall =
1179       new(C) clang::CallExpr(C,
1180                              RSSetObjectFP,
1181                              ArgList,
1182                              SetObjectFD->getCallResultType(),
1183                              clang::VK_RValue,
1184                              Loc);
1185 
1186   std::list<clang::Stmt*> StmtList;
1187   StmtList.push_back(RSSetObjectCall);
1188   AppendAfterStmt(C, mCS, DS, StmtList);
1189 }
1190 
InsertLocalVarDestructors()1191 void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
1192   if (mRSO.empty()) {
1193     return;
1194   }
1195 
1196   clang::DeclContext* DC = mRSO.front()->getDeclContext();
1197   clang::ASTContext& C = DC->getParentASTContext();
1198   clang::SourceManager& SM = C.getSourceManager();
1199 
1200   const auto& OccursBefore = [&SM] (clang::SourceLocation L1, clang::SourceLocation L2)->bool {
1201     return SM.isBeforeInTranslationUnit(L1, L2);
1202   };
1203   typedef std::map<clang::SourceLocation, clang::Stmt*, decltype(OccursBefore)> DMap;
1204 
1205   DMap dtors(OccursBefore);
1206 
1207   // Create rsClearObject calls. Note the DMap entries are sorted by the SourceLocation.
1208   for (clang::VarDecl* VD : mRSO) {
1209     clang::SourceLocation Loc = VD->getSourceRange().getBegin();
1210     clang::Stmt* RSClearObjectCall = ClearRSObject(VD, DC);
1211     dtors.insert(std::make_pair(Loc, RSClearObjectCall));
1212   }
1213 
1214   DestructorVisitor Visitor;
1215   Visitor.Visit(mCS);
1216 
1217   // Replace each exiting statement with a block that contains the original statement
1218   // and added rsClearObject() calls before it.
1219   for (clang::Stmt* S : Visitor.getExitingStmts()) {
1220 
1221     const clang::SourceLocation currentLoc = S->getLocStart();
1222 
1223     DMap::iterator firstDtorIter = dtors.begin();
1224     DMap::iterator currentDtorIter = firstDtorIter;
1225     DMap::iterator lastDtorIter = dtors.end();
1226 
1227     while (currentDtorIter != lastDtorIter &&
1228            OccursBefore(currentDtorIter->first, currentLoc)) {
1229       currentDtorIter++;
1230     }
1231 
1232     if (currentDtorIter == firstDtorIter) {
1233       continue;
1234     }
1235 
1236     std::vector<clang::Stmt*> Stmts;
1237 
1238     // Insert rsClearObject() calls for all rsObjects declared before the current statement
1239     for(DMap::iterator it = firstDtorIter; it != currentDtorIter; it++) {
1240       Stmts.push_back(it->second);
1241     }
1242     Stmts.push_back(S);
1243 
1244     RSASTReplace R(C);
1245     clang::CompoundStmt* CS = BuildCompoundStmt(C, Stmts, S->getLocEnd());
1246     R.ReplaceStmt(mCS, S, CS);
1247   }
1248 
1249   std::list<clang::Stmt*> Stmts;
1250   for(auto LocCallPair : dtors) {
1251     Stmts.push_back(LocCallPair.second);
1252   }
1253   AppendAfterStmt(C, mCS, nullptr, Stmts);
1254 }
1255 
ClearRSObject(clang::VarDecl * VD,clang::DeclContext * DC)1256 clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(
1257     clang::VarDecl *VD,
1258     clang::DeclContext *DC) {
1259   slangAssert(VD);
1260   clang::ASTContext &C = VD->getASTContext();
1261   clang::SourceLocation Loc = VD->getLocation();
1262   clang::SourceLocation StartLoc = VD->getInnerLocStart();
1263   const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1264 
1265   // Reference expr to target RS object variable
1266   clang::DeclRefExpr *RefRSVar =
1267       clang::DeclRefExpr::Create(C,
1268                                  clang::NestedNameSpecifierLoc(),
1269                                  clang::SourceLocation(),
1270                                  VD,
1271                                  false,
1272                                  Loc,
1273                                  T->getCanonicalTypeInternal(),
1274                                  clang::VK_RValue,
1275                                  nullptr);
1276 
1277   if (T->isArrayType()) {
1278     return ClearArrayRSObject(C, DC, RefRSVar, StartLoc, Loc);
1279   }
1280 
1281   DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
1282 
1283   if (DT == DataTypeUnknown ||
1284       DT == DataTypeIsStruct) {
1285     return ClearStructRSObject(C, DC, RefRSVar, StartLoc, Loc);
1286   }
1287 
1288   slangAssert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
1289               "Should be RS object");
1290 
1291   return ClearSingleRSObject(C, RefRSVar, Loc);
1292 }
1293 
InitializeRSObject(clang::VarDecl * VD,DataType * DT,clang::Expr ** InitExpr)1294 bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD,
1295                                           DataType *DT,
1296                                           clang::Expr **InitExpr) {
1297   slangAssert(VD && DT && InitExpr);
1298   const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1299 
1300   // Loop through array types to get to base type
1301   while (T && T->isArrayType()) {
1302     T = T->getArrayElementTypeNoTypeQual();
1303   }
1304 
1305   bool DataTypeIsStructWithRSObject = false;
1306   *DT = RSExportPrimitiveType::GetRSSpecificType(T);
1307 
1308   if (*DT == DataTypeUnknown) {
1309     if (RSExportPrimitiveType::IsStructureTypeWithRSObject(T)) {
1310       *DT = DataTypeIsStruct;
1311       DataTypeIsStructWithRSObject = true;
1312     } else {
1313       return false;
1314     }
1315   }
1316 
1317   bool DataTypeIsRSObject = false;
1318   if (DataTypeIsStructWithRSObject) {
1319     DataTypeIsRSObject = true;
1320   } else {
1321     DataTypeIsRSObject = RSExportPrimitiveType::IsRSObjectType(*DT);
1322   }
1323   *InitExpr = VD->getInit();
1324 
1325   if (!DataTypeIsRSObject && *InitExpr) {
1326     // If we already have an initializer for a matrix type, we are done.
1327     return DataTypeIsRSObject;
1328   }
1329 
1330   clang::Expr *ZeroInitializer =
1331       CreateEmptyInitListExpr(VD->getASTContext(), VD->getLocation());
1332 
1333   if (ZeroInitializer) {
1334     ZeroInitializer->setType(T->getCanonicalTypeInternal());
1335     VD->setInit(ZeroInitializer);
1336   }
1337 
1338   return DataTypeIsRSObject;
1339 }
1340 
CreateEmptyInitListExpr(clang::ASTContext & C,const clang::SourceLocation & Loc)1341 clang::Expr *RSObjectRefCount::CreateEmptyInitListExpr(
1342     clang::ASTContext &C,
1343     const clang::SourceLocation &Loc) {
1344 
1345   // We can cheaply construct a zero initializer by just creating an empty
1346   // initializer list. Clang supports this extension to C(99), and will create
1347   // any necessary constructs to zero out the entire variable.
1348   llvm::SmallVector<clang::Expr*, 1> EmptyInitList;
1349   return new(C) clang::InitListExpr(C, Loc, EmptyInitList, Loc);
1350 }
1351 
CreateGuard(clang::ASTContext & C,clang::DeclContext * DC,clang::Expr * E,const llvm::Twine & VarName,std::vector<clang::Stmt * > & NewStmts)1352 clang::DeclRefExpr *RSObjectRefCount::CreateGuard(clang::ASTContext &C,
1353                                                   clang::DeclContext *DC,
1354                                                   clang::Expr *E,
1355                                                   const llvm::Twine &VarName,
1356                                                   std::vector<clang::Stmt*> &NewStmts) {
1357   clang::SourceLocation Loc = E->getLocStart();
1358   const clang::QualType Ty = E->getType();
1359   clang::VarDecl* TmpDecl = clang::VarDecl::Create(
1360       C,                                     // AST context
1361       DC,                                    // Decl context
1362       Loc,                                   // Start location
1363       Loc,                                   // Id location
1364       &C.Idents.get(VarName.str()),          // Id
1365       Ty,                                    // Type
1366       C.getTrivialTypeSourceInfo(Ty),        // Type info
1367       clang::SC_None                         // Storage class
1368   );
1369   const clang::Type *T = Ty.getTypePtr();
1370   clang::Expr *ZeroInitializer =
1371       RSObjectRefCount::CreateEmptyInitListExpr(C, Loc);
1372   ZeroInitializer->setType(T->getCanonicalTypeInternal());
1373   TmpDecl->setInit(ZeroInitializer);
1374   TmpDecl->markUsed(C);
1375   clang::Decl* Decls[] = { TmpDecl };
1376   const clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(
1377       C, Decls, sizeof(Decls) / sizeof(*Decls));
1378   clang::DeclStmt* DS = new (C) clang::DeclStmt(DGR, Loc, Loc);
1379   NewStmts.push_back(DS);
1380 
1381   clang::DeclRefExpr* DRE = clang::DeclRefExpr::Create(
1382       C,
1383       clang::NestedNameSpecifierLoc(),       // QualifierLoc
1384       Loc,                                   // TemplateKWLoc
1385       TmpDecl,
1386       false,                                 // RefersToEnclosingVariableOrCapture
1387       Loc,                                   // NameLoc
1388       Ty,
1389       clang::VK_LValue
1390   );
1391 
1392   clang::Stmt *UpdatedStmt = nullptr;
1393   if (CountRSObjectTypes(Ty.getTypePtr()) == 0) {
1394     // The expression E is not an RS object itself. Instead of calling
1395     // rsSetObject(), create an assignment statement to set the value of the
1396     // temporary "guard" variable to the expression.
1397     // This can happen if called from RSObjectRefCount::VisitReturnStmt(),
1398     // when the return expression is not an RS object but references one.
1399     UpdatedStmt =
1400       new(C) clang::BinaryOperator(DRE, E, clang::BO_Assign, Ty,
1401                                    clang::VK_RValue, clang::OK_Ordinary, Loc,
1402                                    false);
1403 
1404   } else if (!RSExportPrimitiveType::IsRSObjectType(Ty.getTypePtr())) {
1405     // By definition, this is a struct assignment if we get here
1406     UpdatedStmt =
1407         CreateStructRSSetObject(C, DRE, E, Loc, Loc);
1408   } else {
1409     UpdatedStmt =
1410         CreateSingleRSSetObject(C, DRE, E, Loc, Loc);
1411   }
1412   NewStmts.push_back(UpdatedStmt);
1413 
1414   return DRE;
1415 }
1416 
CreateParameterGuard(clang::ASTContext & C,clang::DeclContext * DC,clang::ParmVarDecl * PD,std::vector<clang::Stmt * > & NewStmts)1417 void RSObjectRefCount::CreateParameterGuard(clang::ASTContext &C,
1418                                             clang::DeclContext *DC,
1419                                             clang::ParmVarDecl *PD,
1420                                             std::vector<clang::Stmt*> &NewStmts) {
1421   clang::SourceLocation Loc = PD->getLocStart();
1422   clang::DeclRefExpr* ParamDRE = clang::DeclRefExpr::Create(
1423       C,
1424       clang::NestedNameSpecifierLoc(),       // QualifierLoc
1425       Loc,                                   // TemplateKWLoc
1426       PD,
1427       false,                                 // RefersToEnclosingVariableOrCapture
1428       Loc,                                   // NameLoc
1429       PD->getType(),
1430       clang::VK_RValue
1431   );
1432 
1433   CreateGuard(C, DC, ParamDRE,
1434               llvm::Twine(".rs.param.") + llvm::Twine(PD->getName()), NewStmts);
1435 }
1436 
HandleParamsAndLocals(clang::FunctionDecl * FD)1437 void RSObjectRefCount::HandleParamsAndLocals(clang::FunctionDecl *FD) {
1438   std::vector<clang::Stmt*> NewStmts;
1439   std::list<clang::ParmVarDecl*> ObjParams;
1440   for (clang::ParmVarDecl *Param : FD->parameters()) {
1441     clang::QualType QT = Param->getType();
1442     if (CountRSObjectTypes(QT.getTypePtr())) {
1443       // Ignore non-object types
1444       RSObjectRefCount::CreateParameterGuard(mCtx, FD, Param, NewStmts);
1445       ObjParams.push_back(Param);
1446     }
1447   }
1448 
1449   clang::Stmt *OldBody = FD->getBody();
1450   if (ObjParams.empty()) {
1451     Visit(OldBody);
1452     return;
1453   }
1454 
1455   NewStmts.push_back(OldBody);
1456 
1457   clang::SourceLocation Loc = FD->getLocStart();
1458   clang::CompoundStmt *NewBody = BuildCompoundStmt(mCtx, NewStmts, Loc);
1459   Scope S(NewBody);
1460   for (clang::ParmVarDecl *Param : ObjParams) {
1461     S.addRSObject(Param);
1462   }
1463   mScopeStack.push_back(&S);
1464 
1465   // To avoid adding unnecessary ref counting artifacts to newly added temporary
1466   // local variables for parameters, visits only the old function body here.
1467   Visit(OldBody);
1468 
1469   FD->setBody(NewBody);
1470 
1471   S.InsertLocalVarDestructors();
1472   mScopeStack.pop_back();
1473 }
1474 
CreateRetStmtWithTempVar(clang::ASTContext & C,clang::DeclContext * DC,clang::ReturnStmt * RS,const unsigned id)1475 clang::CompoundStmt* RSObjectRefCount::CreateRetStmtWithTempVar(
1476     clang::ASTContext& C,
1477     clang::DeclContext* DC,
1478     clang::ReturnStmt* RS,
1479     const unsigned id) {
1480   std::vector<clang::Stmt*> NewStmts;
1481   // Since we insert rsClearObj() calls before the return statement, we need
1482   // to make sure none of the cleared RS objects are referenced in the
1483   // return statement.
1484   // For that, we create a new local variable named .rs.retval, assign the
1485   // original return expression to it, make all necessary rsClearObj()
1486   // calls, then return .rs.retval. Note rsClearObj() is not called on
1487   // .rs.retval.
1488   clang::SourceLocation Loc = RS->getLocStart();
1489   clang::Expr* RetVal = RS->getRetValue();
1490   const clang::QualType RetTy = RetVal->getType();
1491   clang::DeclRefExpr *DRE = CreateGuard(C, DC, RetVal,
1492                                         llvm::Twine(".rs.retval") + llvm::Twine(id),
1493                                         NewStmts);
1494 
1495   // Creates a new return statement
1496   clang::ReturnStmt* NewRet = new (C) clang::ReturnStmt(Loc);
1497   clang::Expr* CastExpr = clang::ImplicitCastExpr::Create(
1498       C,
1499       RetTy,
1500       clang::CK_LValueToRValue,
1501       DRE,
1502       nullptr,
1503       clang::VK_RValue
1504   );
1505   NewRet->setRetValue(CastExpr);
1506   NewStmts.push_back(NewRet);
1507 
1508   return BuildCompoundStmt(C, NewStmts, Loc);
1509 }
1510 
VisitDeclStmt(clang::DeclStmt * DS)1511 void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
1512   VisitStmt(DS);
1513   getCurrentScope()->setCurrentStmt(DS);
1514   for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
1515        I != E;
1516        I++) {
1517     clang::Decl *D = *I;
1518     if (D->getKind() == clang::Decl::Var) {
1519       clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
1520       DataType DT = DataTypeUnknown;
1521       clang::Expr *InitExpr = nullptr;
1522       if (InitializeRSObject(VD, &DT, &InitExpr)) {
1523         // We need to zero-init all RS object types (including matrices), ...
1524         getCurrentScope()->AppendRSObjectInit(VD, DS, DT, InitExpr);
1525         // ... but, only add to the list of RS objects if we have some
1526         // non-matrix RS object fields.
1527         if (CountRSObjectTypes(VD->getType().getTypePtr())) {
1528           getCurrentScope()->addRSObject(VD);
1529         }
1530       }
1531     }
1532   }
1533 }
1534 
VisitCallExpr(clang::CallExpr * CE)1535 void RSObjectRefCount::VisitCallExpr(clang::CallExpr* CE) {
1536   clang::QualType RetTy;
1537   const clang::FunctionDecl* FD = CE->getDirectCallee();
1538 
1539   if (FD) {
1540     // Direct calls
1541 
1542     RetTy = FD->getReturnType();
1543   } else {
1544     // Indirect calls through function pointers
1545 
1546     const clang::Expr* Callee = CE->getCallee();
1547     const clang::Type* CalleeType = Callee->getType().getTypePtr();
1548     const clang::PointerType* PtrType = CalleeType->getAs<clang::PointerType>();
1549 
1550     if (!PtrType) {
1551       return;
1552     }
1553 
1554     const clang::Type* PointeeType = PtrType->getPointeeType().getTypePtr();
1555     const clang::FunctionType* FuncType = PointeeType->getAs<clang::FunctionType>();
1556 
1557     if (!FuncType) {
1558       return;
1559     }
1560 
1561     RetTy = FuncType->getReturnType();
1562   }
1563 
1564   // The RenderScript runtime API maintains the invariant that the sysRef of a new RS object would
1565   // be 1, with the exception of rsGetAllocation() (deprecated in API 22), which leaves the sysRef
1566   // 0 for a new allocation. It is the responsibility of the callee of the API to decrement the
1567   // sysRef when a reference of the RS object goes out of scope. The compiler generates code to do
1568   // just that, by creating a temporary variable named ".rs.tmpN" with the result of
1569   // an RS-object-returning API directly assigned to it, and calling rsClearObject() on .rs.tmpN
1570   // right before it exits the current scope. Such code generation is skipped for rsGetAllocation()
1571   // to avoid decrementing its sysRef below zero.
1572 
1573   if (CountRSObjectTypes(RetTy.getTypePtr())==0 ||
1574       (FD && FD->getName() == "rsGetAllocation")) {
1575     return;
1576   }
1577 
1578   clang::SourceLocation Loc = CE->getSourceRange().getBegin();
1579   std::stringstream ss;
1580   ss << ".rs.tmp" << getNextID();
1581   llvm::StringRef VarName(ss.str());
1582 
1583   clang::VarDecl* TempVarDecl = clang::VarDecl::Create(
1584       mCtx,                                  // AST context
1585       GetDeclContext(),                      // Decl context
1586       Loc,                                   // Start location
1587       Loc,                                   // Id location
1588       &mCtx.Idents.get(VarName),             // Id
1589       RetTy,                                 // Type
1590       mCtx.getTrivialTypeSourceInfo(RetTy),  // Type info
1591       clang::SC_None                         // Storage class
1592   );
1593   TempVarDecl->setInit(CE);
1594   TempVarDecl->markUsed(mCtx);
1595   clang::Decl* Decls[] = { TempVarDecl };
1596   const clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(
1597       mCtx, Decls, sizeof(Decls) / sizeof(*Decls));
1598   clang::DeclStmt* DS = new (mCtx) clang::DeclStmt(DGR, Loc, Loc);
1599 
1600   getCurrentScope()->InsertStmt(mCtx, DS);
1601 
1602   clang::DeclRefExpr* DRE = clang::DeclRefExpr::Create(
1603       mCtx,                                  // AST context
1604       clang::NestedNameSpecifierLoc(),       // QualifierLoc
1605       Loc,                                   // TemplateKWLoc
1606       TempVarDecl,
1607       false,                                 // RefersToEnclosingVariableOrCapture
1608       Loc,                                   // NameLoc
1609       RetTy,
1610       clang::VK_LValue
1611   );
1612   clang::Expr* CastExpr = clang::ImplicitCastExpr::Create(
1613       mCtx,
1614       RetTy,
1615       clang::CK_LValueToRValue,
1616       DRE,
1617       nullptr,
1618       clang::VK_RValue
1619   );
1620 
1621   getCurrentScope()->ReplaceExpr(mCtx, CE, CastExpr);
1622 
1623   // Register TempVarDecl for destruction call (rsClearObj).
1624   getCurrentScope()->addRSObject(TempVarDecl);
1625 }
1626 
VisitCompoundStmt(clang::CompoundStmt * CS)1627 void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
1628   if (!emptyScope()) {
1629     getCurrentScope()->setCurrentStmt(CS);
1630   }
1631 
1632   if (!CS->body_empty()) {
1633     // Push a new scope
1634     Scope *S = new Scope(CS);
1635     mScopeStack.push_back(S);
1636 
1637     VisitStmt(CS);
1638 
1639     // Destroy the scope
1640     slangAssert((getCurrentScope() == S) && "Corrupted scope stack!");
1641     S->InsertLocalVarDestructors();
1642     mScopeStack.pop_back();
1643     delete S;
1644   }
1645 }
1646 
VisitBinAssign(clang::BinaryOperator * AS)1647 void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
1648   getCurrentScope()->setCurrentStmt(AS);
1649   clang::QualType QT = AS->getType();
1650 
1651   if (CountRSObjectTypes(QT.getTypePtr())) {
1652     getCurrentScope()->ReplaceRSObjectAssignment(AS);
1653   }
1654 }
1655 
1656 namespace {
1657 
1658 class FindRSObjRefVisitor : public clang::RecursiveASTVisitor<FindRSObjRefVisitor> {
1659 public:
FindRSObjRefVisitor()1660   explicit FindRSObjRefVisitor() : mRefRSObj(false) {}
VisitExpr(clang::Expr * Expression)1661   bool VisitExpr(clang::Expr* Expression) {
1662     if (CountRSObjectTypes(Expression->getType().getTypePtr()) > 0) {
1663       mRefRSObj = true;
1664       // Found a reference to an RS object. Stop the AST traversal.
1665       return false;
1666     }
1667     return true;
1668   }
1669 
foundRSObjRef() const1670   bool foundRSObjRef() const { return mRefRSObj; }
1671 
1672 private:
1673   bool mRefRSObj;
1674 };
1675 
1676 }  // anonymous namespace
1677 
VisitReturnStmt(clang::ReturnStmt * RS)1678 void RSObjectRefCount::VisitReturnStmt(clang::ReturnStmt *RS) {
1679   getCurrentScope()->setCurrentStmt(RS);
1680 
1681   // If there is no local rsObject declared so far, no need to transform the
1682   // return statement.
1683 
1684   bool RSObjDeclared = false;
1685 
1686   for (const Scope* S : mScopeStack) {
1687     if (S->hasRSObject()) {
1688       RSObjDeclared = true;
1689       break;
1690     }
1691   }
1692 
1693   if (!RSObjDeclared) {
1694     return;
1695   }
1696 
1697   FindRSObjRefVisitor visitor;
1698 
1699   visitor.TraverseStmt(RS);
1700 
1701   // If the return statement does not return anything, or if it does not reference
1702   // a rsObject, no need to transform it.
1703 
1704   if (!visitor.foundRSObjRef()) {
1705     return;
1706   }
1707 
1708   // Transform the return statement so that it does not potentially return or
1709   // reference a rsObject that has been cleared.
1710 
1711   clang::CompoundStmt* NewRS;
1712   NewRS = CreateRetStmtWithTempVar(mCtx, GetDeclContext(), RS, getNextID());
1713 
1714   getCurrentScope()->ReplaceStmt(mCtx, NewRS);
1715 }
1716 
VisitStmt(clang::Stmt * S)1717 void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
1718   getCurrentScope()->setCurrentStmt(S);
1719   for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
1720        I != E;
1721        I++) {
1722     if (clang::Stmt *Child = *I) {
1723       Visit(Child);
1724     }
1725   }
1726 }
1727 
1728 // This function walks the list of global variables and (potentially) creates
1729 // a single global static destructor function that properly decrements
1730 // reference counts on the contained RS object types.
CreateStaticGlobalDtor()1731 clang::FunctionDecl *RSObjectRefCount::CreateStaticGlobalDtor() {
1732   Init();
1733 
1734   clang::DeclContext *DC = mCtx.getTranslationUnitDecl();
1735   clang::SourceLocation loc;
1736 
1737   llvm::StringRef SR(".rs.dtor");
1738   clang::IdentifierInfo &II = mCtx.Idents.get(SR);
1739   clang::DeclarationName N(&II);
1740   clang::FunctionProtoType::ExtProtoInfo EPI;
1741   clang::QualType T = mCtx.getFunctionType(mCtx.VoidTy,
1742       llvm::ArrayRef<clang::QualType>(), EPI);
1743   clang::FunctionDecl *FD = nullptr;
1744 
1745   // Generate rsClearObject() call chains for every global variable
1746   // (whether static or extern).
1747   std::vector<clang::Stmt *> StmtList;
1748   for (clang::DeclContext::decl_iterator I = DC->decls_begin(),
1749           E = DC->decls_end(); I != E; I++) {
1750     clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*I);
1751     if (VD) {
1752       if (CountRSObjectTypes(VD->getType().getTypePtr())) {
1753         if (!FD) {
1754           // Only create FD if we are going to use it.
1755           FD = clang::FunctionDecl::Create(mCtx, DC, loc, loc, N, T, nullptr,
1756                                            clang::SC_None);
1757         }
1758         // Mark VD as used.  It might be unused, except for the destructor.
1759         // 'markUsed' has side-effects that are caused only if VD is not already
1760         // used.  Hence no need for an extra check here.
1761         VD->markUsed(mCtx);
1762         // Make sure to create any helpers within the function's DeclContext,
1763         // not the one associated with the global translation unit.
1764         clang::Stmt *RSClearObjectCall = Scope::ClearRSObject(VD, FD);
1765         StmtList.push_back(RSClearObjectCall);
1766       }
1767     }
1768   }
1769 
1770   // Nothing needs to be destroyed, so don't emit a dtor.
1771   if (StmtList.empty()) {
1772     return nullptr;
1773   }
1774 
1775   clang::CompoundStmt *CS = BuildCompoundStmt(mCtx, StmtList, loc);
1776 
1777   FD->setBody(CS);
1778   // We need some way to tell if this FD is generated by slang
1779   FD->setImplicit();
1780 
1781   return FD;
1782 }
1783 
HasRSObjectType(const clang::Type * T)1784 bool HasRSObjectType(const clang::Type *T) {
1785   return CountRSObjectTypes(T) != 0;
1786 }
1787 
1788 }  // namespace slang
1789