1 /*
2  * Copyright 2010, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "slang_rs_object_ref_count.h"
18 
19 #include <list>
20 
21 #include "clang/AST/DeclGroup.h"
22 #include "clang/AST/Expr.h"
23 #include "clang/AST/NestedNameSpecifier.h"
24 #include "clang/AST/OperationKinds.h"
25 #include "clang/AST/Stmt.h"
26 #include "clang/AST/StmtVisitor.h"
27 
28 #include "slang_assert.h"
29 #include "slang.h"
30 #include "slang_rs_ast_replace.h"
31 #include "slang_rs_export_type.h"
32 
33 namespace slang {
34 
35 /* Even though those two arrays are of size DataTypeMax, only entries that
36  * correspond to object types will be set.
37  */
38 clang::FunctionDecl *
39 RSObjectRefCount::RSSetObjectFD[DataTypeMax];
40 clang::FunctionDecl *
41 RSObjectRefCount::RSClearObjectFD[DataTypeMax];
42 
GetRSRefCountingFunctions(clang::ASTContext & C)43 void RSObjectRefCount::GetRSRefCountingFunctions(clang::ASTContext &C) {
44   for (unsigned i = 0; i < DataTypeMax; i++) {
45     RSSetObjectFD[i] = nullptr;
46     RSClearObjectFD[i] = nullptr;
47   }
48 
49   clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
50 
51   for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
52           E = TUDecl->decls_end(); I != E; I++) {
53     if ((I->getKind() >= clang::Decl::firstFunction) &&
54         (I->getKind() <= clang::Decl::lastFunction)) {
55       clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
56 
57       // points to RSSetObjectFD or RSClearObjectFD
58       clang::FunctionDecl **RSObjectFD;
59 
60       if (FD->getName() == "rsSetObject") {
61         slangAssert((FD->getNumParams() == 2) &&
62                     "Invalid rsSetObject function prototype (# params)");
63         RSObjectFD = RSSetObjectFD;
64       } else if (FD->getName() == "rsClearObject") {
65         slangAssert((FD->getNumParams() == 1) &&
66                     "Invalid rsClearObject function prototype (# params)");
67         RSObjectFD = RSClearObjectFD;
68       } else {
69         continue;
70       }
71 
72       const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
73       clang::QualType PVT = PVD->getOriginalType();
74       // The first parameter must be a pointer like rs_allocation*
75       slangAssert(PVT->isPointerType() &&
76           "Invalid rs{Set,Clear}Object function prototype (pointer param)");
77 
78       // The rs object type passed to the FD
79       clang::QualType RST = PVT->getPointeeType();
80       DataType DT = RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
81       slangAssert(RSExportPrimitiveType::IsRSObjectType(DT)
82              && "must be RS object type");
83 
84       if (DT >= 0 && DT < DataTypeMax) {
85           RSObjectFD[DT] = FD;
86       } else {
87           slangAssert(false && "incorrect type");
88       }
89     }
90   }
91 }
92 
93 namespace {
94 
95 // This function constructs a new CompoundStmt from the input StmtList.
BuildCompoundStmt(clang::ASTContext & C,std::list<clang::Stmt * > & StmtList,clang::SourceLocation Loc)96 static clang::CompoundStmt* BuildCompoundStmt(clang::ASTContext &C,
97       std::list<clang::Stmt*> &StmtList, clang::SourceLocation Loc) {
98   unsigned NewStmtCount = StmtList.size();
99   unsigned CompoundStmtCount = 0;
100 
101   clang::Stmt **CompoundStmtList;
102   CompoundStmtList = new clang::Stmt*[NewStmtCount];
103 
104   std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
105   std::list<clang::Stmt*>::const_iterator E = StmtList.end();
106   for ( ; I != E; I++) {
107     CompoundStmtList[CompoundStmtCount++] = *I;
108   }
109   slangAssert(CompoundStmtCount == NewStmtCount);
110 
111   clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
112       C, llvm::makeArrayRef(CompoundStmtList, CompoundStmtCount), Loc, Loc);
113 
114   delete [] CompoundStmtList;
115 
116   return CS;
117 }
118 
AppendAfterStmt(clang::ASTContext & C,clang::CompoundStmt * CS,clang::Stmt * S,std::list<clang::Stmt * > & StmtList)119 static void AppendAfterStmt(clang::ASTContext &C,
120                             clang::CompoundStmt *CS,
121                             clang::Stmt *S,
122                             std::list<clang::Stmt*> &StmtList) {
123   slangAssert(CS);
124   clang::CompoundStmt::body_iterator bI = CS->body_begin();
125   clang::CompoundStmt::body_iterator bE = CS->body_end();
126   clang::Stmt **UpdatedStmtList =
127       new clang::Stmt*[CS->size() + StmtList.size()];
128 
129   unsigned UpdatedStmtCount = 0;
130   unsigned Once = 0;
131   for ( ; bI != bE; bI++) {
132     if (!S && ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass)) {
133       // If we come across a return here, we don't have anything we can
134       // reasonably replace. We should have already inserted our destructor
135       // code in the proper spot, so we just clean up and return.
136       delete [] UpdatedStmtList;
137 
138       return;
139     }
140 
141     UpdatedStmtList[UpdatedStmtCount++] = *bI;
142 
143     if ((*bI == S) && !Once) {
144       Once++;
145       std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
146       std::list<clang::Stmt*>::const_iterator E = StmtList.end();
147       for ( ; I != E; I++) {
148         UpdatedStmtList[UpdatedStmtCount++] = *I;
149       }
150     }
151   }
152   slangAssert(Once <= 1);
153 
154   // When S is nullptr, we are appending to the end of the CompoundStmt.
155   if (!S) {
156     slangAssert(Once == 0);
157     std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
158     std::list<clang::Stmt*>::const_iterator E = StmtList.end();
159     for ( ; I != E; I++) {
160       UpdatedStmtList[UpdatedStmtCount++] = *I;
161     }
162   }
163 
164   CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
165 
166   delete [] UpdatedStmtList;
167 }
168 
169 // This class visits a compound statement and inserts DtorStmt
170 // in proper locations. This includes inserting it before any
171 // return statement in any sub-block, at the end of the logical enclosing
172 // scope (compound statement), and/or before any break/continue statement that
173 // would resume outside the declared scope. We will not handle the case for
174 // goto statements that leave a local scope.
175 //
176 // To accomplish these goals, it collects a list of sub-Stmt's that
177 // correspond to scope exit points. It then uses an RSASTReplace visitor to
178 // transform the AST, inserting appropriate destructors before each of those
179 // sub-Stmt's (and also before the exit of the outermost containing Stmt for
180 // the scope).
181 class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
182  private:
183   clang::ASTContext &mCtx;
184 
185   // The loop depth of the currently visited node.
186   int mLoopDepth;
187 
188   // The switch statement depth of the currently visited node.
189   // Note that this is tracked separately from the loop depth because
190   // SwitchStmt-contained ContinueStmt's should have destructors for the
191   // corresponding loop scope.
192   int mSwitchDepth;
193 
194   // The outermost statement block that we are currently visiting.
195   // This should always be a CompoundStmt.
196   clang::Stmt *mOuterStmt;
197 
198   // The destructor to execute for this scope/variable.
199   clang::Stmt* mDtorStmt;
200 
201   // The stack of statements which should be replaced by a compound statement
202   // containing the new destructor call followed by the original Stmt.
203   std::stack<clang::Stmt*> mReplaceStmtStack;
204 
205   // The source location for the variable declaration that we are trying to
206   // insert destructors for. Note that InsertDestructors() will not generate
207   // destructor calls for source locations that occur lexically before this
208   // location.
209   clang::SourceLocation mVarLoc;
210 
211  public:
212   DestructorVisitor(clang::ASTContext &C,
213                     clang::Stmt* OuterStmt,
214                     clang::Stmt* DtorStmt,
215                     clang::SourceLocation VarLoc);
216 
217   // This code walks the collected list of Stmts to replace and actually does
218   // the replacement. It also finishes up by appending the destructor to the
219   // current outermost CompoundStmt.
InsertDestructors()220   void InsertDestructors() {
221     clang::Stmt *S = nullptr;
222     clang::SourceManager &SM = mCtx.getSourceManager();
223     std::list<clang::Stmt *> StmtList;
224     StmtList.push_back(mDtorStmt);
225 
226     while (!mReplaceStmtStack.empty()) {
227       S = mReplaceStmtStack.top();
228       mReplaceStmtStack.pop();
229 
230       // Skip all source locations that occur before the variable's
231       // declaration, since it won't have been initialized yet.
232       if (SM.isBeforeInTranslationUnit(S->getLocStart(), mVarLoc)) {
233         continue;
234       }
235 
236       StmtList.push_back(S);
237       clang::CompoundStmt *CS =
238           BuildCompoundStmt(mCtx, StmtList, S->getLocEnd());
239       StmtList.pop_back();
240 
241       RSASTReplace R(mCtx);
242       R.ReplaceStmt(mOuterStmt, S, CS);
243     }
244     clang::CompoundStmt *CS =
245       llvm::dyn_cast<clang::CompoundStmt>(mOuterStmt);
246     slangAssert(CS);
247     AppendAfterStmt(mCtx, CS, nullptr, StmtList);
248   }
249 
250   void VisitStmt(clang::Stmt *S);
251   void VisitCompoundStmt(clang::CompoundStmt *CS);
252 
253   void VisitBreakStmt(clang::BreakStmt *BS);
254   void VisitCaseStmt(clang::CaseStmt *CS);
255   void VisitContinueStmt(clang::ContinueStmt *CS);
256   void VisitDefaultStmt(clang::DefaultStmt *DS);
257   void VisitDoStmt(clang::DoStmt *DS);
258   void VisitForStmt(clang::ForStmt *FS);
259   void VisitIfStmt(clang::IfStmt *IS);
260   void VisitReturnStmt(clang::ReturnStmt *RS);
261   void VisitSwitchCase(clang::SwitchCase *SC);
262   void VisitSwitchStmt(clang::SwitchStmt *SS);
263   void VisitWhileStmt(clang::WhileStmt *WS);
264 };
265 
DestructorVisitor(clang::ASTContext & C,clang::Stmt * OuterStmt,clang::Stmt * DtorStmt,clang::SourceLocation VarLoc)266 DestructorVisitor::DestructorVisitor(clang::ASTContext &C,
267                          clang::Stmt *OuterStmt,
268                          clang::Stmt *DtorStmt,
269                          clang::SourceLocation VarLoc)
270   : mCtx(C),
271     mLoopDepth(0),
272     mSwitchDepth(0),
273     mOuterStmt(OuterStmt),
274     mDtorStmt(DtorStmt),
275     mVarLoc(VarLoc) {
276 }
277 
VisitStmt(clang::Stmt * S)278 void DestructorVisitor::VisitStmt(clang::Stmt *S) {
279   for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
280        I != E;
281        I++) {
282     if (clang::Stmt *Child = *I) {
283       Visit(Child);
284     }
285   }
286 }
287 
VisitCompoundStmt(clang::CompoundStmt * CS)288 void DestructorVisitor::VisitCompoundStmt(clang::CompoundStmt *CS) {
289   VisitStmt(CS);
290 }
291 
VisitBreakStmt(clang::BreakStmt * BS)292 void DestructorVisitor::VisitBreakStmt(clang::BreakStmt *BS) {
293   VisitStmt(BS);
294   if ((mLoopDepth == 0) && (mSwitchDepth == 0)) {
295     mReplaceStmtStack.push(BS);
296   }
297 }
298 
VisitCaseStmt(clang::CaseStmt * CS)299 void DestructorVisitor::VisitCaseStmt(clang::CaseStmt *CS) {
300   VisitStmt(CS);
301 }
302 
VisitContinueStmt(clang::ContinueStmt * CS)303 void DestructorVisitor::VisitContinueStmt(clang::ContinueStmt *CS) {
304   VisitStmt(CS);
305   if (mLoopDepth == 0) {
306     // Switch statements can have nested continues.
307     mReplaceStmtStack.push(CS);
308   }
309 }
310 
VisitDefaultStmt(clang::DefaultStmt * DS)311 void DestructorVisitor::VisitDefaultStmt(clang::DefaultStmt *DS) {
312   VisitStmt(DS);
313 }
314 
VisitDoStmt(clang::DoStmt * DS)315 void DestructorVisitor::VisitDoStmt(clang::DoStmt *DS) {
316   mLoopDepth++;
317   VisitStmt(DS);
318   mLoopDepth--;
319 }
320 
VisitForStmt(clang::ForStmt * FS)321 void DestructorVisitor::VisitForStmt(clang::ForStmt *FS) {
322   mLoopDepth++;
323   VisitStmt(FS);
324   mLoopDepth--;
325 }
326 
VisitIfStmt(clang::IfStmt * IS)327 void DestructorVisitor::VisitIfStmt(clang::IfStmt *IS) {
328   VisitStmt(IS);
329 }
330 
VisitReturnStmt(clang::ReturnStmt * RS)331 void DestructorVisitor::VisitReturnStmt(clang::ReturnStmt *RS) {
332   mReplaceStmtStack.push(RS);
333 }
334 
VisitSwitchCase(clang::SwitchCase * SC)335 void DestructorVisitor::VisitSwitchCase(clang::SwitchCase *SC) {
336   slangAssert(false && "Both case and default have specialized handlers");
337   VisitStmt(SC);
338 }
339 
VisitSwitchStmt(clang::SwitchStmt * SS)340 void DestructorVisitor::VisitSwitchStmt(clang::SwitchStmt *SS) {
341   mSwitchDepth++;
342   VisitStmt(SS);
343   mSwitchDepth--;
344 }
345 
VisitWhileStmt(clang::WhileStmt * WS)346 void DestructorVisitor::VisitWhileStmt(clang::WhileStmt *WS) {
347   mLoopDepth++;
348   VisitStmt(WS);
349   mLoopDepth--;
350 }
351 
ClearSingleRSObject(clang::ASTContext & C,clang::Expr * RefRSVar,clang::SourceLocation Loc)352 clang::Expr *ClearSingleRSObject(clang::ASTContext &C,
353                                  clang::Expr *RefRSVar,
354                                  clang::SourceLocation Loc) {
355   slangAssert(RefRSVar);
356   const clang::Type *T = RefRSVar->getType().getTypePtr();
357   slangAssert(!T->isArrayType() &&
358               "Should not be destroying arrays with this function");
359 
360   clang::FunctionDecl *ClearObjectFD = RSObjectRefCount::GetRSClearObjectFD(T);
361   slangAssert((ClearObjectFD != nullptr) &&
362               "rsClearObject doesn't cover all RS object types");
363 
364   clang::QualType ClearObjectFDType = ClearObjectFD->getType();
365   clang::QualType ClearObjectFDArgType =
366       ClearObjectFD->getParamDecl(0)->getOriginalType();
367 
368   // Example destructor for "rs_font localFont;"
369   //
370   // (CallExpr 'void'
371   //   (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
372   //     (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
373   //   (UnaryOperator 'rs_font *' prefix '&'
374   //     (DeclRefExpr 'rs_font':'rs_font' Var='localFont')))
375 
376   // Get address of targeted RS object
377   clang::Expr *AddrRefRSVar =
378       new(C) clang::UnaryOperator(RefRSVar,
379                                   clang::UO_AddrOf,
380                                   ClearObjectFDArgType,
381                                   clang::VK_RValue,
382                                   clang::OK_Ordinary,
383                                   Loc);
384 
385   clang::Expr *RefRSClearObjectFD =
386       clang::DeclRefExpr::Create(C,
387                                  clang::NestedNameSpecifierLoc(),
388                                  clang::SourceLocation(),
389                                  ClearObjectFD,
390                                  false,
391                                  ClearObjectFD->getLocation(),
392                                  ClearObjectFDType,
393                                  clang::VK_RValue,
394                                  nullptr);
395 
396   clang::Expr *RSClearObjectFP =
397       clang::ImplicitCastExpr::Create(C,
398                                       C.getPointerType(ClearObjectFDType),
399                                       clang::CK_FunctionToPointerDecay,
400                                       RefRSClearObjectFD,
401                                       nullptr,
402                                       clang::VK_RValue);
403 
404   llvm::SmallVector<clang::Expr*, 1> ArgList;
405   ArgList.push_back(AddrRefRSVar);
406 
407   clang::CallExpr *RSClearObjectCall =
408       new(C) clang::CallExpr(C,
409                              RSClearObjectFP,
410                              ArgList,
411                              ClearObjectFD->getCallResultType(),
412                              clang::VK_RValue,
413                              Loc);
414 
415   return RSClearObjectCall;
416 }
417 
ArrayDim(const clang::Type * T)418 static int ArrayDim(const clang::Type *T) {
419   if (!T || !T->isArrayType()) {
420     return 0;
421   }
422 
423   const clang::ConstantArrayType *CAT =
424     static_cast<const clang::ConstantArrayType *>(T);
425   return static_cast<int>(CAT->getSize().getSExtValue());
426 }
427 
428 static clang::Stmt *ClearStructRSObject(
429     clang::ASTContext &C,
430     clang::DeclContext *DC,
431     clang::Expr *RefRSStruct,
432     clang::SourceLocation StartLoc,
433     clang::SourceLocation Loc);
434 
ClearArrayRSObject(clang::ASTContext & C,clang::DeclContext * DC,clang::Expr * RefRSArr,clang::SourceLocation StartLoc,clang::SourceLocation Loc)435 static clang::Stmt *ClearArrayRSObject(
436     clang::ASTContext &C,
437     clang::DeclContext *DC,
438     clang::Expr *RefRSArr,
439     clang::SourceLocation StartLoc,
440     clang::SourceLocation Loc) {
441   const clang::Type *BaseType = RefRSArr->getType().getTypePtr();
442   slangAssert(BaseType->isArrayType());
443 
444   int NumArrayElements = ArrayDim(BaseType);
445   // Actually extract out the base RS object type for use later
446   BaseType = BaseType->getArrayElementTypeNoTypeQual();
447 
448   clang::Stmt *StmtArray[2] = {nullptr};
449   int StmtCtr = 0;
450 
451   if (NumArrayElements <= 0) {
452     return nullptr;
453   }
454 
455   // Example destructor loop for "rs_font fontArr[10];"
456   //
457   // (CompoundStmt
458   //   (DeclStmt "int rsIntIter")
459   //   (ForStmt
460   //     (BinaryOperator 'int' '='
461   //       (DeclRefExpr 'int' Var='rsIntIter')
462   //       (IntegerLiteral 'int' 0))
463   //     (BinaryOperator 'int' '<'
464   //       (DeclRefExpr 'int' Var='rsIntIter')
465   //       (IntegerLiteral 'int' 10)
466   //     nullptr << CondVar >>
467   //     (UnaryOperator 'int' postfix '++'
468   //       (DeclRefExpr 'int' Var='rsIntIter'))
469   //     (CallExpr 'void'
470   //       (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
471   //         (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
472   //       (UnaryOperator 'rs_font *' prefix '&'
473   //         (ArraySubscriptExpr 'rs_font':'rs_font'
474   //           (ImplicitCastExpr 'rs_font *' <ArrayToPointerDecay>
475   //             (DeclRefExpr 'rs_font [10]' Var='fontArr'))
476   //           (DeclRefExpr 'int' Var='rsIntIter')))))))
477 
478   // Create helper variable for iterating through elements
479   clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
480   clang::VarDecl *IIVD =
481       clang::VarDecl::Create(C,
482                              DC,
483                              StartLoc,
484                              Loc,
485                              &II,
486                              C.IntTy,
487                              C.getTrivialTypeSourceInfo(C.IntTy),
488                              clang::SC_None);
489   // Mark "rsIntIter" as used
490   IIVD->markUsed(C);
491   clang::Decl *IID = (clang::Decl *)IIVD;
492 
493   clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
494   StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
495 
496   // Form the actual destructor loop
497   // for (Init; Cond; Inc)
498   //   RSClearObjectCall;
499 
500   // Init -> "rsIntIter = 0"
501   clang::DeclRefExpr *RefrsIntIter =
502       clang::DeclRefExpr::Create(C,
503                                  clang::NestedNameSpecifierLoc(),
504                                  clang::SourceLocation(),
505                                  IIVD,
506                                  false,
507                                  Loc,
508                                  C.IntTy,
509                                  clang::VK_RValue,
510                                  nullptr);
511 
512   clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
513       llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
514 
515   clang::BinaryOperator *Init =
516       new(C) clang::BinaryOperator(RefrsIntIter,
517                                    Int0,
518                                    clang::BO_Assign,
519                                    C.IntTy,
520                                    clang::VK_RValue,
521                                    clang::OK_Ordinary,
522                                    Loc,
523                                    false);
524 
525   // Cond -> "rsIntIter < NumArrayElements"
526   clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
527       llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
528 
529   clang::BinaryOperator *Cond =
530       new(C) clang::BinaryOperator(RefrsIntIter,
531                                    NumArrayElementsExpr,
532                                    clang::BO_LT,
533                                    C.IntTy,
534                                    clang::VK_RValue,
535                                    clang::OK_Ordinary,
536                                    Loc,
537                                    false);
538 
539   // Inc -> "rsIntIter++"
540   clang::UnaryOperator *Inc =
541       new(C) clang::UnaryOperator(RefrsIntIter,
542                                   clang::UO_PostInc,
543                                   C.IntTy,
544                                   clang::VK_RValue,
545                                   clang::OK_Ordinary,
546                                   Loc);
547 
548   // Body -> "rsClearObject(&VD[rsIntIter]);"
549   // Destructor loop operates on individual array elements
550 
551   clang::Expr *RefRSArrPtr =
552       clang::ImplicitCastExpr::Create(C,
553           C.getPointerType(BaseType->getCanonicalTypeInternal()),
554           clang::CK_ArrayToPointerDecay,
555           RefRSArr,
556           nullptr,
557           clang::VK_RValue);
558 
559   clang::Expr *RefRSArrPtrSubscript =
560       new(C) clang::ArraySubscriptExpr(RefRSArrPtr,
561                                        RefrsIntIter,
562                                        BaseType->getCanonicalTypeInternal(),
563                                        clang::VK_RValue,
564                                        clang::OK_Ordinary,
565                                        Loc);
566 
567   DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
568 
569   clang::Stmt *RSClearObjectCall = nullptr;
570   if (BaseType->isArrayType()) {
571     RSClearObjectCall =
572         ClearArrayRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
573   } else if (DT == DataTypeUnknown) {
574     RSClearObjectCall =
575         ClearStructRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
576   } else {
577     RSClearObjectCall = ClearSingleRSObject(C, RefRSArrPtrSubscript, Loc);
578   }
579 
580   clang::ForStmt *DestructorLoop =
581       new(C) clang::ForStmt(C,
582                             Init,
583                             Cond,
584                             nullptr,  // no condVar
585                             Inc,
586                             RSClearObjectCall,
587                             Loc,
588                             Loc,
589                             Loc);
590 
591   StmtArray[StmtCtr++] = DestructorLoop;
592   slangAssert(StmtCtr == 2);
593 
594   clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
595       C, llvm::makeArrayRef(StmtArray, StmtCtr), Loc, Loc);
596 
597   return CS;
598 }
599 
CountRSObjectTypes(clang::ASTContext & C,const clang::Type * T,clang::SourceLocation Loc)600 static unsigned CountRSObjectTypes(clang::ASTContext &C,
601                                    const clang::Type *T,
602                                    clang::SourceLocation Loc) {
603   slangAssert(T);
604   unsigned RSObjectCount = 0;
605 
606   if (T->isArrayType()) {
607     return CountRSObjectTypes(C, T->getArrayElementTypeNoTypeQual(), Loc);
608   }
609 
610   DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
611   if (DT != DataTypeUnknown) {
612     return (RSExportPrimitiveType::IsRSObjectType(DT) ? 1 : 0);
613   }
614 
615   if (T->isUnionType()) {
616     clang::RecordDecl *RD = T->getAsUnionType()->getDecl();
617     RD = RD->getDefinition();
618     for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
619            FE = RD->field_end();
620          FI != FE;
621          FI++) {
622       const clang::FieldDecl *FD = *FI;
623       const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
624       if (CountRSObjectTypes(C, FT, Loc)) {
625         slangAssert(false && "can't have unions with RS object types!");
626         return 0;
627       }
628     }
629   }
630 
631   if (!T->isStructureType()) {
632     return 0;
633   }
634 
635   clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
636   RD = RD->getDefinition();
637   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
638          FE = RD->field_end();
639        FI != FE;
640        FI++) {
641     const clang::FieldDecl *FD = *FI;
642     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
643     if (CountRSObjectTypes(C, FT, Loc)) {
644       // Sub-structs should only count once (as should arrays, etc.)
645       RSObjectCount++;
646     }
647   }
648 
649   return RSObjectCount;
650 }
651 
ClearStructRSObject(clang::ASTContext & C,clang::DeclContext * DC,clang::Expr * RefRSStruct,clang::SourceLocation StartLoc,clang::SourceLocation Loc)652 static clang::Stmt *ClearStructRSObject(
653     clang::ASTContext &C,
654     clang::DeclContext *DC,
655     clang::Expr *RefRSStruct,
656     clang::SourceLocation StartLoc,
657     clang::SourceLocation Loc) {
658   const clang::Type *BaseType = RefRSStruct->getType().getTypePtr();
659 
660   slangAssert(!BaseType->isArrayType());
661 
662   // Structs should show up as unknown primitive types
663   slangAssert(RSExportPrimitiveType::GetRSSpecificType(BaseType) ==
664               DataTypeUnknown);
665 
666   unsigned FieldsToDestroy = CountRSObjectTypes(C, BaseType, Loc);
667   slangAssert(FieldsToDestroy != 0);
668 
669   unsigned StmtCount = 0;
670   clang::Stmt **StmtArray = new clang::Stmt*[FieldsToDestroy];
671   for (unsigned i = 0; i < FieldsToDestroy; i++) {
672     StmtArray[i] = nullptr;
673   }
674 
675   // Populate StmtArray by creating a destructor for each RS object field
676   clang::RecordDecl *RD = BaseType->getAsStructureType()->getDecl();
677   RD = RD->getDefinition();
678   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
679          FE = RD->field_end();
680        FI != FE;
681        FI++) {
682     // We just look through all field declarations to see if we find a
683     // declaration for an RS object type (or an array of one).
684     bool IsArrayType = false;
685     clang::FieldDecl *FD = *FI;
686     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
687     const clang::Type *OrigType = FT;
688     while (FT && FT->isArrayType()) {
689       FT = FT->getArrayElementTypeNoTypeQual();
690       IsArrayType = true;
691     }
692 
693     // Pass a DeclarationNameInfo with a valid DeclName, since name equality
694     // gets asserted during CodeGen.
695     clang::DeclarationNameInfo FDDeclNameInfo(FD->getDeclName(),
696                                               FD->getLocation());
697 
698     if (RSExportPrimitiveType::IsRSObjectType(FT)) {
699       clang::DeclAccessPair FoundDecl =
700           clang::DeclAccessPair::make(FD, clang::AS_none);
701       clang::MemberExpr *RSObjectMember =
702           clang::MemberExpr::Create(C,
703                                     RefRSStruct,
704                                     false,
705                                     clang::SourceLocation(),
706                                     clang::NestedNameSpecifierLoc(),
707                                     clang::SourceLocation(),
708                                     FD,
709                                     FoundDecl,
710                                     FDDeclNameInfo,
711                                     nullptr,
712                                     OrigType->getCanonicalTypeInternal(),
713                                     clang::VK_RValue,
714                                     clang::OK_Ordinary);
715 
716       slangAssert(StmtCount < FieldsToDestroy);
717 
718       if (IsArrayType) {
719         StmtArray[StmtCount++] = ClearArrayRSObject(C,
720                                                     DC,
721                                                     RSObjectMember,
722                                                     StartLoc,
723                                                     Loc);
724       } else {
725         StmtArray[StmtCount++] = ClearSingleRSObject(C,
726                                                      RSObjectMember,
727                                                      Loc);
728       }
729     } else if (FT->isStructureType() && CountRSObjectTypes(C, FT, Loc)) {
730       // In this case, we have a nested struct. We may not end up filling all
731       // of the spaces in StmtArray (sub-structs should handle themselves
732       // with separate compound statements).
733       clang::DeclAccessPair FoundDecl =
734           clang::DeclAccessPair::make(FD, clang::AS_none);
735       clang::MemberExpr *RSObjectMember =
736           clang::MemberExpr::Create(C,
737                                     RefRSStruct,
738                                     false,
739                                     clang::SourceLocation(),
740                                     clang::NestedNameSpecifierLoc(),
741                                     clang::SourceLocation(),
742                                     FD,
743                                     FoundDecl,
744                                     clang::DeclarationNameInfo(),
745                                     nullptr,
746                                     OrigType->getCanonicalTypeInternal(),
747                                     clang::VK_RValue,
748                                     clang::OK_Ordinary);
749 
750       if (IsArrayType) {
751         StmtArray[StmtCount++] = ClearArrayRSObject(C,
752                                                     DC,
753                                                     RSObjectMember,
754                                                     StartLoc,
755                                                     Loc);
756       } else {
757         StmtArray[StmtCount++] = ClearStructRSObject(C,
758                                                      DC,
759                                                      RSObjectMember,
760                                                      StartLoc,
761                                                      Loc);
762       }
763     }
764   }
765 
766   slangAssert(StmtCount > 0);
767   clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
768       C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
769 
770   delete [] StmtArray;
771 
772   return CS;
773 }
774 
CreateSingleRSSetObject(clang::ASTContext & C,clang::Expr * DstExpr,clang::Expr * SrcExpr,clang::SourceLocation StartLoc,clang::SourceLocation Loc)775 static clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
776                                             clang::Expr *DstExpr,
777                                             clang::Expr *SrcExpr,
778                                             clang::SourceLocation StartLoc,
779                                             clang::SourceLocation Loc) {
780   const clang::Type *T = DstExpr->getType().getTypePtr();
781   clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(T);
782   slangAssert((SetObjectFD != nullptr) &&
783               "rsSetObject doesn't cover all RS object types");
784 
785   clang::QualType SetObjectFDType = SetObjectFD->getType();
786   clang::QualType SetObjectFDArgType[2];
787   SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
788   SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
789 
790   clang::Expr *RefRSSetObjectFD =
791       clang::DeclRefExpr::Create(C,
792                                  clang::NestedNameSpecifierLoc(),
793                                  clang::SourceLocation(),
794                                  SetObjectFD,
795                                  false,
796                                  Loc,
797                                  SetObjectFDType,
798                                  clang::VK_RValue,
799                                  nullptr);
800 
801   clang::Expr *RSSetObjectFP =
802       clang::ImplicitCastExpr::Create(C,
803                                       C.getPointerType(SetObjectFDType),
804                                       clang::CK_FunctionToPointerDecay,
805                                       RefRSSetObjectFD,
806                                       nullptr,
807                                       clang::VK_RValue);
808 
809   llvm::SmallVector<clang::Expr*, 2> ArgList;
810   ArgList.push_back(new(C) clang::UnaryOperator(DstExpr,
811                                                 clang::UO_AddrOf,
812                                                 SetObjectFDArgType[0],
813                                                 clang::VK_RValue,
814                                                 clang::OK_Ordinary,
815                                                 Loc));
816   ArgList.push_back(SrcExpr);
817 
818   clang::CallExpr *RSSetObjectCall =
819       new(C) clang::CallExpr(C,
820                              RSSetObjectFP,
821                              ArgList,
822                              SetObjectFD->getCallResultType(),
823                              clang::VK_RValue,
824                              Loc);
825 
826   return RSSetObjectCall;
827 }
828 
829 static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
830                                             clang::Expr *LHS,
831                                             clang::Expr *RHS,
832                                             clang::SourceLocation StartLoc,
833                                             clang::SourceLocation Loc);
834 
835 /*static clang::Stmt *CreateArrayRSSetObject(clang::ASTContext &C,
836                                            clang::Expr *DstArr,
837                                            clang::Expr *SrcArr,
838                                            clang::SourceLocation StartLoc,
839                                            clang::SourceLocation Loc) {
840   clang::DeclContext *DC = nullptr;
841   const clang::Type *BaseType = DstArr->getType().getTypePtr();
842   slangAssert(BaseType->isArrayType());
843 
844   int NumArrayElements = ArrayDim(BaseType);
845   // Actually extract out the base RS object type for use later
846   BaseType = BaseType->getArrayElementTypeNoTypeQual();
847 
848   clang::Stmt *StmtArray[2] = {nullptr};
849   int StmtCtr = 0;
850 
851   if (NumArrayElements <= 0) {
852     return nullptr;
853   }
854 
855   // Create helper variable for iterating through elements
856   clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
857   clang::VarDecl *IIVD =
858       clang::VarDecl::Create(C,
859                              DC,
860                              StartLoc,
861                              Loc,
862                              &II,
863                              C.IntTy,
864                              C.getTrivialTypeSourceInfo(C.IntTy),
865                              clang::SC_None,
866                              clang::SC_None);
867   clang::Decl *IID = (clang::Decl *)IIVD;
868 
869   clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
870   StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
871 
872   // Form the actual loop
873   // for (Init; Cond; Inc)
874   //   RSSetObjectCall;
875 
876   // Init -> "rsIntIter = 0"
877   clang::DeclRefExpr *RefrsIntIter =
878       clang::DeclRefExpr::Create(C,
879                                  clang::NestedNameSpecifierLoc(),
880                                  IIVD,
881                                  Loc,
882                                  C.IntTy,
883                                  clang::VK_RValue,
884                                  nullptr);
885 
886   clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
887       llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
888 
889   clang::BinaryOperator *Init =
890       new(C) clang::BinaryOperator(RefrsIntIter,
891                                    Int0,
892                                    clang::BO_Assign,
893                                    C.IntTy,
894                                    clang::VK_RValue,
895                                    clang::OK_Ordinary,
896                                    Loc);
897 
898   // Cond -> "rsIntIter < NumArrayElements"
899   clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
900       llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
901 
902   clang::BinaryOperator *Cond =
903       new(C) clang::BinaryOperator(RefrsIntIter,
904                                    NumArrayElementsExpr,
905                                    clang::BO_LT,
906                                    C.IntTy,
907                                    clang::VK_RValue,
908                                    clang::OK_Ordinary,
909                                    Loc);
910 
911   // Inc -> "rsIntIter++"
912   clang::UnaryOperator *Inc =
913       new(C) clang::UnaryOperator(RefrsIntIter,
914                                   clang::UO_PostInc,
915                                   C.IntTy,
916                                   clang::VK_RValue,
917                                   clang::OK_Ordinary,
918                                   Loc);
919 
920   // Body -> "rsSetObject(&Dst[rsIntIter], Src[rsIntIter]);"
921   // Loop operates on individual array elements
922 
923   clang::Expr *DstArrPtr =
924       clang::ImplicitCastExpr::Create(C,
925           C.getPointerType(BaseType->getCanonicalTypeInternal()),
926           clang::CK_ArrayToPointerDecay,
927           DstArr,
928           nullptr,
929           clang::VK_RValue);
930 
931   clang::Expr *DstArrPtrSubscript =
932       new(C) clang::ArraySubscriptExpr(DstArrPtr,
933                                        RefrsIntIter,
934                                        BaseType->getCanonicalTypeInternal(),
935                                        clang::VK_RValue,
936                                        clang::OK_Ordinary,
937                                        Loc);
938 
939   clang::Expr *SrcArrPtr =
940       clang::ImplicitCastExpr::Create(C,
941           C.getPointerType(BaseType->getCanonicalTypeInternal()),
942           clang::CK_ArrayToPointerDecay,
943           SrcArr,
944           nullptr,
945           clang::VK_RValue);
946 
947   clang::Expr *SrcArrPtrSubscript =
948       new(C) clang::ArraySubscriptExpr(SrcArrPtr,
949                                        RefrsIntIter,
950                                        BaseType->getCanonicalTypeInternal(),
951                                        clang::VK_RValue,
952                                        clang::OK_Ordinary,
953                                        Loc);
954 
955   DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
956 
957   clang::Stmt *RSSetObjectCall = nullptr;
958   if (BaseType->isArrayType()) {
959     RSSetObjectCall = CreateArrayRSSetObject(C, DstArrPtrSubscript,
960                                              SrcArrPtrSubscript,
961                                              StartLoc, Loc);
962   } else if (DT == DataTypeUnknown) {
963     RSSetObjectCall = CreateStructRSSetObject(C, DstArrPtrSubscript,
964                                               SrcArrPtrSubscript,
965                                               StartLoc, Loc);
966   } else {
967     RSSetObjectCall = CreateSingleRSSetObject(C, DstArrPtrSubscript,
968                                               SrcArrPtrSubscript,
969                                               StartLoc, Loc);
970   }
971 
972   clang::ForStmt *DestructorLoop =
973       new(C) clang::ForStmt(C,
974                             Init,
975                             Cond,
976                             nullptr,  // no condVar
977                             Inc,
978                             RSSetObjectCall,
979                             Loc,
980                             Loc,
981                             Loc);
982 
983   StmtArray[StmtCtr++] = DestructorLoop;
984   slangAssert(StmtCtr == 2);
985 
986   clang::CompoundStmt *CS =
987       new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
988 
989   return CS;
990 } */
991 
CreateStructRSSetObject(clang::ASTContext & C,clang::Expr * LHS,clang::Expr * RHS,clang::SourceLocation StartLoc,clang::SourceLocation Loc)992 static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
993                                             clang::Expr *LHS,
994                                             clang::Expr *RHS,
995                                             clang::SourceLocation StartLoc,
996                                             clang::SourceLocation Loc) {
997   clang::QualType QT = LHS->getType();
998   const clang::Type *T = QT.getTypePtr();
999   slangAssert(T->isStructureType());
1000   slangAssert(!RSExportPrimitiveType::IsRSObjectType(T));
1001 
1002   // Keep an extra slot for the original copy (memcpy)
1003   unsigned FieldsToSet = CountRSObjectTypes(C, T, Loc) + 1;
1004 
1005   unsigned StmtCount = 0;
1006   clang::Stmt **StmtArray = new clang::Stmt*[FieldsToSet];
1007   for (unsigned i = 0; i < FieldsToSet; i++) {
1008     StmtArray[i] = nullptr;
1009   }
1010 
1011   clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
1012   RD = RD->getDefinition();
1013   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1014          FE = RD->field_end();
1015        FI != FE;
1016        FI++) {
1017     bool IsArrayType = false;
1018     clang::FieldDecl *FD = *FI;
1019     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1020     const clang::Type *OrigType = FT;
1021 
1022     if (!CountRSObjectTypes(C, FT, Loc)) {
1023       // Skip to next if we don't have any viable RS object types
1024       continue;
1025     }
1026 
1027     clang::DeclAccessPair FoundDecl =
1028         clang::DeclAccessPair::make(FD, clang::AS_none);
1029     clang::MemberExpr *DstMember =
1030         clang::MemberExpr::Create(C,
1031                                   LHS,
1032                                   false,
1033                                   clang::SourceLocation(),
1034                                   clang::NestedNameSpecifierLoc(),
1035                                   clang::SourceLocation(),
1036                                   FD,
1037                                   FoundDecl,
1038                                   clang::DeclarationNameInfo(),
1039                                   nullptr,
1040                                   OrigType->getCanonicalTypeInternal(),
1041                                   clang::VK_RValue,
1042                                   clang::OK_Ordinary);
1043 
1044     clang::MemberExpr *SrcMember =
1045         clang::MemberExpr::Create(C,
1046                                   RHS,
1047                                   false,
1048                                   clang::SourceLocation(),
1049                                   clang::NestedNameSpecifierLoc(),
1050                                   clang::SourceLocation(),
1051                                   FD,
1052                                   FoundDecl,
1053                                   clang::DeclarationNameInfo(),
1054                                   nullptr,
1055                                   OrigType->getCanonicalTypeInternal(),
1056                                   clang::VK_RValue,
1057                                   clang::OK_Ordinary);
1058 
1059     if (FT->isArrayType()) {
1060       FT = FT->getArrayElementTypeNoTypeQual();
1061       IsArrayType = true;
1062     }
1063 
1064     DataType DT = RSExportPrimitiveType::GetRSSpecificType(FT);
1065 
1066     if (IsArrayType) {
1067       clang::DiagnosticsEngine &DiagEngine = C.getDiagnostics();
1068       DiagEngine.Report(
1069         clang::FullSourceLoc(Loc, C.getSourceManager()),
1070         DiagEngine.getCustomDiagID(
1071           clang::DiagnosticsEngine::Error,
1072           "Arrays of RS object types within structures cannot be copied"));
1073       // TODO(srhines): Support setting arrays of RS objects
1074       // StmtArray[StmtCount++] =
1075       //    CreateArrayRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1076     } else if (DT == DataTypeUnknown) {
1077       StmtArray[StmtCount++] =
1078           CreateStructRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1079     } else if (RSExportPrimitiveType::IsRSObjectType(DT)) {
1080       StmtArray[StmtCount++] =
1081           CreateSingleRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1082     } else {
1083       slangAssert(false);
1084     }
1085   }
1086 
1087   slangAssert(StmtCount < FieldsToSet);
1088 
1089   // We still need to actually do the overall struct copy. For simplicity,
1090   // we just do a straight-up assignment (which will still preserve all
1091   // the proper RS object reference counts).
1092   clang::BinaryOperator *CopyStruct =
1093       new(C) clang::BinaryOperator(LHS, RHS, clang::BO_Assign, QT,
1094                                    clang::VK_RValue, clang::OK_Ordinary, Loc,
1095                                    false);
1096   StmtArray[StmtCount++] = CopyStruct;
1097 
1098   clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
1099       C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
1100 
1101   delete [] StmtArray;
1102 
1103   return CS;
1104 }
1105 
1106 }  // namespace
1107 
ReplaceRSObjectAssignment(clang::BinaryOperator * AS)1108 void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
1109     clang::BinaryOperator *AS) {
1110 
1111   clang::QualType QT = AS->getType();
1112 
1113   clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1114       DataTypeRSAllocation)->getASTContext();
1115 
1116   clang::SourceLocation Loc = AS->getExprLoc();
1117   clang::SourceLocation StartLoc = AS->getLHS()->getExprLoc();
1118   clang::Stmt *UpdatedStmt = nullptr;
1119 
1120   if (!RSExportPrimitiveType::IsRSObjectType(QT.getTypePtr())) {
1121     // By definition, this is a struct assignment if we get here
1122     UpdatedStmt =
1123         CreateStructRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1124   } else {
1125     UpdatedStmt =
1126         CreateSingleRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1127   }
1128 
1129   RSASTReplace R(C);
1130   R.ReplaceStmt(mCS, AS, UpdatedStmt);
1131 }
1132 
AppendRSObjectInit(clang::VarDecl * VD,clang::DeclStmt * DS,DataType DT,clang::Expr * InitExpr)1133 void RSObjectRefCount::Scope::AppendRSObjectInit(
1134     clang::VarDecl *VD,
1135     clang::DeclStmt *DS,
1136     DataType DT,
1137     clang::Expr *InitExpr) {
1138   slangAssert(VD);
1139 
1140   if (!InitExpr) {
1141     return;
1142   }
1143 
1144   clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1145       DataTypeRSAllocation)->getASTContext();
1146   clang::SourceLocation Loc = RSObjectRefCount::GetRSSetObjectFD(
1147       DataTypeRSAllocation)->getLocation();
1148   clang::SourceLocation StartLoc = RSObjectRefCount::GetRSSetObjectFD(
1149       DataTypeRSAllocation)->getInnerLocStart();
1150 
1151   if (DT == DataTypeIsStruct) {
1152     const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1153     clang::DeclRefExpr *RefRSVar =
1154         clang::DeclRefExpr::Create(C,
1155                                    clang::NestedNameSpecifierLoc(),
1156                                    clang::SourceLocation(),
1157                                    VD,
1158                                    false,
1159                                    Loc,
1160                                    T->getCanonicalTypeInternal(),
1161                                    clang::VK_RValue,
1162                                    nullptr);
1163 
1164     clang::Stmt *RSSetObjectOps =
1165         CreateStructRSSetObject(C, RefRSVar, InitExpr, StartLoc, Loc);
1166 
1167     std::list<clang::Stmt*> StmtList;
1168     StmtList.push_back(RSSetObjectOps);
1169     AppendAfterStmt(C, mCS, DS, StmtList);
1170     return;
1171   }
1172 
1173   clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(DT);
1174   slangAssert((SetObjectFD != nullptr) &&
1175               "rsSetObject doesn't cover all RS object types");
1176 
1177   clang::QualType SetObjectFDType = SetObjectFD->getType();
1178   clang::QualType SetObjectFDArgType[2];
1179   SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
1180   SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
1181 
1182   clang::Expr *RefRSSetObjectFD =
1183       clang::DeclRefExpr::Create(C,
1184                                  clang::NestedNameSpecifierLoc(),
1185                                  clang::SourceLocation(),
1186                                  SetObjectFD,
1187                                  false,
1188                                  Loc,
1189                                  SetObjectFDType,
1190                                  clang::VK_RValue,
1191                                  nullptr);
1192 
1193   clang::Expr *RSSetObjectFP =
1194       clang::ImplicitCastExpr::Create(C,
1195                                       C.getPointerType(SetObjectFDType),
1196                                       clang::CK_FunctionToPointerDecay,
1197                                       RefRSSetObjectFD,
1198                                       nullptr,
1199                                       clang::VK_RValue);
1200 
1201   const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1202   clang::DeclRefExpr *RefRSVar =
1203       clang::DeclRefExpr::Create(C,
1204                                  clang::NestedNameSpecifierLoc(),
1205                                  clang::SourceLocation(),
1206                                  VD,
1207                                  false,
1208                                  Loc,
1209                                  T->getCanonicalTypeInternal(),
1210                                  clang::VK_RValue,
1211                                  nullptr);
1212 
1213   llvm::SmallVector<clang::Expr*, 2> ArgList;
1214   ArgList.push_back(new(C) clang::UnaryOperator(RefRSVar,
1215                                                 clang::UO_AddrOf,
1216                                                 SetObjectFDArgType[0],
1217                                                 clang::VK_RValue,
1218                                                 clang::OK_Ordinary,
1219                                                 Loc));
1220   ArgList.push_back(InitExpr);
1221 
1222   clang::CallExpr *RSSetObjectCall =
1223       new(C) clang::CallExpr(C,
1224                              RSSetObjectFP,
1225                              ArgList,
1226                              SetObjectFD->getCallResultType(),
1227                              clang::VK_RValue,
1228                              Loc);
1229 
1230   std::list<clang::Stmt*> StmtList;
1231   StmtList.push_back(RSSetObjectCall);
1232   AppendAfterStmt(C, mCS, DS, StmtList);
1233 }
1234 
InsertLocalVarDestructors()1235 void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
1236   for (std::list<clang::VarDecl*>::const_iterator I = mRSO.begin(),
1237           E = mRSO.end();
1238         I != E;
1239         I++) {
1240     clang::VarDecl *VD = *I;
1241     clang::Stmt *RSClearObjectCall = ClearRSObject(VD, VD->getDeclContext());
1242     if (RSClearObjectCall) {
1243       clang::ASTContext &C = (*mRSO.begin())->getASTContext();
1244       // Mark VD as used.  It might be unused, except for the destructor.
1245       // 'markUsed' has side-effects that are caused only if VD is not already
1246       // used.  Hence no need for an extra check here.
1247       VD->markUsed(C);
1248       DestructorVisitor DV(C,
1249                            mCS,
1250                            RSClearObjectCall,
1251                            VD->getSourceRange().getBegin());
1252       DV.Visit(mCS);
1253       DV.InsertDestructors();
1254     }
1255   }
1256 }
1257 
ClearRSObject(clang::VarDecl * VD,clang::DeclContext * DC)1258 clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(
1259     clang::VarDecl *VD,
1260     clang::DeclContext *DC) {
1261   slangAssert(VD);
1262   clang::ASTContext &C = VD->getASTContext();
1263   clang::SourceLocation Loc = VD->getLocation();
1264   clang::SourceLocation StartLoc = VD->getInnerLocStart();
1265   const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1266 
1267   // Reference expr to target RS object variable
1268   clang::DeclRefExpr *RefRSVar =
1269       clang::DeclRefExpr::Create(C,
1270                                  clang::NestedNameSpecifierLoc(),
1271                                  clang::SourceLocation(),
1272                                  VD,
1273                                  false,
1274                                  Loc,
1275                                  T->getCanonicalTypeInternal(),
1276                                  clang::VK_RValue,
1277                                  nullptr);
1278 
1279   if (T->isArrayType()) {
1280     return ClearArrayRSObject(C, DC, RefRSVar, StartLoc, Loc);
1281   }
1282 
1283   DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
1284 
1285   if (DT == DataTypeUnknown ||
1286       DT == DataTypeIsStruct) {
1287     return ClearStructRSObject(C, DC, RefRSVar, StartLoc, Loc);
1288   }
1289 
1290   slangAssert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
1291               "Should be RS object");
1292 
1293   return ClearSingleRSObject(C, RefRSVar, Loc);
1294 }
1295 
InitializeRSObject(clang::VarDecl * VD,DataType * DT,clang::Expr ** InitExpr)1296 bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD,
1297                                           DataType *DT,
1298                                           clang::Expr **InitExpr) {
1299   slangAssert(VD && DT && InitExpr);
1300   const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1301 
1302   // Loop through array types to get to base type
1303   while (T && T->isArrayType()) {
1304     T = T->getArrayElementTypeNoTypeQual();
1305   }
1306 
1307   bool DataTypeIsStructWithRSObject = false;
1308   *DT = RSExportPrimitiveType::GetRSSpecificType(T);
1309 
1310   if (*DT == DataTypeUnknown) {
1311     if (RSExportPrimitiveType::IsStructureTypeWithRSObject(T)) {
1312       *DT = DataTypeIsStruct;
1313       DataTypeIsStructWithRSObject = true;
1314     } else {
1315       return false;
1316     }
1317   }
1318 
1319   bool DataTypeIsRSObject = false;
1320   if (DataTypeIsStructWithRSObject) {
1321     DataTypeIsRSObject = true;
1322   } else {
1323     DataTypeIsRSObject = RSExportPrimitiveType::IsRSObjectType(*DT);
1324   }
1325   *InitExpr = VD->getInit();
1326 
1327   if (!DataTypeIsRSObject && *InitExpr) {
1328     // If we already have an initializer for a matrix type, we are done.
1329     return DataTypeIsRSObject;
1330   }
1331 
1332   clang::Expr *ZeroInitializer =
1333       CreateZeroInitializerForRSSpecificType(*DT,
1334                                              VD->getASTContext(),
1335                                              VD->getLocation());
1336 
1337   if (ZeroInitializer) {
1338     ZeroInitializer->setType(T->getCanonicalTypeInternal());
1339     VD->setInit(ZeroInitializer);
1340   }
1341 
1342   return DataTypeIsRSObject;
1343 }
1344 
CreateZeroInitializerForRSSpecificType(DataType DT,clang::ASTContext & C,const clang::SourceLocation & Loc)1345 clang::Expr *RSObjectRefCount::CreateZeroInitializerForRSSpecificType(
1346     DataType DT,
1347     clang::ASTContext &C,
1348     const clang::SourceLocation &Loc) {
1349   clang::Expr *Res = nullptr;
1350   switch (DT) {
1351     case DataTypeIsStruct:
1352     case DataTypeRSElement:
1353     case DataTypeRSType:
1354     case DataTypeRSAllocation:
1355     case DataTypeRSSampler:
1356     case DataTypeRSScript:
1357     case DataTypeRSMesh:
1358     case DataTypeRSPath:
1359     case DataTypeRSProgramFragment:
1360     case DataTypeRSProgramVertex:
1361     case DataTypeRSProgramRaster:
1362     case DataTypeRSProgramStore:
1363     case DataTypeRSFont: {
1364       //    (ImplicitCastExpr 'nullptr_t'
1365       //      (IntegerLiteral 0)))
1366       llvm::APInt Zero(C.getTypeSize(C.IntTy), 0);
1367       clang::Expr *Int0 = clang::IntegerLiteral::Create(C, Zero, C.IntTy, Loc);
1368       clang::Expr *CastToNull =
1369           clang::ImplicitCastExpr::Create(C,
1370                                           C.NullPtrTy,
1371                                           clang::CK_IntegralToPointer,
1372                                           Int0,
1373                                           nullptr,
1374                                           clang::VK_RValue);
1375 
1376       llvm::SmallVector<clang::Expr*, 1>InitList;
1377       InitList.push_back(CastToNull);
1378 
1379       Res = new(C) clang::InitListExpr(C, Loc, InitList, Loc);
1380       break;
1381     }
1382     case DataTypeRSMatrix2x2:
1383     case DataTypeRSMatrix3x3:
1384     case DataTypeRSMatrix4x4: {
1385       // RS matrix is not completely an RS object. They hold data by themselves.
1386       // (InitListExpr rs_matrix2x2
1387       //   (InitListExpr float[4]
1388       //     (FloatingLiteral 0)
1389       //     (FloatingLiteral 0)
1390       //     (FloatingLiteral 0)
1391       //     (FloatingLiteral 0)))
1392       clang::QualType FloatTy = C.FloatTy;
1393       // Constructor sets value to 0.0f by default
1394       llvm::APFloat Val(C.getFloatTypeSemantics(FloatTy));
1395       clang::FloatingLiteral *Float0Val =
1396           clang::FloatingLiteral::Create(C,
1397                                          Val,
1398                                          /* isExact = */true,
1399                                          FloatTy,
1400                                          Loc);
1401 
1402       unsigned N = 0;
1403       if (DT == DataTypeRSMatrix2x2)
1404         N = 2;
1405       else if (DT == DataTypeRSMatrix3x3)
1406         N = 3;
1407       else if (DT == DataTypeRSMatrix4x4)
1408         N = 4;
1409       unsigned N_2 = N * N;
1410 
1411       // Assume we are going to be allocating 16 elements, since 4x4 is max.
1412       llvm::SmallVector<clang::Expr*, 16> InitVals;
1413       for (unsigned i = 0; i < N_2; i++)
1414         InitVals.push_back(Float0Val);
1415       clang::Expr *InitExpr =
1416           new(C) clang::InitListExpr(C, Loc, InitVals, Loc);
1417       InitExpr->setType(C.getConstantArrayType(FloatTy,
1418                                                llvm::APInt(32, N_2),
1419                                                clang::ArrayType::Normal,
1420                                                /* EltTypeQuals = */0));
1421       llvm::SmallVector<clang::Expr*, 1> InitExprVec;
1422       InitExprVec.push_back(InitExpr);
1423 
1424       Res = new(C) clang::InitListExpr(C, Loc, InitExprVec, Loc);
1425       break;
1426     }
1427     case DataTypeUnknown:
1428     case DataTypeFloat16:
1429     case DataTypeFloat32:
1430     case DataTypeFloat64:
1431     case DataTypeSigned8:
1432     case DataTypeSigned16:
1433     case DataTypeSigned32:
1434     case DataTypeSigned64:
1435     case DataTypeUnsigned8:
1436     case DataTypeUnsigned16:
1437     case DataTypeUnsigned32:
1438     case DataTypeUnsigned64:
1439     case DataTypeBoolean:
1440     case DataTypeUnsigned565:
1441     case DataTypeUnsigned5551:
1442     case DataTypeUnsigned4444:
1443     case DataTypeMax: {
1444       slangAssert(false && "Not RS object type!");
1445     }
1446     // No default case will enable compiler detecting the missing cases
1447   }
1448 
1449   return Res;
1450 }
1451 
VisitDeclStmt(clang::DeclStmt * DS)1452 void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
1453   for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
1454        I != E;
1455        I++) {
1456     clang::Decl *D = *I;
1457     if (D->getKind() == clang::Decl::Var) {
1458       clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
1459       DataType DT = DataTypeUnknown;
1460       clang::Expr *InitExpr = nullptr;
1461       if (InitializeRSObject(VD, &DT, &InitExpr)) {
1462         // We need to zero-init all RS object types (including matrices), ...
1463         getCurrentScope()->AppendRSObjectInit(VD, DS, DT, InitExpr);
1464         // ... but, only add to the list of RS objects if we have some
1465         // non-matrix RS object fields.
1466         if (CountRSObjectTypes(mCtx, VD->getType().getTypePtr(),
1467                                VD->getLocation())) {
1468           getCurrentScope()->addRSObject(VD);
1469         }
1470       }
1471     }
1472   }
1473 }
1474 
VisitCompoundStmt(clang::CompoundStmt * CS)1475 void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
1476   if (!CS->body_empty()) {
1477     // Push a new scope
1478     Scope *S = new Scope(CS);
1479     mScopeStack.push(S);
1480 
1481     VisitStmt(CS);
1482 
1483     // Destroy the scope
1484     slangAssert((getCurrentScope() == S) && "Corrupted scope stack!");
1485     S->InsertLocalVarDestructors();
1486     mScopeStack.pop();
1487     delete S;
1488   }
1489 }
1490 
VisitBinAssign(clang::BinaryOperator * AS)1491 void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
1492   clang::QualType QT = AS->getType();
1493 
1494   if (CountRSObjectTypes(mCtx, QT.getTypePtr(), AS->getExprLoc())) {
1495     getCurrentScope()->ReplaceRSObjectAssignment(AS);
1496   }
1497 }
1498 
VisitStmt(clang::Stmt * S)1499 void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
1500   for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
1501        I != E;
1502        I++) {
1503     if (clang::Stmt *Child = *I) {
1504       Visit(Child);
1505     }
1506   }
1507 }
1508 
1509 // This function walks the list of global variables and (potentially) creates
1510 // a single global static destructor function that properly decrements
1511 // reference counts on the contained RS object types.
CreateStaticGlobalDtor()1512 clang::FunctionDecl *RSObjectRefCount::CreateStaticGlobalDtor() {
1513   Init();
1514 
1515   clang::DeclContext *DC = mCtx.getTranslationUnitDecl();
1516   clang::SourceLocation loc;
1517 
1518   llvm::StringRef SR(".rs.dtor");
1519   clang::IdentifierInfo &II = mCtx.Idents.get(SR);
1520   clang::DeclarationName N(&II);
1521   clang::FunctionProtoType::ExtProtoInfo EPI;
1522   clang::QualType T = mCtx.getFunctionType(mCtx.VoidTy,
1523       llvm::ArrayRef<clang::QualType>(), EPI);
1524   clang::FunctionDecl *FD = nullptr;
1525 
1526   // Generate rsClearObject() call chains for every global variable
1527   // (whether static or extern).
1528   std::list<clang::Stmt *> StmtList;
1529   for (clang::DeclContext::decl_iterator I = DC->decls_begin(),
1530           E = DC->decls_end(); I != E; I++) {
1531     clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*I);
1532     if (VD) {
1533       if (CountRSObjectTypes(mCtx, VD->getType().getTypePtr(), loc)) {
1534         if (!FD) {
1535           // Only create FD if we are going to use it.
1536           FD = clang::FunctionDecl::Create(mCtx, DC, loc, loc, N, T, nullptr,
1537                                            clang::SC_None);
1538         }
1539         // Mark VD as used.  It might be unused, except for the destructor.
1540         // 'markUsed' has side-effects that are caused only if VD is not already
1541         // used.  Hence no need for an extra check here.
1542         VD->markUsed(mCtx);
1543         // Make sure to create any helpers within the function's DeclContext,
1544         // not the one associated with the global translation unit.
1545         clang::Stmt *RSClearObjectCall = Scope::ClearRSObject(VD, FD);
1546         StmtList.push_back(RSClearObjectCall);
1547       }
1548     }
1549   }
1550 
1551   // Nothing needs to be destroyed, so don't emit a dtor.
1552   if (StmtList.empty()) {
1553     return nullptr;
1554   }
1555 
1556   clang::CompoundStmt *CS = BuildCompoundStmt(mCtx, StmtList, loc);
1557 
1558   FD->setBody(CS);
1559 
1560   return FD;
1561 }
1562 
1563 }  // namespace slang
1564