1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "SkSLCompiler.h"
9 
10 #include "SkSLCFGGenerator.h"
11 #include "SkSLCPPCodeGenerator.h"
12 #include "SkSLGLSLCodeGenerator.h"
13 #include "SkSLHCodeGenerator.h"
14 #include "SkSLIRGenerator.h"
15 #include "SkSLMetalCodeGenerator.h"
16 #include "SkSLPipelineStageCodeGenerator.h"
17 #include "SkSLSPIRVCodeGenerator.h"
18 #include "ir/SkSLEnum.h"
19 #include "ir/SkSLExpression.h"
20 #include "ir/SkSLExpressionStatement.h"
21 #include "ir/SkSLFunctionCall.h"
22 #include "ir/SkSLIntLiteral.h"
23 #include "ir/SkSLModifiersDeclaration.h"
24 #include "ir/SkSLNop.h"
25 #include "ir/SkSLSymbolTable.h"
26 #include "ir/SkSLTernaryExpression.h"
27 #include "ir/SkSLUnresolvedFunction.h"
28 #include "ir/SkSLVarDeclarations.h"
29 
30 #ifdef SK_ENABLE_SPIRV_VALIDATION
31 #include "spirv-tools/libspirv.hpp"
32 #endif
33 
34 // include the built-in shader symbols as static strings
35 
36 #define STRINGIFY(x) #x
37 
38 static const char* SKSL_INCLUDE =
39 #include "sksl.inc"
40 ;
41 
42 static const char* SKSL_VERT_INCLUDE =
43 #include "sksl_vert.inc"
44 ;
45 
46 static const char* SKSL_FRAG_INCLUDE =
47 #include "sksl_frag.inc"
48 ;
49 
50 static const char* SKSL_GEOM_INCLUDE =
51 #include "sksl_geom.inc"
52 ;
53 
54 static const char* SKSL_FP_INCLUDE =
55 #include "sksl_enums.inc"
56 #include "sksl_fp.inc"
57 ;
58 
59 static const char* SKSL_PIPELINE_STAGE_INCLUDE =
60 #include "sksl_pipeline.inc"
61 ;
62 
63 namespace SkSL {
64 
Compiler(Flags flags)65 Compiler::Compiler(Flags flags)
66 : fFlags(flags)
67 , fContext(new Context())
68 , fErrorCount(0) {
69     auto types = std::shared_ptr<SymbolTable>(new SymbolTable(this));
70     auto symbols = std::shared_ptr<SymbolTable>(new SymbolTable(types, this));
71     fIRGenerator = new IRGenerator(fContext.get(), symbols, *this);
72     fTypes = types;
73     #define ADD_TYPE(t) types->addWithoutOwnership(fContext->f ## t ## _Type->fName, \
74                                                    fContext->f ## t ## _Type.get())
75     ADD_TYPE(Void);
76     ADD_TYPE(Float);
77     ADD_TYPE(Float2);
78     ADD_TYPE(Float3);
79     ADD_TYPE(Float4);
80     ADD_TYPE(Half);
81     ADD_TYPE(Half2);
82     ADD_TYPE(Half3);
83     ADD_TYPE(Half4);
84     ADD_TYPE(Double);
85     ADD_TYPE(Double2);
86     ADD_TYPE(Double3);
87     ADD_TYPE(Double4);
88     ADD_TYPE(Int);
89     ADD_TYPE(Int2);
90     ADD_TYPE(Int3);
91     ADD_TYPE(Int4);
92     ADD_TYPE(UInt);
93     ADD_TYPE(UInt2);
94     ADD_TYPE(UInt3);
95     ADD_TYPE(UInt4);
96     ADD_TYPE(Short);
97     ADD_TYPE(Short2);
98     ADD_TYPE(Short3);
99     ADD_TYPE(Short4);
100     ADD_TYPE(UShort);
101     ADD_TYPE(UShort2);
102     ADD_TYPE(UShort3);
103     ADD_TYPE(UShort4);
104     ADD_TYPE(Byte);
105     ADD_TYPE(Byte2);
106     ADD_TYPE(Byte3);
107     ADD_TYPE(Byte4);
108     ADD_TYPE(UByte);
109     ADD_TYPE(UByte2);
110     ADD_TYPE(UByte3);
111     ADD_TYPE(UByte4);
112     ADD_TYPE(Bool);
113     ADD_TYPE(Bool2);
114     ADD_TYPE(Bool3);
115     ADD_TYPE(Bool4);
116     ADD_TYPE(Float2x2);
117     ADD_TYPE(Float2x3);
118     ADD_TYPE(Float2x4);
119     ADD_TYPE(Float3x2);
120     ADD_TYPE(Float3x3);
121     ADD_TYPE(Float3x4);
122     ADD_TYPE(Float4x2);
123     ADD_TYPE(Float4x3);
124     ADD_TYPE(Float4x4);
125     ADD_TYPE(Half2x2);
126     ADD_TYPE(Half2x3);
127     ADD_TYPE(Half2x4);
128     ADD_TYPE(Half3x2);
129     ADD_TYPE(Half3x3);
130     ADD_TYPE(Half3x4);
131     ADD_TYPE(Half4x2);
132     ADD_TYPE(Half4x3);
133     ADD_TYPE(Half4x4);
134     ADD_TYPE(Double2x2);
135     ADD_TYPE(Double2x3);
136     ADD_TYPE(Double2x4);
137     ADD_TYPE(Double3x2);
138     ADD_TYPE(Double3x3);
139     ADD_TYPE(Double3x4);
140     ADD_TYPE(Double4x2);
141     ADD_TYPE(Double4x3);
142     ADD_TYPE(Double4x4);
143     ADD_TYPE(GenType);
144     ADD_TYPE(GenHType);
145     ADD_TYPE(GenDType);
146     ADD_TYPE(GenIType);
147     ADD_TYPE(GenUType);
148     ADD_TYPE(GenBType);
149     ADD_TYPE(Mat);
150     ADD_TYPE(Vec);
151     ADD_TYPE(GVec);
152     ADD_TYPE(GVec2);
153     ADD_TYPE(GVec3);
154     ADD_TYPE(GVec4);
155     ADD_TYPE(HVec);
156     ADD_TYPE(DVec);
157     ADD_TYPE(IVec);
158     ADD_TYPE(UVec);
159     ADD_TYPE(SVec);
160     ADD_TYPE(USVec);
161     ADD_TYPE(ByteVec);
162     ADD_TYPE(UByteVec);
163     ADD_TYPE(BVec);
164 
165     ADD_TYPE(Sampler1D);
166     ADD_TYPE(Sampler2D);
167     ADD_TYPE(Sampler3D);
168     ADD_TYPE(SamplerExternalOES);
169     ADD_TYPE(SamplerCube);
170     ADD_TYPE(Sampler2DRect);
171     ADD_TYPE(Sampler1DArray);
172     ADD_TYPE(Sampler2DArray);
173     ADD_TYPE(SamplerCubeArray);
174     ADD_TYPE(SamplerBuffer);
175     ADD_TYPE(Sampler2DMS);
176     ADD_TYPE(Sampler2DMSArray);
177 
178     ADD_TYPE(ISampler2D);
179 
180     ADD_TYPE(Image2D);
181     ADD_TYPE(IImage2D);
182 
183     ADD_TYPE(SubpassInput);
184     ADD_TYPE(SubpassInputMS);
185 
186     ADD_TYPE(GSampler1D);
187     ADD_TYPE(GSampler2D);
188     ADD_TYPE(GSampler3D);
189     ADD_TYPE(GSamplerCube);
190     ADD_TYPE(GSampler2DRect);
191     ADD_TYPE(GSampler1DArray);
192     ADD_TYPE(GSampler2DArray);
193     ADD_TYPE(GSamplerCubeArray);
194     ADD_TYPE(GSamplerBuffer);
195     ADD_TYPE(GSampler2DMS);
196     ADD_TYPE(GSampler2DMSArray);
197 
198     ADD_TYPE(Sampler1DShadow);
199     ADD_TYPE(Sampler2DShadow);
200     ADD_TYPE(SamplerCubeShadow);
201     ADD_TYPE(Sampler2DRectShadow);
202     ADD_TYPE(Sampler1DArrayShadow);
203     ADD_TYPE(Sampler2DArrayShadow);
204     ADD_TYPE(SamplerCubeArrayShadow);
205     ADD_TYPE(GSampler2DArrayShadow);
206     ADD_TYPE(GSamplerCubeArrayShadow);
207     ADD_TYPE(FragmentProcessor);
208     ADD_TYPE(SkRasterPipeline);
209 
210     StringFragment skCapsName("sk_Caps");
211     Variable* skCaps = new Variable(-1, Modifiers(), skCapsName,
212                                     *fContext->fSkCaps_Type, Variable::kGlobal_Storage);
213     fIRGenerator->fSymbolTable->add(skCapsName, std::unique_ptr<Symbol>(skCaps));
214 
215     StringFragment skArgsName("sk_Args");
216     Variable* skArgs = new Variable(-1, Modifiers(), skArgsName,
217                                     *fContext->fSkArgs_Type, Variable::kGlobal_Storage);
218     fIRGenerator->fSymbolTable->add(skArgsName, std::unique_ptr<Symbol>(skArgs));
219 
220     std::vector<std::unique_ptr<ProgramElement>> ignored;
221     fIRGenerator->convertProgram(Program::kFragment_Kind, SKSL_INCLUDE, strlen(SKSL_INCLUDE),
222                                  *fTypes, &ignored);
223     fIRGenerator->fSymbolTable->markAllFunctionsBuiltin();
224     if (fErrorCount) {
225         printf("Unexpected errors: %s\n", fErrorText.c_str());
226     }
227     SkASSERT(!fErrorCount);
228 
229     Program::Settings settings;
230     fIRGenerator->start(&settings, nullptr);
231     fIRGenerator->convertProgram(Program::kFragment_Kind, SKSL_VERT_INCLUDE,
232                                  strlen(SKSL_VERT_INCLUDE), *fTypes, &fVertexInclude);
233     fIRGenerator->fSymbolTable->markAllFunctionsBuiltin();
234     fVertexSymbolTable = fIRGenerator->fSymbolTable;
235 
236     fIRGenerator->start(&settings, nullptr);
237     fIRGenerator->convertProgram(Program::kVertex_Kind, SKSL_FRAG_INCLUDE,
238                                  strlen(SKSL_FRAG_INCLUDE), *fTypes, &fFragmentInclude);
239     fIRGenerator->fSymbolTable->markAllFunctionsBuiltin();
240     fFragmentSymbolTable = fIRGenerator->fSymbolTable;
241 
242     fIRGenerator->start(&settings, nullptr);
243     fIRGenerator->convertProgram(Program::kGeometry_Kind, SKSL_GEOM_INCLUDE,
244                                  strlen(SKSL_GEOM_INCLUDE), *fTypes, &fGeometryInclude);
245     fIRGenerator->fSymbolTable->markAllFunctionsBuiltin();
246     fGeometrySymbolTable = fIRGenerator->fSymbolTable;
247 }
248 
~Compiler()249 Compiler::~Compiler() {
250     delete fIRGenerator;
251 }
252 
253 // add the definition created by assigning to the lvalue to the definition set
addDefinition(const Expression * lvalue,std::unique_ptr<Expression> * expr,DefinitionMap * definitions)254 void Compiler::addDefinition(const Expression* lvalue, std::unique_ptr<Expression>* expr,
255                              DefinitionMap* definitions) {
256     switch (lvalue->fKind) {
257         case Expression::kVariableReference_Kind: {
258             const Variable& var = ((VariableReference*) lvalue)->fVariable;
259             if (var.fStorage == Variable::kLocal_Storage) {
260                 (*definitions)[&var] = expr;
261             }
262             break;
263         }
264         case Expression::kSwizzle_Kind:
265             // We consider the variable written to as long as at least some of its components have
266             // been written to. This will lead to some false negatives (we won't catch it if you
267             // write to foo.x and then read foo.y), but being stricter could lead to false positives
268             // (we write to foo.x, and then pass foo to a function which happens to only read foo.x,
269             // but since we pass foo as a whole it is flagged as an error) unless we perform a much
270             // more complicated whole-program analysis. This is probably good enough.
271             this->addDefinition(((Swizzle*) lvalue)->fBase.get(),
272                                 (std::unique_ptr<Expression>*) &fContext->fDefined_Expression,
273                                 definitions);
274             break;
275         case Expression::kIndex_Kind:
276             // see comments in Swizzle
277             this->addDefinition(((IndexExpression*) lvalue)->fBase.get(),
278                                 (std::unique_ptr<Expression>*) &fContext->fDefined_Expression,
279                                 definitions);
280             break;
281         case Expression::kFieldAccess_Kind:
282             // see comments in Swizzle
283             this->addDefinition(((FieldAccess*) lvalue)->fBase.get(),
284                                 (std::unique_ptr<Expression>*) &fContext->fDefined_Expression,
285                                 definitions);
286             break;
287         case Expression::kTernary_Kind:
288             // To simplify analysis, we just pretend that we write to both sides of the ternary.
289             // This allows for false positives (meaning we fail to detect that a variable might not
290             // have been assigned), but is preferable to false negatives.
291             this->addDefinition(((TernaryExpression*) lvalue)->fIfTrue.get(),
292                                 (std::unique_ptr<Expression>*) &fContext->fDefined_Expression,
293                                 definitions);
294             this->addDefinition(((TernaryExpression*) lvalue)->fIfFalse.get(),
295                                 (std::unique_ptr<Expression>*) &fContext->fDefined_Expression,
296                                 definitions);
297             break;
298         default:
299             // not an lvalue, can't happen
300             SkASSERT(false);
301     }
302 }
303 
304 // add local variables defined by this node to the set
addDefinitions(const BasicBlock::Node & node,DefinitionMap * definitions)305 void Compiler::addDefinitions(const BasicBlock::Node& node,
306                               DefinitionMap* definitions) {
307     switch (node.fKind) {
308         case BasicBlock::Node::kExpression_Kind: {
309             SkASSERT(node.expression());
310             const Expression* expr = (Expression*) node.expression()->get();
311             switch (expr->fKind) {
312                 case Expression::kBinary_Kind: {
313                     BinaryExpression* b = (BinaryExpression*) expr;
314                     if (b->fOperator == Token::EQ) {
315                         this->addDefinition(b->fLeft.get(), &b->fRight, definitions);
316                     } else if (Compiler::IsAssignment(b->fOperator)) {
317                         this->addDefinition(
318                                       b->fLeft.get(),
319                                       (std::unique_ptr<Expression>*) &fContext->fDefined_Expression,
320                                       definitions);
321 
322                     }
323                     break;
324                 }
325                 case Expression::kFunctionCall_Kind: {
326                     const FunctionCall& c = (const FunctionCall&) *expr;
327                     for (size_t i = 0; i < c.fFunction.fParameters.size(); ++i) {
328                         if (c.fFunction.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag) {
329                             this->addDefinition(
330                                       c.fArguments[i].get(),
331                                       (std::unique_ptr<Expression>*) &fContext->fDefined_Expression,
332                                       definitions);
333                         }
334                     }
335                     break;
336                 }
337                 case Expression::kPrefix_Kind: {
338                     const PrefixExpression* p = (PrefixExpression*) expr;
339                     if (p->fOperator == Token::MINUSMINUS || p->fOperator == Token::PLUSPLUS) {
340                         this->addDefinition(
341                                       p->fOperand.get(),
342                                       (std::unique_ptr<Expression>*) &fContext->fDefined_Expression,
343                                       definitions);
344                     }
345                     break;
346                 }
347                 case Expression::kPostfix_Kind: {
348                     const PostfixExpression* p = (PostfixExpression*) expr;
349                     if (p->fOperator == Token::MINUSMINUS || p->fOperator == Token::PLUSPLUS) {
350                         this->addDefinition(
351                                       p->fOperand.get(),
352                                       (std::unique_ptr<Expression>*) &fContext->fDefined_Expression,
353                                       definitions);
354                     }
355                     break;
356                 }
357                 case Expression::kVariableReference_Kind: {
358                     const VariableReference* v = (VariableReference*) expr;
359                     if (v->fRefKind != VariableReference::kRead_RefKind) {
360                         this->addDefinition(
361                                       v,
362                                       (std::unique_ptr<Expression>*) &fContext->fDefined_Expression,
363                                       definitions);
364                     }
365                 }
366                 default:
367                     break;
368             }
369             break;
370         }
371         case BasicBlock::Node::kStatement_Kind: {
372             const Statement* stmt = (Statement*) node.statement()->get();
373             if (stmt->fKind == Statement::kVarDeclaration_Kind) {
374                 VarDeclaration& vd = (VarDeclaration&) *stmt;
375                 if (vd.fValue) {
376                     (*definitions)[vd.fVar] = &vd.fValue;
377                 }
378             }
379             break;
380         }
381     }
382 }
383 
scanCFG(CFG * cfg,BlockId blockId,std::set<BlockId> * workList)384 void Compiler::scanCFG(CFG* cfg, BlockId blockId, std::set<BlockId>* workList) {
385     BasicBlock& block = cfg->fBlocks[blockId];
386 
387     // compute definitions after this block
388     DefinitionMap after = block.fBefore;
389     for (const BasicBlock::Node& n : block.fNodes) {
390         this->addDefinitions(n, &after);
391     }
392 
393     // propagate definitions to exits
394     for (BlockId exitId : block.fExits) {
395         if (exitId == blockId) {
396             continue;
397         }
398         BasicBlock& exit = cfg->fBlocks[exitId];
399         for (const auto& pair : after) {
400             std::unique_ptr<Expression>* e1 = pair.second;
401             auto found = exit.fBefore.find(pair.first);
402             if (found == exit.fBefore.end()) {
403                 // exit has no definition for it, just copy it
404                 workList->insert(exitId);
405                 exit.fBefore[pair.first] = e1;
406             } else {
407                 // exit has a (possibly different) value already defined
408                 std::unique_ptr<Expression>* e2 = exit.fBefore[pair.first];
409                 if (e1 != e2) {
410                     // definition has changed, merge and add exit block to worklist
411                     workList->insert(exitId);
412                     if (e1 && e2) {
413                         exit.fBefore[pair.first] =
414                                       (std::unique_ptr<Expression>*) &fContext->fDefined_Expression;
415                     } else {
416                         exit.fBefore[pair.first] = nullptr;
417                     }
418                 }
419             }
420         }
421     }
422 }
423 
424 // returns a map which maps all local variables in the function to null, indicating that their value
425 // is initially unknown
compute_start_state(const CFG & cfg)426 static DefinitionMap compute_start_state(const CFG& cfg) {
427     DefinitionMap result;
428     for (const auto& block : cfg.fBlocks) {
429         for (const auto& node : block.fNodes) {
430             if (node.fKind == BasicBlock::Node::kStatement_Kind) {
431                 SkASSERT(node.statement());
432                 const Statement* s = node.statement()->get();
433                 if (s->fKind == Statement::kVarDeclarations_Kind) {
434                     const VarDeclarationsStatement* vd = (const VarDeclarationsStatement*) s;
435                     for (const auto& decl : vd->fDeclaration->fVars) {
436                         if (decl->fKind == Statement::kVarDeclaration_Kind) {
437                             result[((VarDeclaration&) *decl).fVar] = nullptr;
438                         }
439                     }
440                 }
441             }
442         }
443     }
444     return result;
445 }
446 
447 /**
448  * Returns true if assigning to this lvalue has no effect.
449  */
is_dead(const Expression & lvalue)450 static bool is_dead(const Expression& lvalue) {
451     switch (lvalue.fKind) {
452         case Expression::kVariableReference_Kind:
453             return ((VariableReference&) lvalue).fVariable.dead();
454         case Expression::kSwizzle_Kind:
455             return is_dead(*((Swizzle&) lvalue).fBase);
456         case Expression::kFieldAccess_Kind:
457             return is_dead(*((FieldAccess&) lvalue).fBase);
458         case Expression::kIndex_Kind: {
459             const IndexExpression& idx = (IndexExpression&) lvalue;
460             return is_dead(*idx.fBase) && !idx.fIndex->hasSideEffects();
461         }
462         case Expression::kTernary_Kind: {
463             const TernaryExpression& t = (TernaryExpression&) lvalue;
464             return !t.fTest->hasSideEffects() && is_dead(*t.fIfTrue) && is_dead(*t.fIfFalse);
465         }
466         default:
467             ABORT("invalid lvalue: %s\n", lvalue.description().c_str());
468     }
469 }
470 
471 /**
472  * Returns true if this is an assignment which can be collapsed down to just the right hand side due
473  * to a dead target and lack of side effects on the left hand side.
474  */
dead_assignment(const BinaryExpression & b)475 static bool dead_assignment(const BinaryExpression& b) {
476     if (!Compiler::IsAssignment(b.fOperator)) {
477         return false;
478     }
479     return is_dead(*b.fLeft);
480 }
481 
computeDataFlow(CFG * cfg)482 void Compiler::computeDataFlow(CFG* cfg) {
483     cfg->fBlocks[cfg->fStart].fBefore = compute_start_state(*cfg);
484     std::set<BlockId> workList;
485     for (BlockId i = 0; i < cfg->fBlocks.size(); i++) {
486         workList.insert(i);
487     }
488     while (workList.size()) {
489         BlockId next = *workList.begin();
490         workList.erase(workList.begin());
491         this->scanCFG(cfg, next, &workList);
492     }
493 }
494 
495 /**
496  * Attempts to replace the expression pointed to by iter with a new one (in both the CFG and the
497  * IR). If the expression can be cleanly removed, returns true and updates the iterator to point to
498  * the newly-inserted element. Otherwise updates only the IR and returns false (and the CFG will
499  * need to be regenerated).
500  */
try_replace_expression(BasicBlock * b,std::vector<BasicBlock::Node>::iterator * iter,std::unique_ptr<Expression> * newExpression)501 bool try_replace_expression(BasicBlock* b,
502                             std::vector<BasicBlock::Node>::iterator* iter,
503                             std::unique_ptr<Expression>* newExpression) {
504     std::unique_ptr<Expression>* target = (*iter)->expression();
505     if (!b->tryRemoveExpression(iter)) {
506         *target = std::move(*newExpression);
507         return false;
508     }
509     *target = std::move(*newExpression);
510     return b->tryInsertExpression(iter, target);
511 }
512 
513 /**
514  * Returns true if the expression is a constant numeric literal with the specified value, or a
515  * constant vector with all elements equal to the specified value.
516  */
is_constant(const Expression & expr,double value)517 bool is_constant(const Expression& expr, double value) {
518     switch (expr.fKind) {
519         case Expression::kIntLiteral_Kind:
520             return ((IntLiteral&) expr).fValue == value;
521         case Expression::kFloatLiteral_Kind:
522             return ((FloatLiteral&) expr).fValue == value;
523         case Expression::kConstructor_Kind: {
524             Constructor& c = (Constructor&) expr;
525             if (c.fType.kind() == Type::kVector_Kind && c.isConstant()) {
526                 for (int i = 0; i < c.fType.columns(); ++i) {
527                     if (!is_constant(c.getVecComponent(i), value)) {
528                         return false;
529                     }
530                 }
531                 return true;
532             }
533             return false;
534         }
535         default:
536             return false;
537     }
538 }
539 
540 /**
541  * Collapses the binary expression pointed to by iter down to just the right side (in both the IR
542  * and CFG structures).
543  */
delete_left(BasicBlock * b,std::vector<BasicBlock::Node>::iterator * iter,bool * outUpdated,bool * outNeedsRescan)544 void delete_left(BasicBlock* b,
545                  std::vector<BasicBlock::Node>::iterator* iter,
546                  bool* outUpdated,
547                  bool* outNeedsRescan) {
548     *outUpdated = true;
549     std::unique_ptr<Expression>* target = (*iter)->expression();
550     SkASSERT((*target)->fKind == Expression::kBinary_Kind);
551     BinaryExpression& bin = (BinaryExpression&) **target;
552     SkASSERT(!bin.fLeft->hasSideEffects());
553     bool result;
554     if (bin.fOperator == Token::EQ) {
555         result = b->tryRemoveLValueBefore(iter, bin.fLeft.get());
556     } else {
557         result = b->tryRemoveExpressionBefore(iter, bin.fLeft.get());
558     }
559     *target = std::move(bin.fRight);
560     if (!result) {
561         *outNeedsRescan = true;
562         return;
563     }
564     if (*iter == b->fNodes.begin()) {
565         *outNeedsRescan = true;
566         return;
567     }
568     --(*iter);
569     if ((*iter)->fKind != BasicBlock::Node::kExpression_Kind ||
570         (*iter)->expression() != &bin.fRight) {
571         *outNeedsRescan = true;
572         return;
573     }
574     *iter = b->fNodes.erase(*iter);
575     SkASSERT((*iter)->expression() == target);
576 }
577 
578 /**
579  * Collapses the binary expression pointed to by iter down to just the left side (in both the IR and
580  * CFG structures).
581  */
delete_right(BasicBlock * b,std::vector<BasicBlock::Node>::iterator * iter,bool * outUpdated,bool * outNeedsRescan)582 void delete_right(BasicBlock* b,
583                   std::vector<BasicBlock::Node>::iterator* iter,
584                   bool* outUpdated,
585                   bool* outNeedsRescan) {
586     *outUpdated = true;
587     std::unique_ptr<Expression>* target = (*iter)->expression();
588     SkASSERT((*target)->fKind == Expression::kBinary_Kind);
589     BinaryExpression& bin = (BinaryExpression&) **target;
590     SkASSERT(!bin.fRight->hasSideEffects());
591     if (!b->tryRemoveExpressionBefore(iter, bin.fRight.get())) {
592         *target = std::move(bin.fLeft);
593         *outNeedsRescan = true;
594         return;
595     }
596     *target = std::move(bin.fLeft);
597     if (*iter == b->fNodes.begin()) {
598         *outNeedsRescan = true;
599         return;
600     }
601     --(*iter);
602     if (((*iter)->fKind != BasicBlock::Node::kExpression_Kind ||
603         (*iter)->expression() != &bin.fLeft)) {
604         *outNeedsRescan = true;
605         return;
606     }
607     *iter = b->fNodes.erase(*iter);
608     SkASSERT((*iter)->expression() == target);
609 }
610 
611 /**
612  * Constructs the specified type using a single argument.
613  */
construct(const Type & type,std::unique_ptr<Expression> v)614 static std::unique_ptr<Expression> construct(const Type& type, std::unique_ptr<Expression> v) {
615     std::vector<std::unique_ptr<Expression>> args;
616     args.push_back(std::move(v));
617     auto result = std::unique_ptr<Expression>(new Constructor(-1, type, std::move(args)));
618     return result;
619 }
620 
621 /**
622  * Used in the implementations of vectorize_left and vectorize_right. Given a vector type and an
623  * expression x, deletes the expression pointed to by iter and replaces it with <type>(x).
624  */
vectorize(BasicBlock * b,std::vector<BasicBlock::Node>::iterator * iter,const Type & type,std::unique_ptr<Expression> * otherExpression,bool * outUpdated,bool * outNeedsRescan)625 static void vectorize(BasicBlock* b,
626                       std::vector<BasicBlock::Node>::iterator* iter,
627                       const Type& type,
628                       std::unique_ptr<Expression>* otherExpression,
629                       bool* outUpdated,
630                       bool* outNeedsRescan) {
631     SkASSERT((*(*iter)->expression())->fKind == Expression::kBinary_Kind);
632     SkASSERT(type.kind() == Type::kVector_Kind);
633     SkASSERT((*otherExpression)->fType.kind() == Type::kScalar_Kind);
634     *outUpdated = true;
635     std::unique_ptr<Expression>* target = (*iter)->expression();
636     if (!b->tryRemoveExpression(iter)) {
637         *target = construct(type, std::move(*otherExpression));
638         *outNeedsRescan = true;
639     } else {
640         *target = construct(type, std::move(*otherExpression));
641         if (!b->tryInsertExpression(iter, target)) {
642             *outNeedsRescan = true;
643         }
644     }
645 }
646 
647 /**
648  * Given a binary expression of the form x <op> vec<n>(y), deletes the right side and vectorizes the
649  * left to yield vec<n>(x).
650  */
vectorize_left(BasicBlock * b,std::vector<BasicBlock::Node>::iterator * iter,bool * outUpdated,bool * outNeedsRescan)651 static void vectorize_left(BasicBlock* b,
652                            std::vector<BasicBlock::Node>::iterator* iter,
653                            bool* outUpdated,
654                            bool* outNeedsRescan) {
655     BinaryExpression& bin = (BinaryExpression&) **(*iter)->expression();
656     vectorize(b, iter, bin.fRight->fType, &bin.fLeft, outUpdated, outNeedsRescan);
657 }
658 
659 /**
660  * Given a binary expression of the form vec<n>(x) <op> y, deletes the left side and vectorizes the
661  * right to yield vec<n>(y).
662  */
vectorize_right(BasicBlock * b,std::vector<BasicBlock::Node>::iterator * iter,bool * outUpdated,bool * outNeedsRescan)663 static void vectorize_right(BasicBlock* b,
664                             std::vector<BasicBlock::Node>::iterator* iter,
665                             bool* outUpdated,
666                             bool* outNeedsRescan) {
667     BinaryExpression& bin = (BinaryExpression&) **(*iter)->expression();
668     vectorize(b, iter, bin.fLeft->fType, &bin.fRight, outUpdated, outNeedsRescan);
669 }
670 
671 // Mark that an expression which we were writing to is no longer being written to
clear_write(const Expression & expr)672 void clear_write(const Expression& expr) {
673     switch (expr.fKind) {
674         case Expression::kVariableReference_Kind: {
675             ((VariableReference&) expr).setRefKind(VariableReference::kRead_RefKind);
676             break;
677         }
678         case Expression::kFieldAccess_Kind:
679             clear_write(*((FieldAccess&) expr).fBase);
680             break;
681         case Expression::kSwizzle_Kind:
682             clear_write(*((Swizzle&) expr).fBase);
683             break;
684         case Expression::kIndex_Kind:
685             clear_write(*((IndexExpression&) expr).fBase);
686             break;
687         default:
688             ABORT("shouldn't be writing to this kind of expression\n");
689             break;
690     }
691 }
692 
simplifyExpression(DefinitionMap & definitions,BasicBlock & b,std::vector<BasicBlock::Node>::iterator * iter,std::unordered_set<const Variable * > * undefinedVariables,bool * outUpdated,bool * outNeedsRescan)693 void Compiler::simplifyExpression(DefinitionMap& definitions,
694                                   BasicBlock& b,
695                                   std::vector<BasicBlock::Node>::iterator* iter,
696                                   std::unordered_set<const Variable*>* undefinedVariables,
697                                   bool* outUpdated,
698                                   bool* outNeedsRescan) {
699     Expression* expr = (*iter)->expression()->get();
700     SkASSERT(expr);
701     if ((*iter)->fConstantPropagation) {
702         std::unique_ptr<Expression> optimized = expr->constantPropagate(*fIRGenerator, definitions);
703         if (optimized) {
704             *outUpdated = true;
705             if (!try_replace_expression(&b, iter, &optimized)) {
706                 *outNeedsRescan = true;
707                 return;
708             }
709             SkASSERT((*iter)->fKind == BasicBlock::Node::kExpression_Kind);
710             expr = (*iter)->expression()->get();
711         }
712     }
713     switch (expr->fKind) {
714         case Expression::kVariableReference_Kind: {
715             const VariableReference& ref = (VariableReference&) *expr;
716             const Variable& var = ref.fVariable;
717             if (ref.refKind() != VariableReference::kWrite_RefKind &&
718                 ref.refKind() != VariableReference::kPointer_RefKind &&
719                 var.fStorage == Variable::kLocal_Storage && !definitions[&var] &&
720                 (*undefinedVariables).find(&var) == (*undefinedVariables).end()) {
721                 (*undefinedVariables).insert(&var);
722                 this->error(expr->fOffset,
723                             "'" + var.fName + "' has not been assigned");
724             }
725             break;
726         }
727         case Expression::kTernary_Kind: {
728             TernaryExpression* t = (TernaryExpression*) expr;
729             if (t->fTest->fKind == Expression::kBoolLiteral_Kind) {
730                 // ternary has a constant test, replace it with either the true or
731                 // false branch
732                 if (((BoolLiteral&) *t->fTest).fValue) {
733                     (*iter)->setExpression(std::move(t->fIfTrue));
734                 } else {
735                     (*iter)->setExpression(std::move(t->fIfFalse));
736                 }
737                 *outUpdated = true;
738                 *outNeedsRescan = true;
739             }
740             break;
741         }
742         case Expression::kBinary_Kind: {
743             BinaryExpression* bin = (BinaryExpression*) expr;
744             if (dead_assignment(*bin)) {
745                 delete_left(&b, iter, outUpdated, outNeedsRescan);
746                 break;
747             }
748             // collapse useless expressions like x * 1 or x + 0
749             if (((bin->fLeft->fType.kind()  != Type::kScalar_Kind) &&
750                  (bin->fLeft->fType.kind()  != Type::kVector_Kind)) ||
751                 ((bin->fRight->fType.kind() != Type::kScalar_Kind) &&
752                  (bin->fRight->fType.kind() != Type::kVector_Kind))) {
753                 break;
754             }
755             switch (bin->fOperator) {
756                 case Token::STAR:
757                     if (is_constant(*bin->fLeft, 1)) {
758                         if (bin->fLeft->fType.kind() == Type::kVector_Kind &&
759                             bin->fRight->fType.kind() == Type::kScalar_Kind) {
760                             // float4(1) * x -> float4(x)
761                             vectorize_right(&b, iter, outUpdated, outNeedsRescan);
762                         } else {
763                             // 1 * x -> x
764                             // 1 * float4(x) -> float4(x)
765                             // float4(1) * float4(x) -> float4(x)
766                             delete_left(&b, iter, outUpdated, outNeedsRescan);
767                         }
768                     }
769                     else if (is_constant(*bin->fLeft, 0)) {
770                         if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
771                             bin->fRight->fType.kind() == Type::kVector_Kind &&
772                             !bin->fRight->hasSideEffects()) {
773                             // 0 * float4(x) -> float4(0)
774                             vectorize_left(&b, iter, outUpdated, outNeedsRescan);
775                         } else {
776                             // 0 * x -> 0
777                             // float4(0) * x -> float4(0)
778                             // float4(0) * float4(x) -> float4(0)
779                             if (!bin->fRight->hasSideEffects()) {
780                                 delete_right(&b, iter, outUpdated, outNeedsRescan);
781                             }
782                         }
783                     }
784                     else if (is_constant(*bin->fRight, 1)) {
785                         if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
786                             bin->fRight->fType.kind() == Type::kVector_Kind) {
787                             // x * float4(1) -> float4(x)
788                             vectorize_left(&b, iter, outUpdated, outNeedsRescan);
789                         } else {
790                             // x * 1 -> x
791                             // float4(x) * 1 -> float4(x)
792                             // float4(x) * float4(1) -> float4(x)
793                             delete_right(&b, iter, outUpdated, outNeedsRescan);
794                         }
795                     }
796                     else if (is_constant(*bin->fRight, 0)) {
797                         if (bin->fLeft->fType.kind() == Type::kVector_Kind &&
798                             bin->fRight->fType.kind() == Type::kScalar_Kind &&
799                             !bin->fLeft->hasSideEffects()) {
800                             // float4(x) * 0 -> float4(0)
801                             vectorize_right(&b, iter, outUpdated, outNeedsRescan);
802                         } else {
803                             // x * 0 -> 0
804                             // x * float4(0) -> float4(0)
805                             // float4(x) * float4(0) -> float4(0)
806                             if (!bin->fLeft->hasSideEffects()) {
807                                 delete_left(&b, iter, outUpdated, outNeedsRescan);
808                             }
809                         }
810                     }
811                     break;
812                 case Token::PLUS:
813                     if (is_constant(*bin->fLeft, 0)) {
814                         if (bin->fLeft->fType.kind() == Type::kVector_Kind &&
815                             bin->fRight->fType.kind() == Type::kScalar_Kind) {
816                             // float4(0) + x -> float4(x)
817                             vectorize_right(&b, iter, outUpdated, outNeedsRescan);
818                         } else {
819                             // 0 + x -> x
820                             // 0 + float4(x) -> float4(x)
821                             // float4(0) + float4(x) -> float4(x)
822                             delete_left(&b, iter, outUpdated, outNeedsRescan);
823                         }
824                     } else if (is_constant(*bin->fRight, 0)) {
825                         if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
826                             bin->fRight->fType.kind() == Type::kVector_Kind) {
827                             // x + float4(0) -> float4(x)
828                             vectorize_left(&b, iter, outUpdated, outNeedsRescan);
829                         } else {
830                             // x + 0 -> x
831                             // float4(x) + 0 -> float4(x)
832                             // float4(x) + float4(0) -> float4(x)
833                             delete_right(&b, iter, outUpdated, outNeedsRescan);
834                         }
835                     }
836                     break;
837                 case Token::MINUS:
838                     if (is_constant(*bin->fRight, 0)) {
839                         if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
840                             bin->fRight->fType.kind() == Type::kVector_Kind) {
841                             // x - float4(0) -> float4(x)
842                             vectorize_left(&b, iter, outUpdated, outNeedsRescan);
843                         } else {
844                             // x - 0 -> x
845                             // float4(x) - 0 -> float4(x)
846                             // float4(x) - float4(0) -> float4(x)
847                             delete_right(&b, iter, outUpdated, outNeedsRescan);
848                         }
849                     }
850                     break;
851                 case Token::SLASH:
852                     if (is_constant(*bin->fRight, 1)) {
853                         if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
854                             bin->fRight->fType.kind() == Type::kVector_Kind) {
855                             // x / float4(1) -> float4(x)
856                             vectorize_left(&b, iter, outUpdated, outNeedsRescan);
857                         } else {
858                             // x / 1 -> x
859                             // float4(x) / 1 -> float4(x)
860                             // float4(x) / float4(1) -> float4(x)
861                             delete_right(&b, iter, outUpdated, outNeedsRescan);
862                         }
863                     } else if (is_constant(*bin->fLeft, 0)) {
864                         if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
865                             bin->fRight->fType.kind() == Type::kVector_Kind &&
866                             !bin->fRight->hasSideEffects()) {
867                             // 0 / float4(x) -> float4(0)
868                             vectorize_left(&b, iter, outUpdated, outNeedsRescan);
869                         } else {
870                             // 0 / x -> 0
871                             // float4(0) / x -> float4(0)
872                             // float4(0) / float4(x) -> float4(0)
873                             if (!bin->fRight->hasSideEffects()) {
874                                 delete_right(&b, iter, outUpdated, outNeedsRescan);
875                             }
876                         }
877                     }
878                     break;
879                 case Token::PLUSEQ:
880                     if (is_constant(*bin->fRight, 0)) {
881                         clear_write(*bin->fLeft);
882                         delete_right(&b, iter, outUpdated, outNeedsRescan);
883                     }
884                     break;
885                 case Token::MINUSEQ:
886                     if (is_constant(*bin->fRight, 0)) {
887                         clear_write(*bin->fLeft);
888                         delete_right(&b, iter, outUpdated, outNeedsRescan);
889                     }
890                     break;
891                 case Token::STAREQ:
892                     if (is_constant(*bin->fRight, 1)) {
893                         clear_write(*bin->fLeft);
894                         delete_right(&b, iter, outUpdated, outNeedsRescan);
895                     }
896                     break;
897                 case Token::SLASHEQ:
898                     if (is_constant(*bin->fRight, 1)) {
899                         clear_write(*bin->fLeft);
900                         delete_right(&b, iter, outUpdated, outNeedsRescan);
901                     }
902                     break;
903                 default:
904                     break;
905             }
906         }
907         default:
908             break;
909     }
910 }
911 
912 // returns true if this statement could potentially execute a break at the current level (we ignore
913 // nested loops and switches, since any breaks inside of them will merely break the loop / switch)
contains_conditional_break(Statement & s,bool inConditional)914 static bool contains_conditional_break(Statement& s, bool inConditional) {
915     switch (s.fKind) {
916         case Statement::kBlock_Kind:
917             for (const auto& sub : ((Block&) s).fStatements) {
918                 if (contains_conditional_break(*sub, inConditional)) {
919                     return true;
920                 }
921             }
922             return false;
923         case Statement::kBreak_Kind:
924             return inConditional;
925         case Statement::kIf_Kind: {
926             const IfStatement& i = (IfStatement&) s;
927             return contains_conditional_break(*i.fIfTrue, true) ||
928                    (i.fIfFalse && contains_conditional_break(*i.fIfFalse, true));
929         }
930         default:
931             return false;
932     }
933 }
934 
935 // returns true if this statement definitely executes a break at the current level (we ignore
936 // nested loops and switches, since any breaks inside of them will merely break the loop / switch)
contains_unconditional_break(Statement & s)937 static bool contains_unconditional_break(Statement& s) {
938     switch (s.fKind) {
939         case Statement::kBlock_Kind:
940             for (const auto& sub : ((Block&) s).fStatements) {
941                 if (contains_unconditional_break(*sub)) {
942                     return true;
943                 }
944             }
945             return false;
946         case Statement::kBreak_Kind:
947             return true;
948         default:
949             return false;
950     }
951 }
952 
953 // Returns a block containing all of the statements that will be run if the given case matches
954 // (which, owing to the statements being owned by unique_ptrs, means the switch itself will be
955 // broken by this call and must then be discarded).
956 // Returns null (and leaves the switch unmodified) if no such simple reduction is possible, such as
957 // when break statements appear inside conditionals.
block_for_case(SwitchStatement * s,SwitchCase * c)958 static std::unique_ptr<Statement> block_for_case(SwitchStatement* s, SwitchCase* c) {
959     bool capturing = false;
960     std::vector<std::unique_ptr<Statement>*> statementPtrs;
961     for (const auto& current : s->fCases) {
962         if (current.get() == c) {
963             capturing = true;
964         }
965         if (capturing) {
966             for (auto& stmt : current->fStatements) {
967                 if (contains_conditional_break(*stmt, s->fKind == Statement::kIf_Kind)) {
968                     return nullptr;
969                 }
970                 if (contains_unconditional_break(*stmt)) {
971                     capturing = false;
972                     break;
973                 }
974                 statementPtrs.push_back(&stmt);
975             }
976             if (!capturing) {
977                 break;
978             }
979         }
980     }
981     std::vector<std::unique_ptr<Statement>> statements;
982     for (const auto& s : statementPtrs) {
983         statements.push_back(std::move(*s));
984     }
985     return std::unique_ptr<Statement>(new Block(-1, std::move(statements), s->fSymbols));
986 }
987 
simplifyStatement(DefinitionMap & definitions,BasicBlock & b,std::vector<BasicBlock::Node>::iterator * iter,std::unordered_set<const Variable * > * undefinedVariables,bool * outUpdated,bool * outNeedsRescan)988 void Compiler::simplifyStatement(DefinitionMap& definitions,
989                                  BasicBlock& b,
990                                  std::vector<BasicBlock::Node>::iterator* iter,
991                                  std::unordered_set<const Variable*>* undefinedVariables,
992                                  bool* outUpdated,
993                                  bool* outNeedsRescan) {
994     Statement* stmt = (*iter)->statement()->get();
995     switch (stmt->fKind) {
996         case Statement::kVarDeclaration_Kind: {
997             const auto& varDecl = (VarDeclaration&) *stmt;
998             if (varDecl.fVar->dead() &&
999                 (!varDecl.fValue ||
1000                  !varDecl.fValue->hasSideEffects())) {
1001                 if (varDecl.fValue) {
1002                     SkASSERT((*iter)->statement()->get() == stmt);
1003                     if (!b.tryRemoveExpressionBefore(iter, varDecl.fValue.get())) {
1004                         *outNeedsRescan = true;
1005                     }
1006                 }
1007                 (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
1008                 *outUpdated = true;
1009             }
1010             break;
1011         }
1012         case Statement::kIf_Kind: {
1013             IfStatement& i = (IfStatement&) *stmt;
1014             if (i.fTest->fKind == Expression::kBoolLiteral_Kind) {
1015                 // constant if, collapse down to a single branch
1016                 if (((BoolLiteral&) *i.fTest).fValue) {
1017                     SkASSERT(i.fIfTrue);
1018                     (*iter)->setStatement(std::move(i.fIfTrue));
1019                 } else {
1020                     if (i.fIfFalse) {
1021                         (*iter)->setStatement(std::move(i.fIfFalse));
1022                     } else {
1023                         (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
1024                     }
1025                 }
1026                 *outUpdated = true;
1027                 *outNeedsRescan = true;
1028                 break;
1029             }
1030             if (i.fIfFalse && i.fIfFalse->isEmpty()) {
1031                 // else block doesn't do anything, remove it
1032                 i.fIfFalse.reset();
1033                 *outUpdated = true;
1034                 *outNeedsRescan = true;
1035             }
1036             if (!i.fIfFalse && i.fIfTrue->isEmpty()) {
1037                 // if block doesn't do anything, no else block
1038                 if (i.fTest->hasSideEffects()) {
1039                     // test has side effects, keep it
1040                     (*iter)->setStatement(std::unique_ptr<Statement>(
1041                                                       new ExpressionStatement(std::move(i.fTest))));
1042                 } else {
1043                     // no if, no else, no test side effects, kill the whole if
1044                     // statement
1045                     (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
1046                 }
1047                 *outUpdated = true;
1048                 *outNeedsRescan = true;
1049             }
1050             break;
1051         }
1052         case Statement::kSwitch_Kind: {
1053             SwitchStatement& s = (SwitchStatement&) *stmt;
1054             if (s.fValue->isConstant()) {
1055                 // switch is constant, replace it with the case that matches
1056                 bool found = false;
1057                 SwitchCase* defaultCase = nullptr;
1058                 for (const auto& c : s.fCases) {
1059                     if (!c->fValue) {
1060                         defaultCase = c.get();
1061                         continue;
1062                     }
1063                     SkASSERT(c->fValue->fKind == s.fValue->fKind);
1064                     found = c->fValue->compareConstant(*fContext, *s.fValue);
1065                     if (found) {
1066                         std::unique_ptr<Statement> newBlock = block_for_case(&s, c.get());
1067                         if (newBlock) {
1068                             (*iter)->setStatement(std::move(newBlock));
1069                             break;
1070                         } else {
1071                             if (s.fIsStatic && !(fFlags & kPermitInvalidStaticTests_Flag)) {
1072                                 this->error(s.fOffset,
1073                                             "static switch contains non-static conditional break");
1074                                 s.fIsStatic = false;
1075                             }
1076                             return; // can't simplify
1077                         }
1078                     }
1079                 }
1080                 if (!found) {
1081                     // no matching case. use default if it exists, or kill the whole thing
1082                     if (defaultCase) {
1083                         std::unique_ptr<Statement> newBlock = block_for_case(&s, defaultCase);
1084                         if (newBlock) {
1085                             (*iter)->setStatement(std::move(newBlock));
1086                         } else {
1087                             if (s.fIsStatic && !(fFlags & kPermitInvalidStaticTests_Flag)) {
1088                                 this->error(s.fOffset,
1089                                             "static switch contains non-static conditional break");
1090                                 s.fIsStatic = false;
1091                             }
1092                             return; // can't simplify
1093                         }
1094                     } else {
1095                         (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
1096                     }
1097                 }
1098                 *outUpdated = true;
1099                 *outNeedsRescan = true;
1100             }
1101             break;
1102         }
1103         case Statement::kExpression_Kind: {
1104             ExpressionStatement& e = (ExpressionStatement&) *stmt;
1105             SkASSERT((*iter)->statement()->get() == &e);
1106             if (!e.fExpression->hasSideEffects()) {
1107                 // Expression statement with no side effects, kill it
1108                 if (!b.tryRemoveExpressionBefore(iter, e.fExpression.get())) {
1109                     *outNeedsRescan = true;
1110                 }
1111                 SkASSERT((*iter)->statement()->get() == stmt);
1112                 (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
1113                 *outUpdated = true;
1114             }
1115             break;
1116         }
1117         default:
1118             break;
1119     }
1120 }
1121 
scanCFG(FunctionDefinition & f)1122 void Compiler::scanCFG(FunctionDefinition& f) {
1123     CFG cfg = CFGGenerator().getCFG(f);
1124     this->computeDataFlow(&cfg);
1125 
1126     // check for unreachable code
1127     for (size_t i = 0; i < cfg.fBlocks.size(); i++) {
1128         if (i != cfg.fStart && !cfg.fBlocks[i].fEntrances.size() &&
1129             cfg.fBlocks[i].fNodes.size()) {
1130             int offset;
1131             switch (cfg.fBlocks[i].fNodes[0].fKind) {
1132                 case BasicBlock::Node::kStatement_Kind:
1133                     offset = (*cfg.fBlocks[i].fNodes[0].statement())->fOffset;
1134                     break;
1135                 case BasicBlock::Node::kExpression_Kind:
1136                     offset = (*cfg.fBlocks[i].fNodes[0].expression())->fOffset;
1137                     break;
1138             }
1139             this->error(offset, String("unreachable"));
1140         }
1141     }
1142     if (fErrorCount) {
1143         return;
1144     }
1145 
1146     // check for dead code & undefined variables, perform constant propagation
1147     std::unordered_set<const Variable*> undefinedVariables;
1148     bool updated;
1149     bool needsRescan = false;
1150     do {
1151         if (needsRescan) {
1152             cfg = CFGGenerator().getCFG(f);
1153             this->computeDataFlow(&cfg);
1154             needsRescan = false;
1155         }
1156 
1157         updated = false;
1158         for (BasicBlock& b : cfg.fBlocks) {
1159             DefinitionMap definitions = b.fBefore;
1160 
1161             for (auto iter = b.fNodes.begin(); iter != b.fNodes.end() && !needsRescan; ++iter) {
1162                 if (iter->fKind == BasicBlock::Node::kExpression_Kind) {
1163                     this->simplifyExpression(definitions, b, &iter, &undefinedVariables, &updated,
1164                                              &needsRescan);
1165                 } else {
1166                     this->simplifyStatement(definitions, b, &iter, &undefinedVariables, &updated,
1167                                              &needsRescan);
1168                 }
1169                 if (needsRescan) {
1170                     break;
1171                 }
1172                 this->addDefinitions(*iter, &definitions);
1173             }
1174         }
1175     } while (updated);
1176     SkASSERT(!needsRescan);
1177 
1178     // verify static ifs & switches, clean up dead variable decls
1179     for (BasicBlock& b : cfg.fBlocks) {
1180         DefinitionMap definitions = b.fBefore;
1181 
1182         for (auto iter = b.fNodes.begin(); iter != b.fNodes.end() && !needsRescan;) {
1183             if (iter->fKind == BasicBlock::Node::kStatement_Kind) {
1184                 const Statement& s = **iter->statement();
1185                 switch (s.fKind) {
1186                     case Statement::kIf_Kind:
1187                         if (((const IfStatement&) s).fIsStatic &&
1188                             !(fFlags & kPermitInvalidStaticTests_Flag)) {
1189                             this->error(s.fOffset, "static if has non-static test");
1190                         }
1191                         ++iter;
1192                         break;
1193                     case Statement::kSwitch_Kind:
1194                         if (((const SwitchStatement&) s).fIsStatic &&
1195                              !(fFlags & kPermitInvalidStaticTests_Flag)) {
1196                             this->error(s.fOffset, "static switch has non-static test");
1197                         }
1198                         ++iter;
1199                         break;
1200                     case Statement::kVarDeclarations_Kind: {
1201                         VarDeclarations& decls = *((VarDeclarationsStatement&) s).fDeclaration;
1202                         for (auto varIter = decls.fVars.begin(); varIter != decls.fVars.end();) {
1203                             if ((*varIter)->fKind == Statement::kNop_Kind) {
1204                                 varIter = decls.fVars.erase(varIter);
1205                             } else {
1206                                 ++varIter;
1207                             }
1208                         }
1209                         if (!decls.fVars.size()) {
1210                             iter = b.fNodes.erase(iter);
1211                         } else {
1212                             ++iter;
1213                         }
1214                         break;
1215                     }
1216                     default:
1217                         ++iter;
1218                         break;
1219                 }
1220             } else {
1221                 ++iter;
1222             }
1223         }
1224     }
1225 
1226     // check for missing return
1227     if (f.fDeclaration.fReturnType != *fContext->fVoid_Type) {
1228         if (cfg.fBlocks[cfg.fExit].fEntrances.size()) {
1229             this->error(f.fOffset, String("function can exit without returning a value"));
1230         }
1231     }
1232 }
1233 
convertProgram(Program::Kind kind,String text,const Program::Settings & settings)1234 std::unique_ptr<Program> Compiler::convertProgram(Program::Kind kind, String text,
1235                                                   const Program::Settings& settings) {
1236     fErrorText = "";
1237     fErrorCount = 0;
1238     std::vector<std::unique_ptr<ProgramElement>>* inherited;
1239     std::vector<std::unique_ptr<ProgramElement>> elements;
1240     switch (kind) {
1241         case Program::kVertex_Kind:
1242             inherited = &fVertexInclude;
1243             fIRGenerator->fSymbolTable = fVertexSymbolTable;
1244             fIRGenerator->start(&settings, inherited);
1245             break;
1246         case Program::kFragment_Kind:
1247             inherited = &fFragmentInclude;
1248             fIRGenerator->fSymbolTable = fFragmentSymbolTable;
1249             fIRGenerator->start(&settings, inherited);
1250             break;
1251         case Program::kGeometry_Kind:
1252             inherited = &fGeometryInclude;
1253             fIRGenerator->fSymbolTable = fGeometrySymbolTable;
1254             fIRGenerator->start(&settings, inherited);
1255             break;
1256         case Program::kFragmentProcessor_Kind:
1257             inherited = nullptr;
1258             fIRGenerator->start(&settings, nullptr);
1259             fIRGenerator->convertProgram(kind, SKSL_FP_INCLUDE, strlen(SKSL_FP_INCLUDE), *fTypes,
1260                                          &elements);
1261             fIRGenerator->fSymbolTable->markAllFunctionsBuiltin();
1262             break;
1263         case Program::kPipelineStage_Kind:
1264             inherited = nullptr;
1265             fIRGenerator->start(&settings, nullptr);
1266             fIRGenerator->convertProgram(kind, SKSL_PIPELINE_STAGE_INCLUDE,
1267                                          strlen(SKSL_PIPELINE_STAGE_INCLUDE), *fTypes, &elements);
1268             fIRGenerator->fSymbolTable->markAllFunctionsBuiltin();
1269             break;
1270     }
1271     for (auto& element : elements) {
1272         if (element->fKind == ProgramElement::kEnum_Kind) {
1273             ((Enum&) *element).fBuiltin = true;
1274         }
1275     }
1276     std::unique_ptr<String> textPtr(new String(std::move(text)));
1277     fSource = textPtr.get();
1278     fIRGenerator->convertProgram(kind, textPtr->c_str(), textPtr->size(), *fTypes, &elements);
1279     auto result = std::unique_ptr<Program>(new Program(kind,
1280                                                        std::move(textPtr),
1281                                                        settings,
1282                                                        fContext,
1283                                                        inherited,
1284                                                        std::move(elements),
1285                                                        fIRGenerator->fSymbolTable,
1286                                                        fIRGenerator->fInputs));
1287     if (fErrorCount) {
1288         return nullptr;
1289     }
1290     return result;
1291 }
1292 
optimize(Program & program)1293 bool Compiler::optimize(Program& program) {
1294     SkASSERT(!fErrorCount);
1295     if (!program.fIsOptimized) {
1296         program.fIsOptimized = true;
1297         fIRGenerator->fKind = program.fKind;
1298         fIRGenerator->fSettings = &program.fSettings;
1299         for (auto& element : program) {
1300             if (element.fKind == ProgramElement::kFunction_Kind) {
1301                 this->scanCFG((FunctionDefinition&) element);
1302             }
1303         }
1304         fSource = nullptr;
1305     }
1306     return fErrorCount == 0;
1307 }
1308 
specialize(Program & program,const std::unordered_map<SkSL::String,SkSL::Program::Settings::Value> & inputs)1309 std::unique_ptr<Program> Compiler::specialize(
1310                    Program& program,
1311                    const std::unordered_map<SkSL::String, SkSL::Program::Settings::Value>& inputs) {
1312     std::vector<std::unique_ptr<ProgramElement>> elements;
1313     for (const auto& e : program) {
1314         elements.push_back(e.clone());
1315     }
1316     Program::Settings settings;
1317     settings.fCaps = program.fSettings.fCaps;
1318     for (auto iter = inputs.begin(); iter != inputs.end(); ++iter) {
1319         settings.fArgs.insert(*iter);
1320     }
1321     std::unique_ptr<Program> result(new Program(program.fKind,
1322                                                 nullptr,
1323                                                 settings,
1324                                                 program.fContext,
1325                                                 program.fInheritedElements,
1326                                                 std::move(elements),
1327                                                 program.fSymbols,
1328                                                 program.fInputs));
1329     return result;
1330 }
1331 
toSPIRV(Program & program,OutputStream & out)1332 bool Compiler::toSPIRV(Program& program, OutputStream& out) {
1333     if (!this->optimize(program)) {
1334         return false;
1335     }
1336 #ifdef SK_ENABLE_SPIRV_VALIDATION
1337     StringStream buffer;
1338     fSource = program.fSource.get();
1339     SPIRVCodeGenerator cg(fContext.get(), &program, this, &buffer);
1340     bool result = cg.generateCode();
1341     fSource = nullptr;
1342     if (result) {
1343         spvtools::SpirvTools tools(SPV_ENV_VULKAN_1_0);
1344         const String& data = buffer.str();
1345         SkASSERT(0 == data.size() % 4);
1346         auto dumpmsg = [](spv_message_level_t, const char*, const spv_position_t&, const char* m) {
1347             SkDebugf("SPIR-V validation error: %s\n", m);
1348         };
1349         tools.SetMessageConsumer(dumpmsg);
1350         // Verify that the SPIR-V we produced is valid. If this SkASSERT fails, check the logs prior
1351         // to the failure to see the validation errors.
1352         SkAssertResult(tools.Validate((const uint32_t*) data.c_str(), data.size() / 4));
1353         out.write(data.c_str(), data.size());
1354     }
1355 #else
1356     fSource = program.fSource.get();
1357     SPIRVCodeGenerator cg(fContext.get(), &program, this, &out);
1358     bool result = cg.generateCode();
1359     fSource = nullptr;
1360 #endif
1361     return result;
1362 }
1363 
toSPIRV(Program & program,String * out)1364 bool Compiler::toSPIRV(Program& program, String* out) {
1365     StringStream buffer;
1366     bool result = this->toSPIRV(program, buffer);
1367     if (result) {
1368         *out = buffer.str();
1369     }
1370     return result;
1371 }
1372 
toGLSL(Program & program,OutputStream & out)1373 bool Compiler::toGLSL(Program& program, OutputStream& out) {
1374     if (!this->optimize(program)) {
1375         return false;
1376     }
1377     fSource = program.fSource.get();
1378     GLSLCodeGenerator cg(fContext.get(), &program, this, &out);
1379     bool result = cg.generateCode();
1380     fSource = nullptr;
1381     return result;
1382 }
1383 
toGLSL(Program & program,String * out)1384 bool Compiler::toGLSL(Program& program, String* out) {
1385     StringStream buffer;
1386     bool result = this->toGLSL(program, buffer);
1387     if (result) {
1388         *out = buffer.str();
1389     }
1390     return result;
1391 }
1392 
toMetal(Program & program,OutputStream & out)1393 bool Compiler::toMetal(Program& program, OutputStream& out) {
1394     if (!this->optimize(program)) {
1395         return false;
1396     }
1397     MetalCodeGenerator cg(fContext.get(), &program, this, &out);
1398     bool result = cg.generateCode();
1399     return result;
1400 }
1401 
toMetal(Program & program,String * out)1402 bool Compiler::toMetal(Program& program, String* out) {
1403     if (!this->optimize(program)) {
1404         return false;
1405     }
1406     StringStream buffer;
1407     bool result = this->toMetal(program, buffer);
1408     if (result) {
1409         *out = buffer.str();
1410     }
1411     return result;
1412 }
1413 
toCPP(Program & program,String name,OutputStream & out)1414 bool Compiler::toCPP(Program& program, String name, OutputStream& out) {
1415     if (!this->optimize(program)) {
1416         return false;
1417     }
1418     fSource = program.fSource.get();
1419     CPPCodeGenerator cg(fContext.get(), &program, this, name, &out);
1420     bool result = cg.generateCode();
1421     fSource = nullptr;
1422     return result;
1423 }
1424 
toH(Program & program,String name,OutputStream & out)1425 bool Compiler::toH(Program& program, String name, OutputStream& out) {
1426     if (!this->optimize(program)) {
1427         return false;
1428     }
1429     fSource = program.fSource.get();
1430     HCodeGenerator cg(fContext.get(), &program, this, name, &out);
1431     bool result = cg.generateCode();
1432     fSource = nullptr;
1433     return result;
1434 }
1435 
toPipelineStage(const Program & program,String * out,std::vector<FormatArg> * outFormatArgs)1436 bool Compiler::toPipelineStage(const Program& program, String* out,
1437                                std::vector<FormatArg>* outFormatArgs) {
1438     SkASSERT(program.fIsOptimized);
1439     fSource = program.fSource.get();
1440     StringStream buffer;
1441     PipelineStageCodeGenerator cg(fContext.get(), &program, this, &buffer, outFormatArgs);
1442     bool result = cg.generateCode();
1443     fSource = nullptr;
1444     if (result) {
1445         *out = buffer.str();
1446     }
1447     return result;
1448 }
1449 
OperatorName(Token::Kind kind)1450 const char* Compiler::OperatorName(Token::Kind kind) {
1451     switch (kind) {
1452         case Token::PLUS:         return "+";
1453         case Token::MINUS:        return "-";
1454         case Token::STAR:         return "*";
1455         case Token::SLASH:        return "/";
1456         case Token::PERCENT:      return "%";
1457         case Token::SHL:          return "<<";
1458         case Token::SHR:          return ">>";
1459         case Token::LOGICALNOT:   return "!";
1460         case Token::LOGICALAND:   return "&&";
1461         case Token::LOGICALOR:    return "||";
1462         case Token::LOGICALXOR:   return "^^";
1463         case Token::BITWISENOT:   return "~";
1464         case Token::BITWISEAND:   return "&";
1465         case Token::BITWISEOR:    return "|";
1466         case Token::BITWISEXOR:   return "^";
1467         case Token::EQ:           return "=";
1468         case Token::EQEQ:         return "==";
1469         case Token::NEQ:          return "!=";
1470         case Token::LT:           return "<";
1471         case Token::GT:           return ">";
1472         case Token::LTEQ:         return "<=";
1473         case Token::GTEQ:         return ">=";
1474         case Token::PLUSEQ:       return "+=";
1475         case Token::MINUSEQ:      return "-=";
1476         case Token::STAREQ:       return "*=";
1477         case Token::SLASHEQ:      return "/=";
1478         case Token::PERCENTEQ:    return "%=";
1479         case Token::SHLEQ:        return "<<=";
1480         case Token::SHREQ:        return ">>=";
1481         case Token::LOGICALANDEQ: return "&&=";
1482         case Token::LOGICALOREQ:  return "||=";
1483         case Token::LOGICALXOREQ: return "^^=";
1484         case Token::BITWISEANDEQ: return "&=";
1485         case Token::BITWISEOREQ:  return "|=";
1486         case Token::BITWISEXOREQ: return "^=";
1487         case Token::PLUSPLUS:     return "++";
1488         case Token::MINUSMINUS:   return "--";
1489         case Token::COMMA:        return ",";
1490         default:
1491             ABORT("unsupported operator: %d\n", kind);
1492     }
1493 }
1494 
1495 
IsAssignment(Token::Kind op)1496 bool Compiler::IsAssignment(Token::Kind op) {
1497     switch (op) {
1498         case Token::EQ:           // fall through
1499         case Token::PLUSEQ:       // fall through
1500         case Token::MINUSEQ:      // fall through
1501         case Token::STAREQ:       // fall through
1502         case Token::SLASHEQ:      // fall through
1503         case Token::PERCENTEQ:    // fall through
1504         case Token::SHLEQ:        // fall through
1505         case Token::SHREQ:        // fall through
1506         case Token::BITWISEOREQ:  // fall through
1507         case Token::BITWISEXOREQ: // fall through
1508         case Token::BITWISEANDEQ: // fall through
1509         case Token::LOGICALOREQ:  // fall through
1510         case Token::LOGICALXOREQ: // fall through
1511         case Token::LOGICALANDEQ:
1512             return true;
1513         default:
1514             return false;
1515     }
1516 }
1517 
position(int offset)1518 Position Compiler::position(int offset) {
1519     SkASSERT(fSource);
1520     int line = 1;
1521     int column = 1;
1522     for (int i = 0; i < offset; i++) {
1523         if ((*fSource)[i] == '\n') {
1524             ++line;
1525             column = 1;
1526         }
1527         else {
1528             ++column;
1529         }
1530     }
1531     return Position(line, column);
1532 }
1533 
error(int offset,String msg)1534 void Compiler::error(int offset, String msg) {
1535     fErrorCount++;
1536     Position pos = this->position(offset);
1537     fErrorText += "error: " + to_string(pos.fLine) + ": " + msg.c_str() + "\n";
1538 }
1539 
errorText()1540 String Compiler::errorText() {
1541     this->writeErrorCount();
1542     fErrorCount = 0;
1543     String result = fErrorText;
1544     return result;
1545 }
1546 
writeErrorCount()1547 void Compiler::writeErrorCount() {
1548     if (fErrorCount) {
1549         fErrorText += to_string(fErrorCount) + " error";
1550         if (fErrorCount > 1) {
1551             fErrorText += "s";
1552         }
1553         fErrorText += "\n";
1554     }
1555 }
1556 
1557 } // namespace
1558