1 /*
2  * Copyright © 2019 Valve Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 /**
25  * Although it's called a load/store "vectorization" pass, this also combines
26  * intersecting and identical loads/stores. It currently supports derefs, ubo,
27  * ssbo and push constant loads/stores.
28  *
29  * This doesn't handle copy_deref intrinsics and assumes that
30  * nir_lower_alu_to_scalar() has been called and that the IR is free from ALU
31  * modifiers. It also assumes that derefs have explicitly laid out types.
32  *
33  * After vectorization, the backend may want to call nir_lower_alu_to_scalar()
34  * and nir_lower_pack(). Also this creates cast instructions taking derefs as a
35  * source and some parts of NIR may not be able to handle that well.
36  *
37  * There are a few situations where this doesn't vectorize as well as it could:
38  * - It won't turn four consecutive vec3 loads into 3 vec4 loads.
39  * - It doesn't do global vectorization.
40  * Handling these cases probably wouldn't provide much benefit though.
41  *
42  * This probably doesn't handle big-endian GPUs correctly.
43 */
44 
45 #include "nir.h"
46 #include "nir_deref.h"
47 #include "nir_builder.h"
48 #include "nir_worklist.h"
49 #include "util/u_dynarray.h"
50 
51 #include <stdlib.h>
52 
53 struct intrinsic_info {
54    nir_variable_mode mode; /* 0 if the mode is obtained from the deref. */
55    nir_intrinsic_op op;
56    bool is_atomic;
57    /* Indices into nir_intrinsic::src[] or -1 if not applicable. */
58    int resource_src; /* resource (e.g. from vulkan_resource_index) */
59    int base_src; /* offset which it loads/stores from */
60    int deref_src; /* deref which is loads/stores from */
61    int value_src; /* the data it is storing */
62 };
63 
64 static const struct intrinsic_info *
get_info(nir_intrinsic_op op)65 get_info(nir_intrinsic_op op) {
66    switch (op) {
67 #define INFO(mode, op, atomic, res, base, deref, val) \
68 case nir_intrinsic_##op: {\
69    static const struct intrinsic_info op##_info = {mode, nir_intrinsic_##op, atomic, res, base, deref, val};\
70    return &op##_info;\
71 }
72 #define LOAD(mode, op, res, base, deref) INFO(mode, load_##op, false, res, base, deref, -1)
73 #define STORE(mode, op, res, base, deref, val) INFO(mode, store_##op, false, res, base, deref, val)
74 #define ATOMIC(mode, type, op, res, base, deref, val) INFO(mode, type##_atomic_##op, true, res, base, deref, val)
75    LOAD(nir_var_mem_push_const, push_constant, -1, 0, -1)
76    LOAD(nir_var_mem_ubo, ubo, 0, 1, -1)
77    LOAD(nir_var_mem_ssbo, ssbo, 0, 1, -1)
78    STORE(nir_var_mem_ssbo, ssbo, 1, 2, -1, 0)
79    LOAD(0, deref, -1, -1, 0)
80    STORE(0, deref, -1, -1, 0, 1)
81    LOAD(nir_var_mem_shared, shared, -1, 0, -1)
82    STORE(nir_var_mem_shared, shared, -1, 1, -1, 0)
83    LOAD(nir_var_mem_global, global, -1, 0, -1)
84    STORE(nir_var_mem_global, global, -1, 1, -1, 0)
85    ATOMIC(nir_var_mem_ssbo, ssbo, add, 0, 1, -1, 2)
86    ATOMIC(nir_var_mem_ssbo, ssbo, imin, 0, 1, -1, 2)
87    ATOMIC(nir_var_mem_ssbo, ssbo, umin, 0, 1, -1, 2)
88    ATOMIC(nir_var_mem_ssbo, ssbo, imax, 0, 1, -1, 2)
89    ATOMIC(nir_var_mem_ssbo, ssbo, umax, 0, 1, -1, 2)
90    ATOMIC(nir_var_mem_ssbo, ssbo, and, 0, 1, -1, 2)
91    ATOMIC(nir_var_mem_ssbo, ssbo, or, 0, 1, -1, 2)
92    ATOMIC(nir_var_mem_ssbo, ssbo, xor, 0, 1, -1, 2)
93    ATOMIC(nir_var_mem_ssbo, ssbo, exchange, 0, 1, -1, 2)
94    ATOMIC(nir_var_mem_ssbo, ssbo, comp_swap, 0, 1, -1, 2)
95    ATOMIC(nir_var_mem_ssbo, ssbo, fadd, 0, 1, -1, 2)
96    ATOMIC(nir_var_mem_ssbo, ssbo, fmin, 0, 1, -1, 2)
97    ATOMIC(nir_var_mem_ssbo, ssbo, fmax, 0, 1, -1, 2)
98    ATOMIC(nir_var_mem_ssbo, ssbo, fcomp_swap, 0, 1, -1, 2)
99    ATOMIC(0, deref, add, -1, -1, 0, 1)
100    ATOMIC(0, deref, imin, -1, -1, 0, 1)
101    ATOMIC(0, deref, umin, -1, -1, 0, 1)
102    ATOMIC(0, deref, imax, -1, -1, 0, 1)
103    ATOMIC(0, deref, umax, -1, -1, 0, 1)
104    ATOMIC(0, deref, and, -1, -1, 0, 1)
105    ATOMIC(0, deref, or, -1, -1, 0, 1)
106    ATOMIC(0, deref, xor, -1, -1, 0, 1)
107    ATOMIC(0, deref, exchange, -1, -1, 0, 1)
108    ATOMIC(0, deref, comp_swap, -1, -1, 0, 1)
109    ATOMIC(0, deref, fadd, -1, -1, 0, 1)
110    ATOMIC(0, deref, fmin, -1, -1, 0, 1)
111    ATOMIC(0, deref, fmax, -1, -1, 0, 1)
112    ATOMIC(0, deref, fcomp_swap, -1, -1, 0, 1)
113    ATOMIC(nir_var_mem_shared, shared, add, -1, 0, -1, 1)
114    ATOMIC(nir_var_mem_shared, shared, imin, -1, 0, -1, 1)
115    ATOMIC(nir_var_mem_shared, shared, umin, -1, 0, -1, 1)
116    ATOMIC(nir_var_mem_shared, shared, imax, -1, 0, -1, 1)
117    ATOMIC(nir_var_mem_shared, shared, umax, -1, 0, -1, 1)
118    ATOMIC(nir_var_mem_shared, shared, and, -1, 0, -1, 1)
119    ATOMIC(nir_var_mem_shared, shared, or, -1, 0, -1, 1)
120    ATOMIC(nir_var_mem_shared, shared, xor, -1, 0, -1, 1)
121    ATOMIC(nir_var_mem_shared, shared, exchange, -1, 0, -1, 1)
122    ATOMIC(nir_var_mem_shared, shared, comp_swap, -1, 0, -1, 1)
123    ATOMIC(nir_var_mem_shared, shared, fadd, -1, 0, -1, 1)
124    ATOMIC(nir_var_mem_shared, shared, fmin, -1, 0, -1, 1)
125    ATOMIC(nir_var_mem_shared, shared, fmax, -1, 0, -1, 1)
126    ATOMIC(nir_var_mem_shared, shared, fcomp_swap, -1, 0, -1, 1)
127    ATOMIC(nir_var_mem_global, global, add, -1, 0, -1, 1)
128    ATOMIC(nir_var_mem_global, global, imin, -1, 0, -1, 1)
129    ATOMIC(nir_var_mem_global, global, umin, -1, 0, -1, 1)
130    ATOMIC(nir_var_mem_global, global, imax, -1, 0, -1, 1)
131    ATOMIC(nir_var_mem_global, global, umax, -1, 0, -1, 1)
132    ATOMIC(nir_var_mem_global, global, and, -1, 0, -1, 1)
133    ATOMIC(nir_var_mem_global, global, or, -1, 0, -1, 1)
134    ATOMIC(nir_var_mem_global, global, xor, -1, 0, -1, 1)
135    ATOMIC(nir_var_mem_global, global, exchange, -1, 0, -1, 1)
136    ATOMIC(nir_var_mem_global, global, comp_swap, -1, 0, -1, 1)
137    ATOMIC(nir_var_mem_global, global, fadd, -1, 0, -1, 1)
138    ATOMIC(nir_var_mem_global, global, fmin, -1, 0, -1, 1)
139    ATOMIC(nir_var_mem_global, global, fmax, -1, 0, -1, 1)
140    ATOMIC(nir_var_mem_global, global, fcomp_swap, -1, 0, -1, 1)
141    default:
142       break;
143 #undef ATOMIC
144 #undef STORE
145 #undef LOAD
146 #undef INFO
147    }
148    return NULL;
149 }
150 
151 /*
152  * Information used to compare memory operations.
153  * It canonically represents an offset as:
154  * `offset_defs[0]*offset_defs_mul[0] + offset_defs[1]*offset_defs_mul[1] + ...`
155  * "offset_defs" is sorted in ascenting order by the ssa definition's index.
156  * "resource" or "var" may be NULL.
157  */
158 struct entry_key {
159    nir_ssa_def *resource;
160    nir_variable *var;
161    unsigned offset_def_count;
162    nir_ssa_def **offset_defs;
163    uint64_t *offset_defs_mul;
164 };
165 
166 /* Information on a single memory operation. */
167 struct entry {
168    struct list_head head;
169    unsigned index;
170 
171    struct entry_key *key;
172    union {
173       uint64_t offset; /* sign-extended */
174       int64_t offset_signed;
175    };
176    uint32_t align_mul;
177    uint32_t align_offset;
178 
179    nir_instr *instr;
180    nir_intrinsic_instr *intrin;
181    const struct intrinsic_info *info;
182    enum gl_access_qualifier access;
183    bool is_store;
184 
185    nir_deref_instr *deref;
186 };
187 
188 struct vectorize_ctx {
189    nir_variable_mode modes;
190    nir_should_vectorize_mem_func callback;
191    nir_variable_mode robust_modes;
192    struct list_head entries[nir_num_variable_modes];
193    struct hash_table *loads[nir_num_variable_modes];
194    struct hash_table *stores[nir_num_variable_modes];
195 };
196 
hash_entry_key(const void * key_)197 static uint32_t hash_entry_key(const void *key_)
198 {
199    /* this is careful to not include pointers in the hash calculation so that
200     * the order of the hash table walk is deterministic */
201    struct entry_key *key = (struct entry_key*)key_;
202 
203    uint32_t hash = 0;
204    if (key->resource)
205       hash = XXH32(&key->resource->index, sizeof(key->resource->index), hash);
206    if (key->var) {
207       hash = XXH32(&key->var->index, sizeof(key->var->index), hash);
208       unsigned mode = key->var->data.mode;
209       hash = XXH32(&mode, sizeof(mode), hash);
210    }
211 
212    for (unsigned i = 0; i < key->offset_def_count; i++)
213       hash = XXH32(&key->offset_defs[i]->index, sizeof(key->offset_defs[i]->index), hash);
214 
215    hash = XXH32(key->offset_defs_mul, key->offset_def_count * sizeof(uint64_t), hash);
216 
217    return hash;
218 }
219 
entry_key_equals(const void * a_,const void * b_)220 static bool entry_key_equals(const void *a_, const void *b_)
221 {
222    struct entry_key *a = (struct entry_key*)a_;
223    struct entry_key *b = (struct entry_key*)b_;
224 
225    if (a->var != b->var || a->resource != b->resource)
226       return false;
227 
228    if (a->offset_def_count != b->offset_def_count)
229       return false;
230 
231    size_t offset_def_size = a->offset_def_count * sizeof(nir_ssa_def *);
232    size_t offset_def_mul_size = a->offset_def_count * sizeof(uint64_t);
233    if (a->offset_def_count &&
234        (memcmp(a->offset_defs, b->offset_defs, offset_def_size) ||
235         memcmp(a->offset_defs_mul, b->offset_defs_mul, offset_def_mul_size)))
236       return false;
237 
238    return true;
239 }
240 
delete_entry_dynarray(struct hash_entry * entry)241 static void delete_entry_dynarray(struct hash_entry *entry)
242 {
243    struct util_dynarray *arr = (struct util_dynarray *)entry->data;
244    ralloc_free(arr);
245 }
246 
sort_entries(const void * a_,const void * b_)247 static int sort_entries(const void *a_, const void *b_)
248 {
249    struct entry *a = *(struct entry*const*)a_;
250    struct entry *b = *(struct entry*const*)b_;
251 
252    if (a->offset_signed > b->offset_signed)
253       return 1;
254    else if (a->offset_signed < b->offset_signed)
255       return -1;
256    else
257       return 0;
258 }
259 
260 static unsigned
get_bit_size(struct entry * entry)261 get_bit_size(struct entry *entry)
262 {
263    unsigned size = entry->is_store ?
264                    entry->intrin->src[entry->info->value_src].ssa->bit_size :
265                    entry->intrin->dest.ssa.bit_size;
266    return size == 1 ? 32u : size;
267 }
268 
269 /* If "def" is from an alu instruction with the opcode "op" and one of it's
270  * sources is a constant, update "def" to be the non-constant source, fill "c"
271  * with the constant and return true. */
272 static bool
parse_alu(nir_ssa_def ** def,nir_op op,uint64_t * c)273 parse_alu(nir_ssa_def **def, nir_op op, uint64_t *c)
274 {
275    nir_ssa_scalar scalar;
276    scalar.def = *def;
277    scalar.comp = 0;
278 
279    if (!nir_ssa_scalar_is_alu(scalar) || nir_ssa_scalar_alu_op(scalar) != op)
280       return false;
281 
282    nir_ssa_scalar src0 = nir_ssa_scalar_chase_alu_src(scalar, 0);
283    nir_ssa_scalar src1 = nir_ssa_scalar_chase_alu_src(scalar, 1);
284    if (op != nir_op_ishl && nir_ssa_scalar_is_const(src0) && src1.comp == 0) {
285       *c = nir_ssa_scalar_as_uint(src0);
286       *def = src1.def;
287    } else if (nir_ssa_scalar_is_const(src1) && src0.comp == 0) {
288       *c = nir_ssa_scalar_as_uint(src1);
289       *def = src0.def;
290    } else {
291       return false;
292    }
293    return true;
294 }
295 
296 /* Parses an offset expression such as "a * 16 + 4" and "(a * 16 + 4) * 64 + 32". */
297 static void
parse_offset(nir_ssa_def ** base,uint64_t * base_mul,uint64_t * offset)298 parse_offset(nir_ssa_def **base, uint64_t *base_mul, uint64_t *offset)
299 {
300    if ((*base)->parent_instr->type == nir_instr_type_load_const) {
301       *offset = nir_src_comp_as_uint(nir_src_for_ssa(*base), 0);
302       *base = NULL;
303       return;
304    }
305 
306    uint64_t mul = 1;
307    uint64_t add = 0;
308    bool progress = false;
309    do {
310       uint64_t mul2 = 1, add2 = 0;
311 
312       progress = parse_alu(base, nir_op_imul, &mul2);
313       mul *= mul2;
314 
315       mul2 = 0;
316       progress |= parse_alu(base, nir_op_ishl, &mul2);
317       mul <<= mul2;
318 
319       progress |= parse_alu(base, nir_op_iadd, &add2);
320       add += add2 * mul;
321    } while (progress);
322 
323    *base_mul = mul;
324    *offset = add;
325 }
326 
327 static unsigned
type_scalar_size_bytes(const struct glsl_type * type)328 type_scalar_size_bytes(const struct glsl_type *type)
329 {
330    assert(glsl_type_is_vector_or_scalar(type) ||
331           glsl_type_is_matrix(type));
332    return glsl_type_is_boolean(type) ? 4u : glsl_get_bit_size(type) / 8u;
333 }
334 
335 static uint64_t
mask_sign_extend(uint64_t val,unsigned bit_size)336 mask_sign_extend(uint64_t val, unsigned bit_size)
337 {
338    return (int64_t)(val << (64 - bit_size)) >> (64 - bit_size);
339 }
340 
341 static unsigned
add_to_entry_key(nir_ssa_def ** offset_defs,uint64_t * offset_defs_mul,unsigned offset_def_count,nir_ssa_def * def,uint64_t mul)342 add_to_entry_key(nir_ssa_def **offset_defs, uint64_t *offset_defs_mul,
343                  unsigned offset_def_count, nir_ssa_def *def, uint64_t mul)
344 {
345    mul = mask_sign_extend(mul, def->bit_size);
346 
347    for (unsigned i = 0; i <= offset_def_count; i++) {
348       if (i == offset_def_count || def->index > offset_defs[i]->index) {
349          /* insert before i */
350          memmove(offset_defs + i + 1, offset_defs + i,
351                  (offset_def_count - i) * sizeof(nir_ssa_def *));
352          memmove(offset_defs_mul + i + 1, offset_defs_mul + i,
353                  (offset_def_count - i) * sizeof(uint64_t));
354          offset_defs[i] = def;
355          offset_defs_mul[i] = mul;
356          return 1;
357       } else if (def->index == offset_defs[i]->index) {
358          /* merge with offset_def at i */
359          offset_defs_mul[i] += mul;
360          return 0;
361       }
362    }
363    unreachable("Unreachable.");
364    return 0;
365 }
366 
367 static struct entry_key *
create_entry_key_from_deref(void * mem_ctx,struct vectorize_ctx * ctx,nir_deref_path * path,uint64_t * offset_base)368 create_entry_key_from_deref(void *mem_ctx,
369                             struct vectorize_ctx *ctx,
370                             nir_deref_path *path,
371                             uint64_t *offset_base)
372 {
373    unsigned path_len = 0;
374    while (path->path[path_len])
375       path_len++;
376 
377    nir_ssa_def *offset_defs_stack[32];
378    uint64_t offset_defs_mul_stack[32];
379    nir_ssa_def **offset_defs = offset_defs_stack;
380    uint64_t *offset_defs_mul = offset_defs_mul_stack;
381    if (path_len > 32) {
382       offset_defs = malloc(path_len * sizeof(nir_ssa_def *));
383       offset_defs_mul = malloc(path_len * sizeof(uint64_t));
384    }
385    unsigned offset_def_count = 0;
386 
387    struct entry_key* key = ralloc(mem_ctx, struct entry_key);
388    key->resource = NULL;
389    key->var = NULL;
390    *offset_base = 0;
391 
392    for (unsigned i = 0; i < path_len; i++) {
393       nir_deref_instr *parent = i ? path->path[i - 1] : NULL;
394       nir_deref_instr *deref = path->path[i];
395 
396       switch (deref->deref_type) {
397       case nir_deref_type_var: {
398          assert(!parent);
399          key->var = deref->var;
400          break;
401       }
402       case nir_deref_type_array:
403       case nir_deref_type_ptr_as_array: {
404          assert(parent);
405          nir_ssa_def *index = deref->arr.index.ssa;
406          uint32_t stride = nir_deref_instr_array_stride(deref);
407 
408          nir_ssa_def *base = index;
409          uint64_t offset = 0, base_mul = 1;
410          parse_offset(&base, &base_mul, &offset);
411          offset = mask_sign_extend(offset, index->bit_size);
412 
413          *offset_base += offset * stride;
414          if (base) {
415             offset_def_count += add_to_entry_key(offset_defs, offset_defs_mul,
416                                                  offset_def_count,
417                                                  base, base_mul * stride);
418          }
419          break;
420       }
421       case nir_deref_type_struct: {
422          assert(parent);
423          int offset = glsl_get_struct_field_offset(parent->type, deref->strct.index);
424          *offset_base += offset;
425          break;
426       }
427       case nir_deref_type_cast: {
428          if (!parent)
429             key->resource = deref->parent.ssa;
430          break;
431       }
432       default:
433          unreachable("Unhandled deref type");
434       }
435    }
436 
437    key->offset_def_count = offset_def_count;
438    key->offset_defs = ralloc_array(mem_ctx, nir_ssa_def *, offset_def_count);
439    key->offset_defs_mul = ralloc_array(mem_ctx, uint64_t, offset_def_count);
440    memcpy(key->offset_defs, offset_defs, offset_def_count * sizeof(nir_ssa_def *));
441    memcpy(key->offset_defs_mul, offset_defs_mul, offset_def_count * sizeof(uint64_t));
442 
443    if (offset_defs != offset_defs_stack)
444       free(offset_defs);
445    if (offset_defs_mul != offset_defs_mul_stack)
446       free(offset_defs_mul);
447 
448    return key;
449 }
450 
451 static unsigned
parse_entry_key_from_offset(struct entry_key * key,unsigned size,unsigned left,nir_ssa_def * base,uint64_t base_mul,uint64_t * offset)452 parse_entry_key_from_offset(struct entry_key *key, unsigned size, unsigned left,
453                             nir_ssa_def *base, uint64_t base_mul, uint64_t *offset)
454 {
455    uint64_t new_mul;
456    uint64_t new_offset;
457    parse_offset(&base, &new_mul, &new_offset);
458    *offset += new_offset * base_mul;
459 
460    if (!base)
461       return 0;
462 
463    base_mul *= new_mul;
464 
465    assert(left >= 1);
466 
467    if (left >= 2) {
468       nir_ssa_scalar scalar;
469       scalar.def = base;
470       scalar.comp = 0;
471       if (nir_ssa_scalar_is_alu(scalar) && nir_ssa_scalar_alu_op(scalar) == nir_op_iadd) {
472          nir_ssa_scalar src0 = nir_ssa_scalar_chase_alu_src(scalar, 0);
473          nir_ssa_scalar src1 = nir_ssa_scalar_chase_alu_src(scalar, 1);
474          if (src0.comp == 0 && src1.comp == 0) {
475             unsigned amount = parse_entry_key_from_offset(key, size, left - 1, src0.def, base_mul, offset);
476             amount += parse_entry_key_from_offset(key, size + amount, left - amount, src1.def, base_mul, offset);
477             return amount;
478          }
479       }
480    }
481 
482    return add_to_entry_key(key->offset_defs, key->offset_defs_mul, size, base, base_mul);
483 }
484 
485 static struct entry_key *
create_entry_key_from_offset(void * mem_ctx,nir_ssa_def * base,uint64_t base_mul,uint64_t * offset)486 create_entry_key_from_offset(void *mem_ctx, nir_ssa_def *base, uint64_t base_mul, uint64_t *offset)
487 {
488    struct entry_key *key = ralloc(mem_ctx, struct entry_key);
489    key->resource = NULL;
490    key->var = NULL;
491    if (base) {
492       nir_ssa_def *offset_defs[32];
493       uint64_t offset_defs_mul[32];
494       key->offset_defs = offset_defs;
495       key->offset_defs_mul = offset_defs_mul;
496 
497       key->offset_def_count = parse_entry_key_from_offset(key, 0, 32, base, base_mul, offset);
498 
499       key->offset_defs = ralloc_array(mem_ctx, nir_ssa_def *, key->offset_def_count);
500       key->offset_defs_mul = ralloc_array(mem_ctx, uint64_t, key->offset_def_count);
501       memcpy(key->offset_defs, offset_defs, key->offset_def_count * sizeof(nir_ssa_def *));
502       memcpy(key->offset_defs_mul, offset_defs_mul, key->offset_def_count * sizeof(uint64_t));
503    } else {
504       key->offset_def_count = 0;
505       key->offset_defs = NULL;
506       key->offset_defs_mul = NULL;
507    }
508    return key;
509 }
510 
511 static nir_variable_mode
get_variable_mode(struct entry * entry)512 get_variable_mode(struct entry *entry)
513 {
514    if (entry->info->mode)
515       return entry->info->mode;
516    assert(entry->deref && util_bitcount(entry->deref->modes) == 1);
517    return entry->deref->modes;
518 }
519 
520 static unsigned
mode_to_index(nir_variable_mode mode)521 mode_to_index(nir_variable_mode mode)
522 {
523    assert(util_bitcount(mode) == 1);
524 
525    /* Globals and SSBOs should be tracked together */
526    if (mode == nir_var_mem_global)
527       mode = nir_var_mem_ssbo;
528 
529    return ffs(mode) - 1;
530 }
531 
532 static nir_variable_mode
aliasing_modes(nir_variable_mode modes)533 aliasing_modes(nir_variable_mode modes)
534 {
535    /* Global and SSBO can alias */
536    if (modes & (nir_var_mem_ssbo | nir_var_mem_global))
537       modes |= nir_var_mem_ssbo | nir_var_mem_global;
538    return modes;
539 }
540 
541 static void
calc_alignment(struct entry * entry)542 calc_alignment(struct entry *entry)
543 {
544    uint32_t align_mul = 31;
545    for (unsigned i = 0; i < entry->key->offset_def_count; i++) {
546       if (entry->key->offset_defs_mul[i])
547          align_mul = MIN2(align_mul, ffsll(entry->key->offset_defs_mul[i]));
548    }
549 
550    entry->align_mul = 1u << (align_mul - 1);
551    bool has_align = nir_intrinsic_infos[entry->intrin->intrinsic].index_map[NIR_INTRINSIC_ALIGN_MUL];
552    if (!has_align || entry->align_mul >= nir_intrinsic_align_mul(entry->intrin)) {
553       entry->align_offset = entry->offset % entry->align_mul;
554    } else {
555       entry->align_mul = nir_intrinsic_align_mul(entry->intrin);
556       entry->align_offset = nir_intrinsic_align_offset(entry->intrin);
557    }
558 }
559 
560 static struct entry *
create_entry(struct vectorize_ctx * ctx,const struct intrinsic_info * info,nir_intrinsic_instr * intrin)561 create_entry(struct vectorize_ctx *ctx,
562              const struct intrinsic_info *info,
563              nir_intrinsic_instr *intrin)
564 {
565    struct entry *entry = rzalloc(ctx, struct entry);
566    entry->intrin = intrin;
567    entry->instr = &intrin->instr;
568    entry->info = info;
569    entry->is_store = entry->info->value_src >= 0;
570 
571    if (entry->info->deref_src >= 0) {
572       entry->deref = nir_src_as_deref(intrin->src[entry->info->deref_src]);
573       nir_deref_path path;
574       nir_deref_path_init(&path, entry->deref, NULL);
575       entry->key = create_entry_key_from_deref(entry, ctx, &path, &entry->offset);
576       nir_deref_path_finish(&path);
577    } else {
578       nir_ssa_def *base = entry->info->base_src >= 0 ?
579                           intrin->src[entry->info->base_src].ssa : NULL;
580       uint64_t offset = 0;
581       if (nir_intrinsic_has_base(intrin))
582          offset += nir_intrinsic_base(intrin);
583       entry->key = create_entry_key_from_offset(entry, base, 1, &offset);
584       entry->offset = offset;
585 
586       if (base)
587          entry->offset = mask_sign_extend(entry->offset, base->bit_size);
588    }
589 
590    if (entry->info->resource_src >= 0)
591       entry->key->resource = intrin->src[entry->info->resource_src].ssa;
592 
593    if (nir_intrinsic_has_access(intrin))
594       entry->access = nir_intrinsic_access(intrin);
595    else if (entry->key->var)
596       entry->access = entry->key->var->data.access;
597 
598    uint32_t restrict_modes = nir_var_shader_in | nir_var_shader_out;
599    restrict_modes |= nir_var_shader_temp | nir_var_function_temp;
600    restrict_modes |= nir_var_uniform | nir_var_mem_push_const;
601    restrict_modes |= nir_var_system_value | nir_var_mem_shared;
602    if (get_variable_mode(entry) & restrict_modes)
603       entry->access |= ACCESS_RESTRICT;
604 
605    calc_alignment(entry);
606 
607    return entry;
608 }
609 
610 static nir_deref_instr *
cast_deref(nir_builder * b,unsigned num_components,unsigned bit_size,nir_deref_instr * deref)611 cast_deref(nir_builder *b, unsigned num_components, unsigned bit_size, nir_deref_instr *deref)
612 {
613    if (glsl_get_components(deref->type) == num_components &&
614        type_scalar_size_bytes(deref->type)*8u == bit_size)
615       return deref;
616 
617    enum glsl_base_type types[] = {
618       GLSL_TYPE_UINT8, GLSL_TYPE_UINT16, GLSL_TYPE_UINT, GLSL_TYPE_UINT64};
619    enum glsl_base_type base = types[ffs(bit_size / 8u) - 1u];
620    const struct glsl_type *type = glsl_vector_type(base, num_components);
621 
622    if (deref->type == type)
623       return deref;
624 
625    return nir_build_deref_cast(b, &deref->dest.ssa, deref->modes, type, 0);
626 }
627 
628 /* Return true if "new_bit_size" is a usable bit size for a vectorized load/store
629  * of "low" and "high". */
630 static bool
new_bitsize_acceptable(struct vectorize_ctx * ctx,unsigned new_bit_size,struct entry * low,struct entry * high,unsigned size)631 new_bitsize_acceptable(struct vectorize_ctx *ctx, unsigned new_bit_size,
632                        struct entry *low, struct entry *high, unsigned size)
633 {
634    if (size % new_bit_size != 0)
635       return false;
636 
637    unsigned new_num_components = size / new_bit_size;
638    if (!nir_num_components_valid(new_num_components))
639       return false;
640 
641    unsigned high_offset = high->offset_signed - low->offset_signed;
642 
643    /* check nir_extract_bits limitations */
644    unsigned common_bit_size = MIN2(get_bit_size(low), get_bit_size(high));
645    common_bit_size = MIN2(common_bit_size, new_bit_size);
646    if (high_offset > 0)
647       common_bit_size = MIN2(common_bit_size, (1u << (ffs(high_offset * 8) - 1)));
648    if (new_bit_size / common_bit_size > NIR_MAX_VEC_COMPONENTS)
649       return false;
650 
651    if (!ctx->callback(low->align_mul,
652                       low->align_offset,
653                       new_bit_size, new_num_components,
654                       low->intrin, high->intrin))
655       return false;
656 
657    if (low->is_store) {
658       unsigned low_size = low->intrin->num_components * get_bit_size(low);
659       unsigned high_size = high->intrin->num_components * get_bit_size(high);
660 
661       if (low_size % new_bit_size != 0)
662          return false;
663       if (high_size % new_bit_size != 0)
664          return false;
665 
666       unsigned write_mask = nir_intrinsic_write_mask(low->intrin);
667       if (!nir_component_mask_can_reinterpret(write_mask, get_bit_size(low), new_bit_size))
668          return false;
669 
670       write_mask = nir_intrinsic_write_mask(high->intrin);
671       if (!nir_component_mask_can_reinterpret(write_mask, get_bit_size(high), new_bit_size))
672          return false;
673    }
674 
675    return true;
676 }
677 
subtract_deref(nir_builder * b,nir_deref_instr * deref,int64_t offset)678 static nir_deref_instr *subtract_deref(nir_builder *b, nir_deref_instr *deref, int64_t offset)
679 {
680    /* avoid adding another deref to the path */
681    if (deref->deref_type == nir_deref_type_ptr_as_array &&
682        nir_src_is_const(deref->arr.index) &&
683        offset % nir_deref_instr_array_stride(deref) == 0) {
684       unsigned stride = nir_deref_instr_array_stride(deref);
685       nir_ssa_def *index = nir_imm_intN_t(b, nir_src_as_int(deref->arr.index) - offset / stride,
686                                           deref->dest.ssa.bit_size);
687       return nir_build_deref_ptr_as_array(b, nir_deref_instr_parent(deref), index);
688    }
689 
690    if (deref->deref_type == nir_deref_type_array &&
691        nir_src_is_const(deref->arr.index)) {
692       nir_deref_instr *parent = nir_deref_instr_parent(deref);
693       unsigned stride = glsl_get_explicit_stride(parent->type);
694       if (offset % stride == 0)
695          return nir_build_deref_array_imm(
696             b, parent, nir_src_as_int(deref->arr.index) - offset / stride);
697    }
698 
699 
700    deref = nir_build_deref_cast(b, &deref->dest.ssa, deref->modes,
701                                 glsl_scalar_type(GLSL_TYPE_UINT8), 1);
702    return nir_build_deref_ptr_as_array(
703       b, deref, nir_imm_intN_t(b, -offset, deref->dest.ssa.bit_size));
704 }
705 
706 static void
vectorize_loads(nir_builder * b,struct vectorize_ctx * ctx,struct entry * low,struct entry * high,struct entry * first,struct entry * second,unsigned new_bit_size,unsigned new_num_components,unsigned high_start)707 vectorize_loads(nir_builder *b, struct vectorize_ctx *ctx,
708                 struct entry *low, struct entry *high,
709                 struct entry *first, struct entry *second,
710                 unsigned new_bit_size, unsigned new_num_components,
711                 unsigned high_start)
712 {
713    unsigned low_bit_size = get_bit_size(low);
714    unsigned high_bit_size = get_bit_size(high);
715    bool low_bool = low->intrin->dest.ssa.bit_size == 1;
716    bool high_bool = high->intrin->dest.ssa.bit_size == 1;
717    nir_ssa_def *data = &first->intrin->dest.ssa;
718 
719    b->cursor = nir_after_instr(first->instr);
720 
721    /* update the load's destination size and extract data for each of the original loads */
722    data->num_components = new_num_components;
723    data->bit_size = new_bit_size;
724 
725    nir_ssa_def *low_def = nir_extract_bits(
726       b, &data, 1, 0, low->intrin->num_components, low_bit_size);
727    nir_ssa_def *high_def = nir_extract_bits(
728       b, &data, 1, high_start, high->intrin->num_components, high_bit_size);
729 
730    /* convert booleans */
731    low_def = low_bool ? nir_i2b(b, low_def) : nir_mov(b, low_def);
732    high_def = high_bool ? nir_i2b(b, high_def) : nir_mov(b, high_def);
733 
734    /* update uses */
735    if (first == low) {
736       nir_ssa_def_rewrite_uses_after(&low->intrin->dest.ssa, nir_src_for_ssa(low_def),
737                                      high_def->parent_instr);
738       nir_ssa_def_rewrite_uses(&high->intrin->dest.ssa, nir_src_for_ssa(high_def));
739    } else {
740       nir_ssa_def_rewrite_uses(&low->intrin->dest.ssa, nir_src_for_ssa(low_def));
741       nir_ssa_def_rewrite_uses_after(&high->intrin->dest.ssa, nir_src_for_ssa(high_def),
742                                      high_def->parent_instr);
743    }
744 
745    /* update the intrinsic */
746    first->intrin->num_components = new_num_components;
747 
748    const struct intrinsic_info *info = first->info;
749 
750    /* update the offset */
751    if (first != low && info->base_src >= 0) {
752       /* let nir_opt_algebraic() remove this addition. this doesn't have much
753        * issues with subtracting 16 from expressions like "(i + 1) * 16" because
754        * nir_opt_algebraic() turns them into "i * 16 + 16" */
755       b->cursor = nir_before_instr(first->instr);
756 
757       nir_ssa_def *new_base = first->intrin->src[info->base_src].ssa;
758       new_base = nir_iadd_imm(b, new_base, -(int)(high_start / 8u));
759 
760       nir_instr_rewrite_src(first->instr, &first->intrin->src[info->base_src],
761                             nir_src_for_ssa(new_base));
762    }
763 
764    /* update the deref */
765    if (info->deref_src >= 0) {
766       b->cursor = nir_before_instr(first->instr);
767 
768       nir_deref_instr *deref = nir_src_as_deref(first->intrin->src[info->deref_src]);
769       if (first != low && high_start != 0)
770          deref = subtract_deref(b, deref, high_start / 8u);
771       first->deref = cast_deref(b, new_num_components, new_bit_size, deref);
772 
773       nir_instr_rewrite_src(first->instr, &first->intrin->src[info->deref_src],
774                             nir_src_for_ssa(&first->deref->dest.ssa));
775    }
776 
777    /* update base/align */
778    if (first != low && nir_intrinsic_has_base(first->intrin))
779       nir_intrinsic_set_base(first->intrin, nir_intrinsic_base(low->intrin));
780 
781    if (nir_intrinsic_has_range_base(first->intrin)) {
782       uint32_t low_base = nir_intrinsic_range_base(low->intrin);
783       uint32_t high_base = nir_intrinsic_range_base(high->intrin);
784       uint32_t low_end = low_base + nir_intrinsic_range(low->intrin);
785       uint32_t high_end = high_base + nir_intrinsic_range(high->intrin);
786 
787       nir_intrinsic_set_range_base(first->intrin, low_base);
788       nir_intrinsic_set_range(first->intrin, MAX2(low_end, high_end) - low_base);
789    }
790 
791    first->key = low->key;
792    first->offset = low->offset;
793 
794    first->align_mul = low->align_mul;
795    first->align_offset = low->align_offset;
796 
797    nir_instr_remove(second->instr);
798 }
799 
800 static void
vectorize_stores(nir_builder * b,struct vectorize_ctx * ctx,struct entry * low,struct entry * high,struct entry * first,struct entry * second,unsigned new_bit_size,unsigned new_num_components,unsigned high_start)801 vectorize_stores(nir_builder *b, struct vectorize_ctx *ctx,
802                  struct entry *low, struct entry *high,
803                  struct entry *first, struct entry *second,
804                  unsigned new_bit_size, unsigned new_num_components,
805                  unsigned high_start)
806 {
807    ASSERTED unsigned low_size = low->intrin->num_components * get_bit_size(low);
808    assert(low_size % new_bit_size == 0);
809 
810    b->cursor = nir_before_instr(second->instr);
811 
812    /* get new writemasks */
813    uint32_t low_write_mask = nir_intrinsic_write_mask(low->intrin);
814    uint32_t high_write_mask = nir_intrinsic_write_mask(high->intrin);
815    low_write_mask = nir_component_mask_reinterpret(low_write_mask,
816                                                    get_bit_size(low),
817                                                    new_bit_size);
818    high_write_mask = nir_component_mask_reinterpret(high_write_mask,
819                                                     get_bit_size(high),
820                                                     new_bit_size);
821    high_write_mask <<= high_start / new_bit_size;
822 
823    uint32_t write_mask = low_write_mask | high_write_mask;
824 
825    /* convert booleans */
826    nir_ssa_def *low_val = low->intrin->src[low->info->value_src].ssa;
827    nir_ssa_def *high_val = high->intrin->src[high->info->value_src].ssa;
828    low_val = low_val->bit_size == 1 ? nir_b2i(b, low_val, 32) : low_val;
829    high_val = high_val->bit_size == 1 ? nir_b2i(b, high_val, 32) : high_val;
830 
831    /* combine the data */
832    nir_ssa_def *data_channels[NIR_MAX_VEC_COMPONENTS];
833    for (unsigned i = 0; i < new_num_components; i++) {
834       bool set_low = low_write_mask & (1 << i);
835       bool set_high = high_write_mask & (1 << i);
836 
837       if (set_low && (!set_high || low == second)) {
838          unsigned offset = i * new_bit_size;
839          data_channels[i] = nir_extract_bits(b, &low_val, 1, offset, 1, new_bit_size);
840       } else if (set_high) {
841          assert(!set_low || high == second);
842          unsigned offset = i * new_bit_size - high_start;
843          data_channels[i] = nir_extract_bits(b, &high_val, 1, offset, 1, new_bit_size);
844       } else {
845          data_channels[i] = nir_ssa_undef(b, 1, new_bit_size);
846       }
847    }
848    nir_ssa_def *data = nir_vec(b, data_channels, new_num_components);
849 
850    /* update the intrinsic */
851    nir_intrinsic_set_write_mask(second->intrin, write_mask);
852    second->intrin->num_components = data->num_components;
853 
854    const struct intrinsic_info *info = second->info;
855    assert(info->value_src >= 0);
856    nir_instr_rewrite_src(second->instr, &second->intrin->src[info->value_src],
857                          nir_src_for_ssa(data));
858 
859    /* update the offset */
860    if (second != low && info->base_src >= 0)
861       nir_instr_rewrite_src(second->instr, &second->intrin->src[info->base_src],
862                             low->intrin->src[info->base_src]);
863 
864    /* update the deref */
865    if (info->deref_src >= 0) {
866       b->cursor = nir_before_instr(second->instr);
867       second->deref = cast_deref(b, new_num_components, new_bit_size,
868                                  nir_src_as_deref(low->intrin->src[info->deref_src]));
869       nir_instr_rewrite_src(second->instr, &second->intrin->src[info->deref_src],
870                             nir_src_for_ssa(&second->deref->dest.ssa));
871    }
872 
873    /* update base/align */
874    if (second != low && nir_intrinsic_has_base(second->intrin))
875       nir_intrinsic_set_base(second->intrin, nir_intrinsic_base(low->intrin));
876 
877    second->key = low->key;
878    second->offset = low->offset;
879 
880    second->align_mul = low->align_mul;
881    second->align_offset = low->align_offset;
882 
883    list_del(&first->head);
884    nir_instr_remove(first->instr);
885 }
886 
887 /* Returns true if it can prove that "a" and "b" point to different resources. */
888 static bool
resources_different(nir_ssa_def * a,nir_ssa_def * b)889 resources_different(nir_ssa_def *a, nir_ssa_def *b)
890 {
891    if (!a || !b)
892       return false;
893 
894    if (a->parent_instr->type == nir_instr_type_load_const &&
895        b->parent_instr->type == nir_instr_type_load_const) {
896       return nir_src_as_uint(nir_src_for_ssa(a)) != nir_src_as_uint(nir_src_for_ssa(b));
897    }
898 
899    if (a->parent_instr->type == nir_instr_type_intrinsic &&
900        b->parent_instr->type == nir_instr_type_intrinsic) {
901       nir_intrinsic_instr *aintrin = nir_instr_as_intrinsic(a->parent_instr);
902       nir_intrinsic_instr *bintrin = nir_instr_as_intrinsic(b->parent_instr);
903       if (aintrin->intrinsic == nir_intrinsic_vulkan_resource_index &&
904           bintrin->intrinsic == nir_intrinsic_vulkan_resource_index) {
905          return nir_intrinsic_desc_set(aintrin) != nir_intrinsic_desc_set(bintrin) ||
906                 nir_intrinsic_binding(aintrin) != nir_intrinsic_binding(bintrin) ||
907                 resources_different(aintrin->src[0].ssa, bintrin->src[0].ssa);
908       }
909    }
910 
911    return false;
912 }
913 
914 static int64_t
compare_entries(struct entry * a,struct entry * b)915 compare_entries(struct entry *a, struct entry *b)
916 {
917    if (!entry_key_equals(a->key, b->key))
918       return INT64_MAX;
919    return b->offset_signed - a->offset_signed;
920 }
921 
922 static bool
may_alias(struct entry * a,struct entry * b)923 may_alias(struct entry *a, struct entry *b)
924 {
925    assert(mode_to_index(get_variable_mode(a)) ==
926           mode_to_index(get_variable_mode(b)));
927 
928    /* if the resources/variables are definitively different and both have
929     * ACCESS_RESTRICT, we can assume they do not alias. */
930    bool res_different = a->key->var != b->key->var ||
931                         resources_different(a->key->resource, b->key->resource);
932    if (res_different && (a->access & ACCESS_RESTRICT) && (b->access & ACCESS_RESTRICT))
933       return false;
934 
935    /* we can't compare offsets if the resources/variables might be different */
936    if (a->key->var != b->key->var || a->key->resource != b->key->resource)
937       return true;
938 
939    /* use adjacency information */
940    /* TODO: we can look closer at the entry keys */
941    int64_t diff = compare_entries(a, b);
942    if (diff != INT64_MAX) {
943       /* with atomics, intrin->num_components can be 0 */
944       if (diff < 0)
945          return llabs(diff) < MAX2(b->intrin->num_components, 1u) * (get_bit_size(b) / 8u);
946       else
947          return diff < MAX2(a->intrin->num_components, 1u) * (get_bit_size(a) / 8u);
948    }
949 
950    /* TODO: we can use deref information */
951 
952    return true;
953 }
954 
955 static bool
check_for_aliasing(struct vectorize_ctx * ctx,struct entry * first,struct entry * second)956 check_for_aliasing(struct vectorize_ctx *ctx, struct entry *first, struct entry *second)
957 {
958    nir_variable_mode mode = get_variable_mode(first);
959    if (mode & (nir_var_uniform | nir_var_system_value |
960                nir_var_mem_push_const | nir_var_mem_ubo))
961       return false;
962 
963    unsigned mode_index = mode_to_index(mode);
964    if (first->is_store) {
965       /* find first entry that aliases "first" */
966       list_for_each_entry_from(struct entry, next, first, &ctx->entries[mode_index], head) {
967          if (next == first)
968             continue;
969          if (next == second)
970             return false;
971          if (may_alias(first, next))
972             return true;
973       }
974    } else {
975       /* find previous store that aliases this load */
976       list_for_each_entry_from_rev(struct entry, prev, second, &ctx->entries[mode_index], head) {
977          if (prev == second)
978             continue;
979          if (prev == first)
980             return false;
981          if (prev->is_store && may_alias(second, prev))
982             return true;
983       }
984    }
985 
986    return false;
987 }
988 
989 static bool
check_for_robustness(struct vectorize_ctx * ctx,struct entry * low)990 check_for_robustness(struct vectorize_ctx *ctx, struct entry *low)
991 {
992    nir_variable_mode mode = get_variable_mode(low);
993    if (mode & ctx->robust_modes) {
994       unsigned low_bit_size = get_bit_size(low);
995       unsigned low_size = low->intrin->num_components * low_bit_size;
996 
997       /* don't attempt to vectorize accesses if the offset can overflow. */
998       /* TODO: handle indirect accesses. */
999       return low->offset_signed < 0 && low->offset_signed + low_size >= 0;
1000    }
1001 
1002    return false;
1003 }
1004 
1005 static bool
is_strided_vector(const struct glsl_type * type)1006 is_strided_vector(const struct glsl_type *type)
1007 {
1008    if (glsl_type_is_vector(type)) {
1009       unsigned explicit_stride = glsl_get_explicit_stride(type);
1010       return explicit_stride != 0 && explicit_stride !=
1011              type_scalar_size_bytes(glsl_get_array_element(type));
1012    } else {
1013       return false;
1014    }
1015 }
1016 
1017 static bool
try_vectorize(nir_function_impl * impl,struct vectorize_ctx * ctx,struct entry * low,struct entry * high,struct entry * first,struct entry * second)1018 try_vectorize(nir_function_impl *impl, struct vectorize_ctx *ctx,
1019               struct entry *low, struct entry *high,
1020               struct entry *first, struct entry *second)
1021 {
1022    if (!(get_variable_mode(first) & ctx->modes) ||
1023        !(get_variable_mode(second) & ctx->modes))
1024       return false;
1025 
1026    if (check_for_aliasing(ctx, first, second))
1027       return false;
1028 
1029    if (check_for_robustness(ctx, low))
1030       return false;
1031 
1032    /* we can only vectorize non-volatile loads/stores of the same type and with
1033     * the same access */
1034    if (first->info != second->info || first->access != second->access ||
1035        (first->access & ACCESS_VOLATILE) || first->info->is_atomic)
1036       return false;
1037 
1038    /* don't attempt to vectorize accesses of row-major matrix columns */
1039    if (first->deref) {
1040       const struct glsl_type *first_type = first->deref->type;
1041       const struct glsl_type *second_type = second->deref->type;
1042       if (is_strided_vector(first_type) || is_strided_vector(second_type))
1043          return false;
1044    }
1045 
1046    /* gather information */
1047    uint64_t diff = high->offset_signed - low->offset_signed;
1048    unsigned low_bit_size = get_bit_size(low);
1049    unsigned high_bit_size = get_bit_size(high);
1050    unsigned low_size = low->intrin->num_components * low_bit_size;
1051    unsigned high_size = high->intrin->num_components * high_bit_size;
1052    unsigned new_size = MAX2(diff * 8u + high_size, low_size);
1053 
1054    /* find a good bit size for the new load/store */
1055    unsigned new_bit_size = 0;
1056    if (new_bitsize_acceptable(ctx, low_bit_size, low, high, new_size)) {
1057       new_bit_size = low_bit_size;
1058    } else if (low_bit_size != high_bit_size &&
1059               new_bitsize_acceptable(ctx, high_bit_size, low, high, new_size)) {
1060       new_bit_size = high_bit_size;
1061    } else {
1062       new_bit_size = 64;
1063       for (; new_bit_size >= 8; new_bit_size /= 2) {
1064          /* don't repeat trying out bitsizes */
1065          if (new_bit_size == low_bit_size || new_bit_size == high_bit_size)
1066             continue;
1067          if (new_bitsize_acceptable(ctx, new_bit_size, low, high, new_size))
1068             break;
1069       }
1070       if (new_bit_size < 8)
1071          return false;
1072    }
1073    unsigned new_num_components = new_size / new_bit_size;
1074 
1075    /* vectorize the loads/stores */
1076    nir_builder b;
1077    nir_builder_init(&b, impl);
1078 
1079    if (first->is_store)
1080       vectorize_stores(&b, ctx, low, high, first, second,
1081                        new_bit_size, new_num_components, diff * 8u);
1082    else
1083       vectorize_loads(&b, ctx, low, high, first, second,
1084                       new_bit_size, new_num_components, diff * 8u);
1085 
1086    return true;
1087 }
1088 
1089 static bool
update_align(struct entry * entry)1090 update_align(struct entry *entry)
1091 {
1092    if (nir_intrinsic_has_align_mul(entry->intrin) &&
1093        (entry->align_mul != nir_intrinsic_align_mul(entry->intrin) ||
1094         entry->align_offset != nir_intrinsic_align_offset(entry->intrin))) {
1095       nir_intrinsic_set_align(entry->intrin, entry->align_mul, entry->align_offset);
1096       return true;
1097    }
1098    return false;
1099 }
1100 
1101 static bool
vectorize_entries(struct vectorize_ctx * ctx,nir_function_impl * impl,struct hash_table * ht)1102 vectorize_entries(struct vectorize_ctx *ctx, nir_function_impl *impl, struct hash_table *ht)
1103 {
1104    if (!ht)
1105       return false;
1106 
1107    bool progress = false;
1108    hash_table_foreach(ht, entry) {
1109       struct util_dynarray *arr = entry->data;
1110       if (!arr->size)
1111          continue;
1112 
1113       qsort(util_dynarray_begin(arr),
1114             util_dynarray_num_elements(arr, struct entry *),
1115             sizeof(struct entry *), &sort_entries);
1116 
1117       unsigned num_entries = util_dynarray_num_elements(arr, struct entry *);
1118 
1119       for (unsigned first_idx = 0; first_idx < num_entries; first_idx++) {
1120          struct entry *low = *util_dynarray_element(arr, struct entry *, first_idx);
1121          if (!low)
1122             continue;
1123 
1124          for (unsigned second_idx = first_idx + 1; second_idx < num_entries; second_idx++) {
1125             struct entry *high = *util_dynarray_element(arr, struct entry *, second_idx);
1126             if (!high)
1127                continue;
1128 
1129             uint64_t diff = high->offset_signed - low->offset_signed;
1130             if (diff > get_bit_size(low) / 8u * low->intrin->num_components)
1131                break;
1132 
1133             struct entry *first = low->index < high->index ? low : high;
1134             struct entry *second = low->index < high->index ? high : low;
1135 
1136             if (try_vectorize(impl, ctx, low, high, first, second)) {
1137                low = low->is_store ? second : first;
1138                *util_dynarray_element(arr, struct entry *, second_idx) = NULL;
1139                progress = true;
1140             }
1141          }
1142 
1143          *util_dynarray_element(arr, struct entry *, first_idx) = low;
1144       }
1145 
1146       util_dynarray_foreach(arr, struct entry *, elem) {
1147          if (*elem)
1148             progress |= update_align(*elem);
1149       }
1150    }
1151 
1152    _mesa_hash_table_clear(ht, delete_entry_dynarray);
1153 
1154    return progress;
1155 }
1156 
1157 static bool
handle_barrier(struct vectorize_ctx * ctx,bool * progress,nir_function_impl * impl,nir_instr * instr)1158 handle_barrier(struct vectorize_ctx *ctx, bool *progress, nir_function_impl *impl, nir_instr *instr)
1159 {
1160    unsigned modes = 0;
1161    bool acquire = true;
1162    bool release = true;
1163    if (instr->type == nir_instr_type_intrinsic) {
1164       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1165       switch (intrin->intrinsic) {
1166       case nir_intrinsic_group_memory_barrier:
1167       case nir_intrinsic_memory_barrier:
1168          modes = nir_var_mem_ssbo | nir_var_mem_shared | nir_var_mem_global;
1169          break;
1170       /* prevent speculative loads/stores */
1171       case nir_intrinsic_discard_if:
1172       case nir_intrinsic_discard:
1173       case nir_intrinsic_terminate_if:
1174       case nir_intrinsic_terminate:
1175          modes = nir_var_all;
1176          break;
1177       case nir_intrinsic_demote_if:
1178       case nir_intrinsic_demote:
1179          acquire = false;
1180          modes = nir_var_all;
1181          break;
1182       case nir_intrinsic_memory_barrier_buffer:
1183          modes = nir_var_mem_ssbo | nir_var_mem_global;
1184          break;
1185       case nir_intrinsic_memory_barrier_shared:
1186          modes = nir_var_mem_shared;
1187          break;
1188       case nir_intrinsic_scoped_barrier:
1189          if (nir_intrinsic_memory_scope(intrin) == NIR_SCOPE_NONE)
1190             break;
1191 
1192          modes = nir_intrinsic_memory_modes(intrin) & (nir_var_mem_ssbo |
1193                                                        nir_var_mem_shared |
1194                                                        nir_var_mem_global);
1195          acquire = nir_intrinsic_memory_semantics(intrin) & NIR_MEMORY_ACQUIRE;
1196          release = nir_intrinsic_memory_semantics(intrin) & NIR_MEMORY_RELEASE;
1197          switch (nir_intrinsic_memory_scope(intrin)) {
1198          case NIR_SCOPE_INVOCATION:
1199             /* a barier should never be required for correctness with these scopes */
1200             modes = 0;
1201             break;
1202          default:
1203             break;
1204          }
1205          break;
1206       default:
1207          return false;
1208       }
1209    } else if (instr->type == nir_instr_type_call) {
1210       modes = nir_var_all;
1211    } else {
1212       return false;
1213    }
1214 
1215    while (modes) {
1216       unsigned mode_index = u_bit_scan(&modes);
1217       if ((1 << mode_index) == nir_var_mem_global) {
1218          /* Global should be rolled in with SSBO */
1219          assert(list_is_empty(&ctx->entries[mode_index]));
1220          assert(ctx->loads[mode_index] == NULL);
1221          assert(ctx->stores[mode_index] == NULL);
1222          continue;
1223       }
1224 
1225       if (acquire)
1226          *progress |= vectorize_entries(ctx, impl, ctx->loads[mode_index]);
1227       if (release)
1228          *progress |= vectorize_entries(ctx, impl, ctx->stores[mode_index]);
1229    }
1230 
1231    return true;
1232 }
1233 
1234 static bool
process_block(nir_function_impl * impl,struct vectorize_ctx * ctx,nir_block * block)1235 process_block(nir_function_impl *impl, struct vectorize_ctx *ctx, nir_block *block)
1236 {
1237    bool progress = false;
1238 
1239    for (unsigned i = 0; i < nir_num_variable_modes; i++) {
1240       list_inithead(&ctx->entries[i]);
1241       if (ctx->loads[i])
1242          _mesa_hash_table_clear(ctx->loads[i], delete_entry_dynarray);
1243       if (ctx->stores[i])
1244          _mesa_hash_table_clear(ctx->stores[i], delete_entry_dynarray);
1245    }
1246 
1247    /* create entries */
1248    unsigned next_index = 0;
1249 
1250    nir_foreach_instr_safe(instr, block) {
1251       if (handle_barrier(ctx, &progress, impl, instr))
1252          continue;
1253 
1254       /* gather information */
1255       if (instr->type != nir_instr_type_intrinsic)
1256          continue;
1257       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1258 
1259       const struct intrinsic_info *info = get_info(intrin->intrinsic);
1260       if (!info)
1261          continue;
1262 
1263       nir_variable_mode mode = info->mode;
1264       if (!mode)
1265          mode = nir_src_as_deref(intrin->src[info->deref_src])->modes;
1266       if (!(mode & aliasing_modes(ctx->modes)))
1267          continue;
1268       unsigned mode_index = mode_to_index(mode);
1269 
1270       /* create entry */
1271       struct entry *entry = create_entry(ctx, info, intrin);
1272       entry->index = next_index++;
1273 
1274       list_addtail(&entry->head, &ctx->entries[mode_index]);
1275 
1276       /* add the entry to a hash table */
1277 
1278       struct hash_table *adj_ht = NULL;
1279       if (entry->is_store) {
1280          if (!ctx->stores[mode_index])
1281             ctx->stores[mode_index] = _mesa_hash_table_create(ctx, &hash_entry_key, &entry_key_equals);
1282          adj_ht = ctx->stores[mode_index];
1283       } else {
1284          if (!ctx->loads[mode_index])
1285             ctx->loads[mode_index] = _mesa_hash_table_create(ctx, &hash_entry_key, &entry_key_equals);
1286          adj_ht = ctx->loads[mode_index];
1287       }
1288 
1289       uint32_t key_hash = hash_entry_key(entry->key);
1290       struct hash_entry *adj_entry = _mesa_hash_table_search_pre_hashed(adj_ht, key_hash, entry->key);
1291       struct util_dynarray *arr;
1292       if (adj_entry && adj_entry->data) {
1293          arr = (struct util_dynarray *)adj_entry->data;
1294       } else {
1295          arr = ralloc(ctx, struct util_dynarray);
1296          util_dynarray_init(arr, arr);
1297          _mesa_hash_table_insert_pre_hashed(adj_ht, key_hash, entry->key, arr);
1298       }
1299       util_dynarray_append(arr, struct entry *, entry);
1300    }
1301 
1302    /* sort and combine entries */
1303    for (unsigned i = 0; i < nir_num_variable_modes; i++) {
1304       progress |= vectorize_entries(ctx, impl, ctx->loads[i]);
1305       progress |= vectorize_entries(ctx, impl, ctx->stores[i]);
1306    }
1307 
1308    return progress;
1309 }
1310 
1311 bool
nir_opt_load_store_vectorize(nir_shader * shader,nir_variable_mode modes,nir_should_vectorize_mem_func callback,nir_variable_mode robust_modes)1312 nir_opt_load_store_vectorize(nir_shader *shader, nir_variable_mode modes,
1313                              nir_should_vectorize_mem_func callback,
1314                              nir_variable_mode robust_modes)
1315 {
1316    bool progress = false;
1317 
1318    struct vectorize_ctx *ctx = rzalloc(NULL, struct vectorize_ctx);
1319    ctx->modes = modes;
1320    ctx->callback = callback;
1321    ctx->robust_modes = robust_modes;
1322 
1323    nir_shader_index_vars(shader, modes);
1324 
1325    nir_foreach_function(function, shader) {
1326       if (function->impl) {
1327          if (modes & nir_var_function_temp)
1328             nir_function_impl_index_vars(function->impl);
1329 
1330          nir_foreach_block(block, function->impl)
1331             progress |= process_block(function->impl, ctx, block);
1332 
1333          nir_metadata_preserve(function->impl,
1334                                nir_metadata_block_index |
1335                                nir_metadata_dominance |
1336                                nir_metadata_live_ssa_defs);
1337       }
1338    }
1339 
1340    ralloc_free(ctx);
1341    return progress;
1342 }
1343