1 /*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "SkSLIRGenerator.h"
9
10 #include "limits.h"
11 #include <unordered_set>
12
13 #include "SkSLCompiler.h"
14 #include "ast/SkSLASTBoolLiteral.h"
15 #include "ast/SkSLASTFieldSuffix.h"
16 #include "ast/SkSLASTFloatLiteral.h"
17 #include "ast/SkSLASTIndexSuffix.h"
18 #include "ast/SkSLASTIntLiteral.h"
19 #include "ir/SkSLBinaryExpression.h"
20 #include "ir/SkSLBoolLiteral.h"
21 #include "ir/SkSLBreakStatement.h"
22 #include "ir/SkSLConstructor.h"
23 #include "ir/SkSLContinueStatement.h"
24 #include "ir/SkSLDiscardStatement.h"
25 #include "ir/SkSLDoStatement.h"
26 #include "ir/SkSLExpressionStatement.h"
27 #include "ir/SkSLField.h"
28 #include "ir/SkSLFieldAccess.h"
29 #include "ir/SkSLFloatLiteral.h"
30 #include "ir/SkSLForStatement.h"
31 #include "ir/SkSLFunctionCall.h"
32 #include "ir/SkSLFunctionDeclaration.h"
33 #include "ir/SkSLFunctionDefinition.h"
34 #include "ir/SkSLFunctionReference.h"
35 #include "ir/SkSLIfStatement.h"
36 #include "ir/SkSLIndexExpression.h"
37 #include "ir/SkSLInterfaceBlock.h"
38 #include "ir/SkSLIntLiteral.h"
39 #include "ir/SkSLLayout.h"
40 #include "ir/SkSLPostfixExpression.h"
41 #include "ir/SkSLPrefixExpression.h"
42 #include "ir/SkSLReturnStatement.h"
43 #include "ir/SkSLSwitchCase.h"
44 #include "ir/SkSLSwitchStatement.h"
45 #include "ir/SkSLSwizzle.h"
46 #include "ir/SkSLTernaryExpression.h"
47 #include "ir/SkSLUnresolvedFunction.h"
48 #include "ir/SkSLVariable.h"
49 #include "ir/SkSLVarDeclarations.h"
50 #include "ir/SkSLVarDeclarationsStatement.h"
51 #include "ir/SkSLVariableReference.h"
52 #include "ir/SkSLWhileStatement.h"
53
54 namespace SkSL {
55
56 class AutoSymbolTable {
57 public:
AutoSymbolTable(IRGenerator * ir)58 AutoSymbolTable(IRGenerator* ir)
59 : fIR(ir)
60 , fPrevious(fIR->fSymbolTable) {
61 fIR->pushSymbolTable();
62 }
63
~AutoSymbolTable()64 ~AutoSymbolTable() {
65 fIR->popSymbolTable();
66 ASSERT(fPrevious == fIR->fSymbolTable);
67 }
68
69 IRGenerator* fIR;
70 std::shared_ptr<SymbolTable> fPrevious;
71 };
72
73 class AutoLoopLevel {
74 public:
AutoLoopLevel(IRGenerator * ir)75 AutoLoopLevel(IRGenerator* ir)
76 : fIR(ir) {
77 fIR->fLoopLevel++;
78 }
79
~AutoLoopLevel()80 ~AutoLoopLevel() {
81 fIR->fLoopLevel--;
82 }
83
84 IRGenerator* fIR;
85 };
86
87 class AutoSwitchLevel {
88 public:
AutoSwitchLevel(IRGenerator * ir)89 AutoSwitchLevel(IRGenerator* ir)
90 : fIR(ir) {
91 fIR->fSwitchLevel++;
92 }
93
~AutoSwitchLevel()94 ~AutoSwitchLevel() {
95 fIR->fSwitchLevel--;
96 }
97
98 IRGenerator* fIR;
99 };
100
IRGenerator(const Context * context,std::shared_ptr<SymbolTable> symbolTable,ErrorReporter & errorReporter)101 IRGenerator::IRGenerator(const Context* context, std::shared_ptr<SymbolTable> symbolTable,
102 ErrorReporter& errorReporter)
103 : fContext(*context)
104 , fCurrentFunction(nullptr)
105 , fSymbolTable(std::move(symbolTable))
106 , fLoopLevel(0)
107 , fSwitchLevel(0)
108 , fErrors(errorReporter) {}
109
pushSymbolTable()110 void IRGenerator::pushSymbolTable() {
111 fSymbolTable.reset(new SymbolTable(std::move(fSymbolTable), fErrors));
112 }
113
popSymbolTable()114 void IRGenerator::popSymbolTable() {
115 fSymbolTable = fSymbolTable->fParent;
116 }
117
fill_caps(const GrShaderCaps & caps,std::unordered_map<SkString,CapValue> * capsMap)118 static void fill_caps(const GrShaderCaps& caps, std::unordered_map<SkString, CapValue>* capsMap) {
119 #define CAP(name) capsMap->insert(std::make_pair(SkString(#name), CapValue(caps.name())));
120 CAP(fbFetchSupport);
121 CAP(fbFetchNeedsCustomOutput);
122 CAP(bindlessTextureSupport);
123 CAP(dropsTileOnZeroDivide);
124 CAP(flatInterpolationSupport);
125 CAP(noperspectiveInterpolationSupport);
126 CAP(multisampleInterpolationSupport);
127 CAP(sampleVariablesSupport);
128 CAP(sampleMaskOverrideCoverageSupport);
129 CAP(externalTextureSupport);
130 CAP(texelFetchSupport);
131 CAP(imageLoadStoreSupport);
132 CAP(mustEnableAdvBlendEqs);
133 CAP(mustEnableSpecificAdvBlendEqs);
134 CAP(mustDeclareFragmentShaderOutput);
135 CAP(canUseAnyFunctionInShader);
136 #undef CAP
137 }
138
start(const Program::Settings * settings)139 void IRGenerator::start(const Program::Settings* settings) {
140 fSettings = settings;
141 fCapsMap.clear();
142 if (settings->fCaps) {
143 fill_caps(*settings->fCaps, &fCapsMap);
144 }
145 this->pushSymbolTable();
146 fInputs.reset();
147 }
148
finish()149 void IRGenerator::finish() {
150 this->popSymbolTable();
151 fSettings = nullptr;
152 }
153
convertExtension(const ASTExtension & extension)154 std::unique_ptr<Extension> IRGenerator::convertExtension(const ASTExtension& extension) {
155 return std::unique_ptr<Extension>(new Extension(extension.fPosition, extension.fName));
156 }
157
convertStatement(const ASTStatement & statement)158 std::unique_ptr<Statement> IRGenerator::convertStatement(const ASTStatement& statement) {
159 switch (statement.fKind) {
160 case ASTStatement::kBlock_Kind:
161 return this->convertBlock((ASTBlock&) statement);
162 case ASTStatement::kVarDeclaration_Kind:
163 return this->convertVarDeclarationStatement((ASTVarDeclarationStatement&) statement);
164 case ASTStatement::kExpression_Kind:
165 return this->convertExpressionStatement((ASTExpressionStatement&) statement);
166 case ASTStatement::kIf_Kind:
167 return this->convertIf((ASTIfStatement&) statement);
168 case ASTStatement::kFor_Kind:
169 return this->convertFor((ASTForStatement&) statement);
170 case ASTStatement::kWhile_Kind:
171 return this->convertWhile((ASTWhileStatement&) statement);
172 case ASTStatement::kDo_Kind:
173 return this->convertDo((ASTDoStatement&) statement);
174 case ASTStatement::kSwitch_Kind:
175 return this->convertSwitch((ASTSwitchStatement&) statement);
176 case ASTStatement::kReturn_Kind:
177 return this->convertReturn((ASTReturnStatement&) statement);
178 case ASTStatement::kBreak_Kind:
179 return this->convertBreak((ASTBreakStatement&) statement);
180 case ASTStatement::kContinue_Kind:
181 return this->convertContinue((ASTContinueStatement&) statement);
182 case ASTStatement::kDiscard_Kind:
183 return this->convertDiscard((ASTDiscardStatement&) statement);
184 default:
185 ABORT("unsupported statement type: %d\n", statement.fKind);
186 }
187 }
188
convertBlock(const ASTBlock & block)189 std::unique_ptr<Block> IRGenerator::convertBlock(const ASTBlock& block) {
190 AutoSymbolTable table(this);
191 std::vector<std::unique_ptr<Statement>> statements;
192 for (size_t i = 0; i < block.fStatements.size(); i++) {
193 std::unique_ptr<Statement> statement = this->convertStatement(*block.fStatements[i]);
194 if (!statement) {
195 return nullptr;
196 }
197 statements.push_back(std::move(statement));
198 }
199 return std::unique_ptr<Block>(new Block(block.fPosition, std::move(statements), fSymbolTable));
200 }
201
convertVarDeclarationStatement(const ASTVarDeclarationStatement & s)202 std::unique_ptr<Statement> IRGenerator::convertVarDeclarationStatement(
203 const ASTVarDeclarationStatement& s) {
204 auto decl = this->convertVarDeclarations(*s.fDeclarations, Variable::kLocal_Storage);
205 if (!decl) {
206 return nullptr;
207 }
208 return std::unique_ptr<Statement>(new VarDeclarationsStatement(std::move(decl)));
209 }
210
convertVarDeclarations(const ASTVarDeclarations & decl,Variable::Storage storage)211 std::unique_ptr<VarDeclarations> IRGenerator::convertVarDeclarations(const ASTVarDeclarations& decl,
212 Variable::Storage storage) {
213 std::vector<VarDeclaration> variables;
214 const Type* baseType = this->convertType(*decl.fType);
215 if (!baseType) {
216 return nullptr;
217 }
218 for (const auto& varDecl : decl.fVars) {
219 const Type* type = baseType;
220 std::vector<std::unique_ptr<Expression>> sizes;
221 for (const auto& rawSize : varDecl.fSizes) {
222 if (rawSize) {
223 auto size = this->coerce(this->convertExpression(*rawSize), *fContext.fInt_Type);
224 if (!size) {
225 return nullptr;
226 }
227 SkString name = type->fName;
228 int64_t count;
229 if (size->fKind == Expression::kIntLiteral_Kind) {
230 count = ((IntLiteral&) *size).fValue;
231 if (count <= 0) {
232 fErrors.error(size->fPosition, "array size must be positive");
233 }
234 name += "[" + to_string(count) + "]";
235 } else {
236 count = -1;
237 name += "[]";
238 }
239 type = new Type(name, Type::kArray_Kind, *type, (int) count);
240 fSymbolTable->takeOwnership((Type*) type);
241 sizes.push_back(std::move(size));
242 } else {
243 type = new Type(type->fName + "[]", Type::kArray_Kind, *type, -1);
244 fSymbolTable->takeOwnership((Type*) type);
245 sizes.push_back(nullptr);
246 }
247 }
248 auto var = std::unique_ptr<Variable>(new Variable(decl.fPosition, decl.fModifiers,
249 varDecl.fName, *type, storage));
250 std::unique_ptr<Expression> value;
251 if (varDecl.fValue) {
252 value = this->convertExpression(*varDecl.fValue);
253 if (!value) {
254 return nullptr;
255 }
256 value = this->coerce(std::move(value), *type);
257 }
258 if (storage == Variable::kGlobal_Storage && varDecl.fName == SkString("sk_FragColor") &&
259 (*fSymbolTable)[varDecl.fName]) {
260 // already defined, ignore
261 } else if (storage == Variable::kGlobal_Storage && (*fSymbolTable)[varDecl.fName] &&
262 (*fSymbolTable)[varDecl.fName]->fKind == Symbol::kVariable_Kind &&
263 ((Variable*) (*fSymbolTable)[varDecl.fName])->fModifiers.fLayout.fBuiltin >= 0) {
264 // already defined, just update the modifiers
265 Variable* old = (Variable*) (*fSymbolTable)[varDecl.fName];
266 old->fModifiers = var->fModifiers;
267 } else {
268 variables.emplace_back(var.get(), std::move(sizes), std::move(value));
269 fSymbolTable->add(varDecl.fName, std::move(var));
270 }
271 }
272 return std::unique_ptr<VarDeclarations>(new VarDeclarations(decl.fPosition,
273 baseType,
274 std::move(variables)));
275 }
276
convertModifiersDeclaration(const ASTModifiersDeclaration & m)277 std::unique_ptr<ModifiersDeclaration> IRGenerator::convertModifiersDeclaration(
278 const ASTModifiersDeclaration& m) {
279 return std::unique_ptr<ModifiersDeclaration>(new ModifiersDeclaration(m.fModifiers));
280 }
281
convertIf(const ASTIfStatement & s)282 std::unique_ptr<Statement> IRGenerator::convertIf(const ASTIfStatement& s) {
283 std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*s.fTest),
284 *fContext.fBool_Type);
285 if (!test) {
286 return nullptr;
287 }
288 std::unique_ptr<Statement> ifTrue = this->convertStatement(*s.fIfTrue);
289 if (!ifTrue) {
290 return nullptr;
291 }
292 std::unique_ptr<Statement> ifFalse;
293 if (s.fIfFalse) {
294 ifFalse = this->convertStatement(*s.fIfFalse);
295 if (!ifFalse) {
296 return nullptr;
297 }
298 }
299 if (test->fKind == Expression::kBoolLiteral_Kind) {
300 // static boolean value, fold down to a single branch
301 if (((BoolLiteral&) *test).fValue) {
302 return ifTrue;
303 } else if (s.fIfFalse) {
304 return ifFalse;
305 } else {
306 // False & no else clause. Not an error, so don't return null!
307 std::vector<std::unique_ptr<Statement>> empty;
308 return std::unique_ptr<Statement>(new Block(s.fPosition, std::move(empty),
309 fSymbolTable));
310 }
311 }
312 return std::unique_ptr<Statement>(new IfStatement(s.fPosition, std::move(test),
313 std::move(ifTrue), std::move(ifFalse)));
314 }
315
convertFor(const ASTForStatement & f)316 std::unique_ptr<Statement> IRGenerator::convertFor(const ASTForStatement& f) {
317 AutoLoopLevel level(this);
318 AutoSymbolTable table(this);
319 std::unique_ptr<Statement> initializer;
320 if (f.fInitializer) {
321 initializer = this->convertStatement(*f.fInitializer);
322 if (!initializer) {
323 return nullptr;
324 }
325 }
326 std::unique_ptr<Expression> test;
327 if (f.fTest) {
328 test = this->coerce(this->convertExpression(*f.fTest), *fContext.fBool_Type);
329 if (!test) {
330 return nullptr;
331 }
332 }
333 std::unique_ptr<Expression> next;
334 if (f.fNext) {
335 next = this->convertExpression(*f.fNext);
336 if (!next) {
337 return nullptr;
338 }
339 this->checkValid(*next);
340 }
341 std::unique_ptr<Statement> statement = this->convertStatement(*f.fStatement);
342 if (!statement) {
343 return nullptr;
344 }
345 return std::unique_ptr<Statement>(new ForStatement(f.fPosition, std::move(initializer),
346 std::move(test), std::move(next),
347 std::move(statement), fSymbolTable));
348 }
349
convertWhile(const ASTWhileStatement & w)350 std::unique_ptr<Statement> IRGenerator::convertWhile(const ASTWhileStatement& w) {
351 AutoLoopLevel level(this);
352 std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*w.fTest),
353 *fContext.fBool_Type);
354 if (!test) {
355 return nullptr;
356 }
357 std::unique_ptr<Statement> statement = this->convertStatement(*w.fStatement);
358 if (!statement) {
359 return nullptr;
360 }
361 return std::unique_ptr<Statement>(new WhileStatement(w.fPosition, std::move(test),
362 std::move(statement)));
363 }
364
convertDo(const ASTDoStatement & d)365 std::unique_ptr<Statement> IRGenerator::convertDo(const ASTDoStatement& d) {
366 AutoLoopLevel level(this);
367 std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*d.fTest),
368 *fContext.fBool_Type);
369 if (!test) {
370 return nullptr;
371 }
372 std::unique_ptr<Statement> statement = this->convertStatement(*d.fStatement);
373 if (!statement) {
374 return nullptr;
375 }
376 return std::unique_ptr<Statement>(new DoStatement(d.fPosition, std::move(statement),
377 std::move(test)));
378 }
379
convertSwitch(const ASTSwitchStatement & s)380 std::unique_ptr<Statement> IRGenerator::convertSwitch(const ASTSwitchStatement& s) {
381 AutoSwitchLevel level(this);
382 std::unique_ptr<Expression> value = this->convertExpression(*s.fValue);
383 if (!value) {
384 return nullptr;
385 }
386 if (value->fType != *fContext.fUInt_Type) {
387 value = this->coerce(std::move(value), *fContext.fInt_Type);
388 if (!value) {
389 return nullptr;
390 }
391 }
392 AutoSymbolTable table(this);
393 std::unordered_set<int> caseValues;
394 std::vector<std::unique_ptr<SwitchCase>> cases;
395 for (const auto& c : s.fCases) {
396 std::unique_ptr<Expression> caseValue;
397 if (c->fValue) {
398 caseValue = this->convertExpression(*c->fValue);
399 if (!caseValue) {
400 return nullptr;
401 }
402 if (caseValue->fType != *fContext.fUInt_Type) {
403 caseValue = this->coerce(std::move(caseValue), *fContext.fInt_Type);
404 if (!caseValue) {
405 return nullptr;
406 }
407 }
408 if (!caseValue->isConstant()) {
409 fErrors.error(caseValue->fPosition, "case value must be a constant");
410 return nullptr;
411 }
412 ASSERT(caseValue->fKind == Expression::kIntLiteral_Kind);
413 int64_t v = ((IntLiteral&) *caseValue).fValue;
414 if (caseValues.find(v) != caseValues.end()) {
415 fErrors.error(caseValue->fPosition, "duplicate case value");
416 }
417 caseValues.insert(v);
418 }
419 std::vector<std::unique_ptr<Statement>> statements;
420 for (const auto& s : c->fStatements) {
421 std::unique_ptr<Statement> converted = this->convertStatement(*s);
422 if (!converted) {
423 return nullptr;
424 }
425 statements.push_back(std::move(converted));
426 }
427 cases.emplace_back(new SwitchCase(c->fPosition, std::move(caseValue),
428 std::move(statements)));
429 }
430 return std::unique_ptr<Statement>(new SwitchStatement(s.fPosition, std::move(value),
431 std::move(cases)));
432 }
433
convertExpressionStatement(const ASTExpressionStatement & s)434 std::unique_ptr<Statement> IRGenerator::convertExpressionStatement(
435 const ASTExpressionStatement& s) {
436 std::unique_ptr<Expression> e = this->convertExpression(*s.fExpression);
437 if (!e) {
438 return nullptr;
439 }
440 this->checkValid(*e);
441 return std::unique_ptr<Statement>(new ExpressionStatement(std::move(e)));
442 }
443
convertReturn(const ASTReturnStatement & r)444 std::unique_ptr<Statement> IRGenerator::convertReturn(const ASTReturnStatement& r) {
445 ASSERT(fCurrentFunction);
446 if (r.fExpression) {
447 std::unique_ptr<Expression> result = this->convertExpression(*r.fExpression);
448 if (!result) {
449 return nullptr;
450 }
451 if (fCurrentFunction->fReturnType == *fContext.fVoid_Type) {
452 fErrors.error(result->fPosition, "may not return a value from a void function");
453 } else {
454 result = this->coerce(std::move(result), fCurrentFunction->fReturnType);
455 if (!result) {
456 return nullptr;
457 }
458 }
459 return std::unique_ptr<Statement>(new ReturnStatement(std::move(result)));
460 } else {
461 if (fCurrentFunction->fReturnType != *fContext.fVoid_Type) {
462 fErrors.error(r.fPosition, "expected function to return '" +
463 fCurrentFunction->fReturnType.description() + "'");
464 }
465 return std::unique_ptr<Statement>(new ReturnStatement(r.fPosition));
466 }
467 }
468
convertBreak(const ASTBreakStatement & b)469 std::unique_ptr<Statement> IRGenerator::convertBreak(const ASTBreakStatement& b) {
470 if (fLoopLevel > 0 || fSwitchLevel > 0) {
471 return std::unique_ptr<Statement>(new BreakStatement(b.fPosition));
472 } else {
473 fErrors.error(b.fPosition, "break statement must be inside a loop or switch");
474 return nullptr;
475 }
476 }
477
convertContinue(const ASTContinueStatement & c)478 std::unique_ptr<Statement> IRGenerator::convertContinue(const ASTContinueStatement& c) {
479 if (fLoopLevel > 0) {
480 return std::unique_ptr<Statement>(new ContinueStatement(c.fPosition));
481 } else {
482 fErrors.error(c.fPosition, "continue statement must be inside a loop");
483 return nullptr;
484 }
485 }
486
convertDiscard(const ASTDiscardStatement & d)487 std::unique_ptr<Statement> IRGenerator::convertDiscard(const ASTDiscardStatement& d) {
488 return std::unique_ptr<Statement>(new DiscardStatement(d.fPosition));
489 }
490
convertFunction(const ASTFunction & f)491 std::unique_ptr<FunctionDefinition> IRGenerator::convertFunction(const ASTFunction& f) {
492 const Type* returnType = this->convertType(*f.fReturnType);
493 if (!returnType) {
494 return nullptr;
495 }
496 std::vector<const Variable*> parameters;
497 for (const auto& param : f.fParameters) {
498 const Type* type = this->convertType(*param->fType);
499 if (!type) {
500 return nullptr;
501 }
502 for (int j = (int) param->fSizes.size() - 1; j >= 0; j--) {
503 int size = param->fSizes[j];
504 SkString name = type->name() + "[" + to_string(size) + "]";
505 Type* newType = new Type(std::move(name), Type::kArray_Kind, *type, size);
506 fSymbolTable->takeOwnership(newType);
507 type = newType;
508 }
509 SkString name = param->fName;
510 Position pos = param->fPosition;
511 Variable* var = new Variable(pos, param->fModifiers, std::move(name), *type,
512 Variable::kParameter_Storage);
513 fSymbolTable->takeOwnership(var);
514 parameters.push_back(var);
515 }
516
517 // find existing declaration
518 const FunctionDeclaration* decl = nullptr;
519 auto entry = (*fSymbolTable)[f.fName];
520 if (entry) {
521 std::vector<const FunctionDeclaration*> functions;
522 switch (entry->fKind) {
523 case Symbol::kUnresolvedFunction_Kind:
524 functions = ((UnresolvedFunction*) entry)->fFunctions;
525 break;
526 case Symbol::kFunctionDeclaration_Kind:
527 functions.push_back((FunctionDeclaration*) entry);
528 break;
529 default:
530 fErrors.error(f.fPosition, "symbol '" + f.fName + "' was already defined");
531 return nullptr;
532 }
533 for (const auto& other : functions) {
534 ASSERT(other->fName == f.fName);
535 if (parameters.size() == other->fParameters.size()) {
536 bool match = true;
537 for (size_t i = 0; i < parameters.size(); i++) {
538 if (parameters[i]->fType != other->fParameters[i]->fType) {
539 match = false;
540 break;
541 }
542 }
543 if (match) {
544 if (*returnType != other->fReturnType) {
545 FunctionDeclaration newDecl(f.fPosition, f.fName, parameters, *returnType);
546 fErrors.error(f.fPosition, "functions '" + newDecl.description() +
547 "' and '" + other->description() +
548 "' differ only in return type");
549 return nullptr;
550 }
551 decl = other;
552 for (size_t i = 0; i < parameters.size(); i++) {
553 if (parameters[i]->fModifiers != other->fParameters[i]->fModifiers) {
554 fErrors.error(f.fPosition, "modifiers on parameter " +
555 to_string((uint64_t) i + 1) +
556 " differ between declaration and "
557 "definition");
558 return nullptr;
559 }
560 }
561 if (other->fDefined) {
562 fErrors.error(f.fPosition, "duplicate definition of " +
563 other->description());
564 }
565 break;
566 }
567 }
568 }
569 }
570 if (!decl) {
571 // couldn't find an existing declaration
572 auto newDecl = std::unique_ptr<FunctionDeclaration>(new FunctionDeclaration(f.fPosition,
573 f.fName,
574 parameters,
575 *returnType));
576 decl = newDecl.get();
577 fSymbolTable->add(decl->fName, std::move(newDecl));
578 }
579 if (f.fBody) {
580 ASSERT(!fCurrentFunction);
581 fCurrentFunction = decl;
582 decl->fDefined = true;
583 std::shared_ptr<SymbolTable> old = fSymbolTable;
584 AutoSymbolTable table(this);
585 for (size_t i = 0; i < parameters.size(); i++) {
586 fSymbolTable->addWithoutOwnership(parameters[i]->fName, decl->fParameters[i]);
587 }
588 std::unique_ptr<Block> body = this->convertBlock(*f.fBody);
589 fCurrentFunction = nullptr;
590 if (!body) {
591 return nullptr;
592 }
593 return std::unique_ptr<FunctionDefinition>(new FunctionDefinition(f.fPosition, *decl,
594 std::move(body)));
595 }
596 return nullptr;
597 }
598
convertInterfaceBlock(const ASTInterfaceBlock & intf)599 std::unique_ptr<InterfaceBlock> IRGenerator::convertInterfaceBlock(const ASTInterfaceBlock& intf) {
600 std::shared_ptr<SymbolTable> old = fSymbolTable;
601 AutoSymbolTable table(this);
602 std::vector<Type::Field> fields;
603 for (size_t i = 0; i < intf.fDeclarations.size(); i++) {
604 std::unique_ptr<VarDeclarations> decl = this->convertVarDeclarations(
605 *intf.fDeclarations[i],
606 Variable::kGlobal_Storage);
607 if (!decl) {
608 return nullptr;
609 }
610 for (const auto& var : decl->fVars) {
611 fields.push_back(Type::Field(var.fVar->fModifiers, var.fVar->fName,
612 &var.fVar->fType));
613 if (var.fValue) {
614 fErrors.error(decl->fPosition,
615 "initializers are not permitted on interface block fields");
616 }
617 if (var.fVar->fModifiers.fFlags & (Modifiers::kIn_Flag |
618 Modifiers::kOut_Flag |
619 Modifiers::kUniform_Flag |
620 Modifiers::kConst_Flag)) {
621 fErrors.error(decl->fPosition,
622 "interface block fields may not have storage qualifiers");
623 }
624 }
625 }
626 Type* type = new Type(intf.fPosition, intf.fTypeName, fields);
627 old->takeOwnership(type);
628 std::vector<std::unique_ptr<Expression>> sizes;
629 for (const auto& size : intf.fSizes) {
630 if (size) {
631 std::unique_ptr<Expression> converted = this->convertExpression(*size);
632 if (!converted) {
633 return nullptr;
634 }
635 SkString name = type->fName;
636 int64_t count;
637 if (converted->fKind == Expression::kIntLiteral_Kind) {
638 count = ((IntLiteral&) *converted).fValue;
639 if (count <= 0) {
640 fErrors.error(converted->fPosition, "array size must be positive");
641 }
642 name += "[" + to_string(count) + "]";
643 } else {
644 count = -1;
645 name += "[]";
646 }
647 type = new Type(name, Type::kArray_Kind, *type, (int) count);
648 fSymbolTable->takeOwnership((Type*) type);
649 sizes.push_back(std::move(converted));
650 } else {
651 type = new Type(type->fName + "[]", Type::kArray_Kind, *type, -1);
652 fSymbolTable->takeOwnership((Type*) type);
653 sizes.push_back(nullptr);
654 }
655 }
656 Variable* var = new Variable(intf.fPosition, intf.fModifiers,
657 intf.fInstanceName.size() ? intf.fInstanceName : intf.fTypeName,
658 *type, Variable::kGlobal_Storage);
659 old->takeOwnership(var);
660 if (intf.fInstanceName.size()) {
661 old->addWithoutOwnership(intf.fInstanceName, var);
662 } else {
663 for (size_t i = 0; i < fields.size(); i++) {
664 old->add(fields[i].fName, std::unique_ptr<Field>(new Field(intf.fPosition, *var,
665 (int) i)));
666 }
667 }
668 return std::unique_ptr<InterfaceBlock>(new InterfaceBlock(intf.fPosition, *var,
669 intf.fTypeName,
670 intf.fInstanceName,
671 std::move(sizes),
672 fSymbolTable));
673 }
674
convertType(const ASTType & type)675 const Type* IRGenerator::convertType(const ASTType& type) {
676 const Symbol* result = (*fSymbolTable)[type.fName];
677 if (result && result->fKind == Symbol::kType_Kind) {
678 for (int size : type.fSizes) {
679 SkString name = result->fName + "[";
680 if (size != -1) {
681 name += to_string(size);
682 }
683 name += "]";
684 result = new Type(name, Type::kArray_Kind, (const Type&) *result, size);
685 fSymbolTable->takeOwnership((Type*) result);
686 }
687 return (const Type*) result;
688 }
689 fErrors.error(type.fPosition, "unknown type '" + type.fName + "'");
690 return nullptr;
691 }
692
convertExpression(const ASTExpression & expr)693 std::unique_ptr<Expression> IRGenerator::convertExpression(const ASTExpression& expr) {
694 switch (expr.fKind) {
695 case ASTExpression::kIdentifier_Kind:
696 return this->convertIdentifier((ASTIdentifier&) expr);
697 case ASTExpression::kBool_Kind:
698 return std::unique_ptr<Expression>(new BoolLiteral(fContext, expr.fPosition,
699 ((ASTBoolLiteral&) expr).fValue));
700 case ASTExpression::kInt_Kind:
701 return std::unique_ptr<Expression>(new IntLiteral(fContext, expr.fPosition,
702 ((ASTIntLiteral&) expr).fValue));
703 case ASTExpression::kFloat_Kind:
704 return std::unique_ptr<Expression>(new FloatLiteral(fContext, expr.fPosition,
705 ((ASTFloatLiteral&) expr).fValue));
706 case ASTExpression::kBinary_Kind:
707 return this->convertBinaryExpression((ASTBinaryExpression&) expr);
708 case ASTExpression::kPrefix_Kind:
709 return this->convertPrefixExpression((ASTPrefixExpression&) expr);
710 case ASTExpression::kSuffix_Kind:
711 return this->convertSuffixExpression((ASTSuffixExpression&) expr);
712 case ASTExpression::kTernary_Kind:
713 return this->convertTernaryExpression((ASTTernaryExpression&) expr);
714 default:
715 ABORT("unsupported expression type: %d\n", expr.fKind);
716 }
717 }
718
convertIdentifier(const ASTIdentifier & identifier)719 std::unique_ptr<Expression> IRGenerator::convertIdentifier(const ASTIdentifier& identifier) {
720 const Symbol* result = (*fSymbolTable)[identifier.fText];
721 if (!result) {
722 fErrors.error(identifier.fPosition, "unknown identifier '" + identifier.fText + "'");
723 return nullptr;
724 }
725 switch (result->fKind) {
726 case Symbol::kFunctionDeclaration_Kind: {
727 std::vector<const FunctionDeclaration*> f = {
728 (const FunctionDeclaration*) result
729 };
730 return std::unique_ptr<FunctionReference>(new FunctionReference(fContext,
731 identifier.fPosition,
732 f));
733 }
734 case Symbol::kUnresolvedFunction_Kind: {
735 const UnresolvedFunction* f = (const UnresolvedFunction*) result;
736 return std::unique_ptr<FunctionReference>(new FunctionReference(fContext,
737 identifier.fPosition,
738 f->fFunctions));
739 }
740 case Symbol::kVariable_Kind: {
741 const Variable* var = (const Variable*) result;
742 if (var->fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
743 fInputs.fFlipY = true;
744 if (fSettings->fFlipY &&
745 (!fSettings->fCaps ||
746 !fSettings->fCaps->fragCoordConventionsExtensionString())) {
747 fInputs.fRTHeight = true;
748 }
749 }
750 // default to kRead_RefKind; this will be corrected later if the variable is written to
751 return std::unique_ptr<VariableReference>(new VariableReference(
752 identifier.fPosition,
753 *var,
754 VariableReference::kRead_RefKind));
755 }
756 case Symbol::kField_Kind: {
757 const Field* field = (const Field*) result;
758 VariableReference* base = new VariableReference(identifier.fPosition, field->fOwner,
759 VariableReference::kRead_RefKind);
760 return std::unique_ptr<Expression>(new FieldAccess(
761 std::unique_ptr<Expression>(base),
762 field->fFieldIndex,
763 FieldAccess::kAnonymousInterfaceBlock_OwnerKind));
764 }
765 case Symbol::kType_Kind: {
766 const Type* t = (const Type*) result;
767 return std::unique_ptr<TypeReference>(new TypeReference(fContext, identifier.fPosition,
768 *t));
769 }
770 default:
771 ABORT("unsupported symbol type %d\n", result->fKind);
772 }
773
774 }
775
coerce(std::unique_ptr<Expression> expr,const Type & type)776 std::unique_ptr<Expression> IRGenerator::coerce(std::unique_ptr<Expression> expr,
777 const Type& type) {
778 if (!expr) {
779 return nullptr;
780 }
781 if (expr->fType == type) {
782 return expr;
783 }
784 this->checkValid(*expr);
785 if (expr->fType == *fContext.fInvalid_Type) {
786 return nullptr;
787 }
788 if (!expr->fType.canCoerceTo(type)) {
789 fErrors.error(expr->fPosition, "expected '" + type.description() + "', but found '" +
790 expr->fType.description() + "'");
791 return nullptr;
792 }
793 if (type.kind() == Type::kScalar_Kind) {
794 std::vector<std::unique_ptr<Expression>> args;
795 args.push_back(std::move(expr));
796 ASTIdentifier id(Position(), type.description());
797 std::unique_ptr<Expression> ctor = this->convertIdentifier(id);
798 ASSERT(ctor);
799 return this->call(Position(), std::move(ctor), std::move(args));
800 }
801 std::vector<std::unique_ptr<Expression>> args;
802 args.push_back(std::move(expr));
803 return std::unique_ptr<Expression>(new Constructor(Position(), type, std::move(args)));
804 }
805
is_matrix_multiply(const Type & left,const Type & right)806 static bool is_matrix_multiply(const Type& left, const Type& right) {
807 if (left.kind() == Type::kMatrix_Kind) {
808 return right.kind() == Type::kMatrix_Kind || right.kind() == Type::kVector_Kind;
809 }
810 return left.kind() == Type::kVector_Kind && right.kind() == Type::kMatrix_Kind;
811 }
812
813 /**
814 * Determines the operand and result types of a binary expression. Returns true if the expression is
815 * legal, false otherwise. If false, the values of the out parameters are undefined.
816 */
determine_binary_type(const Context & context,Token::Kind op,const Type & left,const Type & right,const Type ** outLeftType,const Type ** outRightType,const Type ** outResultType,bool tryFlipped)817 static bool determine_binary_type(const Context& context,
818 Token::Kind op,
819 const Type& left,
820 const Type& right,
821 const Type** outLeftType,
822 const Type** outRightType,
823 const Type** outResultType,
824 bool tryFlipped) {
825 bool isLogical;
826 bool validMatrixOrVectorOp;
827 switch (op) {
828 case Token::EQ:
829 *outLeftType = &left;
830 *outRightType = &left;
831 *outResultType = &left;
832 return right.canCoerceTo(left);
833 case Token::EQEQ: // fall through
834 case Token::NEQ:
835 isLogical = true;
836 validMatrixOrVectorOp = true;
837 break;
838 case Token::LT: // fall through
839 case Token::GT: // fall through
840 case Token::LTEQ: // fall through
841 case Token::GTEQ:
842 isLogical = true;
843 validMatrixOrVectorOp = false;
844 break;
845 case Token::LOGICALOR: // fall through
846 case Token::LOGICALAND: // fall through
847 case Token::LOGICALXOR: // fall through
848 case Token::LOGICALOREQ: // fall through
849 case Token::LOGICALANDEQ: // fall through
850 case Token::LOGICALXOREQ:
851 *outLeftType = context.fBool_Type.get();
852 *outRightType = context.fBool_Type.get();
853 *outResultType = context.fBool_Type.get();
854 return left.canCoerceTo(*context.fBool_Type) &&
855 right.canCoerceTo(*context.fBool_Type);
856 case Token::STAR: // fall through
857 case Token::STAREQ:
858 if (is_matrix_multiply(left, right)) {
859 // determine final component type
860 if (determine_binary_type(context, Token::STAR, left.componentType(),
861 right.componentType(), outLeftType, outRightType,
862 outResultType, false)) {
863 *outLeftType = &(*outResultType)->toCompound(context, left.columns(),
864 left.rows());;
865 *outRightType = &(*outResultType)->toCompound(context, right.columns(),
866 right.rows());;
867 int leftColumns = left.columns();
868 int leftRows = left.rows();
869 int rightColumns;
870 int rightRows;
871 if (right.kind() == Type::kVector_Kind) {
872 // matrix * vector treats the vector as a column vector, so we need to
873 // transpose it
874 rightColumns = right.rows();
875 rightRows = right.columns();
876 ASSERT(rightColumns == 1);
877 } else {
878 rightColumns = right.columns();
879 rightRows = right.rows();
880 }
881 if (rightColumns > 1) {
882 *outResultType = &(*outResultType)->toCompound(context, rightColumns,
883 leftRows);
884 } else {
885 // result was a column vector, transpose it back to a row
886 *outResultType = &(*outResultType)->toCompound(context, leftRows,
887 rightColumns);
888 }
889 return leftColumns == rightRows;
890 } else {
891 return false;
892 }
893 }
894 isLogical = false;
895 validMatrixOrVectorOp = true;
896 break;
897 case Token::PLUS: // fall through
898 case Token::PLUSEQ: // fall through
899 case Token::MINUS: // fall through
900 case Token::MINUSEQ: // fall through
901 case Token::SLASH: // fall through
902 case Token::SLASHEQ:
903 isLogical = false;
904 validMatrixOrVectorOp = true;
905 break;
906 default:
907 isLogical = false;
908 validMatrixOrVectorOp = false;
909 }
910 bool isVectorOrMatrix = left.kind() == Type::kVector_Kind || left.kind() == Type::kMatrix_Kind;
911 // FIXME: incorrect for shift
912 if (right.canCoerceTo(left) && (left.kind() == Type::kScalar_Kind ||
913 (isVectorOrMatrix && validMatrixOrVectorOp))) {
914 *outLeftType = &left;
915 *outRightType = &left;
916 if (isLogical) {
917 *outResultType = context.fBool_Type.get();
918 } else {
919 *outResultType = &left;
920 }
921 return true;
922 }
923 if ((left.kind() == Type::kVector_Kind || left.kind() == Type::kMatrix_Kind) &&
924 (right.kind() == Type::kScalar_Kind)) {
925 if (determine_binary_type(context, op, left.componentType(), right, outLeftType,
926 outRightType, outResultType, false)) {
927 *outLeftType = &(*outLeftType)->toCompound(context, left.columns(), left.rows());
928 if (!isLogical) {
929 *outResultType = &(*outResultType)->toCompound(context, left.columns(),
930 left.rows());
931 }
932 return true;
933 }
934 return false;
935 }
936 if (tryFlipped) {
937 return determine_binary_type(context, op, right, left, outRightType, outLeftType,
938 outResultType, false);
939 }
940 return false;
941 }
942
constantFold(const Expression & left,Token::Kind op,const Expression & right) const943 std::unique_ptr<Expression> IRGenerator::constantFold(const Expression& left,
944 Token::Kind op,
945 const Expression& right) const {
946 // Note that we expressly do not worry about precision and overflow here -- we use the maximum
947 // precision to calculate the results and hope the result makes sense. The plan is to move the
948 // Skia caps into SkSL, so we have access to all of them including the precisions of the various
949 // types, which will let us be more intelligent about this.
950 if (left.fKind == Expression::kBoolLiteral_Kind &&
951 right.fKind == Expression::kBoolLiteral_Kind) {
952 bool leftVal = ((BoolLiteral&) left).fValue;
953 bool rightVal = ((BoolLiteral&) right).fValue;
954 bool result;
955 switch (op) {
956 case Token::LOGICALAND: result = leftVal && rightVal; break;
957 case Token::LOGICALOR: result = leftVal || rightVal; break;
958 case Token::LOGICALXOR: result = leftVal ^ rightVal; break;
959 default: return nullptr;
960 }
961 return std::unique_ptr<Expression>(new BoolLiteral(fContext, left.fPosition, result));
962 }
963 #define RESULT(t, op) std::unique_ptr<Expression>(new t ## Literal(fContext, left.fPosition, \
964 leftVal op rightVal))
965 if (left.fKind == Expression::kIntLiteral_Kind && right.fKind == Expression::kIntLiteral_Kind) {
966 int64_t leftVal = ((IntLiteral&) left).fValue;
967 int64_t rightVal = ((IntLiteral&) right).fValue;
968 switch (op) {
969 case Token::PLUS: return RESULT(Int, +);
970 case Token::MINUS: return RESULT(Int, -);
971 case Token::STAR: return RESULT(Int, *);
972 case Token::SLASH:
973 if (rightVal) {
974 return RESULT(Int, /);
975 }
976 fErrors.error(right.fPosition, "division by zero");
977 return nullptr;
978 case Token::PERCENT:
979 if (rightVal) {
980 return RESULT(Int, %);
981 }
982 fErrors.error(right.fPosition, "division by zero");
983 return nullptr;
984 case Token::BITWISEAND: return RESULT(Int, &);
985 case Token::BITWISEOR: return RESULT(Int, |);
986 case Token::BITWISEXOR: return RESULT(Int, ^);
987 case Token::SHL: return RESULT(Int, <<);
988 case Token::SHR: return RESULT(Int, >>);
989 case Token::EQEQ: return RESULT(Bool, ==);
990 case Token::NEQ: return RESULT(Bool, !=);
991 case Token::GT: return RESULT(Bool, >);
992 case Token::GTEQ: return RESULT(Bool, >=);
993 case Token::LT: return RESULT(Bool, <);
994 case Token::LTEQ: return RESULT(Bool, <=);
995 default: return nullptr;
996 }
997 }
998 if (left.fKind == Expression::kFloatLiteral_Kind &&
999 right.fKind == Expression::kFloatLiteral_Kind) {
1000 double leftVal = ((FloatLiteral&) left).fValue;
1001 double rightVal = ((FloatLiteral&) right).fValue;
1002 switch (op) {
1003 case Token::PLUS: return RESULT(Float, +);
1004 case Token::MINUS: return RESULT(Float, -);
1005 case Token::STAR: return RESULT(Float, *);
1006 case Token::SLASH:
1007 if (rightVal) {
1008 return RESULT(Float, /);
1009 }
1010 fErrors.error(right.fPosition, "division by zero");
1011 return nullptr;
1012 case Token::EQEQ: return RESULT(Bool, ==);
1013 case Token::NEQ: return RESULT(Bool, !=);
1014 case Token::GT: return RESULT(Bool, >);
1015 case Token::GTEQ: return RESULT(Bool, >=);
1016 case Token::LT: return RESULT(Bool, <);
1017 case Token::LTEQ: return RESULT(Bool, <=);
1018 default: return nullptr;
1019 }
1020 }
1021 #undef RESULT
1022 return nullptr;
1023 }
1024
convertBinaryExpression(const ASTBinaryExpression & expression)1025 std::unique_ptr<Expression> IRGenerator::convertBinaryExpression(
1026 const ASTBinaryExpression& expression) {
1027 std::unique_ptr<Expression> left = this->convertExpression(*expression.fLeft);
1028 if (!left) {
1029 return nullptr;
1030 }
1031 std::unique_ptr<Expression> right = this->convertExpression(*expression.fRight);
1032 if (!right) {
1033 return nullptr;
1034 }
1035 const Type* leftType;
1036 const Type* rightType;
1037 const Type* resultType;
1038 if (!determine_binary_type(fContext, expression.fOperator, left->fType, right->fType, &leftType,
1039 &rightType, &resultType,
1040 !Token::IsAssignment(expression.fOperator))) {
1041 fErrors.error(expression.fPosition, "type mismatch: '" +
1042 Token::OperatorName(expression.fOperator) +
1043 "' cannot operate on '" + left->fType.fName +
1044 "', '" + right->fType.fName + "'");
1045 return nullptr;
1046 }
1047 if (Token::IsAssignment(expression.fOperator)) {
1048 this->markWrittenTo(*left, expression.fOperator != Token::EQ);
1049 }
1050 left = this->coerce(std::move(left), *leftType);
1051 right = this->coerce(std::move(right), *rightType);
1052 if (!left || !right) {
1053 return nullptr;
1054 }
1055 std::unique_ptr<Expression> result = this->constantFold(*left.get(), expression.fOperator,
1056 *right.get());
1057 if (!result) {
1058 result = std::unique_ptr<Expression>(new BinaryExpression(expression.fPosition,
1059 std::move(left),
1060 expression.fOperator,
1061 std::move(right),
1062 *resultType));
1063 }
1064 return result;
1065 }
1066
convertTernaryExpression(const ASTTernaryExpression & expression)1067 std::unique_ptr<Expression> IRGenerator::convertTernaryExpression(
1068 const ASTTernaryExpression& expression) {
1069 std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*expression.fTest),
1070 *fContext.fBool_Type);
1071 if (!test) {
1072 return nullptr;
1073 }
1074 std::unique_ptr<Expression> ifTrue = this->convertExpression(*expression.fIfTrue);
1075 if (!ifTrue) {
1076 return nullptr;
1077 }
1078 std::unique_ptr<Expression> ifFalse = this->convertExpression(*expression.fIfFalse);
1079 if (!ifFalse) {
1080 return nullptr;
1081 }
1082 const Type* trueType;
1083 const Type* falseType;
1084 const Type* resultType;
1085 if (!determine_binary_type(fContext, Token::EQEQ, ifTrue->fType, ifFalse->fType, &trueType,
1086 &falseType, &resultType, true) || trueType != falseType) {
1087 fErrors.error(expression.fPosition, "ternary operator result mismatch: '" +
1088 ifTrue->fType.fName + "', '" +
1089 ifFalse->fType.fName + "'");
1090 return nullptr;
1091 }
1092 ifTrue = this->coerce(std::move(ifTrue), *trueType);
1093 if (!ifTrue) {
1094 return nullptr;
1095 }
1096 ifFalse = this->coerce(std::move(ifFalse), *falseType);
1097 if (!ifFalse) {
1098 return nullptr;
1099 }
1100 if (test->fKind == Expression::kBoolLiteral_Kind) {
1101 // static boolean test, just return one of the branches
1102 if (((BoolLiteral&) *test).fValue) {
1103 return ifTrue;
1104 } else {
1105 return ifFalse;
1106 }
1107 }
1108 return std::unique_ptr<Expression>(new TernaryExpression(expression.fPosition,
1109 std::move(test),
1110 std::move(ifTrue),
1111 std::move(ifFalse)));
1112 }
1113
call(Position position,const FunctionDeclaration & function,std::vector<std::unique_ptr<Expression>> arguments)1114 std::unique_ptr<Expression> IRGenerator::call(Position position,
1115 const FunctionDeclaration& function,
1116 std::vector<std::unique_ptr<Expression>> arguments) {
1117 if (function.fParameters.size() != arguments.size()) {
1118 SkString msg = "call to '" + function.fName + "' expected " +
1119 to_string((uint64_t) function.fParameters.size()) +
1120 " argument";
1121 if (function.fParameters.size() != 1) {
1122 msg += "s";
1123 }
1124 msg += ", but found " + to_string((uint64_t) arguments.size());
1125 fErrors.error(position, msg);
1126 return nullptr;
1127 }
1128 std::vector<const Type*> types;
1129 const Type* returnType;
1130 if (!function.determineFinalTypes(arguments, &types, &returnType)) {
1131 SkString msg = "no match for " + function.fName + "(";
1132 SkString separator;
1133 for (size_t i = 0; i < arguments.size(); i++) {
1134 msg += separator;
1135 separator = ", ";
1136 msg += arguments[i]->fType.description();
1137 }
1138 msg += ")";
1139 fErrors.error(position, msg);
1140 return nullptr;
1141 }
1142 for (size_t i = 0; i < arguments.size(); i++) {
1143 arguments[i] = this->coerce(std::move(arguments[i]), *types[i]);
1144 if (!arguments[i]) {
1145 return nullptr;
1146 }
1147 if (arguments[i] && (function.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag)) {
1148 this->markWrittenTo(*arguments[i], true);
1149 }
1150 }
1151 return std::unique_ptr<FunctionCall>(new FunctionCall(position, *returnType, function,
1152 std::move(arguments)));
1153 }
1154
1155 /**
1156 * Determines the cost of coercing the arguments of a function to the required types. Returns true
1157 * if the cost could be computed, false if the call is not valid. Cost has no particular meaning
1158 * other than "lower costs are preferred".
1159 */
determineCallCost(const FunctionDeclaration & function,const std::vector<std::unique_ptr<Expression>> & arguments,int * outCost)1160 bool IRGenerator::determineCallCost(const FunctionDeclaration& function,
1161 const std::vector<std::unique_ptr<Expression>>& arguments,
1162 int* outCost) {
1163 if (function.fParameters.size() != arguments.size()) {
1164 return false;
1165 }
1166 int total = 0;
1167 std::vector<const Type*> types;
1168 const Type* ignored;
1169 if (!function.determineFinalTypes(arguments, &types, &ignored)) {
1170 return false;
1171 }
1172 for (size_t i = 0; i < arguments.size(); i++) {
1173 int cost;
1174 if (arguments[i]->fType.determineCoercionCost(*types[i], &cost)) {
1175 total += cost;
1176 } else {
1177 return false;
1178 }
1179 }
1180 *outCost = total;
1181 return true;
1182 }
1183
call(Position position,std::unique_ptr<Expression> functionValue,std::vector<std::unique_ptr<Expression>> arguments)1184 std::unique_ptr<Expression> IRGenerator::call(Position position,
1185 std::unique_ptr<Expression> functionValue,
1186 std::vector<std::unique_ptr<Expression>> arguments) {
1187 if (functionValue->fKind == Expression::kTypeReference_Kind) {
1188 return this->convertConstructor(position,
1189 ((TypeReference&) *functionValue).fValue,
1190 std::move(arguments));
1191 }
1192 if (functionValue->fKind != Expression::kFunctionReference_Kind) {
1193 fErrors.error(position, "'" + functionValue->description() + "' is not a function");
1194 return nullptr;
1195 }
1196 FunctionReference* ref = (FunctionReference*) functionValue.get();
1197 int bestCost = INT_MAX;
1198 const FunctionDeclaration* best = nullptr;
1199 if (ref->fFunctions.size() > 1) {
1200 for (const auto& f : ref->fFunctions) {
1201 int cost;
1202 if (this->determineCallCost(*f, arguments, &cost) && cost < bestCost) {
1203 bestCost = cost;
1204 best = f;
1205 }
1206 }
1207 if (best) {
1208 return this->call(position, *best, std::move(arguments));
1209 }
1210 SkString msg = "no match for " + ref->fFunctions[0]->fName + "(";
1211 SkString separator;
1212 for (size_t i = 0; i < arguments.size(); i++) {
1213 msg += separator;
1214 separator = ", ";
1215 msg += arguments[i]->fType.description();
1216 }
1217 msg += ")";
1218 fErrors.error(position, msg);
1219 return nullptr;
1220 }
1221 return this->call(position, *ref->fFunctions[0], std::move(arguments));
1222 }
1223
convertNumberConstructor(Position position,const Type & type,std::vector<std::unique_ptr<Expression>> args)1224 std::unique_ptr<Expression> IRGenerator::convertNumberConstructor(
1225 Position position,
1226 const Type& type,
1227 std::vector<std::unique_ptr<Expression>> args) {
1228 ASSERT(type.isNumber());
1229 if (args.size() != 1) {
1230 fErrors.error(position, "invalid arguments to '" + type.description() +
1231 "' constructor, (expected exactly 1 argument, but found " +
1232 to_string((uint64_t) args.size()) + ")");
1233 return nullptr;
1234 }
1235 if (type == *fContext.fFloat_Type && args.size() == 1 &&
1236 args[0]->fKind == Expression::kIntLiteral_Kind) {
1237 int64_t value = ((IntLiteral&) *args[0]).fValue;
1238 return std::unique_ptr<Expression>(new FloatLiteral(fContext, position, (double) value));
1239 }
1240 if (args[0]->fKind == Expression::kIntLiteral_Kind && (type == *fContext.fInt_Type ||
1241 type == *fContext.fUInt_Type)) {
1242 return std::unique_ptr<Expression>(new IntLiteral(fContext,
1243 position,
1244 ((IntLiteral&) *args[0]).fValue,
1245 &type));
1246 }
1247 if (args[0]->fType == *fContext.fBool_Type) {
1248 std::unique_ptr<IntLiteral> zero(new IntLiteral(fContext, position, 0));
1249 std::unique_ptr<IntLiteral> one(new IntLiteral(fContext, position, 1));
1250 return std::unique_ptr<Expression>(
1251 new TernaryExpression(position, std::move(args[0]),
1252 this->coerce(std::move(one), type),
1253 this->coerce(std::move(zero),
1254 type)));
1255 }
1256 if (!args[0]->fType.isNumber()) {
1257 fErrors.error(position, "invalid argument to '" + type.description() +
1258 "' constructor (expected a number or bool, but found '" +
1259 args[0]->fType.description() + "')");
1260 return nullptr;
1261 }
1262 return std::unique_ptr<Expression>(new Constructor(position, std::move(type), std::move(args)));
1263 }
1264
component_count(const Type & type)1265 int component_count(const Type& type) {
1266 switch (type.kind()) {
1267 case Type::kVector_Kind:
1268 return type.columns();
1269 case Type::kMatrix_Kind:
1270 return type.columns() * type.rows();
1271 default:
1272 return 1;
1273 }
1274 }
1275
convertCompoundConstructor(Position position,const Type & type,std::vector<std::unique_ptr<Expression>> args)1276 std::unique_ptr<Expression> IRGenerator::convertCompoundConstructor(
1277 Position position,
1278 const Type& type,
1279 std::vector<std::unique_ptr<Expression>> args) {
1280 ASSERT(type.kind() == Type::kVector_Kind || type.kind() == Type::kMatrix_Kind);
1281 if (type.kind() == Type::kMatrix_Kind && args.size() == 1 &&
1282 args[0]->fType.kind() == Type::kMatrix_Kind) {
1283 // matrix from matrix is always legal
1284 return std::unique_ptr<Expression>(new Constructor(position, std::move(type),
1285 std::move(args)));
1286 }
1287 int actual = 0;
1288 int expected = type.rows() * type.columns();
1289 if (args.size() != 1 || expected != component_count(args[0]->fType) ||
1290 type.componentType().isNumber() != args[0]->fType.componentType().isNumber()) {
1291 for (size_t i = 0; i < args.size(); i++) {
1292 if (args[i]->fType.kind() == Type::kVector_Kind) {
1293 if (type.componentType().isNumber() !=
1294 args[i]->fType.componentType().isNumber()) {
1295 fErrors.error(position, "'" + args[i]->fType.description() + "' is not a valid "
1296 "parameter to '" + type.description() +
1297 "' constructor");
1298 return nullptr;
1299 }
1300 actual += args[i]->fType.columns();
1301 } else if (args[i]->fType.kind() == Type::kScalar_Kind) {
1302 actual += 1;
1303 if (type.kind() != Type::kScalar_Kind) {
1304 args[i] = this->coerce(std::move(args[i]), type.componentType());
1305 if (!args[i]) {
1306 return nullptr;
1307 }
1308 }
1309 } else {
1310 fErrors.error(position, "'" + args[i]->fType.description() + "' is not a valid "
1311 "parameter to '" + type.description() + "' constructor");
1312 return nullptr;
1313 }
1314 }
1315 if (actual != 1 && actual != expected) {
1316 fErrors.error(position, "invalid arguments to '" + type.description() +
1317 "' constructor (expected " + to_string(expected) +
1318 " scalars, but found " + to_string(actual) + ")");
1319 return nullptr;
1320 }
1321 }
1322 return std::unique_ptr<Expression>(new Constructor(position, std::move(type), std::move(args)));
1323 }
1324
convertConstructor(Position position,const Type & type,std::vector<std::unique_ptr<Expression>> args)1325 std::unique_ptr<Expression> IRGenerator::convertConstructor(
1326 Position position,
1327 const Type& type,
1328 std::vector<std::unique_ptr<Expression>> args) {
1329 // FIXME: add support for structs
1330 Type::Kind kind = type.kind();
1331 if (args.size() == 1 && args[0]->fType == type) {
1332 // argument is already the right type, just return it
1333 return std::move(args[0]);
1334 }
1335 if (type.isNumber()) {
1336 return this->convertNumberConstructor(position, type, std::move(args));
1337 } else if (kind == Type::kArray_Kind) {
1338 const Type& base = type.componentType();
1339 for (size_t i = 0; i < args.size(); i++) {
1340 args[i] = this->coerce(std::move(args[i]), base);
1341 if (!args[i]) {
1342 return nullptr;
1343 }
1344 }
1345 return std::unique_ptr<Expression>(new Constructor(position, std::move(type),
1346 std::move(args)));
1347 } else if (kind == Type::kVector_Kind || kind == Type::kMatrix_Kind) {
1348 return this->convertCompoundConstructor(position, type, std::move(args));
1349 } else {
1350 fErrors.error(position, "cannot construct '" + type.description() + "'");
1351 return nullptr;
1352 }
1353 }
1354
convertPrefixExpression(const ASTPrefixExpression & expression)1355 std::unique_ptr<Expression> IRGenerator::convertPrefixExpression(
1356 const ASTPrefixExpression& expression) {
1357 std::unique_ptr<Expression> base = this->convertExpression(*expression.fOperand);
1358 if (!base) {
1359 return nullptr;
1360 }
1361 switch (expression.fOperator) {
1362 case Token::PLUS:
1363 if (!base->fType.isNumber() && base->fType.kind() != Type::kVector_Kind) {
1364 fErrors.error(expression.fPosition,
1365 "'+' cannot operate on '" + base->fType.description() + "'");
1366 return nullptr;
1367 }
1368 return base;
1369 case Token::MINUS:
1370 if (!base->fType.isNumber() && base->fType.kind() != Type::kVector_Kind) {
1371 fErrors.error(expression.fPosition,
1372 "'-' cannot operate on '" + base->fType.description() + "'");
1373 return nullptr;
1374 }
1375 if (base->fKind == Expression::kIntLiteral_Kind) {
1376 return std::unique_ptr<Expression>(new IntLiteral(fContext, base->fPosition,
1377 -((IntLiteral&) *base).fValue));
1378 }
1379 if (base->fKind == Expression::kFloatLiteral_Kind) {
1380 double value = -((FloatLiteral&) *base).fValue;
1381 return std::unique_ptr<Expression>(new FloatLiteral(fContext, base->fPosition,
1382 value));
1383 }
1384 return std::unique_ptr<Expression>(new PrefixExpression(Token::MINUS, std::move(base)));
1385 case Token::PLUSPLUS:
1386 if (!base->fType.isNumber()) {
1387 fErrors.error(expression.fPosition,
1388 "'" + Token::OperatorName(expression.fOperator) +
1389 "' cannot operate on '" + base->fType.description() + "'");
1390 return nullptr;
1391 }
1392 this->markWrittenTo(*base, true);
1393 break;
1394 case Token::MINUSMINUS:
1395 if (!base->fType.isNumber()) {
1396 fErrors.error(expression.fPosition,
1397 "'" + Token::OperatorName(expression.fOperator) +
1398 "' cannot operate on '" + base->fType.description() + "'");
1399 return nullptr;
1400 }
1401 this->markWrittenTo(*base, true);
1402 break;
1403 case Token::LOGICALNOT:
1404 if (base->fType != *fContext.fBool_Type) {
1405 fErrors.error(expression.fPosition,
1406 "'" + Token::OperatorName(expression.fOperator) +
1407 "' cannot operate on '" + base->fType.description() + "'");
1408 return nullptr;
1409 }
1410 if (base->fKind == Expression::kBoolLiteral_Kind) {
1411 return std::unique_ptr<Expression>(new BoolLiteral(fContext, base->fPosition,
1412 !((BoolLiteral&) *base).fValue));
1413 }
1414 break;
1415 case Token::BITWISENOT:
1416 if (base->fType != *fContext.fInt_Type) {
1417 fErrors.error(expression.fPosition,
1418 "'" + Token::OperatorName(expression.fOperator) +
1419 "' cannot operate on '" + base->fType.description() + "'");
1420 return nullptr;
1421 }
1422 break;
1423 default:
1424 ABORT("unsupported prefix operator\n");
1425 }
1426 return std::unique_ptr<Expression>(new PrefixExpression(expression.fOperator,
1427 std::move(base)));
1428 }
1429
convertIndex(std::unique_ptr<Expression> base,const ASTExpression & index)1430 std::unique_ptr<Expression> IRGenerator::convertIndex(std::unique_ptr<Expression> base,
1431 const ASTExpression& index) {
1432 if (base->fKind == Expression::kTypeReference_Kind) {
1433 if (index.fKind == ASTExpression::kInt_Kind) {
1434 const Type& oldType = ((TypeReference&) *base).fValue;
1435 int64_t size = ((const ASTIntLiteral&) index).fValue;
1436 Type* newType = new Type(oldType.name() + "[" + to_string(size) + "]",
1437 Type::kArray_Kind, oldType, size);
1438 fSymbolTable->takeOwnership(newType);
1439 return std::unique_ptr<Expression>(new TypeReference(fContext, base->fPosition,
1440 *newType));
1441
1442 } else {
1443 fErrors.error(base->fPosition, "array size must be a constant");
1444 return nullptr;
1445 }
1446 }
1447 if (base->fType.kind() != Type::kArray_Kind && base->fType.kind() != Type::kMatrix_Kind &&
1448 base->fType.kind() != Type::kVector_Kind) {
1449 fErrors.error(base->fPosition, "expected array, but found '" + base->fType.description() +
1450 "'");
1451 return nullptr;
1452 }
1453 std::unique_ptr<Expression> converted = this->convertExpression(index);
1454 if (!converted) {
1455 return nullptr;
1456 }
1457 if (converted->fType != *fContext.fUInt_Type) {
1458 converted = this->coerce(std::move(converted), *fContext.fInt_Type);
1459 if (!converted) {
1460 return nullptr;
1461 }
1462 }
1463 return std::unique_ptr<Expression>(new IndexExpression(fContext, std::move(base),
1464 std::move(converted)));
1465 }
1466
convertField(std::unique_ptr<Expression> base,const SkString & field)1467 std::unique_ptr<Expression> IRGenerator::convertField(std::unique_ptr<Expression> base,
1468 const SkString& field) {
1469 auto fields = base->fType.fields();
1470 for (size_t i = 0; i < fields.size(); i++) {
1471 if (fields[i].fName == field) {
1472 return std::unique_ptr<Expression>(new FieldAccess(std::move(base), (int) i));
1473 }
1474 }
1475 fErrors.error(base->fPosition, "type '" + base->fType.description() + "' does not have a "
1476 "field named '" + field + "");
1477 return nullptr;
1478 }
1479
convertSwizzle(std::unique_ptr<Expression> base,const SkString & fields)1480 std::unique_ptr<Expression> IRGenerator::convertSwizzle(std::unique_ptr<Expression> base,
1481 const SkString& fields) {
1482 if (base->fType.kind() != Type::kVector_Kind) {
1483 fErrors.error(base->fPosition, "cannot swizzle type '" + base->fType.description() + "'");
1484 return nullptr;
1485 }
1486 std::vector<int> swizzleComponents;
1487 for (size_t i = 0; i < fields.size(); i++) {
1488 switch (fields[i]) {
1489 case 'x': // fall through
1490 case 'r': // fall through
1491 case 's':
1492 swizzleComponents.push_back(0);
1493 break;
1494 case 'y': // fall through
1495 case 'g': // fall through
1496 case 't':
1497 if (base->fType.columns() >= 2) {
1498 swizzleComponents.push_back(1);
1499 break;
1500 }
1501 // fall through
1502 case 'z': // fall through
1503 case 'b': // fall through
1504 case 'p':
1505 if (base->fType.columns() >= 3) {
1506 swizzleComponents.push_back(2);
1507 break;
1508 }
1509 // fall through
1510 case 'w': // fall through
1511 case 'a': // fall through
1512 case 'q':
1513 if (base->fType.columns() >= 4) {
1514 swizzleComponents.push_back(3);
1515 break;
1516 }
1517 // fall through
1518 default:
1519 fErrors.error(base->fPosition, SkStringPrintf("invalid swizzle component '%c'",
1520 fields[i]));
1521 return nullptr;
1522 }
1523 }
1524 ASSERT(swizzleComponents.size() > 0);
1525 if (swizzleComponents.size() > 4) {
1526 fErrors.error(base->fPosition, "too many components in swizzle mask '" + fields + "'");
1527 return nullptr;
1528 }
1529 return std::unique_ptr<Expression>(new Swizzle(fContext, std::move(base), swizzleComponents));
1530 }
1531
getCap(Position position,SkString name)1532 std::unique_ptr<Expression> IRGenerator::getCap(Position position, SkString name) {
1533 auto found = fCapsMap.find(name);
1534 if (found == fCapsMap.end()) {
1535 fErrors.error(position, "unknown capability flag '" + name + "'");
1536 return nullptr;
1537 }
1538 switch (found->second.fKind) {
1539 case CapValue::kBool_Kind:
1540 return std::unique_ptr<Expression>(new BoolLiteral(fContext, position,
1541 (bool) found->second.fValue));
1542 case CapValue::kInt_Kind:
1543 return std::unique_ptr<Expression>(new IntLiteral(fContext, position,
1544 found->second.fValue));
1545 }
1546 ASSERT(false);
1547 return nullptr;
1548 }
1549
convertSuffixExpression(const ASTSuffixExpression & expression)1550 std::unique_ptr<Expression> IRGenerator::convertSuffixExpression(
1551 const ASTSuffixExpression& expression) {
1552 std::unique_ptr<Expression> base = this->convertExpression(*expression.fBase);
1553 if (!base) {
1554 return nullptr;
1555 }
1556 switch (expression.fSuffix->fKind) {
1557 case ASTSuffix::kIndex_Kind: {
1558 const ASTExpression* expr = ((ASTIndexSuffix&) *expression.fSuffix).fExpression.get();
1559 if (expr) {
1560 return this->convertIndex(std::move(base), *expr);
1561 } else if (base->fKind == Expression::kTypeReference_Kind) {
1562 const Type& oldType = ((TypeReference&) *base).fValue;
1563 Type* newType = new Type(oldType.name() + "[]", Type::kArray_Kind, oldType,
1564 -1);
1565 fSymbolTable->takeOwnership(newType);
1566 return std::unique_ptr<Expression>(new TypeReference(fContext, base->fPosition,
1567 *newType));
1568 } else {
1569 fErrors.error(expression.fPosition, "'[]' must follow a type name");
1570 return nullptr;
1571 }
1572 }
1573 case ASTSuffix::kCall_Kind: {
1574 auto rawArguments = &((ASTCallSuffix&) *expression.fSuffix).fArguments;
1575 std::vector<std::unique_ptr<Expression>> arguments;
1576 for (size_t i = 0; i < rawArguments->size(); i++) {
1577 std::unique_ptr<Expression> converted =
1578 this->convertExpression(*(*rawArguments)[i]);
1579 if (!converted) {
1580 return nullptr;
1581 }
1582 arguments.push_back(std::move(converted));
1583 }
1584 return this->call(expression.fPosition, std::move(base), std::move(arguments));
1585 }
1586 case ASTSuffix::kField_Kind: {
1587 if (base->fType == *fContext.fSkCaps_Type) {
1588 return this->getCap(expression.fPosition,
1589 ((ASTFieldSuffix&) *expression.fSuffix).fField);
1590 }
1591 switch (base->fType.kind()) {
1592 case Type::kVector_Kind:
1593 return this->convertSwizzle(std::move(base),
1594 ((ASTFieldSuffix&) *expression.fSuffix).fField);
1595 case Type::kStruct_Kind:
1596 return this->convertField(std::move(base),
1597 ((ASTFieldSuffix&) *expression.fSuffix).fField);
1598 default:
1599 fErrors.error(base->fPosition, "cannot swizzle value of type '" +
1600 base->fType.description() + "'");
1601 return nullptr;
1602 }
1603 }
1604 case ASTSuffix::kPostIncrement_Kind:
1605 if (!base->fType.isNumber()) {
1606 fErrors.error(expression.fPosition,
1607 "'++' cannot operate on '" + base->fType.description() + "'");
1608 return nullptr;
1609 }
1610 this->markWrittenTo(*base, true);
1611 return std::unique_ptr<Expression>(new PostfixExpression(std::move(base),
1612 Token::PLUSPLUS));
1613 case ASTSuffix::kPostDecrement_Kind:
1614 if (!base->fType.isNumber()) {
1615 fErrors.error(expression.fPosition,
1616 "'--' cannot operate on '" + base->fType.description() + "'");
1617 return nullptr;
1618 }
1619 this->markWrittenTo(*base, true);
1620 return std::unique_ptr<Expression>(new PostfixExpression(std::move(base),
1621 Token::MINUSMINUS));
1622 default:
1623 ABORT("unsupported suffix operator");
1624 }
1625 }
1626
checkValid(const Expression & expr)1627 void IRGenerator::checkValid(const Expression& expr) {
1628 switch (expr.fKind) {
1629 case Expression::kFunctionReference_Kind:
1630 fErrors.error(expr.fPosition, "expected '(' to begin function call");
1631 break;
1632 case Expression::kTypeReference_Kind:
1633 fErrors.error(expr.fPosition, "expected '(' to begin constructor invocation");
1634 break;
1635 default:
1636 if (expr.fType == *fContext.fInvalid_Type) {
1637 fErrors.error(expr.fPosition, "invalid expression");
1638 }
1639 }
1640 }
1641
has_duplicates(const Swizzle & swizzle)1642 static bool has_duplicates(const Swizzle& swizzle) {
1643 int bits = 0;
1644 for (int idx : swizzle.fComponents) {
1645 ASSERT(idx >= 0 && idx <= 3);
1646 int bit = 1 << idx;
1647 if (bits & bit) {
1648 return true;
1649 }
1650 bits |= bit;
1651 }
1652 return false;
1653 }
1654
markWrittenTo(const Expression & expr,bool readWrite)1655 void IRGenerator::markWrittenTo(const Expression& expr, bool readWrite) {
1656 switch (expr.fKind) {
1657 case Expression::kVariableReference_Kind: {
1658 const Variable& var = ((VariableReference&) expr).fVariable;
1659 if (var.fModifiers.fFlags & (Modifiers::kConst_Flag | Modifiers::kUniform_Flag)) {
1660 fErrors.error(expr.fPosition,
1661 "cannot modify immutable variable '" + var.fName + "'");
1662 }
1663 ((VariableReference&) expr).setRefKind(readWrite ? VariableReference::kReadWrite_RefKind
1664 : VariableReference::kWrite_RefKind);
1665 break;
1666 }
1667 case Expression::kFieldAccess_Kind:
1668 this->markWrittenTo(*((FieldAccess&) expr).fBase, readWrite);
1669 break;
1670 case Expression::kSwizzle_Kind:
1671 if (has_duplicates((Swizzle&) expr)) {
1672 fErrors.error(expr.fPosition,
1673 "cannot write to the same swizzle field more than once");
1674 }
1675 this->markWrittenTo(*((Swizzle&) expr).fBase, readWrite);
1676 break;
1677 case Expression::kIndex_Kind:
1678 this->markWrittenTo(*((IndexExpression&) expr).fBase, readWrite);
1679 break;
1680 default:
1681 fErrors.error(expr.fPosition, "cannot assign to '" + expr.description() + "'");
1682 break;
1683 }
1684 }
1685
1686 }
1687