1 /*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "slang_rs_object_ref_count.h"
18
19 #include <list>
20
21 #include "clang/AST/DeclGroup.h"
22 #include "clang/AST/Expr.h"
23 #include "clang/AST/NestedNameSpecifier.h"
24 #include "clang/AST/OperationKinds.h"
25 #include "clang/AST/Stmt.h"
26 #include "clang/AST/StmtVisitor.h"
27
28 #include "slang_assert.h"
29 #include "slang.h"
30 #include "slang_rs_ast_replace.h"
31 #include "slang_rs_export_type.h"
32
33 namespace slang {
34
35 /* Even though those two arrays are of size DataTypeMax, only entries that
36 * correspond to object types will be set.
37 */
38 clang::FunctionDecl *
39 RSObjectRefCount::RSSetObjectFD[DataTypeMax];
40 clang::FunctionDecl *
41 RSObjectRefCount::RSClearObjectFD[DataTypeMax];
42
GetRSRefCountingFunctions(clang::ASTContext & C)43 void RSObjectRefCount::GetRSRefCountingFunctions(clang::ASTContext &C) {
44 for (unsigned i = 0; i < DataTypeMax; i++) {
45 RSSetObjectFD[i] = nullptr;
46 RSClearObjectFD[i] = nullptr;
47 }
48
49 clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
50
51 for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
52 E = TUDecl->decls_end(); I != E; I++) {
53 if ((I->getKind() >= clang::Decl::firstFunction) &&
54 (I->getKind() <= clang::Decl::lastFunction)) {
55 clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
56
57 // points to RSSetObjectFD or RSClearObjectFD
58 clang::FunctionDecl **RSObjectFD;
59
60 if (FD->getName() == "rsSetObject") {
61 slangAssert((FD->getNumParams() == 2) &&
62 "Invalid rsSetObject function prototype (# params)");
63 RSObjectFD = RSSetObjectFD;
64 } else if (FD->getName() == "rsClearObject") {
65 slangAssert((FD->getNumParams() == 1) &&
66 "Invalid rsClearObject function prototype (# params)");
67 RSObjectFD = RSClearObjectFD;
68 } else {
69 continue;
70 }
71
72 const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
73 clang::QualType PVT = PVD->getOriginalType();
74 // The first parameter must be a pointer like rs_allocation*
75 slangAssert(PVT->isPointerType() &&
76 "Invalid rs{Set,Clear}Object function prototype (pointer param)");
77
78 // The rs object type passed to the FD
79 clang::QualType RST = PVT->getPointeeType();
80 DataType DT = RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
81 slangAssert(RSExportPrimitiveType::IsRSObjectType(DT)
82 && "must be RS object type");
83
84 if (DT >= 0 && DT < DataTypeMax) {
85 RSObjectFD[DT] = FD;
86 } else {
87 slangAssert(false && "incorrect type");
88 }
89 }
90 }
91 }
92
93 namespace {
94
95 // This function constructs a new CompoundStmt from the input StmtList.
BuildCompoundStmt(clang::ASTContext & C,std::list<clang::Stmt * > & StmtList,clang::SourceLocation Loc)96 static clang::CompoundStmt* BuildCompoundStmt(clang::ASTContext &C,
97 std::list<clang::Stmt*> &StmtList, clang::SourceLocation Loc) {
98 unsigned NewStmtCount = StmtList.size();
99 unsigned CompoundStmtCount = 0;
100
101 clang::Stmt **CompoundStmtList;
102 CompoundStmtList = new clang::Stmt*[NewStmtCount];
103
104 std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
105 std::list<clang::Stmt*>::const_iterator E = StmtList.end();
106 for ( ; I != E; I++) {
107 CompoundStmtList[CompoundStmtCount++] = *I;
108 }
109 slangAssert(CompoundStmtCount == NewStmtCount);
110
111 clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
112 C, llvm::makeArrayRef(CompoundStmtList, CompoundStmtCount), Loc, Loc);
113
114 delete [] CompoundStmtList;
115
116 return CS;
117 }
118
AppendAfterStmt(clang::ASTContext & C,clang::CompoundStmt * CS,clang::Stmt * S,std::list<clang::Stmt * > & StmtList)119 static void AppendAfterStmt(clang::ASTContext &C,
120 clang::CompoundStmt *CS,
121 clang::Stmt *S,
122 std::list<clang::Stmt*> &StmtList) {
123 slangAssert(CS);
124 clang::CompoundStmt::body_iterator bI = CS->body_begin();
125 clang::CompoundStmt::body_iterator bE = CS->body_end();
126 clang::Stmt **UpdatedStmtList =
127 new clang::Stmt*[CS->size() + StmtList.size()];
128
129 unsigned UpdatedStmtCount = 0;
130 unsigned Once = 0;
131 for ( ; bI != bE; bI++) {
132 if (!S && ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass)) {
133 // If we come across a return here, we don't have anything we can
134 // reasonably replace. We should have already inserted our destructor
135 // code in the proper spot, so we just clean up and return.
136 delete [] UpdatedStmtList;
137
138 return;
139 }
140
141 UpdatedStmtList[UpdatedStmtCount++] = *bI;
142
143 if ((*bI == S) && !Once) {
144 Once++;
145 std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
146 std::list<clang::Stmt*>::const_iterator E = StmtList.end();
147 for ( ; I != E; I++) {
148 UpdatedStmtList[UpdatedStmtCount++] = *I;
149 }
150 }
151 }
152 slangAssert(Once <= 1);
153
154 // When S is nullptr, we are appending to the end of the CompoundStmt.
155 if (!S) {
156 slangAssert(Once == 0);
157 std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
158 std::list<clang::Stmt*>::const_iterator E = StmtList.end();
159 for ( ; I != E; I++) {
160 UpdatedStmtList[UpdatedStmtCount++] = *I;
161 }
162 }
163
164 CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
165
166 delete [] UpdatedStmtList;
167 }
168
169 // This class visits a compound statement and inserts DtorStmt
170 // in proper locations. This includes inserting it before any
171 // return statement in any sub-block, at the end of the logical enclosing
172 // scope (compound statement), and/or before any break/continue statement that
173 // would resume outside the declared scope. We will not handle the case for
174 // goto statements that leave a local scope.
175 //
176 // To accomplish these goals, it collects a list of sub-Stmt's that
177 // correspond to scope exit points. It then uses an RSASTReplace visitor to
178 // transform the AST, inserting appropriate destructors before each of those
179 // sub-Stmt's (and also before the exit of the outermost containing Stmt for
180 // the scope).
181 class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
182 private:
183 clang::ASTContext &mCtx;
184
185 // The loop depth of the currently visited node.
186 int mLoopDepth;
187
188 // The switch statement depth of the currently visited node.
189 // Note that this is tracked separately from the loop depth because
190 // SwitchStmt-contained ContinueStmt's should have destructors for the
191 // corresponding loop scope.
192 int mSwitchDepth;
193
194 // The outermost statement block that we are currently visiting.
195 // This should always be a CompoundStmt.
196 clang::Stmt *mOuterStmt;
197
198 // The destructor to execute for this scope/variable.
199 clang::Stmt* mDtorStmt;
200
201 // The stack of statements which should be replaced by a compound statement
202 // containing the new destructor call followed by the original Stmt.
203 std::stack<clang::Stmt*> mReplaceStmtStack;
204
205 // The source location for the variable declaration that we are trying to
206 // insert destructors for. Note that InsertDestructors() will not generate
207 // destructor calls for source locations that occur lexically before this
208 // location.
209 clang::SourceLocation mVarLoc;
210
211 public:
212 DestructorVisitor(clang::ASTContext &C,
213 clang::Stmt* OuterStmt,
214 clang::Stmt* DtorStmt,
215 clang::SourceLocation VarLoc);
216
217 // This code walks the collected list of Stmts to replace and actually does
218 // the replacement. It also finishes up by appending the destructor to the
219 // current outermost CompoundStmt.
InsertDestructors()220 void InsertDestructors() {
221 clang::Stmt *S = nullptr;
222 clang::SourceManager &SM = mCtx.getSourceManager();
223 std::list<clang::Stmt *> StmtList;
224 StmtList.push_back(mDtorStmt);
225
226 while (!mReplaceStmtStack.empty()) {
227 S = mReplaceStmtStack.top();
228 mReplaceStmtStack.pop();
229
230 // Skip all source locations that occur before the variable's
231 // declaration, since it won't have been initialized yet.
232 if (SM.isBeforeInTranslationUnit(S->getLocStart(), mVarLoc)) {
233 continue;
234 }
235
236 StmtList.push_back(S);
237 clang::CompoundStmt *CS =
238 BuildCompoundStmt(mCtx, StmtList, S->getLocEnd());
239 StmtList.pop_back();
240
241 RSASTReplace R(mCtx);
242 R.ReplaceStmt(mOuterStmt, S, CS);
243 }
244 clang::CompoundStmt *CS =
245 llvm::dyn_cast<clang::CompoundStmt>(mOuterStmt);
246 slangAssert(CS);
247 AppendAfterStmt(mCtx, CS, nullptr, StmtList);
248 }
249
250 void VisitStmt(clang::Stmt *S);
251 void VisitCompoundStmt(clang::CompoundStmt *CS);
252
253 void VisitBreakStmt(clang::BreakStmt *BS);
254 void VisitCaseStmt(clang::CaseStmt *CS);
255 void VisitContinueStmt(clang::ContinueStmt *CS);
256 void VisitDefaultStmt(clang::DefaultStmt *DS);
257 void VisitDoStmt(clang::DoStmt *DS);
258 void VisitForStmt(clang::ForStmt *FS);
259 void VisitIfStmt(clang::IfStmt *IS);
260 void VisitReturnStmt(clang::ReturnStmt *RS);
261 void VisitSwitchCase(clang::SwitchCase *SC);
262 void VisitSwitchStmt(clang::SwitchStmt *SS);
263 void VisitWhileStmt(clang::WhileStmt *WS);
264 };
265
DestructorVisitor(clang::ASTContext & C,clang::Stmt * OuterStmt,clang::Stmt * DtorStmt,clang::SourceLocation VarLoc)266 DestructorVisitor::DestructorVisitor(clang::ASTContext &C,
267 clang::Stmt *OuterStmt,
268 clang::Stmt *DtorStmt,
269 clang::SourceLocation VarLoc)
270 : mCtx(C),
271 mLoopDepth(0),
272 mSwitchDepth(0),
273 mOuterStmt(OuterStmt),
274 mDtorStmt(DtorStmt),
275 mVarLoc(VarLoc) {
276 }
277
VisitStmt(clang::Stmt * S)278 void DestructorVisitor::VisitStmt(clang::Stmt *S) {
279 for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
280 I != E;
281 I++) {
282 if (clang::Stmt *Child = *I) {
283 Visit(Child);
284 }
285 }
286 }
287
VisitCompoundStmt(clang::CompoundStmt * CS)288 void DestructorVisitor::VisitCompoundStmt(clang::CompoundStmt *CS) {
289 VisitStmt(CS);
290 }
291
VisitBreakStmt(clang::BreakStmt * BS)292 void DestructorVisitor::VisitBreakStmt(clang::BreakStmt *BS) {
293 VisitStmt(BS);
294 if ((mLoopDepth == 0) && (mSwitchDepth == 0)) {
295 mReplaceStmtStack.push(BS);
296 }
297 }
298
VisitCaseStmt(clang::CaseStmt * CS)299 void DestructorVisitor::VisitCaseStmt(clang::CaseStmt *CS) {
300 VisitStmt(CS);
301 }
302
VisitContinueStmt(clang::ContinueStmt * CS)303 void DestructorVisitor::VisitContinueStmt(clang::ContinueStmt *CS) {
304 VisitStmt(CS);
305 if (mLoopDepth == 0) {
306 // Switch statements can have nested continues.
307 mReplaceStmtStack.push(CS);
308 }
309 }
310
VisitDefaultStmt(clang::DefaultStmt * DS)311 void DestructorVisitor::VisitDefaultStmt(clang::DefaultStmt *DS) {
312 VisitStmt(DS);
313 }
314
VisitDoStmt(clang::DoStmt * DS)315 void DestructorVisitor::VisitDoStmt(clang::DoStmt *DS) {
316 mLoopDepth++;
317 VisitStmt(DS);
318 mLoopDepth--;
319 }
320
VisitForStmt(clang::ForStmt * FS)321 void DestructorVisitor::VisitForStmt(clang::ForStmt *FS) {
322 mLoopDepth++;
323 VisitStmt(FS);
324 mLoopDepth--;
325 }
326
VisitIfStmt(clang::IfStmt * IS)327 void DestructorVisitor::VisitIfStmt(clang::IfStmt *IS) {
328 VisitStmt(IS);
329 }
330
VisitReturnStmt(clang::ReturnStmt * RS)331 void DestructorVisitor::VisitReturnStmt(clang::ReturnStmt *RS) {
332 mReplaceStmtStack.push(RS);
333 }
334
VisitSwitchCase(clang::SwitchCase * SC)335 void DestructorVisitor::VisitSwitchCase(clang::SwitchCase *SC) {
336 slangAssert(false && "Both case and default have specialized handlers");
337 VisitStmt(SC);
338 }
339
VisitSwitchStmt(clang::SwitchStmt * SS)340 void DestructorVisitor::VisitSwitchStmt(clang::SwitchStmt *SS) {
341 mSwitchDepth++;
342 VisitStmt(SS);
343 mSwitchDepth--;
344 }
345
VisitWhileStmt(clang::WhileStmt * WS)346 void DestructorVisitor::VisitWhileStmt(clang::WhileStmt *WS) {
347 mLoopDepth++;
348 VisitStmt(WS);
349 mLoopDepth--;
350 }
351
ClearSingleRSObject(clang::ASTContext & C,clang::Expr * RefRSVar,clang::SourceLocation Loc)352 clang::Expr *ClearSingleRSObject(clang::ASTContext &C,
353 clang::Expr *RefRSVar,
354 clang::SourceLocation Loc) {
355 slangAssert(RefRSVar);
356 const clang::Type *T = RefRSVar->getType().getTypePtr();
357 slangAssert(!T->isArrayType() &&
358 "Should not be destroying arrays with this function");
359
360 clang::FunctionDecl *ClearObjectFD = RSObjectRefCount::GetRSClearObjectFD(T);
361 slangAssert((ClearObjectFD != nullptr) &&
362 "rsClearObject doesn't cover all RS object types");
363
364 clang::QualType ClearObjectFDType = ClearObjectFD->getType();
365 clang::QualType ClearObjectFDArgType =
366 ClearObjectFD->getParamDecl(0)->getOriginalType();
367
368 // Example destructor for "rs_font localFont;"
369 //
370 // (CallExpr 'void'
371 // (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
372 // (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
373 // (UnaryOperator 'rs_font *' prefix '&'
374 // (DeclRefExpr 'rs_font':'rs_font' Var='localFont')))
375
376 // Get address of targeted RS object
377 clang::Expr *AddrRefRSVar =
378 new(C) clang::UnaryOperator(RefRSVar,
379 clang::UO_AddrOf,
380 ClearObjectFDArgType,
381 clang::VK_RValue,
382 clang::OK_Ordinary,
383 Loc);
384
385 clang::Expr *RefRSClearObjectFD =
386 clang::DeclRefExpr::Create(C,
387 clang::NestedNameSpecifierLoc(),
388 clang::SourceLocation(),
389 ClearObjectFD,
390 false,
391 ClearObjectFD->getLocation(),
392 ClearObjectFDType,
393 clang::VK_RValue,
394 nullptr);
395
396 clang::Expr *RSClearObjectFP =
397 clang::ImplicitCastExpr::Create(C,
398 C.getPointerType(ClearObjectFDType),
399 clang::CK_FunctionToPointerDecay,
400 RefRSClearObjectFD,
401 nullptr,
402 clang::VK_RValue);
403
404 llvm::SmallVector<clang::Expr*, 1> ArgList;
405 ArgList.push_back(AddrRefRSVar);
406
407 clang::CallExpr *RSClearObjectCall =
408 new(C) clang::CallExpr(C,
409 RSClearObjectFP,
410 ArgList,
411 ClearObjectFD->getCallResultType(),
412 clang::VK_RValue,
413 Loc);
414
415 return RSClearObjectCall;
416 }
417
ArrayDim(const clang::Type * T)418 static int ArrayDim(const clang::Type *T) {
419 if (!T || !T->isArrayType()) {
420 return 0;
421 }
422
423 const clang::ConstantArrayType *CAT =
424 static_cast<const clang::ConstantArrayType *>(T);
425 return static_cast<int>(CAT->getSize().getSExtValue());
426 }
427
428 static clang::Stmt *ClearStructRSObject(
429 clang::ASTContext &C,
430 clang::DeclContext *DC,
431 clang::Expr *RefRSStruct,
432 clang::SourceLocation StartLoc,
433 clang::SourceLocation Loc);
434
ClearArrayRSObject(clang::ASTContext & C,clang::DeclContext * DC,clang::Expr * RefRSArr,clang::SourceLocation StartLoc,clang::SourceLocation Loc)435 static clang::Stmt *ClearArrayRSObject(
436 clang::ASTContext &C,
437 clang::DeclContext *DC,
438 clang::Expr *RefRSArr,
439 clang::SourceLocation StartLoc,
440 clang::SourceLocation Loc) {
441 const clang::Type *BaseType = RefRSArr->getType().getTypePtr();
442 slangAssert(BaseType->isArrayType());
443
444 int NumArrayElements = ArrayDim(BaseType);
445 // Actually extract out the base RS object type for use later
446 BaseType = BaseType->getArrayElementTypeNoTypeQual();
447
448 clang::Stmt *StmtArray[2] = {nullptr};
449 int StmtCtr = 0;
450
451 if (NumArrayElements <= 0) {
452 return nullptr;
453 }
454
455 // Example destructor loop for "rs_font fontArr[10];"
456 //
457 // (CompoundStmt
458 // (DeclStmt "int rsIntIter")
459 // (ForStmt
460 // (BinaryOperator 'int' '='
461 // (DeclRefExpr 'int' Var='rsIntIter')
462 // (IntegerLiteral 'int' 0))
463 // (BinaryOperator 'int' '<'
464 // (DeclRefExpr 'int' Var='rsIntIter')
465 // (IntegerLiteral 'int' 10)
466 // nullptr << CondVar >>
467 // (UnaryOperator 'int' postfix '++'
468 // (DeclRefExpr 'int' Var='rsIntIter'))
469 // (CallExpr 'void'
470 // (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
471 // (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
472 // (UnaryOperator 'rs_font *' prefix '&'
473 // (ArraySubscriptExpr 'rs_font':'rs_font'
474 // (ImplicitCastExpr 'rs_font *' <ArrayToPointerDecay>
475 // (DeclRefExpr 'rs_font [10]' Var='fontArr'))
476 // (DeclRefExpr 'int' Var='rsIntIter')))))))
477
478 // Create helper variable for iterating through elements
479 clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
480 clang::VarDecl *IIVD =
481 clang::VarDecl::Create(C,
482 DC,
483 StartLoc,
484 Loc,
485 &II,
486 C.IntTy,
487 C.getTrivialTypeSourceInfo(C.IntTy),
488 clang::SC_None);
489 // Mark "rsIntIter" as used
490 IIVD->markUsed(C);
491 clang::Decl *IID = (clang::Decl *)IIVD;
492
493 clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
494 StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
495
496 // Form the actual destructor loop
497 // for (Init; Cond; Inc)
498 // RSClearObjectCall;
499
500 // Init -> "rsIntIter = 0"
501 clang::DeclRefExpr *RefrsIntIter =
502 clang::DeclRefExpr::Create(C,
503 clang::NestedNameSpecifierLoc(),
504 clang::SourceLocation(),
505 IIVD,
506 false,
507 Loc,
508 C.IntTy,
509 clang::VK_RValue,
510 nullptr);
511
512 clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
513 llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
514
515 clang::BinaryOperator *Init =
516 new(C) clang::BinaryOperator(RefrsIntIter,
517 Int0,
518 clang::BO_Assign,
519 C.IntTy,
520 clang::VK_RValue,
521 clang::OK_Ordinary,
522 Loc,
523 false);
524
525 // Cond -> "rsIntIter < NumArrayElements"
526 clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
527 llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
528
529 clang::BinaryOperator *Cond =
530 new(C) clang::BinaryOperator(RefrsIntIter,
531 NumArrayElementsExpr,
532 clang::BO_LT,
533 C.IntTy,
534 clang::VK_RValue,
535 clang::OK_Ordinary,
536 Loc,
537 false);
538
539 // Inc -> "rsIntIter++"
540 clang::UnaryOperator *Inc =
541 new(C) clang::UnaryOperator(RefrsIntIter,
542 clang::UO_PostInc,
543 C.IntTy,
544 clang::VK_RValue,
545 clang::OK_Ordinary,
546 Loc);
547
548 // Body -> "rsClearObject(&VD[rsIntIter]);"
549 // Destructor loop operates on individual array elements
550
551 clang::Expr *RefRSArrPtr =
552 clang::ImplicitCastExpr::Create(C,
553 C.getPointerType(BaseType->getCanonicalTypeInternal()),
554 clang::CK_ArrayToPointerDecay,
555 RefRSArr,
556 nullptr,
557 clang::VK_RValue);
558
559 clang::Expr *RefRSArrPtrSubscript =
560 new(C) clang::ArraySubscriptExpr(RefRSArrPtr,
561 RefrsIntIter,
562 BaseType->getCanonicalTypeInternal(),
563 clang::VK_RValue,
564 clang::OK_Ordinary,
565 Loc);
566
567 DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
568
569 clang::Stmt *RSClearObjectCall = nullptr;
570 if (BaseType->isArrayType()) {
571 RSClearObjectCall =
572 ClearArrayRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
573 } else if (DT == DataTypeUnknown) {
574 RSClearObjectCall =
575 ClearStructRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
576 } else {
577 RSClearObjectCall = ClearSingleRSObject(C, RefRSArrPtrSubscript, Loc);
578 }
579
580 clang::ForStmt *DestructorLoop =
581 new(C) clang::ForStmt(C,
582 Init,
583 Cond,
584 nullptr, // no condVar
585 Inc,
586 RSClearObjectCall,
587 Loc,
588 Loc,
589 Loc);
590
591 StmtArray[StmtCtr++] = DestructorLoop;
592 slangAssert(StmtCtr == 2);
593
594 clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
595 C, llvm::makeArrayRef(StmtArray, StmtCtr), Loc, Loc);
596
597 return CS;
598 }
599
CountRSObjectTypes(clang::ASTContext & C,const clang::Type * T,clang::SourceLocation Loc)600 static unsigned CountRSObjectTypes(clang::ASTContext &C,
601 const clang::Type *T,
602 clang::SourceLocation Loc) {
603 slangAssert(T);
604 unsigned RSObjectCount = 0;
605
606 if (T->isArrayType()) {
607 return CountRSObjectTypes(C, T->getArrayElementTypeNoTypeQual(), Loc);
608 }
609
610 DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
611 if (DT != DataTypeUnknown) {
612 return (RSExportPrimitiveType::IsRSObjectType(DT) ? 1 : 0);
613 }
614
615 if (T->isUnionType()) {
616 clang::RecordDecl *RD = T->getAsUnionType()->getDecl();
617 RD = RD->getDefinition();
618 for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
619 FE = RD->field_end();
620 FI != FE;
621 FI++) {
622 const clang::FieldDecl *FD = *FI;
623 const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
624 if (CountRSObjectTypes(C, FT, Loc)) {
625 slangAssert(false && "can't have unions with RS object types!");
626 return 0;
627 }
628 }
629 }
630
631 if (!T->isStructureType()) {
632 return 0;
633 }
634
635 clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
636 RD = RD->getDefinition();
637 for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
638 FE = RD->field_end();
639 FI != FE;
640 FI++) {
641 const clang::FieldDecl *FD = *FI;
642 const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
643 if (CountRSObjectTypes(C, FT, Loc)) {
644 // Sub-structs should only count once (as should arrays, etc.)
645 RSObjectCount++;
646 }
647 }
648
649 return RSObjectCount;
650 }
651
ClearStructRSObject(clang::ASTContext & C,clang::DeclContext * DC,clang::Expr * RefRSStruct,clang::SourceLocation StartLoc,clang::SourceLocation Loc)652 static clang::Stmt *ClearStructRSObject(
653 clang::ASTContext &C,
654 clang::DeclContext *DC,
655 clang::Expr *RefRSStruct,
656 clang::SourceLocation StartLoc,
657 clang::SourceLocation Loc) {
658 const clang::Type *BaseType = RefRSStruct->getType().getTypePtr();
659
660 slangAssert(!BaseType->isArrayType());
661
662 // Structs should show up as unknown primitive types
663 slangAssert(RSExportPrimitiveType::GetRSSpecificType(BaseType) ==
664 DataTypeUnknown);
665
666 unsigned FieldsToDestroy = CountRSObjectTypes(C, BaseType, Loc);
667 slangAssert(FieldsToDestroy != 0);
668
669 unsigned StmtCount = 0;
670 clang::Stmt **StmtArray = new clang::Stmt*[FieldsToDestroy];
671 for (unsigned i = 0; i < FieldsToDestroy; i++) {
672 StmtArray[i] = nullptr;
673 }
674
675 // Populate StmtArray by creating a destructor for each RS object field
676 clang::RecordDecl *RD = BaseType->getAsStructureType()->getDecl();
677 RD = RD->getDefinition();
678 for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
679 FE = RD->field_end();
680 FI != FE;
681 FI++) {
682 // We just look through all field declarations to see if we find a
683 // declaration for an RS object type (or an array of one).
684 bool IsArrayType = false;
685 clang::FieldDecl *FD = *FI;
686 const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
687 const clang::Type *OrigType = FT;
688 while (FT && FT->isArrayType()) {
689 FT = FT->getArrayElementTypeNoTypeQual();
690 IsArrayType = true;
691 }
692
693 // Pass a DeclarationNameInfo with a valid DeclName, since name equality
694 // gets asserted during CodeGen.
695 clang::DeclarationNameInfo FDDeclNameInfo(FD->getDeclName(),
696 FD->getLocation());
697
698 if (RSExportPrimitiveType::IsRSObjectType(FT)) {
699 clang::DeclAccessPair FoundDecl =
700 clang::DeclAccessPair::make(FD, clang::AS_none);
701 clang::MemberExpr *RSObjectMember =
702 clang::MemberExpr::Create(C,
703 RefRSStruct,
704 false,
705 clang::SourceLocation(),
706 clang::NestedNameSpecifierLoc(),
707 clang::SourceLocation(),
708 FD,
709 FoundDecl,
710 FDDeclNameInfo,
711 nullptr,
712 OrigType->getCanonicalTypeInternal(),
713 clang::VK_RValue,
714 clang::OK_Ordinary);
715
716 slangAssert(StmtCount < FieldsToDestroy);
717
718 if (IsArrayType) {
719 StmtArray[StmtCount++] = ClearArrayRSObject(C,
720 DC,
721 RSObjectMember,
722 StartLoc,
723 Loc);
724 } else {
725 StmtArray[StmtCount++] = ClearSingleRSObject(C,
726 RSObjectMember,
727 Loc);
728 }
729 } else if (FT->isStructureType() && CountRSObjectTypes(C, FT, Loc)) {
730 // In this case, we have a nested struct. We may not end up filling all
731 // of the spaces in StmtArray (sub-structs should handle themselves
732 // with separate compound statements).
733 clang::DeclAccessPair FoundDecl =
734 clang::DeclAccessPair::make(FD, clang::AS_none);
735 clang::MemberExpr *RSObjectMember =
736 clang::MemberExpr::Create(C,
737 RefRSStruct,
738 false,
739 clang::SourceLocation(),
740 clang::NestedNameSpecifierLoc(),
741 clang::SourceLocation(),
742 FD,
743 FoundDecl,
744 clang::DeclarationNameInfo(),
745 nullptr,
746 OrigType->getCanonicalTypeInternal(),
747 clang::VK_RValue,
748 clang::OK_Ordinary);
749
750 if (IsArrayType) {
751 StmtArray[StmtCount++] = ClearArrayRSObject(C,
752 DC,
753 RSObjectMember,
754 StartLoc,
755 Loc);
756 } else {
757 StmtArray[StmtCount++] = ClearStructRSObject(C,
758 DC,
759 RSObjectMember,
760 StartLoc,
761 Loc);
762 }
763 }
764 }
765
766 slangAssert(StmtCount > 0);
767 clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
768 C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
769
770 delete [] StmtArray;
771
772 return CS;
773 }
774
CreateSingleRSSetObject(clang::ASTContext & C,clang::Expr * DstExpr,clang::Expr * SrcExpr,clang::SourceLocation StartLoc,clang::SourceLocation Loc)775 static clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
776 clang::Expr *DstExpr,
777 clang::Expr *SrcExpr,
778 clang::SourceLocation StartLoc,
779 clang::SourceLocation Loc) {
780 const clang::Type *T = DstExpr->getType().getTypePtr();
781 clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(T);
782 slangAssert((SetObjectFD != nullptr) &&
783 "rsSetObject doesn't cover all RS object types");
784
785 clang::QualType SetObjectFDType = SetObjectFD->getType();
786 clang::QualType SetObjectFDArgType[2];
787 SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
788 SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
789
790 clang::Expr *RefRSSetObjectFD =
791 clang::DeclRefExpr::Create(C,
792 clang::NestedNameSpecifierLoc(),
793 clang::SourceLocation(),
794 SetObjectFD,
795 false,
796 Loc,
797 SetObjectFDType,
798 clang::VK_RValue,
799 nullptr);
800
801 clang::Expr *RSSetObjectFP =
802 clang::ImplicitCastExpr::Create(C,
803 C.getPointerType(SetObjectFDType),
804 clang::CK_FunctionToPointerDecay,
805 RefRSSetObjectFD,
806 nullptr,
807 clang::VK_RValue);
808
809 llvm::SmallVector<clang::Expr*, 2> ArgList;
810 ArgList.push_back(new(C) clang::UnaryOperator(DstExpr,
811 clang::UO_AddrOf,
812 SetObjectFDArgType[0],
813 clang::VK_RValue,
814 clang::OK_Ordinary,
815 Loc));
816 ArgList.push_back(SrcExpr);
817
818 clang::CallExpr *RSSetObjectCall =
819 new(C) clang::CallExpr(C,
820 RSSetObjectFP,
821 ArgList,
822 SetObjectFD->getCallResultType(),
823 clang::VK_RValue,
824 Loc);
825
826 return RSSetObjectCall;
827 }
828
829 static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
830 clang::Expr *LHS,
831 clang::Expr *RHS,
832 clang::SourceLocation StartLoc,
833 clang::SourceLocation Loc);
834
835 /*static clang::Stmt *CreateArrayRSSetObject(clang::ASTContext &C,
836 clang::Expr *DstArr,
837 clang::Expr *SrcArr,
838 clang::SourceLocation StartLoc,
839 clang::SourceLocation Loc) {
840 clang::DeclContext *DC = nullptr;
841 const clang::Type *BaseType = DstArr->getType().getTypePtr();
842 slangAssert(BaseType->isArrayType());
843
844 int NumArrayElements = ArrayDim(BaseType);
845 // Actually extract out the base RS object type for use later
846 BaseType = BaseType->getArrayElementTypeNoTypeQual();
847
848 clang::Stmt *StmtArray[2] = {nullptr};
849 int StmtCtr = 0;
850
851 if (NumArrayElements <= 0) {
852 return nullptr;
853 }
854
855 // Create helper variable for iterating through elements
856 clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
857 clang::VarDecl *IIVD =
858 clang::VarDecl::Create(C,
859 DC,
860 StartLoc,
861 Loc,
862 &II,
863 C.IntTy,
864 C.getTrivialTypeSourceInfo(C.IntTy),
865 clang::SC_None,
866 clang::SC_None);
867 clang::Decl *IID = (clang::Decl *)IIVD;
868
869 clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
870 StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
871
872 // Form the actual loop
873 // for (Init; Cond; Inc)
874 // RSSetObjectCall;
875
876 // Init -> "rsIntIter = 0"
877 clang::DeclRefExpr *RefrsIntIter =
878 clang::DeclRefExpr::Create(C,
879 clang::NestedNameSpecifierLoc(),
880 IIVD,
881 Loc,
882 C.IntTy,
883 clang::VK_RValue,
884 nullptr);
885
886 clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
887 llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
888
889 clang::BinaryOperator *Init =
890 new(C) clang::BinaryOperator(RefrsIntIter,
891 Int0,
892 clang::BO_Assign,
893 C.IntTy,
894 clang::VK_RValue,
895 clang::OK_Ordinary,
896 Loc);
897
898 // Cond -> "rsIntIter < NumArrayElements"
899 clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
900 llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
901
902 clang::BinaryOperator *Cond =
903 new(C) clang::BinaryOperator(RefrsIntIter,
904 NumArrayElementsExpr,
905 clang::BO_LT,
906 C.IntTy,
907 clang::VK_RValue,
908 clang::OK_Ordinary,
909 Loc);
910
911 // Inc -> "rsIntIter++"
912 clang::UnaryOperator *Inc =
913 new(C) clang::UnaryOperator(RefrsIntIter,
914 clang::UO_PostInc,
915 C.IntTy,
916 clang::VK_RValue,
917 clang::OK_Ordinary,
918 Loc);
919
920 // Body -> "rsSetObject(&Dst[rsIntIter], Src[rsIntIter]);"
921 // Loop operates on individual array elements
922
923 clang::Expr *DstArrPtr =
924 clang::ImplicitCastExpr::Create(C,
925 C.getPointerType(BaseType->getCanonicalTypeInternal()),
926 clang::CK_ArrayToPointerDecay,
927 DstArr,
928 nullptr,
929 clang::VK_RValue);
930
931 clang::Expr *DstArrPtrSubscript =
932 new(C) clang::ArraySubscriptExpr(DstArrPtr,
933 RefrsIntIter,
934 BaseType->getCanonicalTypeInternal(),
935 clang::VK_RValue,
936 clang::OK_Ordinary,
937 Loc);
938
939 clang::Expr *SrcArrPtr =
940 clang::ImplicitCastExpr::Create(C,
941 C.getPointerType(BaseType->getCanonicalTypeInternal()),
942 clang::CK_ArrayToPointerDecay,
943 SrcArr,
944 nullptr,
945 clang::VK_RValue);
946
947 clang::Expr *SrcArrPtrSubscript =
948 new(C) clang::ArraySubscriptExpr(SrcArrPtr,
949 RefrsIntIter,
950 BaseType->getCanonicalTypeInternal(),
951 clang::VK_RValue,
952 clang::OK_Ordinary,
953 Loc);
954
955 DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
956
957 clang::Stmt *RSSetObjectCall = nullptr;
958 if (BaseType->isArrayType()) {
959 RSSetObjectCall = CreateArrayRSSetObject(C, DstArrPtrSubscript,
960 SrcArrPtrSubscript,
961 StartLoc, Loc);
962 } else if (DT == DataTypeUnknown) {
963 RSSetObjectCall = CreateStructRSSetObject(C, DstArrPtrSubscript,
964 SrcArrPtrSubscript,
965 StartLoc, Loc);
966 } else {
967 RSSetObjectCall = CreateSingleRSSetObject(C, DstArrPtrSubscript,
968 SrcArrPtrSubscript,
969 StartLoc, Loc);
970 }
971
972 clang::ForStmt *DestructorLoop =
973 new(C) clang::ForStmt(C,
974 Init,
975 Cond,
976 nullptr, // no condVar
977 Inc,
978 RSSetObjectCall,
979 Loc,
980 Loc,
981 Loc);
982
983 StmtArray[StmtCtr++] = DestructorLoop;
984 slangAssert(StmtCtr == 2);
985
986 clang::CompoundStmt *CS =
987 new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
988
989 return CS;
990 } */
991
CreateStructRSSetObject(clang::ASTContext & C,clang::Expr * LHS,clang::Expr * RHS,clang::SourceLocation StartLoc,clang::SourceLocation Loc)992 static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
993 clang::Expr *LHS,
994 clang::Expr *RHS,
995 clang::SourceLocation StartLoc,
996 clang::SourceLocation Loc) {
997 clang::QualType QT = LHS->getType();
998 const clang::Type *T = QT.getTypePtr();
999 slangAssert(T->isStructureType());
1000 slangAssert(!RSExportPrimitiveType::IsRSObjectType(T));
1001
1002 // Keep an extra slot for the original copy (memcpy)
1003 unsigned FieldsToSet = CountRSObjectTypes(C, T, Loc) + 1;
1004
1005 unsigned StmtCount = 0;
1006 clang::Stmt **StmtArray = new clang::Stmt*[FieldsToSet];
1007 for (unsigned i = 0; i < FieldsToSet; i++) {
1008 StmtArray[i] = nullptr;
1009 }
1010
1011 clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
1012 RD = RD->getDefinition();
1013 for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1014 FE = RD->field_end();
1015 FI != FE;
1016 FI++) {
1017 bool IsArrayType = false;
1018 clang::FieldDecl *FD = *FI;
1019 const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1020 const clang::Type *OrigType = FT;
1021
1022 if (!CountRSObjectTypes(C, FT, Loc)) {
1023 // Skip to next if we don't have any viable RS object types
1024 continue;
1025 }
1026
1027 clang::DeclAccessPair FoundDecl =
1028 clang::DeclAccessPair::make(FD, clang::AS_none);
1029 clang::MemberExpr *DstMember =
1030 clang::MemberExpr::Create(C,
1031 LHS,
1032 false,
1033 clang::SourceLocation(),
1034 clang::NestedNameSpecifierLoc(),
1035 clang::SourceLocation(),
1036 FD,
1037 FoundDecl,
1038 clang::DeclarationNameInfo(),
1039 nullptr,
1040 OrigType->getCanonicalTypeInternal(),
1041 clang::VK_RValue,
1042 clang::OK_Ordinary);
1043
1044 clang::MemberExpr *SrcMember =
1045 clang::MemberExpr::Create(C,
1046 RHS,
1047 false,
1048 clang::SourceLocation(),
1049 clang::NestedNameSpecifierLoc(),
1050 clang::SourceLocation(),
1051 FD,
1052 FoundDecl,
1053 clang::DeclarationNameInfo(),
1054 nullptr,
1055 OrigType->getCanonicalTypeInternal(),
1056 clang::VK_RValue,
1057 clang::OK_Ordinary);
1058
1059 if (FT->isArrayType()) {
1060 FT = FT->getArrayElementTypeNoTypeQual();
1061 IsArrayType = true;
1062 }
1063
1064 DataType DT = RSExportPrimitiveType::GetRSSpecificType(FT);
1065
1066 if (IsArrayType) {
1067 clang::DiagnosticsEngine &DiagEngine = C.getDiagnostics();
1068 DiagEngine.Report(
1069 clang::FullSourceLoc(Loc, C.getSourceManager()),
1070 DiagEngine.getCustomDiagID(
1071 clang::DiagnosticsEngine::Error,
1072 "Arrays of RS object types within structures cannot be copied"));
1073 // TODO(srhines): Support setting arrays of RS objects
1074 // StmtArray[StmtCount++] =
1075 // CreateArrayRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1076 } else if (DT == DataTypeUnknown) {
1077 StmtArray[StmtCount++] =
1078 CreateStructRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1079 } else if (RSExportPrimitiveType::IsRSObjectType(DT)) {
1080 StmtArray[StmtCount++] =
1081 CreateSingleRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1082 } else {
1083 slangAssert(false);
1084 }
1085 }
1086
1087 slangAssert(StmtCount < FieldsToSet);
1088
1089 // We still need to actually do the overall struct copy. For simplicity,
1090 // we just do a straight-up assignment (which will still preserve all
1091 // the proper RS object reference counts).
1092 clang::BinaryOperator *CopyStruct =
1093 new(C) clang::BinaryOperator(LHS, RHS, clang::BO_Assign, QT,
1094 clang::VK_RValue, clang::OK_Ordinary, Loc,
1095 false);
1096 StmtArray[StmtCount++] = CopyStruct;
1097
1098 clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
1099 C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
1100
1101 delete [] StmtArray;
1102
1103 return CS;
1104 }
1105
1106 } // namespace
1107
ReplaceRSObjectAssignment(clang::BinaryOperator * AS)1108 void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
1109 clang::BinaryOperator *AS) {
1110
1111 clang::QualType QT = AS->getType();
1112
1113 clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1114 DataTypeRSAllocation)->getASTContext();
1115
1116 clang::SourceLocation Loc = AS->getExprLoc();
1117 clang::SourceLocation StartLoc = AS->getLHS()->getExprLoc();
1118 clang::Stmt *UpdatedStmt = nullptr;
1119
1120 if (!RSExportPrimitiveType::IsRSObjectType(QT.getTypePtr())) {
1121 // By definition, this is a struct assignment if we get here
1122 UpdatedStmt =
1123 CreateStructRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1124 } else {
1125 UpdatedStmt =
1126 CreateSingleRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1127 }
1128
1129 RSASTReplace R(C);
1130 R.ReplaceStmt(mCS, AS, UpdatedStmt);
1131 }
1132
AppendRSObjectInit(clang::VarDecl * VD,clang::DeclStmt * DS,DataType DT,clang::Expr * InitExpr)1133 void RSObjectRefCount::Scope::AppendRSObjectInit(
1134 clang::VarDecl *VD,
1135 clang::DeclStmt *DS,
1136 DataType DT,
1137 clang::Expr *InitExpr) {
1138 slangAssert(VD);
1139
1140 if (!InitExpr) {
1141 return;
1142 }
1143
1144 clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1145 DataTypeRSAllocation)->getASTContext();
1146 clang::SourceLocation Loc = RSObjectRefCount::GetRSSetObjectFD(
1147 DataTypeRSAllocation)->getLocation();
1148 clang::SourceLocation StartLoc = RSObjectRefCount::GetRSSetObjectFD(
1149 DataTypeRSAllocation)->getInnerLocStart();
1150
1151 if (DT == DataTypeIsStruct) {
1152 const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1153 clang::DeclRefExpr *RefRSVar =
1154 clang::DeclRefExpr::Create(C,
1155 clang::NestedNameSpecifierLoc(),
1156 clang::SourceLocation(),
1157 VD,
1158 false,
1159 Loc,
1160 T->getCanonicalTypeInternal(),
1161 clang::VK_RValue,
1162 nullptr);
1163
1164 clang::Stmt *RSSetObjectOps =
1165 CreateStructRSSetObject(C, RefRSVar, InitExpr, StartLoc, Loc);
1166
1167 std::list<clang::Stmt*> StmtList;
1168 StmtList.push_back(RSSetObjectOps);
1169 AppendAfterStmt(C, mCS, DS, StmtList);
1170 return;
1171 }
1172
1173 clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(DT);
1174 slangAssert((SetObjectFD != nullptr) &&
1175 "rsSetObject doesn't cover all RS object types");
1176
1177 clang::QualType SetObjectFDType = SetObjectFD->getType();
1178 clang::QualType SetObjectFDArgType[2];
1179 SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
1180 SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
1181
1182 clang::Expr *RefRSSetObjectFD =
1183 clang::DeclRefExpr::Create(C,
1184 clang::NestedNameSpecifierLoc(),
1185 clang::SourceLocation(),
1186 SetObjectFD,
1187 false,
1188 Loc,
1189 SetObjectFDType,
1190 clang::VK_RValue,
1191 nullptr);
1192
1193 clang::Expr *RSSetObjectFP =
1194 clang::ImplicitCastExpr::Create(C,
1195 C.getPointerType(SetObjectFDType),
1196 clang::CK_FunctionToPointerDecay,
1197 RefRSSetObjectFD,
1198 nullptr,
1199 clang::VK_RValue);
1200
1201 const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1202 clang::DeclRefExpr *RefRSVar =
1203 clang::DeclRefExpr::Create(C,
1204 clang::NestedNameSpecifierLoc(),
1205 clang::SourceLocation(),
1206 VD,
1207 false,
1208 Loc,
1209 T->getCanonicalTypeInternal(),
1210 clang::VK_RValue,
1211 nullptr);
1212
1213 llvm::SmallVector<clang::Expr*, 2> ArgList;
1214 ArgList.push_back(new(C) clang::UnaryOperator(RefRSVar,
1215 clang::UO_AddrOf,
1216 SetObjectFDArgType[0],
1217 clang::VK_RValue,
1218 clang::OK_Ordinary,
1219 Loc));
1220 ArgList.push_back(InitExpr);
1221
1222 clang::CallExpr *RSSetObjectCall =
1223 new(C) clang::CallExpr(C,
1224 RSSetObjectFP,
1225 ArgList,
1226 SetObjectFD->getCallResultType(),
1227 clang::VK_RValue,
1228 Loc);
1229
1230 std::list<clang::Stmt*> StmtList;
1231 StmtList.push_back(RSSetObjectCall);
1232 AppendAfterStmt(C, mCS, DS, StmtList);
1233 }
1234
InsertLocalVarDestructors()1235 void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
1236 for (std::list<clang::VarDecl*>::const_iterator I = mRSO.begin(),
1237 E = mRSO.end();
1238 I != E;
1239 I++) {
1240 clang::VarDecl *VD = *I;
1241 clang::Stmt *RSClearObjectCall = ClearRSObject(VD, VD->getDeclContext());
1242 if (RSClearObjectCall) {
1243 clang::ASTContext &C = (*mRSO.begin())->getASTContext();
1244 // Mark VD as used. It might be unused, except for the destructor.
1245 // 'markUsed' has side-effects that are caused only if VD is not already
1246 // used. Hence no need for an extra check here.
1247 VD->markUsed(C);
1248 DestructorVisitor DV(C,
1249 mCS,
1250 RSClearObjectCall,
1251 VD->getSourceRange().getBegin());
1252 DV.Visit(mCS);
1253 DV.InsertDestructors();
1254 }
1255 }
1256 }
1257
ClearRSObject(clang::VarDecl * VD,clang::DeclContext * DC)1258 clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(
1259 clang::VarDecl *VD,
1260 clang::DeclContext *DC) {
1261 slangAssert(VD);
1262 clang::ASTContext &C = VD->getASTContext();
1263 clang::SourceLocation Loc = VD->getLocation();
1264 clang::SourceLocation StartLoc = VD->getInnerLocStart();
1265 const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1266
1267 // Reference expr to target RS object variable
1268 clang::DeclRefExpr *RefRSVar =
1269 clang::DeclRefExpr::Create(C,
1270 clang::NestedNameSpecifierLoc(),
1271 clang::SourceLocation(),
1272 VD,
1273 false,
1274 Loc,
1275 T->getCanonicalTypeInternal(),
1276 clang::VK_RValue,
1277 nullptr);
1278
1279 if (T->isArrayType()) {
1280 return ClearArrayRSObject(C, DC, RefRSVar, StartLoc, Loc);
1281 }
1282
1283 DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
1284
1285 if (DT == DataTypeUnknown ||
1286 DT == DataTypeIsStruct) {
1287 return ClearStructRSObject(C, DC, RefRSVar, StartLoc, Loc);
1288 }
1289
1290 slangAssert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
1291 "Should be RS object");
1292
1293 return ClearSingleRSObject(C, RefRSVar, Loc);
1294 }
1295
InitializeRSObject(clang::VarDecl * VD,DataType * DT,clang::Expr ** InitExpr)1296 bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD,
1297 DataType *DT,
1298 clang::Expr **InitExpr) {
1299 slangAssert(VD && DT && InitExpr);
1300 const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1301
1302 // Loop through array types to get to base type
1303 while (T && T->isArrayType()) {
1304 T = T->getArrayElementTypeNoTypeQual();
1305 }
1306
1307 bool DataTypeIsStructWithRSObject = false;
1308 *DT = RSExportPrimitiveType::GetRSSpecificType(T);
1309
1310 if (*DT == DataTypeUnknown) {
1311 if (RSExportPrimitiveType::IsStructureTypeWithRSObject(T)) {
1312 *DT = DataTypeIsStruct;
1313 DataTypeIsStructWithRSObject = true;
1314 } else {
1315 return false;
1316 }
1317 }
1318
1319 bool DataTypeIsRSObject = false;
1320 if (DataTypeIsStructWithRSObject) {
1321 DataTypeIsRSObject = true;
1322 } else {
1323 DataTypeIsRSObject = RSExportPrimitiveType::IsRSObjectType(*DT);
1324 }
1325 *InitExpr = VD->getInit();
1326
1327 if (!DataTypeIsRSObject && *InitExpr) {
1328 // If we already have an initializer for a matrix type, we are done.
1329 return DataTypeIsRSObject;
1330 }
1331
1332 clang::Expr *ZeroInitializer =
1333 CreateZeroInitializerForRSSpecificType(*DT,
1334 VD->getASTContext(),
1335 VD->getLocation());
1336
1337 if (ZeroInitializer) {
1338 ZeroInitializer->setType(T->getCanonicalTypeInternal());
1339 VD->setInit(ZeroInitializer);
1340 }
1341
1342 return DataTypeIsRSObject;
1343 }
1344
CreateZeroInitializerForRSSpecificType(DataType DT,clang::ASTContext & C,const clang::SourceLocation & Loc)1345 clang::Expr *RSObjectRefCount::CreateZeroInitializerForRSSpecificType(
1346 DataType DT,
1347 clang::ASTContext &C,
1348 const clang::SourceLocation &Loc) {
1349 clang::Expr *Res = nullptr;
1350 switch (DT) {
1351 case DataTypeIsStruct:
1352 case DataTypeRSElement:
1353 case DataTypeRSType:
1354 case DataTypeRSAllocation:
1355 case DataTypeRSSampler:
1356 case DataTypeRSScript:
1357 case DataTypeRSMesh:
1358 case DataTypeRSPath:
1359 case DataTypeRSProgramFragment:
1360 case DataTypeRSProgramVertex:
1361 case DataTypeRSProgramRaster:
1362 case DataTypeRSProgramStore:
1363 case DataTypeRSFont: {
1364 // (ImplicitCastExpr 'nullptr_t'
1365 // (IntegerLiteral 0)))
1366 llvm::APInt Zero(C.getTypeSize(C.IntTy), 0);
1367 clang::Expr *Int0 = clang::IntegerLiteral::Create(C, Zero, C.IntTy, Loc);
1368 clang::Expr *CastToNull =
1369 clang::ImplicitCastExpr::Create(C,
1370 C.NullPtrTy,
1371 clang::CK_IntegralToPointer,
1372 Int0,
1373 nullptr,
1374 clang::VK_RValue);
1375
1376 llvm::SmallVector<clang::Expr*, 1>InitList;
1377 InitList.push_back(CastToNull);
1378
1379 Res = new(C) clang::InitListExpr(C, Loc, InitList, Loc);
1380 break;
1381 }
1382 case DataTypeRSMatrix2x2:
1383 case DataTypeRSMatrix3x3:
1384 case DataTypeRSMatrix4x4: {
1385 // RS matrix is not completely an RS object. They hold data by themselves.
1386 // (InitListExpr rs_matrix2x2
1387 // (InitListExpr float[4]
1388 // (FloatingLiteral 0)
1389 // (FloatingLiteral 0)
1390 // (FloatingLiteral 0)
1391 // (FloatingLiteral 0)))
1392 clang::QualType FloatTy = C.FloatTy;
1393 // Constructor sets value to 0.0f by default
1394 llvm::APFloat Val(C.getFloatTypeSemantics(FloatTy));
1395 clang::FloatingLiteral *Float0Val =
1396 clang::FloatingLiteral::Create(C,
1397 Val,
1398 /* isExact = */true,
1399 FloatTy,
1400 Loc);
1401
1402 unsigned N = 0;
1403 if (DT == DataTypeRSMatrix2x2)
1404 N = 2;
1405 else if (DT == DataTypeRSMatrix3x3)
1406 N = 3;
1407 else if (DT == DataTypeRSMatrix4x4)
1408 N = 4;
1409 unsigned N_2 = N * N;
1410
1411 // Assume we are going to be allocating 16 elements, since 4x4 is max.
1412 llvm::SmallVector<clang::Expr*, 16> InitVals;
1413 for (unsigned i = 0; i < N_2; i++)
1414 InitVals.push_back(Float0Val);
1415 clang::Expr *InitExpr =
1416 new(C) clang::InitListExpr(C, Loc, InitVals, Loc);
1417 InitExpr->setType(C.getConstantArrayType(FloatTy,
1418 llvm::APInt(32, N_2),
1419 clang::ArrayType::Normal,
1420 /* EltTypeQuals = */0));
1421 llvm::SmallVector<clang::Expr*, 1> InitExprVec;
1422 InitExprVec.push_back(InitExpr);
1423
1424 Res = new(C) clang::InitListExpr(C, Loc, InitExprVec, Loc);
1425 break;
1426 }
1427 case DataTypeUnknown:
1428 case DataTypeFloat16:
1429 case DataTypeFloat32:
1430 case DataTypeFloat64:
1431 case DataTypeSigned8:
1432 case DataTypeSigned16:
1433 case DataTypeSigned32:
1434 case DataTypeSigned64:
1435 case DataTypeUnsigned8:
1436 case DataTypeUnsigned16:
1437 case DataTypeUnsigned32:
1438 case DataTypeUnsigned64:
1439 case DataTypeBoolean:
1440 case DataTypeUnsigned565:
1441 case DataTypeUnsigned5551:
1442 case DataTypeUnsigned4444:
1443 case DataTypeMax: {
1444 slangAssert(false && "Not RS object type!");
1445 }
1446 // No default case will enable compiler detecting the missing cases
1447 }
1448
1449 return Res;
1450 }
1451
VisitDeclStmt(clang::DeclStmt * DS)1452 void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
1453 for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
1454 I != E;
1455 I++) {
1456 clang::Decl *D = *I;
1457 if (D->getKind() == clang::Decl::Var) {
1458 clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
1459 DataType DT = DataTypeUnknown;
1460 clang::Expr *InitExpr = nullptr;
1461 if (InitializeRSObject(VD, &DT, &InitExpr)) {
1462 // We need to zero-init all RS object types (including matrices), ...
1463 getCurrentScope()->AppendRSObjectInit(VD, DS, DT, InitExpr);
1464 // ... but, only add to the list of RS objects if we have some
1465 // non-matrix RS object fields.
1466 if (CountRSObjectTypes(mCtx, VD->getType().getTypePtr(),
1467 VD->getLocation())) {
1468 getCurrentScope()->addRSObject(VD);
1469 }
1470 }
1471 }
1472 }
1473 }
1474
VisitCompoundStmt(clang::CompoundStmt * CS)1475 void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
1476 if (!CS->body_empty()) {
1477 // Push a new scope
1478 Scope *S = new Scope(CS);
1479 mScopeStack.push(S);
1480
1481 VisitStmt(CS);
1482
1483 // Destroy the scope
1484 slangAssert((getCurrentScope() == S) && "Corrupted scope stack!");
1485 S->InsertLocalVarDestructors();
1486 mScopeStack.pop();
1487 delete S;
1488 }
1489 }
1490
VisitBinAssign(clang::BinaryOperator * AS)1491 void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
1492 clang::QualType QT = AS->getType();
1493
1494 if (CountRSObjectTypes(mCtx, QT.getTypePtr(), AS->getExprLoc())) {
1495 getCurrentScope()->ReplaceRSObjectAssignment(AS);
1496 }
1497 }
1498
VisitStmt(clang::Stmt * S)1499 void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
1500 for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
1501 I != E;
1502 I++) {
1503 if (clang::Stmt *Child = *I) {
1504 Visit(Child);
1505 }
1506 }
1507 }
1508
1509 // This function walks the list of global variables and (potentially) creates
1510 // a single global static destructor function that properly decrements
1511 // reference counts on the contained RS object types.
CreateStaticGlobalDtor()1512 clang::FunctionDecl *RSObjectRefCount::CreateStaticGlobalDtor() {
1513 Init();
1514
1515 clang::DeclContext *DC = mCtx.getTranslationUnitDecl();
1516 clang::SourceLocation loc;
1517
1518 llvm::StringRef SR(".rs.dtor");
1519 clang::IdentifierInfo &II = mCtx.Idents.get(SR);
1520 clang::DeclarationName N(&II);
1521 clang::FunctionProtoType::ExtProtoInfo EPI;
1522 clang::QualType T = mCtx.getFunctionType(mCtx.VoidTy,
1523 llvm::ArrayRef<clang::QualType>(), EPI);
1524 clang::FunctionDecl *FD = nullptr;
1525
1526 // Generate rsClearObject() call chains for every global variable
1527 // (whether static or extern).
1528 std::list<clang::Stmt *> StmtList;
1529 for (clang::DeclContext::decl_iterator I = DC->decls_begin(),
1530 E = DC->decls_end(); I != E; I++) {
1531 clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*I);
1532 if (VD) {
1533 if (CountRSObjectTypes(mCtx, VD->getType().getTypePtr(), loc)) {
1534 if (!FD) {
1535 // Only create FD if we are going to use it.
1536 FD = clang::FunctionDecl::Create(mCtx, DC, loc, loc, N, T, nullptr,
1537 clang::SC_None);
1538 }
1539 // Mark VD as used. It might be unused, except for the destructor.
1540 // 'markUsed' has side-effects that are caused only if VD is not already
1541 // used. Hence no need for an extra check here.
1542 VD->markUsed(mCtx);
1543 // Make sure to create any helpers within the function's DeclContext,
1544 // not the one associated with the global translation unit.
1545 clang::Stmt *RSClearObjectCall = Scope::ClearRSObject(VD, FD);
1546 StmtList.push_back(RSClearObjectCall);
1547 }
1548 }
1549 }
1550
1551 // Nothing needs to be destroyed, so don't emit a dtor.
1552 if (StmtList.empty()) {
1553 return nullptr;
1554 }
1555
1556 clang::CompoundStmt *CS = BuildCompoundStmt(mCtx, StmtList, loc);
1557
1558 FD->setBody(CS);
1559
1560 return FD;
1561 }
1562
1563 } // namespace slang
1564