1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/SkSLIRGenerator.h"
9 
10 #include "limits.h"
11 #include <iterator>
12 #include <memory>
13 #include <unordered_set>
14 
15 #include "include/private/SkSLLayout.h"
16 #include "include/private/SkTArray.h"
17 #include "src/core/SkScopeExit.h"
18 #include "src/sksl/SkSLAnalysis.h"
19 #include "src/sksl/SkSLCompiler.h"
20 #include "src/sksl/SkSLConstantFolder.h"
21 #include "src/sksl/SkSLOperators.h"
22 #include "src/sksl/SkSLParser.h"
23 #include "src/sksl/SkSLUtil.h"
24 #include "src/sksl/ir/SkSLBinaryExpression.h"
25 #include "src/sksl/ir/SkSLBoolLiteral.h"
26 #include "src/sksl/ir/SkSLBreakStatement.h"
27 #include "src/sksl/ir/SkSLConstructor.h"
28 #include "src/sksl/ir/SkSLContinueStatement.h"
29 #include "src/sksl/ir/SkSLDiscardStatement.h"
30 #include "src/sksl/ir/SkSLDoStatement.h"
31 #include "src/sksl/ir/SkSLEnum.h"
32 #include "src/sksl/ir/SkSLExpressionStatement.h"
33 #include "src/sksl/ir/SkSLExternalFunctionCall.h"
34 #include "src/sksl/ir/SkSLExternalFunctionReference.h"
35 #include "src/sksl/ir/SkSLField.h"
36 #include "src/sksl/ir/SkSLFieldAccess.h"
37 #include "src/sksl/ir/SkSLFloatLiteral.h"
38 #include "src/sksl/ir/SkSLForStatement.h"
39 #include "src/sksl/ir/SkSLFunctionCall.h"
40 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
41 #include "src/sksl/ir/SkSLFunctionDefinition.h"
42 #include "src/sksl/ir/SkSLFunctionPrototype.h"
43 #include "src/sksl/ir/SkSLFunctionReference.h"
44 #include "src/sksl/ir/SkSLIfStatement.h"
45 #include "src/sksl/ir/SkSLIndexExpression.h"
46 #include "src/sksl/ir/SkSLIntLiteral.h"
47 #include "src/sksl/ir/SkSLInterfaceBlock.h"
48 #include "src/sksl/ir/SkSLNop.h"
49 #include "src/sksl/ir/SkSLPostfixExpression.h"
50 #include "src/sksl/ir/SkSLPrefixExpression.h"
51 #include "src/sksl/ir/SkSLReturnStatement.h"
52 #include "src/sksl/ir/SkSLSetting.h"
53 #include "src/sksl/ir/SkSLStructDefinition.h"
54 #include "src/sksl/ir/SkSLSwitchCase.h"
55 #include "src/sksl/ir/SkSLSwitchStatement.h"
56 #include "src/sksl/ir/SkSLSwizzle.h"
57 #include "src/sksl/ir/SkSLTernaryExpression.h"
58 #include "src/sksl/ir/SkSLUnresolvedFunction.h"
59 #include "src/sksl/ir/SkSLVarDeclarations.h"
60 #include "src/sksl/ir/SkSLVariable.h"
61 #include "src/sksl/ir/SkSLVariableReference.h"
62 
63 namespace SkSL {
64 
65 class AutoSymbolTable {
66 public:
AutoSymbolTable(IRGenerator * ir)67     AutoSymbolTable(IRGenerator* ir)
68     : fIR(ir)
69     , fPrevious(fIR->fSymbolTable) {
70         fIR->pushSymbolTable();
71     }
72 
~AutoSymbolTable()73     ~AutoSymbolTable() {
74         fIR->popSymbolTable();
75         SkASSERT(fPrevious == fIR->fSymbolTable);
76     }
77 
78     IRGenerator* fIR;
79     std::shared_ptr<SymbolTable> fPrevious;
80 };
81 
IRGenerator(const Context * context)82 IRGenerator::IRGenerator(const Context* context)
83         : fContext(*context) {}
84 
pushSymbolTable()85 void IRGenerator::pushSymbolTable() {
86     auto childSymTable = std::make_shared<SymbolTable>(std::move(fSymbolTable), fIsBuiltinCode);
87     fSymbolTable = std::move(childSymTable);
88 }
89 
popSymbolTable()90 void IRGenerator::popSymbolTable() {
91     fSymbolTable = fSymbolTable->fParent;
92 }
93 
detectVarDeclarationWithoutScope(const Statement & stmt)94 bool IRGenerator::detectVarDeclarationWithoutScope(const Statement& stmt) {
95     // Parsing an AST node containing a single variable declaration creates a lone VarDeclaration
96     // statement. An AST with multiple variable declarations creates an unscoped Block containing
97     // multiple VarDeclaration statements. We need to detect either case.
98     const Variable* var;
99     if (stmt.is<VarDeclaration>()) {
100         // The single-variable case. No blocks at all.
101         var = &stmt.as<VarDeclaration>().var();
102     } else if (stmt.is<Block>()) {
103         // The multiple-variable case: an unscoped, non-empty block...
104         const Block& block = stmt.as<Block>();
105         if (block.isScope() || block.children().empty()) {
106             return false;
107         }
108         // ... holding a variable declaration.
109         const Statement& innerStmt = *block.children().front();
110         if (!innerStmt.is<VarDeclaration>()) {
111             return false;
112         }
113         var = &innerStmt.as<VarDeclaration>().var();
114     } else {
115         // This statement wasn't a variable declaration. No problem.
116         return false;
117     }
118 
119     // Report an error.
120     SkASSERT(var);
121     this->errorReporter().error(stmt.fOffset,
122                                 "variable '" + var->name() + "' must be created in a scope");
123     return true;
124 }
125 
convertExtension(int offset,StringFragment name)126 std::unique_ptr<Extension> IRGenerator::convertExtension(int offset, StringFragment name) {
127     if (this->programKind() != ProgramKind::kFragment &&
128         this->programKind() != ProgramKind::kVertex &&
129         this->programKind() != ProgramKind::kGeometry) {
130         this->errorReporter().error(offset, "extensions are not allowed here");
131         return nullptr;
132     }
133 
134     return std::make_unique<Extension>(offset, name);
135 }
136 
convertStatement(const ASTNode & statement)137 std::unique_ptr<Statement> IRGenerator::convertStatement(const ASTNode& statement) {
138     switch (statement.fKind) {
139         case ASTNode::Kind::kBlock:
140             return this->convertBlock(statement);
141         case ASTNode::Kind::kVarDeclarations:
142             return this->convertVarDeclarationStatement(statement);
143         case ASTNode::Kind::kIf:
144             return this->convertIf(statement);
145         case ASTNode::Kind::kFor:
146             return this->convertFor(statement);
147         case ASTNode::Kind::kWhile:
148             return this->convertWhile(statement);
149         case ASTNode::Kind::kDo:
150             return this->convertDo(statement);
151         case ASTNode::Kind::kSwitch:
152             return this->convertSwitch(statement);
153         case ASTNode::Kind::kReturn:
154             return this->convertReturn(statement);
155         case ASTNode::Kind::kBreak:
156             return this->convertBreak(statement);
157         case ASTNode::Kind::kContinue:
158             return this->convertContinue(statement);
159         case ASTNode::Kind::kDiscard:
160             return this->convertDiscard(statement);
161         case ASTNode::Kind::kType:
162             // TODO: add IRNode for struct definition inside a function
163             return nullptr;
164         default:
165             // it's an expression
166             std::unique_ptr<Statement> result = this->convertExpressionStatement(statement);
167             if (fRTAdjust && this->programKind() == ProgramKind::kGeometry) {
168                 SkASSERT(result->is<ExpressionStatement>());
169                 Expression& expr = *result->as<ExpressionStatement>().expression();
170                 if (expr.is<FunctionCall>()) {
171                     FunctionCall& fc = expr.as<FunctionCall>();
172                     if (fc.function().isBuiltin() && fc.function().name() == "EmitVertex") {
173                         StatementArray statements;
174                         statements.reserve_back(2);
175                         statements.push_back(getNormalizeSkPositionCode());
176                         statements.push_back(std::move(result));
177                         return Block::Make(statement.fOffset, std::move(statements),
178                                            fSymbolTable, /*isScope=*/true);
179                     }
180                 }
181             }
182             return result;
183     }
184 }
185 
convertBlock(const ASTNode & block)186 std::unique_ptr<Block> IRGenerator::convertBlock(const ASTNode& block) {
187     SkASSERT(block.fKind == ASTNode::Kind::kBlock);
188     AutoSymbolTable table(this);
189     StatementArray statements;
190     for (const auto& child : block) {
191         std::unique_ptr<Statement> statement = this->convertStatement(child);
192         if (!statement) {
193             return nullptr;
194         }
195         statements.push_back(std::move(statement));
196     }
197     return Block::Make(block.fOffset, std::move(statements), fSymbolTable);
198 }
199 
convertVarDeclarationStatement(const ASTNode & s)200 std::unique_ptr<Statement> IRGenerator::convertVarDeclarationStatement(const ASTNode& s) {
201     SkASSERT(s.fKind == ASTNode::Kind::kVarDeclarations);
202     auto decls = this->convertVarDeclarations(s, Variable::Storage::kLocal);
203     if (decls.empty()) {
204         return nullptr;
205     }
206     return Block::MakeUnscoped(s.fOffset, std::move(decls));
207 }
208 
convertArraySize(const Type & type,int offset,const ASTNode & s)209 int IRGenerator::convertArraySize(const Type& type, int offset, const ASTNode& s) {
210     if (!s) {
211         this->errorReporter().error(offset, "array must have a size");
212         return 0;
213     }
214     auto size = this->convertExpression(s);
215     if (!size) {
216         return 0;
217     }
218     return this->convertArraySize(type, std::move(size));
219 }
220 
convertArraySize(const Type & type,std::unique_ptr<Expression> size)221 int IRGenerator::convertArraySize(const Type& type, std::unique_ptr<Expression> size) {
222     size = this->coerce(std::move(size), *fContext.fTypes.fInt);
223     if (!size) {
224         return 0;
225     }
226     if (type.isVoid()) {
227         this->errorReporter().error(size->fOffset, "type 'void' may not be used in an array");
228         return 0;
229     }
230     if (type.isOpaque()) {
231         this->errorReporter().error(
232                 size->fOffset, "opaque type '" + type.name() + "' may not be used in an array");
233         return 0;
234     }
235     if (!size->is<IntLiteral>()) {
236         this->errorReporter().error(size->fOffset, "array size must be an integer");
237         return 0;
238     }
239     SKSL_INT count = size->as<IntLiteral>().value();
240     if (count <= 0) {
241         this->errorReporter().error(size->fOffset, "array size must be positive");
242         return 0;
243     }
244     if (!SkTFitsIn<int>(count)) {
245         this->errorReporter().error(size->fOffset, "array size is too large");
246         return 0;
247     }
248     return static_cast<int>(count);
249 }
250 
checkVarDeclaration(int offset,const Modifiers & modifiers,const Type * baseType,Variable::Storage storage)251 void IRGenerator::checkVarDeclaration(int offset, const Modifiers& modifiers, const Type* baseType,
252                                       Variable::Storage storage) {
253     if (this->strictES2Mode() && baseType->isArray()) {
254         this->errorReporter().error(offset, "array size must appear after variable name");
255     }
256 
257     if (baseType->componentType().isOpaque() && storage != Variable::Storage::kGlobal) {
258         this->errorReporter().error(
259                 offset,
260                 "variables of type '" + baseType->displayName() + "' must be global");
261     }
262     if (this->programKind() != ProgramKind::kFragmentProcessor) {
263         if ((modifiers.fFlags & Modifiers::kIn_Flag) && baseType->isMatrix()) {
264             this->errorReporter().error(offset, "'in' variables may not have matrix type");
265         }
266         if ((modifiers.fFlags & Modifiers::kIn_Flag) &&
267             (modifiers.fFlags & Modifiers::kUniform_Flag)) {
268             this->errorReporter().error(
269                     offset,
270                     "'in uniform' variables only permitted within fragment processors");
271         }
272         if (modifiers.fLayout.fWhen.fLength) {
273             this->errorReporter().error(offset,
274                                         "'when' is only permitted within fragment processors");
275         }
276         if (modifiers.fLayout.fFlags & Layout::kTracked_Flag) {
277             this->errorReporter().error(offset,
278                                         "'tracked' is only permitted within fragment processors");
279         }
280         if (modifiers.fLayout.fCType != Layout::CType::kDefault) {
281             this->errorReporter().error(offset,
282                                         "'ctype' is only permitted within fragment processors");
283         }
284         if (modifiers.fLayout.fFlags & Layout::kKey_Flag) {
285             this->errorReporter().error(offset,
286                                         "'key' is only permitted within fragment processors");
287         }
288     }
289     if (this->programKind() == ProgramKind::kRuntimeColorFilter ||
290         this->programKind() == ProgramKind::kRuntimeShader) {
291         if (modifiers.fFlags & Modifiers::kIn_Flag) {
292             this->errorReporter().error(offset, "'in' variables not permitted in runtime effects");
293         }
294     }
295     if (baseType->isEffectChild() && !(modifiers.fFlags & Modifiers::kUniform_Flag)) {
296         this->errorReporter().error(
297                 offset, "variables of type '" + baseType->displayName() + "' must be uniform");
298     }
299     if ((modifiers.fLayout.fFlags & Layout::kKey_Flag) &&
300         (modifiers.fFlags & Modifiers::kUniform_Flag)) {
301         this->errorReporter().error(offset, "'key' is not permitted on 'uniform' variables");
302     }
303     if (modifiers.fLayout.fFlags & Layout::kSRGBUnpremul_Flag) {
304         if (this->programKind() != ProgramKind::kRuntimeColorFilter &&
305             this->programKind() != ProgramKind::kRuntimeShader) {
306             this->errorReporter().error(offset,
307                                         "'srgb_unpremul' is only permitted in runtime effects");
308         }
309         if (!(modifiers.fFlags & Modifiers::kUniform_Flag)) {
310             this->errorReporter().error(offset,
311                                         "'srgb_unpremul' is only permitted on 'uniform' variables");
312         }
313         auto validColorXformType = [](const Type& t) {
314             return t.isVector() && t.componentType().isFloat() &&
315                    (t.columns() == 3 || t.columns() == 4);
316         };
317         if (!validColorXformType(*baseType) && !(baseType->isArray() &&
318                                                  validColorXformType(baseType->componentType()))) {
319             this->errorReporter().error(offset,
320                                         "'srgb_unpremul' is only permitted on half3, half4, "
321                                         "float3, or float4 variables");
322         }
323     }
324     int permitted = Modifiers::kConst_Flag;
325     if (storage == Variable::Storage::kGlobal) {
326         permitted |= Modifiers::kIn_Flag | Modifiers::kOut_Flag | Modifiers::kUniform_Flag |
327                      Modifiers::kFlat_Flag | Modifiers::kNoPerspective_Flag;
328     }
329     // TODO(skbug.com/11301): Migrate above checks into building a mask of permitted layout flags
330     CheckModifiers(fContext, offset, modifiers, permitted, /*permittedLayoutFlags=*/~0);
331 }
332 
convertVar(int offset,const Modifiers & modifiers,const Type * baseType,StringFragment name,bool isArray,std::unique_ptr<Expression> arraySize,Variable::Storage storage)333 std::unique_ptr<Variable> IRGenerator::convertVar(int offset, const Modifiers& modifiers,
334                                                   const Type* baseType, StringFragment name,
335                                                   bool isArray,
336                                                   std::unique_ptr<Expression> arraySize,
337                                                   Variable::Storage storage) {
338     if (modifiers.fLayout.fLocation == 0 && modifiers.fLayout.fIndex == 0 &&
339         (modifiers.fFlags & Modifiers::kOut_Flag) &&
340         this->programKind() == ProgramKind::kFragment && name != Compiler::FRAGCOLOR_NAME) {
341         this->errorReporter().error(offset,
342                                     "out location=0, index=0 is reserved for sk_FragColor");
343     }
344     const Type* type = baseType;
345     int arraySizeValue = 0;
346     if (isArray) {
347         SkASSERT(arraySize);
348         arraySizeValue = this->convertArraySize(*type, std::move(arraySize));
349         if (!arraySizeValue) {
350             return {};
351         }
352         type = fSymbolTable->addArrayDimension(type, arraySizeValue);
353     }
354     return std::make_unique<Variable>(offset, this->modifiersPool().add(modifiers), name,
355                                       type, fIsBuiltinCode, storage);
356 }
357 
convertVarDeclaration(std::unique_ptr<Variable> var,std::unique_ptr<Expression> value)358 std::unique_ptr<Statement> IRGenerator::convertVarDeclaration(std::unique_ptr<Variable> var,
359                                                               std::unique_ptr<Expression> value) {
360     std::unique_ptr<Statement> varDecl = VarDeclaration::Convert(fContext, var.get(),
361                                                                  std::move(value));
362     if (!varDecl) {
363         return nullptr;
364     }
365 
366     // Detect the declaration of magical variables.
367     if ((var->storage() == Variable::Storage::kGlobal) && var->name() == Compiler::FRAGCOLOR_NAME) {
368         // Silently ignore duplicate definitions of `sk_FragColor`.
369         const Symbol* symbol = (*fSymbolTable)[var->name()];
370         if (symbol) {
371             return nullptr;
372         }
373     } else if ((var->storage() == Variable::Storage::kGlobal ||
374                 var->storage() == Variable::Storage::kInterfaceBlock) &&
375                var->name() == Compiler::RTADJUST_NAME) {
376         // `sk_RTAdjust` is special, and makes the IR generator emit position-fixup expressions.
377         if (fRTAdjust) {
378             this->errorReporter().error(var->fOffset, "duplicate definition of 'sk_RTAdjust'");
379             return nullptr;
380         }
381         if (var->type() != *fContext.fTypes.fFloat4) {
382             this->errorReporter().error(var->fOffset, "sk_RTAdjust must have type 'float4'");
383             return nullptr;
384         }
385         fRTAdjust = var.get();
386     }
387 
388     fSymbolTable->add(std::move(var));
389     return varDecl;
390 }
391 
convertVarDeclaration(int offset,const Modifiers & modifiers,const Type * baseType,StringFragment name,bool isArray,std::unique_ptr<Expression> arraySize,std::unique_ptr<Expression> value,Variable::Storage storage)392 std::unique_ptr<Statement> IRGenerator::convertVarDeclaration(int offset,
393                                                               const Modifiers& modifiers,
394                                                               const Type* baseType,
395                                                               StringFragment name,
396                                                               bool isArray,
397                                                               std::unique_ptr<Expression> arraySize,
398                                                               std::unique_ptr<Expression> value,
399                                                               Variable::Storage storage) {
400     std::unique_ptr<Variable> var = this->convertVar(offset, modifiers, baseType, name, isArray,
401                                                      std::move(arraySize), storage);
402     if (!var) {
403         return nullptr;
404     }
405     return this->convertVarDeclaration(std::move(var), std::move(value));
406 }
407 
convertVarDeclarations(const ASTNode & decls,Variable::Storage storage)408 StatementArray IRGenerator::convertVarDeclarations(const ASTNode& decls,
409                                                    Variable::Storage storage) {
410     SkASSERT(decls.fKind == ASTNode::Kind::kVarDeclarations);
411     auto declarationsIter = decls.begin();
412     const Modifiers& modifiers = declarationsIter++->getModifiers();
413     const ASTNode& rawType = *(declarationsIter++);
414     const Type* baseType = this->convertType(rawType);
415     if (!baseType) {
416         return {};
417     }
418 
419     this->checkVarDeclaration(decls.fOffset, modifiers, baseType, storage);
420 
421     StatementArray varDecls;
422     for (; declarationsIter != decls.end(); ++declarationsIter) {
423         const ASTNode& varDecl = *declarationsIter;
424         const ASTNode::VarData& varData = varDecl.getVarData();
425         std::unique_ptr<Expression> arraySize;
426         std::unique_ptr<Expression> value;
427         auto iter = varDecl.begin();
428         if (iter != varDecl.end() && varData.fIsArray) {
429             if (*iter) {
430                 arraySize = this->convertExpression(*iter++);
431             } else {
432                 this->errorReporter().error(decls.fOffset, "array must have a size");
433                 continue;
434             }
435         }
436         if (iter != varDecl.end()) {
437             value = this->convertExpression(*iter);
438             if (!value) {
439                 continue;
440             }
441         }
442         std::unique_ptr<Statement> varDeclStmt = this->convertVarDeclaration(varDecl.fOffset,
443                                                                              modifiers,
444                                                                              baseType,
445                                                                              varData.fName,
446                                                                              varData.fIsArray,
447                                                                              std::move(arraySize),
448                                                                              std::move(value),
449                                                                              storage);
450         if (varDeclStmt) {
451             varDecls.push_back(std::move(varDeclStmt));
452         }
453     }
454     return varDecls;
455 }
456 
convertModifiersDeclaration(const ASTNode & m)457 std::unique_ptr<ModifiersDeclaration> IRGenerator::convertModifiersDeclaration(const ASTNode& m) {
458     if (this->programKind() != ProgramKind::kFragment &&
459         this->programKind() != ProgramKind::kVertex &&
460         this->programKind() != ProgramKind::kGeometry) {
461         this->errorReporter().error(m.fOffset, "layout qualifiers are not allowed here");
462         return nullptr;
463     }
464 
465     SkASSERT(m.fKind == ASTNode::Kind::kModifiers);
466     Modifiers modifiers = m.getModifiers();
467     if (modifiers.fLayout.fInvocations != -1) {
468         if (this->programKind() != ProgramKind::kGeometry) {
469             this->errorReporter().error(m.fOffset,
470                                         "'invocations' is only legal in geometry shaders");
471             return nullptr;
472         }
473         fInvocations = modifiers.fLayout.fInvocations;
474         if (!this->caps().gsInvocationsSupport()) {
475             modifiers.fLayout.fInvocations = -1;
476             if (modifiers.fLayout.description() == "") {
477                 return nullptr;
478             }
479         }
480     }
481     if (modifiers.fLayout.fMaxVertices != -1 && fInvocations > 0 &&
482         !this->caps().gsInvocationsSupport()) {
483         modifiers.fLayout.fMaxVertices *= fInvocations;
484     }
485     return std::make_unique<ModifiersDeclaration>(this->modifiersPool().add(modifiers));
486 }
487 
convertIf(const ASTNode & n)488 std::unique_ptr<Statement> IRGenerator::convertIf(const ASTNode& n) {
489     SkASSERT(n.fKind == ASTNode::Kind::kIf);
490     auto iter = n.begin();
491     std::unique_ptr<Expression> test = this->convertExpression(*(iter++));
492     if (!test) {
493         return nullptr;
494     }
495     std::unique_ptr<Statement> ifTrue = this->convertStatement(*(iter++));
496     if (!ifTrue) {
497         return nullptr;
498     }
499     if (this->detectVarDeclarationWithoutScope(*ifTrue)) {
500         return nullptr;
501     }
502     std::unique_ptr<Statement> ifFalse;
503     if (iter != n.end()) {
504         ifFalse = this->convertStatement(*(iter++));
505         if (!ifFalse) {
506             return nullptr;
507         }
508         if (this->detectVarDeclarationWithoutScope(*ifFalse)) {
509             return nullptr;
510         }
511     }
512     bool isStatic = n.getBool();
513     return IfStatement::Convert(fContext, n.fOffset, isStatic, std::move(test),
514                                 std::move(ifTrue), std::move(ifFalse));
515 }
516 
convertFor(const ASTNode & f)517 std::unique_ptr<Statement> IRGenerator::convertFor(const ASTNode& f) {
518     SkASSERT(f.fKind == ASTNode::Kind::kFor);
519     AutoSymbolTable table(this);
520     std::unique_ptr<Statement> initializer;
521     auto iter = f.begin();
522     if (*iter) {
523         initializer = this->convertStatement(*iter);
524         if (!initializer) {
525             return nullptr;
526         }
527     }
528     ++iter;
529     std::unique_ptr<Expression> test;
530     if (*iter) {
531         test = this->convertExpression(*iter);
532         if (!test) {
533             return nullptr;
534         }
535     }
536     ++iter;
537     std::unique_ptr<Expression> next;
538     if (*iter) {
539         next = this->convertExpression(*iter);
540         if (!next) {
541             return nullptr;
542         }
543     }
544     ++iter;
545     std::unique_ptr<Statement> statement = this->convertStatement(*iter);
546     if (!statement) {
547         return nullptr;
548     }
549     if (this->detectVarDeclarationWithoutScope(*statement)) {
550         return nullptr;
551     }
552 
553     return ForStatement::Convert(fContext, f.fOffset, std::move(initializer), std::move(test),
554                                  std::move(next), std::move(statement), fSymbolTable);
555 }
556 
convertWhile(const ASTNode & w)557 std::unique_ptr<Statement> IRGenerator::convertWhile(const ASTNode& w) {
558     SkASSERT(w.fKind == ASTNode::Kind::kWhile);
559     auto iter = w.begin();
560     std::unique_ptr<Expression> test = this->convertExpression(*(iter++));
561     if (!test) {
562         return nullptr;
563     }
564     std::unique_ptr<Statement> statement = this->convertStatement(*(iter++));
565     if (!statement) {
566         return nullptr;
567     }
568     if (this->detectVarDeclarationWithoutScope(*statement)) {
569         return nullptr;
570     }
571     return ForStatement::ConvertWhile(fContext, w.fOffset, std::move(test), std::move(statement),
572                                       fSymbolTable);
573 }
574 
convertDo(const ASTNode & d)575 std::unique_ptr<Statement> IRGenerator::convertDo(const ASTNode& d) {
576     SkASSERT(d.fKind == ASTNode::Kind::kDo);
577     auto iter = d.begin();
578     std::unique_ptr<Statement> statement = this->convertStatement(*(iter++));
579     if (!statement) {
580         return nullptr;
581     }
582     std::unique_ptr<Expression> test = this->convertExpression(*(iter++));
583     if (!test) {
584         return nullptr;
585     }
586     if (this->detectVarDeclarationWithoutScope(*statement)) {
587         return nullptr;
588     }
589     return DoStatement::Convert(fContext, std::move(statement), std::move(test));
590 }
591 
convertSwitch(const ASTNode & s)592 std::unique_ptr<Statement> IRGenerator::convertSwitch(const ASTNode& s) {
593     SkASSERT(s.fKind == ASTNode::Kind::kSwitch);
594 
595     auto iter = s.begin();
596     std::unique_ptr<Expression> value = this->convertExpression(*(iter++));
597     if (!value) {
598         return nullptr;
599     }
600     AutoSymbolTable table(this);
601     ExpressionArray caseValues;
602     StatementArray caseStatements;
603     for (; iter != s.end(); ++iter) {
604         const ASTNode& c = *iter;
605         SkASSERT(c.fKind == ASTNode::Kind::kSwitchCase);
606         std::unique_ptr<Expression>& caseValue = caseValues.emplace_back();
607         auto childIter = c.begin();
608         if (*childIter) {
609             caseValue = this->convertExpression(*childIter);
610             if (!caseValue) {
611                 return nullptr;
612             }
613         }
614         ++childIter;
615 
616         StatementArray statements;
617         for (; childIter != c.end(); ++childIter) {
618             std::unique_ptr<Statement> converted = this->convertStatement(*childIter);
619             if (!converted) {
620                 return nullptr;
621             }
622             statements.push_back(std::move(converted));
623         }
624 
625         caseStatements.push_back(Block::MakeUnscoped(c.fOffset, std::move(statements)));
626     }
627     return SwitchStatement::Convert(fContext, s.fOffset, s.getBool(), std::move(value),
628                                     std::move(caseValues), std::move(caseStatements), fSymbolTable);
629 }
630 
convertExpressionStatement(const ASTNode & s)631 std::unique_ptr<Statement> IRGenerator::convertExpressionStatement(const ASTNode& s) {
632     std::unique_ptr<Expression> e = this->convertExpression(s);
633     if (!e) {
634         return nullptr;
635     }
636     return ExpressionStatement::Make(fContext, std::move(e));
637 }
638 
convertReturn(int offset,std::unique_ptr<Expression> result)639 std::unique_ptr<Statement> IRGenerator::convertReturn(int offset,
640                                                       std::unique_ptr<Expression> result) {
641     return ReturnStatement::Make(offset, std::move(result));
642 }
643 
convertReturn(const ASTNode & r)644 std::unique_ptr<Statement> IRGenerator::convertReturn(const ASTNode& r) {
645     SkASSERT(r.fKind == ASTNode::Kind::kReturn);
646     if (r.begin() != r.end()) {
647         std::unique_ptr<Expression> value = this->convertExpression(*r.begin());
648         if (!value) {
649             return nullptr;
650         }
651         return this->convertReturn(r.fOffset, std::move(value));
652     } else {
653         return this->convertReturn(r.fOffset, /*result=*/nullptr);
654     }
655 }
656 
convertBreak(const ASTNode & b)657 std::unique_ptr<Statement> IRGenerator::convertBreak(const ASTNode& b) {
658     SkASSERT(b.fKind == ASTNode::Kind::kBreak);
659     return BreakStatement::Make(b.fOffset);
660 }
661 
convertContinue(const ASTNode & c)662 std::unique_ptr<Statement> IRGenerator::convertContinue(const ASTNode& c) {
663     SkASSERT(c.fKind == ASTNode::Kind::kContinue);
664     return ContinueStatement::Make(c.fOffset);
665 }
666 
convertDiscard(const ASTNode & d)667 std::unique_ptr<Statement> IRGenerator::convertDiscard(const ASTNode& d) {
668     SkASSERT(d.fKind == ASTNode::Kind::kDiscard);
669     if (this->programKind() != ProgramKind::kFragment &&
670         this->programKind() != ProgramKind::kFragmentProcessor) {
671         this->errorReporter().error(d.fOffset,
672                                     "discard statement is only permitted in fragment shaders");
673         return nullptr;
674     }
675     return DiscardStatement::Make(d.fOffset);
676 }
677 
applyInvocationIDWorkaround(std::unique_ptr<Block> main)678 std::unique_ptr<Block> IRGenerator::applyInvocationIDWorkaround(std::unique_ptr<Block> main) {
679     Layout invokeLayout;
680     Modifiers invokeModifiers(invokeLayout, Modifiers::kHasSideEffects_Flag);
681     const FunctionDeclaration* invokeDecl = fSymbolTable->add(std::make_unique<FunctionDeclaration>(
682             /*offset=*/-1,
683             this->modifiersPool().add(invokeModifiers),
684             "_invoke",
685             std::vector<const Variable*>(),
686             fContext.fTypes.fVoid.get(),
687             fIsBuiltinCode));
688     auto invokeDef = std::make_unique<FunctionDefinition>(/*offset=*/-1, invokeDecl, fIsBuiltinCode,
689                                                           std::move(main));
690     invokeDecl->setDefinition(invokeDef.get());
691     fProgramElements->push_back(std::move(invokeDef));
692     std::vector<std::unique_ptr<VarDeclaration>> variables;
693     const Variable* loopIdx = &(*fSymbolTable)["sk_InvocationID"]->as<Variable>();
694     auto test = BinaryExpression::Make(
695             fContext,
696             VariableReference::Make(/*offset=*/-1, loopIdx),
697             Token::Kind::TK_LT,
698             IntLiteral::Make(fContext, /*offset=*/-1, fInvocations));
699     auto next = PostfixExpression::Make(
700             fContext,
701             VariableReference::Make(/*offset=*/-1, loopIdx, VariableRefKind::kReadWrite),
702             Token::Kind::TK_PLUSPLUS);
703     ASTNode endPrimitiveID(&fFile->fNodes, -1, ASTNode::Kind::kIdentifier, "EndPrimitive");
704     std::unique_ptr<Expression> endPrimitive = this->convertExpression(endPrimitiveID);
705     SkASSERT(endPrimitive);
706 
707     StatementArray loopBody;
708     loopBody.reserve_back(2);
709     loopBody.push_back(ExpressionStatement::Make(fContext, this->call(/*offset=*/-1,
710                                                                       *invokeDecl,
711                                                                       ExpressionArray{})));
712     loopBody.push_back(ExpressionStatement::Make(fContext, this->call(/*offset=*/-1,
713                                                                       std::move(endPrimitive),
714                                                                       ExpressionArray{})));
715     auto assignment = BinaryExpression::Make(
716             fContext,
717             VariableReference::Make(/*offset=*/-1, loopIdx, VariableRefKind::kWrite),
718             Token::Kind::TK_EQ,
719             IntLiteral::Make(fContext, /*offset=*/-1, /*value=*/0));
720     auto initializer = ExpressionStatement::Make(fContext, std::move(assignment));
721     auto loop = ForStatement::Make(fContext, /*offset=*/-1,
722                                    std::move(initializer),
723                                    std::move(test),
724                                    std::move(next),
725                                    Block::Make(/*offset=*/-1, std::move(loopBody)),
726                                    fSymbolTable);
727     StatementArray children;
728     children.push_back(std::move(loop));
729     return Block::Make(/*offset=*/-1, std::move(children));
730 }
731 
getNormalizeSkPositionCode()732 std::unique_ptr<Statement> IRGenerator::getNormalizeSkPositionCode() {
733     const Variable* skPerVertex = nullptr;
734     if (const ProgramElement* perVertexDecl = fIntrinsics->find(Compiler::PERVERTEX_NAME)) {
735         SkASSERT(perVertexDecl->is<InterfaceBlock>());
736         skPerVertex = &perVertexDecl->as<InterfaceBlock>().variable();
737     }
738 
739     // sk_Position = float4(sk_Position.xy * rtAdjust.xz + sk_Position.ww * rtAdjust.yw,
740     //                      0,
741     //                      sk_Position.w);
742     SkASSERT(skPerVertex && fRTAdjust);
743     auto Ref = [](const Variable* var) -> std::unique_ptr<Expression> {
744         return VariableReference::Make(/*offset=*/-1, var, VariableReference::RefKind::kRead);
745     };
746     auto WRef = [](const Variable* var) -> std::unique_ptr<Expression> {
747         return VariableReference::Make(/*offset=*/-1, var, VariableReference::RefKind::kWrite);
748     };
749     auto Field = [&](const Variable* var, int idx) -> std::unique_ptr<Expression> {
750         return FieldAccess::Make(fContext, Ref(var), idx,
751                                  FieldAccess::OwnerKind::kAnonymousInterfaceBlock);
752     };
753     auto Pos = [&]() -> std::unique_ptr<Expression> {
754         return FieldAccess::Make(fContext, WRef(skPerVertex), 0,
755                                  FieldAccess::OwnerKind::kAnonymousInterfaceBlock);
756     };
757     auto Adjust = [&]() -> std::unique_ptr<Expression> {
758         return fRTAdjustInterfaceBlock ? Field(fRTAdjustInterfaceBlock, fRTAdjustFieldIndex)
759                                        : Ref(fRTAdjust);
760     };
761     auto Swizzle = [&](std::unique_ptr<Expression> expr,
762                        const ComponentArray& comp) -> std::unique_ptr<Expression> {
763         return std::make_unique<SkSL::Swizzle>(fContext, std::move(expr), comp);
764     };
765     auto Op = [&](std::unique_ptr<Expression> left, Token::Kind op,
766                   std::unique_ptr<Expression> right) -> std::unique_ptr<Expression> {
767         return BinaryExpression::Make(fContext, std::move(left), op, std::move(right));
768     };
769 
770     static const ComponentArray kXYIndices{0, 1};
771     static const ComponentArray kXZIndices{0, 2};
772     static const ComponentArray kYWIndices{1, 3};
773     static const ComponentArray kWWIndices{3, 3};
774     static const ComponentArray kWIndex{3};
775 
776     ExpressionArray children;
777     children.reserve_back(3);
778     children.push_back(Op(
779             Op(Swizzle(Pos(), kXYIndices), Token::Kind::TK_STAR, Swizzle(Adjust(), kXZIndices)),
780             Token::Kind::TK_PLUS,
781             Op(Swizzle(Pos(), kWWIndices), Token::Kind::TK_STAR, Swizzle(Adjust(), kYWIndices))));
782     children.push_back(FloatLiteral::Make(fContext, /*offset=*/-1, /*value=*/0.0));
783     children.push_back(Swizzle(Pos(), kWIndex));
784     std::unique_ptr<Expression> result =
785             Op(Pos(), Token::Kind::TK_EQ,
786                Constructor::Convert(fContext, /*offset=*/-1, *fContext.fTypes.fFloat4,
787                                     std::move(children)));
788     return ExpressionStatement::Make(fContext, std::move(result));
789 }
790 
CheckModifiers(const Context & context,int offset,const Modifiers & modifiers,int permittedModifierFlags,int permittedLayoutFlags)791 void IRGenerator::CheckModifiers(const Context& context,
792                                  int offset,
793                                  const Modifiers& modifiers,
794                                  int permittedModifierFlags,
795                                  int permittedLayoutFlags) {
796     ErrorReporter& errorReporter = context.fErrors;
797     int flags = modifiers.fFlags;
798     auto checkModifier = [&](Modifiers::Flag flag, const char* name) {
799         if (flags & flag) {
800             if (!(permittedModifierFlags & flag)) {
801                 errorReporter.error(offset, "'" + String(name) + "' is not permitted here");
802             }
803             flags &= ~flag;
804         }
805     };
806 
807     checkModifier(Modifiers::kConst_Flag,          "const");
808     checkModifier(Modifiers::kIn_Flag,             "in");
809     checkModifier(Modifiers::kOut_Flag,            "out");
810     checkModifier(Modifiers::kUniform_Flag,        "uniform");
811     checkModifier(Modifiers::kFlat_Flag,           "flat");
812     checkModifier(Modifiers::kNoPerspective_Flag,  "noperspective");
813     checkModifier(Modifiers::kHasSideEffects_Flag, "sk_has_side_effects");
814     checkModifier(Modifiers::kInline_Flag,         "inline");
815     checkModifier(Modifiers::kNoInline_Flag,       "noinline");
816     SkASSERT(flags == 0);
817 
818     int layoutFlags = modifiers.fLayout.fFlags;
819     auto checkLayout = [&](Layout::Flag flag, const char* name) {
820         if (layoutFlags & flag) {
821             if (!(permittedLayoutFlags & flag)) {
822                 errorReporter.error(offset, "layout qualifier '" + String(name) +
823                                             "' is not permitted here");
824             }
825             layoutFlags &= ~flag;
826         }
827     };
828 
829     checkLayout(Layout::kOriginUpperLeft_Flag,          "origin_upper_left");
830     checkLayout(Layout::kOverrideCoverage_Flag,         "override_coverage");
831     checkLayout(Layout::kPushConstant_Flag,             "push_constant");
832     checkLayout(Layout::kBlendSupportAllEquations_Flag, "blend_support_all_equations");
833     checkLayout(Layout::kTracked_Flag,                  "tracked");
834     checkLayout(Layout::kSRGBUnpremul_Flag,             "srgb_unpremul");
835     checkLayout(Layout::kKey_Flag,                      "key");
836     checkLayout(Layout::kLocation_Flag,                 "location");
837     checkLayout(Layout::kOffset_Flag,                   "offset");
838     checkLayout(Layout::kBinding_Flag,                  "binding");
839     checkLayout(Layout::kIndex_Flag,                    "index");
840     checkLayout(Layout::kSet_Flag,                      "set");
841     checkLayout(Layout::kBuiltin_Flag,                  "builtin");
842     checkLayout(Layout::kInputAttachmentIndex_Flag,     "input_attachment_index");
843     checkLayout(Layout::kPrimitive_Flag,                "primitive-type");
844     checkLayout(Layout::kMaxVertices_Flag,              "max_vertices");
845     checkLayout(Layout::kInvocations_Flag,              "invocations");
846     checkLayout(Layout::kWhen_Flag,                     "when");
847     checkLayout(Layout::kCType_Flag,                    "ctype");
848     SkASSERT(layoutFlags == 0);
849 }
850 
finalizeFunction(const FunctionDeclaration & funcDecl,Statement * body)851 void IRGenerator::finalizeFunction(const FunctionDeclaration& funcDecl, Statement* body) {
852     class Finalizer : public ProgramWriter {
853     public:
854         Finalizer(IRGenerator* irGenerator, const FunctionDeclaration* function)
855             : fIRGenerator(irGenerator)
856             , fFunction(function) {}
857 
858         ~Finalizer() override {
859             SkASSERT(!fBreakableLevel);
860             SkASSERT(!fContinuableLevel);
861         }
862 
863         bool functionReturnsValue() const {
864             return !fFunction->returnType().isVoid();
865         }
866 
867         bool visitExpression(Expression& expr) override {
868             // Do not recurse into expressions.
869             return false;
870         }
871 
872         bool visitStatement(Statement& stmt) override {
873             switch (stmt.kind()) {
874                 case Statement::Kind::kReturn: {
875                     // early returns from a vertex main function will bypass the sk_Position
876                     // normalization, so SkASSERT that we aren't doing that. It is of course
877                     // possible to fix this by adding a normalization before each return, but it
878                     // will probably never actually be necessary.
879                     SkASSERT(fIRGenerator->programKind() != ProgramKind::kVertex ||
880                              !fIRGenerator->fRTAdjust ||
881                              !fFunction->isMain());
882 
883                     // Verify that the return statement matches the function's return type.
884                     ReturnStatement& returnStmt = stmt.as<ReturnStatement>();
885                     const Type& returnType = fFunction->returnType();
886                     if (returnStmt.expression()) {
887                         if (this->functionReturnsValue()) {
888                             // Coerce return expression to the function's return type.
889                             returnStmt.setExpression(fIRGenerator->coerce(
890                                     std::move(returnStmt.expression()), returnType));
891                         } else {
892                             // Returning something from a function with a void return type.
893                             fIRGenerator->errorReporter().error(returnStmt.fOffset,
894                                                      "may not return a value from a void function");
895                         }
896                     } else {
897                         if (this->functionReturnsValue()) {
898                             // Returning nothing from a function with a non-void return type.
899                             fIRGenerator->errorReporter().error(returnStmt.fOffset,
900                                   "expected function to return '" + returnType.displayName() + "'");
901                         }
902                     }
903                     break;
904                 }
905                 case Statement::Kind::kDo:
906                 case Statement::Kind::kFor: {
907                     ++fBreakableLevel;
908                     ++fContinuableLevel;
909                     bool result = INHERITED::visitStatement(stmt);
910                     --fContinuableLevel;
911                     --fBreakableLevel;
912                     return result;
913                 }
914                 case Statement::Kind::kSwitch: {
915                     ++fBreakableLevel;
916                     bool result = INHERITED::visitStatement(stmt);
917                     --fBreakableLevel;
918                     return result;
919                 }
920                 case Statement::Kind::kBreak:
921                     if (!fBreakableLevel) {
922                         fIRGenerator->errorReporter().error(stmt.fOffset,
923                                                  "break statement must be inside a loop or switch");
924                     }
925                     break;
926                 case Statement::Kind::kContinue:
927                     if (!fContinuableLevel) {
928                         fIRGenerator->errorReporter().error(stmt.fOffset,
929                                                         "continue statement must be inside a loop");
930                     }
931                     break;
932                 default:
933                     break;
934             }
935             return INHERITED::visitStatement(stmt);
936         }
937 
938     private:
939         IRGenerator* fIRGenerator;
940         const FunctionDeclaration* fFunction;
941         // how deeply nested we are in breakable constructs (for, do, switch).
942         int fBreakableLevel = 0;
943         // how deeply nested we are in continuable constructs (for, do).
944         int fContinuableLevel = 0;
945 
946         using INHERITED = ProgramWriter;
947     };
948 
949     Finalizer finalizer{this, &funcDecl};
950     finalizer.visitStatement(*body);
951 
952     if (Analysis::CanExitWithoutReturningValue(funcDecl, *body)) {
953         this->errorReporter().error(funcDecl.fOffset, "function '" + funcDecl.name() +
954                                                       "' can exit without returning a value");
955     }
956 }
957 
convertFunction(const ASTNode & f)958 void IRGenerator::convertFunction(const ASTNode& f) {
959     SkASSERT(fReferencedIntrinsics.empty());
960     SK_AT_SCOPE_EXIT(fReferencedIntrinsics.clear());
961 
962     auto iter = f.begin();
963     const Type* returnType = this->convertType(*(iter++), /*allowVoid=*/true);
964     if (returnType == nullptr) {
965         return;
966     }
967     const ASTNode::FunctionData& funcData = f.getFunctionData();
968     bool isMain = (funcData.fName == "main");
969 
970     std::vector<std::unique_ptr<Variable>> parameters;
971     parameters.reserve(funcData.fParameterCount);
972     for (size_t i = 0; i < funcData.fParameterCount; ++i) {
973         const ASTNode& param = *(iter++);
974         SkASSERT(param.fKind == ASTNode::Kind::kParameter);
975         const ASTNode::ParameterData& pd = param.getParameterData();
976         auto paramIter = param.begin();
977         const Type* type = this->convertType(*(paramIter++));
978         if (!type) {
979             return;
980         }
981         if (pd.fIsArray) {
982             int arraySize = this->convertArraySize(*type, param.fOffset, *paramIter++);
983             if (!arraySize) {
984                 return;
985             }
986             type = fSymbolTable->addArrayDimension(type, arraySize);
987         }
988 
989         parameters.push_back(std::make_unique<Variable>(param.fOffset,
990                                                         this->modifiersPool().add(pd.fModifiers),
991                                                         pd.fName,
992                                                         type,
993                                                         fIsBuiltinCode,
994                                                         Variable::Storage::kParameter));
995     }
996 
997     // Conservatively assume all user-defined functions have side effects.
998     Modifiers declModifiers = funcData.fModifiers;
999     if (!fIsBuiltinCode) {
1000         declModifiers.fFlags |= Modifiers::kHasSideEffects_Flag;
1001     }
1002 
1003     if (fContext.fConfig->fSettings.fForceNoInline) {
1004         // Apply the `noinline` modifier to every function. This allows us to test Runtime
1005         // Effects without any inlining, even when the code is later added to a paint.
1006         declModifiers.fFlags &= ~Modifiers::kInline_Flag;
1007         declModifiers.fFlags |= Modifiers::kNoInline_Flag;
1008     }
1009 
1010     const FunctionDeclaration* decl = FunctionDeclaration::Convert(
1011                                                            fContext,
1012                                                            *fSymbolTable,
1013                                                            f.fOffset,
1014                                                            this->modifiersPool().add(declModifiers),
1015                                                            funcData.fName,
1016                                                            std::move(parameters),
1017                                                            returnType,
1018                                                            fIsBuiltinCode);
1019     if (!decl) {
1020         return;
1021     }
1022     if (iter == f.end()) {
1023         // If there's no body, we've found a prototype.
1024         fProgramElements->push_back(std::make_unique<FunctionPrototype>(f.fOffset, decl,
1025                                                                         fIsBuiltinCode));
1026     } else {
1027         // Compile function body.
1028         AutoSymbolTable table(this);
1029         for (const Variable* param : decl->parameters()) {
1030             fSymbolTable->addWithoutOwnership(param);
1031         }
1032         bool needInvocationIDWorkaround = fInvocations != -1 && isMain &&
1033                                           !this->caps().gsInvocationsSupport();
1034         std::unique_ptr<Block> body = this->convertBlock(*iter);
1035         if (!body) {
1036             return;
1037         }
1038         if (needInvocationIDWorkaround) {
1039             body = this->applyInvocationIDWorkaround(std::move(body));
1040         }
1041         if (ProgramKind::kVertex == this->programKind() && isMain && fRTAdjust) {
1042             body->children().push_back(this->getNormalizeSkPositionCode());
1043         }
1044         this->finalizeFunction(*decl, body.get());
1045         auto result = std::make_unique<FunctionDefinition>(
1046                 f.fOffset, decl, fIsBuiltinCode, std::move(body), std::move(fReferencedIntrinsics));
1047         decl->setDefinition(result.get());
1048         result->setSource(&f);
1049         fProgramElements->push_back(std::move(result));
1050     }
1051 }
1052 
convertStructDefinition(const ASTNode & node)1053 std::unique_ptr<StructDefinition> IRGenerator::convertStructDefinition(const ASTNode& node) {
1054     SkASSERT(node.fKind == ASTNode::Kind::kType);
1055 
1056     const Type* type = this->convertType(node);
1057     if (!type) {
1058         return nullptr;
1059     }
1060     if (!type->isStruct()) {
1061         this->errorReporter().error(node.fOffset,
1062                                     "expected a struct here, found '" + type->name() + "'");
1063         return nullptr;
1064     }
1065     SkDEBUGCODE(auto [iter, wasInserted] =) fDefinedStructs.insert(type);
1066     SkASSERT(wasInserted);
1067     return std::make_unique<StructDefinition>(node.fOffset, *type);
1068 }
1069 
convertInterfaceBlock(const ASTNode & intf)1070 std::unique_ptr<InterfaceBlock> IRGenerator::convertInterfaceBlock(const ASTNode& intf) {
1071     if (this->programKind() != ProgramKind::kFragment &&
1072         this->programKind() != ProgramKind::kVertex &&
1073         this->programKind() != ProgramKind::kGeometry) {
1074         this->errorReporter().error(intf.fOffset, "interface block is not allowed here");
1075         return nullptr;
1076     }
1077 
1078     SkASSERT(intf.fKind == ASTNode::Kind::kInterfaceBlock);
1079     const ASTNode::InterfaceBlockData& id = intf.getInterfaceBlockData();
1080     std::shared_ptr<SymbolTable> old = fSymbolTable;
1081     std::shared_ptr<SymbolTable> symbols;
1082     std::vector<Type::Field> fields;
1083     bool foundRTAdjust = false;
1084     auto iter = intf.begin();
1085     {
1086         AutoSymbolTable table(this);
1087         symbols = fSymbolTable;
1088         for (size_t i = 0; i < id.fDeclarationCount; ++i) {
1089             StatementArray decls = this->convertVarDeclarations(*(iter++),
1090                                                                 Variable::Storage::kInterfaceBlock);
1091             if (decls.empty()) {
1092                 return nullptr;
1093             }
1094             for (const auto& decl : decls) {
1095                 const VarDeclaration& vd = decl->as<VarDeclaration>();
1096                 if (vd.var().type().isOpaque()) {
1097                     this->errorReporter().error(decl->fOffset,
1098                                                 "opaque type '" + vd.var().type().name() +
1099                                                         "' is not permitted in an interface block");
1100                 }
1101                 if (&vd.var() == fRTAdjust) {
1102                     foundRTAdjust = true;
1103                     SkASSERT(vd.var().type() == *fContext.fTypes.fFloat4);
1104                     fRTAdjustFieldIndex = fields.size();
1105                 }
1106                 fields.push_back(Type::Field(vd.var().modifiers(), vd.var().name(),
1107                                             &vd.var().type()));
1108                 if (vd.value()) {
1109                     this->errorReporter().error(
1110                             decl->fOffset,
1111                             "initializers are not permitted on interface block fields");
1112                 }
1113             }
1114         }
1115     }
1116     const Type* type = old->takeOwnershipOfSymbol(Type::MakeStructType(intf.fOffset, id.fTypeName,
1117                                                                        fields));
1118     int arraySize = 0;
1119     if (id.fIsArray) {
1120         const ASTNode& size = *(iter++);
1121         if (size) {
1122             // convertArraySize rejects unsized arrays. This is the one place we allow those, but
1123             // we've already checked for that, so this is verifying the other aspects (constant,
1124             // positive, not too large).
1125             arraySize = this->convertArraySize(*type, size.fOffset, size);
1126             if (!arraySize) {
1127                 return nullptr;
1128             }
1129         } else {
1130             arraySize = Type::kUnsizedArray;
1131         }
1132         type = symbols->addArrayDimension(type, arraySize);
1133     }
1134     const Variable* var = old->takeOwnershipOfSymbol(
1135             std::make_unique<Variable>(intf.fOffset,
1136                                        this->modifiersPool().add(id.fModifiers),
1137                                        id.fInstanceName.fLength ? id.fInstanceName : id.fTypeName,
1138                                        type,
1139                                        fIsBuiltinCode,
1140                                        Variable::Storage::kGlobal));
1141     if (foundRTAdjust) {
1142         fRTAdjustInterfaceBlock = var;
1143     }
1144     if (id.fInstanceName.fLength) {
1145         old->addWithoutOwnership(var);
1146     } else {
1147         for (size_t i = 0; i < fields.size(); i++) {
1148             old->add(std::make_unique<Field>(intf.fOffset, var, (int)i));
1149         }
1150     }
1151     return std::make_unique<InterfaceBlock>(intf.fOffset,
1152                                             var,
1153                                             id.fTypeName,
1154                                             id.fInstanceName,
1155                                             arraySize,
1156                                             symbols);
1157 }
1158 
convertGlobalVarDeclarations(const ASTNode & decl)1159 void IRGenerator::convertGlobalVarDeclarations(const ASTNode& decl) {
1160     StatementArray decls = this->convertVarDeclarations(decl, Variable::Storage::kGlobal);
1161     for (std::unique_ptr<Statement>& stmt : decls) {
1162         const Type* type = &stmt->as<VarDeclaration>().baseType();
1163         if (type->isStruct()) {
1164             auto [iter, wasInserted] = fDefinedStructs.insert(type);
1165             if (wasInserted) {
1166                 fProgramElements->push_back(
1167                         std::make_unique<StructDefinition>(decl.fOffset, *type));
1168             }
1169         }
1170         fProgramElements->push_back(std::make_unique<GlobalVarDeclaration>(std::move(stmt)));
1171     }
1172 }
1173 
convertEnum(const ASTNode & e)1174 void IRGenerator::convertEnum(const ASTNode& e) {
1175     if (this->strictES2Mode()) {
1176         this->errorReporter().error(e.fOffset, "enum is not allowed here");
1177         return;
1178     }
1179 
1180     SkASSERT(e.fKind == ASTNode::Kind::kEnum);
1181     SKSL_INT currentValue = 0;
1182     Layout layout;
1183     ASTNode enumType(e.fNodes, e.fOffset, ASTNode::Kind::kType, e.getString());
1184     const Type* type = this->convertType(enumType);
1185     Modifiers modifiers(layout, Modifiers::kConst_Flag);
1186     std::shared_ptr<SymbolTable> oldTable = fSymbolTable;
1187     fSymbolTable = std::make_shared<SymbolTable>(fSymbolTable, fIsBuiltinCode);
1188     for (auto iter = e.begin(); iter != e.end(); ++iter) {
1189         const ASTNode& child = *iter;
1190         SkASSERT(child.fKind == ASTNode::Kind::kEnumCase);
1191         std::unique_ptr<Expression> value;
1192         if (child.begin() != child.end()) {
1193             value = this->convertExpression(*child.begin());
1194             if (!value) {
1195                 fSymbolTable = oldTable;
1196                 return;
1197             }
1198             if (!ConstantFolder::GetConstantInt(*value, &currentValue)) {
1199                 this->errorReporter().error(value->fOffset,
1200                                             "enum value must be a constant integer");
1201                 fSymbolTable = oldTable;
1202                 return;
1203             }
1204         }
1205         value = IntLiteral::Make(fContext, e.fOffset, currentValue);
1206         ++currentValue;
1207         auto var = std::make_unique<Variable>(e.fOffset, this->modifiersPool().add(modifiers),
1208                                               child.getString(), type, fIsBuiltinCode,
1209                                               Variable::Storage::kGlobal);
1210         // enum variables aren't really 'declared', but we have to create a declaration to store
1211         // the value
1212         auto declaration = VarDeclaration::Make(fContext, var.get(), &var->type(), /*arraySize=*/0,
1213                                                 std::move(value));
1214         fSymbolTable->add(std::move(var));
1215         fSymbolTable->takeOwnershipOfIRNode(std::move(declaration));
1216     }
1217     // Now we orphanize the Enum's symbol table, so that future lookups in it are strict
1218     fSymbolTable->fParent = nullptr;
1219     fProgramElements->push_back(std::make_unique<Enum>(e.fOffset, e.getString(), fSymbolTable,
1220                                                        /*isSharedWithCpp=*/fIsBuiltinCode,
1221                                                        /*isBuiltin=*/fIsBuiltinCode));
1222     fSymbolTable = oldTable;
1223 }
1224 
typeContainsPrivateFields(const Type & type)1225 bool IRGenerator::typeContainsPrivateFields(const Type& type) {
1226     // Checks for usage of private types, including fields inside a struct.
1227     if (type.isPrivate()) {
1228         return true;
1229     }
1230     if (type.isStruct()) {
1231         for (const auto& f : type.fields()) {
1232             if (this->typeContainsPrivateFields(*f.fType)) {
1233                 return true;
1234             }
1235         }
1236     }
1237     return false;
1238 }
1239 
convertType(const ASTNode & type,bool allowVoid)1240 const Type* IRGenerator::convertType(const ASTNode& type, bool allowVoid) {
1241     StringFragment name = type.getString();
1242     const Symbol* symbol = (*fSymbolTable)[name];
1243     if (!symbol || !symbol->is<Type>()) {
1244         this->errorReporter().error(type.fOffset, "unknown type '" + name + "'");
1245         return nullptr;
1246     }
1247     const Type* result = &symbol->as<Type>();
1248     const bool isArray = (type.begin() != type.end());
1249     if (result->isVoid() && !allowVoid) {
1250         this->errorReporter().error(type.fOffset,
1251                                     "type '" + name + "' not allowed in this context");
1252         return nullptr;
1253     }
1254     if (!fIsBuiltinCode && this->typeContainsPrivateFields(*result)) {
1255         this->errorReporter().error(type.fOffset, "type '" + name + "' is private");
1256         return nullptr;
1257     }
1258     if (isArray) {
1259         auto iter = type.begin();
1260         int arraySize = this->convertArraySize(*result, type.fOffset, *iter);
1261         if (!arraySize) {
1262             return nullptr;
1263         }
1264         result = fSymbolTable->addArrayDimension(result, arraySize);
1265     }
1266     return result;
1267 }
1268 
convertExpression(const ASTNode & expr)1269 std::unique_ptr<Expression> IRGenerator::convertExpression(const ASTNode& expr) {
1270     switch (expr.fKind) {
1271         case ASTNode::Kind::kBinary:
1272             return this->convertBinaryExpression(expr);
1273         case ASTNode::Kind::kBool:
1274             return BoolLiteral::Make(fContext, expr.fOffset, expr.getBool());
1275         case ASTNode::Kind::kCall:
1276             return this->convertCallExpression(expr);
1277         case ASTNode::Kind::kField:
1278             return this->convertFieldExpression(expr);
1279         case ASTNode::Kind::kFloat:
1280             return FloatLiteral::Make(fContext, expr.fOffset, expr.getFloat());
1281         case ASTNode::Kind::kIdentifier:
1282             return this->convertIdentifier(expr);
1283         case ASTNode::Kind::kIndex:
1284             return this->convertIndexExpression(expr);
1285         case ASTNode::Kind::kInt:
1286             return IntLiteral::Make(fContext, expr.fOffset, expr.getInt());
1287         case ASTNode::Kind::kPostfix:
1288             return this->convertPostfixExpression(expr);
1289         case ASTNode::Kind::kPrefix:
1290             return this->convertPrefixExpression(expr);
1291         case ASTNode::Kind::kScope:
1292             return this->convertScopeExpression(expr);
1293         case ASTNode::Kind::kTernary:
1294             return this->convertTernaryExpression(expr);
1295         default:
1296             SkDEBUGFAILF("unsupported expression: %s\n", expr.description().c_str());
1297             return nullptr;
1298     }
1299 }
1300 
convertIdentifier(int offset,StringFragment name)1301 std::unique_ptr<Expression> IRGenerator::convertIdentifier(int offset, StringFragment name) {
1302     const Symbol* result = (*fSymbolTable)[name];
1303     if (!result) {
1304         this->errorReporter().error(offset, "unknown identifier '" + name + "'");
1305         return nullptr;
1306     }
1307     switch (result->kind()) {
1308         case Symbol::Kind::kFunctionDeclaration: {
1309             std::vector<const FunctionDeclaration*> f = {
1310                 &result->as<FunctionDeclaration>()
1311             };
1312             return std::make_unique<FunctionReference>(fContext, offset, f);
1313         }
1314         case Symbol::Kind::kUnresolvedFunction: {
1315             const UnresolvedFunction* f = &result->as<UnresolvedFunction>();
1316             return std::make_unique<FunctionReference>(fContext, offset, f->functions());
1317         }
1318         case Symbol::Kind::kVariable: {
1319             const Variable* var = &result->as<Variable>();
1320             const Modifiers& modifiers = var->modifiers();
1321             switch (modifiers.fLayout.fBuiltin) {
1322 #ifndef SKSL_STANDALONE
1323                 case SK_FRAGCOORD_BUILTIN:
1324                     fInputs.fFlipY = true;
1325                     if (this->settings().fFlipY &&
1326                         !this->caps().fragCoordConventionsExtensionString()) {
1327                         fInputs.fRTHeight = true;
1328                     }
1329 #endif
1330             }
1331             if (this->programKind() == ProgramKind::kFragmentProcessor &&
1332                 (modifiers.fFlags & Modifiers::kIn_Flag) &&
1333                 !(modifiers.fFlags & Modifiers::kUniform_Flag) &&
1334                 !(modifiers.fLayout.fFlags & Layout::kKey_Flag) &&
1335                 modifiers.fLayout.fBuiltin == -1 &&
1336                 !var->type().isFragmentProcessor() &&
1337                 var->type().typeKind() != Type::TypeKind::kSampler) {
1338                 bool valid = false;
1339                 for (const auto& decl : fFile->root()) {
1340                     if (decl.fKind == ASTNode::Kind::kSection) {
1341                         const ASTNode::SectionData& section = decl.getSectionData();
1342                         if (section.fName == "setData") {
1343                             valid = true;
1344                             break;
1345                         }
1346                     }
1347                 }
1348                 if (!valid) {
1349                     this->errorReporter().error(
1350                             offset,
1351                             "'in' variable must be either 'uniform' or 'layout(key)', or there "
1352                             "must be a custom @setData function");
1353                 }
1354             }
1355             // default to kRead_RefKind; this will be corrected later if the variable is written to
1356             return VariableReference::Make(offset, var, VariableReference::RefKind::kRead);
1357         }
1358         case Symbol::Kind::kField: {
1359             const Field* field = &result->as<Field>();
1360             auto base = VariableReference::Make(offset, &field->owner(),
1361                                                 VariableReference::RefKind::kRead);
1362             return FieldAccess::Make(fContext, std::move(base), field->fieldIndex(),
1363                                      FieldAccess::OwnerKind::kAnonymousInterfaceBlock);
1364         }
1365         case Symbol::Kind::kType: {
1366             const Type* t = &result->as<Type>();
1367             return std::make_unique<TypeReference>(fContext, offset, t);
1368         }
1369         case Symbol::Kind::kExternal: {
1370             const ExternalFunction* r = &result->as<ExternalFunction>();
1371             return std::make_unique<ExternalFunctionReference>(offset, r);
1372         }
1373         default:
1374             SK_ABORT("unsupported symbol type %d\n", (int) result->kind());
1375     }
1376 }
1377 
convertIdentifier(const ASTNode & identifier)1378 std::unique_ptr<Expression> IRGenerator::convertIdentifier(const ASTNode& identifier) {
1379     return this->convertIdentifier(identifier.fOffset, identifier.getString());
1380 }
1381 
convertSection(const ASTNode & s)1382 std::unique_ptr<Section> IRGenerator::convertSection(const ASTNode& s) {
1383     if (this->programKind() != ProgramKind::kFragmentProcessor) {
1384         this->errorReporter().error(s.fOffset, "syntax error");
1385         return nullptr;
1386     }
1387 
1388     const ASTNode::SectionData& section = s.getSectionData();
1389     return std::make_unique<Section>(s.fOffset, section.fName, section.fArgument,
1390                                                 section.fText);
1391 }
1392 
coerce(std::unique_ptr<Expression> expr,const Type & type)1393 std::unique_ptr<Expression> IRGenerator::coerce(std::unique_ptr<Expression> expr,
1394                                                 const Type& type) {
1395     return type.coerceExpression(std::move(expr), fContext);
1396 }
1397 
convertBinaryExpression(const ASTNode & expression)1398 std::unique_ptr<Expression> IRGenerator::convertBinaryExpression(const ASTNode& expression) {
1399     SkASSERT(expression.fKind == ASTNode::Kind::kBinary);
1400     auto iter = expression.begin();
1401     std::unique_ptr<Expression> left = this->convertExpression(*(iter++));
1402     if (!left) {
1403         return nullptr;
1404     }
1405     std::unique_ptr<Expression> right = this->convertExpression(*(iter++));
1406     if (!right) {
1407         return nullptr;
1408     }
1409     return BinaryExpression::Convert(fContext, std::move(left), expression.getOperator(),
1410                                      std::move(right));
1411 }
1412 
convertTernaryExpression(const ASTNode & node)1413 std::unique_ptr<Expression> IRGenerator::convertTernaryExpression(const ASTNode& node) {
1414     SkASSERT(node.fKind == ASTNode::Kind::kTernary);
1415     auto iter = node.begin();
1416     std::unique_ptr<Expression> test = this->convertExpression(*(iter++));
1417     if (!test) {
1418         return nullptr;
1419     }
1420     std::unique_ptr<Expression> ifTrue = this->convertExpression(*(iter++));
1421     if (!ifTrue) {
1422         return nullptr;
1423     }
1424     std::unique_ptr<Expression> ifFalse = this->convertExpression(*(iter++));
1425     if (!ifFalse) {
1426         return nullptr;
1427     }
1428     return TernaryExpression::Convert(fContext, std::move(test),
1429                                       std::move(ifTrue), std::move(ifFalse));
1430 }
1431 
copyIntrinsicIfNeeded(const FunctionDeclaration & function)1432 void IRGenerator::copyIntrinsicIfNeeded(const FunctionDeclaration& function) {
1433     if (const ProgramElement* found = fIntrinsics->findAndInclude(function.description())) {
1434         const FunctionDefinition& original = found->as<FunctionDefinition>();
1435 
1436         // Sort the referenced intrinsics into a consistent order; otherwise our output will become
1437         // non-deterministic.
1438         std::vector<const FunctionDeclaration*> intrinsics(original.referencedIntrinsics().begin(),
1439                                                            original.referencedIntrinsics().end());
1440         std::sort(intrinsics.begin(), intrinsics.end(),
1441                   [](const FunctionDeclaration* a, const FunctionDeclaration* b) {
1442                       if (a->isBuiltin() != b->isBuiltin()) {
1443                           return a->isBuiltin() < b->isBuiltin();
1444                       }
1445                       if (a->fOffset != b->fOffset) {
1446                           return a->fOffset < b->fOffset;
1447                       }
1448                       if (a->name() != b->name()) {
1449                           return a->name() < b->name();
1450                       }
1451                       return a->description() < b->description();
1452                   });
1453         for (const FunctionDeclaration* f : intrinsics) {
1454             this->copyIntrinsicIfNeeded(*f);
1455         }
1456 
1457         fSharedElements->push_back(found);
1458     }
1459 }
1460 
call(int offset,const FunctionDeclaration & function,ExpressionArray arguments)1461 std::unique_ptr<Expression> IRGenerator::call(int offset,
1462                                               const FunctionDeclaration& function,
1463                                               ExpressionArray arguments) {
1464     if (function.isBuiltin()) {
1465         if (function.definition()) {
1466             fReferencedIntrinsics.insert(&function);
1467         }
1468         if (!fIsBuiltinCode && fIntrinsics) {
1469             this->copyIntrinsicIfNeeded(function);
1470         }
1471     }
1472 
1473     return FunctionCall::Convert(fContext, offset, function, std::move(arguments));
1474 }
1475 
1476 /**
1477  * Determines the cost of coercing the arguments of a function to the required types. Cost has no
1478  * particular meaning other than "lower costs are preferred". Returns CoercionCost::Impossible() if
1479  * the call is not valid.
1480  */
callCost(const FunctionDeclaration & function,const ExpressionArray & arguments)1481 CoercionCost IRGenerator::callCost(const FunctionDeclaration& function,
1482                                    const ExpressionArray& arguments) {
1483     if (function.parameters().size() != arguments.size()) {
1484         return CoercionCost::Impossible();
1485     }
1486     FunctionDeclaration::ParamTypes types;
1487     const Type* ignored;
1488     if (!function.determineFinalTypes(arguments, &types, &ignored)) {
1489         return CoercionCost::Impossible();
1490     }
1491     CoercionCost total = CoercionCost::Free();
1492     for (size_t i = 0; i < arguments.size(); i++) {
1493         total = total + arguments[i]->coercionCost(*types[i]);
1494     }
1495     return total;
1496 }
1497 
call(int offset,std::unique_ptr<Expression> functionValue,ExpressionArray arguments)1498 std::unique_ptr<Expression> IRGenerator::call(int offset,
1499                                               std::unique_ptr<Expression> functionValue,
1500                                               ExpressionArray arguments) {
1501     switch (functionValue->kind()) {
1502         case Expression::Kind::kTypeReference:
1503             return Constructor::Convert(fContext,
1504                                         offset,
1505                                         functionValue->as<TypeReference>().value(),
1506                                         std::move(arguments));
1507         case Expression::Kind::kExternalFunctionReference: {
1508             const ExternalFunction& f = functionValue->as<ExternalFunctionReference>().function();
1509             int count = f.callParameterCount();
1510             if (count != (int) arguments.size()) {
1511                 this->errorReporter().error(offset, "external function expected " +
1512                                                     to_string(count) + " arguments, but found " +
1513                                                     to_string((int)arguments.size()));
1514                 return nullptr;
1515             }
1516             static constexpr int PARAMETER_MAX = 16;
1517             SkASSERT(count < PARAMETER_MAX);
1518             const Type* types[PARAMETER_MAX];
1519             f.getCallParameterTypes(types);
1520             for (int i = 0; i < count; ++i) {
1521                 arguments[i] = this->coerce(std::move(arguments[i]), *types[i]);
1522                 if (!arguments[i]) {
1523                     return nullptr;
1524                 }
1525             }
1526             return std::make_unique<ExternalFunctionCall>(offset, &f, std::move(arguments));
1527         }
1528         case Expression::Kind::kFunctionReference: {
1529             const FunctionReference& ref = functionValue->as<FunctionReference>();
1530             const std::vector<const FunctionDeclaration*>& functions = ref.functions();
1531             CoercionCost bestCost = CoercionCost::Impossible();
1532             const FunctionDeclaration* best = nullptr;
1533             if (functions.size() > 1) {
1534                 for (const auto& f : functions) {
1535                     CoercionCost cost = this->callCost(*f, arguments);
1536                     if (cost < bestCost) {
1537                         bestCost = cost;
1538                         best = f;
1539                     }
1540                 }
1541                 if (best) {
1542                     return this->call(offset, *best, std::move(arguments));
1543                 }
1544                 String msg = "no match for " + functions[0]->name() + "(";
1545                 String separator;
1546                 for (size_t i = 0; i < arguments.size(); i++) {
1547                     msg += separator;
1548                     separator = ", ";
1549                     msg += arguments[i]->type().displayName();
1550                 }
1551                 msg += ")";
1552                 this->errorReporter().error(offset, msg);
1553                 return nullptr;
1554             }
1555             return this->call(offset, *functions[0], std::move(arguments));
1556         }
1557         default:
1558             this->errorReporter().error(offset, "not a function");
1559             return nullptr;
1560     }
1561 }
1562 
convertPrefixExpression(const ASTNode & expression)1563 std::unique_ptr<Expression> IRGenerator::convertPrefixExpression(const ASTNode& expression) {
1564     SkASSERT(expression.fKind == ASTNode::Kind::kPrefix);
1565     std::unique_ptr<Expression> base = this->convertExpression(*expression.begin());
1566     if (!base) {
1567         return nullptr;
1568     }
1569     return PrefixExpression::Convert(fContext, expression.getOperator(), std::move(base));
1570 }
1571 
1572 // Swizzles are complicated due to constant components. The most difficult case is a mask like
1573 // '.x1w0'. A naive approach might turn that into 'float4(base.x, 1, base.w, 0)', but that evaluates
1574 // 'base' twice. We instead group the swizzle mask ('xw') and constants ('1, 0') together and use a
1575 // secondary swizzle to put them back into the right order, so in this case we end up with
1576 // 'float4(base.xw, 1, 0).xzyw'.
convertSwizzle(std::unique_ptr<Expression> base,String fields)1577 std::unique_ptr<Expression> IRGenerator::convertSwizzle(std::unique_ptr<Expression> base,
1578                                                         String fields) {
1579     const int offset = base->fOffset;
1580     const Type& baseType = base->type();
1581     if (!baseType.isVector() && !baseType.isNumber()) {
1582         this->errorReporter().error(
1583                 offset, "cannot swizzle value of type '" + baseType.displayName() + "'");
1584         return nullptr;
1585     }
1586 
1587     if (fields.length() > 4) {
1588         this->errorReporter().error(offset, "too many components in swizzle mask '" + fields + "'");
1589         return nullptr;
1590     }
1591 
1592     ComponentArray components;
1593     bool foundXYZW = false;
1594     for (char field : fields) {
1595         switch (field) {
1596             case '0':
1597                 components.push_back(SwizzleComponent::ZERO);
1598                 break;
1599             case '1':
1600                 components.push_back(SwizzleComponent::ONE);
1601                 break;
1602             case 'x':
1603             case 'r':
1604             case 's':
1605             case 'L':
1606                 components.push_back(SwizzleComponent::X);
1607                 foundXYZW = true;
1608                 break;
1609             case 'y':
1610             case 'g':
1611             case 't':
1612             case 'T':
1613                 if (baseType.columns() >= 2) {
1614                     components.push_back(SwizzleComponent::Y);
1615                     foundXYZW = true;
1616                     break;
1617                 }
1618                 [[fallthrough]];
1619             case 'z':
1620             case 'b':
1621             case 'p':
1622             case 'R':
1623                 if (baseType.columns() >= 3) {
1624                     components.push_back(SwizzleComponent::Z);
1625                     foundXYZW = true;
1626                     break;
1627                 }
1628                 [[fallthrough]];
1629             case 'w':
1630             case 'a':
1631             case 'q':
1632             case 'B':
1633                 if (baseType.columns() >= 4) {
1634                     components.push_back(SwizzleComponent::W);
1635                     foundXYZW = true;
1636                     break;
1637                 }
1638                 [[fallthrough]];
1639             default:
1640                 this->errorReporter().error(
1641                         offset, String::printf("invalid swizzle component '%c'", field));
1642                 return nullptr;
1643         }
1644     }
1645 
1646     if (!foundXYZW) {
1647         this->errorReporter().error(offset, "swizzle must refer to base expression");
1648         return nullptr;
1649     }
1650 
1651     return Swizzle::Convert(fContext, std::move(base), components);
1652 }
1653 
convertTypeField(int offset,const Type & type,StringFragment field)1654 std::unique_ptr<Expression> IRGenerator::convertTypeField(int offset, const Type& type,
1655                                                           StringFragment field) {
1656     const ProgramElement* enumElement = nullptr;
1657     // Find the Enum element that this type refers to, start by searching our elements
1658     for (const std::unique_ptr<ProgramElement>& e : *fProgramElements) {
1659         if (e->is<Enum>() && type.name() == e->as<Enum>().typeName()) {
1660             enumElement = e.get();
1661             break;
1662         }
1663     }
1664     // ... if that fails, look in our shared elements
1665     if (!enumElement) {
1666         for (const ProgramElement* e : *fSharedElements) {
1667             if (e->is<Enum>() && type.name() == e->as<Enum>().typeName()) {
1668                 enumElement = e;
1669                 break;
1670             }
1671         }
1672     }
1673     // ... and if that fails, check the intrinsics, add it to our shared elements
1674     if (!enumElement && !fIsBuiltinCode && fIntrinsics) {
1675         if (const ProgramElement* found = fIntrinsics->findAndInclude(type.name())) {
1676             fSharedElements->push_back(found);
1677             enumElement = found;
1678         }
1679     }
1680     if (!enumElement) {
1681         this->errorReporter().error(offset,
1682                                     "type '" + type.displayName() + "' is not a known enum");
1683         return nullptr;
1684     }
1685 
1686     // We found the Enum element. Look for 'field' as a member.
1687     std::shared_ptr<SymbolTable> old = fSymbolTable;
1688     fSymbolTable = enumElement->as<Enum>().symbols();
1689     std::unique_ptr<Expression> result =
1690             convertIdentifier(ASTNode(&fFile->fNodes, offset, ASTNode::Kind::kIdentifier, field));
1691     if (result) {
1692         const Variable& v = *result->as<VariableReference>().variable();
1693         SkASSERT(v.initialValue());
1694         result = IntLiteral::Make(offset, v.initialValue()->as<IntLiteral>().value(), &type);
1695     } else {
1696         this->errorReporter().error(
1697                 offset, "type '" + type.name() + "' does not contain enumerator '" + field + "'");
1698     }
1699     fSymbolTable = old;
1700     return result;
1701 }
1702 
convertIndexExpression(const ASTNode & index)1703 std::unique_ptr<Expression> IRGenerator::convertIndexExpression(const ASTNode& index) {
1704     SkASSERT(index.fKind == ASTNode::Kind::kIndex);
1705     auto iter = index.begin();
1706     std::unique_ptr<Expression> base = this->convertExpression(*(iter++));
1707     if (!base) {
1708         return nullptr;
1709     }
1710     if (base->is<TypeReference>()) {
1711         // Convert an index expression starting with a type name: `int[12]`
1712         if (iter == index.end()) {
1713             this->errorReporter().error(index.fOffset, "array must have a size");
1714             return nullptr;
1715         }
1716         const Type* type = &base->as<TypeReference>().value();
1717         int arraySize = this->convertArraySize(*type, index.fOffset, *iter);
1718         if (!arraySize) {
1719             return nullptr;
1720         }
1721         type = fSymbolTable->addArrayDimension(type, arraySize);
1722         return std::make_unique<TypeReference>(fContext, base->fOffset, type);
1723     }
1724 
1725     if (iter == index.end()) {
1726         this->errorReporter().error(base->fOffset, "missing index in '[]'");
1727         return nullptr;
1728     }
1729     std::unique_ptr<Expression> converted = this->convertExpression(*(iter++));
1730     if (!converted) {
1731         return nullptr;
1732     }
1733     return IndexExpression::Convert(fContext, std::move(base), std::move(converted));
1734 }
1735 
convertCallExpression(const ASTNode & callNode)1736 std::unique_ptr<Expression> IRGenerator::convertCallExpression(const ASTNode& callNode) {
1737     SkASSERT(callNode.fKind == ASTNode::Kind::kCall);
1738     auto iter = callNode.begin();
1739     std::unique_ptr<Expression> base = this->convertExpression(*(iter++));
1740     if (!base) {
1741         return nullptr;
1742     }
1743     ExpressionArray arguments;
1744     for (; iter != callNode.end(); ++iter) {
1745         std::unique_ptr<Expression> converted = this->convertExpression(*iter);
1746         if (!converted) {
1747             return nullptr;
1748         }
1749         arguments.push_back(std::move(converted));
1750     }
1751     return this->call(callNode.fOffset, std::move(base), std::move(arguments));
1752 }
1753 
convertFieldExpression(const ASTNode & fieldNode)1754 std::unique_ptr<Expression> IRGenerator::convertFieldExpression(const ASTNode& fieldNode) {
1755     std::unique_ptr<Expression> base = this->convertExpression(*fieldNode.begin());
1756     if (!base) {
1757         return nullptr;
1758     }
1759     const StringFragment& field = fieldNode.getString();
1760     const Type& baseType = base->type();
1761     if (baseType == *fContext.fTypes.fSkCaps) {
1762         return Setting::Convert(fContext, fieldNode.fOffset, field);
1763     }
1764     if (baseType.isStruct()) {
1765         return FieldAccess::Convert(fContext, std::move(base), field);
1766     }
1767     return this->convertSwizzle(std::move(base), field);
1768 }
1769 
convertScopeExpression(const ASTNode & scopeNode)1770 std::unique_ptr<Expression> IRGenerator::convertScopeExpression(const ASTNode& scopeNode) {
1771     std::unique_ptr<Expression> base = this->convertExpression(*scopeNode.begin());
1772     if (!base) {
1773         return nullptr;
1774     }
1775     if (!base->is<TypeReference>()) {
1776         this->errorReporter().error(scopeNode.fOffset, "'::' must follow a type name");
1777         return nullptr;
1778     }
1779     const StringFragment& member = scopeNode.getString();
1780     return this->convertTypeField(base->fOffset, base->as<TypeReference>().value(), member);
1781 }
1782 
convertPostfixExpression(const ASTNode & expression)1783 std::unique_ptr<Expression> IRGenerator::convertPostfixExpression(const ASTNode& expression) {
1784     SkASSERT(expression.fKind == ASTNode::Kind::kPostfix);
1785     std::unique_ptr<Expression> base = this->convertExpression(*expression.begin());
1786     if (!base) {
1787         return nullptr;
1788     }
1789     return PostfixExpression::Convert(fContext, std::move(base), expression.getOperator());
1790 }
1791 
checkValid(const Expression & expr)1792 void IRGenerator::checkValid(const Expression& expr) {
1793     switch (expr.kind()) {
1794         case Expression::Kind::kFunctionCall: {
1795             const FunctionDeclaration& decl = expr.as<FunctionCall>().function();
1796             if (!decl.isBuiltin() && !decl.definition()) {
1797                 this->errorReporter().error(expr.fOffset,
1798                                             "function '" + decl.description() + "' is not defined");
1799             }
1800             break;
1801         }
1802         case Expression::Kind::kFunctionReference:
1803         case Expression::Kind::kTypeReference:
1804             SkDEBUGFAIL("invalid reference-expression, should have been reported by coerce()");
1805             this->errorReporter().error(expr.fOffset, "invalid expression");
1806             break;
1807         default:
1808             if (expr.type() == *fContext.fTypes.fInvalid) {
1809                 this->errorReporter().error(expr.fOffset, "invalid expression");
1810             }
1811             break;
1812     }
1813 }
1814 
findAndDeclareBuiltinVariables()1815 void IRGenerator::findAndDeclareBuiltinVariables() {
1816     class BuiltinVariableScanner : public ProgramVisitor {
1817     public:
1818         BuiltinVariableScanner(IRGenerator* generator) : fGenerator(generator) {}
1819 
1820         void addDeclaringElement(const String& name) {
1821             // If this is the *first* time we've seen this builtin, findAndInclude will return
1822             // the corresponding ProgramElement.
1823             if (const ProgramElement* decl = fGenerator->fIntrinsics->findAndInclude(name)) {
1824                 SkASSERT(decl->is<GlobalVarDeclaration>() || decl->is<InterfaceBlock>());
1825                 fNewElements.push_back(decl);
1826             }
1827         }
1828 
1829         bool visitProgramElement(const ProgramElement& pe) override {
1830             if (pe.is<FunctionDefinition>()) {
1831                 const FunctionDefinition& funcDef = pe.as<FunctionDefinition>();
1832                 // We synthesize writes to sk_FragColor if main() returns a color, even if it's
1833                 // otherwise unreferenced. Check main's return type to see if it's half4.
1834                 if (funcDef.declaration().isMain() &&
1835                     funcDef.declaration().returnType() == *fGenerator->fContext.fTypes.fHalf4) {
1836                     fPreserveFragColor = true;
1837                 }
1838             }
1839             return INHERITED::visitProgramElement(pe);
1840         }
1841 
1842         bool visitExpression(const Expression& e) override {
1843             if (e.is<VariableReference>() && e.as<VariableReference>().variable()->isBuiltin()) {
1844                 this->addDeclaringElement(e.as<VariableReference>().variable()->name());
1845             }
1846             return INHERITED::visitExpression(e);
1847         }
1848 
1849         IRGenerator* fGenerator;
1850         std::vector<const ProgramElement*> fNewElements;
1851         bool fPreserveFragColor = false;
1852 
1853         using INHERITED = ProgramVisitor;
1854         using INHERITED::visitProgramElement;
1855     };
1856 
1857     BuiltinVariableScanner scanner(this);
1858     SkASSERT(fProgramElements);
1859     for (auto& e : *fProgramElements) {
1860         scanner.visitProgramElement(*e);
1861     }
1862 
1863     if (scanner.fPreserveFragColor) {
1864         // main() returns a half4, so make sure we don't dead-strip sk_FragColor.
1865         scanner.addDeclaringElement(Compiler::FRAGCOLOR_NAME);
1866     }
1867 
1868     switch (this->programKind()) {
1869         case ProgramKind::kFragment:
1870             // Vulkan requires certain builtin variables be present, even if they're unused. At one
1871             // time, validation errors would result if sk_Clockwise was missing. Now, it's just
1872             // (Adreno) driver bugs that drop or corrupt draws if they're missing.
1873             scanner.addDeclaringElement("sk_Clockwise");
1874             break;
1875         default:
1876             break;
1877     }
1878 
1879     fSharedElements->insert(
1880             fSharedElements->begin(), scanner.fNewElements.begin(), scanner.fNewElements.end());
1881 }
1882 
start(const ParsedModule & base,bool isBuiltinCode,const std::vector<std::unique_ptr<ExternalFunction>> * externalFunctions,std::vector<std::unique_ptr<ProgramElement>> * elements,std::vector<const ProgramElement * > * sharedElements)1883 void IRGenerator::start(const ParsedModule& base,
1884                         bool isBuiltinCode,
1885                         const std::vector<std::unique_ptr<ExternalFunction>>* externalFunctions,
1886                         std::vector<std::unique_ptr<ProgramElement>>* elements,
1887                         std::vector<const ProgramElement*>* sharedElements) {
1888     fProgramElements = elements;
1889     fSharedElements = sharedElements;
1890     fSymbolTable = base.fSymbols;
1891     fIntrinsics = base.fIntrinsics.get();
1892     if (fIntrinsics) {
1893         fIntrinsics->resetAlreadyIncluded();
1894     }
1895     fIsBuiltinCode = isBuiltinCode;
1896 
1897     fInputs.reset();
1898     fInvocations = -1;
1899     fRTAdjust = nullptr;
1900     fRTAdjustInterfaceBlock = nullptr;
1901     fDefinedStructs.clear();
1902     this->pushSymbolTable();
1903 
1904     if (this->programKind() == ProgramKind::kGeometry && !fIsBuiltinCode) {
1905         // Declare sk_InvocationID programmatically. With invocations support, it's an 'in' builtin.
1906         // If we're applying the workaround, then it's a plain global.
1907         bool workaround = !this->caps().gsInvocationsSupport();
1908         Modifiers m;
1909         if (!workaround) {
1910             m.fFlags = Modifiers::kIn_Flag;
1911             m.fLayout.fBuiltin = SK_INVOCATIONID_BUILTIN;
1912         }
1913         auto var = std::make_unique<Variable>(/*offset=*/-1, this->modifiersPool().add(m),
1914                                               "sk_InvocationID", fContext.fTypes.fInt.get(),
1915                                               /*builtin=*/false, Variable::Storage::kGlobal);
1916         auto decl = VarDeclaration::Make(fContext, var.get(), fContext.fTypes.fInt.get(),
1917                                          /*arraySize=*/0, /*value=*/nullptr);
1918         fSymbolTable->add(std::move(var));
1919         fProgramElements->push_back(std::make_unique<GlobalVarDeclaration>(std::move(decl)));
1920     }
1921 
1922     if (externalFunctions) {
1923         // Add any external values to the new symbol table, so they're only visible to this Program
1924         for (const auto& ef : *externalFunctions) {
1925             fSymbolTable->addWithoutOwnership(ef.get());
1926         }
1927     }
1928 }
1929 
finish()1930 IRGenerator::IRBundle IRGenerator::finish() {
1931     // Variables defined in the pre-includes need their declaring elements added to the program
1932     if (!fIsBuiltinCode && fIntrinsics) {
1933         this->findAndDeclareBuiltinVariables();
1934     }
1935 
1936     // Do a pass looking for dangling FunctionReference or TypeReference expressions
1937     class FindIllegalExpressions : public ProgramVisitor {
1938     public:
1939         FindIllegalExpressions(IRGenerator* generator) : fGenerator(generator) {}
1940 
1941         bool visitExpression(const Expression& e) override {
1942             fGenerator->checkValid(e);
1943             return INHERITED::visitExpression(e);
1944         }
1945 
1946         IRGenerator* fGenerator;
1947         using INHERITED = ProgramVisitor;
1948         using INHERITED::visitProgramElement;
1949     };
1950     for (const auto& pe : *fProgramElements) {
1951         FindIllegalExpressions{this}.visitProgramElement(*pe);
1952     }
1953 
1954     // If we're in ES2 mode (runtime effects), do a pass to enforce Appendix A, Section 5 of the
1955     // GLSL ES 1.00 spec -- Indexing. Don't bother if we've already found errors - this logic
1956     // assumes that all loops meet the criteria of Section 4, and if they don't, could crash.
1957     if (this->strictES2Mode() && this->errorReporter().errorCount() == 0) {
1958         for (const auto& pe : *fProgramElements) {
1959             Analysis::ValidateIndexingForES2(*pe, this->errorReporter());
1960         }
1961     }
1962 
1963     return IRBundle{std::move(*fProgramElements),
1964                     std::move(*fSharedElements),
1965                     std::move(fSymbolTable),
1966                     fInputs};
1967 }
1968 
convertProgram(const ParsedModule & base,bool isBuiltinCode,const char * text,size_t length,const std::vector<std::unique_ptr<ExternalFunction>> * externalFunctions)1969 IRGenerator::IRBundle IRGenerator::convertProgram(
1970         const ParsedModule& base,
1971         bool isBuiltinCode,
1972         const char* text,
1973         size_t length,
1974         const std::vector<std::unique_ptr<ExternalFunction>>* externalFunctions) {
1975     std::vector<std::unique_ptr<ProgramElement>> elements;
1976     std::vector<const ProgramElement*> sharedElements;
1977 
1978     this->start(base, isBuiltinCode, externalFunctions, &elements, &sharedElements);
1979 
1980     Parser parser(text, length, *fSymbolTable, this->errorReporter());
1981     fFile = parser.compilationUnit();
1982     if (this->errorReporter().errorCount() == 0) {
1983         SkASSERT(fFile);
1984         for (const auto& decl : fFile->root()) {
1985             switch (decl.fKind) {
1986                 case ASTNode::Kind::kVarDeclarations:
1987                     this->convertGlobalVarDeclarations(decl);
1988                     break;
1989 
1990                 case ASTNode::Kind::kEnum:
1991                     this->convertEnum(decl);
1992                     break;
1993 
1994                 case ASTNode::Kind::kFunction:
1995                     this->convertFunction(decl);
1996                     break;
1997 
1998                 case ASTNode::Kind::kModifiers: {
1999                     std::unique_ptr<ModifiersDeclaration> f =
2000                                                             this->convertModifiersDeclaration(decl);
2001                     if (f) {
2002                         fProgramElements->push_back(std::move(f));
2003                     }
2004                     break;
2005                 }
2006                 case ASTNode::Kind::kInterfaceBlock: {
2007                     std::unique_ptr<InterfaceBlock> i = this->convertInterfaceBlock(decl);
2008                     if (i) {
2009                         fProgramElements->push_back(std::move(i));
2010                     }
2011                     break;
2012                 }
2013                 case ASTNode::Kind::kExtension: {
2014                     std::unique_ptr<Extension> e = this->convertExtension(decl.fOffset,
2015                                                                           decl.getString());
2016                     if (e) {
2017                         fProgramElements->push_back(std::move(e));
2018                     }
2019                     break;
2020                 }
2021                 case ASTNode::Kind::kSection: {
2022                     std::unique_ptr<Section> s = this->convertSection(decl);
2023                     if (s) {
2024                         fProgramElements->push_back(std::move(s));
2025                     }
2026                     break;
2027                 }
2028                 case ASTNode::Kind::kType: {
2029                     std::unique_ptr<StructDefinition> s = this->convertStructDefinition(decl);
2030                     if (s) {
2031                         fProgramElements->push_back(std::move(s));
2032                     }
2033                     break;
2034                 }
2035                 default:
2036                     SkDEBUGFAILF("unsupported declaration: %s\n", decl.description().c_str());
2037                     break;
2038             }
2039         }
2040     }
2041     return this->finish();
2042 }
2043 
2044 }  // namespace SkSL
2045