1 /*
2  * Copyright 2018 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #ifndef SKSL_JIT
9 #define SKSL_JIT
10 
11 #ifdef SK_LLVM_AVAILABLE
12 
13 #include "ir/SkSLBinaryExpression.h"
14 #include "ir/SkSLBreakStatement.h"
15 #include "ir/SkSLContinueStatement.h"
16 #include "ir/SkSLExpression.h"
17 #include "ir/SkSLDoStatement.h"
18 #include "ir/SkSLForStatement.h"
19 #include "ir/SkSLFunctionCall.h"
20 #include "ir/SkSLFunctionDefinition.h"
21 #include "ir/SkSLIfStatement.h"
22 #include "ir/SkSLIndexExpression.h"
23 #include "ir/SkSLPrefixExpression.h"
24 #include "ir/SkSLPostfixExpression.h"
25 #include "ir/SkSLProgram.h"
26 #include "ir/SkSLReturnStatement.h"
27 #include "ir/SkSLStatement.h"
28 #include "ir/SkSLSwizzle.h"
29 #include "ir/SkSLTernaryExpression.h"
30 #include "ir/SkSLVarDeclarationsStatement.h"
31 #include "ir/SkSLVariableReference.h"
32 #include "ir/SkSLWhileStatement.h"
33 
34 #include "llvm-c/Analysis.h"
35 #include "llvm-c/Core.h"
36 #include "llvm-c/OrcBindings.h"
37 #include "llvm-c/Support.h"
38 #include "llvm-c/Target.h"
39 #include "llvm-c/Transforms/PassManagerBuilder.h"
40 #include "llvm-c/Types.h"
41 #include <stack>
42 
43 class SkRasterPipeline;
44 
45 namespace SkSL {
46 
47 struct AppendStage;
48 
49 /**
50  * A just-in-time compiler for SkSL code which uses an LLVM backend. Only available when the
51  * skia_llvm_path gn arg is set.
52  *
53  * Example of using SkSLJIT to set up an SkJumper pipeline stage:
54  *
55  * #ifdef SK_LLVM_AVAILABLE
56  *   SkSL::Compiler compiler;
57  *   SkSL::Program::Settings settings;
58  *   std::unique_ptr<SkSL::Program> program = compiler.convertProgram(
59          SkSL::Program::kPipelineStage_Kind,
60  *       "void swap(int x, int y, inout float4 color) {"
61  *       "    color.rb = color.br;"
62  *       "}",
63  *       settings);
64  *   if (!program) {
65  *       printf("%s\n", compiler.errorText().c_str());
66  *       abort();
67  *   }
68  *   SkSL::JIT& jit = *scratch->make<SkSL::JIT>(&compiler);
69  *   std::unique_ptr<SkSL::JIT::Module> module = jit.compile(std::move(program));
70  *   void* func = module->getJumperStage("swap");
71  *   p->append(func, nullptr);
72  * #endif
73  */
74 class JIT {
75     typedef int StackIndex;
76 
77 public:
78     class Module {
79     public:
80         /**
81          * Returns the address of a symbol in the module.
82          */
83         void* getSymbol(const char* name);
84 
85         /**
86          * Returns the address of a function as an SkJumper pipeline stage. The function must have
87          * the signature void <name>(int x, int y, inout float4 color). The returned function will
88          * have the correct signature to function as an SkJumper stage (meaning it will actually
89          * have a different signature at runtime, accepting vector parameters and operating on
90          * multiple pixels simultaneously as is normal for SkJumper stages).
91          */
92         void* getJumperStage(const char* name);
93 
~Module()94         ~Module() {
95             LLVMOrcDisposeSharedModuleRef(fSharedModule);
96         }
97 
98     private:
Module(std::unique_ptr<Program> program,LLVMSharedModuleRef sharedModule,LLVMOrcJITStackRef jitStack)99         Module(std::unique_ptr<Program> program,
100                LLVMSharedModuleRef sharedModule,
101                LLVMOrcJITStackRef jitStack)
102         : fProgram(std::move(program))
103         , fSharedModule(sharedModule)
104         , fJITStack(jitStack) {}
105 
106         std::unique_ptr<Program> fProgram;
107         LLVMSharedModuleRef fSharedModule;
108         LLVMOrcJITStackRef fJITStack;
109 
110         friend class JIT;
111     };
112 
113     JIT(Compiler* compiler);
114 
115     ~JIT();
116 
117     /**
118      * Just-in-time compiles an SkSL program and returns the resulting Module. The JIT must not be
119      * destroyed before all of its Modules are destroyed.
120      */
121     std::unique_ptr<Module> compile(std::unique_ptr<Program> program);
122 
123 private:
124     static constexpr int CHANNELS = 4;
125 
126     enum TypeKind {
127         kFloat_TypeKind,
128         kInt_TypeKind,
129         kUInt_TypeKind,
130         kBool_TypeKind
131     };
132 
133     class LValue {
134     public:
~LValue()135         virtual ~LValue() {}
136 
137         virtual LLVMValueRef load(LLVMBuilderRef builder) = 0;
138 
139         virtual void store(LLVMBuilderRef builder, LLVMValueRef value) = 0;
140     };
141 
142     void addBuiltinFunction(const char* ourName, const char* realName, LLVMTypeRef returnType,
143                             std::vector<LLVMTypeRef> parameters);
144 
145     void loadBuiltinFunctions();
146 
147     void setBlock(LLVMBuilderRef builder, LLVMBasicBlockRef block);
148 
149     LLVMTypeRef getType(const Type& type);
150 
151     TypeKind typeKind(const Type& type);
152 
153     std::unique_ptr<LValue> getLValue(LLVMBuilderRef builder, const Expression& expr);
154 
155     void vectorize(LLVMBuilderRef builder, LLVMValueRef* value, int columns);
156 
157     void vectorize(LLVMBuilderRef builder, const BinaryExpression& b, LLVMValueRef* left,
158                    LLVMValueRef* right);
159 
160     LLVMValueRef compileBinary(LLVMBuilderRef builder, const BinaryExpression& b);
161 
162     LLVMValueRef compileConstructor(LLVMBuilderRef builder, const Constructor& c);
163 
164     LLVMValueRef compileFunctionCall(LLVMBuilderRef builder, const FunctionCall& fc);
165 
166     LLVMValueRef compileIndex(LLVMBuilderRef builder, const IndexExpression& v);
167 
168     LLVMValueRef compilePostfix(LLVMBuilderRef builder, const PostfixExpression& p);
169 
170     LLVMValueRef compilePrefix(LLVMBuilderRef builder, const PrefixExpression& p);
171 
172     LLVMValueRef compileSwizzle(LLVMBuilderRef builder, const Swizzle& s);
173 
174     LLVMValueRef compileVariableReference(LLVMBuilderRef builder, const VariableReference& v);
175 
176     LLVMValueRef compileTernary(LLVMBuilderRef builder, const TernaryExpression& t);
177 
178     LLVMValueRef compileExpression(LLVMBuilderRef builder, const Expression& expr);
179 
180     void appendStage(LLVMBuilderRef builder, const AppendStage& a);
181 
182     void compileBlock(LLVMBuilderRef builder, const Block& block);
183 
184     void compileBreak(LLVMBuilderRef builder, const BreakStatement& b);
185 
186     void compileContinue(LLVMBuilderRef builder, const ContinueStatement& c);
187 
188     void compileDo(LLVMBuilderRef builder, const DoStatement& d);
189 
190     void compileFor(LLVMBuilderRef builder, const ForStatement& f);
191 
192     void compileIf(LLVMBuilderRef builder, const IfStatement& i);
193 
194     void compileReturn(LLVMBuilderRef builder, const ReturnStatement& r);
195 
196     void compileVarDeclarations(LLVMBuilderRef builder, const VarDeclarationsStatement& decls);
197 
198     void compileWhile(LLVMBuilderRef builder, const WhileStatement& w);
199 
200     void compileStatement(LLVMBuilderRef builder, const Statement& stmt);
201 
202     // The "Vector" variants of functions attempt to compile a given expression or statement as part
203     // of a vectorized SkJumper stage function - that is, with r, g, b, and a each being vectors of
204     // fVectorCount floats. So a statement like "color.r = 0;" looks like it modifies a single
205     // channel of a single pixel, but the compiled code will actually modify the red channel of
206     // fVectorCount pixels at once.
207     //
208     // As not everything can be vectorized, these calls return a bool to indicate whether they were
209     // successful. If anything anywhere in the function cannot be vectorized, the JIT will fall back
210     // to looping over the pixels instead.
211     //
212     // Since we process multiple pixels at once, and each pixel consists of multiple color channels,
213     // expressions may effectively result in a vector-of-vectors. We produce zero to four outputs
214     // when compiling expression, each of which is a vector, so that e.g. float2(1, 0) actually
215     // produces two vectors, one containing all 1s, the other all 0s. The out parameter always
216     // allows for 4 channels, but the functions produce 0 to 4 channels depending on the type they
217     // are operating on. Thus evaluating "color.rgb" actually fills in out[0] through out[2],
218     // leaving out[3] uninitialized.
219     // As the number of outputs can be inferred from the type of the expression, it is not
220     // explicitly signalled anywhere.
221     bool compileVectorBinary(LLVMBuilderRef builder, const BinaryExpression& b,
222                              LLVMValueRef out[CHANNELS]);
223 
224     bool compileVectorConstructor(LLVMBuilderRef builder, const Constructor& c,
225                                   LLVMValueRef out[CHANNELS]);
226 
227     bool compileVectorFloatLiteral(LLVMBuilderRef builder, const FloatLiteral& f,
228                                    LLVMValueRef out[CHANNELS]);
229 
230     bool compileVectorSwizzle(LLVMBuilderRef builder, const Swizzle& s,
231                               LLVMValueRef out[CHANNELS]);
232 
233     bool compileVectorVariableReference(LLVMBuilderRef builder, const VariableReference& v,
234                                         LLVMValueRef out[CHANNELS]);
235 
236     bool compileVectorExpression(LLVMBuilderRef builder, const Expression& expr,
237                                  LLVMValueRef out[CHANNELS]);
238 
239     bool getVectorLValue(LLVMBuilderRef builder, const Expression& e, LLVMValueRef out[CHANNELS]);
240 
241     /**
242      * Evaluates the left and right operands of a binary operation, promoting one of them to a
243      * vector if necessary to make the types match.
244      */
245     bool getVectorBinaryOperands(LLVMBuilderRef builder, const Expression& left,
246                                  LLVMValueRef outLeft[CHANNELS], const Expression& right,
247                                  LLVMValueRef outRight[CHANNELS]);
248 
249     bool compileVectorStatement(LLVMBuilderRef builder, const Statement& stmt);
250 
251     /**
252      * Returns true if this function has the signature void(int, int, inout float4) and thus can be
253      * used as an SkJumper stage.
254      */
255     bool hasStageSignature(const FunctionDeclaration& f);
256 
257     /**
258      * Attempts to compile a vectorized stage function, returning true on success. A stage function
259      * of e.g. "color.r = 0;" will produce code which sets the entire red vector to zeros in a
260      * single instruction, thus calculating several pixels at once.
261      */
262     bool compileStageFunctionVector(const FunctionDefinition& f, LLVMValueRef newFunc);
263 
264     /**
265      * Fallback function which loops over the pixels, for when vectorization fails. A stage function
266      * of e.g. "color.r = 0;" will produce a loop which iterates over the entries in the red vector,
267      * setting each one to zero individually.
268      */
269     void compileStageFunctionLoop(const FunctionDefinition& f, LLVMValueRef newFunc);
270 
271     /**
272      * Called when compiling a function which has the signature of an SkJumper stage. Produces a
273      * version of the function which can be plugged into SkJumper (thus having a signature which
274      * accepts four vectors, one for each color channel, containing the color data of multiple
275      * pixels at once). To go from SkSL code which operates on a single pixel at a time to CPU code
276      * which operates on multiple pixels at once, the code is either vectorized using
277      * compileStageFunctionVector or wrapped in a loop using compileStageFunctionLoop.
278      */
279     LLVMValueRef compileStageFunction(const FunctionDefinition& f);
280 
281     /**
282      * Compiles an SkSL function to an LLVM function. If the function has the signature of an
283      * SkJumper stage, it will *also* be compiled by compileStageFunction, resulting in both a stage
284      * and non-stage version of the function.
285      */
286     LLVMValueRef compileFunction(const FunctionDefinition& f);
287 
288     void createModule();
289 
290     void optimize();
291 
292     bool isColorRef(const Expression& expr);
293 
294     static uint64_t resolveSymbol(const char* name, JIT* jit);
295 
296     const char* fCPU;
297     int fVectorCount;
298     Compiler& fCompiler;
299     std::unique_ptr<Program> fProgram;
300     LLVMContextRef fContext;
301     LLVMModuleRef fModule;
302     LLVMSharedModuleRef fSharedModule;
303     LLVMOrcJITStackRef fJITStack;
304     LLVMValueRef fCurrentFunction;
305     LLVMBasicBlockRef fAllocaBlock;
306     LLVMBasicBlockRef fCurrentBlock;
307     LLVMTypeRef fVoidType;
308     LLVMTypeRef fInt1Type;
309     LLVMTypeRef fInt1VectorType;
310     LLVMTypeRef fInt1Vector2Type;
311     LLVMTypeRef fInt1Vector3Type;
312     LLVMTypeRef fInt1Vector4Type;
313     LLVMTypeRef fInt8Type;
314     LLVMTypeRef fInt8PtrType;
315     LLVMTypeRef fInt32Type;
316     LLVMTypeRef fInt32VectorType;
317     LLVMTypeRef fInt32Vector2Type;
318     LLVMTypeRef fInt32Vector3Type;
319     LLVMTypeRef fInt32Vector4Type;
320     LLVMTypeRef fInt64Type;
321     LLVMTypeRef fSizeTType;
322     LLVMTypeRef fFloat32Type;
323     LLVMTypeRef fFloat32VectorType;
324     LLVMTypeRef fFloat32Vector2Type;
325     LLVMTypeRef fFloat32Vector3Type;
326     LLVMTypeRef fFloat32Vector4Type;
327     // Our SkSL stage functions have a single float4 for color, but the actual SkJumper stage
328     // function has four separate vectors, one for each channel. These four values are references to
329     // the red, green, blue, and alpha vectors respectively.
330     LLVMValueRef fChannels[CHANNELS];
331     // when processing a stage function, this points to the SkSL color parameter (an inout float4)
332     const Variable* fColorParam;
333     std::unordered_map<const FunctionDeclaration*, LLVMValueRef> fFunctions;
334     std::unordered_map<const Variable*, LLVMValueRef> fVariables;
335     // LLVM function parameters are read-only, so when modifying function parameters we need to
336     // first promote them to variables. This keeps track of which parameters have been promoted.
337     std::set<const Variable*> fPromotedParameters;
338     std::vector<LLVMBasicBlockRef> fBreakTarget;
339     std::vector<LLVMBasicBlockRef> fContinueTarget;
340 
341     LLVMValueRef fFoldAnd2Func;
342     LLVMValueRef fFoldOr2Func;
343     LLVMValueRef fFoldAnd3Func;
344     LLVMValueRef fFoldOr3Func;
345     LLVMValueRef fFoldAnd4Func;
346     LLVMValueRef fFoldOr4Func;
347     LLVMValueRef fAppendFunc;
348     LLVMValueRef fAppendCallbackFunc;
349     LLVMValueRef fDebugFunc;
350 };
351 
352 } // namespace
353 
354 #endif // SK_LLVM_AVAILABLE
355 
356 #endif // SKSL_JIT
357