1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/codegen/SkSLMetalCodeGenerator.h"
9 
10 #include "src/core/SkScopeExit.h"
11 #include "src/sksl/SkSLCompiler.h"
12 #include "src/sksl/SkSLMemoryLayout.h"
13 #include "src/sksl/ir/SkSLConstructorArray.h"
14 #include "src/sksl/ir/SkSLConstructorCompoundCast.h"
15 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
16 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
17 #include "src/sksl/ir/SkSLConstructorSplat.h"
18 #include "src/sksl/ir/SkSLConstructorStruct.h"
19 #include "src/sksl/ir/SkSLExpressionStatement.h"
20 #include "src/sksl/ir/SkSLExtension.h"
21 #include "src/sksl/ir/SkSLIndexExpression.h"
22 #include "src/sksl/ir/SkSLModifiersDeclaration.h"
23 #include "src/sksl/ir/SkSLNop.h"
24 #include "src/sksl/ir/SkSLStructDefinition.h"
25 #include "src/sksl/ir/SkSLVariableReference.h"
26 
27 #include <algorithm>
28 
29 namespace SkSL {
30 
OperatorName(Operator op)31 const char* MetalCodeGenerator::OperatorName(Operator op) {
32     switch (op.kind()) {
33         case Token::Kind::TK_LOGICALXOR:  return "!=";
34         default:                          return op.operatorName();
35     }
36 }
37 
38 class MetalCodeGenerator::GlobalStructVisitor {
39 public:
40     virtual ~GlobalStructVisitor() = default;
41     virtual void visitInterfaceBlock(const InterfaceBlock& block, const String& blockName) = 0;
42     virtual void visitTexture(const Type& type, const String& name) = 0;
43     virtual void visitSampler(const Type& type, const String& name) = 0;
44     virtual void visitVariable(const Variable& var, const Expression* value) = 0;
45 };
46 
write(const char * s)47 void MetalCodeGenerator::write(const char* s) {
48     if (!s[0]) {
49         return;
50     }
51     if (fAtLineStart) {
52         for (int i = 0; i < fIndentation; i++) {
53             fOut->writeText("    ");
54         }
55     }
56     fOut->writeText(s);
57     fAtLineStart = false;
58 }
59 
writeLine(const char * s)60 void MetalCodeGenerator::writeLine(const char* s) {
61     this->write(s);
62     this->writeLine();
63 }
64 
write(const String & s)65 void MetalCodeGenerator::write(const String& s) {
66     this->write(s.c_str());
67 }
68 
writeLine(const String & s)69 void MetalCodeGenerator::writeLine(const String& s) {
70     this->writeLine(s.c_str());
71 }
72 
writeLine()73 void MetalCodeGenerator::writeLine() {
74     fOut->writeText(fLineEnding);
75     fAtLineStart = true;
76 }
77 
finishLine()78 void MetalCodeGenerator::finishLine() {
79     if (!fAtLineStart) {
80         this->writeLine();
81     }
82 }
83 
writeExtension(const Extension & ext)84 void MetalCodeGenerator::writeExtension(const Extension& ext) {
85     this->writeLine("#extension " + ext.name() + " : enable");
86 }
87 
typeName(const Type & type)88 String MetalCodeGenerator::typeName(const Type& type) {
89     switch (type.typeKind()) {
90         case Type::TypeKind::kArray:
91             SkASSERTF(type.columns() > 0, "invalid array size: %s", type.description().c_str());
92             return String::printf("array<%s, %d>",
93                                   this->typeName(type.componentType()).c_str(), type.columns());
94 
95         case Type::TypeKind::kVector:
96             return this->typeName(type.componentType()) + to_string(type.columns());
97 
98         case Type::TypeKind::kMatrix:
99             return this->typeName(type.componentType()) + to_string(type.columns()) + "x" +
100                                   to_string(type.rows());
101 
102         case Type::TypeKind::kSampler:
103             return "texture2d<float>"; // FIXME - support other texture types
104 
105         case Type::TypeKind::kEnum:
106             return "int";
107 
108         default:
109             if (type == *fContext.fTypes.fHalf) {
110                 // FIXME - Currently only supporting floats in MSL to avoid type coercion issues.
111                 return fContext.fTypes.fFloat->name();
112             } else {
113                 return type.name();
114             }
115     }
116 }
117 
writeStructDefinition(const StructDefinition & s)118 void MetalCodeGenerator::writeStructDefinition(const StructDefinition& s) {
119     const Type& type = s.type();
120     this->writeLine("struct " + type.name() + " {");
121     fIndentation++;
122     this->writeFields(type.fields(), type.fOffset);
123     fIndentation--;
124     this->writeLine("};");
125 }
126 
writeType(const Type & type)127 void MetalCodeGenerator::writeType(const Type& type) {
128     this->write(this->typeName(type));
129 }
130 
writeExpression(const Expression & expr,Precedence parentPrecedence)131 void MetalCodeGenerator::writeExpression(const Expression& expr, Precedence parentPrecedence) {
132     switch (expr.kind()) {
133         case Expression::Kind::kBinary:
134             this->writeBinaryExpression(expr.as<BinaryExpression>(), parentPrecedence);
135             break;
136         case Expression::Kind::kBoolLiteral:
137             this->writeBoolLiteral(expr.as<BoolLiteral>());
138             break;
139         case Expression::Kind::kConstructorArray:
140         case Expression::Kind::kConstructorStruct:
141             this->writeAnyConstructor(expr.asAnyConstructor(), "{", "}", parentPrecedence);
142             break;
143         case Expression::Kind::kConstructorCompound:
144             this->writeConstructorCompound(expr.as<ConstructorCompound>(), parentPrecedence);
145             break;
146         case Expression::Kind::kConstructorDiagonalMatrix:
147         case Expression::Kind::kConstructorSplat:
148             this->writeAnyConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
149             break;
150         case Expression::Kind::kConstructorMatrixResize:
151             this->writeConstructorMatrixResize(expr.as<ConstructorMatrixResize>(),
152                                                parentPrecedence);
153             break;
154         case Expression::Kind::kConstructorScalarCast:
155         case Expression::Kind::kConstructorCompoundCast:
156             this->writeCastConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
157             break;
158         case Expression::Kind::kIntLiteral:
159             this->writeIntLiteral(expr.as<IntLiteral>());
160             break;
161         case Expression::Kind::kFieldAccess:
162             this->writeFieldAccess(expr.as<FieldAccess>());
163             break;
164         case Expression::Kind::kFloatLiteral:
165             this->writeFloatLiteral(expr.as<FloatLiteral>());
166             break;
167         case Expression::Kind::kFunctionCall:
168             this->writeFunctionCall(expr.as<FunctionCall>());
169             break;
170         case Expression::Kind::kPrefix:
171             this->writePrefixExpression(expr.as<PrefixExpression>(), parentPrecedence);
172             break;
173         case Expression::Kind::kPostfix:
174             this->writePostfixExpression(expr.as<PostfixExpression>(), parentPrecedence);
175             break;
176         case Expression::Kind::kSetting:
177             this->writeSetting(expr.as<Setting>());
178             break;
179         case Expression::Kind::kSwizzle:
180             this->writeSwizzle(expr.as<Swizzle>());
181             break;
182         case Expression::Kind::kVariableReference:
183             this->writeVariableReference(expr.as<VariableReference>());
184             break;
185         case Expression::Kind::kTernary:
186             this->writeTernaryExpression(expr.as<TernaryExpression>(), parentPrecedence);
187             break;
188         case Expression::Kind::kIndex:
189             this->writeIndexExpression(expr.as<IndexExpression>());
190             break;
191         default:
192             SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str());
193             break;
194     }
195 }
196 
getOutParamHelper(const FunctionCall & call,const ExpressionArray & arguments,const SkTArray<VariableReference * > & outVars)197 String MetalCodeGenerator::getOutParamHelper(const FunctionCall& call,
198                                              const ExpressionArray& arguments,
199                                              const SkTArray<VariableReference*>& outVars) {
200     AutoOutputStream outputToExtraFunctions(this, &fExtraFunctions, &fIndentation);
201     const FunctionDeclaration& function = call.function();
202 
203     String name = "_skOutParamHelper" + to_string(fSwizzleHelperCount++) +
204                   "_" + function.mangledName();
205     const char* separator = "";
206 
207     // Emit a prototype for the function we'll be calling through to in our helper.
208     if (!function.isBuiltin()) {
209         this->writeFunctionDeclaration(function);
210         this->writeLine(";");
211     }
212 
213     // Synthesize a helper function that takes the same inputs as `function`, except in places where
214     // `outVars` is non-null; in those places, we take the type of the VariableReference.
215     //
216     // float _skOutParamHelper0_originalFuncName(float _var0, float _var1, float& outParam) {
217     this->writeType(call.type());
218     this->write(" ");
219     this->write(name);
220     this->write("(");
221     this->writeFunctionRequirementParams(function, separator);
222 
223     SkASSERT(outVars.size() == arguments.size());
224     SkASSERT(outVars.size() == function.parameters().size());
225 
226     // We need to detect cases where the caller passes the same variable as an out-param more than
227     // once, and avoid reusing the variable name. (In those cases we can actually just ignore the
228     // redundant input parameter entirely, and not give it any name.)
229     std::unordered_set<const Variable*> writtenVars;
230 
231     for (int index = 0; index < arguments.count(); ++index) {
232         this->write(separator);
233         separator = ", ";
234 
235         const Variable* param = function.parameters()[index];
236         this->writeModifiers(param->modifiers(), /*globalContext=*/false);
237 
238         const Type* type = outVars[index] ? &outVars[index]->type() : &arguments[index]->type();
239         this->writeType(*type);
240 
241         if (param->modifiers().fFlags & Modifiers::kOut_Flag) {
242             this->write("&");
243         }
244         if (outVars[index]) {
245             auto [iter, didInsert] = writtenVars.insert(outVars[index]->variable());
246             if (didInsert) {
247                 this->write(" ");
248                 fIgnoreVariableReferenceModifiers = true;
249                 this->writeVariableReference(*outVars[index]);
250                 fIgnoreVariableReferenceModifiers = false;
251             }
252         } else {
253             this->write(" _var");
254             this->write(to_string(index));
255         }
256     }
257     this->writeLine(") {");
258 
259     ++fIndentation;
260     for (int index = 0; index < outVars.count(); ++index) {
261         if (!outVars[index]) {
262             continue;
263         }
264         // float3 _var2[ = outParam.zyx];
265         this->writeType(arguments[index]->type());
266         this->write(" _var");
267         this->write(to_string(index));
268 
269         const Variable* param = function.parameters()[index];
270         if (param->modifiers().fFlags & Modifiers::kIn_Flag) {
271             this->write(" = ");
272             fIgnoreVariableReferenceModifiers = true;
273             this->writeExpression(*arguments[index], Precedence::kAssignment);
274             fIgnoreVariableReferenceModifiers = false;
275         }
276 
277         this->writeLine(";");
278     }
279 
280     // [int _skResult = ] myFunction(inputs, outputs, _globals, _var0, _var1, _var2, _var3);
281     bool hasResult = (call.type().name() != "void");
282     if (hasResult) {
283         this->writeType(call.type());
284         this->write(" _skResult = ");
285     }
286 
287     this->writeName(function.mangledName());
288     this->write("(");
289     separator = "";
290     this->writeFunctionRequirementArgs(function, separator);
291 
292     for (int index = 0; index < arguments.count(); ++index) {
293         this->write(separator);
294         separator = ", ";
295 
296         this->write("_var");
297         this->write(to_string(index));
298     }
299     this->writeLine(");");
300 
301     for (int index = 0; index < outVars.count(); ++index) {
302         if (!outVars[index]) {
303             continue;
304         }
305         // outParam.zyx = _var2;
306         fIgnoreVariableReferenceModifiers = true;
307         this->writeExpression(*arguments[index], Precedence::kAssignment);
308         fIgnoreVariableReferenceModifiers = false;
309         this->write(" = _var");
310         this->write(to_string(index));
311         this->writeLine(";");
312     }
313 
314     if (hasResult) {
315         this->writeLine("return _skResult;");
316     }
317 
318     --fIndentation;
319     this->writeLine("}");
320 
321     return name;
322 }
323 
getBitcastIntrinsic(const Type & outType)324 String MetalCodeGenerator::getBitcastIntrinsic(const Type& outType) {
325     return "as_type<" +  outType.displayName() + ">";
326 }
327 
writeFunctionCall(const FunctionCall & c)328 void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) {
329     const FunctionDeclaration& function = c.function();
330 
331     // Many intrinsics need to be rewritten in Metal.
332     if (function.isIntrinsic()) {
333         if (this->writeIntrinsicCall(c, function.intrinsicKind())) {
334             return;
335         }
336     }
337 
338     // Determine whether or not we need to emulate GLSL's out-param semantics for Metal using a
339     // helper function. (Specifically, out-parameters in GLSL are only written back to the original
340     // variable at the end of the function call; also, swizzles are supported, whereas Metal doesn't
341     // allow a swizzle to be passed to a `floatN&`.)
342     const ExpressionArray& arguments = c.arguments();
343     const std::vector<const Variable*>& parameters = function.parameters();
344     SkASSERT(arguments.size() == parameters.size());
345 
346     bool foundOutParam = false;
347     SkSTArray<16, VariableReference*> outVars;
348     outVars.push_back_n(arguments.count(), (VariableReference*)nullptr);
349 
350     for (int index = 0; index < arguments.count(); ++index) {
351         // If this is an out parameter...
352         if (parameters[index]->modifiers().fFlags & Modifiers::kOut_Flag) {
353             // Find the expression's inner variable being written to.
354             Analysis::AssignmentInfo info;
355             // Assignability was verified at IRGeneration time, so this should always succeed.
356             SkAssertResult(Analysis::IsAssignable(*arguments[index], &info));
357             outVars[index] = info.fAssignedVar;
358             foundOutParam = true;
359         }
360     }
361 
362     if (foundOutParam) {
363         // Out parameters need to be written back to at the end of the function. To do this, we
364         // synthesize a helper function which evaluates the out-param expression into a temporary
365         // variable, calls the original function, then writes the temp var back into the out param
366         // using the original out-param expression. (This lets us support things like swizzles and
367         // array indices.)
368         this->write(getOutParamHelper(c, arguments, outVars));
369     } else {
370         this->write(function.mangledName());
371     }
372 
373     this->write("(");
374     const char* separator = "";
375     this->writeFunctionRequirementArgs(function, separator);
376     for (int i = 0; i < arguments.count(); ++i) {
377         this->write(separator);
378         separator = ", ";
379 
380         if (outVars[i]) {
381             this->writeExpression(*outVars[i], Precedence::kSequence);
382         } else {
383             this->writeExpression(*arguments[i], Precedence::kSequence);
384         }
385     }
386     this->write(")");
387 }
388 
389 static constexpr char kInverse2x2[] = R"(
390 float2x2 float2x2_inverse(float2x2 m) {
391     return float2x2(m[1][1], -m[0][1], -m[1][0], m[0][0]) * (1/determinant(m));
392 }
393 )";
394 
395 static constexpr char kInverse3x3[] = R"(
396 float3x3 float3x3_inverse(float3x3 m) {
397     float a00 = m[0][0], a01 = m[0][1], a02 = m[0][2];
398     float a10 = m[1][0], a11 = m[1][1], a12 = m[1][2];
399     float a20 = m[2][0], a21 = m[2][1], a22 = m[2][2];
400     float b01 =  a22*a11 - a12*a21;
401     float b11 = -a22*a10 + a12*a20;
402     float b21 =  a21*a10 - a11*a20;
403     float det = a00*b01 + a01*b11 + a02*b21;
404     return float3x3(b01, (-a22*a01 + a02*a21), ( a12*a01 - a02*a11),
405                     b11, ( a22*a00 - a02*a20), (-a12*a00 + a02*a10),
406                     b21, (-a21*a00 + a01*a20), ( a11*a00 - a01*a10)) * (1/det);
407 }
408 )";
409 
410 static constexpr char kInverse4x4[] = R"(
411 float4x4 float4x4_inverse(float4x4 m) {
412     float a00 = m[0][0], a01 = m[0][1], a02 = m[0][2], a03 = m[0][3];
413     float a10 = m[1][0], a11 = m[1][1], a12 = m[1][2], a13 = m[1][3];
414     float a20 = m[2][0], a21 = m[2][1], a22 = m[2][2], a23 = m[2][3];
415     float a30 = m[3][0], a31 = m[3][1], a32 = m[3][2], a33 = m[3][3];
416     float b00 = a00*a11 - a01*a10;
417     float b01 = a00*a12 - a02*a10;
418     float b02 = a00*a13 - a03*a10;
419     float b03 = a01*a12 - a02*a11;
420     float b04 = a01*a13 - a03*a11;
421     float b05 = a02*a13 - a03*a12;
422     float b06 = a20*a31 - a21*a30;
423     float b07 = a20*a32 - a22*a30;
424     float b08 = a20*a33 - a23*a30;
425     float b09 = a21*a32 - a22*a31;
426     float b10 = a21*a33 - a23*a31;
427     float b11 = a22*a33 - a23*a32;
428     float det = b00*b11 - b01*b10 + b02*b09 + b03*b08 - b04*b07 + b05*b06;
429     return float4x4(a11*b11 - a12*b10 + a13*b09,
430                     a02*b10 - a01*b11 - a03*b09,
431                     a31*b05 - a32*b04 + a33*b03,
432                     a22*b04 - a21*b05 - a23*b03,
433                     a12*b08 - a10*b11 - a13*b07,
434                     a00*b11 - a02*b08 + a03*b07,
435                     a32*b02 - a30*b05 - a33*b01,
436                     a20*b05 - a22*b02 + a23*b01,
437                     a10*b10 - a11*b08 + a13*b06,
438                     a01*b08 - a00*b10 - a03*b06,
439                     a30*b04 - a31*b02 + a33*b00,
440                     a21*b02 - a20*b04 - a23*b00,
441                     a11*b07 - a10*b09 - a12*b06,
442                     a00*b09 - a01*b07 + a02*b06,
443                     a31*b01 - a30*b03 - a32*b00,
444                     a20*b03 - a21*b01 + a22*b00) * (1/det);
445 }
446 )";
447 
getInversePolyfill(const ExpressionArray & arguments)448 String MetalCodeGenerator::getInversePolyfill(const ExpressionArray& arguments) {
449     // Only use polyfills for a function taking a single-argument square matrix.
450     if (arguments.size() == 1) {
451         const Type& type = arguments.front()->type();
452         if (type.isMatrix() && type.rows() == type.columns()) {
453             // Inject the correct polyfill based on the matrix size.
454             String name = this->typeName(type) + "_inverse";
455             auto [iter, didInsert] = fWrittenIntrinsics.insert(name);
456             if (didInsert) {
457                 switch (type.rows()) {
458                     case 2:
459                         fExtraFunctions.writeText(kInverse2x2);
460                         break;
461                     case 3:
462                         fExtraFunctions.writeText(kInverse3x3);
463                         break;
464                     case 4:
465                         fExtraFunctions.writeText(kInverse4x4);
466                         break;
467                 }
468             }
469             return name;
470         }
471     }
472     // This isn't the built-in `inverse`. We don't want to polyfill it at all.
473     return "inverse";
474 }
475 
476 static constexpr char kMatrixCompMult[] = R"(
477 template <int C, int R>
478 matrix<float, C, R> matrixCompMult(matrix<float, C, R> a, matrix<float, C, R> b) {
479     matrix<float, C, R> result;
480     for (int c = 0; c < C; ++c) {
481         result[c] = a[c] * b[c];
482     }
483     return result;
484 }
485 )";
486 
writeMatrixCompMult()487 void MetalCodeGenerator::writeMatrixCompMult() {
488     String name = "matrixCompMult";
489     if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) {
490         fWrittenIntrinsics.insert(name);
491         fExtraFunctions.writeText(kMatrixCompMult);
492     }
493 }
494 
getTempVariable(const Type & type)495 String MetalCodeGenerator::getTempVariable(const Type& type) {
496     String tempVar = "_skTemp" + to_string(fVarCount++);
497     this->fFunctionHeader += "    " + this->typeName(type) + " " + tempVar + ";\n";
498     return tempVar;
499 }
500 
writeSimpleIntrinsic(const FunctionCall & c)501 void MetalCodeGenerator::writeSimpleIntrinsic(const FunctionCall& c) {
502     // Write out an intrinsic function call exactly as-is. No muss no fuss.
503     this->write(c.function().name());
504     this->writeArgumentList(c.arguments());
505 }
506 
writeArgumentList(const ExpressionArray & arguments)507 void MetalCodeGenerator::writeArgumentList(const ExpressionArray& arguments) {
508     this->write("(");
509     const char* separator = "";
510     for (const std::unique_ptr<Expression>& arg : arguments) {
511         this->write(separator);
512         separator = ", ";
513         this->writeExpression(*arg, Precedence::kSequence);
514     }
515     this->write(")");
516 }
517 
writeIntrinsicCall(const FunctionCall & c,IntrinsicKind kind)518 bool MetalCodeGenerator::writeIntrinsicCall(const FunctionCall& c, IntrinsicKind kind) {
519     const ExpressionArray& arguments = c.arguments();
520     switch (kind) {
521         case k_sample_IntrinsicKind: {
522             this->writeExpression(*arguments[0], Precedence::kSequence);
523             this->write(".sample(");
524             this->writeExpression(*arguments[0], Precedence::kSequence);
525             this->write(SAMPLER_SUFFIX);
526             this->write(", ");
527             const Type& arg1Type = arguments[1]->type();
528             if (arg1Type.columns() == 3) {
529                 // have to store the vector in a temp variable to avoid double evaluating it
530                 String tmpVar = this->getTempVariable(arg1Type);
531                 this->write("(" + tmpVar + " = ");
532                 this->writeExpression(*arguments[1], Precedence::kSequence);
533                 this->write(", " + tmpVar + ".xy / " + tmpVar + ".z))");
534             } else {
535                 SkASSERT(arg1Type.columns() == 2);
536                 this->writeExpression(*arguments[1], Precedence::kSequence);
537                 this->write(")");
538             }
539             return true;
540         }
541         case k_mod_IntrinsicKind: {
542             // fmod(x, y) in metal calculates x - y * trunc(x / y) instead of x - y * floor(x / y)
543             String tmpX = this->getTempVariable(arguments[0]->type());
544             String tmpY = this->getTempVariable(arguments[1]->type());
545             this->write("(" + tmpX + " = ");
546             this->writeExpression(*arguments[0], Precedence::kSequence);
547             this->write(", " + tmpY + " = ");
548             this->writeExpression(*arguments[1], Precedence::kSequence);
549             this->write(", " + tmpX + " - " + tmpY + " * floor(" + tmpX + " / " + tmpY + "))");
550             return true;
551         }
552         // GLSL declares scalar versions of most geometric intrinsics, but these don't exist in MSL
553         case k_distance_IntrinsicKind: {
554             if (arguments[0]->type().columns() == 1) {
555                 this->write("abs(");
556                 this->writeExpression(*arguments[0], Precedence::kAdditive);
557                 this->write(" - ");
558                 this->writeExpression(*arguments[1], Precedence::kAdditive);
559                 this->write(")");
560             } else {
561                 this->writeSimpleIntrinsic(c);
562             }
563             return true;
564         }
565         case k_dot_IntrinsicKind: {
566             if (arguments[0]->type().columns() == 1) {
567                 this->write("(");
568                 this->writeExpression(*arguments[0], Precedence::kMultiplicative);
569                 this->write(" * ");
570                 this->writeExpression(*arguments[1], Precedence::kMultiplicative);
571                 this->write(")");
572             } else {
573                 this->writeSimpleIntrinsic(c);
574             }
575             return true;
576         }
577         case k_faceforward_IntrinsicKind: {
578             if (arguments[0]->type().columns() == 1) {
579                 // ((((Nref) * (I) < 0) ? 1 : -1) * (N))
580                 this->write("((((");
581                 this->writeExpression(*arguments[2], Precedence::kSequence);
582                 this->write(") * (");
583                 this->writeExpression(*arguments[1], Precedence::kSequence);
584                 this->write(") < 0) ? 1 : -1) * (");
585                 this->writeExpression(*arguments[0], Precedence::kSequence);
586                 this->write("))");
587             } else {
588                 this->writeSimpleIntrinsic(c);
589             }
590             return true;
591         }
592         case k_length_IntrinsicKind: {
593             this->write(arguments[0]->type().columns() == 1 ? "abs(" : "length(");
594             this->writeExpression(*arguments[0], Precedence::kSequence);
595             this->write(")");
596             return true;
597         }
598         case k_normalize_IntrinsicKind: {
599             this->write(arguments[0]->type().columns() == 1 ? "sign(" : "normalize(");
600             this->writeExpression(*arguments[0], Precedence::kSequence);
601             this->write(")");
602             return true;
603         }
604 
605         case k_floatBitsToInt_IntrinsicKind:
606         case k_floatBitsToUint_IntrinsicKind:
607         case k_intBitsToFloat_IntrinsicKind:
608         case k_uintBitsToFloat_IntrinsicKind: {
609             this->write(this->getBitcastIntrinsic(c.type()));
610             this->write("(");
611             this->writeExpression(*arguments[0], Precedence::kSequence);
612             this->write(")");
613             return true;
614         }
615         case k_degrees_IntrinsicKind: {
616             this->write("((");
617             this->writeExpression(*arguments[0], Precedence::kSequence);
618             this->write(") * 57.2957795)");
619             return true;
620         }
621         case k_radians_IntrinsicKind: {
622             this->write("((");
623             this->writeExpression(*arguments[0], Precedence::kSequence);
624             this->write(") * 0.0174532925)");
625             return true;
626         }
627         case k_dFdx_IntrinsicKind: {
628             this->write("dfdx");
629             this->writeArgumentList(c.arguments());
630             return true;
631         }
632         case k_dFdy_IntrinsicKind: {
633             // Flipping Y also negates the Y derivatives.
634             if (fProgram.fConfig->fSettings.fFlipY) {
635                 this->write("-");
636             }
637             this->write("dfdy");
638             this->writeArgumentList(c.arguments());
639             return true;
640         }
641         case k_inverse_IntrinsicKind: {
642             this->write(this->getInversePolyfill(arguments));
643             this->writeArgumentList(c.arguments());
644             return true;
645         }
646         case k_inversesqrt_IntrinsicKind: {
647             this->write("rsqrt");
648             this->writeArgumentList(c.arguments());
649             return true;
650         }
651         case k_atan_IntrinsicKind: {
652             this->write(c.arguments().size() == 2 ? "atan2" : "atan");
653             this->writeArgumentList(c.arguments());
654             return true;
655         }
656         case k_reflect_IntrinsicKind: {
657             if (arguments[0]->type().columns() == 1) {
658                 // We need to synthesize `I - 2 * N * I * N`.
659                 String tmpI = this->getTempVariable(arguments[0]->type());
660                 String tmpN = this->getTempVariable(arguments[1]->type());
661 
662                 // (_skTempI = ...
663                 this->write("(" + tmpI + " = ");
664                 this->writeExpression(*arguments[0], Precedence::kSequence);
665 
666                 // , _skTempN = ...
667                 this->write(", " + tmpN + " = ");
668                 this->writeExpression(*arguments[1], Precedence::kSequence);
669 
670                 // , _skTempI - 2 * _skTempN * _skTempI * _skTempN)
671                 this->write(", " + tmpI + " - 2 * " + tmpN + " * " + tmpI + " * " + tmpN + ")");
672             } else {
673                 this->writeSimpleIntrinsic(c);
674             }
675             return true;
676         }
677         case k_refract_IntrinsicKind: {
678             if (arguments[0]->type().columns() == 1) {
679                 // Metal does implement refract for vectors; rather than reimplementing refract from
680                 // scratch, we can replace the call with `refract(float2(I,0), float2(N,0), eta).x`.
681                 this->write("(refract(float2(");
682                 this->writeExpression(*arguments[0], Precedence::kSequence);
683                 this->write(", 0), float2(");
684                 this->writeExpression(*arguments[1], Precedence::kSequence);
685                 this->write(", 0), ");
686                 this->writeExpression(*arguments[2], Precedence::kSequence);
687                 this->write(").x)");
688             } else {
689                 this->writeSimpleIntrinsic(c);
690             }
691             return true;
692         }
693         case k_roundEven_IntrinsicKind: {
694             this->write("rint");
695             this->writeArgumentList(c.arguments());
696             return true;
697         }
698         case k_bitCount_IntrinsicKind: {
699             this->write("popcount(");
700             this->writeExpression(*arguments[0], Precedence::kSequence);
701             this->write(")");
702             return true;
703         }
704         case k_findLSB_IntrinsicKind: {
705             // Create a temp variable to store the expression, to avoid double-evaluating it.
706             String skTemp = this->getTempVariable(arguments[0]->type());
707             String exprType = this->typeName(arguments[0]->type());
708 
709             // ctz returns numbits(type) on zero inputs; GLSL documents it as generating -1 instead.
710             // Use select to detect zero inputs and force a -1 result.
711 
712             // (_skTemp1 = (.....), select(ctz(_skTemp1), int4(-1), _skTemp1 == int4(0)))
713             this->write("(");
714             this->write(skTemp);
715             this->write(" = (");
716             this->writeExpression(*arguments[0], Precedence::kSequence);
717             this->write("), select(ctz(");
718             this->write(skTemp);
719             this->write("), ");
720             this->write(exprType);
721             this->write("(-1), ");
722             this->write(skTemp);
723             this->write(" == ");
724             this->write(exprType);
725             this->write("(0)))");
726             return true;
727         }
728         case k_findMSB_IntrinsicKind: {
729             // Create a temp variable to store the expression, to avoid double-evaluating it.
730             String skTemp1 = this->getTempVariable(arguments[0]->type());
731             String exprType = this->typeName(arguments[0]->type());
732 
733             // GLSL findMSB is actually quite different from Metal's clz:
734             // - For signed negative numbers, it returns the first zero bit, not the first one bit!
735             // - For an empty input (0/~0 depending on sign), findMSB gives -1; clz is numbits(type)
736 
737             // (_skTemp1 = (.....),
738             this->write("(");
739             this->write(skTemp1);
740             this->write(" = (");
741             this->writeExpression(*arguments[0], Precedence::kSequence);
742             this->write("), ");
743 
744             // Signed input types might be negative; we need another helper variable to negate the
745             // input (since we can only find one bits, not zero bits).
746             String skTemp2;
747             if (arguments[0]->type().isSigned()) {
748                 // ... _skTemp2 = (select(_skTemp1, ~_skTemp1, _skTemp1 < 0)),
749                 skTemp2 = this->getTempVariable(arguments[0]->type());
750                 this->write(skTemp2);
751                 this->write(" = (select(");
752                 this->write(skTemp1);
753                 this->write(", ~");
754                 this->write(skTemp1);
755                 this->write(", ");
756                 this->write(skTemp1);
757                 this->write(" < 0)), ");
758             } else {
759                 skTemp2 = skTemp1;
760             }
761 
762             // ... select(int4(clz(_skTemp2)), int4(-1), _skTemp2 == int4(0)))
763             this->write("select(");
764             this->write(this->typeName(c.type()));
765             this->write("(clz(");
766             this->write(skTemp2);
767             this->write(")), ");
768             this->write(this->typeName(c.type()));
769             this->write("(-1), ");
770             this->write(skTemp2);
771             this->write(" == ");
772             this->write(exprType);
773             this->write("(0)))");
774             return true;
775         }
776         case k_matrixCompMult_IntrinsicKind: {
777             this->writeMatrixCompMult();
778             this->writeSimpleIntrinsic(c);
779             return true;
780         }
781         case k_equal_IntrinsicKind:
782         case k_greaterThan_IntrinsicKind:
783         case k_greaterThanEqual_IntrinsicKind:
784         case k_lessThan_IntrinsicKind:
785         case k_lessThanEqual_IntrinsicKind:
786         case k_notEqual_IntrinsicKind: {
787             this->write("(");
788             this->writeExpression(*c.arguments()[0], Precedence::kRelational);
789             switch (kind) {
790                 case k_equal_IntrinsicKind:
791                     this->write(" == ");
792                     break;
793                 case k_notEqual_IntrinsicKind:
794                     this->write(" != ");
795                     break;
796                 case k_lessThan_IntrinsicKind:
797                     this->write(" < ");
798                     break;
799                 case k_lessThanEqual_IntrinsicKind:
800                     this->write(" <= ");
801                     break;
802                 case k_greaterThan_IntrinsicKind:
803                     this->write(" > ");
804                     break;
805                 case k_greaterThanEqual_IntrinsicKind:
806                     this->write(" >= ");
807                     break;
808                 default:
809                     SK_ABORT("unsupported comparison intrinsic kind");
810             }
811             this->writeExpression(*c.arguments()[1], Precedence::kRelational);
812             this->write(")");
813             return true;
814         }
815         default:
816             return false;
817     }
818 }
819 
820 // Assembles a matrix of type floatRxC by resizing another matrix named `x0`.
821 // Cells that don't exist in the source matrix will be populated with identity-matrix values.
assembleMatrixFromMatrix(const Type & sourceMatrix,int rows,int columns)822 void MetalCodeGenerator::assembleMatrixFromMatrix(const Type& sourceMatrix, int rows, int columns) {
823     SkASSERT(rows <= 4);
824     SkASSERT(columns <= 4);
825 
826     const char* columnSeparator = "";
827     for (int c = 0; c < columns; ++c) {
828         fExtraFunctions.printf("%sfloat%d(", columnSeparator, rows);
829         columnSeparator = "), ";
830 
831         // Determine how many values to take from the source matrix for this row.
832         int swizzleLength = 0;
833         if (c < sourceMatrix.columns()) {
834             swizzleLength = std::min<>(rows, sourceMatrix.rows());
835         }
836 
837         // Emit all the values from the source matrix row.
838         bool firstItem;
839         switch (swizzleLength) {
840             case 0:  firstItem = true;                                            break;
841             case 1:  firstItem = false; fExtraFunctions.printf("x0[%d].x", c);    break;
842             case 2:  firstItem = false; fExtraFunctions.printf("x0[%d].xy", c);   break;
843             case 3:  firstItem = false; fExtraFunctions.printf("x0[%d].xyz", c);  break;
844             case 4:  firstItem = false; fExtraFunctions.printf("x0[%d].xyzw", c); break;
845             default: SkUNREACHABLE;
846         }
847 
848         // Emit the placeholder identity-matrix cells.
849         for (int r = swizzleLength; r < rows; ++r) {
850             fExtraFunctions.printf("%s%s", firstItem ? "" : ", ", (r == c) ? "1.0" : "0.0");
851             firstItem = false;
852         }
853     }
854 
855     fExtraFunctions.writeText(")");
856 }
857 
858 // Assembles a matrix of type floatRxC by concatenating an arbitrary mix of values, named `x0`,
859 // `x1`, etc. An error is written if the expression list don't contain exactly R*C scalars.
assembleMatrixFromExpressions(const AnyConstructor & ctor,int rows,int columns)860 void MetalCodeGenerator::assembleMatrixFromExpressions(const AnyConstructor& ctor,
861                                                        int rows, int columns) {
862     size_t argIndex = 0;
863     int argPosition = 0;
864     auto args = ctor.argumentSpan();
865 
866     const char* columnSeparator = "";
867     for (int c = 0; c < columns; ++c) {
868         fExtraFunctions.printf("%sfloat%d(", columnSeparator, rows);
869         columnSeparator = "), ";
870 
871         const char* rowSeparator = "";
872         for (int r = 0; r < rows; ++r) {
873             fExtraFunctions.writeText(rowSeparator);
874             rowSeparator = ", ";
875 
876             if (argIndex < args.size()) {
877                 const Type& argType = args[argIndex]->type();
878                 switch (argType.typeKind()) {
879                     case Type::TypeKind::kScalar: {
880                         fExtraFunctions.printf("x%zu", argIndex);
881                         break;
882                     }
883                     case Type::TypeKind::kVector: {
884                         fExtraFunctions.printf("x%zu[%d]", argIndex, argPosition);
885                         break;
886                     }
887                     case Type::TypeKind::kMatrix: {
888                         fExtraFunctions.printf("x%zu[%d][%d]", argIndex,
889                                                argPosition / argType.rows(),
890                                                argPosition % argType.rows());
891                         break;
892                     }
893                     default: {
894                         SkDEBUGFAIL("incorrect type of argument for matrix constructor");
895                         fExtraFunctions.writeText("<error>");
896                         break;
897                     }
898                 }
899 
900                 ++argPosition;
901                 if (argPosition >= argType.columns() * argType.rows()) {
902                     ++argIndex;
903                     argPosition = 0;
904                 }
905             } else {
906                 SkDEBUGFAIL("not enough arguments for matrix constructor");
907                 fExtraFunctions.writeText("<error>");
908             }
909         }
910     }
911 
912     if (argPosition != 0 || argIndex != args.size()) {
913         SkDEBUGFAIL("incorrect number of arguments for matrix constructor");
914         fExtraFunctions.writeText(", <error>");
915     }
916 
917     fExtraFunctions.writeText(")");
918 }
919 
920 // Generates a constructor for 'matrix' which reorganizes the input arguments into the proper shape.
921 // Keeps track of previously generated constructors so that we won't generate more than one
922 // constructor for any given permutation of input argument types. Returns the name of the
923 // generated constructor method.
getMatrixConstructHelper(const AnyConstructor & c)924 String MetalCodeGenerator::getMatrixConstructHelper(const AnyConstructor& c) {
925     const Type& matrix = c.type();
926     int columns = matrix.columns();
927     int rows = matrix.rows();
928     auto args = c.argumentSpan();
929 
930     // Create the helper-method name and use it as our lookup key.
931     String name;
932     name.appendf("float%dx%d_from", columns, rows);
933     for (const std::unique_ptr<Expression>& expr : args) {
934         name.appendf("_%s", this->typeName(expr->type()).c_str());
935     }
936 
937     // If a helper-method has already been synthesized, we don't need to synthesize it again.
938     auto [iter, newlyCreated] = fHelpers.insert(name);
939     if (!newlyCreated) {
940         return name;
941     }
942 
943     // Unlike GLSL, Metal requires that matrices are initialized with exactly R vectors of C
944     // components apiece. (In Metal 2.0, you can also supply R*C scalars, but you still cannot
945     // supply a mixture of scalars and vectors.)
946     fExtraFunctions.printf("float%dx%d %s(", columns, rows, name.c_str());
947 
948     size_t argIndex = 0;
949     const char* argSeparator = "";
950     for (const std::unique_ptr<Expression>& expr : args) {
951         fExtraFunctions.printf("%s%s x%zu", argSeparator,
952                                this->typeName(expr->type()).c_str(), argIndex++);
953         argSeparator = ", ";
954     }
955 
956     fExtraFunctions.printf(") {\n    return float%dx%d(", columns, rows);
957 
958     if (args.size() == 1 && args.front()->type().isMatrix()) {
959         this->assembleMatrixFromMatrix(args.front()->type(), rows, columns);
960     } else {
961         this->assembleMatrixFromExpressions(c, rows, columns);
962     }
963 
964     fExtraFunctions.writeText(");\n}\n");
965     return name;
966 }
967 
canCoerce(const Type & t1,const Type & t2)968 bool MetalCodeGenerator::canCoerce(const Type& t1, const Type& t2) {
969     if (t1.columns() != t2.columns() || t1.rows() != t2.rows()) {
970         return false;
971     }
972     if (t1.columns() > 1) {
973         return this->canCoerce(t1.componentType(), t2.componentType());
974     }
975     return t1.isFloat() && t2.isFloat();
976 }
977 
matrixConstructHelperIsNeeded(const ConstructorCompound & c)978 bool MetalCodeGenerator::matrixConstructHelperIsNeeded(const ConstructorCompound& c) {
979     SkASSERT(c.type().isMatrix());
980 
981     // GLSL is fairly free-form about inputs to its matrix constructors, but Metal is not; it
982     // expects exactly R vectors of C components apiece. (Metal 2.0 also allows a list of R*C
983     // scalars.) Some cases are simple to translate and so we handle those inline--e.g. a list of
984     // scalars can be constructed trivially. In more complex cases, we generate a helper function
985     // that converts our inputs into a properly-shaped matrix.
986     // A matrix construct helper method is always used if any input argument is a matrix.
987     // Helper methods are also necessary when any argument would span multiple rows. For instance:
988     //
989     // float2 x = (1, 2);
990     // float3x2(x, 3, 4, 5, 6) = | 1 3 5 | = no helper needed; conversion can be done inline
991     //                           | 2 4 6 |
992     //
993     // float2 x = (2, 3);
994     // float3x2(1, x, 4, 5, 6) = | 1 3 5 | = x spans multiple rows; a helper method will be used
995     //                           | 2 4 6 |
996     //
997     // float4 x = (1, 2, 3, 4);
998     // float2x2(x) = | 1 3 | = x spans multiple rows; a helper method will be used
999     //               | 2 4 |
1000     //
1001 
1002     int position = 0;
1003     for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1004         // If an input argument is a matrix, we need a helper function.
1005         if (expr->type().isMatrix()) {
1006             return true;
1007         }
1008         position += expr->type().columns();
1009         if (position > c.type().rows()) {
1010             // An input argument would span multiple rows; a helper function is required.
1011             return true;
1012         }
1013         if (position == c.type().rows()) {
1014             // We've advanced to the end of a row. Wrap to the start of the next row.
1015             position = 0;
1016         }
1017     }
1018 
1019     return false;
1020 }
1021 
writeConstructorMatrixResize(const ConstructorMatrixResize & c,Precedence parentPrecedence)1022 void MetalCodeGenerator::writeConstructorMatrixResize(const ConstructorMatrixResize& c,
1023                                                       Precedence parentPrecedence) {
1024     // Matrix-resize via casting doesn't natively exist in Metal at all, so we always need to use a
1025     // matrix-construct helper here.
1026     this->write(this->getMatrixConstructHelper(c));
1027     this->write("(");
1028     this->writeExpression(*c.argument(), Precedence::kSequence);
1029     this->write(")");
1030 }
1031 
writeConstructorCompound(const ConstructorCompound & c,Precedence parentPrecedence)1032 void MetalCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
1033                                                   Precedence parentPrecedence) {
1034     if (c.type().isMatrix()) {
1035         this->writeConstructorCompoundMatrix(c, parentPrecedence);
1036     } else {
1037         this->writeAnyConstructor(c, "(", ")", parentPrecedence);
1038     }
1039 }
1040 
writeConstructorCompoundMatrix(const ConstructorCompound & c,Precedence parentPrecedence)1041 void MetalCodeGenerator::writeConstructorCompoundMatrix(const ConstructorCompound& c,
1042                                                         Precedence parentPrecedence) {
1043     // Emit and invoke a matrix-constructor helper method if one is necessary.
1044     if (this->matrixConstructHelperIsNeeded(c)) {
1045         this->write(this->getMatrixConstructHelper(c));
1046         this->write("(");
1047         const char* separator = "";
1048         for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1049             this->write(separator);
1050             separator = ", ";
1051             this->writeExpression(*expr, Precedence::kSequence);
1052         }
1053         this->write(")");
1054         return;
1055     }
1056 
1057     // Metal doesn't allow creating matrices by passing in scalars and vectors in a jumble; it
1058     // requires your scalars to be grouped up into columns. Because `matrixConstructHelperIsNeeded`
1059     // returned false, we know that none of our scalars/vectors "wrap" across across a column, so we
1060     // can group our inputs up and synthesize a constructor for each column.
1061     const Type& matrixType = c.type();
1062     const Type& columnType = matrixType.componentType().toCompound(
1063             fContext, /*columns=*/matrixType.rows(), /*rows=*/1);
1064 
1065     this->writeType(matrixType);
1066     this->write("(");
1067     const char* separator = "";
1068     int scalarCount = 0;
1069     for (const std::unique_ptr<Expression>& arg : c.arguments()) {
1070         this->write(separator);
1071         separator = ", ";
1072         if (arg->type().columns() < matrixType.rows()) {
1073             // Write a `floatN(` constructor to group scalars and smaller vectors together.
1074             if (!scalarCount) {
1075                 this->writeType(columnType);
1076                 this->write("(");
1077             }
1078             scalarCount += arg->type().columns();
1079         }
1080         this->writeExpression(*arg, Precedence::kSequence);
1081         if (scalarCount && scalarCount == matrixType.rows()) {
1082             // Close our `floatN(...` constructor block from above.
1083             this->write(")");
1084             scalarCount = 0;
1085         }
1086     }
1087     this->write(")");
1088 }
1089 
writeAnyConstructor(const AnyConstructor & c,const char * leftBracket,const char * rightBracket,Precedence parentPrecedence)1090 void MetalCodeGenerator::writeAnyConstructor(const AnyConstructor& c,
1091                                              const char* leftBracket,
1092                                              const char* rightBracket,
1093                                              Precedence parentPrecedence) {
1094     this->writeType(c.type());
1095     this->write(leftBracket);
1096     const char* separator = "";
1097     for (const std::unique_ptr<Expression>& arg : c.argumentSpan()) {
1098         this->write(separator);
1099         separator = ", ";
1100         this->writeExpression(*arg, Precedence::kSequence);
1101     }
1102     this->write(rightBracket);
1103 }
1104 
writeCastConstructor(const AnyConstructor & c,const char * leftBracket,const char * rightBracket,Precedence parentPrecedence)1105 void MetalCodeGenerator::writeCastConstructor(const AnyConstructor& c,
1106                                               const char* leftBracket,
1107                                               const char* rightBracket,
1108                                               Precedence parentPrecedence) {
1109     // If the type is coercible, emit it directly without the cast.
1110     auto args = c.argumentSpan();
1111     if (args.size() == 1) {
1112         if (this->canCoerce(c.type(), args.front()->type())) {
1113             this->writeExpression(*args.front(), parentPrecedence);
1114             return;
1115         }
1116     }
1117 
1118     return this->writeAnyConstructor(c, leftBracket, rightBracket, parentPrecedence);
1119 }
1120 
writeFragCoord()1121 void MetalCodeGenerator::writeFragCoord() {
1122     if (fRTHeightName.length()) {
1123         this->write("float4(_fragCoord.x, ");
1124         this->write(fRTHeightName.c_str());
1125         this->write(" - _fragCoord.y, 0.0, _fragCoord.w)");
1126     } else {
1127         this->write("float4(_fragCoord.x, _fragCoord.y, 0.0, _fragCoord.w)");
1128     }
1129 }
1130 
writeVariableReference(const VariableReference & ref)1131 void MetalCodeGenerator::writeVariableReference(const VariableReference& ref) {
1132     // When assembling out-param helper functions, we copy variables into local clones with matching
1133     // names. We never want to prepend "_in." or "_globals." when writing these variables since
1134     // we're actually targeting the clones.
1135     if (fIgnoreVariableReferenceModifiers) {
1136         this->writeName(ref.variable()->name());
1137         return;
1138     }
1139 
1140     switch (ref.variable()->modifiers().fLayout.fBuiltin) {
1141         case SK_FRAGCOLOR_BUILTIN:
1142             this->write("_out.sk_FragColor");
1143             break;
1144         case SK_FRAGCOORD_BUILTIN:
1145             this->writeFragCoord();
1146             break;
1147         case SK_VERTEXID_BUILTIN:
1148             this->write("sk_VertexID");
1149             break;
1150         case SK_INSTANCEID_BUILTIN:
1151             this->write("sk_InstanceID");
1152             break;
1153         case SK_CLOCKWISE_BUILTIN:
1154             // We'd set the front facing winding in the MTLRenderCommandEncoder to be counter
1155             // clockwise to match Skia convention.
1156             this->write(fProgram.fConfig->fSettings.fFlipY ? "_frontFacing" : "(!_frontFacing)");
1157             break;
1158         default:
1159             const Variable& var = *ref.variable();
1160             if (var.storage() == Variable::Storage::kGlobal) {
1161                 if (var.modifiers().fFlags & Modifiers::kIn_Flag) {
1162                     this->write("_in.");
1163                 } else if (var.modifiers().fFlags & Modifiers::kOut_Flag) {
1164                     this->write("_out.");
1165                 } else if (var.modifiers().fFlags & Modifiers::kUniform_Flag &&
1166                            var.type().typeKind() != Type::TypeKind::kSampler) {
1167                     this->write("_uniforms.");
1168                 } else {
1169                     this->write("_globals.");
1170                 }
1171             }
1172             this->writeName(var.name());
1173     }
1174 }
1175 
writeIndexExpression(const IndexExpression & expr)1176 void MetalCodeGenerator::writeIndexExpression(const IndexExpression& expr) {
1177     this->writeExpression(*expr.base(), Precedence::kPostfix);
1178     this->write("[");
1179     this->writeExpression(*expr.index(), Precedence::kTopLevel);
1180     this->write("]");
1181 }
1182 
writeFieldAccess(const FieldAccess & f)1183 void MetalCodeGenerator::writeFieldAccess(const FieldAccess& f) {
1184     const Type::Field* field = &f.base()->type().fields()[f.fieldIndex()];
1185     if (FieldAccess::OwnerKind::kDefault == f.ownerKind()) {
1186         this->writeExpression(*f.base(), Precedence::kPostfix);
1187         this->write(".");
1188     }
1189     switch (field->fModifiers.fLayout.fBuiltin) {
1190         case SK_POSITION_BUILTIN:
1191             this->write("_out.sk_Position");
1192             break;
1193         default:
1194             if (field->fName == "sk_PointSize") {
1195                 this->write("_out.sk_PointSize");
1196             } else {
1197                 if (FieldAccess::OwnerKind::kAnonymousInterfaceBlock == f.ownerKind()) {
1198                     this->write("_globals.");
1199                     this->write(fInterfaceBlockNameMap[fInterfaceBlockMap[field]]);
1200                     this->write("->");
1201                 }
1202                 this->writeName(field->fName);
1203             }
1204     }
1205 }
1206 
writeSwizzle(const Swizzle & swizzle)1207 void MetalCodeGenerator::writeSwizzle(const Swizzle& swizzle) {
1208     this->writeExpression(*swizzle.base(), Precedence::kPostfix);
1209     this->write(".");
1210     for (int c : swizzle.components()) {
1211         SkASSERT(c >= 0 && c <= 3);
1212         this->write(&("x\0y\0z\0w\0"[c * 2]));
1213     }
1214 }
1215 
writeMatrixTimesEqualHelper(const Type & left,const Type & right,const Type & result)1216 void MetalCodeGenerator::writeMatrixTimesEqualHelper(const Type& left, const Type& right,
1217                                                      const Type& result) {
1218     String key = "TimesEqual " + this->typeName(left) + ":" + this->typeName(right);
1219 
1220     auto [iter, wasInserted] = fHelpers.insert(key);
1221     if (wasInserted) {
1222         fExtraFunctions.printf("thread %s& operator*=(thread %s& left, thread const %s& right) {\n"
1223                                "    left = left * right;\n"
1224                                "    return left;\n"
1225                                "}\n",
1226                                this->typeName(result).c_str(), this->typeName(left).c_str(),
1227                                this->typeName(right).c_str());
1228     }
1229 }
1230 
writeMatrixEqualityHelpers(const Type & left,const Type & right)1231 void MetalCodeGenerator::writeMatrixEqualityHelpers(const Type& left, const Type& right) {
1232     SkASSERT(left.isMatrix());
1233     SkASSERT(right.isMatrix());
1234     SkASSERT(left.rows() == right.rows());
1235     SkASSERT(left.columns() == right.columns());
1236 
1237     String key = "MatrixEquality " + this->typeName(left) + ":" + this->typeName(right);
1238 
1239     auto [iter, wasInserted] = fHelpers.insert(key);
1240     if (wasInserted) {
1241         fExtraFunctions.printf(
1242                 "thread bool operator==(const %s left, const %s right) {\n"
1243                 "    return ",
1244                 this->typeName(left).c_str(), this->typeName(right).c_str());
1245 
1246         const char* separator = "";
1247         for (int index=0; index<left.columns(); ++index) {
1248             fExtraFunctions.printf("%sall(left[%d] == right[%d])", separator, index, index);
1249             separator = " &&\n           ";
1250         }
1251 
1252         fExtraFunctions.printf(
1253                 ";\n"
1254                 "}\n"
1255                 "thread bool operator!=(const %s left, const %s right) {\n"
1256                 "    return !(left == right);\n"
1257                 "}\n",
1258                 this->typeName(left).c_str(), this->typeName(right).c_str());
1259     }
1260 }
1261 
writeArrayEqualityHelpers(const Type & type)1262 void MetalCodeGenerator::writeArrayEqualityHelpers(const Type& type) {
1263     SkASSERT(type.isArray());
1264 
1265     // If the array's component type needs a helper as well, we need to emit that one first.
1266     this->writeEqualityHelpers(type.componentType(), type.componentType());
1267 
1268     auto [iter, wasInserted] = fHelpers.insert("ArrayEquality []");
1269     if (wasInserted) {
1270         fExtraFunctions.writeText(R"(
1271 template <typename T, size_t N>
1272 bool operator==(thread const array<T, N>& left, thread const array<T, N>& right) {
1273     for (size_t index = 0; index < N; ++index) {
1274         if (!(left[index] == right[index])) {
1275             return false;
1276         }
1277     }
1278     return true;
1279 }
1280 
1281 template <typename T, size_t N>
1282 bool operator!=(thread const array<T, N>& left, thread const array<T, N>& right) {
1283     return !(left == right);
1284 }
1285 )");
1286     }
1287 }
1288 
writeStructEqualityHelpers(const Type & type)1289 void MetalCodeGenerator::writeStructEqualityHelpers(const Type& type) {
1290     SkASSERT(type.isStruct());
1291     String key = "StructEquality " + this->typeName(type);
1292 
1293     auto [iter, wasInserted] = fHelpers.insert(key);
1294     if (wasInserted) {
1295         // If one of the struct's fields needs a helper as well, we need to emit that one first.
1296         for (const Type::Field& field : type.fields()) {
1297             this->writeEqualityHelpers(*field.fType, *field.fType);
1298         }
1299 
1300         // Write operator== and operator!= for this struct, since those are assumed to exist in SkSL
1301         // and GLSL but do not exist by default in Metal.
1302         fExtraFunctions.printf(
1303                 "thread bool operator==(thread const %s& left, thread const %s& right) {\n"
1304                 "    return ",
1305                 this->typeName(type).c_str(),
1306                 this->typeName(type).c_str());
1307 
1308         const char* separator = "";
1309         for (const Type::Field& field : type.fields()) {
1310             fExtraFunctions.printf("%s(left.%.*s == right.%.*s)",
1311                                    separator,
1312                                    (int)field.fName.size(), field.fName.data(),
1313                                    (int)field.fName.size(), field.fName.data());
1314             separator = " &&\n           ";
1315         }
1316         fExtraFunctions.printf(
1317                 ";\n"
1318                 "}\n"
1319                 "thread bool operator!=(thread const %s& left, thread const %s& right) {\n"
1320                 "    return !(left == right);\n"
1321                 "}\n",
1322                 this->typeName(type).c_str(),
1323                 this->typeName(type).c_str());
1324     }
1325 }
1326 
writeEqualityHelpers(const Type & leftType,const Type & rightType)1327 void MetalCodeGenerator::writeEqualityHelpers(const Type& leftType, const Type& rightType) {
1328     if (leftType.isArray() && rightType.isArray()) {
1329         this->writeArrayEqualityHelpers(leftType);
1330         return;
1331     }
1332     if (leftType.isStruct() && rightType.isStruct()) {
1333         this->writeStructEqualityHelpers(leftType);
1334         return;
1335     }
1336     if (leftType.isMatrix() && rightType.isMatrix()) {
1337         this->writeMatrixEqualityHelpers(leftType, rightType);
1338         return;
1339     }
1340 }
1341 
writeBinaryExpression(const BinaryExpression & b,Precedence parentPrecedence)1342 void MetalCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
1343                                                Precedence parentPrecedence) {
1344     const Expression& left = *b.left();
1345     const Expression& right = *b.right();
1346     const Type& leftType = left.type();
1347     const Type& rightType = right.type();
1348     Operator op = b.getOperator();
1349     Precedence precedence = op.getBinaryPrecedence();
1350     bool needParens = precedence >= parentPrecedence;
1351     switch (op.kind()) {
1352         case Token::Kind::TK_EQEQ:
1353             this->writeEqualityHelpers(leftType, rightType);
1354             if (leftType.isVector()) {
1355                 this->write("all");
1356                 needParens = true;
1357             }
1358             break;
1359         case Token::Kind::TK_NEQ:
1360             this->writeEqualityHelpers(leftType, rightType);
1361             if (leftType.isVector()) {
1362                 this->write("any");
1363                 needParens = true;
1364             }
1365             break;
1366         default:
1367             break;
1368     }
1369     if (needParens) {
1370         this->write("(");
1371     }
1372     if (leftType.isMatrix() && rightType.isMatrix() && op.kind() == Token::Kind::TK_STAREQ) {
1373         this->writeMatrixTimesEqualHelper(leftType, rightType, b.type());
1374     }
1375     this->writeExpression(left, precedence);
1376     if (op.kind() != Token::Kind::TK_EQ && op.isAssignment() &&
1377         left.kind() == Expression::Kind::kSwizzle && !left.hasSideEffects()) {
1378         // This doesn't compile in Metal:
1379         // float4 x = float4(1);
1380         // x.xy *= float2x2(...);
1381         // with the error message "non-const reference cannot bind to vector element",
1382         // but switching it to x.xy = x.xy * float2x2(...) fixes it. We perform this tranformation
1383         // as long as the LHS has no side effects, and hope for the best otherwise.
1384         this->write(" = ");
1385         this->writeExpression(left, Precedence::kAssignment);
1386         this->write(" ");
1387         String opName = OperatorName(op);
1388         SkASSERT(opName.endsWith("="));
1389         this->write(opName.substr(0, opName.size() - 1).c_str());
1390         this->write(" ");
1391     } else {
1392         this->write(String(" ") + OperatorName(op) + " ");
1393     }
1394     this->writeExpression(right, precedence);
1395     if (needParens) {
1396         this->write(")");
1397     }
1398 }
1399 
writeTernaryExpression(const TernaryExpression & t,Precedence parentPrecedence)1400 void MetalCodeGenerator::writeTernaryExpression(const TernaryExpression& t,
1401                                                Precedence parentPrecedence) {
1402     if (Precedence::kTernary >= parentPrecedence) {
1403         this->write("(");
1404     }
1405     this->writeExpression(*t.test(), Precedence::kTernary);
1406     this->write(" ? ");
1407     this->writeExpression(*t.ifTrue(), Precedence::kTernary);
1408     this->write(" : ");
1409     this->writeExpression(*t.ifFalse(), Precedence::kTernary);
1410     if (Precedence::kTernary >= parentPrecedence) {
1411         this->write(")");
1412     }
1413 }
1414 
writePrefixExpression(const PrefixExpression & p,Precedence parentPrecedence)1415 void MetalCodeGenerator::writePrefixExpression(const PrefixExpression& p,
1416                                               Precedence parentPrecedence) {
1417     if (Precedence::kPrefix >= parentPrecedence) {
1418         this->write("(");
1419     }
1420     this->write(OperatorName(p.getOperator()));
1421     this->writeExpression(*p.operand(), Precedence::kPrefix);
1422     if (Precedence::kPrefix >= parentPrecedence) {
1423         this->write(")");
1424     }
1425 }
1426 
writePostfixExpression(const PostfixExpression & p,Precedence parentPrecedence)1427 void MetalCodeGenerator::writePostfixExpression(const PostfixExpression& p,
1428                                                Precedence parentPrecedence) {
1429     if (Precedence::kPostfix >= parentPrecedence) {
1430         this->write("(");
1431     }
1432     this->writeExpression(*p.operand(), Precedence::kPostfix);
1433     this->write(OperatorName(p.getOperator()));
1434     if (Precedence::kPostfix >= parentPrecedence) {
1435         this->write(")");
1436     }
1437 }
1438 
writeBoolLiteral(const BoolLiteral & b)1439 void MetalCodeGenerator::writeBoolLiteral(const BoolLiteral& b) {
1440     this->write(b.value() ? "true" : "false");
1441 }
1442 
writeIntLiteral(const IntLiteral & i)1443 void MetalCodeGenerator::writeIntLiteral(const IntLiteral& i) {
1444     const Type& type = i.type();
1445     if (type == *fContext.fTypes.fUInt) {
1446         this->write(to_string(i.value() & 0xffffffff) + "u");
1447     } else if (type == *fContext.fTypes.fUShort) {
1448         this->write(to_string(i.value() & 0xffff) + "u");
1449     } else {
1450         this->write(to_string(i.value()));
1451     }
1452 }
1453 
writeFloatLiteral(const FloatLiteral & f)1454 void MetalCodeGenerator::writeFloatLiteral(const FloatLiteral& f) {
1455     this->write(to_string(f.value()));
1456 }
1457 
writeSetting(const Setting & s)1458 void MetalCodeGenerator::writeSetting(const Setting& s) {
1459     SK_ABORT("internal error; setting was not folded to a constant during compilation\n");
1460 }
1461 
writeFunctionRequirementArgs(const FunctionDeclaration & f,const char * & separator)1462 void MetalCodeGenerator::writeFunctionRequirementArgs(const FunctionDeclaration& f,
1463                                                       const char*& separator) {
1464     Requirements requirements = this->requirements(f);
1465     if (requirements & kInputs_Requirement) {
1466         this->write(separator);
1467         this->write("_in");
1468         separator = ", ";
1469     }
1470     if (requirements & kOutputs_Requirement) {
1471         this->write(separator);
1472         this->write("_out");
1473         separator = ", ";
1474     }
1475     if (requirements & kUniforms_Requirement) {
1476         this->write(separator);
1477         this->write("_uniforms");
1478         separator = ", ";
1479     }
1480     if (requirements & kGlobals_Requirement) {
1481         this->write(separator);
1482         this->write("_globals");
1483         separator = ", ";
1484     }
1485     if (requirements & kFragCoord_Requirement) {
1486         this->write(separator);
1487         this->write("_fragCoord");
1488         separator = ", ";
1489     }
1490 }
1491 
writeFunctionRequirementParams(const FunctionDeclaration & f,const char * & separator)1492 void MetalCodeGenerator::writeFunctionRequirementParams(const FunctionDeclaration& f,
1493                                                         const char*& separator) {
1494     Requirements requirements = this->requirements(f);
1495     if (requirements & kInputs_Requirement) {
1496         this->write(separator);
1497         this->write("Inputs _in");
1498         separator = ", ";
1499     }
1500     if (requirements & kOutputs_Requirement) {
1501         this->write(separator);
1502         this->write("thread Outputs& _out");
1503         separator = ", ";
1504     }
1505     if (requirements & kUniforms_Requirement) {
1506         this->write(separator);
1507         this->write("Uniforms _uniforms");
1508         separator = ", ";
1509     }
1510     if (requirements & kGlobals_Requirement) {
1511         this->write(separator);
1512         this->write("thread Globals& _globals");
1513         separator = ", ";
1514     }
1515     if (requirements & kFragCoord_Requirement) {
1516         this->write(separator);
1517         this->write("float4 _fragCoord");
1518         separator = ", ";
1519     }
1520 }
1521 
getUniformBinding(const Modifiers & m)1522 int MetalCodeGenerator::getUniformBinding(const Modifiers& m) {
1523     return (m.fLayout.fBinding >= 0) ? m.fLayout.fBinding
1524                                      : fProgram.fConfig->fSettings.fDefaultUniformBinding;
1525 }
1526 
getUniformSet(const Modifiers & m)1527 int MetalCodeGenerator::getUniformSet(const Modifiers& m) {
1528     return (m.fLayout.fSet >= 0) ? m.fLayout.fSet
1529                                  : fProgram.fConfig->fSettings.fDefaultUniformSet;
1530 }
1531 
writeFunctionDeclaration(const FunctionDeclaration & f)1532 bool MetalCodeGenerator::writeFunctionDeclaration(const FunctionDeclaration& f) {
1533     fRTHeightName = fProgram.fInputs.fRTHeight ? "_globals._anonInterface0->u_skRTHeight" : "";
1534     const char* separator = "";
1535     if (f.isMain()) {
1536         switch (fProgram.fConfig->fKind) {
1537             case ProgramKind::kFragment:
1538                 this->write("fragment Outputs fragmentMain");
1539                 break;
1540             case ProgramKind::kVertex:
1541                 this->write("vertex Outputs vertexMain");
1542                 break;
1543             default:
1544                 fErrors.error(-1, "unsupported kind of program");
1545                 return false;
1546         }
1547         this->write("(Inputs _in [[stage_in]]");
1548         if (-1 != fUniformBuffer) {
1549             this->write(", constant Uniforms& _uniforms [[buffer(" +
1550                         to_string(fUniformBuffer) + ")]]");
1551         }
1552         for (const ProgramElement* e : fProgram.elements()) {
1553             if (e->is<GlobalVarDeclaration>()) {
1554                 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
1555                 const VarDeclaration& var = decls.declaration()->as<VarDeclaration>();
1556                 if (var.var().type().typeKind() == Type::TypeKind::kSampler) {
1557                     if (var.var().modifiers().fLayout.fBinding < 0) {
1558                         fErrors.error(decls.fOffset,
1559                                       "Metal samplers must have 'layout(binding=...)'");
1560                         return false;
1561                     }
1562                     if (var.var().type().dimensions() != SpvDim2D) {
1563                         // Not yet implemented--Skia currently only uses 2D textures.
1564                         fErrors.error(decls.fOffset, "Unsupported texture dimensions");
1565                         return false;
1566                     }
1567                     this->write(", texture2d<float> ");
1568                     this->writeName(var.var().name());
1569                     this->write("[[texture(");
1570                     this->write(to_string(var.var().modifiers().fLayout.fBinding));
1571                     this->write(")]]");
1572                     this->write(", sampler ");
1573                     this->writeName(var.var().name());
1574                     this->write(SAMPLER_SUFFIX);
1575                     this->write("[[sampler(");
1576                     this->write(to_string(var.var().modifiers().fLayout.fBinding));
1577                     this->write(")]]");
1578                 }
1579             } else if (e->is<InterfaceBlock>()) {
1580                 const InterfaceBlock& intf = e->as<InterfaceBlock>();
1581                 if (intf.typeName() == "sk_PerVertex") {
1582                     continue;
1583                 }
1584                 this->write(", constant ");
1585                 this->writeType(intf.variable().type());
1586                 this->write("& " );
1587                 this->write(fInterfaceBlockNameMap[&intf]);
1588                 this->write(" [[buffer(");
1589                 this->write(to_string(this->getUniformBinding(intf.variable().modifiers())));
1590                 this->write(")]]");
1591             }
1592         }
1593         if (fProgram.fConfig->fKind == ProgramKind::kFragment) {
1594             if (fProgram.fInputs.fRTHeight && fInterfaceBlockNameMap.empty()) {
1595                 this->write(", constant sksl_synthetic_uniforms& _anonInterface0 [[buffer(1)]]");
1596                 fRTHeightName = "_anonInterface0.u_skRTHeight";
1597             }
1598             this->write(", bool _frontFacing [[front_facing]]");
1599             this->write(", float4 _fragCoord [[position]]");
1600         } else if (fProgram.fConfig->fKind == ProgramKind::kVertex) {
1601             this->write(", uint sk_VertexID [[vertex_id]], uint sk_InstanceID [[instance_id]]");
1602         }
1603         separator = ", ";
1604     } else {
1605         this->writeType(f.returnType());
1606         this->write(" ");
1607         this->writeName(f.mangledName());
1608         this->write("(");
1609         this->writeFunctionRequirementParams(f, separator);
1610     }
1611     for (const auto& param : f.parameters()) {
1612         if (f.isMain() && param->modifiers().fLayout.fBuiltin != -1) {
1613             continue;
1614         }
1615         this->write(separator);
1616         separator = ", ";
1617         this->writeModifiers(param->modifiers(), /*globalContext=*/false);
1618         const Type* type = &param->type();
1619         this->writeType(*type);
1620         if (param->modifiers().fFlags & Modifiers::kOut_Flag) {
1621             this->write("&");
1622         }
1623         this->write(" ");
1624         this->writeName(param->name());
1625     }
1626     this->write(")");
1627     return true;
1628 }
1629 
writeFunctionPrototype(const FunctionPrototype & f)1630 void MetalCodeGenerator::writeFunctionPrototype(const FunctionPrototype& f) {
1631     this->writeFunctionDeclaration(f.declaration());
1632     this->writeLine(";");
1633 }
1634 
is_block_ending_with_return(const Statement * stmt)1635 static bool is_block_ending_with_return(const Statement* stmt) {
1636     // This function detects (potentially nested) blocks that end in a return statement.
1637     if (!stmt->is<Block>()) {
1638         return false;
1639     }
1640     const StatementArray& block = stmt->as<Block>().children();
1641     for (int index = block.count(); index--; ) {
1642         const Statement& stmt = *block[index];
1643         if (stmt.is<ReturnStatement>()) {
1644             return true;
1645         }
1646         if (stmt.is<Block>()) {
1647             return is_block_ending_with_return(&stmt);
1648         }
1649         if (!stmt.is<Nop>()) {
1650             break;
1651         }
1652     }
1653     return false;
1654 }
1655 
writeFunction(const FunctionDefinition & f)1656 void MetalCodeGenerator::writeFunction(const FunctionDefinition& f) {
1657     SkASSERT(!fProgram.fConfig->fSettings.fFragColorIsInOut);
1658 
1659     if (!this->writeFunctionDeclaration(f.declaration())) {
1660         return;
1661     }
1662 
1663     fCurrentFunction = &f.declaration();
1664     SkScopeExit clearCurrentFunction([&] { fCurrentFunction = nullptr; });
1665 
1666     this->writeLine(" {");
1667 
1668     if (f.declaration().isMain()) {
1669         this->writeGlobalInit();
1670         this->writeLine("    Outputs _out;");
1671         this->writeLine("    (void)_out;");
1672     }
1673 
1674     fFunctionHeader.clear();
1675     StringStream buffer;
1676     {
1677         AutoOutputStream outputToBuffer(this, &buffer);
1678         fIndentation++;
1679         for (const std::unique_ptr<Statement>& stmt : f.body()->as<Block>().children()) {
1680             if (!stmt->isEmpty()) {
1681                 this->writeStatement(*stmt);
1682                 this->finishLine();
1683             }
1684         }
1685         if (f.declaration().isMain()) {
1686             // If the main function doesn't end with a return, we need to synthesize one here.
1687             if (!is_block_ending_with_return(f.body().get())) {
1688                 this->writeReturnStatementFromMain();
1689                 this->finishLine();
1690             }
1691         }
1692         fIndentation--;
1693         this->writeLine("}");
1694     }
1695     this->write(fFunctionHeader);
1696     this->write(buffer.str());
1697 }
1698 
writeModifiers(const Modifiers & modifiers,bool globalContext)1699 void MetalCodeGenerator::writeModifiers(const Modifiers& modifiers,
1700                                         bool globalContext) {
1701     if (modifiers.fFlags & Modifiers::kOut_Flag) {
1702         this->write("thread ");
1703     }
1704     if (modifiers.fFlags & Modifiers::kConst_Flag) {
1705         this->write("const ");
1706     }
1707 }
1708 
writeInterfaceBlock(const InterfaceBlock & intf)1709 void MetalCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
1710     if ("sk_PerVertex" == intf.typeName()) {
1711         return;
1712     }
1713     this->writeModifiers(intf.variable().modifiers(), /*globalContext=*/true);
1714     this->write("struct ");
1715     this->writeLine(intf.typeName() + " {");
1716     const Type* structType = &intf.variable().type();
1717     if (structType->isArray()) {
1718         structType = &structType->componentType();
1719     }
1720     fIndentation++;
1721     this->writeFields(structType->fields(), structType->fOffset, &intf);
1722     if (fProgram.fInputs.fRTHeight) {
1723         this->writeLine("float u_skRTHeight;");
1724     }
1725     fIndentation--;
1726     this->write("}");
1727     if (intf.instanceName().size()) {
1728         this->write(" ");
1729         this->write(intf.instanceName());
1730         if (intf.arraySize() > 0) {
1731             this->write("[");
1732             this->write(to_string(intf.arraySize()));
1733             this->write("]");
1734         } else if (intf.arraySize() == Type::kUnsizedArray){
1735             this->write("[]");
1736         }
1737         fInterfaceBlockNameMap[&intf] = intf.instanceName();
1738     } else {
1739         fInterfaceBlockNameMap[&intf] = "_anonInterface" +  to_string(fAnonInterfaceCount++);
1740     }
1741     this->writeLine(";");
1742 }
1743 
writeFields(const std::vector<Type::Field> & fields,int parentOffset,const InterfaceBlock * parentIntf)1744 void MetalCodeGenerator::writeFields(const std::vector<Type::Field>& fields, int parentOffset,
1745                                      const InterfaceBlock* parentIntf) {
1746     MemoryLayout memoryLayout(MemoryLayout::kMetal_Standard);
1747     int currentOffset = 0;
1748     for (const Type::Field& field : fields) {
1749         int fieldOffset = field.fModifiers.fLayout.fOffset;
1750         const Type* fieldType = field.fType;
1751         if (!MemoryLayout::LayoutIsSupported(*fieldType)) {
1752             fErrors.error(parentOffset, "type '" + fieldType->name() + "' is not permitted here");
1753             return;
1754         }
1755         if (fieldOffset != -1) {
1756             if (currentOffset > fieldOffset) {
1757                 fErrors.error(parentOffset,
1758                               "offset of field '" + field.fName + "' must be at least " +
1759                               to_string((int) currentOffset));
1760                 return;
1761             } else if (currentOffset < fieldOffset) {
1762                 this->write("char pad");
1763                 this->write(to_string(fPaddingCount++));
1764                 this->write("[");
1765                 this->write(to_string(fieldOffset - currentOffset));
1766                 this->writeLine("];");
1767                 currentOffset = fieldOffset;
1768             }
1769             int alignment = memoryLayout.alignment(*fieldType);
1770             if (fieldOffset % alignment) {
1771                 fErrors.error(parentOffset,
1772                               "offset of field '" + field.fName + "' must be a multiple of " +
1773                               to_string((int) alignment));
1774                 return;
1775             }
1776         }
1777         size_t fieldSize = memoryLayout.size(*fieldType);
1778         if (fieldSize > static_cast<size_t>(std::numeric_limits<int>::max() - currentOffset)) {
1779             fErrors.error(parentOffset, "field offset overflow");
1780             return;
1781         }
1782         currentOffset += fieldSize;
1783         this->writeModifiers(field.fModifiers, /*globalContext=*/false);
1784         this->writeType(*fieldType);
1785         this->write(" ");
1786         this->writeName(field.fName);
1787         this->writeLine(";");
1788         if (parentIntf) {
1789             fInterfaceBlockMap[&field] = parentIntf;
1790         }
1791     }
1792 }
1793 
writeVarInitializer(const Variable & var,const Expression & value)1794 void MetalCodeGenerator::writeVarInitializer(const Variable& var, const Expression& value) {
1795     this->writeExpression(value, Precedence::kTopLevel);
1796 }
1797 
writeName(const String & name)1798 void MetalCodeGenerator::writeName(const String& name) {
1799     if (fReservedWords.find(name) != fReservedWords.end()) {
1800         this->write("_"); // adding underscore before name to avoid conflict with reserved words
1801     }
1802     this->write(name);
1803 }
1804 
writeVarDeclaration(const VarDeclaration & varDecl,bool global)1805 void MetalCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl, bool global) {
1806     if (global && !(varDecl.var().modifiers().fFlags & Modifiers::kConst_Flag)) {
1807         return;
1808     }
1809     this->writeModifiers(varDecl.var().modifiers(), global);
1810     this->writeType(varDecl.var().type());
1811     this->write(" ");
1812     this->writeName(varDecl.var().name());
1813     if (varDecl.value()) {
1814         this->write(" = ");
1815         this->writeVarInitializer(varDecl.var(), *varDecl.value());
1816     }
1817     this->write(";");
1818 }
1819 
writeStatement(const Statement & s)1820 void MetalCodeGenerator::writeStatement(const Statement& s) {
1821     switch (s.kind()) {
1822         case Statement::Kind::kBlock:
1823             this->writeBlock(s.as<Block>());
1824             break;
1825         case Statement::Kind::kExpression:
1826             this->writeExpression(*s.as<ExpressionStatement>().expression(), Precedence::kTopLevel);
1827             this->write(";");
1828             break;
1829         case Statement::Kind::kReturn:
1830             this->writeReturnStatement(s.as<ReturnStatement>());
1831             break;
1832         case Statement::Kind::kVarDeclaration:
1833             this->writeVarDeclaration(s.as<VarDeclaration>(), false);
1834             break;
1835         case Statement::Kind::kIf:
1836             this->writeIfStatement(s.as<IfStatement>());
1837             break;
1838         case Statement::Kind::kFor:
1839             this->writeForStatement(s.as<ForStatement>());
1840             break;
1841         case Statement::Kind::kDo:
1842             this->writeDoStatement(s.as<DoStatement>());
1843             break;
1844         case Statement::Kind::kSwitch:
1845             this->writeSwitchStatement(s.as<SwitchStatement>());
1846             break;
1847         case Statement::Kind::kBreak:
1848             this->write("break;");
1849             break;
1850         case Statement::Kind::kContinue:
1851             this->write("continue;");
1852             break;
1853         case Statement::Kind::kDiscard:
1854             this->write("discard_fragment();");
1855             break;
1856         case Statement::Kind::kInlineMarker:
1857         case Statement::Kind::kNop:
1858             this->write(";");
1859             break;
1860         default:
1861             SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
1862             break;
1863     }
1864 }
1865 
writeBlock(const Block & b)1866 void MetalCodeGenerator::writeBlock(const Block& b) {
1867     // Write scope markers if this block is a scope, or if the block is empty (since we need to emit
1868     // something here to make the code valid).
1869     bool isScope = b.isScope() || b.isEmpty();
1870     if (isScope) {
1871         this->writeLine("{");
1872         fIndentation++;
1873     }
1874     for (const std::unique_ptr<Statement>& stmt : b.children()) {
1875         if (!stmt->isEmpty()) {
1876             this->writeStatement(*stmt);
1877             this->finishLine();
1878         }
1879     }
1880     if (isScope) {
1881         fIndentation--;
1882         this->write("}");
1883     }
1884 }
1885 
writeIfStatement(const IfStatement & stmt)1886 void MetalCodeGenerator::writeIfStatement(const IfStatement& stmt) {
1887     this->write("if (");
1888     this->writeExpression(*stmt.test(), Precedence::kTopLevel);
1889     this->write(") ");
1890     this->writeStatement(*stmt.ifTrue());
1891     if (stmt.ifFalse()) {
1892         this->write(" else ");
1893         this->writeStatement(*stmt.ifFalse());
1894     }
1895 }
1896 
writeForStatement(const ForStatement & f)1897 void MetalCodeGenerator::writeForStatement(const ForStatement& f) {
1898     // Emit loops of the form 'for(;test;)' as 'while(test)', which is probably how they started
1899     if (!f.initializer() && f.test() && !f.next()) {
1900         this->write("while (");
1901         this->writeExpression(*f.test(), Precedence::kTopLevel);
1902         this->write(") ");
1903         this->writeStatement(*f.statement());
1904         return;
1905     }
1906 
1907     this->write("for (");
1908     if (f.initializer() && !f.initializer()->isEmpty()) {
1909         this->writeStatement(*f.initializer());
1910     } else {
1911         this->write("; ");
1912     }
1913     if (f.test()) {
1914         this->writeExpression(*f.test(), Precedence::kTopLevel);
1915     }
1916     this->write("; ");
1917     if (f.next()) {
1918         this->writeExpression(*f.next(), Precedence::kTopLevel);
1919     }
1920     this->write(") ");
1921     this->writeStatement(*f.statement());
1922 }
1923 
writeDoStatement(const DoStatement & d)1924 void MetalCodeGenerator::writeDoStatement(const DoStatement& d) {
1925     this->write("do ");
1926     this->writeStatement(*d.statement());
1927     this->write(" while (");
1928     this->writeExpression(*d.test(), Precedence::kTopLevel);
1929     this->write(");");
1930 }
1931 
writeSwitchStatement(const SwitchStatement & s)1932 void MetalCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
1933     this->write("switch (");
1934     this->writeExpression(*s.value(), Precedence::kTopLevel);
1935     this->writeLine(") {");
1936     fIndentation++;
1937     for (const std::unique_ptr<Statement>& stmt : s.cases()) {
1938         const SwitchCase& c = stmt->as<SwitchCase>();
1939         if (c.value()) {
1940             this->write("case ");
1941             this->writeExpression(*c.value(), Precedence::kTopLevel);
1942             this->writeLine(":");
1943         } else {
1944             this->writeLine("default:");
1945         }
1946         if (!c.statement()->isEmpty()) {
1947             fIndentation++;
1948             this->writeStatement(*c.statement());
1949             this->finishLine();
1950             fIndentation--;
1951         }
1952     }
1953     fIndentation--;
1954     this->write("}");
1955 }
1956 
writeReturnStatementFromMain()1957 void MetalCodeGenerator::writeReturnStatementFromMain() {
1958     // main functions in Metal return a magic _out parameter that doesn't exist in SkSL.
1959     switch (fProgram.fConfig->fKind) {
1960         case ProgramKind::kFragment:
1961             this->write("return _out;");
1962             break;
1963         case ProgramKind::kVertex:
1964             this->write("return (_out.sk_Position.y = -_out.sk_Position.y, _out);");
1965             break;
1966         default:
1967             SkDEBUGFAIL("unsupported kind of program");
1968     }
1969 }
1970 
writeReturnStatement(const ReturnStatement & r)1971 void MetalCodeGenerator::writeReturnStatement(const ReturnStatement& r) {
1972     if (fCurrentFunction && fCurrentFunction->isMain()) {
1973         if (r.expression()) {
1974             if (r.expression()->type() == *fContext.fTypes.fHalf4) {
1975                 this->write("_out.sk_FragColor = ");
1976                 this->writeExpression(*r.expression(), Precedence::kTopLevel);
1977                 this->writeLine(";");
1978             } else {
1979                 fErrors.error(r.fOffset, "Metal does not support returning '" +
1980                                          r.expression()->type().description() + "' from main()");
1981             }
1982         }
1983         this->writeReturnStatementFromMain();
1984         return;
1985     }
1986 
1987     this->write("return");
1988     if (r.expression()) {
1989         this->write(" ");
1990         this->writeExpression(*r.expression(), Precedence::kTopLevel);
1991     }
1992     this->write(";");
1993 }
1994 
writeHeader()1995 void MetalCodeGenerator::writeHeader() {
1996     this->write("#include <metal_stdlib>\n");
1997     this->write("#include <simd/simd.h>\n");
1998     this->write("using namespace metal;\n");
1999 }
2000 
writeUniformStruct()2001 void MetalCodeGenerator::writeUniformStruct() {
2002     for (const ProgramElement* e : fProgram.elements()) {
2003         if (e->is<GlobalVarDeclaration>()) {
2004             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2005             const Variable& var = decls.declaration()->as<VarDeclaration>().var();
2006             if (var.modifiers().fFlags & Modifiers::kUniform_Flag &&
2007                 var.type().typeKind() != Type::TypeKind::kSampler) {
2008                 int uniformSet = this->getUniformSet(var.modifiers());
2009                 // Make sure that the program's uniform-set value is consistent throughout.
2010                 if (-1 == fUniformBuffer) {
2011                     this->write("struct Uniforms {\n");
2012                     fUniformBuffer = uniformSet;
2013                 } else if (uniformSet != fUniformBuffer) {
2014                     fErrors.error(decls.fOffset, "Metal backend requires all uniforms to have "
2015                                                  "the same 'layout(set=...)'");
2016                 }
2017                 this->write("    ");
2018                 this->writeType(var.type());
2019                 this->write(" ");
2020                 this->writeName(var.name());
2021                 this->write(";\n");
2022             }
2023         }
2024     }
2025     if (-1 != fUniformBuffer) {
2026         this->write("};\n");
2027     }
2028 }
2029 
writeInputStruct()2030 void MetalCodeGenerator::writeInputStruct() {
2031     this->write("struct Inputs {\n");
2032     for (const ProgramElement* e : fProgram.elements()) {
2033         if (e->is<GlobalVarDeclaration>()) {
2034             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2035             const Variable& var = decls.declaration()->as<VarDeclaration>().var();
2036             if (var.modifiers().fFlags & Modifiers::kIn_Flag &&
2037                 -1 == var.modifiers().fLayout.fBuiltin) {
2038                 this->write("    ");
2039                 this->writeType(var.type());
2040                 this->write(" ");
2041                 this->writeName(var.name());
2042                 if (-1 != var.modifiers().fLayout.fLocation) {
2043                     if (fProgram.fConfig->fKind == ProgramKind::kVertex) {
2044                         this->write("  [[attribute(" +
2045                                     to_string(var.modifiers().fLayout.fLocation) + ")]]");
2046                     } else if (fProgram.fConfig->fKind == ProgramKind::kFragment) {
2047                         this->write("  [[user(locn" +
2048                                     to_string(var.modifiers().fLayout.fLocation) + ")]]");
2049                     }
2050                 }
2051                 this->write(";\n");
2052             }
2053         }
2054     }
2055     this->write("};\n");
2056 }
2057 
writeOutputStruct()2058 void MetalCodeGenerator::writeOutputStruct() {
2059     this->write("struct Outputs {\n");
2060     if (fProgram.fConfig->fKind == ProgramKind::kVertex) {
2061         this->write("    float4 sk_Position [[position]];\n");
2062     } else if (fProgram.fConfig->fKind == ProgramKind::kFragment) {
2063         this->write("    float4 sk_FragColor [[color(0)]];\n");
2064     }
2065     for (const ProgramElement* e : fProgram.elements()) {
2066         if (e->is<GlobalVarDeclaration>()) {
2067             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2068             const Variable& var = decls.declaration()->as<VarDeclaration>().var();
2069             if (var.modifiers().fFlags & Modifiers::kOut_Flag &&
2070                 -1 == var.modifiers().fLayout.fBuiltin) {
2071                 this->write("    ");
2072                 this->writeType(var.type());
2073                 this->write(" ");
2074                 this->writeName(var.name());
2075 
2076                 int location = var.modifiers().fLayout.fLocation;
2077                 if (location < 0) {
2078                     fErrors.error(var.fOffset,
2079                                   "Metal out variables must have 'layout(location=...)'");
2080                 } else if (fProgram.fConfig->fKind == ProgramKind::kVertex) {
2081                     this->write(" [[user(locn" + to_string(location) + ")]]");
2082                 } else if (fProgram.fConfig->fKind == ProgramKind::kFragment) {
2083                     this->write(" [[color(" + to_string(location) + ")");
2084                     int colorIndex = var.modifiers().fLayout.fIndex;
2085                     if (colorIndex) {
2086                         this->write(", index(" + to_string(colorIndex) + ")");
2087                     }
2088                     this->write("]]");
2089                 }
2090                 this->write(";\n");
2091             }
2092         }
2093     }
2094     if (fProgram.fConfig->fKind == ProgramKind::kVertex) {
2095         this->write("    float sk_PointSize [[point_size]];\n");
2096     }
2097     this->write("};\n");
2098 }
2099 
writeInterfaceBlocks()2100 void MetalCodeGenerator::writeInterfaceBlocks() {
2101     bool wroteInterfaceBlock = false;
2102     for (const ProgramElement* e : fProgram.elements()) {
2103         if (e->is<InterfaceBlock>()) {
2104             this->writeInterfaceBlock(e->as<InterfaceBlock>());
2105             wroteInterfaceBlock = true;
2106         }
2107     }
2108     if (!wroteInterfaceBlock && fProgram.fInputs.fRTHeight) {
2109         this->writeLine("struct sksl_synthetic_uniforms {");
2110         this->writeLine("    float u_skRTHeight;");
2111         this->writeLine("};");
2112     }
2113 }
2114 
writeStructDefinitions()2115 void MetalCodeGenerator::writeStructDefinitions() {
2116     for (const ProgramElement* e : fProgram.elements()) {
2117         if (e->is<StructDefinition>()) {
2118             this->writeStructDefinition(e->as<StructDefinition>());
2119         }
2120     }
2121 }
2122 
visitGlobalStruct(GlobalStructVisitor * visitor)2123 void MetalCodeGenerator::visitGlobalStruct(GlobalStructVisitor* visitor) {
2124     // Visit the interface blocks.
2125     for (const auto& [interfaceType, interfaceName] : fInterfaceBlockNameMap) {
2126         visitor->visitInterfaceBlock(*interfaceType, interfaceName);
2127     }
2128     for (const ProgramElement* element : fProgram.elements()) {
2129         if (!element->is<GlobalVarDeclaration>()) {
2130             continue;
2131         }
2132         const GlobalVarDeclaration& global = element->as<GlobalVarDeclaration>();
2133         const VarDeclaration& decl = global.declaration()->as<VarDeclaration>();
2134         const Variable& var = decl.var();
2135         if ((!var.modifiers().fFlags && -1 == var.modifiers().fLayout.fBuiltin) ||
2136             var.type().typeKind() == Type::TypeKind::kSampler) {
2137             if (var.type().typeKind() == Type::TypeKind::kSampler) {
2138                 // Samplers are represented as a "texture/sampler" duo in the global struct.
2139                 visitor->visitTexture(var.type(), var.name());
2140                 visitor->visitSampler(var.type(), String(var.name()) + SAMPLER_SUFFIX);
2141             } else {
2142                 // Visit a regular variable.
2143                 visitor->visitVariable(var, decl.value().get());
2144             }
2145         }
2146     }
2147 }
2148 
writeGlobalStruct()2149 void MetalCodeGenerator::writeGlobalStruct() {
2150     class : public GlobalStructVisitor {
2151     public:
2152         void visitInterfaceBlock(const InterfaceBlock& block, const String& blockName) override {
2153             this->addElement();
2154             fCodeGen->write("    constant ");
2155             fCodeGen->write(block.typeName());
2156             fCodeGen->write("* ");
2157             fCodeGen->writeName(blockName);
2158             fCodeGen->write(";\n");
2159         }
2160         void visitTexture(const Type& type, const String& name) override {
2161             this->addElement();
2162             fCodeGen->write("    ");
2163             fCodeGen->writeType(type);
2164             fCodeGen->write(" ");
2165             fCodeGen->writeName(name);
2166             fCodeGen->write(";\n");
2167         }
2168         void visitSampler(const Type&, const String& name) override {
2169             this->addElement();
2170             fCodeGen->write("    sampler ");
2171             fCodeGen->writeName(name);
2172             fCodeGen->write(";\n");
2173         }
2174         void visitVariable(const Variable& var, const Expression* value) override {
2175             this->addElement();
2176             fCodeGen->write("    ");
2177             fCodeGen->writeType(var.type());
2178             fCodeGen->write(" ");
2179             fCodeGen->writeName(var.name());
2180             fCodeGen->write(";\n");
2181         }
2182         void addElement() {
2183             if (fFirst) {
2184                 fCodeGen->write("struct Globals {\n");
2185                 fFirst = false;
2186             }
2187         }
2188         void finish() {
2189             if (!fFirst) {
2190                 fCodeGen->writeLine("};");
2191                 fFirst = true;
2192             }
2193         }
2194 
2195         MetalCodeGenerator* fCodeGen = nullptr;
2196         bool fFirst = true;
2197     } visitor;
2198 
2199     visitor.fCodeGen = this;
2200     this->visitGlobalStruct(&visitor);
2201     visitor.finish();
2202 }
2203 
writeGlobalInit()2204 void MetalCodeGenerator::writeGlobalInit() {
2205     class : public GlobalStructVisitor {
2206     public:
2207         void visitInterfaceBlock(const InterfaceBlock& blockType,
2208                                  const String& blockName) override {
2209             this->addElement();
2210             fCodeGen->write("&");
2211             fCodeGen->writeName(blockName);
2212         }
2213         void visitTexture(const Type&, const String& name) override {
2214             this->addElement();
2215             fCodeGen->writeName(name);
2216         }
2217         void visitSampler(const Type&, const String& name) override {
2218             this->addElement();
2219             fCodeGen->writeName(name);
2220         }
2221         void visitVariable(const Variable& var, const Expression* value) override {
2222             this->addElement();
2223             if (value) {
2224                 fCodeGen->writeVarInitializer(var, *value);
2225             } else {
2226                 fCodeGen->write("{}");
2227             }
2228         }
2229         void addElement() {
2230             if (fFirst) {
2231                 fCodeGen->write("    Globals _globals{");
2232                 fFirst = false;
2233             } else {
2234                 fCodeGen->write(", ");
2235             }
2236         }
2237         void finish() {
2238             if (!fFirst) {
2239                 fCodeGen->writeLine("};");
2240                 fCodeGen->writeLine("    (void)_globals;");
2241             }
2242         }
2243         MetalCodeGenerator* fCodeGen = nullptr;
2244         bool fFirst = true;
2245     } visitor;
2246 
2247     visitor.fCodeGen = this;
2248     this->visitGlobalStruct(&visitor);
2249     visitor.finish();
2250 }
2251 
writeProgramElement(const ProgramElement & e)2252 void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) {
2253     switch (e.kind()) {
2254         case ProgramElement::Kind::kExtension:
2255             break;
2256         case ProgramElement::Kind::kGlobalVar: {
2257             const GlobalVarDeclaration& global = e.as<GlobalVarDeclaration>();
2258             const VarDeclaration& decl = global.declaration()->as<VarDeclaration>();
2259             int builtin = decl.var().modifiers().fLayout.fBuiltin;
2260             if (-1 == builtin) {
2261                 // normal var
2262                 this->writeVarDeclaration(decl, true);
2263                 this->finishLine();
2264             } else if (SK_FRAGCOLOR_BUILTIN == builtin) {
2265                 // ignore
2266             }
2267             break;
2268         }
2269         case ProgramElement::Kind::kInterfaceBlock:
2270             // handled in writeInterfaceBlocks, do nothing
2271             break;
2272         case ProgramElement::Kind::kStructDefinition:
2273             // Handled in writeStructDefinitions. Do nothing.
2274             break;
2275         case ProgramElement::Kind::kFunction:
2276             this->writeFunction(e.as<FunctionDefinition>());
2277             break;
2278         case ProgramElement::Kind::kFunctionPrototype:
2279             this->writeFunctionPrototype(e.as<FunctionPrototype>());
2280             break;
2281         case ProgramElement::Kind::kModifiers:
2282             this->writeModifiers(e.as<ModifiersDeclaration>().modifiers(),
2283                                  /*globalContext=*/true);
2284             this->writeLine(";");
2285             break;
2286         case ProgramElement::Kind::kEnum:
2287             break;
2288         default:
2289             SkDEBUGFAILF("unsupported program element: %s\n", e.description().c_str());
2290             break;
2291     }
2292 }
2293 
requirements(const Expression * e)2294 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Expression* e) {
2295     if (!e) {
2296         return kNo_Requirements;
2297     }
2298     switch (e->kind()) {
2299         case Expression::Kind::kFunctionCall: {
2300             const FunctionCall& f = e->as<FunctionCall>();
2301             Requirements result = this->requirements(f.function());
2302             for (const auto& arg : f.arguments()) {
2303                 result |= this->requirements(arg.get());
2304             }
2305             return result;
2306         }
2307         case Expression::Kind::kConstructorCompound:
2308         case Expression::Kind::kConstructorCompoundCast:
2309         case Expression::Kind::kConstructorArray:
2310         case Expression::Kind::kConstructorDiagonalMatrix:
2311         case Expression::Kind::kConstructorScalarCast:
2312         case Expression::Kind::kConstructorSplat:
2313         case Expression::Kind::kConstructorStruct: {
2314             const AnyConstructor& c = e->asAnyConstructor();
2315             Requirements result = kNo_Requirements;
2316             for (const auto& arg : c.argumentSpan()) {
2317                 result |= this->requirements(arg.get());
2318             }
2319             return result;
2320         }
2321         case Expression::Kind::kFieldAccess: {
2322             const FieldAccess& f = e->as<FieldAccess>();
2323             if (FieldAccess::OwnerKind::kAnonymousInterfaceBlock == f.ownerKind()) {
2324                 return kGlobals_Requirement;
2325             }
2326             return this->requirements(f.base().get());
2327         }
2328         case Expression::Kind::kSwizzle:
2329             return this->requirements(e->as<Swizzle>().base().get());
2330         case Expression::Kind::kBinary: {
2331             const BinaryExpression& bin = e->as<BinaryExpression>();
2332             return this->requirements(bin.left().get()) |
2333                    this->requirements(bin.right().get());
2334         }
2335         case Expression::Kind::kIndex: {
2336             const IndexExpression& idx = e->as<IndexExpression>();
2337             return this->requirements(idx.base().get()) | this->requirements(idx.index().get());
2338         }
2339         case Expression::Kind::kPrefix:
2340             return this->requirements(e->as<PrefixExpression>().operand().get());
2341         case Expression::Kind::kPostfix:
2342             return this->requirements(e->as<PostfixExpression>().operand().get());
2343         case Expression::Kind::kTernary: {
2344             const TernaryExpression& t = e->as<TernaryExpression>();
2345             return this->requirements(t.test().get()) | this->requirements(t.ifTrue().get()) |
2346                    this->requirements(t.ifFalse().get());
2347         }
2348         case Expression::Kind::kVariableReference: {
2349             const VariableReference& v = e->as<VariableReference>();
2350             const Modifiers& modifiers = v.variable()->modifiers();
2351             Requirements result = kNo_Requirements;
2352             if (modifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
2353                 result = kGlobals_Requirement | kFragCoord_Requirement;
2354             } else if (Variable::Storage::kGlobal == v.variable()->storage()) {
2355                 if (modifiers.fFlags & Modifiers::kIn_Flag) {
2356                     result = kInputs_Requirement;
2357                 } else if (modifiers.fFlags & Modifiers::kOut_Flag) {
2358                     result = kOutputs_Requirement;
2359                 } else if (modifiers.fFlags & Modifiers::kUniform_Flag &&
2360                            v.variable()->type().typeKind() != Type::TypeKind::kSampler) {
2361                     result = kUniforms_Requirement;
2362                 } else {
2363                     result = kGlobals_Requirement;
2364                 }
2365             }
2366             return result;
2367         }
2368         default:
2369             return kNo_Requirements;
2370     }
2371 }
2372 
requirements(const Statement * s)2373 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Statement* s) {
2374     if (!s) {
2375         return kNo_Requirements;
2376     }
2377     switch (s->kind()) {
2378         case Statement::Kind::kBlock: {
2379             Requirements result = kNo_Requirements;
2380             for (const std::unique_ptr<Statement>& child : s->as<Block>().children()) {
2381                 result |= this->requirements(child.get());
2382             }
2383             return result;
2384         }
2385         case Statement::Kind::kVarDeclaration: {
2386             const VarDeclaration& var = s->as<VarDeclaration>();
2387             return this->requirements(var.value().get());
2388         }
2389         case Statement::Kind::kExpression:
2390             return this->requirements(s->as<ExpressionStatement>().expression().get());
2391         case Statement::Kind::kReturn: {
2392             const ReturnStatement& r = s->as<ReturnStatement>();
2393             return this->requirements(r.expression().get());
2394         }
2395         case Statement::Kind::kIf: {
2396             const IfStatement& i = s->as<IfStatement>();
2397             return this->requirements(i.test().get()) |
2398                    this->requirements(i.ifTrue().get()) |
2399                    this->requirements(i.ifFalse().get());
2400         }
2401         case Statement::Kind::kFor: {
2402             const ForStatement& f = s->as<ForStatement>();
2403             return this->requirements(f.initializer().get()) |
2404                    this->requirements(f.test().get()) |
2405                    this->requirements(f.next().get()) |
2406                    this->requirements(f.statement().get());
2407         }
2408         case Statement::Kind::kDo: {
2409             const DoStatement& d = s->as<DoStatement>();
2410             return this->requirements(d.test().get()) |
2411                    this->requirements(d.statement().get());
2412         }
2413         case Statement::Kind::kSwitch: {
2414             const SwitchStatement& sw = s->as<SwitchStatement>();
2415             Requirements result = this->requirements(sw.value().get());
2416             for (const std::unique_ptr<Statement>& sc : sw.cases()) {
2417                 result |= this->requirements(sc->as<SwitchCase>().statement().get());
2418             }
2419             return result;
2420         }
2421         default:
2422             return kNo_Requirements;
2423     }
2424 }
2425 
requirements(const FunctionDeclaration & f)2426 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const FunctionDeclaration& f) {
2427     if (f.isBuiltin()) {
2428         return kNo_Requirements;
2429     }
2430     auto found = fRequirements.find(&f);
2431     if (found == fRequirements.end()) {
2432         fRequirements[&f] = kNo_Requirements;
2433         for (const ProgramElement* e : fProgram.elements()) {
2434             if (e->is<FunctionDefinition>()) {
2435                 const FunctionDefinition& def = e->as<FunctionDefinition>();
2436                 if (&def.declaration() == &f) {
2437                     Requirements reqs = this->requirements(def.body().get());
2438                     fRequirements[&f] = reqs;
2439                     return reqs;
2440                 }
2441             }
2442         }
2443         // We never found a definition for this declared function, but it's legal to prototype a
2444         // function without ever giving a definition, as long as you don't call it.
2445         return kNo_Requirements;
2446     }
2447     return found->second;
2448 }
2449 
generateCode()2450 bool MetalCodeGenerator::generateCode() {
2451     StringStream header;
2452     {
2453         AutoOutputStream outputToHeader(this, &header, &fIndentation);
2454         this->writeHeader();
2455         this->writeStructDefinitions();
2456         this->writeUniformStruct();
2457         this->writeInputStruct();
2458         this->writeOutputStruct();
2459         this->writeInterfaceBlocks();
2460         this->writeGlobalStruct();
2461     }
2462     StringStream body;
2463     {
2464         AutoOutputStream outputToBody(this, &body, &fIndentation);
2465         for (const ProgramElement* e : fProgram.elements()) {
2466             this->writeProgramElement(*e);
2467         }
2468     }
2469     write_stringstream(header, *fOut);
2470     write_stringstream(fExtraFunctions, *fOut);
2471     write_stringstream(body, *fOut);
2472     return 0 == fErrors.errorCount();
2473 }
2474 
2475 }  // namespace SkSL
2476