1 /****************************************************************************
2  * Copyright (C) 2014-2015 Intel Corporation.   All Rights Reserved.
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  *
23  * @file streamout_jit.cpp
24  *
25  * @brief Implementation of the streamout jitter
26  *
27  * Notes:
28  *
29  ******************************************************************************/
30 #include "jit_pch.hpp"
31 #include "builder_gfx_mem.h"
32 #include "jit_api.h"
33 #include "streamout_jit.h"
34 #include "gen_state_llvm.h"
35 #include "functionpasses/passes.h"
36 
37 using namespace llvm;
38 using namespace SwrJit;
39 
40 //////////////////////////////////////////////////////////////////////////
41 /// Interface to Jitting a fetch shader
42 //////////////////////////////////////////////////////////////////////////
43 struct StreamOutJit : public BuilderGfxMem
44 {
StreamOutJitStreamOutJit45     StreamOutJit(JitManager* pJitMgr) : BuilderGfxMem(pJitMgr){};
46 
47     // returns pointer to SWR_STREAMOUT_BUFFER
getSOBufferStreamOutJit48     Value* getSOBuffer(Value* pSoCtx, uint32_t buffer)
49     {
50         return LOAD(pSoCtx, {0, SWR_STREAMOUT_CONTEXT_pBuffer, buffer});
51     }
52 
53     //////////////////////////////////////////////////////////////////////////
54     // @brief checks if streamout buffer is oob
55     // @return <i1> true/false
oobStreamOutJit56     Value* oob(const STREAMOUT_COMPILE_STATE& state, Value* pSoCtx, uint32_t buffer)
57     {
58         Value* returnMask = C(false);
59 
60         Value* pBuf = getSOBuffer(pSoCtx, buffer);
61 
62         // load enable
63         // @todo bool data types should generate <i1> llvm type
64         Value* enabled = TRUNC(LOAD(pBuf, {0, SWR_STREAMOUT_BUFFER_enable}), IRB()->getInt1Ty());
65 
66         // load buffer size
67         Value* bufferSize = LOAD(pBuf, {0, SWR_STREAMOUT_BUFFER_bufferSize});
68 
69         // load current streamOffset
70         Value* streamOffset = LOAD(pBuf, {0, SWR_STREAMOUT_BUFFER_streamOffset});
71 
72         // load buffer pitch
73         Value* pitch = LOAD(pBuf, {0, SWR_STREAMOUT_BUFFER_pitch});
74 
75         // buffer is considered oob if in use in a decl but not enabled
76         returnMask = OR(returnMask, NOT(enabled));
77 
78         // buffer is oob if cannot fit a prims worth of verts
79         Value* newOffset = ADD(streamOffset, MUL(pitch, C(state.numVertsPerPrim)));
80         returnMask       = OR(returnMask, ICMP_SGT(newOffset, bufferSize));
81 
82         return returnMask;
83     }
84 
85     //////////////////////////////////////////////////////////////////////////
86     // @brief converts scalar bitmask to <4 x i32> suitable for shuffle vector,
87     //        packing the active mask bits
88     //        ex. bitmask 0011 -> (0, 1, 0, 0)
89     //            bitmask 1000 -> (3, 0, 0, 0)
90     //            bitmask 1100 -> (2, 3, 0, 0)
PackMaskStreamOutJit91     Value* PackMask(uint32_t bitmask)
92     {
93         std::vector<Constant*> indices(4, C(0));
94         unsigned long          index;
95         uint32_t               elem = 0;
96         while (_BitScanForward(&index, bitmask))
97         {
98             indices[elem++] = C((int)index);
99             bitmask &= ~(1 << index);
100         }
101 
102         return ConstantVector::get(indices);
103     }
104 
105     //////////////////////////////////////////////////////////////////////////
106     // @brief convert scalar bitmask to <4xfloat> bitmask
ToMaskStreamOutJit107     Value* ToMask(uint32_t bitmask)
108     {
109         std::vector<Constant*> indices;
110         for (uint32_t i = 0; i < 4; ++i)
111         {
112             if (bitmask & (1 << i))
113             {
114                 indices.push_back(C(true));
115             }
116             else
117             {
118                 indices.push_back(C(false));
119             }
120         }
121         return ConstantVector::get(indices);
122     }
123 
124     //////////////////////////////////////////////////////////////////////////
125     // @brief processes a single decl from the streamout stream. Reads 4 components from the input
126     //        stream and writes N components to the output buffer given the componentMask or if
127     //        a hole, just increments the buffer pointer
128     // @param pStream - pointer to current attribute
129     // @param pOutBuffers - pointers to the current location of each output buffer
130     // @param decl - input decl
buildDeclStreamOutJit131     void buildDecl(Value* pStream, Value* pOutBuffers[4], const STREAMOUT_DECL& decl)
132     {
133         uint32_t numComponents = _mm_popcnt_u32(decl.componentMask);
134         uint32_t packedMask    = (1 << numComponents) - 1;
135         if (!decl.hole)
136         {
137             // increment stream pointer to correct slot
138             Value* pAttrib = GEP(pStream, C(4 * decl.attribSlot));
139 
140             // load 4 components from stream
141             Type* simd4Ty    = getVectorType(IRB()->getFloatTy(), 4);
142             Type* simd4PtrTy = PointerType::get(simd4Ty, 0);
143             pAttrib          = BITCAST(pAttrib, simd4PtrTy);
144             Value* vattrib   = LOAD(pAttrib);
145 
146             // shuffle/pack enabled components
147             Value* vpackedAttrib = VSHUFFLE(vattrib, vattrib, PackMask(decl.componentMask));
148 
149             // store to output buffer
150             // cast SO buffer to i8*, needed by maskstore
151             Value* pOut = BITCAST(pOutBuffers[decl.bufferIndex], PointerType::get(simd4Ty, 0));
152 
153             // cast input to <4xfloat>
154             Value* src = BITCAST(vpackedAttrib, simd4Ty);
155 
156             // cast mask to <4xi1>
157             Value* mask = ToMask(packedMask);
158             MASKED_STORE(src, pOut, 4, mask, PointerType::get(simd4Ty, 0), MEM_CLIENT::GFX_MEM_CLIENT_STREAMOUT);
159         }
160 
161         // increment SO buffer
162         pOutBuffers[decl.bufferIndex] = GEP(pOutBuffers[decl.bufferIndex], C(numComponents));
163     }
164 
165     //////////////////////////////////////////////////////////////////////////
166     // @brief builds a single vertex worth of data for the given stream
167     // @param streamState - state for this stream
168     // @param pCurVertex - pointer to src stream vertex data
169     // @param pOutBuffer - pointers to up to 4 SO buffers
buildVertexStreamOutJit170     void buildVertex(const STREAMOUT_STREAM& streamState, Value* pCurVertex, Value* pOutBuffer[4])
171     {
172         for (uint32_t d = 0; d < streamState.numDecls; ++d)
173         {
174             const STREAMOUT_DECL& decl = streamState.decl[d];
175             buildDecl(pCurVertex, pOutBuffer, decl);
176         }
177     }
178 
buildStreamStreamOutJit179     void buildStream(const STREAMOUT_COMPILE_STATE& state,
180                      const STREAMOUT_STREAM&        streamState,
181                      Value*                         pSoCtx,
182                      BasicBlock*                    returnBB,
183                      Function*                      soFunc)
184     {
185         // get list of active SO buffers
186         std::unordered_set<uint32_t> activeSOBuffers;
187         for (uint32_t d = 0; d < streamState.numDecls; ++d)
188         {
189             const STREAMOUT_DECL& decl = streamState.decl[d];
190             activeSOBuffers.insert(decl.bufferIndex);
191         }
192 
193         // always increment numPrimStorageNeeded
194         Value* numPrimStorageNeeded = LOAD(pSoCtx, {0, SWR_STREAMOUT_CONTEXT_numPrimStorageNeeded});
195         numPrimStorageNeeded        = ADD(numPrimStorageNeeded, C(1));
196         STORE(numPrimStorageNeeded, pSoCtx, {0, SWR_STREAMOUT_CONTEXT_numPrimStorageNeeded});
197 
198         // check OOB on active SO buffers.  If any buffer is out of bound, don't write
199         // the primitive to any buffer
200         Value* oobMask = C(false);
201         for (uint32_t buffer : activeSOBuffers)
202         {
203             oobMask = OR(oobMask, oob(state, pSoCtx, buffer));
204         }
205 
206         BasicBlock* validBB = BasicBlock::Create(JM()->mContext, "valid", soFunc);
207 
208         // early out if OOB
209         COND_BR(oobMask, returnBB, validBB);
210 
211         IRB()->SetInsertPoint(validBB);
212 
213         Value* numPrimsWritten = LOAD(pSoCtx, {0, SWR_STREAMOUT_CONTEXT_numPrimsWritten});
214         numPrimsWritten        = ADD(numPrimsWritten, C(1));
215         STORE(numPrimsWritten, pSoCtx, {0, SWR_STREAMOUT_CONTEXT_numPrimsWritten});
216 
217         // compute start pointer for each output buffer
218         Value* pOutBuffer[4];
219         Value* pOutBufferStartVertex[4];
220         Value* outBufferPitch[4];
221         for (uint32_t b : activeSOBuffers)
222         {
223             Value* pBuf              = getSOBuffer(pSoCtx, b);
224             Value* pData             = LOAD(pBuf, {0, SWR_STREAMOUT_BUFFER_pBuffer});
225             Value* streamOffset      = LOAD(pBuf, {0, SWR_STREAMOUT_BUFFER_streamOffset});
226             pOutBuffer[b] = GEP(pData, streamOffset, PointerType::get(IRB()->getInt32Ty(), 0));
227             pOutBufferStartVertex[b] = pOutBuffer[b];
228 
229             outBufferPitch[b] = LOAD(pBuf, {0, SWR_STREAMOUT_BUFFER_pitch});
230         }
231 
232         // loop over the vertices of the prim
233         Value* pStreamData = LOAD(pSoCtx, {0, SWR_STREAMOUT_CONTEXT_pPrimData});
234         for (uint32_t v = 0; v < state.numVertsPerPrim; ++v)
235         {
236             buildVertex(streamState, pStreamData, pOutBuffer);
237 
238             // increment stream and output buffer pointers
239             // stream verts are always 32*4 dwords apart
240             pStreamData = GEP(pStreamData, C(SWR_VTX_NUM_SLOTS * 4));
241 
242             // output buffers offset using pitch in buffer state
243             for (uint32_t b : activeSOBuffers)
244             {
245                 pOutBufferStartVertex[b] = GEP(pOutBufferStartVertex[b], outBufferPitch[b]);
246                 pOutBuffer[b]            = pOutBufferStartVertex[b];
247             }
248         }
249 
250         // update each active buffer's streamOffset
251         for (uint32_t b : activeSOBuffers)
252         {
253             Value* pBuf         = getSOBuffer(pSoCtx, b);
254             Value* streamOffset = LOAD(pBuf, {0, SWR_STREAMOUT_BUFFER_streamOffset});
255             streamOffset = ADD(streamOffset, MUL(C(state.numVertsPerPrim), outBufferPitch[b]));
256             STORE(streamOffset, pBuf, {0, SWR_STREAMOUT_BUFFER_streamOffset});
257         }
258     }
259 
CreateStreamOutJit260     Function* Create(const STREAMOUT_COMPILE_STATE& state)
261     {
262         std::stringstream fnName("SO_",
263                                  std::ios_base::in | std::ios_base::out | std::ios_base::ate);
264         fnName << ComputeCRC(0, &state, sizeof(state));
265 
266         std::vector<Type*> args{
267             mInt8PtrTy,
268             mInt8PtrTy,
269             PointerType::get(Gen_SWR_STREAMOUT_CONTEXT(JM()), 0), // SWR_STREAMOUT_CONTEXT*
270         };
271 
272         FunctionType* fTy    = FunctionType::get(IRB()->getVoidTy(), args, false);
273         Function*     soFunc = Function::Create(
274             fTy, GlobalValue::ExternalLinkage, fnName.str(), JM()->mpCurrentModule);
275 
276         soFunc->getParent()->setModuleIdentifier(soFunc->getName());
277 
278         // create return basic block
279         BasicBlock* entry    = BasicBlock::Create(JM()->mContext, "entry", soFunc);
280         BasicBlock* returnBB = BasicBlock::Create(JM()->mContext, "return", soFunc);
281 
282         IRB()->SetInsertPoint(entry);
283 
284         // arguments
285         auto   argitr = soFunc->arg_begin();
286 
287         Value* privateContext = &*argitr++;
288         privateContext->setName("privateContext");
289         SetPrivateContext(privateContext);
290 
291         mpWorkerData = &*argitr;
292         ++argitr;
293         mpWorkerData->setName("pWorkerData");
294 
295         Value* pSoCtx = &*argitr++;
296         pSoCtx->setName("pSoCtx");
297 
298         const STREAMOUT_STREAM& streamState = state.stream;
299         buildStream(state, streamState, pSoCtx, returnBB, soFunc);
300 
301         BR(returnBB);
302 
303         IRB()->SetInsertPoint(returnBB);
304         RET_VOID();
305 
306         JitManager::DumpToFile(soFunc, "SoFunc");
307 
308         ::FunctionPassManager passes(JM()->mpCurrentModule);
309 
310         passes.add(createBreakCriticalEdgesPass());
311         passes.add(createCFGSimplificationPass());
312         passes.add(createEarlyCSEPass());
313         passes.add(createPromoteMemoryToRegisterPass());
314         passes.add(createCFGSimplificationPass());
315         passes.add(createEarlyCSEPass());
316         passes.add(createInstructionCombiningPass());
317 #if LLVM_VERSION_MAJOR <= 11
318         passes.add(createConstantPropagationPass());
319 #endif
320         passes.add(createSCCPPass());
321         passes.add(createAggressiveDCEPass());
322 
323         passes.add(createLowerX86Pass(this));
324 
325         passes.run(*soFunc);
326 
327         JitManager::DumpToFile(soFunc, "SoFunc_optimized");
328 
329 
330         return soFunc;
331     }
332 };
333 
334 //////////////////////////////////////////////////////////////////////////
335 /// @brief JITs from streamout shader IR
336 /// @param hJitMgr - JitManager handle
337 /// @param func   - LLVM function IR
338 /// @return PFN_SO_FUNC - pointer to SOS function
JitStreamoutFunc(HANDLE hJitMgr,const HANDLE hFunc)339 PFN_SO_FUNC JitStreamoutFunc(HANDLE hJitMgr, const HANDLE hFunc)
340 {
341     llvm::Function* func    = (llvm::Function*)hFunc;
342     JitManager*     pJitMgr = reinterpret_cast<JitManager*>(hJitMgr);
343     PFN_SO_FUNC     pfnStreamOut;
344     pfnStreamOut = (PFN_SO_FUNC)(pJitMgr->mpExec->getFunctionAddress(func->getName().str()));
345     // MCJIT finalizes modules the first time you JIT code from them. After finalized, you cannot
346     // add new IR to the module
347     pJitMgr->mIsModuleFinalized = true;
348 
349     pJitMgr->DumpAsm(func, "SoFunc_optimized");
350 
351 
352     return pfnStreamOut;
353 }
354 
355 //////////////////////////////////////////////////////////////////////////
356 /// @brief JIT compiles streamout shader
357 /// @param hJitMgr - JitManager handle
358 /// @param state   - SO state to build function from
JitCompileStreamout(HANDLE hJitMgr,const STREAMOUT_COMPILE_STATE & state)359 extern "C" PFN_SO_FUNC JITCALL JitCompileStreamout(HANDLE                         hJitMgr,
360                                                    const STREAMOUT_COMPILE_STATE& state)
361 {
362     JitManager* pJitMgr = reinterpret_cast<JitManager*>(hJitMgr);
363 
364     STREAMOUT_COMPILE_STATE soState = state;
365     if (soState.offsetAttribs)
366     {
367         for (uint32_t i = 0; i < soState.stream.numDecls; ++i)
368         {
369             soState.stream.decl[i].attribSlot -= soState.offsetAttribs;
370         }
371     }
372 
373     pJitMgr->SetupNewModule();
374 
375     StreamOutJit theJit(pJitMgr);
376     HANDLE       hFunc = theJit.Create(soState);
377 
378     return JitStreamoutFunc(hJitMgr, hFunc);
379 }
380