1 //===--- SemaCoroutines.cpp - Semantic Analysis for Coroutines ------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 //  This file implements semantic analysis for C++ Coroutines.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "clang/Sema/SemaInternal.h"
15 #include "clang/AST/Decl.h"
16 #include "clang/AST/ExprCXX.h"
17 #include "clang/AST/StmtCXX.h"
18 #include "clang/Lex/Preprocessor.h"
19 #include "clang/Sema/Initialization.h"
20 #include "clang/Sema/Overload.h"
21 using namespace clang;
22 using namespace sema;
23 
24 /// Look up the std::coroutine_traits<...>::promise_type for the given
25 /// function type.
lookupPromiseType(Sema & S,const FunctionProtoType * FnType,SourceLocation Loc)26 static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType,
27                                   SourceLocation Loc) {
28   // FIXME: Cache std::coroutine_traits once we've found it.
29   NamespaceDecl *Std = S.getStdNamespace();
30   if (!Std) {
31     S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found);
32     return QualType();
33   }
34 
35   LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_traits"),
36                       Loc, Sema::LookupOrdinaryName);
37   if (!S.LookupQualifiedName(Result, Std)) {
38     S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found);
39     return QualType();
40   }
41 
42   ClassTemplateDecl *CoroTraits = Result.getAsSingle<ClassTemplateDecl>();
43   if (!CoroTraits) {
44     Result.suppressDiagnostics();
45     // We found something weird. Complain about the first thing we found.
46     NamedDecl *Found = *Result.begin();
47     S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_traits);
48     return QualType();
49   }
50 
51   // Form template argument list for coroutine_traits<R, P1, P2, ...>.
52   TemplateArgumentListInfo Args(Loc, Loc);
53   Args.addArgument(TemplateArgumentLoc(
54       TemplateArgument(FnType->getReturnType()),
55       S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), Loc)));
56   // FIXME: If the function is a non-static member function, add the type
57   // of the implicit object parameter before the formal parameters.
58   for (QualType T : FnType->getParamTypes())
59     Args.addArgument(TemplateArgumentLoc(
60         TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, Loc)));
61 
62   // Build the template-id.
63   QualType CoroTrait =
64       S.CheckTemplateIdType(TemplateName(CoroTraits), Loc, Args);
65   if (CoroTrait.isNull())
66     return QualType();
67   if (S.RequireCompleteType(Loc, CoroTrait,
68                             diag::err_coroutine_traits_missing_specialization))
69     return QualType();
70 
71   CXXRecordDecl *RD = CoroTrait->getAsCXXRecordDecl();
72   assert(RD && "specialization of class template is not a class?");
73 
74   // Look up the ::promise_type member.
75   LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), Loc,
76                  Sema::LookupOrdinaryName);
77   S.LookupQualifiedName(R, RD);
78   auto *Promise = R.getAsSingle<TypeDecl>();
79   if (!Promise) {
80     S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_found)
81       << RD;
82     return QualType();
83   }
84 
85   // The promise type is required to be a class type.
86   QualType PromiseType = S.Context.getTypeDeclType(Promise);
87   if (!PromiseType->getAsCXXRecordDecl()) {
88     // Use the fully-qualified name of the type.
89     auto *NNS = NestedNameSpecifier::Create(S.Context, nullptr, Std);
90     NNS = NestedNameSpecifier::Create(S.Context, NNS, false,
91                                       CoroTrait.getTypePtr());
92     PromiseType = S.Context.getElaboratedType(ETK_None, NNS, PromiseType);
93 
94     S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_class)
95       << PromiseType;
96     return QualType();
97   }
98 
99   return PromiseType;
100 }
101 
102 /// Check that this is a context in which a coroutine suspension can appear.
103 static FunctionScopeInfo *
checkCoroutineContext(Sema & S,SourceLocation Loc,StringRef Keyword)104 checkCoroutineContext(Sema &S, SourceLocation Loc, StringRef Keyword) {
105   // 'co_await' and 'co_yield' are not permitted in unevaluated operands.
106   if (S.isUnevaluatedContext()) {
107     S.Diag(Loc, diag::err_coroutine_unevaluated_context) << Keyword;
108     return nullptr;
109   }
110 
111   // Any other usage must be within a function.
112   // FIXME: Reject a coroutine with a deduced return type.
113   auto *FD = dyn_cast<FunctionDecl>(S.CurContext);
114   if (!FD) {
115     S.Diag(Loc, isa<ObjCMethodDecl>(S.CurContext)
116                     ? diag::err_coroutine_objc_method
117                     : diag::err_coroutine_outside_function) << Keyword;
118   } else if (isa<CXXConstructorDecl>(FD) || isa<CXXDestructorDecl>(FD)) {
119     // Coroutines TS [special]/6:
120     //   A special member function shall not be a coroutine.
121     //
122     // FIXME: We assume that this really means that a coroutine cannot
123     //        be a constructor or destructor.
124     S.Diag(Loc, diag::err_coroutine_ctor_dtor)
125       << isa<CXXDestructorDecl>(FD) << Keyword;
126   } else if (FD->isConstexpr()) {
127     S.Diag(Loc, diag::err_coroutine_constexpr) << Keyword;
128   } else if (FD->isVariadic()) {
129     S.Diag(Loc, diag::err_coroutine_varargs) << Keyword;
130   } else {
131     auto *ScopeInfo = S.getCurFunction();
132     assert(ScopeInfo && "missing function scope for function");
133 
134     // If we don't have a promise variable, build one now.
135     if (!ScopeInfo->CoroutinePromise) {
136       QualType T =
137           FD->getType()->isDependentType()
138               ? S.Context.DependentTy
139               : lookupPromiseType(S, FD->getType()->castAs<FunctionProtoType>(),
140                                   Loc);
141       if (T.isNull())
142         return nullptr;
143 
144       // Create and default-initialize the promise.
145       ScopeInfo->CoroutinePromise =
146           VarDecl::Create(S.Context, FD, FD->getLocation(), FD->getLocation(),
147                           &S.PP.getIdentifierTable().get("__promise"), T,
148                           S.Context.getTrivialTypeSourceInfo(T, Loc), SC_None);
149       S.CheckVariableDeclarationType(ScopeInfo->CoroutinePromise);
150       if (!ScopeInfo->CoroutinePromise->isInvalidDecl())
151         S.ActOnUninitializedDecl(ScopeInfo->CoroutinePromise, false);
152     }
153 
154     return ScopeInfo;
155   }
156 
157   return nullptr;
158 }
159 
160 /// Build a call to 'operator co_await' if there is a suitable operator for
161 /// the given expression.
buildOperatorCoawaitCall(Sema & SemaRef,Scope * S,SourceLocation Loc,Expr * E)162 static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S,
163                                            SourceLocation Loc, Expr *E) {
164   UnresolvedSet<16> Functions;
165   SemaRef.LookupOverloadedOperatorName(OO_Coawait, S, E->getType(), QualType(),
166                                        Functions);
167   return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
168 }
169 
170 struct ReadySuspendResumeResult {
171   bool IsInvalid;
172   Expr *Results[3];
173 };
174 
buildMemberCall(Sema & S,Expr * Base,SourceLocation Loc,StringRef Name,MutableArrayRef<Expr * > Args)175 static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc,
176                                   StringRef Name,
177                                   MutableArrayRef<Expr *> Args) {
178   DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc);
179 
180   // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&.
181   CXXScopeSpec SS;
182   ExprResult Result = S.BuildMemberReferenceExpr(
183       Base, Base->getType(), Loc, /*IsPtr=*/false, SS,
184       SourceLocation(), nullptr, NameInfo, /*TemplateArgs=*/nullptr,
185       /*Scope=*/nullptr);
186   if (Result.isInvalid())
187     return ExprError();
188 
189   return S.ActOnCallExpr(nullptr, Result.get(), Loc, Args, Loc, nullptr);
190 }
191 
192 /// Build calls to await_ready, await_suspend, and await_resume for a co_await
193 /// expression.
buildCoawaitCalls(Sema & S,SourceLocation Loc,Expr * E)194 static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, SourceLocation Loc,
195                                                   Expr *E) {
196   // Assume invalid until we see otherwise.
197   ReadySuspendResumeResult Calls = {true, {}};
198 
199   const StringRef Funcs[] = {"await_ready", "await_suspend", "await_resume"};
200   for (size_t I = 0, N = llvm::array_lengthof(Funcs); I != N; ++I) {
201     Expr *Operand = new (S.Context) OpaqueValueExpr(
202         Loc, E->getType(), VK_LValue, E->getObjectKind(), E);
203 
204     // FIXME: Pass coroutine handle to await_suspend.
205     ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], None);
206     if (Result.isInvalid())
207       return Calls;
208     Calls.Results[I] = Result.get();
209   }
210 
211   Calls.IsInvalid = false;
212   return Calls;
213 }
214 
ActOnCoawaitExpr(Scope * S,SourceLocation Loc,Expr * E)215 ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) {
216   if (E->getType()->isPlaceholderType()) {
217     ExprResult R = CheckPlaceholderExpr(E);
218     if (R.isInvalid()) return ExprError();
219     E = R.get();
220   }
221 
222   ExprResult Awaitable = buildOperatorCoawaitCall(*this, S, Loc, E);
223   if (Awaitable.isInvalid())
224     return ExprError();
225   return BuildCoawaitExpr(Loc, Awaitable.get());
226 }
BuildCoawaitExpr(SourceLocation Loc,Expr * E)227 ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) {
228   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await");
229   if (!Coroutine)
230     return ExprError();
231 
232   if (E->getType()->isPlaceholderType()) {
233     ExprResult R = CheckPlaceholderExpr(E);
234     if (R.isInvalid()) return ExprError();
235     E = R.get();
236   }
237 
238   if (E->getType()->isDependentType()) {
239     Expr *Res = new (Context) CoawaitExpr(Loc, Context.DependentTy, E);
240     Coroutine->CoroutineStmts.push_back(Res);
241     return Res;
242   }
243 
244   // If the expression is a temporary, materialize it as an lvalue so that we
245   // can use it multiple times.
246   if (E->getValueKind() == VK_RValue)
247     E = new (Context) MaterializeTemporaryExpr(E->getType(), E, true);
248 
249   // Build the await_ready, await_suspend, await_resume calls.
250   ReadySuspendResumeResult RSS = buildCoawaitCalls(*this, Loc, E);
251   if (RSS.IsInvalid)
252     return ExprError();
253 
254   Expr *Res = new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1],
255                                         RSS.Results[2]);
256   Coroutine->CoroutineStmts.push_back(Res);
257   return Res;
258 }
259 
buildPromiseCall(Sema & S,FunctionScopeInfo * Coroutine,SourceLocation Loc,StringRef Name,MutableArrayRef<Expr * > Args)260 static ExprResult buildPromiseCall(Sema &S, FunctionScopeInfo *Coroutine,
261                                    SourceLocation Loc, StringRef Name,
262                                    MutableArrayRef<Expr *> Args) {
263   assert(Coroutine->CoroutinePromise && "no promise for coroutine");
264 
265   // Form a reference to the promise.
266   auto *Promise = Coroutine->CoroutinePromise;
267   ExprResult PromiseRef = S.BuildDeclRefExpr(
268       Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc);
269   if (PromiseRef.isInvalid())
270     return ExprError();
271 
272   // Call 'yield_value', passing in E.
273   return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args);
274 }
275 
ActOnCoyieldExpr(Scope * S,SourceLocation Loc,Expr * E)276 ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) {
277   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield");
278   if (!Coroutine)
279     return ExprError();
280 
281   // Build yield_value call.
282   ExprResult Awaitable =
283       buildPromiseCall(*this, Coroutine, Loc, "yield_value", E);
284   if (Awaitable.isInvalid())
285     return ExprError();
286 
287   // Build 'operator co_await' call.
288   Awaitable = buildOperatorCoawaitCall(*this, S, Loc, Awaitable.get());
289   if (Awaitable.isInvalid())
290     return ExprError();
291 
292   return BuildCoyieldExpr(Loc, Awaitable.get());
293 }
BuildCoyieldExpr(SourceLocation Loc,Expr * E)294 ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) {
295   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield");
296   if (!Coroutine)
297     return ExprError();
298 
299   if (E->getType()->isPlaceholderType()) {
300     ExprResult R = CheckPlaceholderExpr(E);
301     if (R.isInvalid()) return ExprError();
302     E = R.get();
303   }
304 
305   if (E->getType()->isDependentType()) {
306     Expr *Res = new (Context) CoyieldExpr(Loc, Context.DependentTy, E);
307     Coroutine->CoroutineStmts.push_back(Res);
308     return Res;
309   }
310 
311   // If the expression is a temporary, materialize it as an lvalue so that we
312   // can use it multiple times.
313   if (E->getValueKind() == VK_RValue)
314     E = new (Context) MaterializeTemporaryExpr(E->getType(), E, true);
315 
316   // Build the await_ready, await_suspend, await_resume calls.
317   ReadySuspendResumeResult RSS = buildCoawaitCalls(*this, Loc, E);
318   if (RSS.IsInvalid)
319     return ExprError();
320 
321   Expr *Res = new (Context) CoyieldExpr(Loc, E, RSS.Results[0], RSS.Results[1],
322                                         RSS.Results[2]);
323   Coroutine->CoroutineStmts.push_back(Res);
324   return Res;
325 }
326 
ActOnCoreturnStmt(SourceLocation Loc,Expr * E)327 StmtResult Sema::ActOnCoreturnStmt(SourceLocation Loc, Expr *E) {
328   return BuildCoreturnStmt(Loc, E);
329 }
BuildCoreturnStmt(SourceLocation Loc,Expr * E)330 StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E) {
331   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return");
332   if (!Coroutine)
333     return StmtError();
334 
335   if (E && E->getType()->isPlaceholderType() &&
336       !E->getType()->isSpecificPlaceholderType(BuiltinType::Overload)) {
337     ExprResult R = CheckPlaceholderExpr(E);
338     if (R.isInvalid()) return StmtError();
339     E = R.get();
340   }
341 
342   // FIXME: If the operand is a reference to a variable that's about to go out
343   // of scope, we should treat the operand as an xvalue for this overload
344   // resolution.
345   ExprResult PC;
346   if (E && !E->getType()->isVoidType()) {
347     PC = buildPromiseCall(*this, Coroutine, Loc, "return_value", E);
348   } else {
349     E = MakeFullDiscardedValueExpr(E).get();
350     PC = buildPromiseCall(*this, Coroutine, Loc, "return_void", None);
351   }
352   if (PC.isInvalid())
353     return StmtError();
354 
355   Expr *PCE = ActOnFinishFullExpr(PC.get()).get();
356 
357   Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE);
358   Coroutine->CoroutineStmts.push_back(Res);
359   return Res;
360 }
361 
CheckCompletedCoroutineBody(FunctionDecl * FD,Stmt * & Body)362 void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
363   FunctionScopeInfo *Fn = getCurFunction();
364   assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine");
365 
366   // Coroutines [stmt.return]p1:
367   //   A return statement shall not appear in a coroutine.
368   if (Fn->FirstReturnLoc.isValid()) {
369     Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine);
370     auto *First = Fn->CoroutineStmts[0];
371     Diag(First->getLocStart(), diag::note_declared_coroutine_here)
372       << (isa<CoawaitExpr>(First) ? 0 :
373           isa<CoyieldExpr>(First) ? 1 : 2);
374   }
375 
376   bool AnyCoawaits = false;
377   bool AnyCoyields = false;
378   for (auto *CoroutineStmt : Fn->CoroutineStmts) {
379     AnyCoawaits |= isa<CoawaitExpr>(CoroutineStmt);
380     AnyCoyields |= isa<CoyieldExpr>(CoroutineStmt);
381   }
382 
383   if (!AnyCoawaits && !AnyCoyields)
384     Diag(Fn->CoroutineStmts.front()->getLocStart(),
385          diag::ext_coroutine_without_co_await_co_yield);
386 
387   SourceLocation Loc = FD->getLocation();
388 
389   // Form a declaration statement for the promise declaration, so that AST
390   // visitors can more easily find it.
391   StmtResult PromiseStmt =
392       ActOnDeclStmt(ConvertDeclToDeclGroup(Fn->CoroutinePromise), Loc, Loc);
393   if (PromiseStmt.isInvalid())
394     return FD->setInvalidDecl();
395 
396   // Form and check implicit 'co_await p.initial_suspend();' statement.
397   ExprResult InitialSuspend =
398       buildPromiseCall(*this, Fn, Loc, "initial_suspend", None);
399   // FIXME: Support operator co_await here.
400   if (!InitialSuspend.isInvalid())
401     InitialSuspend = BuildCoawaitExpr(Loc, InitialSuspend.get());
402   InitialSuspend = ActOnFinishFullExpr(InitialSuspend.get());
403   if (InitialSuspend.isInvalid())
404     return FD->setInvalidDecl();
405 
406   // Form and check implicit 'co_await p.final_suspend();' statement.
407   ExprResult FinalSuspend =
408       buildPromiseCall(*this, Fn, Loc, "final_suspend", None);
409   // FIXME: Support operator co_await here.
410   if (!FinalSuspend.isInvalid())
411     FinalSuspend = BuildCoawaitExpr(Loc, FinalSuspend.get());
412   FinalSuspend = ActOnFinishFullExpr(FinalSuspend.get());
413   if (FinalSuspend.isInvalid())
414     return FD->setInvalidDecl();
415 
416   // FIXME: Perform analysis of set_exception call.
417 
418   // FIXME: Try to form 'p.return_void();' expression statement to handle
419   // control flowing off the end of the coroutine.
420 
421   // Build implicit 'p.get_return_object()' expression and form initialization
422   // of return type from it.
423   ExprResult ReturnObject =
424     buildPromiseCall(*this, Fn, Loc, "get_return_object", None);
425   if (ReturnObject.isInvalid())
426     return FD->setInvalidDecl();
427   QualType RetType = FD->getReturnType();
428   if (!RetType->isDependentType()) {
429     InitializedEntity Entity =
430         InitializedEntity::InitializeResult(Loc, RetType, false);
431     ReturnObject = PerformMoveOrCopyInitialization(Entity, nullptr, RetType,
432                                                    ReturnObject.get());
433     if (ReturnObject.isInvalid())
434       return FD->setInvalidDecl();
435   }
436   ReturnObject = ActOnFinishFullExpr(ReturnObject.get(), Loc);
437   if (ReturnObject.isInvalid())
438     return FD->setInvalidDecl();
439 
440   // FIXME: Perform move-initialization of parameters into frame-local copies.
441   SmallVector<Expr*, 16> ParamMoves;
442 
443   // Build body for the coroutine wrapper statement.
444   Body = new (Context) CoroutineBodyStmt(
445       Body, PromiseStmt.get(), InitialSuspend.get(), FinalSuspend.get(),
446       /*SetException*/nullptr, /*Fallthrough*/nullptr,
447       ReturnObject.get(), ParamMoves);
448 }
449