1 /*
2  * Copyright 2018 Collabora Ltd.
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * on the rights to use, copy, modify, merge, publish, distribute, sub
8  * license, and/or sell copies of the Software, and to permit persons to whom
9  * the Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18  * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19  * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20  * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21  * USE OR OTHER DEALINGS IN THE SOFTWARE.
22  */
23 
24 #include "nir_to_spirv.h"
25 #include "spirv_builder.h"
26 
27 #include "nir.h"
28 #include "pipe/p_state.h"
29 #include "util/u_memory.h"
30 #include "util/hash_table.h"
31 
32 #define SLOT_UNSET ((unsigned char) -1)
33 
34 struct ntv_context {
35    void *mem_ctx;
36 
37    struct spirv_builder builder;
38 
39    SpvId GLSL_std_450;
40 
41    gl_shader_stage stage;
42    const struct zink_so_info *so_info;
43 
44    SpvId ubos[128];
45    size_t num_ubos;
46    SpvId image_types[PIPE_MAX_SAMPLERS];
47    SpvId samplers[PIPE_MAX_SAMPLERS];
48    unsigned samplers_used : PIPE_MAX_SAMPLERS;
49    SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
50    size_t num_entry_ifaces;
51 
52    SpvId *defs;
53    size_t num_defs;
54 
55    SpvId *regs;
56    size_t num_regs;
57 
58    struct hash_table *vars; /* nir_variable -> SpvId */
59    struct hash_table *so_outputs; /* pipe_stream_output -> SpvId */
60    unsigned outputs[VARYING_SLOT_MAX];
61    const struct glsl_type *so_output_gl_types[VARYING_SLOT_MAX];
62    SpvId so_output_types[VARYING_SLOT_MAX];
63 
64    const SpvId *block_ids;
65    size_t num_blocks;
66    bool block_started;
67    SpvId loop_break, loop_cont;
68 
69    unsigned char *shader_slot_map;
70    unsigned char shader_slots_reserved;
71 
72    SpvId front_face_var, instance_id_var, vertex_id_var,
73          primitive_id_var, invocation_id_var, // geometry
74          sample_mask_type, sample_id_var, sample_pos_var;
75 };
76 
77 static SpvId
78 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
79                   unsigned num_components, float value);
80 
81 static SpvId
82 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
83                   unsigned num_components, uint32_t value);
84 
85 static SpvId
86 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
87                   unsigned num_components, int32_t value);
88 
89 static SpvId
90 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
91 
92 static SpvId
93 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
94            SpvId src0, SpvId src1);
95 
96 static SpvId
97 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
98            SpvId src0, SpvId src1, SpvId src2);
99 
100 static SpvId
get_bvec_type(struct ntv_context * ctx,int num_components)101 get_bvec_type(struct ntv_context *ctx, int num_components)
102 {
103    SpvId bool_type = spirv_builder_type_bool(&ctx->builder);
104    if (num_components > 1)
105       return spirv_builder_type_vector(&ctx->builder, bool_type,
106                                        num_components);
107 
108    assert(num_components == 1);
109    return bool_type;
110 }
111 
112 static SpvId
block_label(struct ntv_context * ctx,nir_block * block)113 block_label(struct ntv_context *ctx, nir_block *block)
114 {
115    assert(block->index < ctx->num_blocks);
116    return ctx->block_ids[block->index];
117 }
118 
119 static SpvId
emit_float_const(struct ntv_context * ctx,int bit_size,float value)120 emit_float_const(struct ntv_context *ctx, int bit_size, float value)
121 {
122    assert(bit_size == 32);
123    return spirv_builder_const_float(&ctx->builder, bit_size, value);
124 }
125 
126 static SpvId
emit_uint_const(struct ntv_context * ctx,int bit_size,uint32_t value)127 emit_uint_const(struct ntv_context *ctx, int bit_size, uint32_t value)
128 {
129    assert(bit_size == 32);
130    return spirv_builder_const_uint(&ctx->builder, bit_size, value);
131 }
132 
133 static SpvId
emit_int_const(struct ntv_context * ctx,int bit_size,int32_t value)134 emit_int_const(struct ntv_context *ctx, int bit_size, int32_t value)
135 {
136    assert(bit_size == 32);
137    return spirv_builder_const_int(&ctx->builder, bit_size, value);
138 }
139 
140 static SpvId
get_fvec_type(struct ntv_context * ctx,unsigned bit_size,unsigned num_components)141 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
142 {
143    assert(bit_size == 32); // only 32-bit floats supported so far
144 
145    SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
146    if (num_components > 1)
147       return spirv_builder_type_vector(&ctx->builder, float_type,
148                                        num_components);
149 
150    assert(num_components == 1);
151    return float_type;
152 }
153 
154 static SpvId
get_ivec_type(struct ntv_context * ctx,unsigned bit_size,unsigned num_components)155 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
156 {
157    assert(bit_size == 32); // only 32-bit ints supported so far
158 
159    SpvId int_type = spirv_builder_type_int(&ctx->builder, bit_size);
160    if (num_components > 1)
161       return spirv_builder_type_vector(&ctx->builder, int_type,
162                                        num_components);
163 
164    assert(num_components == 1);
165    return int_type;
166 }
167 
168 static SpvId
get_uvec_type(struct ntv_context * ctx,unsigned bit_size,unsigned num_components)169 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
170 {
171    assert(bit_size == 32); // only 32-bit uints supported so far
172 
173    SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
174    if (num_components > 1)
175       return spirv_builder_type_vector(&ctx->builder, uint_type,
176                                        num_components);
177 
178    assert(num_components == 1);
179    return uint_type;
180 }
181 
182 static SpvId
get_dest_uvec_type(struct ntv_context * ctx,nir_dest * dest)183 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
184 {
185    unsigned bit_size = MAX2(nir_dest_bit_size(*dest), 32);
186    return get_uvec_type(ctx, bit_size, nir_dest_num_components(*dest));
187 }
188 
189 static SpvId
get_glsl_basetype(struct ntv_context * ctx,enum glsl_base_type type)190 get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
191 {
192    switch (type) {
193    case GLSL_TYPE_BOOL:
194       return spirv_builder_type_bool(&ctx->builder);
195 
196    case GLSL_TYPE_FLOAT:
197       return spirv_builder_type_float(&ctx->builder, 32);
198 
199    case GLSL_TYPE_INT:
200       return spirv_builder_type_int(&ctx->builder, 32);
201 
202    case GLSL_TYPE_UINT:
203       return spirv_builder_type_uint(&ctx->builder, 32);
204    /* TODO: handle more types */
205 
206    default:
207       unreachable("unknown GLSL type");
208    }
209 }
210 
211 static SpvId
get_glsl_type(struct ntv_context * ctx,const struct glsl_type * type)212 get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
213 {
214    assert(type);
215    if (glsl_type_is_scalar(type))
216       return get_glsl_basetype(ctx, glsl_get_base_type(type));
217 
218    if (glsl_type_is_vector(type))
219       return spirv_builder_type_vector(&ctx->builder,
220          get_glsl_basetype(ctx, glsl_get_base_type(type)),
221          glsl_get_vector_elements(type));
222 
223    if (glsl_type_is_array(type)) {
224       SpvId ret = spirv_builder_type_array(&ctx->builder,
225          get_glsl_type(ctx, glsl_get_array_element(type)),
226          emit_uint_const(ctx, 32, glsl_get_length(type)));
227       uint32_t stride = glsl_get_explicit_stride(type);
228       if (stride)
229          spirv_builder_emit_array_stride(&ctx->builder, ret, stride);
230       return ret;
231    }
232 
233    if (glsl_type_is_matrix(type))
234       return spirv_builder_type_matrix(&ctx->builder,
235                                        spirv_builder_type_vector(&ctx->builder,
236                                                                  get_glsl_basetype(ctx, glsl_get_base_type(type)),
237                                                                  glsl_get_vector_elements(type)),
238                                        glsl_get_matrix_columns(type));
239 
240    unreachable("we shouldn't get here, I think...");
241 }
242 
243 static inline unsigned char
reserve_slot(struct ntv_context * ctx)244 reserve_slot(struct ntv_context *ctx)
245 {
246    /* TODO: this should actually be clamped to the limits value as in the table
247     * in 14.1.4 of the vulkan spec, though there's not really any recourse
248     * other than aborting if we do hit it...
249     */
250    assert(ctx->shader_slots_reserved < MAX_VARYING);
251    return ctx->shader_slots_reserved++;
252 }
253 
254 static inline unsigned
handle_slot(struct ntv_context * ctx,unsigned slot)255 handle_slot(struct ntv_context *ctx, unsigned slot)
256 {
257    if (ctx->shader_slot_map[slot] == SLOT_UNSET)
258       ctx->shader_slot_map[slot] = reserve_slot(ctx);
259    slot = ctx->shader_slot_map[slot];
260    assert(slot < MAX_VARYING);
261    return slot;
262 }
263 
264 #define HANDLE_EMIT_BUILTIN(SLOT, BUILTIN) \
265       case VARYING_SLOT_##SLOT: \
266          spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltIn##BUILTIN); \
267          break
268 
269 
270 static void
emit_input(struct ntv_context * ctx,struct nir_variable * var)271 emit_input(struct ntv_context *ctx, struct nir_variable *var)
272 {
273    SpvId var_type = get_glsl_type(ctx, var->type);
274    SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
275                                                    SpvStorageClassInput,
276                                                    var_type);
277    SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
278                                          SpvStorageClassInput);
279 
280    if (var->name)
281       spirv_builder_emit_name(&ctx->builder, var_id, var->name);
282 
283    unsigned slot = var->data.location;
284    if (ctx->stage == MESA_SHADER_VERTEX)
285       spirv_builder_emit_location(&ctx->builder, var_id,
286                                   var->data.driver_location);
287    else if (ctx->stage == MESA_SHADER_FRAGMENT) {
288       switch (slot) {
289       HANDLE_EMIT_BUILTIN(POS, FragCoord);
290       HANDLE_EMIT_BUILTIN(PNTC, PointCoord);
291       HANDLE_EMIT_BUILTIN(LAYER, Layer);
292       HANDLE_EMIT_BUILTIN(PRIMITIVE_ID, PrimitiveId);
293       HANDLE_EMIT_BUILTIN(CLIP_DIST0, ClipDistance);
294       HANDLE_EMIT_BUILTIN(CULL_DIST0, CullDistance);
295       HANDLE_EMIT_BUILTIN(VIEWPORT, ViewportIndex);
296       HANDLE_EMIT_BUILTIN(FACE, FrontFacing);
297 
298       default:
299          slot = handle_slot(ctx, slot);
300          spirv_builder_emit_location(&ctx->builder, var_id, slot);
301       }
302       if (var->data.centroid)
303          spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationCentroid);
304       else if (var->data.sample)
305          spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationSample);
306    } else if (ctx->stage < MESA_SHADER_FRAGMENT) {
307       switch (slot) {
308       HANDLE_EMIT_BUILTIN(POS, Position);
309       HANDLE_EMIT_BUILTIN(PSIZ, PointSize);
310       HANDLE_EMIT_BUILTIN(LAYER, Layer);
311       HANDLE_EMIT_BUILTIN(PRIMITIVE_ID, PrimitiveId);
312       HANDLE_EMIT_BUILTIN(CULL_DIST0, CullDistance);
313       HANDLE_EMIT_BUILTIN(VIEWPORT, ViewportIndex);
314       HANDLE_EMIT_BUILTIN(TESS_LEVEL_OUTER, TessLevelOuter);
315       HANDLE_EMIT_BUILTIN(TESS_LEVEL_INNER, TessLevelInner);
316 
317       case VARYING_SLOT_CLIP_DIST0:
318          assert(glsl_type_is_array(var->type));
319          spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInClipDistance);
320          break;
321 
322       default:
323          slot = handle_slot(ctx, slot);
324          spirv_builder_emit_location(&ctx->builder, var_id, slot);
325       }
326    }
327 
328    if (var->data.location_frac)
329       spirv_builder_emit_component(&ctx->builder, var_id,
330                                    var->data.location_frac);
331 
332    if (var->data.interpolation == INTERP_MODE_FLAT)
333       spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
334 
335    _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
336 
337    assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
338    ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
339 }
340 
341 static void
emit_output(struct ntv_context * ctx,struct nir_variable * var)342 emit_output(struct ntv_context *ctx, struct nir_variable *var)
343 {
344    SpvId var_type = get_glsl_type(ctx, var->type);
345 
346    /* SampleMask is always an array in spirv */
347    if (ctx->stage == MESA_SHADER_FRAGMENT && var->data.location == FRAG_RESULT_SAMPLE_MASK)
348       ctx->sample_mask_type = var_type = spirv_builder_type_array(&ctx->builder, var_type, emit_uint_const(ctx, 32, 1));
349    SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
350                                                    SpvStorageClassOutput,
351                                                    var_type);
352    SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
353                                          SpvStorageClassOutput);
354    if (var->name)
355       spirv_builder_emit_name(&ctx->builder, var_id, var->name);
356 
357    unsigned slot = var->data.location;
358    if (ctx->stage != MESA_SHADER_FRAGMENT) {
359       switch (slot) {
360       HANDLE_EMIT_BUILTIN(POS, Position);
361       HANDLE_EMIT_BUILTIN(PSIZ, PointSize);
362       HANDLE_EMIT_BUILTIN(LAYER, Layer);
363       HANDLE_EMIT_BUILTIN(PRIMITIVE_ID, PrimitiveId);
364       HANDLE_EMIT_BUILTIN(CULL_DIST0, CullDistance);
365       HANDLE_EMIT_BUILTIN(VIEWPORT, ViewportIndex);
366       HANDLE_EMIT_BUILTIN(TESS_LEVEL_OUTER, TessLevelOuter);
367       HANDLE_EMIT_BUILTIN(TESS_LEVEL_INNER, TessLevelInner);
368 
369       case VARYING_SLOT_CLIP_DIST0:
370          assert(glsl_type_is_array(var->type));
371          spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInClipDistance);
372          /* this can be as large as 2x vec4, which requires 2 slots */
373          ctx->outputs[VARYING_SLOT_CLIP_DIST1] = var_id;
374          ctx->so_output_gl_types[VARYING_SLOT_CLIP_DIST1] = var->type;
375          ctx->so_output_types[VARYING_SLOT_CLIP_DIST1] = var_type;
376          break;
377 
378       default:
379          slot = handle_slot(ctx, slot);
380          spirv_builder_emit_location(&ctx->builder, var_id, slot);
381       }
382       ctx->outputs[var->data.location] = var_id;
383       ctx->so_output_gl_types[var->data.location] = var->type;
384       ctx->so_output_types[var->data.location] = var_type;
385    } else {
386       if (var->data.location >= FRAG_RESULT_DATA0) {
387          spirv_builder_emit_location(&ctx->builder, var_id,
388                                      var->data.location - FRAG_RESULT_DATA0);
389          spirv_builder_emit_index(&ctx->builder, var_id, var->data.index);
390       } else {
391          switch (var->data.location) {
392          case FRAG_RESULT_COLOR:
393             unreachable("gl_FragColor should be lowered by now");
394 
395          case FRAG_RESULT_DEPTH:
396             spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
397             break;
398 
399          case FRAG_RESULT_SAMPLE_MASK:
400             spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInSampleMask);
401             break;
402 
403          default:
404             slot = handle_slot(ctx, slot);
405             spirv_builder_emit_location(&ctx->builder, var_id, slot);
406             spirv_builder_emit_index(&ctx->builder, var_id, var->data.index);
407          }
408       }
409       if (var->data.sample)
410          spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationSample);
411    }
412 
413    if (var->data.location_frac)
414       spirv_builder_emit_component(&ctx->builder, var_id,
415                                    var->data.location_frac);
416 
417    switch (var->data.interpolation) {
418    case INTERP_MODE_NONE:
419    case INTERP_MODE_SMOOTH: /* XXX spirv doesn't seem to have anything for this */
420       break;
421    case INTERP_MODE_FLAT:
422       spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
423       break;
424    case INTERP_MODE_EXPLICIT:
425       spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationExplicitInterpAMD);
426       break;
427    case INTERP_MODE_NOPERSPECTIVE:
428       spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationNoPerspective);
429       break;
430    default:
431       unreachable("unknown interpolation value");
432    }
433 
434    _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
435 
436    assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
437    ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
438 }
439 
440 static SpvDim
type_to_dim(enum glsl_sampler_dim gdim,bool * is_ms)441 type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
442 {
443    *is_ms = false;
444    switch (gdim) {
445    case GLSL_SAMPLER_DIM_1D:
446       return SpvDim1D;
447    case GLSL_SAMPLER_DIM_2D:
448       return SpvDim2D;
449    case GLSL_SAMPLER_DIM_3D:
450       return SpvDim3D;
451    case GLSL_SAMPLER_DIM_CUBE:
452       return SpvDimCube;
453    case GLSL_SAMPLER_DIM_RECT:
454       return SpvDim2D;
455    case GLSL_SAMPLER_DIM_BUF:
456       return SpvDimBuffer;
457    case GLSL_SAMPLER_DIM_EXTERNAL:
458       return SpvDim2D; /* seems dodgy... */
459    case GLSL_SAMPLER_DIM_MS:
460       *is_ms = true;
461       return SpvDim2D;
462    default:
463       fprintf(stderr, "unknown sampler type %d\n", gdim);
464       break;
465    }
466    return SpvDim2D;
467 }
468 
469 uint32_t
zink_binding(gl_shader_stage stage,VkDescriptorType type,int index)470 zink_binding(gl_shader_stage stage, VkDescriptorType type, int index)
471 {
472    if (stage == MESA_SHADER_NONE ||
473        stage >= MESA_SHADER_COMPUTE) {
474       unreachable("not supported");
475    } else {
476       uint32_t stage_offset = (uint32_t)stage * (PIPE_MAX_CONSTANT_BUFFERS +
477                                                  PIPE_MAX_SHADER_SAMPLER_VIEWS);
478 
479       switch (type) {
480       case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
481          assert(index < PIPE_MAX_CONSTANT_BUFFERS);
482          return stage_offset + index;
483 
484       case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
485       case VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER:
486          assert(index < PIPE_MAX_SHADER_SAMPLER_VIEWS);
487          return stage_offset + PIPE_MAX_CONSTANT_BUFFERS + index;
488 
489       default:
490          unreachable("unexpected type");
491       }
492    }
493 }
494 
495 static void
emit_sampler(struct ntv_context * ctx,struct nir_variable * var)496 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
497 {
498    const struct glsl_type *type = glsl_without_array(var->type);
499 
500    bool is_ms;
501    SpvDim dimension = type_to_dim(glsl_get_sampler_dim(type), &is_ms);
502 
503    SpvId result_type = get_glsl_basetype(ctx, glsl_get_sampler_result_type(type));
504    SpvId image_type = spirv_builder_type_image(&ctx->builder, result_type,
505                                                dimension, false,
506                                                glsl_sampler_type_is_array(type),
507                                                is_ms, 1,
508                                                SpvImageFormatUnknown);
509 
510    SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
511                                                          image_type);
512    SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
513                                                    SpvStorageClassUniformConstant,
514                                                    sampled_type);
515 
516    if (glsl_type_is_array(var->type)) {
517       /* ARB_arrays_of_arrays from GLSL 1.30 allows nesting of arrays, so we just
518        * use the total array size if we encounter a nested array
519        */
520       unsigned size = glsl_get_aoa_size(var->type);
521       for (int i = 0; i < size; ++i) {
522          SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
523                                                SpvStorageClassUniformConstant);
524 
525          if (var->name) {
526             char element_name[100];
527             snprintf(element_name, sizeof(element_name), "%s_%d", var->name, i);
528             spirv_builder_emit_name(&ctx->builder, var_id, var->name);
529          }
530 
531          int index = var->data.binding + i;
532          assert(!(ctx->samplers_used & (1 << index)));
533          assert(!ctx->image_types[index]);
534          ctx->image_types[index] = image_type;
535          ctx->samplers[index] = var_id;
536          ctx->samplers_used |= 1 << index;
537 
538          spirv_builder_emit_descriptor_set(&ctx->builder, var_id, 0);
539          int binding = zink_binding(ctx->stage,
540                                     zink_sampler_type(glsl_without_array(var->type)),
541                                     var->data.binding + i);
542          spirv_builder_emit_binding(&ctx->builder, var_id, binding);
543       }
544    } else {
545       SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
546                                             SpvStorageClassUniformConstant);
547 
548       if (var->name)
549          spirv_builder_emit_name(&ctx->builder, var_id, var->name);
550 
551       int index = var->data.binding;
552       assert(!(ctx->samplers_used & (1 << index)));
553       assert(!ctx->image_types[index]);
554       ctx->image_types[index] = image_type;
555       ctx->samplers[index] = var_id;
556       ctx->samplers_used |= 1 << index;
557 
558       spirv_builder_emit_descriptor_set(&ctx->builder, var_id, 0);
559       int binding = zink_binding(ctx->stage,
560                                  zink_sampler_type(var->type),
561                                  var->data.binding);
562       spirv_builder_emit_binding(&ctx->builder, var_id, binding);
563    }
564 }
565 
566 static void
emit_ubo(struct ntv_context * ctx,struct nir_variable * var)567 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
568 {
569    /* variables accessed inside a uniform block will get merged into a big
570     * memory blob and accessed by offset
571     */
572    if (var->data.location)
573       return;
574 
575    uint32_t size = glsl_count_attribute_slots(var->interface_type, false);
576    SpvId vec4_type = get_uvec_type(ctx, 32, 4);
577    SpvId array_length = emit_uint_const(ctx, 32, size);
578    SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
579                                                array_length);
580    spirv_builder_emit_array_stride(&ctx->builder, array_type, 16);
581 
582    // wrap UBO-array in a struct
583    SpvId struct_type = spirv_builder_type_struct(&ctx->builder, &array_type, 1);
584    if (var->name) {
585       char struct_name[100];
586       snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
587       spirv_builder_emit_name(&ctx->builder, struct_type, struct_name);
588    }
589 
590    spirv_builder_emit_decoration(&ctx->builder, struct_type,
591                                  SpvDecorationBlock);
592    spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
593 
594 
595    SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
596                                                    SpvStorageClassUniform,
597                                                    struct_type);
598 
599    SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
600                                          SpvStorageClassUniform);
601    if (var->name)
602       spirv_builder_emit_name(&ctx->builder, var_id, var->name);
603 
604    assert(ctx->num_ubos < ARRAY_SIZE(ctx->ubos));
605    ctx->ubos[ctx->num_ubos++] = var_id;
606 
607    spirv_builder_emit_descriptor_set(&ctx->builder, var_id, 0);
608    int binding = zink_binding(ctx->stage,
609                               VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
610                               var->data.binding);
611    spirv_builder_emit_binding(&ctx->builder, var_id, binding);
612 }
613 
614 static void
emit_uniform(struct ntv_context * ctx,struct nir_variable * var)615 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
616 {
617    if (var->data.mode == nir_var_mem_ubo)
618       emit_ubo(ctx, var);
619    else {
620       assert(var->data.mode == nir_var_uniform);
621       if (glsl_type_is_sampler(glsl_without_array(var->type)))
622          emit_sampler(ctx, var);
623    }
624 }
625 
626 static SpvId
get_vec_from_bit_size(struct ntv_context * ctx,uint32_t bit_size,uint32_t num_components)627 get_vec_from_bit_size(struct ntv_context *ctx, uint32_t bit_size, uint32_t num_components)
628 {
629    if (bit_size == 1)
630       return get_bvec_type(ctx, num_components);
631    if (bit_size == 32)
632       return get_uvec_type(ctx, bit_size, num_components);
633    unreachable("unhandled register bit size");
634    return 0;
635 }
636 
637 static SpvId
get_src_ssa(struct ntv_context * ctx,const nir_ssa_def * ssa)638 get_src_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
639 {
640    assert(ssa->index < ctx->num_defs);
641    assert(ctx->defs[ssa->index] != 0);
642    return ctx->defs[ssa->index];
643 }
644 
645 static SpvId
get_var_from_reg(struct ntv_context * ctx,nir_register * reg)646 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
647 {
648    assert(reg->index < ctx->num_regs);
649    assert(ctx->regs[reg->index] != 0);
650    return ctx->regs[reg->index];
651 }
652 
653 static SpvId
get_src_reg(struct ntv_context * ctx,const nir_reg_src * reg)654 get_src_reg(struct ntv_context *ctx, const nir_reg_src *reg)
655 {
656    assert(reg->reg);
657    assert(!reg->indirect);
658    assert(!reg->base_offset);
659 
660    SpvId var = get_var_from_reg(ctx, reg->reg);
661    SpvId type = get_vec_from_bit_size(ctx, reg->reg->bit_size, reg->reg->num_components);
662    return spirv_builder_emit_load(&ctx->builder, type, var);
663 }
664 
665 static SpvId
get_src(struct ntv_context * ctx,nir_src * src)666 get_src(struct ntv_context *ctx, nir_src *src)
667 {
668    if (src->is_ssa)
669       return get_src_ssa(ctx, src->ssa);
670    else
671       return get_src_reg(ctx, &src->reg);
672 }
673 
674 static SpvId
get_alu_src_raw(struct ntv_context * ctx,nir_alu_instr * alu,unsigned src)675 get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
676 {
677    assert(!alu->src[src].negate);
678    assert(!alu->src[src].abs);
679 
680    SpvId def = get_src(ctx, &alu->src[src].src);
681 
682    unsigned used_channels = 0;
683    bool need_swizzle = false;
684    for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
685       if (!nir_alu_instr_channel_used(alu, src, i))
686          continue;
687 
688       used_channels++;
689 
690       if (alu->src[src].swizzle[i] != i)
691          need_swizzle = true;
692    }
693    assert(used_channels != 0);
694 
695    unsigned live_channels = nir_src_num_components(alu->src[src].src);
696    if (used_channels != live_channels)
697       need_swizzle = true;
698 
699    if (!need_swizzle)
700       return def;
701 
702    int bit_size = nir_src_bit_size(alu->src[src].src);
703    assert(bit_size == 1 || bit_size == 32);
704 
705    SpvId raw_type = bit_size == 1 ? spirv_builder_type_bool(&ctx->builder) :
706                                     spirv_builder_type_uint(&ctx->builder, bit_size);
707 
708    if (used_channels == 1) {
709       uint32_t indices[] =  { alu->src[src].swizzle[0] };
710       return spirv_builder_emit_composite_extract(&ctx->builder, raw_type,
711                                                   def, indices,
712                                                   ARRAY_SIZE(indices));
713    } else if (live_channels == 1) {
714       SpvId raw_vec_type = spirv_builder_type_vector(&ctx->builder,
715                                                      raw_type,
716                                                      used_channels);
717 
718       SpvId constituents[NIR_MAX_VEC_COMPONENTS] = {0};
719       for (unsigned i = 0; i < used_channels; ++i)
720         constituents[i] = def;
721 
722       return spirv_builder_emit_composite_construct(&ctx->builder,
723                                                     raw_vec_type,
724                                                     constituents,
725                                                     used_channels);
726    } else {
727       SpvId raw_vec_type = spirv_builder_type_vector(&ctx->builder,
728                                                      raw_type,
729                                                      used_channels);
730 
731       uint32_t components[NIR_MAX_VEC_COMPONENTS] = {0};
732       size_t num_components = 0;
733       for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
734          if (!nir_alu_instr_channel_used(alu, src, i))
735             continue;
736 
737          components[num_components++] = alu->src[src].swizzle[i];
738       }
739 
740       return spirv_builder_emit_vector_shuffle(&ctx->builder, raw_vec_type,
741                                                def, def, components,
742                                                num_components);
743    }
744 }
745 
746 static void
store_ssa_def(struct ntv_context * ctx,nir_ssa_def * ssa,SpvId result)747 store_ssa_def(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
748 {
749    assert(result != 0);
750    assert(ssa->index < ctx->num_defs);
751    ctx->defs[ssa->index] = result;
752 }
753 
754 static SpvId
emit_select(struct ntv_context * ctx,SpvId type,SpvId cond,SpvId if_true,SpvId if_false)755 emit_select(struct ntv_context *ctx, SpvId type, SpvId cond,
756             SpvId if_true, SpvId if_false)
757 {
758    return emit_triop(ctx, SpvOpSelect, type, cond, if_true, if_false);
759 }
760 
761 static SpvId
uvec_to_bvec(struct ntv_context * ctx,SpvId value,unsigned num_components)762 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
763 {
764    SpvId type = get_bvec_type(ctx, num_components);
765    SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
766    return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
767 }
768 
769 static SpvId
emit_bitcast(struct ntv_context * ctx,SpvId type,SpvId value)770 emit_bitcast(struct ntv_context *ctx, SpvId type, SpvId value)
771 {
772    return emit_unop(ctx, SpvOpBitcast, type, value);
773 }
774 
775 static SpvId
bitcast_to_uvec(struct ntv_context * ctx,SpvId value,unsigned bit_size,unsigned num_components)776 bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
777                 unsigned num_components)
778 {
779    SpvId type = get_uvec_type(ctx, bit_size, num_components);
780    return emit_bitcast(ctx, type, value);
781 }
782 
783 static SpvId
bitcast_to_ivec(struct ntv_context * ctx,SpvId value,unsigned bit_size,unsigned num_components)784 bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
785                 unsigned num_components)
786 {
787    SpvId type = get_ivec_type(ctx, bit_size, num_components);
788    return emit_bitcast(ctx, type, value);
789 }
790 
791 static SpvId
bitcast_to_fvec(struct ntv_context * ctx,SpvId value,unsigned bit_size,unsigned num_components)792 bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
793                unsigned num_components)
794 {
795    SpvId type = get_fvec_type(ctx, bit_size, num_components);
796    return emit_bitcast(ctx, type, value);
797 }
798 
799 static void
store_reg_def(struct ntv_context * ctx,nir_reg_dest * reg,SpvId result)800 store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
801 {
802    SpvId var = get_var_from_reg(ctx, reg->reg);
803    assert(var);
804    spirv_builder_emit_store(&ctx->builder, var, result);
805 }
806 
807 static void
store_dest_raw(struct ntv_context * ctx,nir_dest * dest,SpvId result)808 store_dest_raw(struct ntv_context *ctx, nir_dest *dest, SpvId result)
809 {
810    if (dest->is_ssa)
811       store_ssa_def(ctx, &dest->ssa, result);
812    else
813       store_reg_def(ctx, &dest->reg, result);
814 }
815 
816 static SpvId
store_dest(struct ntv_context * ctx,nir_dest * dest,SpvId result,nir_alu_type type)817 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
818 {
819    unsigned num_components = nir_dest_num_components(*dest);
820    unsigned bit_size = nir_dest_bit_size(*dest);
821 
822    if (bit_size != 1) {
823       switch (nir_alu_type_get_base_type(type)) {
824       case nir_type_bool:
825          assert("bool should have bit-size 1");
826 
827       case nir_type_uint:
828          break; /* nothing to do! */
829 
830       case nir_type_int:
831       case nir_type_float:
832          result = bitcast_to_uvec(ctx, result, bit_size, num_components);
833          break;
834 
835       default:
836          unreachable("unsupported nir_alu_type");
837       }
838    }
839 
840    store_dest_raw(ctx, dest, result);
841    return result;
842 }
843 
844 static SpvId
emit_unop(struct ntv_context * ctx,SpvOp op,SpvId type,SpvId src)845 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src)
846 {
847    return spirv_builder_emit_unop(&ctx->builder, op, type, src);
848 }
849 
850 /* return the intended xfb output vec type based on base type and vector size */
851 static SpvId
get_output_type(struct ntv_context * ctx,unsigned register_index,unsigned num_components)852 get_output_type(struct ntv_context *ctx, unsigned register_index, unsigned num_components)
853 {
854    const struct glsl_type *out_type = ctx->so_output_gl_types[register_index];
855    enum glsl_base_type base_type = glsl_get_base_type(out_type);
856    if (base_type == GLSL_TYPE_ARRAY)
857       base_type = glsl_get_base_type(glsl_without_array(out_type));
858 
859    switch (base_type) {
860    case GLSL_TYPE_BOOL:
861       return get_bvec_type(ctx, num_components);
862 
863    case GLSL_TYPE_FLOAT:
864       return get_fvec_type(ctx, 32, num_components);
865 
866    case GLSL_TYPE_INT:
867       return get_ivec_type(ctx, 32, num_components);
868 
869    case GLSL_TYPE_UINT:
870       return get_uvec_type(ctx, 32, num_components);
871 
872    default:
873       break;
874    }
875    unreachable("unknown type");
876    return 0;
877 }
878 
879 /* for streamout create new outputs, as streamout can be done on individual components,
880    from complete outputs, so we just can't use the created packed outputs */
881 static void
emit_so_info(struct ntv_context * ctx,const struct zink_so_info * so_info)882 emit_so_info(struct ntv_context *ctx, const struct zink_so_info *so_info)
883 {
884    for (unsigned i = 0; i < so_info->so_info.num_outputs; i++) {
885       struct pipe_stream_output so_output = so_info->so_info.output[i];
886       unsigned slot = so_info->so_info_slots[i];
887       SpvId out_type = get_output_type(ctx, slot, so_output.num_components);
888       SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
889                                                       SpvStorageClassOutput,
890                                                       out_type);
891       SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
892                                             SpvStorageClassOutput);
893       char name[10];
894 
895       snprintf(name, 10, "xfb%d", i);
896       spirv_builder_emit_name(&ctx->builder, var_id, name);
897       spirv_builder_emit_offset(&ctx->builder, var_id, (so_output.dst_offset * 4));
898       spirv_builder_emit_xfb_buffer(&ctx->builder, var_id, so_output.output_buffer);
899       spirv_builder_emit_xfb_stride(&ctx->builder, var_id, so_info->so_info.stride[so_output.output_buffer] * 4);
900 
901       /* output location is incremented by VARYING_SLOT_VAR0 for non-builtins in vtn,
902        * so we need to ensure that the new xfb location slot doesn't conflict with any previously-emitted
903        * outputs.
904        */
905       uint32_t location = reserve_slot(ctx);
906       assert(location < VARYING_SLOT_VAR0);
907       spirv_builder_emit_location(&ctx->builder, var_id, location);
908 
909       /* note: gl_ClipDistance[4] can the 0-indexed member of VARYING_SLOT_CLIP_DIST1 here,
910        * so this is still the 0 component
911        */
912       if (so_output.start_component)
913          spirv_builder_emit_component(&ctx->builder, var_id, so_output.start_component);
914 
915       uint32_t *key = ralloc_size(ctx->mem_ctx, sizeof(uint32_t));
916       *key = (uint32_t)so_output.register_index << 2 | so_output.start_component;
917       _mesa_hash_table_insert(ctx->so_outputs, key, (void *)(intptr_t)var_id);
918 
919       assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
920       ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
921    }
922 }
923 
924 static void
emit_so_outputs(struct ntv_context * ctx,const struct zink_so_info * so_info)925 emit_so_outputs(struct ntv_context *ctx,
926                 const struct zink_so_info *so_info)
927 {
928    SpvId loaded_outputs[VARYING_SLOT_MAX] = {};
929    for (unsigned i = 0; i < so_info->so_info.num_outputs; i++) {
930       uint32_t components[NIR_MAX_VEC_COMPONENTS];
931       unsigned slot = so_info->so_info_slots[i];
932       struct pipe_stream_output so_output = so_info->so_info.output[i];
933       uint32_t so_key = (uint32_t) so_output.register_index << 2 | so_output.start_component;
934       struct hash_entry *he = _mesa_hash_table_search(ctx->so_outputs, &so_key);
935       assert(he);
936       SpvId so_output_var_id = (SpvId)(intptr_t)he->data;
937 
938       SpvId type = get_output_type(ctx, slot, so_output.num_components);
939       SpvId output = ctx->outputs[slot];
940       SpvId output_type = ctx->so_output_types[slot];
941       const struct glsl_type *out_type = ctx->so_output_gl_types[slot];
942 
943       if (!loaded_outputs[slot])
944          loaded_outputs[slot] = spirv_builder_emit_load(&ctx->builder, output_type, output);
945       SpvId src = loaded_outputs[slot];
946 
947       SpvId result;
948 
949       for (unsigned c = 0; c < so_output.num_components; c++) {
950          components[c] = so_output.start_component + c;
951          /* this is the second half of a 2 * vec4 array */
952          if (slot == VARYING_SLOT_CLIP_DIST1)
953             components[c] += 4;
954       }
955 
956       /* if we're emitting a scalar or the type we're emitting matches the output's original type and we're
957        * emitting the same number of components, then we can skip any sort of conversion here
958        */
959       if (glsl_type_is_scalar(out_type) || (type == output_type && glsl_get_length(out_type) == so_output.num_components))
960          result = src;
961       else {
962          /* OpCompositeExtract can only extract scalars for our use here */
963          if (so_output.num_components == 1) {
964             result = spirv_builder_emit_composite_extract(&ctx->builder, type, src, components, so_output.num_components);
965          } else if (glsl_type_is_vector(out_type)) {
966             /* OpVectorShuffle can select vector members into a differently-sized vector */
967             result = spirv_builder_emit_vector_shuffle(&ctx->builder, type,
968                                                              src, src,
969                                                              components, so_output.num_components);
970             result = emit_unop(ctx, SpvOpBitcast, type, result);
971          } else {
972              /* for arrays, we need to manually extract each desired member
973               * and re-pack them into the desired output type
974               */
975              for (unsigned c = 0; c < so_output.num_components; c++) {
976                 uint32_t member[] = { so_output.start_component + c };
977                 SpvId base_type = get_glsl_type(ctx, glsl_without_array(out_type));
978 
979                 if (slot == VARYING_SLOT_CLIP_DIST1)
980                    member[0] += 4;
981                 components[c] = spirv_builder_emit_composite_extract(&ctx->builder, base_type, src, member, 1);
982              }
983              result = spirv_builder_emit_composite_construct(&ctx->builder, type, components, so_output.num_components);
984          }
985       }
986 
987       spirv_builder_emit_store(&ctx->builder, so_output_var_id, result);
988    }
989 }
990 
991 static SpvId
emit_binop(struct ntv_context * ctx,SpvOp op,SpvId type,SpvId src0,SpvId src1)992 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
993            SpvId src0, SpvId src1)
994 {
995    return spirv_builder_emit_binop(&ctx->builder, op, type, src0, src1);
996 }
997 
998 static SpvId
emit_triop(struct ntv_context * ctx,SpvOp op,SpvId type,SpvId src0,SpvId src1,SpvId src2)999 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
1000            SpvId src0, SpvId src1, SpvId src2)
1001 {
1002    return spirv_builder_emit_triop(&ctx->builder, op, type, src0, src1, src2);
1003 }
1004 
1005 static SpvId
emit_builtin_unop(struct ntv_context * ctx,enum GLSLstd450 op,SpvId type,SpvId src)1006 emit_builtin_unop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
1007                   SpvId src)
1008 {
1009    SpvId args[] = { src };
1010    return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
1011                                       op, args, ARRAY_SIZE(args));
1012 }
1013 
1014 static SpvId
emit_builtin_binop(struct ntv_context * ctx,enum GLSLstd450 op,SpvId type,SpvId src0,SpvId src1)1015 emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
1016                    SpvId src0, SpvId src1)
1017 {
1018    SpvId args[] = { src0, src1 };
1019    return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
1020                                       op, args, ARRAY_SIZE(args));
1021 }
1022 
1023 static SpvId
emit_builtin_triop(struct ntv_context * ctx,enum GLSLstd450 op,SpvId type,SpvId src0,SpvId src1,SpvId src2)1024 emit_builtin_triop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
1025                    SpvId src0, SpvId src1, SpvId src2)
1026 {
1027    SpvId args[] = { src0, src1, src2 };
1028    return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
1029                                       op, args, ARRAY_SIZE(args));
1030 }
1031 
1032 static SpvId
get_fvec_constant(struct ntv_context * ctx,unsigned bit_size,unsigned num_components,float value)1033 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
1034                   unsigned num_components, float value)
1035 {
1036    assert(bit_size == 32);
1037 
1038    SpvId result = emit_float_const(ctx, bit_size, value);
1039    if (num_components == 1)
1040       return result;
1041 
1042    assert(num_components > 1);
1043    SpvId components[num_components];
1044    for (int i = 0; i < num_components; i++)
1045       components[i] = result;
1046 
1047    SpvId type = get_fvec_type(ctx, bit_size, num_components);
1048    return spirv_builder_const_composite(&ctx->builder, type, components,
1049                                         num_components);
1050 }
1051 
1052 static SpvId
get_uvec_constant(struct ntv_context * ctx,unsigned bit_size,unsigned num_components,uint32_t value)1053 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
1054                   unsigned num_components, uint32_t value)
1055 {
1056    assert(bit_size == 32);
1057 
1058    SpvId result = emit_uint_const(ctx, bit_size, value);
1059    if (num_components == 1)
1060       return result;
1061 
1062    assert(num_components > 1);
1063    SpvId components[num_components];
1064    for (int i = 0; i < num_components; i++)
1065       components[i] = result;
1066 
1067    SpvId type = get_uvec_type(ctx, bit_size, num_components);
1068    return spirv_builder_const_composite(&ctx->builder, type, components,
1069                                         num_components);
1070 }
1071 
1072 static SpvId
get_ivec_constant(struct ntv_context * ctx,unsigned bit_size,unsigned num_components,int32_t value)1073 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
1074                   unsigned num_components, int32_t value)
1075 {
1076    assert(bit_size == 32);
1077 
1078    SpvId result = emit_int_const(ctx, bit_size, value);
1079    if (num_components == 1)
1080       return result;
1081 
1082    assert(num_components > 1);
1083    SpvId components[num_components];
1084    for (int i = 0; i < num_components; i++)
1085       components[i] = result;
1086 
1087    SpvId type = get_ivec_type(ctx, bit_size, num_components);
1088    return spirv_builder_const_composite(&ctx->builder, type, components,
1089                                         num_components);
1090 }
1091 
1092 static inline unsigned
alu_instr_src_components(const nir_alu_instr * instr,unsigned src)1093 alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
1094 {
1095    if (nir_op_infos[instr->op].input_sizes[src] > 0)
1096       return nir_op_infos[instr->op].input_sizes[src];
1097 
1098    if (instr->dest.dest.is_ssa)
1099       return instr->dest.dest.ssa.num_components;
1100    else
1101       return instr->dest.dest.reg.reg->num_components;
1102 }
1103 
1104 static SpvId
get_alu_src(struct ntv_context * ctx,nir_alu_instr * alu,unsigned src)1105 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
1106 {
1107    SpvId raw_value = get_alu_src_raw(ctx, alu, src);
1108 
1109    unsigned num_components = alu_instr_src_components(alu, src);
1110    unsigned bit_size = nir_src_bit_size(alu->src[src].src);
1111    nir_alu_type type = nir_op_infos[alu->op].input_types[src];
1112 
1113    if (bit_size == 1)
1114       return raw_value;
1115    else {
1116       switch (nir_alu_type_get_base_type(type)) {
1117       case nir_type_bool:
1118          unreachable("bool should have bit-size 1");
1119 
1120       case nir_type_int:
1121          return bitcast_to_ivec(ctx, raw_value, bit_size, num_components);
1122 
1123       case nir_type_uint:
1124          return raw_value;
1125 
1126       case nir_type_float:
1127          return bitcast_to_fvec(ctx, raw_value, bit_size, num_components);
1128 
1129       default:
1130          unreachable("unknown nir_alu_type");
1131       }
1132    }
1133 }
1134 
1135 static SpvId
store_alu_result(struct ntv_context * ctx,nir_alu_instr * alu,SpvId result)1136 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
1137 {
1138    assert(!alu->dest.saturate);
1139    return store_dest(ctx, &alu->dest.dest, result,
1140                      nir_op_infos[alu->op].output_type);
1141 }
1142 
1143 static SpvId
get_dest_type(struct ntv_context * ctx,nir_dest * dest,nir_alu_type type)1144 get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
1145 {
1146    unsigned num_components = nir_dest_num_components(*dest);
1147    unsigned bit_size = nir_dest_bit_size(*dest);
1148 
1149    if (bit_size == 1)
1150       return get_bvec_type(ctx, num_components);
1151 
1152    switch (nir_alu_type_get_base_type(type)) {
1153    case nir_type_bool:
1154       unreachable("bool should have bit-size 1");
1155 
1156    case nir_type_int:
1157       return get_ivec_type(ctx, bit_size, num_components);
1158 
1159    case nir_type_uint:
1160       return get_uvec_type(ctx, bit_size, num_components);
1161 
1162    case nir_type_float:
1163       return get_fvec_type(ctx, bit_size, num_components);
1164 
1165    default:
1166       unreachable("unsupported nir_alu_type");
1167    }
1168 }
1169 
1170 static void
emit_alu(struct ntv_context * ctx,nir_alu_instr * alu)1171 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
1172 {
1173    SpvId src[nir_op_infos[alu->op].num_inputs];
1174    unsigned in_bit_sizes[nir_op_infos[alu->op].num_inputs];
1175    for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) {
1176       src[i] = get_alu_src(ctx, alu, i);
1177       in_bit_sizes[i] = nir_src_bit_size(alu->src[i].src);
1178    }
1179 
1180    SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
1181                                    nir_op_infos[alu->op].output_type);
1182    unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
1183    unsigned num_components = nir_dest_num_components(alu->dest.dest);
1184 
1185    SpvId result = 0;
1186    switch (alu->op) {
1187    case nir_op_mov:
1188       assert(nir_op_infos[alu->op].num_inputs == 1);
1189       result = src[0];
1190       break;
1191 
1192 #define UNOP(nir_op, spirv_op) \
1193    case nir_op: \
1194       assert(nir_op_infos[alu->op].num_inputs == 1); \
1195       result = emit_unop(ctx, spirv_op, dest_type, src[0]); \
1196       break;
1197 
1198    UNOP(nir_op_ineg, SpvOpSNegate)
1199    UNOP(nir_op_fneg, SpvOpFNegate)
1200    UNOP(nir_op_fddx, SpvOpDPdx)
1201    UNOP(nir_op_fddx_coarse, SpvOpDPdxCoarse)
1202    UNOP(nir_op_fddx_fine, SpvOpDPdxFine)
1203    UNOP(nir_op_fddy, SpvOpDPdy)
1204    UNOP(nir_op_fddy_coarse, SpvOpDPdyCoarse)
1205    UNOP(nir_op_fddy_fine, SpvOpDPdyFine)
1206    UNOP(nir_op_f2i32, SpvOpConvertFToS)
1207    UNOP(nir_op_f2u32, SpvOpConvertFToU)
1208    UNOP(nir_op_i2f32, SpvOpConvertSToF)
1209    UNOP(nir_op_u2f32, SpvOpConvertUToF)
1210    UNOP(nir_op_bitfield_reverse, SpvOpBitReverse)
1211 #undef UNOP
1212 
1213    case nir_op_inot:
1214       if (bit_size == 1)
1215          result = emit_unop(ctx, SpvOpLogicalNot, dest_type, src[0]);
1216       else
1217          result = emit_unop(ctx, SpvOpNot, dest_type, src[0]);
1218       break;
1219 
1220    case nir_op_b2i32:
1221       assert(nir_op_infos[alu->op].num_inputs == 1);
1222       result = emit_select(ctx, dest_type, src[0],
1223                            get_ivec_constant(ctx, 32, num_components, 1),
1224                            get_ivec_constant(ctx, 32, num_components, 0));
1225       break;
1226 
1227    case nir_op_b2f32:
1228       assert(nir_op_infos[alu->op].num_inputs == 1);
1229       result = emit_select(ctx, dest_type, src[0],
1230                            get_fvec_constant(ctx, 32, num_components, 1),
1231                            get_fvec_constant(ctx, 32, num_components, 0));
1232       break;
1233 
1234 #define BUILTIN_UNOP(nir_op, spirv_op) \
1235    case nir_op: \
1236       assert(nir_op_infos[alu->op].num_inputs == 1); \
1237       result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
1238       break;
1239 
1240    BUILTIN_UNOP(nir_op_iabs, GLSLstd450SAbs)
1241    BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
1242    BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
1243    BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
1244    BUILTIN_UNOP(nir_op_flog2, GLSLstd450Log2)
1245    BUILTIN_UNOP(nir_op_fexp2, GLSLstd450Exp2)
1246    BUILTIN_UNOP(nir_op_ffract, GLSLstd450Fract)
1247    BUILTIN_UNOP(nir_op_ffloor, GLSLstd450Floor)
1248    BUILTIN_UNOP(nir_op_fceil, GLSLstd450Ceil)
1249    BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
1250    BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
1251    BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
1252    BUILTIN_UNOP(nir_op_isign, GLSLstd450SSign)
1253    BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
1254    BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
1255 #undef BUILTIN_UNOP
1256 
1257    case nir_op_frcp:
1258       assert(nir_op_infos[alu->op].num_inputs == 1);
1259       result = emit_binop(ctx, SpvOpFDiv, dest_type,
1260                           get_fvec_constant(ctx, bit_size, num_components, 1),
1261                           src[0]);
1262       break;
1263 
1264    case nir_op_f2b1:
1265       assert(nir_op_infos[alu->op].num_inputs == 1);
1266       result = emit_binop(ctx, SpvOpFOrdNotEqual, dest_type, src[0],
1267                           get_fvec_constant(ctx,
1268                                             nir_src_bit_size(alu->src[0].src),
1269                                             num_components, 0));
1270       break;
1271    case nir_op_i2b1:
1272       assert(nir_op_infos[alu->op].num_inputs == 1);
1273       result = emit_binop(ctx, SpvOpINotEqual, dest_type, src[0],
1274                           get_ivec_constant(ctx,
1275                                             nir_src_bit_size(alu->src[0].src),
1276                                             num_components, 0));
1277       break;
1278 
1279 
1280 #define BINOP(nir_op, spirv_op) \
1281    case nir_op: \
1282       assert(nir_op_infos[alu->op].num_inputs == 2); \
1283       result = emit_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
1284       break;
1285 
1286    BINOP(nir_op_iadd, SpvOpIAdd)
1287    BINOP(nir_op_isub, SpvOpISub)
1288    BINOP(nir_op_imul, SpvOpIMul)
1289    BINOP(nir_op_idiv, SpvOpSDiv)
1290    BINOP(nir_op_udiv, SpvOpUDiv)
1291    BINOP(nir_op_umod, SpvOpUMod)
1292    BINOP(nir_op_fadd, SpvOpFAdd)
1293    BINOP(nir_op_fsub, SpvOpFSub)
1294    BINOP(nir_op_fmul, SpvOpFMul)
1295    BINOP(nir_op_fdiv, SpvOpFDiv)
1296    BINOP(nir_op_fmod, SpvOpFMod)
1297    BINOP(nir_op_ilt, SpvOpSLessThan)
1298    BINOP(nir_op_ige, SpvOpSGreaterThanEqual)
1299    BINOP(nir_op_ult, SpvOpULessThan)
1300    BINOP(nir_op_uge, SpvOpUGreaterThanEqual)
1301    BINOP(nir_op_flt, SpvOpFOrdLessThan)
1302    BINOP(nir_op_fge, SpvOpFOrdGreaterThanEqual)
1303    BINOP(nir_op_feq, SpvOpFOrdEqual)
1304    BINOP(nir_op_fneu, SpvOpFUnordNotEqual)
1305    BINOP(nir_op_ishl, SpvOpShiftLeftLogical)
1306    BINOP(nir_op_ishr, SpvOpShiftRightArithmetic)
1307    BINOP(nir_op_ushr, SpvOpShiftRightLogical)
1308    BINOP(nir_op_ixor, SpvOpBitwiseXor)
1309 #undef BINOP
1310 
1311 #define BINOP_LOG(nir_op, spv_op, spv_log_op) \
1312    case nir_op: \
1313       assert(nir_op_infos[alu->op].num_inputs == 2); \
1314       if (nir_src_bit_size(alu->src[0].src) == 1) \
1315          result = emit_binop(ctx, spv_log_op, dest_type, src[0], src[1]); \
1316       else \
1317          result = emit_binop(ctx, spv_op, dest_type, src[0], src[1]); \
1318       break;
1319 
1320    BINOP_LOG(nir_op_iand, SpvOpBitwiseAnd, SpvOpLogicalAnd)
1321    BINOP_LOG(nir_op_ior, SpvOpBitwiseOr, SpvOpLogicalOr)
1322    BINOP_LOG(nir_op_ieq, SpvOpIEqual, SpvOpLogicalEqual)
1323    BINOP_LOG(nir_op_ine, SpvOpINotEqual, SpvOpLogicalNotEqual)
1324 #undef BINOP_LOG
1325 
1326 #define BUILTIN_BINOP(nir_op, spirv_op) \
1327    case nir_op: \
1328       assert(nir_op_infos[alu->op].num_inputs == 2); \
1329       result = emit_builtin_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
1330       break;
1331 
1332    BUILTIN_BINOP(nir_op_fmin, GLSLstd450FMin)
1333    BUILTIN_BINOP(nir_op_fmax, GLSLstd450FMax)
1334    BUILTIN_BINOP(nir_op_imin, GLSLstd450SMin)
1335    BUILTIN_BINOP(nir_op_imax, GLSLstd450SMax)
1336    BUILTIN_BINOP(nir_op_umin, GLSLstd450UMin)
1337    BUILTIN_BINOP(nir_op_umax, GLSLstd450UMax)
1338 #undef BUILTIN_BINOP
1339 
1340    case nir_op_fdot2:
1341    case nir_op_fdot3:
1342    case nir_op_fdot4:
1343       assert(nir_op_infos[alu->op].num_inputs == 2);
1344       result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
1345       break;
1346 
1347    case nir_op_fdph:
1348       unreachable("should already be lowered away");
1349 
1350    case nir_op_seq:
1351    case nir_op_sne:
1352    case nir_op_slt:
1353    case nir_op_sge: {
1354       assert(nir_op_infos[alu->op].num_inputs == 2);
1355       int num_components = nir_dest_num_components(alu->dest.dest);
1356       SpvId bool_type = get_bvec_type(ctx, num_components);
1357 
1358       SpvId zero = emit_float_const(ctx, bit_size, 0.0f);
1359       SpvId one = emit_float_const(ctx, bit_size, 1.0f);
1360       if (num_components > 1) {
1361          SpvId zero_comps[num_components], one_comps[num_components];
1362          for (int i = 0; i < num_components; i++) {
1363             zero_comps[i] = zero;
1364             one_comps[i] = one;
1365          }
1366 
1367          zero = spirv_builder_const_composite(&ctx->builder, dest_type,
1368                                               zero_comps, num_components);
1369          one = spirv_builder_const_composite(&ctx->builder, dest_type,
1370                                              one_comps, num_components);
1371       }
1372 
1373       SpvOp op;
1374       switch (alu->op) {
1375       case nir_op_seq: op = SpvOpFOrdEqual; break;
1376       case nir_op_sne: op = SpvOpFOrdNotEqual; break;
1377       case nir_op_slt: op = SpvOpFOrdLessThan; break;
1378       case nir_op_sge: op = SpvOpFOrdGreaterThanEqual; break;
1379       default: unreachable("unexpected op");
1380       }
1381 
1382       result = emit_binop(ctx, op, bool_type, src[0], src[1]);
1383       result = emit_select(ctx, dest_type, result, one, zero);
1384       }
1385       break;
1386 
1387    case nir_op_flrp:
1388       assert(nir_op_infos[alu->op].num_inputs == 3);
1389       result = emit_builtin_triop(ctx, GLSLstd450FMix, dest_type,
1390                                   src[0], src[1], src[2]);
1391       break;
1392 
1393    case nir_op_fcsel:
1394       result = emit_binop(ctx, SpvOpFOrdGreaterThan,
1395                           get_bvec_type(ctx, num_components),
1396                           src[0],
1397                           get_fvec_constant(ctx,
1398                                             nir_src_bit_size(alu->src[0].src),
1399                                             num_components, 0));
1400       result = emit_select(ctx, dest_type, result, src[1], src[2]);
1401       break;
1402 
1403    case nir_op_bcsel:
1404       assert(nir_op_infos[alu->op].num_inputs == 3);
1405       result = emit_select(ctx, dest_type, src[0], src[1], src[2]);
1406       break;
1407 
1408    case nir_op_bany_fnequal2:
1409    case nir_op_bany_fnequal3:
1410    case nir_op_bany_fnequal4: {
1411       assert(nir_op_infos[alu->op].num_inputs == 2);
1412       assert(alu_instr_src_components(alu, 0) ==
1413              alu_instr_src_components(alu, 1));
1414       assert(in_bit_sizes[0] == in_bit_sizes[1]);
1415       /* The type of Operand 1 and Operand 2 must be a scalar or vector of floating-point type. */
1416       SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalNotEqual : SpvOpFOrdNotEqual;
1417       result = emit_binop(ctx, op,
1418                           get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1419                           src[0], src[1]);
1420       result = emit_unop(ctx, SpvOpAny, dest_type, result);
1421       break;
1422    }
1423 
1424    case nir_op_ball_fequal2:
1425    case nir_op_ball_fequal3:
1426    case nir_op_ball_fequal4: {
1427       assert(nir_op_infos[alu->op].num_inputs == 2);
1428       assert(alu_instr_src_components(alu, 0) ==
1429              alu_instr_src_components(alu, 1));
1430       assert(in_bit_sizes[0] == in_bit_sizes[1]);
1431       /* The type of Operand 1 and Operand 2 must be a scalar or vector of floating-point type. */
1432       SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalEqual : SpvOpFOrdEqual;
1433       result = emit_binop(ctx, op,
1434                           get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1435                           src[0], src[1]);
1436       result = emit_unop(ctx, SpvOpAll, dest_type, result);
1437       break;
1438    }
1439 
1440    case nir_op_bany_inequal2:
1441    case nir_op_bany_inequal3:
1442    case nir_op_bany_inequal4: {
1443       assert(nir_op_infos[alu->op].num_inputs == 2);
1444       assert(alu_instr_src_components(alu, 0) ==
1445              alu_instr_src_components(alu, 1));
1446       assert(in_bit_sizes[0] == in_bit_sizes[1]);
1447       /* The type of Operand 1 and Operand 2 must be a scalar or vector of integer type. */
1448       SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalNotEqual : SpvOpINotEqual;
1449       result = emit_binop(ctx, op,
1450                           get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1451                           src[0], src[1]);
1452       result = emit_unop(ctx, SpvOpAny, dest_type, result);
1453       break;
1454    }
1455 
1456    case nir_op_ball_iequal2:
1457    case nir_op_ball_iequal3:
1458    case nir_op_ball_iequal4: {
1459       assert(nir_op_infos[alu->op].num_inputs == 2);
1460       assert(alu_instr_src_components(alu, 0) ==
1461              alu_instr_src_components(alu, 1));
1462       assert(in_bit_sizes[0] == in_bit_sizes[1]);
1463       /* The type of Operand 1 and Operand 2 must be a scalar or vector of integer type. */
1464       SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalEqual : SpvOpIEqual;
1465       result = emit_binop(ctx, op,
1466                           get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1467                           src[0], src[1]);
1468       result = emit_unop(ctx, SpvOpAll, dest_type, result);
1469       break;
1470    }
1471 
1472    case nir_op_vec2:
1473    case nir_op_vec3:
1474    case nir_op_vec4: {
1475       int num_inputs = nir_op_infos[alu->op].num_inputs;
1476       assert(2 <= num_inputs && num_inputs <= 4);
1477       result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type,
1478                                                       src, num_inputs);
1479    }
1480    break;
1481 
1482    default:
1483       fprintf(stderr, "emit_alu: not implemented (%s)\n",
1484               nir_op_infos[alu->op].name);
1485 
1486       unreachable("unsupported opcode");
1487       return;
1488    }
1489 
1490    store_alu_result(ctx, alu, result);
1491 }
1492 
1493 static void
emit_load_const(struct ntv_context * ctx,nir_load_const_instr * load_const)1494 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
1495 {
1496    unsigned bit_size = load_const->def.bit_size;
1497    unsigned num_components = load_const->def.num_components;
1498 
1499    SpvId constant;
1500    if (num_components > 1) {
1501       SpvId components[num_components];
1502       SpvId type = get_vec_from_bit_size(ctx, bit_size, num_components);
1503       if (bit_size == 1) {
1504          for (int i = 0; i < num_components; i++)
1505             components[i] = spirv_builder_const_bool(&ctx->builder,
1506                                                      load_const->value[i].b);
1507 
1508       } else {
1509          for (int i = 0; i < num_components; i++)
1510             components[i] = emit_uint_const(ctx, bit_size,
1511                                             load_const->value[i].u32);
1512 
1513       }
1514       constant = spirv_builder_const_composite(&ctx->builder, type,
1515                                                components, num_components);
1516    } else {
1517       assert(num_components == 1);
1518       if (bit_size == 1)
1519          constant = spirv_builder_const_bool(&ctx->builder,
1520                                              load_const->value[0].b);
1521       else
1522          constant = emit_uint_const(ctx, bit_size, load_const->value[0].u32);
1523    }
1524 
1525    store_ssa_def(ctx, &load_const->def, constant);
1526 }
1527 
1528 static void
emit_load_ubo_vec4(struct ntv_context * ctx,nir_intrinsic_instr * intr)1529 emit_load_ubo_vec4(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1530 {
1531    ASSERTED nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
1532    assert(const_block_index); // no dynamic indexing for now
1533 
1534    SpvId offset = get_src(ctx, &intr->src[1]);
1535    SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
1536    SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1537                                                    SpvStorageClassUniform,
1538                                                    uvec4_type);
1539 
1540    SpvId member = emit_uint_const(ctx, 32, 0);
1541    SpvId offsets[] = { member, offset };
1542    SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
1543                                                ctx->ubos[const_block_index->u32],
1544                                                offsets, ARRAY_SIZE(offsets));
1545    SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
1546 
1547    SpvId type = get_dest_uvec_type(ctx, &intr->dest);
1548    unsigned num_components = nir_dest_num_components(intr->dest);
1549    if (num_components == 1) {
1550       uint32_t components[] = { 0 };
1551       result = spirv_builder_emit_composite_extract(&ctx->builder,
1552                                                     type,
1553                                                     result, components,
1554                                                     1);
1555    } else if (num_components < 4) {
1556       SpvId constituents[num_components];
1557       SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
1558       for (uint32_t i = 0; i < num_components; ++i)
1559          constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1560                                                                 uint_type,
1561                                                                 result, &i,
1562                                                                 1);
1563 
1564       result = spirv_builder_emit_composite_construct(&ctx->builder,
1565                                                       type,
1566                                                       constituents,
1567                                                       num_components);
1568    }
1569 
1570    if (nir_dest_bit_size(intr->dest) == 1)
1571       result = uvec_to_bvec(ctx, result, num_components);
1572 
1573    store_dest(ctx, &intr->dest, result, nir_type_uint);
1574 }
1575 
1576 static void
emit_discard(struct ntv_context * ctx,nir_intrinsic_instr * intr)1577 emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1578 {
1579    assert(ctx->block_started);
1580    spirv_builder_emit_kill(&ctx->builder);
1581    /* discard is weird in NIR, so let's just create an unreachable block after
1582       it and hope that the vulkan driver will DCE any instructinos in it. */
1583    spirv_builder_label(&ctx->builder, spirv_builder_new_id(&ctx->builder));
1584 }
1585 
1586 static void
emit_load_deref(struct ntv_context * ctx,nir_intrinsic_instr * intr)1587 emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1588 {
1589    SpvId ptr = get_src(ctx, intr->src);
1590 
1591    SpvId result = spirv_builder_emit_load(&ctx->builder,
1592                                           get_glsl_type(ctx, nir_src_as_deref(intr->src[0])->type),
1593                                           ptr);
1594    unsigned num_components = nir_dest_num_components(intr->dest);
1595    unsigned bit_size = nir_dest_bit_size(intr->dest);
1596    if (ctx->stage > MESA_SHADER_VERTEX && ctx->stage <= MESA_SHADER_GEOMETRY &&
1597        (nir_deref_instr_get_variable(nir_src_as_deref(intr->src[0]))->data.location == VARYING_SLOT_POS)) {
1598       /* we previously transformed opengl gl_Position -> vulkan gl_Position in vertex shader,
1599        * so now we have to reverse that and construct a new gl_Position:
1600 
1601          gl_Position.z = gl_Position.z * 2 - gl_Position.w
1602 
1603        */
1604       SpvId components[4];
1605       SpvId f_type = get_fvec_type(ctx, 32, 1);
1606       SpvId base_type = get_fvec_type(ctx, 32, 4);
1607       for (unsigned c = 0; c < 4; c++) {
1608          uint32_t member[] = { c };
1609 
1610          components[c] = spirv_builder_emit_composite_extract(&ctx->builder, f_type, result, member, 1);
1611       }
1612       components[2] = emit_binop(ctx, SpvOpFMul, f_type, components[2], emit_float_const(ctx, 32, 2.0));
1613       components[2] = emit_binop(ctx, SpvOpFSub, f_type, components[2], components[3]);
1614 
1615       result = spirv_builder_emit_composite_construct(&ctx->builder, base_type,
1616                                                       components, 4);
1617    }
1618    result = bitcast_to_uvec(ctx, result, bit_size, num_components);
1619    store_dest(ctx, &intr->dest, result, nir_type_uint);
1620 }
1621 
1622 static void
emit_store_deref(struct ntv_context * ctx,nir_intrinsic_instr * intr)1623 emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1624 {
1625    SpvId ptr = get_src(ctx, &intr->src[0]);
1626    SpvId src = get_src(ctx, &intr->src[1]);
1627 
1628    SpvId type = get_glsl_type(ctx, nir_src_as_deref(intr->src[0])->type);
1629    nir_variable *var = nir_deref_instr_get_variable(nir_src_as_deref(intr->src[0]));
1630    SpvId result;
1631    if (ctx->stage == MESA_SHADER_FRAGMENT && var->data.location == FRAG_RESULT_SAMPLE_MASK) {
1632       src = emit_bitcast(ctx, type, src);
1633       /* SampleMask is always an array in spirv, so we need to construct it into one */
1634       result = spirv_builder_emit_composite_construct(&ctx->builder, ctx->sample_mask_type, &src, 1);
1635    } else
1636       result = emit_bitcast(ctx, type, src);
1637    spirv_builder_emit_store(&ctx->builder, ptr, result);
1638 }
1639 
1640 static SpvId
create_builtin_var(struct ntv_context * ctx,SpvId var_type,SpvStorageClass storage_class,const char * name,SpvBuiltIn builtin)1641 create_builtin_var(struct ntv_context *ctx, SpvId var_type,
1642                    SpvStorageClass storage_class,
1643                    const char *name, SpvBuiltIn builtin)
1644 {
1645    SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1646                                                    storage_class,
1647                                                    var_type);
1648    SpvId var = spirv_builder_emit_var(&ctx->builder, pointer_type,
1649                                       storage_class);
1650    spirv_builder_emit_name(&ctx->builder, var, name);
1651    spirv_builder_emit_builtin(&ctx->builder, var, builtin);
1652 
1653    assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
1654    ctx->entry_ifaces[ctx->num_entry_ifaces++] = var;
1655    return var;
1656 }
1657 
1658 static void
emit_load_front_face(struct ntv_context * ctx,nir_intrinsic_instr * intr)1659 emit_load_front_face(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1660 {
1661    SpvId var_type = spirv_builder_type_bool(&ctx->builder);
1662    if (!ctx->front_face_var)
1663       ctx->front_face_var = create_builtin_var(ctx, var_type,
1664                                                SpvStorageClassInput,
1665                                                "gl_FrontFacing",
1666                                                SpvBuiltInFrontFacing);
1667 
1668    SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1669                                           ctx->front_face_var);
1670    assert(1 == nir_dest_num_components(intr->dest));
1671    store_dest(ctx, &intr->dest, result, nir_type_bool);
1672 }
1673 
1674 static void
emit_load_uint_input(struct ntv_context * ctx,nir_intrinsic_instr * intr,SpvId * var_id,const char * var_name,SpvBuiltIn builtin)1675 emit_load_uint_input(struct ntv_context *ctx, nir_intrinsic_instr *intr, SpvId *var_id, const char *var_name, SpvBuiltIn builtin)
1676 {
1677    SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1678    if (!*var_id)
1679       *var_id = create_builtin_var(ctx, var_type,
1680                                    SpvStorageClassInput,
1681                                    var_name,
1682                                    builtin);
1683 
1684    SpvId result = spirv_builder_emit_load(&ctx->builder, var_type, *var_id);
1685    assert(1 == nir_dest_num_components(intr->dest));
1686    store_dest(ctx, &intr->dest, result, nir_type_uint);
1687 }
1688 
1689 static void
emit_load_vec_input(struct ntv_context * ctx,nir_intrinsic_instr * intr,SpvId * var_id,const char * var_name,SpvBuiltIn builtin,nir_alu_type type)1690 emit_load_vec_input(struct ntv_context *ctx, nir_intrinsic_instr *intr, SpvId *var_id, const char *var_name, SpvBuiltIn builtin, nir_alu_type type)
1691 {
1692    SpvId var_type;
1693 
1694    switch (type) {
1695    case nir_type_bool:
1696       var_type = get_bvec_type(ctx, nir_dest_num_components(intr->dest));
1697       break;
1698    case nir_type_int:
1699       var_type = get_ivec_type(ctx, nir_dest_bit_size(intr->dest), nir_dest_num_components(intr->dest));
1700       break;
1701    case nir_type_uint:
1702       var_type = get_uvec_type(ctx, nir_dest_bit_size(intr->dest), nir_dest_num_components(intr->dest));
1703       break;
1704    case nir_type_float:
1705       var_type = get_fvec_type(ctx, nir_dest_bit_size(intr->dest), nir_dest_num_components(intr->dest));
1706       break;
1707    default:
1708       unreachable("unknown type passed");
1709    }
1710    if (!*var_id)
1711       *var_id = create_builtin_var(ctx, var_type,
1712                                    SpvStorageClassInput,
1713                                    var_name,
1714                                    builtin);
1715 
1716    SpvId result = spirv_builder_emit_load(&ctx->builder, var_type, *var_id);
1717    store_dest(ctx, &intr->dest, result, type);
1718 }
1719 
1720 static void
emit_intrinsic(struct ntv_context * ctx,nir_intrinsic_instr * intr)1721 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1722 {
1723    switch (intr->intrinsic) {
1724    case nir_intrinsic_load_ubo_vec4:
1725       emit_load_ubo_vec4(ctx, intr);
1726       break;
1727 
1728    case nir_intrinsic_discard:
1729       emit_discard(ctx, intr);
1730       break;
1731 
1732    case nir_intrinsic_load_deref:
1733       emit_load_deref(ctx, intr);
1734       break;
1735 
1736    case nir_intrinsic_store_deref:
1737       emit_store_deref(ctx, intr);
1738       break;
1739 
1740    case nir_intrinsic_load_front_face:
1741       emit_load_front_face(ctx, intr);
1742       break;
1743 
1744    case nir_intrinsic_load_instance_id:
1745       emit_load_uint_input(ctx, intr, &ctx->instance_id_var, "gl_InstanceId", SpvBuiltInInstanceIndex);
1746       break;
1747 
1748    case nir_intrinsic_load_vertex_id:
1749       emit_load_uint_input(ctx, intr, &ctx->vertex_id_var, "gl_VertexId", SpvBuiltInVertexIndex);
1750       break;
1751 
1752    case nir_intrinsic_load_primitive_id:
1753       emit_load_uint_input(ctx, intr, &ctx->primitive_id_var, "gl_PrimitiveIdIn", SpvBuiltInPrimitiveId);
1754       break;
1755 
1756    case nir_intrinsic_load_invocation_id:
1757       emit_load_uint_input(ctx, intr, &ctx->invocation_id_var, "gl_InvocationId", SpvBuiltInInvocationId);
1758       break;
1759 
1760    case nir_intrinsic_load_sample_id:
1761       emit_load_uint_input(ctx, intr, &ctx->sample_id_var, "gl_SampleId", SpvBuiltInSampleId);
1762       break;
1763 
1764    case nir_intrinsic_load_sample_pos:
1765       emit_load_vec_input(ctx, intr, &ctx->sample_pos_var, "gl_SamplePosition", SpvBuiltInSamplePosition, nir_type_float);
1766       break;
1767 
1768    case nir_intrinsic_emit_vertex_with_counter:
1769       /* geometry shader emits copied xfb outputs just prior to EmitVertex(),
1770        * since that's the end of the shader
1771        */
1772       if (ctx->so_info)
1773          emit_so_outputs(ctx, ctx->so_info);
1774       spirv_builder_emit_vertex(&ctx->builder);
1775       break;
1776 
1777    case nir_intrinsic_set_vertex_and_primitive_count:
1778       /* do nothing */
1779       break;
1780 
1781    case nir_intrinsic_end_primitive_with_counter:
1782       spirv_builder_end_primitive(&ctx->builder);
1783       break;
1784 
1785    default:
1786       fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
1787               nir_intrinsic_infos[intr->intrinsic].name);
1788       unreachable("unsupported intrinsic");
1789    }
1790 }
1791 
1792 static void
emit_undef(struct ntv_context * ctx,nir_ssa_undef_instr * undef)1793 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
1794 {
1795    SpvId type = get_uvec_type(ctx, undef->def.bit_size,
1796                               undef->def.num_components);
1797 
1798    store_ssa_def(ctx, &undef->def,
1799                  spirv_builder_emit_undef(&ctx->builder, type));
1800 }
1801 
1802 static SpvId
get_src_float(struct ntv_context * ctx,nir_src * src)1803 get_src_float(struct ntv_context *ctx, nir_src *src)
1804 {
1805    SpvId def = get_src(ctx, src);
1806    unsigned num_components = nir_src_num_components(*src);
1807    unsigned bit_size = nir_src_bit_size(*src);
1808    return bitcast_to_fvec(ctx, def, bit_size, num_components);
1809 }
1810 
1811 static SpvId
get_src_int(struct ntv_context * ctx,nir_src * src)1812 get_src_int(struct ntv_context *ctx, nir_src *src)
1813 {
1814    SpvId def = get_src(ctx, src);
1815    unsigned num_components = nir_src_num_components(*src);
1816    unsigned bit_size = nir_src_bit_size(*src);
1817    return bitcast_to_ivec(ctx, def, bit_size, num_components);
1818 }
1819 
1820 static inline bool
tex_instr_is_lod_allowed(nir_tex_instr * tex)1821 tex_instr_is_lod_allowed(nir_tex_instr *tex)
1822 {
1823    /* This can only be used with an OpTypeImage that has a Dim operand of 1D, 2D, 3D, or Cube
1824     * - SPIR-V: 3.14. Image Operands
1825     */
1826 
1827    return (tex->sampler_dim == GLSL_SAMPLER_DIM_1D ||
1828            tex->sampler_dim == GLSL_SAMPLER_DIM_2D ||
1829            tex->sampler_dim == GLSL_SAMPLER_DIM_3D ||
1830            tex->sampler_dim == GLSL_SAMPLER_DIM_CUBE);
1831 }
1832 
1833 static SpvId
pad_coord_vector(struct ntv_context * ctx,SpvId orig,unsigned old_size,unsigned new_size)1834 pad_coord_vector(struct ntv_context *ctx, SpvId orig, unsigned old_size, unsigned new_size)
1835 {
1836     SpvId int_type = spirv_builder_type_int(&ctx->builder, 32);
1837     SpvId type = get_ivec_type(ctx, 32, new_size);
1838     SpvId constituents[NIR_MAX_VEC_COMPONENTS] = {0};
1839     SpvId zero = emit_int_const(ctx, 32, 0);
1840     assert(new_size < NIR_MAX_VEC_COMPONENTS);
1841 
1842     if (old_size == 1)
1843        constituents[0] = orig;
1844     else {
1845        for (unsigned i = 0; i < old_size; i++)
1846           constituents[i] = spirv_builder_emit_vector_extract(&ctx->builder, int_type, orig, i);
1847     }
1848 
1849     for (unsigned i = old_size; i < new_size; i++)
1850        constituents[i] = zero;
1851 
1852     return spirv_builder_emit_composite_construct(&ctx->builder, type,
1853                                                   constituents, new_size);
1854 }
1855 
1856 static void
emit_tex(struct ntv_context * ctx,nir_tex_instr * tex)1857 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
1858 {
1859    assert(tex->op == nir_texop_tex ||
1860           tex->op == nir_texop_txb ||
1861           tex->op == nir_texop_txl ||
1862           tex->op == nir_texop_txd ||
1863           tex->op == nir_texop_txf ||
1864           tex->op == nir_texop_txf_ms ||
1865           tex->op == nir_texop_txs ||
1866           tex->op == nir_texop_lod);
1867    assert(tex->texture_index == tex->sampler_index);
1868 
1869    SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0, dx = 0, dy = 0,
1870          offset = 0, sample = 0;
1871    unsigned coord_components = 0, coord_bitsize = 0, offset_components = 0;
1872    for (unsigned i = 0; i < tex->num_srcs; i++) {
1873       switch (tex->src[i].src_type) {
1874       case nir_tex_src_coord:
1875          if (tex->op == nir_texop_txf ||
1876              tex->op == nir_texop_txf_ms)
1877             coord = get_src_int(ctx, &tex->src[i].src);
1878          else
1879             coord = get_src_float(ctx, &tex->src[i].src);
1880          coord_components = nir_src_num_components(tex->src[i].src);
1881          coord_bitsize = nir_src_bit_size(tex->src[i].src);
1882          break;
1883 
1884       case nir_tex_src_projector:
1885          assert(nir_src_num_components(tex->src[i].src) == 1);
1886          proj = get_src_float(ctx, &tex->src[i].src);
1887          assert(proj != 0);
1888          break;
1889 
1890       case nir_tex_src_offset:
1891          offset = get_src_int(ctx, &tex->src[i].src);
1892          offset_components = nir_src_num_components(tex->src[i].src);
1893          break;
1894 
1895       case nir_tex_src_bias:
1896          assert(tex->op == nir_texop_txb);
1897          bias = get_src_float(ctx, &tex->src[i].src);
1898          assert(bias != 0);
1899          break;
1900 
1901       case nir_tex_src_lod:
1902          assert(nir_src_num_components(tex->src[i].src) == 1);
1903          if (tex->op == nir_texop_txf ||
1904              tex->op == nir_texop_txf_ms ||
1905              tex->op == nir_texop_txs)
1906             lod = get_src_int(ctx, &tex->src[i].src);
1907          else
1908             lod = get_src_float(ctx, &tex->src[i].src);
1909          assert(lod != 0);
1910          break;
1911 
1912       case nir_tex_src_ms_index:
1913          assert(nir_src_num_components(tex->src[i].src) == 1);
1914          sample = get_src_int(ctx, &tex->src[i].src);
1915          break;
1916 
1917       case nir_tex_src_comparator:
1918          assert(nir_src_num_components(tex->src[i].src) == 1);
1919          dref = get_src_float(ctx, &tex->src[i].src);
1920          assert(dref != 0);
1921          break;
1922 
1923       case nir_tex_src_ddx:
1924          dx = get_src_float(ctx, &tex->src[i].src);
1925          assert(dx != 0);
1926          break;
1927 
1928       case nir_tex_src_ddy:
1929          dy = get_src_float(ctx, &tex->src[i].src);
1930          assert(dy != 0);
1931          break;
1932 
1933       default:
1934          fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
1935          unreachable("unknown texture source");
1936       }
1937    }
1938 
1939    if (lod == 0 && ctx->stage != MESA_SHADER_FRAGMENT) {
1940       lod = emit_float_const(ctx, 32, 0.0f);
1941       assert(lod != 0);
1942    }
1943 
1944    SpvId image_type = ctx->image_types[tex->texture_index];
1945    SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
1946                                                          image_type);
1947 
1948    assert(ctx->samplers_used & (1u << tex->texture_index));
1949    SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
1950                                         ctx->samplers[tex->texture_index]);
1951 
1952    SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
1953 
1954    if (!tex_instr_is_lod_allowed(tex))
1955       lod = 0;
1956    if (tex->op == nir_texop_txs) {
1957       SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1958       SpvId result = spirv_builder_emit_image_query_size(&ctx->builder,
1959                                                          dest_type, image,
1960                                                          lod);
1961       store_dest(ctx, &tex->dest, result, tex->dest_type);
1962       return;
1963    }
1964 
1965    if (proj && coord_components > 0) {
1966       SpvId constituents[coord_components + 1];
1967       if (coord_components == 1)
1968          constituents[0] = coord;
1969       else {
1970          assert(coord_components > 1);
1971          SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1972          for (uint32_t i = 0; i < coord_components; ++i)
1973             constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1974                                                  float_type,
1975                                                  coord,
1976                                                  &i, 1);
1977       }
1978 
1979       constituents[coord_components++] = proj;
1980 
1981       SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
1982       coord = spirv_builder_emit_composite_construct(&ctx->builder,
1983                                                             vec_type,
1984                                                             constituents,
1985                                                             coord_components);
1986    }
1987    if (tex->op == nir_texop_lod) {
1988       SpvId result = spirv_builder_emit_image_query_lod(&ctx->builder,
1989                                                          dest_type, load,
1990                                                          coord);
1991       store_dest(ctx, &tex->dest, result, tex->dest_type);
1992       return;
1993    }
1994    SpvId actual_dest_type = dest_type;
1995    if (dref)
1996       actual_dest_type = spirv_builder_type_float(&ctx->builder, 32);
1997 
1998    SpvId result;
1999    if (tex->op == nir_texop_txf ||
2000        tex->op == nir_texop_txf_ms) {
2001       SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
2002       if (offset) {
2003          /* SPIRV requires matched length vectors for OpIAdd, so if a shader
2004           * uses vecs of differing sizes we need to make a new vec padded with zeroes
2005           * to mimic how GLSL does this implicitly
2006           */
2007          if (offset_components > coord_components)
2008             coord = pad_coord_vector(ctx, coord, coord_components, offset_components);
2009          else if (coord_components > offset_components)
2010             offset = pad_coord_vector(ctx, offset, offset_components, coord_components);
2011          coord = emit_binop(ctx, SpvOpIAdd,
2012                             get_ivec_type(ctx, coord_bitsize, coord_components),
2013                             coord, offset);
2014       }
2015       result = spirv_builder_emit_image_fetch(&ctx->builder, dest_type,
2016                                               image, coord, lod, sample);
2017    } else {
2018       result = spirv_builder_emit_image_sample(&ctx->builder,
2019                                                actual_dest_type, load,
2020                                                coord,
2021                                                proj != 0,
2022                                                lod, bias, dref, dx, dy,
2023                                                offset);
2024    }
2025 
2026    spirv_builder_emit_decoration(&ctx->builder, result,
2027                                  SpvDecorationRelaxedPrecision);
2028 
2029    if (dref && nir_dest_num_components(tex->dest) > 1) {
2030       SpvId components[4] = { result, result, result, result };
2031       result = spirv_builder_emit_composite_construct(&ctx->builder,
2032                                                       dest_type,
2033                                                       components,
2034                                                       4);
2035    }
2036 
2037    store_dest(ctx, &tex->dest, result, tex->dest_type);
2038 }
2039 
2040 static void
start_block(struct ntv_context * ctx,SpvId label)2041 start_block(struct ntv_context *ctx, SpvId label)
2042 {
2043    /* terminate previous block if needed */
2044    if (ctx->block_started)
2045       spirv_builder_emit_branch(&ctx->builder, label);
2046 
2047    /* start new block */
2048    spirv_builder_label(&ctx->builder, label);
2049    ctx->block_started = true;
2050 }
2051 
2052 static void
branch(struct ntv_context * ctx,SpvId label)2053 branch(struct ntv_context *ctx, SpvId label)
2054 {
2055    assert(ctx->block_started);
2056    spirv_builder_emit_branch(&ctx->builder, label);
2057    ctx->block_started = false;
2058 }
2059 
2060 static void
branch_conditional(struct ntv_context * ctx,SpvId condition,SpvId then_id,SpvId else_id)2061 branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
2062                    SpvId else_id)
2063 {
2064    assert(ctx->block_started);
2065    spirv_builder_emit_branch_conditional(&ctx->builder, condition,
2066                                          then_id, else_id);
2067    ctx->block_started = false;
2068 }
2069 
2070 static void
emit_jump(struct ntv_context * ctx,nir_jump_instr * jump)2071 emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
2072 {
2073    switch (jump->type) {
2074    case nir_jump_break:
2075       assert(ctx->loop_break);
2076       branch(ctx, ctx->loop_break);
2077       break;
2078 
2079    case nir_jump_continue:
2080       assert(ctx->loop_cont);
2081       branch(ctx, ctx->loop_cont);
2082       break;
2083 
2084    default:
2085       unreachable("Unsupported jump type\n");
2086    }
2087 }
2088 
2089 static void
emit_deref_var(struct ntv_context * ctx,nir_deref_instr * deref)2090 emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref)
2091 {
2092    assert(deref->deref_type == nir_deref_type_var);
2093 
2094    struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
2095    assert(he);
2096    SpvId result = (SpvId)(intptr_t)he->data;
2097    store_dest_raw(ctx, &deref->dest, result);
2098 }
2099 
2100 static void
emit_deref_array(struct ntv_context * ctx,nir_deref_instr * deref)2101 emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
2102 {
2103    assert(deref->deref_type == nir_deref_type_array);
2104    nir_variable *var = nir_deref_instr_get_variable(deref);
2105 
2106    SpvStorageClass storage_class;
2107    switch (var->data.mode) {
2108    case nir_var_shader_in:
2109       storage_class = SpvStorageClassInput;
2110       break;
2111 
2112    case nir_var_shader_out:
2113       storage_class = SpvStorageClassOutput;
2114       break;
2115 
2116    default:
2117       unreachable("Unsupported nir_variable_mode\n");
2118    }
2119 
2120    SpvId index = get_src(ctx, &deref->arr.index);
2121 
2122    SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
2123                                                storage_class,
2124                                                get_glsl_type(ctx, deref->type));
2125 
2126    SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
2127                                                   ptr_type,
2128                                                   get_src(ctx, &deref->parent),
2129                                                   &index, 1);
2130    /* uint is a bit of a lie here, it's really just an opaque type */
2131    store_dest(ctx, &deref->dest, result, nir_type_uint);
2132 }
2133 
2134 static void
emit_deref(struct ntv_context * ctx,nir_deref_instr * deref)2135 emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
2136 {
2137    switch (deref->deref_type) {
2138    case nir_deref_type_var:
2139       emit_deref_var(ctx, deref);
2140       break;
2141 
2142    case nir_deref_type_array:
2143       emit_deref_array(ctx, deref);
2144       break;
2145 
2146    default:
2147       unreachable("unexpected deref_type");
2148    }
2149 }
2150 
2151 static void
emit_block(struct ntv_context * ctx,struct nir_block * block)2152 emit_block(struct ntv_context *ctx, struct nir_block *block)
2153 {
2154    start_block(ctx, block_label(ctx, block));
2155    nir_foreach_instr(instr, block) {
2156       switch (instr->type) {
2157       case nir_instr_type_alu:
2158          emit_alu(ctx, nir_instr_as_alu(instr));
2159          break;
2160       case nir_instr_type_intrinsic:
2161          emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
2162          break;
2163       case nir_instr_type_load_const:
2164          emit_load_const(ctx, nir_instr_as_load_const(instr));
2165          break;
2166       case nir_instr_type_ssa_undef:
2167          emit_undef(ctx, nir_instr_as_ssa_undef(instr));
2168          break;
2169       case nir_instr_type_tex:
2170          emit_tex(ctx, nir_instr_as_tex(instr));
2171          break;
2172       case nir_instr_type_phi:
2173          unreachable("nir_instr_type_phi not supported");
2174          break;
2175       case nir_instr_type_jump:
2176          emit_jump(ctx, nir_instr_as_jump(instr));
2177          break;
2178       case nir_instr_type_call:
2179          unreachable("nir_instr_type_call not supported");
2180          break;
2181       case nir_instr_type_parallel_copy:
2182          unreachable("nir_instr_type_parallel_copy not supported");
2183          break;
2184       case nir_instr_type_deref:
2185          emit_deref(ctx, nir_instr_as_deref(instr));
2186          break;
2187       }
2188    }
2189 }
2190 
2191 static void
2192 emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
2193 
2194 static SpvId
get_src_bool(struct ntv_context * ctx,nir_src * src)2195 get_src_bool(struct ntv_context *ctx, nir_src *src)
2196 {
2197    assert(nir_src_bit_size(*src) == 1);
2198    return get_src(ctx, src);
2199 }
2200 
2201 static void
emit_if(struct ntv_context * ctx,nir_if * if_stmt)2202 emit_if(struct ntv_context *ctx, nir_if *if_stmt)
2203 {
2204    SpvId condition = get_src_bool(ctx, &if_stmt->condition);
2205 
2206    SpvId header_id = spirv_builder_new_id(&ctx->builder);
2207    SpvId then_id = block_label(ctx, nir_if_first_then_block(if_stmt));
2208    SpvId endif_id = spirv_builder_new_id(&ctx->builder);
2209    SpvId else_id = endif_id;
2210 
2211    bool has_else = !exec_list_is_empty(&if_stmt->else_list);
2212    if (has_else) {
2213       assert(nir_if_first_else_block(if_stmt)->index < ctx->num_blocks);
2214       else_id = block_label(ctx, nir_if_first_else_block(if_stmt));
2215    }
2216 
2217    /* create a header-block */
2218    start_block(ctx, header_id);
2219    spirv_builder_emit_selection_merge(&ctx->builder, endif_id,
2220                                       SpvSelectionControlMaskNone);
2221    branch_conditional(ctx, condition, then_id, else_id);
2222 
2223    emit_cf_list(ctx, &if_stmt->then_list);
2224 
2225    if (has_else) {
2226       if (ctx->block_started)
2227          branch(ctx, endif_id);
2228 
2229       emit_cf_list(ctx, &if_stmt->else_list);
2230    }
2231 
2232    start_block(ctx, endif_id);
2233 }
2234 
2235 static void
emit_loop(struct ntv_context * ctx,nir_loop * loop)2236 emit_loop(struct ntv_context *ctx, nir_loop *loop)
2237 {
2238    SpvId header_id = spirv_builder_new_id(&ctx->builder);
2239    SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
2240    SpvId break_id = spirv_builder_new_id(&ctx->builder);
2241    SpvId cont_id = spirv_builder_new_id(&ctx->builder);
2242 
2243    /* create a header-block */
2244    start_block(ctx, header_id);
2245    spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
2246    branch(ctx, begin_id);
2247 
2248    SpvId save_break = ctx->loop_break;
2249    SpvId save_cont = ctx->loop_cont;
2250    ctx->loop_break = break_id;
2251    ctx->loop_cont = cont_id;
2252 
2253    emit_cf_list(ctx, &loop->body);
2254 
2255    ctx->loop_break = save_break;
2256    ctx->loop_cont = save_cont;
2257 
2258    /* loop->body may have already ended our block */
2259    if (ctx->block_started)
2260       branch(ctx, cont_id);
2261    start_block(ctx, cont_id);
2262    branch(ctx, header_id);
2263 
2264    start_block(ctx, break_id);
2265 }
2266 
2267 static void
emit_cf_list(struct ntv_context * ctx,struct exec_list * list)2268 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
2269 {
2270    foreach_list_typed(nir_cf_node, node, node, list) {
2271       switch (node->type) {
2272       case nir_cf_node_block:
2273          emit_block(ctx, nir_cf_node_as_block(node));
2274          break;
2275 
2276       case nir_cf_node_if:
2277          emit_if(ctx, nir_cf_node_as_if(node));
2278          break;
2279 
2280       case nir_cf_node_loop:
2281          emit_loop(ctx, nir_cf_node_as_loop(node));
2282          break;
2283 
2284       case nir_cf_node_function:
2285          unreachable("nir_cf_node_function not supported");
2286          break;
2287       }
2288    }
2289 }
2290 
2291 static SpvExecutionMode
get_input_prim_type_mode(uint16_t type)2292 get_input_prim_type_mode(uint16_t type)
2293 {
2294    switch (type) {
2295    case GL_POINTS:
2296       return SpvExecutionModeInputPoints;
2297    case GL_LINES:
2298    case GL_LINE_LOOP:
2299    case GL_LINE_STRIP:
2300       return SpvExecutionModeInputLines;
2301    case GL_TRIANGLE_STRIP:
2302    case GL_TRIANGLES:
2303    case GL_TRIANGLE_FAN:
2304       return SpvExecutionModeTriangles;
2305    case GL_QUADS:
2306    case GL_QUAD_STRIP:
2307       return SpvExecutionModeQuads;
2308       break;
2309    case GL_POLYGON:
2310       unreachable("handle polygons in gs");
2311       break;
2312    case GL_LINES_ADJACENCY:
2313    case GL_LINE_STRIP_ADJACENCY:
2314       return SpvExecutionModeInputLinesAdjacency;
2315    case GL_TRIANGLES_ADJACENCY:
2316    case GL_TRIANGLE_STRIP_ADJACENCY:
2317       return SpvExecutionModeInputTrianglesAdjacency;
2318       break;
2319    case GL_ISOLINES:
2320       return SpvExecutionModeIsolines;
2321    default:
2322       debug_printf("unknown geometry shader input mode %u\n", type);
2323       unreachable("error!");
2324       break;
2325    }
2326 
2327    return 0;
2328 }
2329 static SpvExecutionMode
get_output_prim_type_mode(uint16_t type)2330 get_output_prim_type_mode(uint16_t type)
2331 {
2332    switch (type) {
2333    case GL_POINTS:
2334       return SpvExecutionModeOutputPoints;
2335    case GL_LINES:
2336    case GL_LINE_LOOP:
2337       unreachable("GL_LINES/LINE_LOOP passed as gs output");
2338       break;
2339    case GL_LINE_STRIP:
2340       return SpvExecutionModeOutputLineStrip;
2341    case GL_TRIANGLE_STRIP:
2342       return SpvExecutionModeOutputTriangleStrip;
2343    case GL_TRIANGLES:
2344    case GL_TRIANGLE_FAN: //FIXME: not sure if right for output
2345       return SpvExecutionModeTriangles;
2346    case GL_QUADS:
2347    case GL_QUAD_STRIP:
2348       return SpvExecutionModeQuads;
2349    case GL_POLYGON:
2350       unreachable("handle polygons in gs");
2351       break;
2352    case GL_LINES_ADJACENCY:
2353    case GL_LINE_STRIP_ADJACENCY:
2354       unreachable("handle line adjacency in gs");
2355       break;
2356    case GL_TRIANGLES_ADJACENCY:
2357    case GL_TRIANGLE_STRIP_ADJACENCY:
2358       unreachable("handle triangle adjacency in gs");
2359       break;
2360    case GL_ISOLINES:
2361       return SpvExecutionModeIsolines;
2362    default:
2363       debug_printf("unknown geometry shader output mode %u\n", type);
2364       unreachable("error!");
2365       break;
2366    }
2367 
2368    return 0;
2369 }
2370 
2371 struct spirv_shader *
nir_to_spirv(struct nir_shader * s,const struct zink_so_info * so_info,unsigned char * shader_slot_map,unsigned char * shader_slots_reserved)2372 nir_to_spirv(struct nir_shader *s, const struct zink_so_info *so_info,
2373              unsigned char *shader_slot_map, unsigned char *shader_slots_reserved)
2374 {
2375    struct spirv_shader *ret = NULL;
2376 
2377    struct ntv_context ctx = {};
2378    ctx.mem_ctx = ralloc_context(NULL);
2379    ctx.builder.mem_ctx = ctx.mem_ctx;
2380 
2381    switch (s->info.stage) {
2382    case MESA_SHADER_VERTEX:
2383    case MESA_SHADER_FRAGMENT:
2384    case MESA_SHADER_COMPUTE:
2385       spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
2386       spirv_builder_emit_cap(&ctx.builder, SpvCapabilityImageBuffer);
2387       spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySampledBuffer);
2388       break;
2389 
2390    case MESA_SHADER_TESS_CTRL:
2391    case MESA_SHADER_TESS_EVAL:
2392       spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTessellation);
2393       break;
2394 
2395    case MESA_SHADER_GEOMETRY:
2396       spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometry);
2397       if (s->info.gs.active_stream_mask)
2398          spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometryStreams);
2399       if (s->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_PSIZ))
2400          spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometryPointSize);
2401       break;
2402 
2403    default:
2404       unreachable("invalid stage");
2405    }
2406 
2407    if (s->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_VIEWPORT)) {
2408       if (s->info.stage < MESA_SHADER_GEOMETRY)
2409          spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShaderViewportIndex);
2410       else
2411          spirv_builder_emit_cap(&ctx.builder, SpvCapabilityMultiViewport);
2412    }
2413 
2414    // TODO: only enable when needed
2415    if (s->info.stage == MESA_SHADER_FRAGMENT) {
2416       spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySampled1D);
2417       spirv_builder_emit_cap(&ctx.builder, SpvCapabilityImageQuery);
2418       spirv_builder_emit_cap(&ctx.builder, SpvCapabilityDerivativeControl);
2419       spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySampleRateShading);
2420    }
2421 
2422    ctx.stage = s->info.stage;
2423    ctx.so_info = so_info;
2424    ctx.shader_slot_map = shader_slot_map;
2425    ctx.shader_slots_reserved = *shader_slots_reserved;
2426    ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
2427    spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageGLSL, 450);
2428 
2429    spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
2430                                 SpvMemoryModelGLSL450);
2431 
2432    SpvExecutionModel exec_model;
2433    switch (s->info.stage) {
2434    case MESA_SHADER_VERTEX:
2435       exec_model = SpvExecutionModelVertex;
2436       break;
2437    case MESA_SHADER_TESS_CTRL:
2438       exec_model = SpvExecutionModelTessellationControl;
2439       break;
2440    case MESA_SHADER_TESS_EVAL:
2441       exec_model = SpvExecutionModelTessellationEvaluation;
2442       break;
2443    case MESA_SHADER_GEOMETRY:
2444       exec_model = SpvExecutionModelGeometry;
2445       break;
2446    case MESA_SHADER_FRAGMENT:
2447       exec_model = SpvExecutionModelFragment;
2448       break;
2449    case MESA_SHADER_COMPUTE:
2450       exec_model = SpvExecutionModelGLCompute;
2451       break;
2452    default:
2453       unreachable("invalid stage");
2454    }
2455 
2456    SpvId type_void = spirv_builder_type_void(&ctx.builder);
2457    SpvId type_main = spirv_builder_type_function(&ctx.builder, type_void,
2458                                                  NULL, 0);
2459    SpvId entry_point = spirv_builder_new_id(&ctx.builder);
2460    spirv_builder_emit_name(&ctx.builder, entry_point, "main");
2461 
2462    ctx.vars = _mesa_hash_table_create(ctx.mem_ctx, _mesa_hash_pointer,
2463                                       _mesa_key_pointer_equal);
2464 
2465    ctx.so_outputs = _mesa_hash_table_create(ctx.mem_ctx, _mesa_hash_u32,
2466                                             _mesa_key_u32_equal);
2467 
2468    nir_foreach_shader_in_variable(var, s)
2469       emit_input(&ctx, var);
2470 
2471    nir_foreach_shader_out_variable(var, s)
2472       emit_output(&ctx, var);
2473 
2474 
2475    if (so_info)
2476       emit_so_info(&ctx, so_info);
2477    /* we have to reverse iterate to match what's done in zink_compiler.c */
2478    foreach_list_typed_reverse(nir_variable, var, node, &s->variables)
2479       if (_nir_shader_variable_has_mode(var, nir_var_uniform |
2480                                         nir_var_mem_ubo |
2481                                         nir_var_mem_ssbo))
2482          emit_uniform(&ctx, var);
2483 
2484    switch (s->info.stage) {
2485    case MESA_SHADER_FRAGMENT:
2486       spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
2487                                    SpvExecutionModeOriginUpperLeft);
2488       if (s->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_DEPTH))
2489          spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
2490                                       SpvExecutionModeDepthReplacing);
2491       break;
2492    case MESA_SHADER_GEOMETRY:
2493       spirv_builder_emit_exec_mode(&ctx.builder, entry_point, get_input_prim_type_mode(s->info.gs.input_primitive));
2494       spirv_builder_emit_exec_mode(&ctx.builder, entry_point, get_output_prim_type_mode(s->info.gs.output_primitive));
2495       spirv_builder_emit_exec_mode_literal(&ctx.builder, entry_point, SpvExecutionModeInvocations, s->info.gs.invocations);
2496       spirv_builder_emit_exec_mode_literal(&ctx.builder, entry_point, SpvExecutionModeOutputVertices, s->info.gs.vertices_out);
2497       break;
2498    default:
2499       break;
2500    }
2501    if (so_info && so_info->so_info.num_outputs) {
2502       spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTransformFeedback);
2503       spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
2504                                    SpvExecutionModeXfb);
2505    }
2506    spirv_builder_function(&ctx.builder, entry_point, type_void,
2507                                             SpvFunctionControlMaskNone,
2508                                             type_main);
2509 
2510    nir_function_impl *entry = nir_shader_get_entrypoint(s);
2511    nir_metadata_require(entry, nir_metadata_block_index);
2512 
2513    ctx.defs = ralloc_array_size(ctx.mem_ctx,
2514                                 sizeof(SpvId), entry->ssa_alloc);
2515    if (!ctx.defs)
2516       goto fail;
2517    ctx.num_defs = entry->ssa_alloc;
2518 
2519    nir_index_local_regs(entry);
2520    ctx.regs = ralloc_array_size(ctx.mem_ctx,
2521                                 sizeof(SpvId), entry->reg_alloc);
2522    if (!ctx.regs)
2523       goto fail;
2524    ctx.num_regs = entry->reg_alloc;
2525 
2526    SpvId *block_ids = ralloc_array_size(ctx.mem_ctx,
2527                                         sizeof(SpvId), entry->num_blocks);
2528    if (!block_ids)
2529       goto fail;
2530 
2531    for (int i = 0; i < entry->num_blocks; ++i)
2532       block_ids[i] = spirv_builder_new_id(&ctx.builder);
2533 
2534    ctx.block_ids = block_ids;
2535    ctx.num_blocks = entry->num_blocks;
2536 
2537    /* emit a block only for the variable declarations */
2538    start_block(&ctx, spirv_builder_new_id(&ctx.builder));
2539    foreach_list_typed(nir_register, reg, node, &entry->registers) {
2540       SpvId type = get_vec_from_bit_size(&ctx, reg->bit_size, reg->num_components);
2541       SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
2542                                                       SpvStorageClassFunction,
2543                                                       type);
2544       SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
2545                                          SpvStorageClassFunction);
2546 
2547       ctx.regs[reg->index] = var;
2548    }
2549 
2550    emit_cf_list(&ctx, &entry->body);
2551 
2552    /* vertex shader emits copied xfb outputs at the end of the shader */
2553    if (so_info && ctx.stage == MESA_SHADER_VERTEX)
2554       emit_so_outputs(&ctx, so_info);
2555 
2556    spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
2557    spirv_builder_function_end(&ctx.builder);
2558 
2559    spirv_builder_emit_entry_point(&ctx.builder, exec_model, entry_point,
2560                                   "main", ctx.entry_ifaces,
2561                                   ctx.num_entry_ifaces);
2562 
2563    size_t num_words = spirv_builder_get_num_words(&ctx.builder);
2564 
2565    ret = CALLOC_STRUCT(spirv_shader);
2566    if (!ret)
2567       goto fail;
2568 
2569    ret->words = MALLOC(sizeof(uint32_t) * num_words);
2570    if (!ret->words)
2571       goto fail;
2572 
2573    ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
2574    assert(ret->num_words == num_words);
2575 
2576    ralloc_free(ctx.mem_ctx);
2577    *shader_slots_reserved = ctx.shader_slots_reserved;
2578 
2579    return ret;
2580 
2581 fail:
2582    ralloc_free(ctx.mem_ctx);
2583 
2584    if (ret)
2585       spirv_shader_delete(ret);
2586 
2587    return NULL;
2588 }
2589 
2590 void
spirv_shader_delete(struct spirv_shader * s)2591 spirv_shader_delete(struct spirv_shader *s)
2592 {
2593    FREE(s->words);
2594    FREE(s);
2595 }
2596