1 // Copyright 2020 Google LLC.
2 // Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
3 
4 #ifndef SkVM_opts_DEFINED
5 #define SkVM_opts_DEFINED
6 
7 #include "include/private/SkVx.h"
8 #include "src/core/SkVM.h"
9 
10 template <int N>
gather32(const int * ptr,const skvx::Vec<N,int> & ix)11 static inline skvx::Vec<N,int> gather32(const int* ptr, const skvx::Vec<N,int>& ix) {
12 #if SK_CPU_SSE_LEVEL >= SK_CPU_SSE_LEVEL_AVX2
13     if constexpr (N == 8) {
14         return skvx::bit_pun<skvx::Vec<N,int>>(
15                 _mm256_i32gather_epi32(ptr, skvx::bit_pun<__m256i>(ix), 4));
16     }
17 #endif
18     // Try to recurse on specializations, falling back on standard scalar map()-based impl.
19     if constexpr (N > 8) {
20         return join(gather32(ptr, ix.lo),
21                     gather32(ptr, ix.hi));
22     }
23     return map([&](int i) { return ptr[i]; }, ix);
24 }
25 
26 namespace SK_OPTS_NS {
27 
interpret_skvm(const skvm::InterpreterInstruction insts[],const int ninsts,const int nregs,const int loop,const int strides[],const int nargs,int n,void * args[])28     inline void interpret_skvm(const skvm::InterpreterInstruction insts[], const int ninsts,
29                                const int nregs, const int loop,
30                                const int strides[], const int nargs,
31                                int n, void* args[]) {
32         using namespace skvm;
33 
34         // We'll operate in SIMT style, knocking off K-size chunks from n while possible.
35     #if SK_CPU_SSE_LEVEL >= SK_CPU_SSE_LEVEL_AVX2
36         constexpr int K = 32;  // 1024-bit: 4 ymm or 2 zmm at a time
37     #else
38         constexpr int K = 8;   // 256-bit: 2 xmm, 2 v-registers, etc.
39     #endif
40         using I32 = skvx::Vec<K, int>;
41         using I16 = skvx::Vec<K, int16_t>;
42         using F32 = skvx::Vec<K, float>;
43         using U64 = skvx::Vec<K, uint64_t>;
44         using U32 = skvx::Vec<K, uint32_t>;
45         using U16 = skvx::Vec<K, uint16_t>;
46         using  U8 = skvx::Vec<K, uint8_t>;
47         union Slot {
48             F32   f32;
49             I32   i32;
50             U32   u32;
51             I16   i16;
52             U16   u16;
53         };
54 
55         Slot                     few_regs[16];
56         std::unique_ptr<char[]> many_regs;
57 
58         Slot* r = few_regs;
59 
60         if (nregs > (int)SK_ARRAY_COUNT(few_regs)) {
61             // Annoyingly we can't trust that malloc() or new will work with Slot because
62             // the skvx::Vec types may have alignment greater than what they provide.
63             // We'll overallocate one extra register so we can align manually.
64             many_regs.reset(new char[ sizeof(Slot) * (nregs + 1) ]);
65 
66             uintptr_t addr = (uintptr_t)many_regs.get();
67             addr += alignof(Slot) -
68                      (addr & (alignof(Slot) - 1));
69             SkASSERT((addr & (alignof(Slot) - 1)) == 0);
70             r = (Slot*)addr;
71         }
72 
73 
74         // Step each argument pointer ahead by its stride a number of times.
75         auto step_args = [&](int times) {
76             for (int i = 0; i < nargs; i++) {
77                 args[i] = (void*)( (char*)args[i] + times * strides[i] );
78             }
79         };
80 
81         int start = 0,
82             stride;
83         for ( ; n > 0; start = loop, n -= stride, step_args(stride)) {
84             stride = n >= K ? K : 1;
85 
86             for (int i = start; i < ninsts; i++) {
87                 InterpreterInstruction inst = insts[i];
88 
89                 // d = op(x,y,z,w, immA,immB)
90                 Reg   d = inst.d,
91                       x = inst.x,
92                       y = inst.y,
93                       z = inst.z,
94                       w = inst.w;
95                 int immA = inst.immA,
96                     immB = inst.immB;
97 
98                 // Ops that interact with memory need to know whether we're stride=1 or K,
99                 // but all non-memory ops can run the same code no matter the stride.
100                 switch (2*(int)inst.op + (stride == K ? 1 : 0)) {
101                     default: SkUNREACHABLE;
102 
103                 #define STRIDE_1(op) case 2*(int)op
104                 #define STRIDE_K(op) case 2*(int)op + 1
105                     STRIDE_1(Op::store8 ): memcpy(args[immA], &r[x].i32, 1); break;
106                     STRIDE_1(Op::store16): memcpy(args[immA], &r[x].i32, 2); break;
107                     STRIDE_1(Op::store32): memcpy(args[immA], &r[x].i32, 4); break;
108                     STRIDE_1(Op::store64): memcpy((char*)args[immA]+0, &r[x].i32, 4);
109                                            memcpy((char*)args[immA]+4, &r[y].i32, 4); break;
110 
111                     STRIDE_K(Op::store8 ): skvx::cast<uint8_t> (r[x].i32).store(args[immA]); break;
112                     STRIDE_K(Op::store16): skvx::cast<uint16_t>(r[x].i32).store(args[immA]); break;
113                     STRIDE_K(Op::store32):                     (r[x].i32).store(args[immA]); break;
114                     STRIDE_K(Op::store64): (skvx::cast<uint64_t>(r[x].u32) << 0 |
115                                             skvx::cast<uint64_t>(r[y].u32) << 32).store(args[immA]);
116                                            break;
117 
118                     STRIDE_1(Op::load8 ): r[d].i32 = 0; memcpy(&r[d].i32, args[immA], 1); break;
119                     STRIDE_1(Op::load16): r[d].i32 = 0; memcpy(&r[d].i32, args[immA], 2); break;
120                     STRIDE_1(Op::load32): r[d].i32 = 0; memcpy(&r[d].i32, args[immA], 4); break;
121                     STRIDE_1(Op::load64):
122                         r[d].i32 = 0; memcpy(&r[d].i32, (char*)args[immA] + 4*immB, 4); break;
123 
124                     STRIDE_K(Op::load8 ): r[d].i32= skvx::cast<int>(U8 ::Load(args[immA])); break;
125                     STRIDE_K(Op::load16): r[d].i32= skvx::cast<int>(U16::Load(args[immA])); break;
126                     STRIDE_K(Op::load32): r[d].i32=                 I32::Load(args[immA]) ; break;
127                     STRIDE_K(Op::load64):
128                         // Low 32 bits if immB=0, or high 32 bits if immB=1.
129                         r[d].i32 = skvx::cast<int>(U64::Load(args[immA]) >> (32*immB)); break;
130 
131                     // The pointer we base our gather on is loaded indirectly from a uniform:
132                     //     - args[immA] is the uniform holding our gather base pointer somewhere;
133                     //     - (const uint8_t*)args[immA] + immB points to the gather base pointer;
134                     //     - memcpy() loads the gather base and into a pointer of the right type.
135                     // After all that we have an ordinary (uniform) pointer `ptr` to load from,
136                     // and we then gather from it using the varying indices in r[x].
137                     STRIDE_1(Op::gather8): {
138                         const uint8_t* ptr;
139                         memcpy(&ptr, (const uint8_t*)args[immA] + immB, sizeof(ptr));
140                         r[d].i32 = ptr[ r[x].i32[0] ];
141                     } break;
142                     STRIDE_1(Op::gather16): {
143                         const uint16_t* ptr;
144                         memcpy(&ptr, (const uint8_t*)args[immA] + immB, sizeof(ptr));
145                         r[d].i32 = ptr[ r[x].i32[0] ];
146                     } break;
147                     STRIDE_1(Op::gather32): {
148                         const int* ptr;
149                         memcpy(&ptr, (const uint8_t*)args[immA] + immB, sizeof(ptr));
150                         r[d].i32 = ptr[ r[x].i32[0] ];
151                     } break;
152 
153                     STRIDE_K(Op::gather8): {
154                         const uint8_t* ptr;
155                         memcpy(&ptr, (const uint8_t*)args[immA] + immB, sizeof(ptr));
156                         r[d].i32 = map([&](int ix) { return (int)ptr[ix]; }, r[x].i32);
157                     } break;
158                     STRIDE_K(Op::gather16): {
159                         const uint16_t* ptr;
160                         memcpy(&ptr, (const uint8_t*)args[immA] + immB, sizeof(ptr));
161                         r[d].i32 = map([&](int ix) { return (int)ptr[ix]; }, r[x].i32);
162                     } break;
163                     STRIDE_K(Op::gather32): {
164                         const int* ptr;
165                         memcpy(&ptr, (const uint8_t*)args[immA] + immB, sizeof(ptr));
166                         r[d].i32 = gather32(ptr, r[x].i32);
167                     } break;
168 
169                 #undef STRIDE_1
170                 #undef STRIDE_K
171 
172                     // Ops that don't interact with memory should never care about the stride.
173                 #define CASE(op) case 2*(int)op: /*fallthrough*/ case 2*(int)op+1
174 
175                     // These 128-bit ops are implemented serially for simplicity.
176                     CASE(Op::store128): {
177                         U64 lo = (skvx::cast<uint64_t>(r[x].u32) << 0 |
178                                   skvx::cast<uint64_t>(r[y].u32) << 32),
179                             hi = (skvx::cast<uint64_t>(r[z].u32) << 0 |
180                                   skvx::cast<uint64_t>(r[w].u32) << 32);
181                         for (int i = 0; i < stride; i++) {
182                             memcpy((char*)args[immA] + 16*i + 0, &lo[i], 8);
183                             memcpy((char*)args[immA] + 16*i + 8, &hi[i], 8);
184                         }
185                     } break;
186 
187                     CASE(Op::load128):
188                         r[d].i32 = 0;
189                         for (int i = 0; i < stride; i++) {
190                             memcpy(&r[d].i32[i], (const char*)args[immA] + 16*i+ 4*immB, 4);
191                         } break;
192 
193                     CASE(Op::assert_true):
194                     #ifdef SK_DEBUG
195                         if (!all(r[x].i32)) {
196                             SkDebugf("inst %d, register %d\n", i, y);
197                             for (int i = 0; i < K; i++) {
198                                 SkDebugf("\t%2d: %08x (%g)\n", i, r[y].i32[i], r[y].f32[i]);
199                             }
200                             SkASSERT(false);
201                         }
202                     #endif
203                     break;
204 
205                     CASE(Op::index): {
206                         const int iota[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15,
207                                             16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,
208                                             32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,
209                                             48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63 };
210                         static_assert(K <= SK_ARRAY_COUNT(iota), "");
211 
212                         r[d].i32 = n - I32::Load(iota);
213                     } break;
214 
215                     CASE(Op::uniform32):
216                         r[d].i32 = *(const int*)( (const char*)args[immA] + immB );
217                         break;
218 
219                     CASE(Op::splat): r[d].i32 = immA; break;
220 
221                     CASE(Op::add_f32): r[d].f32 = r[x].f32 + r[y].f32; break;
222                     CASE(Op::sub_f32): r[d].f32 = r[x].f32 - r[y].f32; break;
223                     CASE(Op::mul_f32): r[d].f32 = r[x].f32 * r[y].f32; break;
224                     CASE(Op::div_f32): r[d].f32 = r[x].f32 / r[y].f32; break;
225                     CASE(Op::min_f32): r[d].f32 = min(r[x].f32, r[y].f32); break;
226                     CASE(Op::max_f32): r[d].f32 = max(r[x].f32, r[y].f32); break;
227 
228                     CASE(Op::fma_f32):  r[d].f32 = fma( r[x].f32, r[y].f32,  r[z].f32); break;
229                     CASE(Op::fms_f32):  r[d].f32 = fma( r[x].f32, r[y].f32, -r[z].f32); break;
230                     CASE(Op::fnma_f32): r[d].f32 = fma(-r[x].f32, r[y].f32,  r[z].f32); break;
231 
232                     CASE(Op::sqrt_f32): r[d].f32 = sqrt(r[x].f32); break;
233 
234                     CASE(Op::add_i32): r[d].i32 = r[x].i32 + r[y].i32; break;
235                     CASE(Op::sub_i32): r[d].i32 = r[x].i32 - r[y].i32; break;
236                     CASE(Op::mul_i32): r[d].i32 = r[x].i32 * r[y].i32; break;
237 
238                     CASE(Op::shl_i32): r[d].i32 = r[x].i32 << immA; break;
239                     CASE(Op::sra_i32): r[d].i32 = r[x].i32 >> immA; break;
240                     CASE(Op::shr_i32): r[d].u32 = r[x].u32 >> immA; break;
241 
242                     CASE(Op:: eq_f32): r[d].i32 = r[x].f32 == r[y].f32; break;
243                     CASE(Op::neq_f32): r[d].i32 = r[x].f32 != r[y].f32; break;
244                     CASE(Op:: gt_f32): r[d].i32 = r[x].f32 >  r[y].f32; break;
245                     CASE(Op::gte_f32): r[d].i32 = r[x].f32 >= r[y].f32; break;
246 
247                     CASE(Op:: eq_i32): r[d].i32 = r[x].i32 == r[y].i32; break;
248                     CASE(Op:: gt_i32): r[d].i32 = r[x].i32 >  r[y].i32; break;
249 
250                     CASE(Op::bit_and  ): r[d].i32 = r[x].i32 &  r[y].i32; break;
251                     CASE(Op::bit_or   ): r[d].i32 = r[x].i32 |  r[y].i32; break;
252                     CASE(Op::bit_xor  ): r[d].i32 = r[x].i32 ^  r[y].i32; break;
253                     CASE(Op::bit_clear): r[d].i32 = r[x].i32 & ~r[y].i32; break;
254 
255                     CASE(Op::select): r[d].i32 = skvx::if_then_else(r[x].i32, r[y].i32, r[z].i32);
256                                       break;
257 
258                     CASE(Op::ceil):   r[d].f32 =                    skvx::ceil(r[x].f32) ; break;
259                     CASE(Op::floor):  r[d].f32 =                   skvx::floor(r[x].f32) ; break;
260                     CASE(Op::to_f32): r[d].f32 = skvx::cast<float>(            r[x].i32 ); break;
261                     CASE(Op::trunc):  r[d].i32 = skvx::cast<int>  (            r[x].f32 ); break;
262                     CASE(Op::round):  r[d].i32 = skvx::cast<int>  (skvx::lrint(r[x].f32)); break;
263 
264                     CASE(Op::to_fp16):
265                         r[d].i32 = skvx::cast<int>(skvx::to_half(r[x].f32));
266                         break;
267                     CASE(Op::from_fp16):
268                         r[d].f32 = skvx::from_half(skvx::cast<uint16_t>(r[x].i32));
269                         break;
270 
271                 #undef CASE
272                 }
273             }
274         }
275     }
276 
277 }  // namespace SK_OPTS_NS
278 
279 #endif//SkVM_opts_DEFINED
280