1 /*
2  * Copyright 2020 Google LLC.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/SkSLAnalysis.h"
9 
10 #include "include/private/SkSLModifiers.h"
11 #include "include/private/SkSLProgramElement.h"
12 #include "include/private/SkSLSampleUsage.h"
13 #include "include/private/SkSLStatement.h"
14 #include "src/sksl/SkSLCompiler.h"
15 #include "src/sksl/SkSLErrorReporter.h"
16 #include "src/sksl/ir/SkSLExpression.h"
17 #include "src/sksl/ir/SkSLProgram.h"
18 
19 // ProgramElements
20 #include "src/sksl/ir/SkSLEnum.h"
21 #include "src/sksl/ir/SkSLExtension.h"
22 #include "src/sksl/ir/SkSLFunctionDefinition.h"
23 #include "src/sksl/ir/SkSLInterfaceBlock.h"
24 #include "src/sksl/ir/SkSLSection.h"
25 #include "src/sksl/ir/SkSLVarDeclarations.h"
26 
27 // Statements
28 #include "src/sksl/ir/SkSLBlock.h"
29 #include "src/sksl/ir/SkSLBreakStatement.h"
30 #include "src/sksl/ir/SkSLContinueStatement.h"
31 #include "src/sksl/ir/SkSLDiscardStatement.h"
32 #include "src/sksl/ir/SkSLDoStatement.h"
33 #include "src/sksl/ir/SkSLExpressionStatement.h"
34 #include "src/sksl/ir/SkSLForStatement.h"
35 #include "src/sksl/ir/SkSLIfStatement.h"
36 #include "src/sksl/ir/SkSLNop.h"
37 #include "src/sksl/ir/SkSLReturnStatement.h"
38 #include "src/sksl/ir/SkSLSwitchStatement.h"
39 
40 // Expressions
41 #include "src/sksl/ir/SkSLBinaryExpression.h"
42 #include "src/sksl/ir/SkSLBoolLiteral.h"
43 #include "src/sksl/ir/SkSLConstructor.h"
44 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
45 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
46 #include "src/sksl/ir/SkSLExternalFunctionCall.h"
47 #include "src/sksl/ir/SkSLExternalFunctionReference.h"
48 #include "src/sksl/ir/SkSLFieldAccess.h"
49 #include "src/sksl/ir/SkSLFloatLiteral.h"
50 #include "src/sksl/ir/SkSLFunctionCall.h"
51 #include "src/sksl/ir/SkSLFunctionReference.h"
52 #include "src/sksl/ir/SkSLIndexExpression.h"
53 #include "src/sksl/ir/SkSLInlineMarker.h"
54 #include "src/sksl/ir/SkSLIntLiteral.h"
55 #include "src/sksl/ir/SkSLPostfixExpression.h"
56 #include "src/sksl/ir/SkSLPrefixExpression.h"
57 #include "src/sksl/ir/SkSLSetting.h"
58 #include "src/sksl/ir/SkSLSwizzle.h"
59 #include "src/sksl/ir/SkSLTernaryExpression.h"
60 #include "src/sksl/ir/SkSLTypeReference.h"
61 #include "src/sksl/ir/SkSLVariableReference.h"
62 
63 namespace SkSL {
64 
65 namespace {
66 
is_sample_call_to_fp(const FunctionCall & fc,const Variable & fp)67 static bool is_sample_call_to_fp(const FunctionCall& fc, const Variable& fp) {
68     const FunctionDeclaration& f = fc.function();
69     return f.isBuiltin() && f.name() == "sample" && fc.arguments().size() >= 1 &&
70            fc.arguments()[0]->is<VariableReference>() &&
71            fc.arguments()[0]->as<VariableReference>().variable() == &fp;
72 }
73 
74 // Visitor that determines the merged SampleUsage for a given child 'fp' in the program.
75 class MergeSampleUsageVisitor : public ProgramVisitor {
76 public:
MergeSampleUsageVisitor(const Context & context,const Variable & fp,bool writesToSampleCoords)77     MergeSampleUsageVisitor(const Context& context, const Variable& fp, bool writesToSampleCoords)
78             : fContext(context), fFP(fp), fWritesToSampleCoords(writesToSampleCoords) {}
79 
visit(const Program & program)80     SampleUsage visit(const Program& program) {
81         fUsage = SampleUsage(); // reset to none
82         INHERITED::visit(program);
83         return fUsage;
84     }
85 
86 protected:
87     const Context& fContext;
88     const Variable& fFP;
89     const bool fWritesToSampleCoords;
90     SampleUsage fUsage;
91 
visitExpression(const Expression & e)92     bool visitExpression(const Expression& e) override {
93         // Looking for sample(fp, ...)
94         if (e.is<FunctionCall>()) {
95             const FunctionCall& fc = e.as<FunctionCall>();
96             if (is_sample_call_to_fp(fc, fFP)) {
97                 // Determine the type of call at this site, and merge it with the accumulated state
98                 if (fc.arguments().size() >= 2) {
99                     const Expression* coords = fc.arguments()[1].get();
100                     if (coords->type() == *fContext.fTypes.fFloat2) {
101                         // If the coords are a direct reference to the program's sample-coords,
102                         // and those coords are never modified, we can conservatively turn this
103                         // into PassThrough sampling. In all other cases, we consider it Explicit.
104                         if (!fWritesToSampleCoords && coords->is<VariableReference>() &&
105                             coords->as<VariableReference>()
106                                             .variable()
107                                             ->modifiers()
108                                             .fLayout.fBuiltin == SK_MAIN_COORDS_BUILTIN) {
109                             fUsage.merge(SampleUsage::PassThrough());
110                         } else {
111                             fUsage.merge(SampleUsage::Explicit());
112                         }
113                     } else {
114                         // sample(fp, half4 inputColor) -> PassThrough
115                         fUsage.merge(SampleUsage::PassThrough());
116                     }
117                 } else {
118                     // sample(fp) -> PassThrough
119                     fUsage.merge(SampleUsage::PassThrough());
120                 }
121                 // NOTE: we don't return true here just because we found a sample call. We need to
122                 // process the entire program and merge across all encountered calls.
123             }
124         }
125 
126         return INHERITED::visitExpression(e);
127     }
128 
129     using INHERITED = ProgramVisitor;
130 };
131 
132 // Visitor that searches through the program for references to a particular builtin variable
133 class BuiltinVariableVisitor : public ProgramVisitor {
134 public:
BuiltinVariableVisitor(int builtin)135     BuiltinVariableVisitor(int builtin) : fBuiltin(builtin) {}
136 
visitExpression(const Expression & e)137     bool visitExpression(const Expression& e) override {
138         if (e.is<VariableReference>()) {
139             const VariableReference& var = e.as<VariableReference>();
140             return var.variable()->modifiers().fLayout.fBuiltin == fBuiltin;
141         }
142         return INHERITED::visitExpression(e);
143     }
144 
145     int fBuiltin;
146 
147     using INHERITED = ProgramVisitor;
148 };
149 
150 // Visitor that counts the number of nodes visited
151 class NodeCountVisitor : public ProgramVisitor {
152 public:
NodeCountVisitor(int limit)153     NodeCountVisitor(int limit) : fLimit(limit) {}
154 
visit(const Statement & s)155     int visit(const Statement& s) {
156         this->visitStatement(s);
157         return fCount;
158     }
159 
visitExpression(const Expression & e)160     bool visitExpression(const Expression& e) override {
161         ++fCount;
162         return (fCount >= fLimit) || INHERITED::visitExpression(e);
163     }
164 
visitProgramElement(const ProgramElement & p)165     bool visitProgramElement(const ProgramElement& p) override {
166         ++fCount;
167         return (fCount >= fLimit) || INHERITED::visitProgramElement(p);
168     }
169 
visitStatement(const Statement & s)170     bool visitStatement(const Statement& s) override {
171         ++fCount;
172         return (fCount >= fLimit) || INHERITED::visitStatement(s);
173     }
174 
175 private:
176     int fCount = 0;
177     int fLimit;
178 
179     using INHERITED = ProgramVisitor;
180 };
181 
182 class ProgramUsageVisitor : public ProgramVisitor {
183 public:
ProgramUsageVisitor(ProgramUsage * usage,int delta)184     ProgramUsageVisitor(ProgramUsage* usage, int delta) : fUsage(usage), fDelta(delta) {}
185 
visitProgramElement(const ProgramElement & pe)186     bool visitProgramElement(const ProgramElement& pe) override {
187         if (pe.is<FunctionDefinition>()) {
188             for (const Variable* param : pe.as<FunctionDefinition>().declaration().parameters()) {
189                 // Ensure function-parameter variables exist in the variable usage map. They aren't
190                 // otherwise declared, but ProgramUsage::get() should be able to find them, even if
191                 // they are unread and unwritten.
192                 fUsage->fVariableCounts[param];
193             }
194         } else if (pe.is<InterfaceBlock>()) {
195             // Ensure interface-block variables exist in the variable usage map.
196             fUsage->fVariableCounts[&pe.as<InterfaceBlock>().variable()];
197         }
198         return INHERITED::visitProgramElement(pe);
199     }
200 
visitStatement(const Statement & s)201     bool visitStatement(const Statement& s) override {
202         if (s.is<VarDeclaration>()) {
203             // Add all declared variables to the usage map (even if never otherwise accessed).
204             const VarDeclaration& vd = s.as<VarDeclaration>();
205             ProgramUsage::VariableCounts& counts = fUsage->fVariableCounts[&vd.var()];
206             counts.fDeclared += fDelta;
207             SkASSERT(counts.fDeclared >= 0);
208             if (vd.value()) {
209                 // The initial-value expression, when present, counts as a write.
210                 counts.fWrite += fDelta;
211             }
212         }
213         return INHERITED::visitStatement(s);
214     }
215 
visitExpression(const Expression & e)216     bool visitExpression(const Expression& e) override {
217         if (e.is<FunctionCall>()) {
218             const FunctionDeclaration* f = &e.as<FunctionCall>().function();
219             fUsage->fCallCounts[f] += fDelta;
220             SkASSERT(fUsage->fCallCounts[f] >= 0);
221         } else if (e.is<VariableReference>()) {
222             const VariableReference& ref = e.as<VariableReference>();
223             ProgramUsage::VariableCounts& counts = fUsage->fVariableCounts[ref.variable()];
224             switch (ref.refKind()) {
225                 case VariableRefKind::kRead:
226                     counts.fRead += fDelta;
227                     break;
228                 case VariableRefKind::kWrite:
229                     counts.fWrite += fDelta;
230                     break;
231                 case VariableRefKind::kReadWrite:
232                 case VariableRefKind::kPointer:
233                     counts.fRead += fDelta;
234                     counts.fWrite += fDelta;
235                     break;
236             }
237             SkASSERT(counts.fRead >= 0 && counts.fWrite >= 0);
238         }
239         return INHERITED::visitExpression(e);
240     }
241 
242     using ProgramVisitor::visitProgramElement;
243     using ProgramVisitor::visitStatement;
244 
245     ProgramUsage* fUsage;
246     int fDelta;
247     using INHERITED = ProgramVisitor;
248 };
249 
250 class VariableWriteVisitor : public ProgramVisitor {
251 public:
VariableWriteVisitor(const Variable * var)252     VariableWriteVisitor(const Variable* var)
253         : fVar(var) {}
254 
visit(const Statement & s)255     bool visit(const Statement& s) {
256         return this->visitStatement(s);
257     }
258 
visitExpression(const Expression & e)259     bool visitExpression(const Expression& e) override {
260         if (e.is<VariableReference>()) {
261             const VariableReference& ref = e.as<VariableReference>();
262             if (ref.variable() == fVar &&
263                 (ref.refKind() == VariableReference::RefKind::kWrite ||
264                  ref.refKind() == VariableReference::RefKind::kReadWrite ||
265                  ref.refKind() == VariableReference::RefKind::kPointer)) {
266                 return true;
267             }
268         }
269         return INHERITED::visitExpression(e);
270     }
271 
272 private:
273     const Variable* fVar;
274 
275     using INHERITED = ProgramVisitor;
276 };
277 
278 // If a caller doesn't care about errors, we can use this trivial reporter that just counts up.
279 class TrivialErrorReporter : public ErrorReporter {
280 public:
error(int offset,String)281     void error(int offset, String) override { ++fErrorCount; }
errorCount()282     int errorCount() override { return fErrorCount; }
setErrorCount(int c)283     void setErrorCount(int c) override { fErrorCount = c; }
284 
285 private:
286     int fErrorCount = 0;
287 };
288 
289 // This isn't actually using ProgramVisitor, because it only considers a subset of the fields for
290 // any given expression kind. For instance, when indexing an array (e.g. `x[1]`), we only want to
291 // know if the base (`x`) is assignable; the index expression (`1`) doesn't need to be.
292 class IsAssignableVisitor {
293 public:
IsAssignableVisitor(ErrorReporter * errors)294     IsAssignableVisitor(ErrorReporter* errors) : fErrors(errors) {}
295 
visit(Expression & expr,Analysis::AssignmentInfo * info)296     bool visit(Expression& expr, Analysis::AssignmentInfo* info) {
297         int oldErrorCount = fErrors->errorCount();
298         this->visitExpression(expr);
299         if (info) {
300             info->fAssignedVar = fAssignedVar;
301         }
302         return fErrors->errorCount() == oldErrorCount;
303     }
304 
visitExpression(Expression & expr)305     void visitExpression(Expression& expr) {
306         switch (expr.kind()) {
307             case Expression::Kind::kVariableReference: {
308                 VariableReference& varRef = expr.as<VariableReference>();
309                 const Variable* var = varRef.variable();
310                 if (var->modifiers().fFlags & (Modifiers::kConst_Flag | Modifiers::kUniform_Flag)) {
311                     fErrors->error(expr.fOffset,
312                                    "cannot modify immutable variable '" + var->name() + "'");
313                 } else {
314                     SkASSERT(fAssignedVar == nullptr);
315                     fAssignedVar = &varRef;
316                 }
317                 break;
318             }
319             case Expression::Kind::kFieldAccess:
320                 this->visitExpression(*expr.as<FieldAccess>().base());
321                 break;
322 
323             case Expression::Kind::kSwizzle: {
324                 const Swizzle& swizzle = expr.as<Swizzle>();
325                 this->checkSwizzleWrite(swizzle);
326                 this->visitExpression(*swizzle.base());
327                 break;
328             }
329             case Expression::Kind::kIndex:
330                 this->visitExpression(*expr.as<IndexExpression>().base());
331                 break;
332 
333             default:
334                 fErrors->error(expr.fOffset, "cannot assign to this expression");
335                 break;
336         }
337     }
338 
339 private:
checkSwizzleWrite(const Swizzle & swizzle)340     void checkSwizzleWrite(const Swizzle& swizzle) {
341         int bits = 0;
342         for (int8_t idx : swizzle.components()) {
343             SkASSERT(idx >= SwizzleComponent::X && idx <= SwizzleComponent::W);
344             int bit = 1 << idx;
345             if (bits & bit) {
346                 fErrors->error(swizzle.fOffset,
347                                "cannot write to the same swizzle field more than once");
348                 break;
349             }
350             bits |= bit;
351         }
352     }
353 
354     ErrorReporter* fErrors;
355     VariableReference* fAssignedVar = nullptr;
356 
357     using INHERITED = ProgramVisitor;
358 };
359 
360 class SwitchCaseContainsExit : public ProgramVisitor {
361 public:
SwitchCaseContainsExit(bool conditionalExits)362     SwitchCaseContainsExit(bool conditionalExits) : fConditionalExits(conditionalExits) {}
363 
visitStatement(const Statement & stmt)364     bool visitStatement(const Statement& stmt) override {
365         switch (stmt.kind()) {
366             case Statement::Kind::kBlock:
367             case Statement::Kind::kSwitchCase:
368                 return INHERITED::visitStatement(stmt);
369 
370             case Statement::Kind::kReturn:
371                 // Returns are an early exit regardless of the surrounding control structures.
372                 return fConditionalExits ? fInConditional : !fInConditional;
373 
374             case Statement::Kind::kContinue:
375                 // Continues are an early exit from switches, but not loops.
376                 return !fInLoop &&
377                        (fConditionalExits ? fInConditional : !fInConditional);
378 
379             case Statement::Kind::kBreak:
380                 // Breaks cannot escape from switches or loops.
381                 return !fInLoop && !fInSwitch &&
382                        (fConditionalExits ? fInConditional : !fInConditional);
383 
384             case Statement::Kind::kIf: {
385                 ++fInConditional;
386                 bool result = INHERITED::visitStatement(stmt);
387                 --fInConditional;
388                 return result;
389             }
390 
391             case Statement::Kind::kFor:
392             case Statement::Kind::kDo: {
393                 // Loops are treated as conditionals because a loop could potentially execute zero
394                 // times. We don't have a straightforward way to determine that a loop definitely
395                 // executes at least once.
396                 ++fInConditional;
397                 ++fInLoop;
398                 bool result = INHERITED::visitStatement(stmt);
399                 --fInLoop;
400                 --fInConditional;
401                 return result;
402             }
403 
404             case Statement::Kind::kSwitch: {
405                 ++fInSwitch;
406                 bool result = INHERITED::visitStatement(stmt);
407                 --fInSwitch;
408                 return result;
409             }
410 
411             default:
412                 return false;
413         }
414     }
415 
416     bool fConditionalExits = false;
417     int fInConditional = 0;
418     int fInLoop = 0;
419     int fInSwitch = 0;
420     using INHERITED = ProgramVisitor;
421 };
422 
423 class ReturnsOnAllPathsVisitor : public ProgramVisitor {
424 public:
visitExpression(const Expression & expr)425     bool visitExpression(const Expression& expr) override {
426         // We can avoid processing expressions entirely.
427         return false;
428     }
429 
visitStatement(const Statement & stmt)430     bool visitStatement(const Statement& stmt) override {
431         switch (stmt.kind()) {
432             // Returns, breaks, or continues will stop the scan, so only one of these should ever be
433             // true.
434             case Statement::Kind::kReturn:
435                 fFoundReturn = true;
436                 return true;
437 
438             case Statement::Kind::kBreak:
439                 fFoundBreak = true;
440                 return true;
441 
442             case Statement::Kind::kContinue:
443                 fFoundContinue = true;
444                 return true;
445 
446             case Statement::Kind::kIf: {
447                 const IfStatement& i = stmt.as<IfStatement>();
448                 ReturnsOnAllPathsVisitor trueVisitor;
449                 ReturnsOnAllPathsVisitor falseVisitor;
450                 trueVisitor.visitStatement(*i.ifTrue());
451                 if (i.ifFalse()) {
452                     falseVisitor.visitStatement(*i.ifFalse());
453                 }
454                 // If either branch leads to a break or continue, we report the entire if as
455                 // containing a break or continue, since we don't know which side will be reached.
456                 fFoundBreak    = (trueVisitor.fFoundBreak    || falseVisitor.fFoundBreak);
457                 fFoundContinue = (trueVisitor.fFoundContinue || falseVisitor.fFoundContinue);
458                 // On the other hand, we only want to report returns that definitely happen, so we
459                 // require those to be found on both sides.
460                 fFoundReturn   = (trueVisitor.fFoundReturn   && falseVisitor.fFoundReturn);
461                 return fFoundBreak || fFoundContinue || fFoundReturn;
462             }
463             case Statement::Kind::kFor: {
464                 const ForStatement& f = stmt.as<ForStatement>();
465                 // We assume a for/while loop runs for at least one iteration; this isn't strictly
466                 // guaranteed, but it's better to be slightly over-permissive here than to fail on
467                 // reasonable code.
468                 ReturnsOnAllPathsVisitor forVisitor;
469                 forVisitor.visitStatement(*f.statement());
470                 // A for loop that contains a break or continue is safe; it won't exit the entire
471                 // function, just the loop. So we disregard those signals.
472                 fFoundReturn = forVisitor.fFoundReturn;
473                 return fFoundReturn;
474             }
475             case Statement::Kind::kDo: {
476                 const DoStatement& d = stmt.as<DoStatement>();
477                 // Do-while blocks are always entered at least once.
478                 ReturnsOnAllPathsVisitor doVisitor;
479                 doVisitor.visitStatement(*d.statement());
480                 // A do-while loop that contains a break or continue is safe; it won't exit the
481                 // entire function, just the loop. So we disregard those signals.
482                 fFoundReturn = doVisitor.fFoundReturn;
483                 return fFoundReturn;
484             }
485             case Statement::Kind::kBlock:
486                 // Blocks are definitely entered and don't imply any additional control flow.
487                 // If the block contains a break, continue or return, we want to keep that.
488                 return INHERITED::visitStatement(stmt);
489 
490             case Statement::Kind::kSwitch: {
491                 // Switches are the most complex control flow we need to deal with; fortunately we
492                 // already have good primitives for dissecting them. We need to verify that:
493                 // - a default case exists, so that every possible input value is covered
494                 // - every switch-case either (a) returns unconditionally, or
495                 //                            (b) falls through to another case that does
496                 const SwitchStatement& s = stmt.as<SwitchStatement>();
497                 bool foundDefault = false;
498                 bool fellThrough = false;
499                 for (const std::unique_ptr<Statement>& stmt : s.cases()) {
500                     // The default case is indicated by a null value. A switch without a default
501                     // case cannot definitively return, as its value might not be in the cases list.
502                     const SwitchCase& sc = stmt->as<SwitchCase>();
503                     if (!sc.value()) {
504                         foundDefault = true;
505                     }
506                     // Scan this switch-case for any exit (break, continue or return).
507                     ReturnsOnAllPathsVisitor caseVisitor;
508                     caseVisitor.visitStatement(sc);
509 
510                     // If we found a break or continue, whether conditional or not, this switch case
511                     // can't be called an unconditional return. Switches absorb breaks but not
512                     // continues.
513                     if (caseVisitor.fFoundContinue) {
514                         fFoundContinue = true;
515                         return false;
516                     }
517                     if (caseVisitor.fFoundBreak) {
518                         return false;
519                     }
520                     // We just confirmed that there weren't any breaks or continues. If we didn't
521                     // find an unconditional return either, the switch is considered fallen-through.
522                     // (There might be a conditional return, but that doesn't count.)
523                     fellThrough = !caseVisitor.fFoundReturn;
524                 }
525 
526                 // If we didn't find a default case, or the very last case fell through, this switch
527                 // doesn't meet our criteria.
528                 if (fellThrough || !foundDefault) {
529                     return false;
530                 }
531 
532                 // We scanned the entire switch, found a default case, and every section either fell
533                 // through or contained an unconditional return.
534                 fFoundReturn = true;
535                 return true;
536             }
537 
538             case Statement::Kind::kSwitchCase:
539                 // Recurse into the switch-case.
540                 return INHERITED::visitStatement(stmt);
541 
542             case Statement::Kind::kDiscard:
543             case Statement::Kind::kExpression:
544             case Statement::Kind::kInlineMarker:
545             case Statement::Kind::kNop:
546             case Statement::Kind::kVarDeclaration:
547                 // None of these statements could contain a return.
548                 break;
549         }
550 
551         return false;
552     }
553 
554     bool fFoundReturn = false;
555     bool fFoundBreak = false;
556     bool fFoundContinue = false;
557 
558     using INHERITED = ProgramVisitor;
559 };
560 
561 }  // namespace
562 
563 ////////////////////////////////////////////////////////////////////////////////
564 // Analysis
565 
GetSampleUsage(const Program & program,const Variable & fp,bool writesToSampleCoords)566 SampleUsage Analysis::GetSampleUsage(const Program& program,
567                                      const Variable& fp,
568                                      bool writesToSampleCoords) {
569     MergeSampleUsageVisitor visitor(*program.fContext, fp, writesToSampleCoords);
570     return visitor.visit(program);
571 }
572 
ReferencesBuiltin(const Program & program,int builtin)573 bool Analysis::ReferencesBuiltin(const Program& program, int builtin) {
574     BuiltinVariableVisitor visitor(builtin);
575     return visitor.visit(program);
576 }
577 
ReferencesSampleCoords(const Program & program)578 bool Analysis::ReferencesSampleCoords(const Program& program) {
579     return Analysis::ReferencesBuiltin(program, SK_MAIN_COORDS_BUILTIN);
580 }
581 
ReferencesFragCoords(const Program & program)582 bool Analysis::ReferencesFragCoords(const Program& program) {
583     return Analysis::ReferencesBuiltin(program, SK_FRAGCOORD_BUILTIN);
584 }
585 
NodeCountUpToLimit(const FunctionDefinition & function,int limit)586 int Analysis::NodeCountUpToLimit(const FunctionDefinition& function, int limit) {
587     return NodeCountVisitor{limit}.visit(*function.body());
588 }
589 
SwitchCaseContainsUnconditionalExit(Statement & stmt)590 bool Analysis::SwitchCaseContainsUnconditionalExit(Statement& stmt) {
591     return SwitchCaseContainsExit{/*conditionalExits=*/false}.visitStatement(stmt);
592 }
593 
SwitchCaseContainsConditionalExit(Statement & stmt)594 bool Analysis::SwitchCaseContainsConditionalExit(Statement& stmt) {
595     return SwitchCaseContainsExit{/*conditionalExits=*/true}.visitStatement(stmt);
596 }
597 
GetUsage(const Program & program)598 std::unique_ptr<ProgramUsage> Analysis::GetUsage(const Program& program) {
599     auto usage = std::make_unique<ProgramUsage>();
600     ProgramUsageVisitor addRefs(usage.get(), /*delta=*/+1);
601     addRefs.visit(program);
602     return usage;
603 }
604 
GetUsage(const LoadedModule & module)605 std::unique_ptr<ProgramUsage> Analysis::GetUsage(const LoadedModule& module) {
606     auto usage = std::make_unique<ProgramUsage>();
607     ProgramUsageVisitor addRefs(usage.get(), /*delta=*/+1);
608     for (const auto& element : module.fElements) {
609         addRefs.visitProgramElement(*element);
610     }
611     return usage;
612 }
613 
get(const Variable & v) const614 ProgramUsage::VariableCounts ProgramUsage::get(const Variable& v) const {
615     const VariableCounts* counts = fVariableCounts.find(&v);
616     SkASSERT(counts);
617     return *counts;
618 }
619 
isDead(const Variable & v) const620 bool ProgramUsage::isDead(const Variable& v) const {
621     const Modifiers& modifiers = v.modifiers();
622     VariableCounts counts = this->get(v);
623     if ((v.storage() != Variable::Storage::kLocal && counts.fRead) ||
624         (modifiers.fFlags &
625          (Modifiers::kIn_Flag | Modifiers::kOut_Flag | Modifiers::kUniform_Flag))) {
626         return false;
627     }
628     // Consider the variable dead if it's never read and never written (besides the initial-value).
629     return !counts.fRead && (counts.fWrite <= (v.initialValue() ? 1 : 0));
630 }
631 
get(const FunctionDeclaration & f) const632 int ProgramUsage::get(const FunctionDeclaration& f) const {
633     const int* count = fCallCounts.find(&f);
634     return count ? *count : 0;
635 }
636 
replace(const Expression * oldExpr,const Expression * newExpr)637 void ProgramUsage::replace(const Expression* oldExpr, const Expression* newExpr) {
638     if (oldExpr) {
639         ProgramUsageVisitor subRefs(this, /*delta=*/-1);
640         subRefs.visitExpression(*oldExpr);
641     }
642     if (newExpr) {
643         ProgramUsageVisitor addRefs(this, /*delta=*/+1);
644         addRefs.visitExpression(*newExpr);
645     }
646 }
647 
add(const Statement * stmt)648 void ProgramUsage::add(const Statement* stmt) {
649     ProgramUsageVisitor addRefs(this, /*delta=*/+1);
650     addRefs.visitStatement(*stmt);
651 }
652 
remove(const Expression * expr)653 void ProgramUsage::remove(const Expression* expr) {
654     ProgramUsageVisitor subRefs(this, /*delta=*/-1);
655     subRefs.visitExpression(*expr);
656 }
657 
remove(const Statement * stmt)658 void ProgramUsage::remove(const Statement* stmt) {
659     ProgramUsageVisitor subRefs(this, /*delta=*/-1);
660     subRefs.visitStatement(*stmt);
661 }
662 
remove(const ProgramElement & element)663 void ProgramUsage::remove(const ProgramElement& element) {
664     ProgramUsageVisitor subRefs(this, /*delta=*/-1);
665     subRefs.visitProgramElement(element);
666 }
667 
StatementWritesToVariable(const Statement & stmt,const Variable & var)668 bool Analysis::StatementWritesToVariable(const Statement& stmt, const Variable& var) {
669     return VariableWriteVisitor(&var).visit(stmt);
670 }
671 
IsAssignable(Expression & expr,AssignmentInfo * info,ErrorReporter * errors)672 bool Analysis::IsAssignable(Expression& expr, AssignmentInfo* info, ErrorReporter* errors) {
673     TrivialErrorReporter trivialErrors;
674     return IsAssignableVisitor{errors ? errors : &trivialErrors}.visit(expr, info);
675 }
676 
UpdateRefKind(Expression * expr,VariableRefKind refKind)677 void Analysis::UpdateRefKind(Expression* expr, VariableRefKind refKind) {
678     class RefKindWriter : public ProgramWriter {
679     public:
680         RefKindWriter(VariableReference::RefKind refKind) : fRefKind(refKind) {}
681 
682         bool visitExpression(Expression& expr) override {
683             if (expr.is<VariableReference>()) {
684                 expr.as<VariableReference>().setRefKind(fRefKind);
685             }
686             return INHERITED::visitExpression(expr);
687         }
688 
689     private:
690         VariableReference::RefKind fRefKind;
691 
692         using INHERITED = ProgramWriter;
693     };
694 
695     RefKindWriter{refKind}.visitExpression(*expr);
696 }
697 
MakeAssignmentExpr(Expression * expr,VariableReference::RefKind kind,ErrorReporter * errors)698 bool Analysis::MakeAssignmentExpr(Expression* expr,
699                                   VariableReference::RefKind kind,
700                                   ErrorReporter* errors) {
701     Analysis::AssignmentInfo info;
702     if (!Analysis::IsAssignable(*expr, &info, errors)) {
703         return false;
704     }
705     if (!info.fAssignedVar) {
706         errors->error(expr->fOffset, "can't assign to expression '" + expr->description() + "'");
707         return false;
708     }
709     info.fAssignedVar->setRefKind(kind);
710     return true;
711 }
712 
IsTrivialExpression(const Expression & expr)713 bool Analysis::IsTrivialExpression(const Expression& expr) {
714     return expr.is<IntLiteral>() ||
715            expr.is<FloatLiteral>() ||
716            expr.is<BoolLiteral>() ||
717            expr.is<VariableReference>() ||
718            (expr.is<Swizzle>() &&
719             IsTrivialExpression(*expr.as<Swizzle>().base())) ||
720            (expr.is<FieldAccess>() &&
721             IsTrivialExpression(*expr.as<FieldAccess>().base())) ||
722            (expr.isAnyConstructor() &&
723             expr.asAnyConstructor().argumentSpan().size() == 1 &&
724             IsTrivialExpression(*expr.asAnyConstructor().argumentSpan().front())) ||
725            (expr.isAnyConstructor() &&
726             expr.isConstantOrUniform()) ||
727            (expr.is<IndexExpression>() &&
728             expr.as<IndexExpression>().index()->is<IntLiteral>() &&
729             IsTrivialExpression(*expr.as<IndexExpression>().base()));
730 }
731 
IsSameExpressionTree(const Expression & left,const Expression & right)732 bool Analysis::IsSameExpressionTree(const Expression& left, const Expression& right) {
733     if (left.kind() != right.kind() || left.type() != right.type()) {
734         return false;
735     }
736 
737     // This isn't a fully exhaustive list of expressions by any stretch of the imagination; for
738     // instance, `x[y+1] = x[y+1]` isn't detected because we don't look at BinaryExpressions.
739     // Since this is intended to be used for optimization purposes, handling the common cases is
740     // sufficient.
741     switch (left.kind()) {
742         case Expression::Kind::kIntLiteral:
743             return left.as<IntLiteral>().value() == right.as<IntLiteral>().value();
744 
745         case Expression::Kind::kFloatLiteral:
746             return left.as<FloatLiteral>().value() == right.as<FloatLiteral>().value();
747 
748         case Expression::Kind::kBoolLiteral:
749             return left.as<BoolLiteral>().value() == right.as<BoolLiteral>().value();
750 
751         case Expression::Kind::kConstructorArray:
752         case Expression::Kind::kConstructorCompound:
753         case Expression::Kind::kConstructorCompoundCast:
754         case Expression::Kind::kConstructorDiagonalMatrix:
755         case Expression::Kind::kConstructorMatrixResize:
756         case Expression::Kind::kConstructorScalarCast:
757         case Expression::Kind::kConstructorStruct:
758         case Expression::Kind::kConstructorSplat: {
759             if (left.kind() != right.kind()) {
760                 return false;
761             }
762             const AnyConstructor& leftCtor = left.asAnyConstructor();
763             const AnyConstructor& rightCtor = right.asAnyConstructor();
764             const auto leftSpan = leftCtor.argumentSpan();
765             const auto rightSpan = rightCtor.argumentSpan();
766             if (leftSpan.size() != rightSpan.size()) {
767                 return false;
768             }
769             for (size_t index = 0; index < leftSpan.size(); ++index) {
770                 if (!IsSameExpressionTree(*leftSpan[index], *rightSpan[index])) {
771                     return false;
772                 }
773             }
774             return true;
775         }
776         case Expression::Kind::kFieldAccess:
777             return left.as<FieldAccess>().fieldIndex() == right.as<FieldAccess>().fieldIndex() &&
778                    IsSameExpressionTree(*left.as<FieldAccess>().base(),
779                                         *right.as<FieldAccess>().base());
780 
781         case Expression::Kind::kIndex:
782             return IsSameExpressionTree(*left.as<IndexExpression>().index(),
783                                         *right.as<IndexExpression>().index()) &&
784                    IsSameExpressionTree(*left.as<IndexExpression>().base(),
785                                         *right.as<IndexExpression>().base());
786 
787         case Expression::Kind::kSwizzle:
788             return left.as<Swizzle>().components() == right.as<Swizzle>().components() &&
789                    IsSameExpressionTree(*left.as<Swizzle>().base(), *right.as<Swizzle>().base());
790 
791         case Expression::Kind::kVariableReference:
792             return left.as<VariableReference>().variable() ==
793                    right.as<VariableReference>().variable();
794 
795         default:
796             return false;
797     }
798 }
799 
get_constant_value(const Expression & expr,double * val)800 static bool get_constant_value(const Expression& expr, double* val) {
801     const Expression* valExpr = expr.getConstantSubexpression(0);
802     if (!valExpr) {
803         return false;
804     }
805     if (valExpr->is<IntLiteral>()) {
806         *val = static_cast<double>(valExpr->as<IntLiteral>().value());
807         return true;
808     }
809     if (valExpr->is<FloatLiteral>()) {
810         *val = static_cast<double>(valExpr->as<FloatLiteral>().value());
811         return true;
812     }
813     SkDEBUGFAILF("unexpected constant type (%s)", expr.type().description().c_str());
814     return false;
815 }
816 
invalid_for_ES2(int offset,const Statement * loopInitializer,const Expression * loopTest,const Expression * loopNext,const Statement * loopStatement,Analysis::UnrollableLoopInfo & loopInfo)817 static const char* invalid_for_ES2(int offset,
818                                    const Statement* loopInitializer,
819                                    const Expression* loopTest,
820                                    const Expression* loopNext,
821                                    const Statement* loopStatement,
822                                    Analysis::UnrollableLoopInfo& loopInfo) {
823     //
824     // init_declaration has the form: type_specifier identifier = constant_expression
825     //
826     if (!loopInitializer) {
827         return "missing init declaration";
828     }
829     if (!loopInitializer->is<VarDeclaration>()) {
830         return "invalid init declaration";
831     }
832     const VarDeclaration& initDecl = loopInitializer->as<VarDeclaration>();
833     if (!initDecl.baseType().isNumber()) {
834         return "invalid type for loop index";
835     }
836     if (initDecl.arraySize() != 0) {
837         return "invalid type for loop index";
838     }
839     if (!initDecl.value()) {
840         return "missing loop index initializer";
841     }
842     if (!get_constant_value(*initDecl.value(), &loopInfo.fStart)) {
843         return "loop index initializer must be a constant expression";
844     }
845 
846     loopInfo.fIndex = &initDecl.var();
847 
848     auto is_loop_index = [&](const std::unique_ptr<Expression>& expr) {
849         return expr->is<VariableReference>() &&
850                expr->as<VariableReference>().variable() == loopInfo.fIndex;
851     };
852 
853     //
854     // condition has the form: loop_index relational_operator constant_expression
855     //
856     if (!loopTest) {
857         return "missing condition";
858     }
859     if (!loopTest->is<BinaryExpression>()) {
860         return "invalid condition";
861     }
862     const BinaryExpression& cond = loopTest->as<BinaryExpression>();
863     if (!is_loop_index(cond.left())) {
864         return "expected loop index on left hand side of condition";
865     }
866     // relational_operator is one of: > >= < <= == or !=
867     switch (cond.getOperator().kind()) {
868         case Token::Kind::TK_GT:
869         case Token::Kind::TK_GTEQ:
870         case Token::Kind::TK_LT:
871         case Token::Kind::TK_LTEQ:
872         case Token::Kind::TK_EQEQ:
873         case Token::Kind::TK_NEQ:
874             break;
875         default:
876             return "invalid relational operator";
877     }
878     double loopEnd = 0;
879     if (!get_constant_value(*cond.right(), &loopEnd)) {
880         return "loop index must be compared with a constant expression";
881     }
882 
883     //
884     // expression has one of the following forms:
885     //   loop_index++
886     //   loop_index--
887     //   loop_index += constant_expression
888     //   loop_index -= constant_expression
889     // The spec doesn't mention prefix increment and decrement, but there is some consensus that
890     // it's an oversight, so we allow those as well.
891     //
892     if (!loopNext) {
893         return "missing loop expression";
894     }
895     switch (loopNext->kind()) {
896         case Expression::Kind::kBinary: {
897             const BinaryExpression& next = loopNext->as<BinaryExpression>();
898             if (!is_loop_index(next.left())) {
899                 return "expected loop index in loop expression";
900             }
901             if (!get_constant_value(*next.right(), &loopInfo.fDelta)) {
902                 return "loop index must be modified by a constant expression";
903             }
904             switch (next.getOperator().kind()) {
905                 case Token::Kind::TK_PLUSEQ:                                      break;
906                 case Token::Kind::TK_MINUSEQ: loopInfo.fDelta = -loopInfo.fDelta; break;
907                 default:
908                     return "invalid operator in loop expression";
909             }
910         } break;
911         case Expression::Kind::kPrefix: {
912             const PrefixExpression& next = loopNext->as<PrefixExpression>();
913             if (!is_loop_index(next.operand())) {
914                 return "expected loop index in loop expression";
915             }
916             switch (next.getOperator().kind()) {
917                 case Token::Kind::TK_PLUSPLUS:   loopInfo.fDelta =  1; break;
918                 case Token::Kind::TK_MINUSMINUS: loopInfo.fDelta = -1; break;
919                 default:
920                     return "invalid operator in loop expression";
921             }
922         } break;
923         case Expression::Kind::kPostfix: {
924             const PostfixExpression& next = loopNext->as<PostfixExpression>();
925             if (!is_loop_index(next.operand())) {
926                 return "expected loop index in loop expression";
927             }
928             switch (next.getOperator().kind()) {
929                 case Token::Kind::TK_PLUSPLUS:   loopInfo.fDelta =  1; break;
930                 case Token::Kind::TK_MINUSMINUS: loopInfo.fDelta = -1; break;
931                 default:
932                     return "invalid operator in loop expression";
933             }
934         } break;
935         default:
936             return "invalid loop expression";
937     }
938 
939     //
940     // Within the body of the loop, the loop index is not statically assigned to, nor is it used as
941     // argument to a function 'out' or 'inout' parameter.
942     //
943     if (Analysis::StatementWritesToVariable(*loopStatement, initDecl.var())) {
944         return "loop index must not be modified within body of the loop";
945     }
946 
947     // Finally, compute the iteration count, based on the bounds, and the termination operator.
948     constexpr int kMaxUnrollableLoopLength = 128;
949     loopInfo.fCount = 0;
950 
951     double val = loopInfo.fStart;
952     auto evalCond = [&]() {
953         switch (cond.getOperator().kind()) {
954             case Token::Kind::TK_GT:   return val >  loopEnd;
955             case Token::Kind::TK_GTEQ: return val >= loopEnd;
956             case Token::Kind::TK_LT:   return val <  loopEnd;
957             case Token::Kind::TK_LTEQ: return val <= loopEnd;
958             case Token::Kind::TK_EQEQ: return val == loopEnd;
959             case Token::Kind::TK_NEQ:  return val != loopEnd;
960             default: SkUNREACHABLE;
961         }
962     };
963 
964     for (loopInfo.fCount = 0; loopInfo.fCount <= kMaxUnrollableLoopLength; ++loopInfo.fCount) {
965         if (!evalCond()) {
966             break;
967         }
968         val += loopInfo.fDelta;
969     }
970 
971     if (loopInfo.fCount > kMaxUnrollableLoopLength) {
972         return "loop must guarantee termination in fewer iterations";
973     }
974 
975     return nullptr;  // All checks pass
976 }
977 
ForLoopIsValidForES2(int offset,const Statement * loopInitializer,const Expression * loopTest,const Expression * loopNext,const Statement * loopStatement,Analysis::UnrollableLoopInfo * outLoopInfo,ErrorReporter * errors)978 bool Analysis::ForLoopIsValidForES2(int offset,
979                                     const Statement* loopInitializer,
980                                     const Expression* loopTest,
981                                     const Expression* loopNext,
982                                     const Statement* loopStatement,
983                                     Analysis::UnrollableLoopInfo* outLoopInfo,
984                                     ErrorReporter* errors) {
985     UnrollableLoopInfo ignored,
986                        *loopInfo = outLoopInfo ? outLoopInfo : &ignored;
987     if (const char* msg = invalid_for_ES2(
988                 offset, loopInitializer, loopTest, loopNext, loopStatement, *loopInfo)) {
989         if (errors) {
990             errors->error(offset, msg);
991         }
992         return false;
993     }
994     return true;
995 }
996 
997 // Checks for ES2 constant-expression rules, and (optionally) constant-index-expression rules
998 // (if loopIndices is non-nullptr)
999 class ConstantExpressionVisitor : public ProgramVisitor {
1000 public:
ConstantExpressionVisitor(const std::set<const Variable * > * loopIndices)1001     ConstantExpressionVisitor(const std::set<const Variable*>* loopIndices)
1002             : fLoopIndices(loopIndices) {}
1003 
visitExpression(const Expression & e)1004     bool visitExpression(const Expression& e) override {
1005         // A constant-(index)-expression is one of...
1006         switch (e.kind()) {
1007             // ... a literal value
1008             case Expression::Kind::kBoolLiteral:
1009             case Expression::Kind::kIntLiteral:
1010             case Expression::Kind::kFloatLiteral:
1011                 return false;
1012 
1013             // ... settings can appear in fragment processors; they will resolve when compiled
1014             case Expression::Kind::kSetting:
1015                 return false;
1016 
1017             // ... a global or local variable qualified as 'const', excluding function parameters.
1018             // ... loop indices as defined in section 4. [constant-index-expression]
1019             case Expression::Kind::kVariableReference: {
1020                 const Variable* v = e.as<VariableReference>().variable();
1021                 if ((v->storage() == Variable::Storage::kGlobal ||
1022                      v->storage() == Variable::Storage::kLocal) &&
1023                     (v->modifiers().fFlags & Modifiers::kConst_Flag)) {
1024                     return false;
1025                 }
1026                 return !fLoopIndices || fLoopIndices->find(v) == fLoopIndices->end();
1027             }
1028 
1029             // ... expressions composed of both of the above
1030             case Expression::Kind::kBinary:
1031             case Expression::Kind::kConstructorArray:
1032             case Expression::Kind::kConstructorCompound:
1033             case Expression::Kind::kConstructorCompoundCast:
1034             case Expression::Kind::kConstructorDiagonalMatrix:
1035             case Expression::Kind::kConstructorMatrixResize:
1036             case Expression::Kind::kConstructorScalarCast:
1037             case Expression::Kind::kConstructorSplat:
1038             case Expression::Kind::kConstructorStruct:
1039             case Expression::Kind::kFieldAccess:
1040             case Expression::Kind::kIndex:
1041             case Expression::Kind::kPrefix:
1042             case Expression::Kind::kPostfix:
1043             case Expression::Kind::kSwizzle:
1044             case Expression::Kind::kTernary:
1045                 return INHERITED::visitExpression(e);
1046 
1047             // These are completely disallowed in SkSL constant-(index)-expressions. GLSL allows
1048             // calls to built-in functions where the arguments are all constant-expressions, but
1049             // we don't guarantee that behavior. (skbug.com/10835)
1050             case Expression::Kind::kExternalFunctionCall:
1051             case Expression::Kind::kFunctionCall:
1052                 return true;
1053 
1054             // These should never appear in final IR
1055             case Expression::Kind::kExternalFunctionReference:
1056             case Expression::Kind::kFunctionReference:
1057             case Expression::Kind::kTypeReference:
1058             default:
1059                 SkDEBUGFAIL("Unexpected expression type");
1060                 return true;
1061         }
1062     }
1063 
1064 private:
1065     const std::set<const Variable*>* fLoopIndices;
1066     using INHERITED = ProgramVisitor;
1067 };
1068 
1069 class ES2IndexingVisitor : public ProgramVisitor {
1070 public:
ES2IndexingVisitor(ErrorReporter & errors)1071     ES2IndexingVisitor(ErrorReporter& errors) : fErrors(errors) {}
1072 
visitStatement(const Statement & s)1073     bool visitStatement(const Statement& s) override {
1074         if (s.is<ForStatement>()) {
1075             const ForStatement& f = s.as<ForStatement>();
1076             SkASSERT(f.initializer() && f.initializer()->is<VarDeclaration>());
1077             const Variable* var = &f.initializer()->as<VarDeclaration>().var();
1078             auto [iter, inserted] = fLoopIndices.insert(var);
1079             SkASSERT(inserted);
1080             bool result = this->visitStatement(*f.statement());
1081             fLoopIndices.erase(iter);
1082             return result;
1083         }
1084         return INHERITED::visitStatement(s);
1085     }
1086 
visitExpression(const Expression & e)1087     bool visitExpression(const Expression& e) override {
1088         if (e.is<IndexExpression>()) {
1089             const IndexExpression& i = e.as<IndexExpression>();
1090             ConstantExpressionVisitor indexerInvalid(&fLoopIndices);
1091             if (indexerInvalid.visitExpression(*i.index())) {
1092                 fErrors.error(i.fOffset, "index expression must be constant");
1093                 return true;
1094             }
1095         }
1096         return INHERITED::visitExpression(e);
1097     }
1098 
1099     using ProgramVisitor::visitProgramElement;
1100 
1101 private:
1102     ErrorReporter& fErrors;
1103     std::set<const Variable*> fLoopIndices;
1104     using INHERITED = ProgramVisitor;
1105 };
1106 
1107 
ValidateIndexingForES2(const ProgramElement & pe,ErrorReporter & errors)1108 void Analysis::ValidateIndexingForES2(const ProgramElement& pe, ErrorReporter& errors) {
1109     ES2IndexingVisitor visitor(errors);
1110     visitor.visitProgramElement(pe);
1111 }
1112 
IsConstantExpression(const Expression & expr)1113 bool Analysis::IsConstantExpression(const Expression& expr) {
1114     ConstantExpressionVisitor visitor(/*loopIndices=*/nullptr);
1115     return !visitor.visitExpression(expr);
1116 }
1117 
CanExitWithoutReturningValue(const FunctionDeclaration & funcDecl,const Statement & body)1118 bool Analysis::CanExitWithoutReturningValue(const FunctionDeclaration& funcDecl,
1119                                             const Statement& body) {
1120     if (funcDecl.returnType().isVoid()) {
1121         return false;
1122     }
1123     ReturnsOnAllPathsVisitor visitor;
1124     visitor.visitStatement(body);
1125     return !visitor.fFoundReturn;
1126 }
1127 
1128 ////////////////////////////////////////////////////////////////////////////////
1129 // ProgramVisitor
1130 
visit(const Program & program)1131 bool ProgramVisitor::visit(const Program& program) {
1132     for (const ProgramElement* pe : program.elements()) {
1133         if (this->visitProgramElement(*pe)) {
1134             return true;
1135         }
1136     }
1137     return false;
1138 }
1139 
visitExpression(typename T::Expression & e)1140 template <typename T> bool TProgramVisitor<T>::visitExpression(typename T::Expression& e) {
1141     switch (e.kind()) {
1142         case Expression::Kind::kBoolLiteral:
1143         case Expression::Kind::kExternalFunctionReference:
1144         case Expression::Kind::kFloatLiteral:
1145         case Expression::Kind::kFunctionReference:
1146         case Expression::Kind::kIntLiteral:
1147         case Expression::Kind::kSetting:
1148         case Expression::Kind::kTypeReference:
1149         case Expression::Kind::kVariableReference:
1150             // Leaf expressions return false
1151             return false;
1152 
1153         case Expression::Kind::kBinary: {
1154             auto& b = e.template as<BinaryExpression>();
1155             return (b.left() && this->visitExpressionPtr(b.left())) ||
1156                    (b.right() && this->visitExpressionPtr(b.right()));
1157         }
1158         case Expression::Kind::kConstructorArray:
1159         case Expression::Kind::kConstructorCompound:
1160         case Expression::Kind::kConstructorCompoundCast:
1161         case Expression::Kind::kConstructorDiagonalMatrix:
1162         case Expression::Kind::kConstructorMatrixResize:
1163         case Expression::Kind::kConstructorScalarCast:
1164         case Expression::Kind::kConstructorSplat:
1165         case Expression::Kind::kConstructorStruct: {
1166             auto& c = e.asAnyConstructor();
1167             for (auto& arg : c.argumentSpan()) {
1168                 if (this->visitExpressionPtr(arg)) { return true; }
1169             }
1170             return false;
1171         }
1172         case Expression::Kind::kExternalFunctionCall: {
1173             auto& c = e.template as<ExternalFunctionCall>();
1174             for (auto& arg : c.arguments()) {
1175                 if (this->visitExpressionPtr(arg)) { return true; }
1176             }
1177             return false;
1178         }
1179         case Expression::Kind::kFieldAccess:
1180             return this->visitExpressionPtr(e.template as<FieldAccess>().base());
1181 
1182         case Expression::Kind::kFunctionCall: {
1183             auto& c = e.template as<FunctionCall>();
1184             for (auto& arg : c.arguments()) {
1185                 if (arg && this->visitExpressionPtr(arg)) { return true; }
1186             }
1187             return false;
1188         }
1189         case Expression::Kind::kIndex: {
1190             auto& i = e.template as<IndexExpression>();
1191             return this->visitExpressionPtr(i.base()) || this->visitExpressionPtr(i.index());
1192         }
1193         case Expression::Kind::kPostfix:
1194             return this->visitExpressionPtr(e.template as<PostfixExpression>().operand());
1195 
1196         case Expression::Kind::kPrefix:
1197             return this->visitExpressionPtr(e.template as<PrefixExpression>().operand());
1198 
1199         case Expression::Kind::kSwizzle: {
1200             auto& s = e.template as<Swizzle>();
1201             return s.base() && this->visitExpressionPtr(s.base());
1202         }
1203 
1204         case Expression::Kind::kTernary: {
1205             auto& t = e.template as<TernaryExpression>();
1206             return this->visitExpressionPtr(t.test()) ||
1207                    (t.ifTrue() && this->visitExpressionPtr(t.ifTrue())) ||
1208                    (t.ifFalse() && this->visitExpressionPtr(t.ifFalse()));
1209         }
1210         default:
1211             SkUNREACHABLE;
1212     }
1213 }
1214 
visitStatement(typename T::Statement & s)1215 template <typename T> bool TProgramVisitor<T>::visitStatement(typename T::Statement& s) {
1216     switch (s.kind()) {
1217         case Statement::Kind::kBreak:
1218         case Statement::Kind::kContinue:
1219         case Statement::Kind::kDiscard:
1220         case Statement::Kind::kInlineMarker:
1221         case Statement::Kind::kNop:
1222             // Leaf statements just return false
1223             return false;
1224 
1225         case Statement::Kind::kBlock:
1226             for (auto& stmt : s.template as<Block>().children()) {
1227                 if (stmt && this->visitStatementPtr(stmt)) {
1228                     return true;
1229                 }
1230             }
1231             return false;
1232 
1233         case Statement::Kind::kSwitchCase: {
1234             auto& sc = s.template as<SwitchCase>();
1235             if (sc.value() && this->visitExpressionPtr(sc.value())) {
1236                 return true;
1237             }
1238             return this->visitStatementPtr(sc.statement());
1239         }
1240         case Statement::Kind::kDo: {
1241             auto& d = s.template as<DoStatement>();
1242             return this->visitExpressionPtr(d.test()) || this->visitStatementPtr(d.statement());
1243         }
1244         case Statement::Kind::kExpression:
1245             return this->visitExpressionPtr(s.template as<ExpressionStatement>().expression());
1246 
1247         case Statement::Kind::kFor: {
1248             auto& f = s.template as<ForStatement>();
1249             return (f.initializer() && this->visitStatementPtr(f.initializer())) ||
1250                    (f.test() && this->visitExpressionPtr(f.test())) ||
1251                    (f.next() && this->visitExpressionPtr(f.next())) ||
1252                    this->visitStatementPtr(f.statement());
1253         }
1254         case Statement::Kind::kIf: {
1255             auto& i = s.template as<IfStatement>();
1256             return (i.test() && this->visitExpressionPtr(i.test())) ||
1257                    (i.ifTrue() && this->visitStatementPtr(i.ifTrue())) ||
1258                    (i.ifFalse() && this->visitStatementPtr(i.ifFalse()));
1259         }
1260         case Statement::Kind::kReturn: {
1261             auto& r = s.template as<ReturnStatement>();
1262             return r.expression() && this->visitExpressionPtr(r.expression());
1263         }
1264         case Statement::Kind::kSwitch: {
1265             auto& sw = s.template as<SwitchStatement>();
1266             if (this->visitExpressionPtr(sw.value())) {
1267                 return true;
1268             }
1269             for (auto& c : sw.cases()) {
1270                 if (this->visitStatementPtr(c)) {
1271                     return true;
1272                 }
1273             }
1274             return false;
1275         }
1276         case Statement::Kind::kVarDeclaration: {
1277             auto& v = s.template as<VarDeclaration>();
1278             return v.value() && this->visitExpressionPtr(v.value());
1279         }
1280         default:
1281             SkUNREACHABLE;
1282     }
1283 }
1284 
visitProgramElement(typename T::ProgramElement & pe)1285 template <typename T> bool TProgramVisitor<T>::visitProgramElement(typename T::ProgramElement& pe) {
1286     switch (pe.kind()) {
1287         case ProgramElement::Kind::kEnum:
1288         case ProgramElement::Kind::kExtension:
1289         case ProgramElement::Kind::kFunctionPrototype:
1290         case ProgramElement::Kind::kInterfaceBlock:
1291         case ProgramElement::Kind::kModifiers:
1292         case ProgramElement::Kind::kSection:
1293         case ProgramElement::Kind::kStructDefinition:
1294             // Leaf program elements just return false by default
1295             return false;
1296 
1297         case ProgramElement::Kind::kFunction:
1298             return this->visitStatementPtr(pe.template as<FunctionDefinition>().body());
1299 
1300         case ProgramElement::Kind::kGlobalVar:
1301             return this->visitStatementPtr(pe.template as<GlobalVarDeclaration>().declaration());
1302 
1303         default:
1304             SkUNREACHABLE;
1305     }
1306 }
1307 
1308 template class TProgramVisitor<ProgramVisitorTypes>;
1309 template class TProgramVisitor<ProgramWriterTypes>;
1310 
1311 }  // namespace SkSL
1312