1 /*
2  * Copyright 2010, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "slang_rs_object_ref_count.h"
18 
19 #include <list>
20 
21 #include "clang/AST/DeclGroup.h"
22 #include "clang/AST/Expr.h"
23 #include "clang/AST/NestedNameSpecifier.h"
24 #include "clang/AST/OperationKinds.h"
25 #include "clang/AST/Stmt.h"
26 #include "clang/AST/StmtVisitor.h"
27 
28 #include "slang_assert.h"
29 #include "slang.h"
30 #include "slang_rs_ast_replace.h"
31 #include "slang_rs_export_type.h"
32 
33 namespace slang {
34 
35 /* Even though those two arrays are of size DataTypeMax, only entries that
36  * correspond to object types will be set.
37  */
38 clang::FunctionDecl *
39 RSObjectRefCount::RSSetObjectFD[DataTypeMax];
40 clang::FunctionDecl *
41 RSObjectRefCount::RSClearObjectFD[DataTypeMax];
42 
GetRSRefCountingFunctions(clang::ASTContext & C)43 void RSObjectRefCount::GetRSRefCountingFunctions(clang::ASTContext &C) {
44   for (unsigned i = 0; i < DataTypeMax; i++) {
45     RSSetObjectFD[i] = nullptr;
46     RSClearObjectFD[i] = nullptr;
47   }
48 
49   clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
50 
51   for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
52           E = TUDecl->decls_end(); I != E; I++) {
53     if ((I->getKind() >= clang::Decl::firstFunction) &&
54         (I->getKind() <= clang::Decl::lastFunction)) {
55       clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
56 
57       // points to RSSetObjectFD or RSClearObjectFD
58       clang::FunctionDecl **RSObjectFD;
59 
60       if (FD->getName() == "rsSetObject") {
61         slangAssert((FD->getNumParams() == 2) &&
62                     "Invalid rsSetObject function prototype (# params)");
63         RSObjectFD = RSSetObjectFD;
64       } else if (FD->getName() == "rsClearObject") {
65         slangAssert((FD->getNumParams() == 1) &&
66                     "Invalid rsClearObject function prototype (# params)");
67         RSObjectFD = RSClearObjectFD;
68       } else {
69         continue;
70       }
71 
72       const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
73       clang::QualType PVT = PVD->getOriginalType();
74       // The first parameter must be a pointer like rs_allocation*
75       slangAssert(PVT->isPointerType() &&
76           "Invalid rs{Set,Clear}Object function prototype (pointer param)");
77 
78       // The rs object type passed to the FD
79       clang::QualType RST = PVT->getPointeeType();
80       DataType DT = RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
81       slangAssert(RSExportPrimitiveType::IsRSObjectType(DT)
82              && "must be RS object type");
83 
84       if (DT >= 0 && DT < DataTypeMax) {
85           RSObjectFD[DT] = FD;
86       } else {
87           slangAssert(false && "incorrect type");
88       }
89     }
90   }
91 }
92 
93 namespace {
94 
95 unsigned CountRSObjectTypes(const clang::Type *T);
96 
97 clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
98                                      clang::Expr *DstExpr,
99                                      clang::Expr *SrcExpr,
100                                      clang::SourceLocation StartLoc,
101                                      clang::SourceLocation Loc);
102 
103 // This function constructs a new CompoundStmt from the input StmtList.
BuildCompoundStmt(clang::ASTContext & C,std::list<clang::Stmt * > & StmtList,clang::SourceLocation Loc)104 clang::CompoundStmt* BuildCompoundStmt(clang::ASTContext &C,
105       std::list<clang::Stmt*> &StmtList, clang::SourceLocation Loc) {
106   unsigned NewStmtCount = StmtList.size();
107   unsigned CompoundStmtCount = 0;
108 
109   clang::Stmt **CompoundStmtList;
110   CompoundStmtList = new clang::Stmt*[NewStmtCount];
111 
112   std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
113   std::list<clang::Stmt*>::const_iterator E = StmtList.end();
114   for ( ; I != E; I++) {
115     CompoundStmtList[CompoundStmtCount++] = *I;
116   }
117   slangAssert(CompoundStmtCount == NewStmtCount);
118 
119   clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
120       C, llvm::makeArrayRef(CompoundStmtList, CompoundStmtCount), Loc, Loc);
121 
122   delete [] CompoundStmtList;
123 
124   return CS;
125 }
126 
AppendAfterStmt(clang::ASTContext & C,clang::CompoundStmt * CS,clang::Stmt * S,std::list<clang::Stmt * > & StmtList)127 void AppendAfterStmt(clang::ASTContext &C,
128                      clang::CompoundStmt *CS,
129                      clang::Stmt *S,
130                      std::list<clang::Stmt*> &StmtList) {
131   slangAssert(CS);
132   clang::CompoundStmt::body_iterator bI = CS->body_begin();
133   clang::CompoundStmt::body_iterator bE = CS->body_end();
134   clang::Stmt **UpdatedStmtList =
135       new clang::Stmt*[CS->size() + StmtList.size()];
136 
137   unsigned UpdatedStmtCount = 0;
138   unsigned Once = 0;
139   for ( ; bI != bE; bI++) {
140     if (!S && ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass)) {
141       // If we come across a return here, we don't have anything we can
142       // reasonably replace. We should have already inserted our destructor
143       // code in the proper spot, so we just clean up and return.
144       delete [] UpdatedStmtList;
145 
146       return;
147     }
148 
149     UpdatedStmtList[UpdatedStmtCount++] = *bI;
150 
151     if ((*bI == S) && !Once) {
152       Once++;
153       std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
154       std::list<clang::Stmt*>::const_iterator E = StmtList.end();
155       for ( ; I != E; I++) {
156         UpdatedStmtList[UpdatedStmtCount++] = *I;
157       }
158     }
159   }
160   slangAssert(Once <= 1);
161 
162   // When S is nullptr, we are appending to the end of the CompoundStmt.
163   if (!S) {
164     slangAssert(Once == 0);
165     std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
166     std::list<clang::Stmt*>::const_iterator E = StmtList.end();
167     for ( ; I != E; I++) {
168       UpdatedStmtList[UpdatedStmtCount++] = *I;
169     }
170   }
171 
172   CS->setStmts(C, llvm::makeArrayRef(UpdatedStmtList, UpdatedStmtCount));
173 
174   delete [] UpdatedStmtList;
175 }
176 
177 // This class visits a compound statement and collects a list of all the exiting
178 // statements, such as any return statement in any sub-block, and any
179 // break/continue statement that would resume outside the current scope.
180 // We do not handle the case for goto statements that leave a local scope.
181 class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
182  private:
183   // The loop depth of the currently visited node.
184   int mLoopDepth;
185 
186   // The switch statement depth of the currently visited node.
187   // Note that this is tracked separately from the loop depth because
188   // SwitchStmt-contained ContinueStmt's should have destructors for the
189   // corresponding loop scope.
190   int mSwitchDepth;
191 
192   // Output of the visitor: the statements that should be replaced by compound
193   // statements, each of which contains rsClearObject() calls followed by the
194   // original statement.
195   std::vector<clang::Stmt*> mExitingStmts;
196 
197  public:
DestructorVisitor()198   DestructorVisitor() : mLoopDepth(0), mSwitchDepth(0) {}
199 
getExitingStmts() const200   const std::vector<clang::Stmt*>& getExitingStmts() const {
201     return mExitingStmts;
202   }
203 
204   void VisitStmt(clang::Stmt *S);
205   void VisitBreakStmt(clang::BreakStmt *BS);
206   void VisitContinueStmt(clang::ContinueStmt *CS);
207   void VisitDoStmt(clang::DoStmt *DS);
208   void VisitForStmt(clang::ForStmt *FS);
209   void VisitReturnStmt(clang::ReturnStmt *RS);
210   void VisitSwitchStmt(clang::SwitchStmt *SS);
211   void VisitWhileStmt(clang::WhileStmt *WS);
212 };
213 
VisitStmt(clang::Stmt * S)214 void DestructorVisitor::VisitStmt(clang::Stmt *S) {
215   for (clang::Stmt* Child : S->children()) {
216     if (Child) {
217       Visit(Child);
218     }
219   }
220 }
221 
VisitBreakStmt(clang::BreakStmt * BS)222 void DestructorVisitor::VisitBreakStmt(clang::BreakStmt *BS) {
223   VisitStmt(BS);
224   if ((mLoopDepth == 0) && (mSwitchDepth == 0)) {
225     mExitingStmts.push_back(BS);
226   }
227 }
228 
VisitContinueStmt(clang::ContinueStmt * CS)229 void DestructorVisitor::VisitContinueStmt(clang::ContinueStmt *CS) {
230   VisitStmt(CS);
231   if (mLoopDepth == 0) {
232     // Switch statements can have nested continues.
233     mExitingStmts.push_back(CS);
234   }
235 }
236 
VisitDoStmt(clang::DoStmt * DS)237 void DestructorVisitor::VisitDoStmt(clang::DoStmt *DS) {
238   mLoopDepth++;
239   VisitStmt(DS);
240   mLoopDepth--;
241 }
242 
VisitForStmt(clang::ForStmt * FS)243 void DestructorVisitor::VisitForStmt(clang::ForStmt *FS) {
244   mLoopDepth++;
245   VisitStmt(FS);
246   mLoopDepth--;
247 }
248 
VisitReturnStmt(clang::ReturnStmt * RS)249 void DestructorVisitor::VisitReturnStmt(clang::ReturnStmt *RS) {
250   mExitingStmts.push_back(RS);
251 }
252 
VisitSwitchStmt(clang::SwitchStmt * SS)253 void DestructorVisitor::VisitSwitchStmt(clang::SwitchStmt *SS) {
254   mSwitchDepth++;
255   VisitStmt(SS);
256   mSwitchDepth--;
257 }
258 
VisitWhileStmt(clang::WhileStmt * WS)259 void DestructorVisitor::VisitWhileStmt(clang::WhileStmt *WS) {
260   mLoopDepth++;
261   VisitStmt(WS);
262   mLoopDepth--;
263 }
264 
ClearSingleRSObject(clang::ASTContext & C,clang::Expr * RefRSVar,clang::SourceLocation Loc)265 clang::Expr *ClearSingleRSObject(clang::ASTContext &C,
266                                  clang::Expr *RefRSVar,
267                                  clang::SourceLocation Loc) {
268   slangAssert(RefRSVar);
269   const clang::Type *T = RefRSVar->getType().getTypePtr();
270   slangAssert(!T->isArrayType() &&
271               "Should not be destroying arrays with this function");
272 
273   clang::FunctionDecl *ClearObjectFD = RSObjectRefCount::GetRSClearObjectFD(T);
274   slangAssert((ClearObjectFD != nullptr) &&
275               "rsClearObject doesn't cover all RS object types");
276 
277   clang::QualType ClearObjectFDType = ClearObjectFD->getType();
278   clang::QualType ClearObjectFDArgType =
279       ClearObjectFD->getParamDecl(0)->getOriginalType();
280 
281   // Example destructor for "rs_font localFont;"
282   //
283   // (CallExpr 'void'
284   //   (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
285   //     (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
286   //   (UnaryOperator 'rs_font *' prefix '&'
287   //     (DeclRefExpr 'rs_font':'rs_font' Var='localFont')))
288 
289   // Get address of targeted RS object
290   clang::Expr *AddrRefRSVar =
291       new(C) clang::UnaryOperator(RefRSVar,
292                                   clang::UO_AddrOf,
293                                   ClearObjectFDArgType,
294                                   clang::VK_RValue,
295                                   clang::OK_Ordinary,
296                                   Loc);
297 
298   clang::Expr *RefRSClearObjectFD =
299       clang::DeclRefExpr::Create(C,
300                                  clang::NestedNameSpecifierLoc(),
301                                  clang::SourceLocation(),
302                                  ClearObjectFD,
303                                  false,
304                                  ClearObjectFD->getLocation(),
305                                  ClearObjectFDType,
306                                  clang::VK_RValue,
307                                  nullptr);
308 
309   clang::Expr *RSClearObjectFP =
310       clang::ImplicitCastExpr::Create(C,
311                                       C.getPointerType(ClearObjectFDType),
312                                       clang::CK_FunctionToPointerDecay,
313                                       RefRSClearObjectFD,
314                                       nullptr,
315                                       clang::VK_RValue);
316 
317   llvm::SmallVector<clang::Expr*, 1> ArgList;
318   ArgList.push_back(AddrRefRSVar);
319 
320   clang::CallExpr *RSClearObjectCall =
321       new(C) clang::CallExpr(C,
322                              RSClearObjectFP,
323                              ArgList,
324                              ClearObjectFD->getCallResultType(),
325                              clang::VK_RValue,
326                              Loc);
327 
328   return RSClearObjectCall;
329 }
330 
ArrayDim(const clang::Type * T)331 static int ArrayDim(const clang::Type *T) {
332   if (!T || !T->isArrayType()) {
333     return 0;
334   }
335 
336   const clang::ConstantArrayType *CAT =
337     static_cast<const clang::ConstantArrayType *>(T);
338   return static_cast<int>(CAT->getSize().getSExtValue());
339 }
340 
341 clang::Stmt *ClearStructRSObject(
342     clang::ASTContext &C,
343     clang::DeclContext *DC,
344     clang::Expr *RefRSStruct,
345     clang::SourceLocation StartLoc,
346     clang::SourceLocation Loc);
347 
ClearArrayRSObject(clang::ASTContext & C,clang::DeclContext * DC,clang::Expr * RefRSArr,clang::SourceLocation StartLoc,clang::SourceLocation Loc)348 clang::Stmt *ClearArrayRSObject(
349     clang::ASTContext &C,
350     clang::DeclContext *DC,
351     clang::Expr *RefRSArr,
352     clang::SourceLocation StartLoc,
353     clang::SourceLocation Loc) {
354   const clang::Type *BaseType = RefRSArr->getType().getTypePtr();
355   slangAssert(BaseType->isArrayType());
356 
357   int NumArrayElements = ArrayDim(BaseType);
358   // Actually extract out the base RS object type for use later
359   BaseType = BaseType->getArrayElementTypeNoTypeQual();
360 
361   if (NumArrayElements <= 0) {
362     return nullptr;
363   }
364 
365   // Example destructor loop for "rs_font fontArr[10];"
366   //
367   // (ForStmt
368   //   (DeclStmt
369   //     (VarDecl used rsIntIter 'int' cinit
370   //       (IntegerLiteral 'int' 0)))
371   //   (BinaryOperator 'int' '<'
372   //     (ImplicitCastExpr int LValueToRValue
373   //       (DeclRefExpr 'int' Var='rsIntIter'))
374   //     (IntegerLiteral 'int' 10)
375   //   nullptr << CondVar >>
376   //   (UnaryOperator 'int' postfix '++'
377   //     (DeclRefExpr 'int' Var='rsIntIter'))
378   //   (CallExpr 'void'
379   //     (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
380   //       (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
381   //     (UnaryOperator 'rs_font *' prefix '&'
382   //       (ArraySubscriptExpr 'rs_font':'rs_font'
383   //         (ImplicitCastExpr 'rs_font *' <ArrayToPointerDecay>
384   //           (DeclRefExpr 'rs_font [10]' Var='fontArr'))
385   //         (DeclRefExpr 'int' Var='rsIntIter'))))))
386 
387   // Create helper variable for iterating through elements
388   static unsigned sIterCounter = 0;
389   std::stringstream UniqueIterName;
390   UniqueIterName << "rsIntIter" << sIterCounter++;
391   clang::IdentifierInfo *II = &C.Idents.get(UniqueIterName.str());
392   clang::VarDecl *IIVD =
393       clang::VarDecl::Create(C,
394                              DC,
395                              StartLoc,
396                              Loc,
397                              II,
398                              C.IntTy,
399                              C.getTrivialTypeSourceInfo(C.IntTy),
400                              clang::SC_None);
401   // Mark "rsIntIter" as used
402   IIVD->markUsed(C);
403 
404   // Form the actual destructor loop
405   // for (Init; Cond; Inc)
406   //   RSClearObjectCall;
407 
408   // Init -> "int rsIntIter = 0"
409   clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
410       llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
411   IIVD->setInit(Int0);
412 
413   clang::Decl *IID = (clang::Decl *)IIVD;
414   clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
415   clang::Stmt *Init = new(C) clang::DeclStmt(DGR, Loc, Loc);
416 
417   // Cond -> "rsIntIter < NumArrayElements"
418   clang::DeclRefExpr *RefrsIntIterLValue =
419       clang::DeclRefExpr::Create(C,
420                                  clang::NestedNameSpecifierLoc(),
421                                  clang::SourceLocation(),
422                                  IIVD,
423                                  false,
424                                  Loc,
425                                  C.IntTy,
426                                  clang::VK_LValue,
427                                  nullptr);
428 
429   clang::Expr *RefrsIntIterRValue =
430       clang::ImplicitCastExpr::Create(C,
431                                       RefrsIntIterLValue->getType(),
432                                       clang::CK_LValueToRValue,
433                                       RefrsIntIterLValue,
434                                       nullptr,
435                                       clang::VK_RValue);
436 
437   clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
438       llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
439 
440   clang::BinaryOperator *Cond =
441       new(C) clang::BinaryOperator(RefrsIntIterRValue,
442                                    NumArrayElementsExpr,
443                                    clang::BO_LT,
444                                    C.IntTy,
445                                    clang::VK_RValue,
446                                    clang::OK_Ordinary,
447                                    Loc,
448                                    false);
449 
450   // Inc -> "rsIntIter++"
451   clang::UnaryOperator *Inc =
452       new(C) clang::UnaryOperator(RefrsIntIterLValue,
453                                   clang::UO_PostInc,
454                                   C.IntTy,
455                                   clang::VK_RValue,
456                                   clang::OK_Ordinary,
457                                   Loc);
458 
459   // Body -> "rsClearObject(&VD[rsIntIter]);"
460   // Destructor loop operates on individual array elements
461 
462   clang::Expr *RefRSArrPtr =
463       clang::ImplicitCastExpr::Create(C,
464           C.getPointerType(BaseType->getCanonicalTypeInternal()),
465           clang::CK_ArrayToPointerDecay,
466           RefRSArr,
467           nullptr,
468           clang::VK_RValue);
469 
470   clang::Expr *RefRSArrPtrSubscript =
471       new(C) clang::ArraySubscriptExpr(RefRSArrPtr,
472                                        RefrsIntIterRValue,
473                                        BaseType->getCanonicalTypeInternal(),
474                                        clang::VK_RValue,
475                                        clang::OK_Ordinary,
476                                        Loc);
477 
478   DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
479 
480   clang::Stmt *RSClearObjectCall = nullptr;
481   if (BaseType->isArrayType()) {
482     RSClearObjectCall =
483         ClearArrayRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
484   } else if (DT == DataTypeUnknown) {
485     RSClearObjectCall =
486         ClearStructRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
487   } else {
488     RSClearObjectCall = ClearSingleRSObject(C, RefRSArrPtrSubscript, Loc);
489   }
490 
491   clang::ForStmt *DestructorLoop =
492       new(C) clang::ForStmt(C,
493                             Init,
494                             Cond,
495                             nullptr,  // no condVar
496                             Inc,
497                             RSClearObjectCall,
498                             Loc,
499                             Loc,
500                             Loc);
501 
502   return DestructorLoop;
503 }
504 
CountRSObjectTypes(const clang::Type * T)505 unsigned CountRSObjectTypes(const clang::Type *T) {
506   slangAssert(T);
507   unsigned RSObjectCount = 0;
508 
509   if (T->isArrayType()) {
510     return CountRSObjectTypes(T->getArrayElementTypeNoTypeQual());
511   }
512 
513   DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
514   if (DT != DataTypeUnknown) {
515     return (RSExportPrimitiveType::IsRSObjectType(DT) ? 1 : 0);
516   }
517 
518   if (T->isUnionType()) {
519     clang::RecordDecl *RD = T->getAsUnionType()->getDecl();
520     RD = RD->getDefinition();
521     for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
522            FE = RD->field_end();
523          FI != FE;
524          FI++) {
525       const clang::FieldDecl *FD = *FI;
526       const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
527       if (CountRSObjectTypes(FT)) {
528         slangAssert(false && "can't have unions with RS object types!");
529         return 0;
530       }
531     }
532   }
533 
534   if (!T->isStructureType()) {
535     return 0;
536   }
537 
538   clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
539   RD = RD->getDefinition();
540   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
541          FE = RD->field_end();
542        FI != FE;
543        FI++) {
544     const clang::FieldDecl *FD = *FI;
545     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
546     if (CountRSObjectTypes(FT)) {
547       // Sub-structs should only count once (as should arrays, etc.)
548       RSObjectCount++;
549     }
550   }
551 
552   return RSObjectCount;
553 }
554 
ClearStructRSObject(clang::ASTContext & C,clang::DeclContext * DC,clang::Expr * RefRSStruct,clang::SourceLocation StartLoc,clang::SourceLocation Loc)555 clang::Stmt *ClearStructRSObject(
556     clang::ASTContext &C,
557     clang::DeclContext *DC,
558     clang::Expr *RefRSStruct,
559     clang::SourceLocation StartLoc,
560     clang::SourceLocation Loc) {
561   const clang::Type *BaseType = RefRSStruct->getType().getTypePtr();
562 
563   slangAssert(!BaseType->isArrayType());
564 
565   // Structs should show up as unknown primitive types
566   slangAssert(RSExportPrimitiveType::GetRSSpecificType(BaseType) ==
567               DataTypeUnknown);
568 
569   unsigned FieldsToDestroy = CountRSObjectTypes(BaseType);
570   slangAssert(FieldsToDestroy != 0);
571 
572   unsigned StmtCount = 0;
573   clang::Stmt **StmtArray = new clang::Stmt*[FieldsToDestroy];
574   for (unsigned i = 0; i < FieldsToDestroy; i++) {
575     StmtArray[i] = nullptr;
576   }
577 
578   // Populate StmtArray by creating a destructor for each RS object field
579   clang::RecordDecl *RD = BaseType->getAsStructureType()->getDecl();
580   RD = RD->getDefinition();
581   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
582          FE = RD->field_end();
583        FI != FE;
584        FI++) {
585     // We just look through all field declarations to see if we find a
586     // declaration for an RS object type (or an array of one).
587     bool IsArrayType = false;
588     clang::FieldDecl *FD = *FI;
589     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
590     const clang::Type *OrigType = FT;
591     while (FT && FT->isArrayType()) {
592       FT = FT->getArrayElementTypeNoTypeQual();
593       IsArrayType = true;
594     }
595 
596     // Pass a DeclarationNameInfo with a valid DeclName, since name equality
597     // gets asserted during CodeGen.
598     clang::DeclarationNameInfo FDDeclNameInfo(FD->getDeclName(),
599                                               FD->getLocation());
600 
601     if (RSExportPrimitiveType::IsRSObjectType(FT)) {
602       clang::DeclAccessPair FoundDecl =
603           clang::DeclAccessPair::make(FD, clang::AS_none);
604       clang::MemberExpr *RSObjectMember =
605           clang::MemberExpr::Create(C,
606                                     RefRSStruct,
607                                     false,
608                                     clang::SourceLocation(),
609                                     clang::NestedNameSpecifierLoc(),
610                                     clang::SourceLocation(),
611                                     FD,
612                                     FoundDecl,
613                                     FDDeclNameInfo,
614                                     nullptr,
615                                     OrigType->getCanonicalTypeInternal(),
616                                     clang::VK_RValue,
617                                     clang::OK_Ordinary);
618 
619       slangAssert(StmtCount < FieldsToDestroy);
620 
621       if (IsArrayType) {
622         StmtArray[StmtCount++] = ClearArrayRSObject(C,
623                                                     DC,
624                                                     RSObjectMember,
625                                                     StartLoc,
626                                                     Loc);
627       } else {
628         StmtArray[StmtCount++] = ClearSingleRSObject(C,
629                                                      RSObjectMember,
630                                                      Loc);
631       }
632     } else if (FT->isStructureType() && CountRSObjectTypes(FT)) {
633       // In this case, we have a nested struct. We may not end up filling all
634       // of the spaces in StmtArray (sub-structs should handle themselves
635       // with separate compound statements).
636       clang::DeclAccessPair FoundDecl =
637           clang::DeclAccessPair::make(FD, clang::AS_none);
638       clang::MemberExpr *RSObjectMember =
639           clang::MemberExpr::Create(C,
640                                     RefRSStruct,
641                                     false,
642                                     clang::SourceLocation(),
643                                     clang::NestedNameSpecifierLoc(),
644                                     clang::SourceLocation(),
645                                     FD,
646                                     FoundDecl,
647                                     clang::DeclarationNameInfo(),
648                                     nullptr,
649                                     OrigType->getCanonicalTypeInternal(),
650                                     clang::VK_RValue,
651                                     clang::OK_Ordinary);
652 
653       if (IsArrayType) {
654         StmtArray[StmtCount++] = ClearArrayRSObject(C,
655                                                     DC,
656                                                     RSObjectMember,
657                                                     StartLoc,
658                                                     Loc);
659       } else {
660         StmtArray[StmtCount++] = ClearStructRSObject(C,
661                                                      DC,
662                                                      RSObjectMember,
663                                                      StartLoc,
664                                                      Loc);
665       }
666     }
667   }
668 
669   slangAssert(StmtCount > 0);
670   clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
671       C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
672 
673   delete [] StmtArray;
674 
675   return CS;
676 }
677 
CreateSingleRSSetObject(clang::ASTContext & C,clang::Expr * DstExpr,clang::Expr * SrcExpr,clang::SourceLocation StartLoc,clang::SourceLocation Loc)678 clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
679                                      clang::Expr *DstExpr,
680                                      clang::Expr *SrcExpr,
681                                      clang::SourceLocation StartLoc,
682                                      clang::SourceLocation Loc) {
683   const clang::Type *T = DstExpr->getType().getTypePtr();
684   clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(T);
685   slangAssert((SetObjectFD != nullptr) &&
686               "rsSetObject doesn't cover all RS object types");
687 
688   clang::QualType SetObjectFDType = SetObjectFD->getType();
689   clang::QualType SetObjectFDArgType[2];
690   SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
691   SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
692 
693   clang::Expr *RefRSSetObjectFD =
694       clang::DeclRefExpr::Create(C,
695                                  clang::NestedNameSpecifierLoc(),
696                                  clang::SourceLocation(),
697                                  SetObjectFD,
698                                  false,
699                                  Loc,
700                                  SetObjectFDType,
701                                  clang::VK_RValue,
702                                  nullptr);
703 
704   clang::Expr *RSSetObjectFP =
705       clang::ImplicitCastExpr::Create(C,
706                                       C.getPointerType(SetObjectFDType),
707                                       clang::CK_FunctionToPointerDecay,
708                                       RefRSSetObjectFD,
709                                       nullptr,
710                                       clang::VK_RValue);
711 
712   llvm::SmallVector<clang::Expr*, 2> ArgList;
713   ArgList.push_back(new(C) clang::UnaryOperator(DstExpr,
714                                                 clang::UO_AddrOf,
715                                                 SetObjectFDArgType[0],
716                                                 clang::VK_RValue,
717                                                 clang::OK_Ordinary,
718                                                 Loc));
719   ArgList.push_back(SrcExpr);
720 
721   clang::CallExpr *RSSetObjectCall =
722       new(C) clang::CallExpr(C,
723                              RSSetObjectFP,
724                              ArgList,
725                              SetObjectFD->getCallResultType(),
726                              clang::VK_RValue,
727                              Loc);
728 
729   return RSSetObjectCall;
730 }
731 
732 clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
733                                      clang::Expr *LHS,
734                                      clang::Expr *RHS,
735                                      clang::SourceLocation StartLoc,
736                                      clang::SourceLocation Loc);
737 
738 /*static clang::Stmt *CreateArrayRSSetObject(clang::ASTContext &C,
739                                            clang::Expr *DstArr,
740                                            clang::Expr *SrcArr,
741                                            clang::SourceLocation StartLoc,
742                                            clang::SourceLocation Loc) {
743   clang::DeclContext *DC = nullptr;
744   const clang::Type *BaseType = DstArr->getType().getTypePtr();
745   slangAssert(BaseType->isArrayType());
746 
747   int NumArrayElements = ArrayDim(BaseType);
748   // Actually extract out the base RS object type for use later
749   BaseType = BaseType->getArrayElementTypeNoTypeQual();
750 
751   clang::Stmt *StmtArray[2] = {nullptr};
752   int StmtCtr = 0;
753 
754   if (NumArrayElements <= 0) {
755     return nullptr;
756   }
757 
758   // Create helper variable for iterating through elements
759   clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
760   clang::VarDecl *IIVD =
761       clang::VarDecl::Create(C,
762                              DC,
763                              StartLoc,
764                              Loc,
765                              &II,
766                              C.IntTy,
767                              C.getTrivialTypeSourceInfo(C.IntTy),
768                              clang::SC_None,
769                              clang::SC_None);
770   clang::Decl *IID = (clang::Decl *)IIVD;
771 
772   clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
773   StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
774 
775   // Form the actual loop
776   // for (Init; Cond; Inc)
777   //   RSSetObjectCall;
778 
779   // Init -> "rsIntIter = 0"
780   clang::DeclRefExpr *RefrsIntIter =
781       clang::DeclRefExpr::Create(C,
782                                  clang::NestedNameSpecifierLoc(),
783                                  IIVD,
784                                  Loc,
785                                  C.IntTy,
786                                  clang::VK_RValue,
787                                  nullptr);
788 
789   clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
790       llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
791 
792   clang::BinaryOperator *Init =
793       new(C) clang::BinaryOperator(RefrsIntIter,
794                                    Int0,
795                                    clang::BO_Assign,
796                                    C.IntTy,
797                                    clang::VK_RValue,
798                                    clang::OK_Ordinary,
799                                    Loc);
800 
801   // Cond -> "rsIntIter < NumArrayElements"
802   clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
803       llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
804 
805   clang::BinaryOperator *Cond =
806       new(C) clang::BinaryOperator(RefrsIntIter,
807                                    NumArrayElementsExpr,
808                                    clang::BO_LT,
809                                    C.IntTy,
810                                    clang::VK_RValue,
811                                    clang::OK_Ordinary,
812                                    Loc);
813 
814   // Inc -> "rsIntIter++"
815   clang::UnaryOperator *Inc =
816       new(C) clang::UnaryOperator(RefrsIntIter,
817                                   clang::UO_PostInc,
818                                   C.IntTy,
819                                   clang::VK_RValue,
820                                   clang::OK_Ordinary,
821                                   Loc);
822 
823   // Body -> "rsSetObject(&Dst[rsIntIter], Src[rsIntIter]);"
824   // Loop operates on individual array elements
825 
826   clang::Expr *DstArrPtr =
827       clang::ImplicitCastExpr::Create(C,
828           C.getPointerType(BaseType->getCanonicalTypeInternal()),
829           clang::CK_ArrayToPointerDecay,
830           DstArr,
831           nullptr,
832           clang::VK_RValue);
833 
834   clang::Expr *DstArrPtrSubscript =
835       new(C) clang::ArraySubscriptExpr(DstArrPtr,
836                                        RefrsIntIter,
837                                        BaseType->getCanonicalTypeInternal(),
838                                        clang::VK_RValue,
839                                        clang::OK_Ordinary,
840                                        Loc);
841 
842   clang::Expr *SrcArrPtr =
843       clang::ImplicitCastExpr::Create(C,
844           C.getPointerType(BaseType->getCanonicalTypeInternal()),
845           clang::CK_ArrayToPointerDecay,
846           SrcArr,
847           nullptr,
848           clang::VK_RValue);
849 
850   clang::Expr *SrcArrPtrSubscript =
851       new(C) clang::ArraySubscriptExpr(SrcArrPtr,
852                                        RefrsIntIter,
853                                        BaseType->getCanonicalTypeInternal(),
854                                        clang::VK_RValue,
855                                        clang::OK_Ordinary,
856                                        Loc);
857 
858   DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
859 
860   clang::Stmt *RSSetObjectCall = nullptr;
861   if (BaseType->isArrayType()) {
862     RSSetObjectCall = CreateArrayRSSetObject(C, DstArrPtrSubscript,
863                                              SrcArrPtrSubscript,
864                                              StartLoc, Loc);
865   } else if (DT == DataTypeUnknown) {
866     RSSetObjectCall = CreateStructRSSetObject(C, DstArrPtrSubscript,
867                                               SrcArrPtrSubscript,
868                                               StartLoc, Loc);
869   } else {
870     RSSetObjectCall = CreateSingleRSSetObject(C, DstArrPtrSubscript,
871                                               SrcArrPtrSubscript,
872                                               StartLoc, Loc);
873   }
874 
875   clang::ForStmt *DestructorLoop =
876       new(C) clang::ForStmt(C,
877                             Init,
878                             Cond,
879                             nullptr,  // no condVar
880                             Inc,
881                             RSSetObjectCall,
882                             Loc,
883                             Loc,
884                             Loc);
885 
886   StmtArray[StmtCtr++] = DestructorLoop;
887   slangAssert(StmtCtr == 2);
888 
889   clang::CompoundStmt *CS =
890       new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
891 
892   return CS;
893 } */
894 
CreateStructRSSetObject(clang::ASTContext & C,clang::Expr * LHS,clang::Expr * RHS,clang::SourceLocation StartLoc,clang::SourceLocation Loc)895 clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
896                                      clang::Expr *LHS,
897                                      clang::Expr *RHS,
898                                      clang::SourceLocation StartLoc,
899                                      clang::SourceLocation Loc) {
900   clang::QualType QT = LHS->getType();
901   const clang::Type *T = QT.getTypePtr();
902   slangAssert(T->isStructureType());
903   slangAssert(!RSExportPrimitiveType::IsRSObjectType(T));
904 
905   // Keep an extra slot for the original copy (memcpy)
906   unsigned FieldsToSet = CountRSObjectTypes(T) + 1;
907 
908   unsigned StmtCount = 0;
909   clang::Stmt **StmtArray = new clang::Stmt*[FieldsToSet];
910   for (unsigned i = 0; i < FieldsToSet; i++) {
911     StmtArray[i] = nullptr;
912   }
913 
914   clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
915   RD = RD->getDefinition();
916   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
917          FE = RD->field_end();
918        FI != FE;
919        FI++) {
920     bool IsArrayType = false;
921     clang::FieldDecl *FD = *FI;
922     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
923     const clang::Type *OrigType = FT;
924 
925     if (!CountRSObjectTypes(FT)) {
926       // Skip to next if we don't have any viable RS object types
927       continue;
928     }
929 
930     clang::DeclAccessPair FoundDecl =
931         clang::DeclAccessPair::make(FD, clang::AS_none);
932     clang::MemberExpr *DstMember =
933         clang::MemberExpr::Create(C,
934                                   LHS,
935                                   false,
936                                   clang::SourceLocation(),
937                                   clang::NestedNameSpecifierLoc(),
938                                   clang::SourceLocation(),
939                                   FD,
940                                   FoundDecl,
941                                   clang::DeclarationNameInfo(),
942                                   nullptr,
943                                   OrigType->getCanonicalTypeInternal(),
944                                   clang::VK_RValue,
945                                   clang::OK_Ordinary);
946 
947     clang::MemberExpr *SrcMember =
948         clang::MemberExpr::Create(C,
949                                   RHS,
950                                   false,
951                                   clang::SourceLocation(),
952                                   clang::NestedNameSpecifierLoc(),
953                                   clang::SourceLocation(),
954                                   FD,
955                                   FoundDecl,
956                                   clang::DeclarationNameInfo(),
957                                   nullptr,
958                                   OrigType->getCanonicalTypeInternal(),
959                                   clang::VK_RValue,
960                                   clang::OK_Ordinary);
961 
962     if (FT->isArrayType()) {
963       FT = FT->getArrayElementTypeNoTypeQual();
964       IsArrayType = true;
965     }
966 
967     DataType DT = RSExportPrimitiveType::GetRSSpecificType(FT);
968 
969     if (IsArrayType) {
970       clang::DiagnosticsEngine &DiagEngine = C.getDiagnostics();
971       DiagEngine.Report(
972         clang::FullSourceLoc(Loc, C.getSourceManager()),
973         DiagEngine.getCustomDiagID(
974           clang::DiagnosticsEngine::Error,
975           "Arrays of RS object types within structures cannot be copied"));
976       // TODO(srhines): Support setting arrays of RS objects
977       // StmtArray[StmtCount++] =
978       //    CreateArrayRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
979     } else if (DT == DataTypeUnknown) {
980       StmtArray[StmtCount++] =
981           CreateStructRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
982     } else if (RSExportPrimitiveType::IsRSObjectType(DT)) {
983       StmtArray[StmtCount++] =
984           CreateSingleRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
985     } else {
986       slangAssert(false);
987     }
988   }
989 
990   slangAssert(StmtCount < FieldsToSet);
991 
992   // We still need to actually do the overall struct copy. For simplicity,
993   // we just do a straight-up assignment (which will still preserve all
994   // the proper RS object reference counts).
995   clang::BinaryOperator *CopyStruct =
996       new(C) clang::BinaryOperator(LHS, RHS, clang::BO_Assign, QT,
997                                    clang::VK_RValue, clang::OK_Ordinary, Loc,
998                                    false);
999   StmtArray[StmtCount++] = CopyStruct;
1000 
1001   clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
1002       C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
1003 
1004   delete [] StmtArray;
1005 
1006   return CS;
1007 }
1008 
1009 }  // namespace
1010 
InsertStmt(const clang::ASTContext & C,clang::Stmt * NewStmt)1011 void RSObjectRefCount::Scope::InsertStmt(const clang::ASTContext &C,
1012                                          clang::Stmt *NewStmt) {
1013   std::vector<clang::Stmt*> newBody;
1014   for (clang::Stmt* S1 : mCS->body()) {
1015     if (S1 == mCurrent) {
1016       newBody.push_back(NewStmt);
1017     }
1018     newBody.push_back(S1);
1019   }
1020   mCS->setStmts(C, newBody);
1021 }
1022 
ReplaceStmt(const clang::ASTContext & C,clang::Stmt * NewStmt)1023 void RSObjectRefCount::Scope::ReplaceStmt(const clang::ASTContext &C,
1024                                           clang::Stmt *NewStmt) {
1025   std::vector<clang::Stmt*> newBody;
1026   for (clang::Stmt* S1 : mCS->body()) {
1027     if (S1 == mCurrent) {
1028       newBody.push_back(NewStmt);
1029     } else {
1030       newBody.push_back(S1);
1031     }
1032   }
1033   mCS->setStmts(C, newBody);
1034 }
1035 
ReplaceExpr(const clang::ASTContext & C,clang::Expr * OldExpr,clang::Expr * NewExpr)1036 void RSObjectRefCount::Scope::ReplaceExpr(const clang::ASTContext& C,
1037                                           clang::Expr* OldExpr,
1038                                           clang::Expr* NewExpr) {
1039   RSASTReplace R(C);
1040   R.ReplaceStmt(mCurrent, OldExpr, NewExpr);
1041 }
1042 
ReplaceRSObjectAssignment(clang::BinaryOperator * AS)1043 void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
1044     clang::BinaryOperator *AS) {
1045 
1046   clang::QualType QT = AS->getType();
1047 
1048   clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1049       DataTypeRSAllocation)->getASTContext();
1050 
1051   clang::SourceLocation Loc = AS->getExprLoc();
1052   clang::SourceLocation StartLoc = AS->getLHS()->getExprLoc();
1053   clang::Stmt *UpdatedStmt = nullptr;
1054 
1055   if (!RSExportPrimitiveType::IsRSObjectType(QT.getTypePtr())) {
1056     // By definition, this is a struct assignment if we get here
1057     UpdatedStmt =
1058         CreateStructRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1059   } else {
1060     UpdatedStmt =
1061         CreateSingleRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1062   }
1063 
1064   RSASTReplace R(C);
1065   R.ReplaceStmt(mCS, AS, UpdatedStmt);
1066 }
1067 
AppendRSObjectInit(clang::VarDecl * VD,clang::DeclStmt * DS,DataType DT,clang::Expr * InitExpr)1068 void RSObjectRefCount::Scope::AppendRSObjectInit(
1069     clang::VarDecl *VD,
1070     clang::DeclStmt *DS,
1071     DataType DT,
1072     clang::Expr *InitExpr) {
1073   slangAssert(VD);
1074 
1075   if (!InitExpr) {
1076     return;
1077   }
1078 
1079   clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1080       DataTypeRSAllocation)->getASTContext();
1081   clang::SourceLocation Loc = RSObjectRefCount::GetRSSetObjectFD(
1082       DataTypeRSAllocation)->getLocation();
1083   clang::SourceLocation StartLoc = RSObjectRefCount::GetRSSetObjectFD(
1084       DataTypeRSAllocation)->getInnerLocStart();
1085 
1086   if (DT == DataTypeIsStruct) {
1087     const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1088     clang::DeclRefExpr *RefRSVar =
1089         clang::DeclRefExpr::Create(C,
1090                                    clang::NestedNameSpecifierLoc(),
1091                                    clang::SourceLocation(),
1092                                    VD,
1093                                    false,
1094                                    Loc,
1095                                    T->getCanonicalTypeInternal(),
1096                                    clang::VK_RValue,
1097                                    nullptr);
1098 
1099     clang::Stmt *RSSetObjectOps =
1100         CreateStructRSSetObject(C, RefRSVar, InitExpr, StartLoc, Loc);
1101 
1102     std::list<clang::Stmt*> StmtList;
1103     StmtList.push_back(RSSetObjectOps);
1104     AppendAfterStmt(C, mCS, DS, StmtList);
1105     return;
1106   }
1107 
1108   clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(DT);
1109   slangAssert((SetObjectFD != nullptr) &&
1110               "rsSetObject doesn't cover all RS object types");
1111 
1112   clang::QualType SetObjectFDType = SetObjectFD->getType();
1113   clang::QualType SetObjectFDArgType[2];
1114   SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
1115   SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
1116 
1117   clang::Expr *RefRSSetObjectFD =
1118       clang::DeclRefExpr::Create(C,
1119                                  clang::NestedNameSpecifierLoc(),
1120                                  clang::SourceLocation(),
1121                                  SetObjectFD,
1122                                  false,
1123                                  Loc,
1124                                  SetObjectFDType,
1125                                  clang::VK_RValue,
1126                                  nullptr);
1127 
1128   clang::Expr *RSSetObjectFP =
1129       clang::ImplicitCastExpr::Create(C,
1130                                       C.getPointerType(SetObjectFDType),
1131                                       clang::CK_FunctionToPointerDecay,
1132                                       RefRSSetObjectFD,
1133                                       nullptr,
1134                                       clang::VK_RValue);
1135 
1136   const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1137   clang::DeclRefExpr *RefRSVar =
1138       clang::DeclRefExpr::Create(C,
1139                                  clang::NestedNameSpecifierLoc(),
1140                                  clang::SourceLocation(),
1141                                  VD,
1142                                  false,
1143                                  Loc,
1144                                  T->getCanonicalTypeInternal(),
1145                                  clang::VK_RValue,
1146                                  nullptr);
1147 
1148   llvm::SmallVector<clang::Expr*, 2> ArgList;
1149   ArgList.push_back(new(C) clang::UnaryOperator(RefRSVar,
1150                                                 clang::UO_AddrOf,
1151                                                 SetObjectFDArgType[0],
1152                                                 clang::VK_RValue,
1153                                                 clang::OK_Ordinary,
1154                                                 Loc));
1155   ArgList.push_back(InitExpr);
1156 
1157   clang::CallExpr *RSSetObjectCall =
1158       new(C) clang::CallExpr(C,
1159                              RSSetObjectFP,
1160                              ArgList,
1161                              SetObjectFD->getCallResultType(),
1162                              clang::VK_RValue,
1163                              Loc);
1164 
1165   std::list<clang::Stmt*> StmtList;
1166   StmtList.push_back(RSSetObjectCall);
1167   AppendAfterStmt(C, mCS, DS, StmtList);
1168 }
1169 
InsertLocalVarDestructors()1170 void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
1171   if (mRSO.empty()) {
1172     return;
1173   }
1174 
1175   clang::DeclContext* DC = mRSO.front()->getDeclContext();
1176   clang::ASTContext& C = DC->getParentASTContext();
1177   clang::SourceManager& SM = C.getSourceManager();
1178 
1179   const auto& OccursBefore = [&SM] (clang::SourceLocation L1, clang::SourceLocation L2)->bool {
1180     return SM.isBeforeInTranslationUnit(L1, L2);
1181   };
1182   typedef std::map<clang::SourceLocation, clang::Stmt*, decltype(OccursBefore)> DMap;
1183 
1184   DMap dtors(OccursBefore);
1185 
1186   // Create rsClearObject calls. Note the DMap entries are sorted by the SourceLocation.
1187   for (clang::VarDecl* VD : mRSO) {
1188     clang::SourceLocation Loc = VD->getSourceRange().getBegin();
1189     clang::Stmt* RSClearObjectCall = ClearRSObject(VD, DC);
1190     dtors.insert(std::make_pair(Loc, RSClearObjectCall));
1191   }
1192 
1193   DestructorVisitor Visitor;
1194   Visitor.Visit(mCS);
1195 
1196   // Replace each exiting statement with a block that contains the original statement
1197   // and added rsClearObject() calls before it.
1198   for (clang::Stmt* S : Visitor.getExitingStmts()) {
1199 
1200     const clang::SourceLocation currentLoc = S->getLocStart();
1201 
1202     DMap::iterator firstDtorIter = dtors.begin();
1203     DMap::iterator currentDtorIter = firstDtorIter;
1204     DMap::iterator lastDtorIter = dtors.end();
1205 
1206     while (currentDtorIter != lastDtorIter &&
1207            OccursBefore(currentDtorIter->first, currentLoc)) {
1208       currentDtorIter++;
1209     }
1210 
1211     if (currentDtorIter == firstDtorIter) {
1212       continue;
1213     }
1214 
1215     std::list<clang::Stmt*> Stmts;
1216 
1217     // Insert rsClearObject() calls for all rsObjects declared before the current statement
1218     for(DMap::iterator it = firstDtorIter; it != currentDtorIter; it++) {
1219       Stmts.push_back(it->second);
1220     }
1221     Stmts.push_back(S);
1222 
1223     RSASTReplace R(C);
1224     clang::CompoundStmt* CS = BuildCompoundStmt(C, Stmts, S->getLocEnd());
1225     R.ReplaceStmt(mCS, S, CS);
1226   }
1227 
1228   std::list<clang::Stmt*> Stmts;
1229   for(auto LocCallPair : dtors) {
1230     Stmts.push_back(LocCallPair.second);
1231   }
1232   AppendAfterStmt(C, mCS, nullptr, Stmts);
1233 }
1234 
ClearRSObject(clang::VarDecl * VD,clang::DeclContext * DC)1235 clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(
1236     clang::VarDecl *VD,
1237     clang::DeclContext *DC) {
1238   slangAssert(VD);
1239   clang::ASTContext &C = VD->getASTContext();
1240   clang::SourceLocation Loc = VD->getLocation();
1241   clang::SourceLocation StartLoc = VD->getInnerLocStart();
1242   const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1243 
1244   // Reference expr to target RS object variable
1245   clang::DeclRefExpr *RefRSVar =
1246       clang::DeclRefExpr::Create(C,
1247                                  clang::NestedNameSpecifierLoc(),
1248                                  clang::SourceLocation(),
1249                                  VD,
1250                                  false,
1251                                  Loc,
1252                                  T->getCanonicalTypeInternal(),
1253                                  clang::VK_RValue,
1254                                  nullptr);
1255 
1256   if (T->isArrayType()) {
1257     return ClearArrayRSObject(C, DC, RefRSVar, StartLoc, Loc);
1258   }
1259 
1260   DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
1261 
1262   if (DT == DataTypeUnknown ||
1263       DT == DataTypeIsStruct) {
1264     return ClearStructRSObject(C, DC, RefRSVar, StartLoc, Loc);
1265   }
1266 
1267   slangAssert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
1268               "Should be RS object");
1269 
1270   return ClearSingleRSObject(C, RefRSVar, Loc);
1271 }
1272 
InitializeRSObject(clang::VarDecl * VD,DataType * DT,clang::Expr ** InitExpr)1273 bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD,
1274                                           DataType *DT,
1275                                           clang::Expr **InitExpr) {
1276   slangAssert(VD && DT && InitExpr);
1277   const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1278 
1279   // Loop through array types to get to base type
1280   while (T && T->isArrayType()) {
1281     T = T->getArrayElementTypeNoTypeQual();
1282   }
1283 
1284   bool DataTypeIsStructWithRSObject = false;
1285   *DT = RSExportPrimitiveType::GetRSSpecificType(T);
1286 
1287   if (*DT == DataTypeUnknown) {
1288     if (RSExportPrimitiveType::IsStructureTypeWithRSObject(T)) {
1289       *DT = DataTypeIsStruct;
1290       DataTypeIsStructWithRSObject = true;
1291     } else {
1292       return false;
1293     }
1294   }
1295 
1296   bool DataTypeIsRSObject = false;
1297   if (DataTypeIsStructWithRSObject) {
1298     DataTypeIsRSObject = true;
1299   } else {
1300     DataTypeIsRSObject = RSExportPrimitiveType::IsRSObjectType(*DT);
1301   }
1302   *InitExpr = VD->getInit();
1303 
1304   if (!DataTypeIsRSObject && *InitExpr) {
1305     // If we already have an initializer for a matrix type, we are done.
1306     return DataTypeIsRSObject;
1307   }
1308 
1309   clang::Expr *ZeroInitializer =
1310       CreateEmptyInitListExpr(VD->getASTContext(), VD->getLocation());
1311 
1312   if (ZeroInitializer) {
1313     ZeroInitializer->setType(T->getCanonicalTypeInternal());
1314     VD->setInit(ZeroInitializer);
1315   }
1316 
1317   return DataTypeIsRSObject;
1318 }
1319 
CreateEmptyInitListExpr(clang::ASTContext & C,const clang::SourceLocation & Loc)1320 clang::Expr *RSObjectRefCount::CreateEmptyInitListExpr(
1321     clang::ASTContext &C,
1322     const clang::SourceLocation &Loc) {
1323 
1324   // We can cheaply construct a zero initializer by just creating an empty
1325   // initializer list. Clang supports this extension to C(99), and will create
1326   // any necessary constructs to zero out the entire variable.
1327   llvm::SmallVector<clang::Expr*, 1> EmptyInitList;
1328   return new(C) clang::InitListExpr(C, Loc, EmptyInitList, Loc);
1329 }
1330 
CreateRetStmtWithTempVar(clang::ASTContext & C,clang::DeclContext * DC,clang::ReturnStmt * RS,const unsigned id)1331 clang::CompoundStmt* RSObjectRefCount::CreateRetStmtWithTempVar(
1332     clang::ASTContext& C,
1333     clang::DeclContext* DC,
1334     clang::ReturnStmt* RS,
1335     const unsigned id) {
1336   std::list<clang::Stmt*> NewStmts;
1337   // Since we insert rsClearObj() calls before the return statement, we need
1338   // to make sure none of the cleared RS objects are referenced in the
1339   // return statement.
1340   // For that, we create a new local variable named .rs.retval, assign the
1341   // original return expression to it, make all necessary rsClearObj()
1342   // calls, then return .rs.retval. Note rsClearObj() is not called on
1343   // .rs.retval.
1344 
1345   clang::SourceLocation Loc = RS->getLocStart();
1346   std::stringstream ss;
1347   ss << ".rs.retval" << id;
1348   llvm::StringRef VarName(ss.str());
1349 
1350   clang::Expr* RetVal = RS->getRetValue();
1351   const clang::QualType RetTy = RetVal->getType();
1352   clang::VarDecl* RSRetValDecl = clang::VarDecl::Create(
1353       C,                                     // AST context
1354       DC,                                    // Decl context
1355       Loc,                                   // Start location
1356       Loc,                                   // Id location
1357       &C.Idents.get(VarName),                // Id
1358       RetTy,                                 // Type
1359       C.getTrivialTypeSourceInfo(RetTy),     // Type info
1360       clang::SC_None                         // Storage class
1361   );
1362   const clang::Type *T = RetTy.getTypePtr();
1363   clang::Expr *ZeroInitializer =
1364       RSObjectRefCount::CreateEmptyInitListExpr(C, Loc);
1365   ZeroInitializer->setType(T->getCanonicalTypeInternal());
1366   RSRetValDecl->setInit(ZeroInitializer);
1367   clang::Decl* Decls[] = { RSRetValDecl };
1368   const clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(
1369       C, Decls, sizeof(Decls) / sizeof(*Decls));
1370   clang::DeclStmt* DS = new (C) clang::DeclStmt(DGR, Loc, Loc);
1371   NewStmts.push_back(DS);
1372 
1373   clang::DeclRefExpr* DRE = clang::DeclRefExpr::Create(
1374       C,
1375       clang::NestedNameSpecifierLoc(),       // QualifierLoc
1376       Loc,                                   // TemplateKWLoc
1377       RSRetValDecl,
1378       false,                                 // RefersToEnclosingVariableOrCapture
1379       Loc,                                   // NameLoc
1380       RetTy,
1381       clang::VK_LValue
1382   );
1383   clang::Stmt* SetRetTempVar = CreateSingleRSSetObject(C, DRE, RetVal, Loc, Loc);
1384   NewStmts.push_back(SetRetTempVar);
1385 
1386   // Creates a new return statement
1387   clang::ReturnStmt* NewRet = new (C) clang::ReturnStmt(RS->getReturnLoc());
1388   clang::Expr* CastExpr = clang::ImplicitCastExpr::Create(
1389       C,
1390       RetTy,
1391       clang::CK_LValueToRValue,
1392       DRE,
1393       nullptr,
1394       clang::VK_RValue
1395   );
1396   NewRet->setRetValue(CastExpr);
1397   NewStmts.push_back(NewRet);
1398 
1399   return BuildCompoundStmt(C, NewStmts, Loc);
1400 }
1401 
VisitDeclStmt(clang::DeclStmt * DS)1402 void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
1403   VisitStmt(DS);
1404   getCurrentScope()->setCurrentStmt(DS);
1405   for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
1406        I != E;
1407        I++) {
1408     clang::Decl *D = *I;
1409     if (D->getKind() == clang::Decl::Var) {
1410       clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
1411       DataType DT = DataTypeUnknown;
1412       clang::Expr *InitExpr = nullptr;
1413       if (InitializeRSObject(VD, &DT, &InitExpr)) {
1414         // We need to zero-init all RS object types (including matrices), ...
1415         getCurrentScope()->AppendRSObjectInit(VD, DS, DT, InitExpr);
1416         // ... but, only add to the list of RS objects if we have some
1417         // non-matrix RS object fields.
1418         if (CountRSObjectTypes(VD->getType().getTypePtr())) {
1419           getCurrentScope()->addRSObject(VD);
1420         }
1421       }
1422     }
1423   }
1424 }
1425 
VisitCallExpr(clang::CallExpr * CE)1426 void RSObjectRefCount::VisitCallExpr(clang::CallExpr* CE) {
1427   clang::QualType RetTy;
1428   const clang::FunctionDecl* FD = CE->getDirectCallee();
1429 
1430   if (FD) {
1431     // Direct calls
1432 
1433     RetTy = FD->getReturnType();
1434   } else {
1435     // Indirect calls through function pointers
1436 
1437     const clang::Expr* Callee = CE->getCallee();
1438     const clang::Type* CalleeType = Callee->getType().getTypePtr();
1439     const clang::PointerType* PtrType = CalleeType->getAs<clang::PointerType>();
1440 
1441     if (!PtrType) {
1442       return;
1443     }
1444 
1445     const clang::Type* PointeeType = PtrType->getPointeeType().getTypePtr();
1446     const clang::FunctionType* FuncType = PointeeType->getAs<clang::FunctionType>();
1447 
1448     if (!FuncType) {
1449       return;
1450     }
1451 
1452     RetTy = FuncType->getReturnType();
1453   }
1454 
1455   if (CountRSObjectTypes(RetTy.getTypePtr())==0) {
1456     return;
1457   }
1458 
1459   clang::SourceLocation Loc = CE->getSourceRange().getBegin();
1460   std::stringstream ss;
1461   ss << ".rs.tmp" << getNextID();
1462   llvm::StringRef VarName(ss.str());
1463 
1464   clang::VarDecl* TempVarDecl = clang::VarDecl::Create(
1465       mCtx,                                  // AST context
1466       GetDeclContext(),                      // Decl context
1467       Loc,                                   // Start location
1468       Loc,                                   // Id location
1469       &mCtx.Idents.get(VarName),             // Id
1470       RetTy,                                 // Type
1471       mCtx.getTrivialTypeSourceInfo(RetTy),  // Type info
1472       clang::SC_None                         // Storage class
1473   );
1474   TempVarDecl->setInit(CE);
1475   clang::Decl* Decls[] = { TempVarDecl };
1476   const clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(
1477       mCtx, Decls, sizeof(Decls) / sizeof(*Decls));
1478   clang::DeclStmt* DS = new (mCtx) clang::DeclStmt(DGR, Loc, Loc);
1479 
1480   getCurrentScope()->InsertStmt(mCtx, DS);
1481 
1482   clang::DeclRefExpr* DRE = clang::DeclRefExpr::Create(
1483       mCtx,                                  // AST context
1484       clang::NestedNameSpecifierLoc(),       // QualifierLoc
1485       Loc,                                   // TemplateKWLoc
1486       TempVarDecl,
1487       false,                                 // RefersToEnclosingVariableOrCapture
1488       Loc,                                   // NameLoc
1489       RetTy,
1490       clang::VK_LValue
1491   );
1492   clang::Expr* CastExpr = clang::ImplicitCastExpr::Create(
1493       mCtx,
1494       RetTy,
1495       clang::CK_LValueToRValue,
1496       DRE,
1497       nullptr,
1498       clang::VK_RValue
1499   );
1500 
1501   getCurrentScope()->ReplaceExpr(mCtx, CE, CastExpr);
1502 
1503   // Register TempVarDecl for destruction call (rsClearObj).
1504   getCurrentScope()->addRSObject(TempVarDecl);
1505 }
1506 
VisitCompoundStmt(clang::CompoundStmt * CS)1507 void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
1508   if (!emptyScope()) {
1509     getCurrentScope()->setCurrentStmt(CS);
1510   }
1511 
1512   if (!CS->body_empty()) {
1513     // Push a new scope
1514     Scope *S = new Scope(CS);
1515     mScopeStack.push_back(S);
1516 
1517     VisitStmt(CS);
1518 
1519     // Destroy the scope
1520     slangAssert((getCurrentScope() == S) && "Corrupted scope stack!");
1521     S->InsertLocalVarDestructors();
1522     mScopeStack.pop_back();
1523     delete S;
1524   }
1525 }
1526 
VisitBinAssign(clang::BinaryOperator * AS)1527 void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
1528   getCurrentScope()->setCurrentStmt(AS);
1529   clang::QualType QT = AS->getType();
1530 
1531   if (CountRSObjectTypes(QT.getTypePtr())) {
1532     getCurrentScope()->ReplaceRSObjectAssignment(AS);
1533   }
1534 }
1535 
VisitReturnStmt(clang::ReturnStmt * RS)1536 void RSObjectRefCount::VisitReturnStmt(clang::ReturnStmt *RS) {
1537   getCurrentScope()->setCurrentStmt(RS);
1538 
1539   // If there is no local rsObject declared so far, no need to transform the
1540   // return statement.
1541 
1542   bool RSObjDeclared = false;
1543 
1544   for (const Scope* S : mScopeStack) {
1545     if (S->hasRSObject()) {
1546       RSObjDeclared = true;
1547       break;
1548     }
1549   }
1550 
1551   if (!RSObjDeclared) {
1552     return;
1553   }
1554 
1555   // If the return statement does not return anything, or if it does not return
1556   // a rsObject, no need to transform it.
1557 
1558   clang::Expr* RetVal = RS->getRetValue();
1559   if (!RetVal || CountRSObjectTypes(RetVal->getType().getTypePtr()) == 0) {
1560     return;
1561   }
1562 
1563   // Transform the return statement so that it does not potentially return or
1564   // reference a rsObject that has been cleared.
1565 
1566   clang::CompoundStmt* NewRS;
1567   NewRS = CreateRetStmtWithTempVar(mCtx, GetDeclContext(), RS, getNextID());
1568 
1569   getCurrentScope()->ReplaceStmt(mCtx, NewRS);
1570 }
1571 
VisitStmt(clang::Stmt * S)1572 void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
1573   getCurrentScope()->setCurrentStmt(S);
1574   for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
1575        I != E;
1576        I++) {
1577     if (clang::Stmt *Child = *I) {
1578       Visit(Child);
1579     }
1580   }
1581 }
1582 
1583 // This function walks the list of global variables and (potentially) creates
1584 // a single global static destructor function that properly decrements
1585 // reference counts on the contained RS object types.
CreateStaticGlobalDtor()1586 clang::FunctionDecl *RSObjectRefCount::CreateStaticGlobalDtor() {
1587   Init();
1588 
1589   clang::DeclContext *DC = mCtx.getTranslationUnitDecl();
1590   clang::SourceLocation loc;
1591 
1592   llvm::StringRef SR(".rs.dtor");
1593   clang::IdentifierInfo &II = mCtx.Idents.get(SR);
1594   clang::DeclarationName N(&II);
1595   clang::FunctionProtoType::ExtProtoInfo EPI;
1596   clang::QualType T = mCtx.getFunctionType(mCtx.VoidTy,
1597       llvm::ArrayRef<clang::QualType>(), EPI);
1598   clang::FunctionDecl *FD = nullptr;
1599 
1600   // Generate rsClearObject() call chains for every global variable
1601   // (whether static or extern).
1602   std::list<clang::Stmt *> StmtList;
1603   for (clang::DeclContext::decl_iterator I = DC->decls_begin(),
1604           E = DC->decls_end(); I != E; I++) {
1605     clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*I);
1606     if (VD) {
1607       if (CountRSObjectTypes(VD->getType().getTypePtr())) {
1608         if (!FD) {
1609           // Only create FD if we are going to use it.
1610           FD = clang::FunctionDecl::Create(mCtx, DC, loc, loc, N, T, nullptr,
1611                                            clang::SC_None);
1612         }
1613         // Mark VD as used.  It might be unused, except for the destructor.
1614         // 'markUsed' has side-effects that are caused only if VD is not already
1615         // used.  Hence no need for an extra check here.
1616         VD->markUsed(mCtx);
1617         // Make sure to create any helpers within the function's DeclContext,
1618         // not the one associated with the global translation unit.
1619         clang::Stmt *RSClearObjectCall = Scope::ClearRSObject(VD, FD);
1620         StmtList.push_back(RSClearObjectCall);
1621       }
1622     }
1623   }
1624 
1625   // Nothing needs to be destroyed, so don't emit a dtor.
1626   if (StmtList.empty()) {
1627     return nullptr;
1628   }
1629 
1630   clang::CompoundStmt *CS = BuildCompoundStmt(mCtx, StmtList, loc);
1631 
1632   FD->setBody(CS);
1633 
1634   return FD;
1635 }
1636 
HasRSObjectType(const clang::Type * T)1637 bool HasRSObjectType(const clang::Type *T) {
1638   return CountRSObjectTypes(T) != 0;
1639 }
1640 
1641 }  // namespace slang
1642