1 /*
2  * Copyright © 2016 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include <math.h>
25 #include "vtn_private.h"
26 #include "spirv_info.h"
27 
28 /*
29  * Normally, column vectors in SPIR-V correspond to a single NIR SSA
30  * definition. But for matrix multiplies, we want to do one routine for
31  * multiplying a matrix by a matrix and then pretend that vectors are matrices
32  * with one column. So we "wrap" these things, and unwrap the result before we
33  * send it off.
34  */
35 
36 static struct vtn_ssa_value *
wrap_matrix(struct vtn_builder * b,struct vtn_ssa_value * val)37 wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
38 {
39    if (val == NULL)
40       return NULL;
41 
42    if (glsl_type_is_matrix(val->type))
43       return val;
44 
45    struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
46    dest->type = glsl_get_bare_type(val->type);
47    dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
48    dest->elems[0] = val;
49 
50    return dest;
51 }
52 
53 static struct vtn_ssa_value *
unwrap_matrix(struct vtn_ssa_value * val)54 unwrap_matrix(struct vtn_ssa_value *val)
55 {
56    if (glsl_type_is_matrix(val->type))
57          return val;
58 
59    return val->elems[0];
60 }
61 
62 static struct vtn_ssa_value *
matrix_multiply(struct vtn_builder * b,struct vtn_ssa_value * _src0,struct vtn_ssa_value * _src1)63 matrix_multiply(struct vtn_builder *b,
64                 struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
65 {
66 
67    struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
68    struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
69    struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
70    struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
71 
72    unsigned src0_rows = glsl_get_vector_elements(src0->type);
73    unsigned src0_columns = glsl_get_matrix_columns(src0->type);
74    unsigned src1_columns = glsl_get_matrix_columns(src1->type);
75 
76    const struct glsl_type *dest_type;
77    if (src1_columns > 1) {
78       dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
79                                    src0_rows, src1_columns);
80    } else {
81       dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
82    }
83    struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
84 
85    dest = wrap_matrix(b, dest);
86 
87    bool transpose_result = false;
88    if (src0_transpose && src1_transpose) {
89       /* transpose(A) * transpose(B) = transpose(B * A) */
90       src1 = src0_transpose;
91       src0 = src1_transpose;
92       src0_transpose = NULL;
93       src1_transpose = NULL;
94       transpose_result = true;
95    }
96 
97    if (src0_transpose && !src1_transpose &&
98        glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
99       /* We already have the rows of src0 and the columns of src1 available,
100        * so we can just take the dot product of each row with each column to
101        * get the result.
102        */
103 
104       for (unsigned i = 0; i < src1_columns; i++) {
105          nir_ssa_def *vec_src[4];
106          for (unsigned j = 0; j < src0_rows; j++) {
107             vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def,
108                                           src1->elems[i]->def);
109          }
110          dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows);
111       }
112    } else {
113       /* We don't handle the case where src1 is transposed but not src0, since
114        * the general case only uses individual components of src1 so the
115        * optimizer should chew through the transpose we emitted for src1.
116        */
117 
118       for (unsigned i = 0; i < src1_columns; i++) {
119          /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
120          dest->elems[i]->def =
121             nir_fmul(&b->nb, src0->elems[src0_columns - 1]->def,
122                      nir_channel(&b->nb, src1->elems[i]->def, src0_columns - 1));
123          for (int j = src0_columns - 2; j >= 0; j--) {
124             dest->elems[i]->def =
125                nir_fadd(&b->nb, dest->elems[i]->def,
126                         nir_fmul(&b->nb, src0->elems[j]->def,
127                                  nir_channel(&b->nb, src1->elems[i]->def, j)));
128          }
129       }
130    }
131 
132    dest = unwrap_matrix(dest);
133 
134    if (transpose_result)
135       dest = vtn_ssa_transpose(b, dest);
136 
137    return dest;
138 }
139 
140 static struct vtn_ssa_value *
mat_times_scalar(struct vtn_builder * b,struct vtn_ssa_value * mat,nir_ssa_def * scalar)141 mat_times_scalar(struct vtn_builder *b,
142                  struct vtn_ssa_value *mat,
143                  nir_ssa_def *scalar)
144 {
145    struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
146    for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
147       if (glsl_base_type_is_integer(glsl_get_base_type(mat->type)))
148          dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
149       else
150          dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
151    }
152 
153    return dest;
154 }
155 
156 static struct vtn_ssa_value *
vtn_handle_matrix_alu(struct vtn_builder * b,SpvOp opcode,struct vtn_ssa_value * src0,struct vtn_ssa_value * src1)157 vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
158                       struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
159 {
160    switch (opcode) {
161    case SpvOpFNegate: {
162       struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
163       unsigned cols = glsl_get_matrix_columns(src0->type);
164       for (unsigned i = 0; i < cols; i++)
165          dest->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
166       return dest;
167    }
168 
169    case SpvOpFAdd: {
170       struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
171       unsigned cols = glsl_get_matrix_columns(src0->type);
172       for (unsigned i = 0; i < cols; i++)
173          dest->elems[i]->def =
174             nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
175       return dest;
176    }
177 
178    case SpvOpFSub: {
179       struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
180       unsigned cols = glsl_get_matrix_columns(src0->type);
181       for (unsigned i = 0; i < cols; i++)
182          dest->elems[i]->def =
183             nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
184       return dest;
185    }
186 
187    case SpvOpTranspose:
188       return vtn_ssa_transpose(b, src0);
189 
190    case SpvOpMatrixTimesScalar:
191       if (src0->transposed) {
192          return vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
193                                                          src1->def));
194       } else {
195          return mat_times_scalar(b, src0, src1->def);
196       }
197       break;
198 
199    case SpvOpVectorTimesMatrix:
200    case SpvOpMatrixTimesVector:
201    case SpvOpMatrixTimesMatrix:
202       if (opcode == SpvOpVectorTimesMatrix) {
203          return matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
204       } else {
205          return matrix_multiply(b, src0, src1);
206       }
207       break;
208 
209    default: vtn_fail_with_opcode("unknown matrix opcode", opcode);
210    }
211 }
212 
213 static nir_alu_type
convert_op_src_type(SpvOp opcode)214 convert_op_src_type(SpvOp opcode)
215 {
216    switch (opcode) {
217    case SpvOpFConvert:
218    case SpvOpConvertFToS:
219    case SpvOpConvertFToU:
220       return nir_type_float;
221    case SpvOpSConvert:
222    case SpvOpConvertSToF:
223    case SpvOpSatConvertSToU:
224       return nir_type_int;
225    case SpvOpUConvert:
226    case SpvOpConvertUToF:
227    case SpvOpSatConvertUToS:
228       return nir_type_uint;
229    default:
230       unreachable("Unhandled conversion op");
231    }
232 }
233 
234 static nir_alu_type
convert_op_dst_type(SpvOp opcode)235 convert_op_dst_type(SpvOp opcode)
236 {
237    switch (opcode) {
238    case SpvOpFConvert:
239    case SpvOpConvertSToF:
240    case SpvOpConvertUToF:
241       return nir_type_float;
242    case SpvOpSConvert:
243    case SpvOpConvertFToS:
244    case SpvOpSatConvertUToS:
245       return nir_type_int;
246    case SpvOpUConvert:
247    case SpvOpConvertFToU:
248    case SpvOpSatConvertSToU:
249       return nir_type_uint;
250    default:
251       unreachable("Unhandled conversion op");
252    }
253 }
254 
255 nir_op
vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder * b,SpvOp opcode,bool * swap,bool * exact,unsigned src_bit_size,unsigned dst_bit_size)256 vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
257                                 SpvOp opcode, bool *swap, bool *exact,
258                                 unsigned src_bit_size, unsigned dst_bit_size)
259 {
260    /* Indicates that the first two arguments should be swapped.  This is
261     * used for implementing greater-than and less-than-or-equal.
262     */
263    *swap = false;
264 
265    *exact = false;
266 
267    switch (opcode) {
268    case SpvOpSNegate:            return nir_op_ineg;
269    case SpvOpFNegate:            return nir_op_fneg;
270    case SpvOpNot:                return nir_op_inot;
271    case SpvOpIAdd:               return nir_op_iadd;
272    case SpvOpFAdd:               return nir_op_fadd;
273    case SpvOpISub:               return nir_op_isub;
274    case SpvOpFSub:               return nir_op_fsub;
275    case SpvOpIMul:               return nir_op_imul;
276    case SpvOpFMul:               return nir_op_fmul;
277    case SpvOpUDiv:               return nir_op_udiv;
278    case SpvOpSDiv:               return nir_op_idiv;
279    case SpvOpFDiv:               return nir_op_fdiv;
280    case SpvOpUMod:               return nir_op_umod;
281    case SpvOpSMod:               return nir_op_imod;
282    case SpvOpFMod:               return nir_op_fmod;
283    case SpvOpSRem:               return nir_op_irem;
284    case SpvOpFRem:               return nir_op_frem;
285 
286    case SpvOpShiftRightLogical:     return nir_op_ushr;
287    case SpvOpShiftRightArithmetic:  return nir_op_ishr;
288    case SpvOpShiftLeftLogical:      return nir_op_ishl;
289    case SpvOpLogicalOr:             return nir_op_ior;
290    case SpvOpLogicalEqual:          return nir_op_ieq;
291    case SpvOpLogicalNotEqual:       return nir_op_ine;
292    case SpvOpLogicalAnd:            return nir_op_iand;
293    case SpvOpLogicalNot:            return nir_op_inot;
294    case SpvOpBitwiseOr:             return nir_op_ior;
295    case SpvOpBitwiseXor:            return nir_op_ixor;
296    case SpvOpBitwiseAnd:            return nir_op_iand;
297    case SpvOpSelect:                return nir_op_bcsel;
298    case SpvOpIEqual:                return nir_op_ieq;
299 
300    case SpvOpBitFieldInsert:        return nir_op_bitfield_insert;
301    case SpvOpBitFieldSExtract:      return nir_op_ibitfield_extract;
302    case SpvOpBitFieldUExtract:      return nir_op_ubitfield_extract;
303    case SpvOpBitReverse:            return nir_op_bitfield_reverse;
304 
305    case SpvOpUCountLeadingZerosINTEL: return nir_op_uclz;
306    /* SpvOpUCountTrailingZerosINTEL is handled elsewhere. */
307    case SpvOpAbsISubINTEL:          return nir_op_uabs_isub;
308    case SpvOpAbsUSubINTEL:          return nir_op_uabs_usub;
309    case SpvOpIAddSatINTEL:          return nir_op_iadd_sat;
310    case SpvOpUAddSatINTEL:          return nir_op_uadd_sat;
311    case SpvOpIAverageINTEL:         return nir_op_ihadd;
312    case SpvOpUAverageINTEL:         return nir_op_uhadd;
313    case SpvOpIAverageRoundedINTEL:  return nir_op_irhadd;
314    case SpvOpUAverageRoundedINTEL:  return nir_op_urhadd;
315    case SpvOpISubSatINTEL:          return nir_op_isub_sat;
316    case SpvOpUSubSatINTEL:          return nir_op_usub_sat;
317    case SpvOpIMul32x16INTEL:        return nir_op_imul_32x16;
318    case SpvOpUMul32x16INTEL:        return nir_op_umul_32x16;
319 
320    /* The ordered / unordered operators need special implementation besides
321     * the logical operator to use since they also need to check if operands are
322     * ordered.
323     */
324    case SpvOpFOrdEqual:                            *exact = true;  return nir_op_feq;
325    case SpvOpFUnordEqual:                          *exact = true;  return nir_op_feq;
326    case SpvOpINotEqual:                                            return nir_op_ine;
327    case SpvOpLessOrGreater:                        /* Deprecated, use OrdNotEqual */
328    case SpvOpFOrdNotEqual:                         *exact = true;  return nir_op_fneu;
329    case SpvOpFUnordNotEqual:                       *exact = true;  return nir_op_fneu;
330    case SpvOpULessThan:                                            return nir_op_ult;
331    case SpvOpSLessThan:                                            return nir_op_ilt;
332    case SpvOpFOrdLessThan:                         *exact = true;  return nir_op_flt;
333    case SpvOpFUnordLessThan:                       *exact = true;  return nir_op_flt;
334    case SpvOpUGreaterThan:          *swap = true;                  return nir_op_ult;
335    case SpvOpSGreaterThan:          *swap = true;                  return nir_op_ilt;
336    case SpvOpFOrdGreaterThan:       *swap = true;  *exact = true;  return nir_op_flt;
337    case SpvOpFUnordGreaterThan:     *swap = true;  *exact = true;  return nir_op_flt;
338    case SpvOpULessThanEqual:        *swap = true;                  return nir_op_uge;
339    case SpvOpSLessThanEqual:        *swap = true;                  return nir_op_ige;
340    case SpvOpFOrdLessThanEqual:     *swap = true;  *exact = true;  return nir_op_fge;
341    case SpvOpFUnordLessThanEqual:   *swap = true;  *exact = true;  return nir_op_fge;
342    case SpvOpUGreaterThanEqual:                                    return nir_op_uge;
343    case SpvOpSGreaterThanEqual:                                    return nir_op_ige;
344    case SpvOpFOrdGreaterThanEqual:                 *exact = true;  return nir_op_fge;
345    case SpvOpFUnordGreaterThanEqual:               *exact = true;  return nir_op_fge;
346 
347    /* Conversions: */
348    case SpvOpQuantizeToF16:         return nir_op_fquantize2f16;
349    case SpvOpUConvert:
350    case SpvOpConvertFToU:
351    case SpvOpConvertFToS:
352    case SpvOpConvertSToF:
353    case SpvOpConvertUToF:
354    case SpvOpSConvert:
355    case SpvOpFConvert: {
356       nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
357       nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
358       return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
359    }
360 
361    case SpvOpPtrCastToGeneric:   return nir_op_mov;
362    case SpvOpGenericCastToPtr:   return nir_op_mov;
363 
364    /* Derivatives: */
365    case SpvOpDPdx:         return nir_op_fddx;
366    case SpvOpDPdy:         return nir_op_fddy;
367    case SpvOpDPdxFine:     return nir_op_fddx_fine;
368    case SpvOpDPdyFine:     return nir_op_fddy_fine;
369    case SpvOpDPdxCoarse:   return nir_op_fddx_coarse;
370    case SpvOpDPdyCoarse:   return nir_op_fddy_coarse;
371 
372    case SpvOpIsNormal:     return nir_op_fisnormal;
373    case SpvOpIsFinite:     return nir_op_fisfinite;
374 
375    default:
376       vtn_fail("No NIR equivalent: %u", opcode);
377    }
378 }
379 
380 static void
handle_no_contraction(struct vtn_builder * b,struct vtn_value * val,int member,const struct vtn_decoration * dec,void * _void)381 handle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member,
382                       const struct vtn_decoration *dec, void *_void)
383 {
384    vtn_assert(dec->scope == VTN_DEC_DECORATION);
385    if (dec->decoration != SpvDecorationNoContraction)
386       return;
387 
388    b->nb.exact = true;
389 }
390 
391 nir_rounding_mode
vtn_rounding_mode_to_nir(struct vtn_builder * b,SpvFPRoundingMode mode)392 vtn_rounding_mode_to_nir(struct vtn_builder *b, SpvFPRoundingMode mode)
393 {
394    switch (mode) {
395    case SpvFPRoundingModeRTE:
396       return nir_rounding_mode_rtne;
397    case SpvFPRoundingModeRTZ:
398       return nir_rounding_mode_rtz;
399    case SpvFPRoundingModeRTP:
400       vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
401                   "FPRoundingModeRTP is only supported in kernels");
402       return nir_rounding_mode_ru;
403    case SpvFPRoundingModeRTN:
404       vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
405                   "FPRoundingModeRTN is only supported in kernels");
406       return nir_rounding_mode_rd;
407    default:
408       vtn_fail("Unsupported rounding mode: %s",
409                spirv_fproundingmode_to_string(mode));
410       break;
411    }
412 }
413 
414 struct conversion_opts {
415    nir_rounding_mode rounding_mode;
416    bool saturate;
417 };
418 
419 static void
handle_conversion_opts(struct vtn_builder * b,struct vtn_value * val,int member,const struct vtn_decoration * dec,void * _opts)420 handle_conversion_opts(struct vtn_builder *b, struct vtn_value *val, int member,
421                        const struct vtn_decoration *dec, void *_opts)
422 {
423    struct conversion_opts *opts = _opts;
424 
425    switch (dec->decoration) {
426    case SpvDecorationFPRoundingMode:
427       opts->rounding_mode = vtn_rounding_mode_to_nir(b, dec->operands[0]);
428       break;
429 
430    case SpvDecorationSaturatedConversion:
431       vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
432                   "Saturated conversions are only allowed in kernels");
433       opts->saturate = true;
434       break;
435 
436    default:
437       break;
438    }
439 }
440 
441 static void
handle_no_wrap(struct vtn_builder * b,struct vtn_value * val,int member,const struct vtn_decoration * dec,void * _alu)442 handle_no_wrap(struct vtn_builder *b, struct vtn_value *val, int member,
443                const struct vtn_decoration *dec, void *_alu)
444 {
445    nir_alu_instr *alu = _alu;
446    switch (dec->decoration) {
447    case SpvDecorationNoSignedWrap:
448       alu->no_signed_wrap = true;
449       break;
450    case SpvDecorationNoUnsignedWrap:
451       alu->no_unsigned_wrap = true;
452       break;
453    default:
454       /* Do nothing. */
455       break;
456    }
457 }
458 
459 void
vtn_handle_alu(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)460 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
461                const uint32_t *w, unsigned count)
462 {
463    struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
464    const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
465 
466    vtn_foreach_decoration(b, dest_val, handle_no_contraction, NULL);
467 
468    /* Collect the various SSA sources */
469    const unsigned num_inputs = count - 3;
470    struct vtn_ssa_value *vtn_src[4] = { NULL, };
471    for (unsigned i = 0; i < num_inputs; i++)
472       vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
473 
474    if (glsl_type_is_matrix(vtn_src[0]->type) ||
475        (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
476       vtn_push_ssa_value(b, w[2],
477          vtn_handle_matrix_alu(b, opcode, vtn_src[0], vtn_src[1]));
478       b->nb.exact = b->exact;
479       return;
480    }
481 
482    struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
483    nir_ssa_def *src[4] = { NULL, };
484    for (unsigned i = 0; i < num_inputs; i++) {
485       vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
486       src[i] = vtn_src[i]->def;
487    }
488 
489    switch (opcode) {
490    case SpvOpAny:
491       dest->def = nir_bany(&b->nb, src[0]);
492       break;
493 
494    case SpvOpAll:
495       dest->def = nir_ball(&b->nb, src[0]);
496       break;
497 
498    case SpvOpOuterProduct: {
499       for (unsigned i = 0; i < src[1]->num_components; i++) {
500          dest->elems[i]->def =
501             nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
502       }
503       break;
504    }
505 
506    case SpvOpDot:
507       dest->def = nir_fdot(&b->nb, src[0], src[1]);
508       break;
509 
510    case SpvOpIAddCarry:
511       vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
512       dest->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
513       dest->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
514       break;
515 
516    case SpvOpISubBorrow:
517       vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
518       dest->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
519       dest->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
520       break;
521 
522    case SpvOpUMulExtended: {
523       vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
524       nir_ssa_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]);
525       dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
526       dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
527       break;
528    }
529 
530    case SpvOpSMulExtended: {
531       vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
532       nir_ssa_def *smul = nir_imul_2x32_64(&b->nb, src[0], src[1]);
533       dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, smul);
534       dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, smul);
535       break;
536    }
537 
538    case SpvOpFwidth:
539       dest->def = nir_fadd(&b->nb,
540                                nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
541                                nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));
542       break;
543    case SpvOpFwidthFine:
544       dest->def = nir_fadd(&b->nb,
545                                nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
546                                nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));
547       break;
548    case SpvOpFwidthCoarse:
549       dest->def = nir_fadd(&b->nb,
550                                nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
551                                nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));
552       break;
553 
554    case SpvOpVectorTimesScalar:
555       /* The builder will take care of splatting for us. */
556       dest->def = nir_fmul(&b->nb, src[0], src[1]);
557       break;
558 
559    case SpvOpIsNan: {
560       const bool save_exact = b->nb.exact;
561 
562       b->nb.exact = true;
563       dest->def = nir_fneu(&b->nb, src[0], src[0]);
564       b->nb.exact = save_exact;
565       break;
566    }
567 
568    case SpvOpOrdered: {
569       const bool save_exact = b->nb.exact;
570 
571       b->nb.exact = true;
572       dest->def = nir_iand(&b->nb, nir_feq(&b->nb, src[0], src[0]),
573                                    nir_feq(&b->nb, src[1], src[1]));
574       b->nb.exact = save_exact;
575       break;
576    }
577 
578    case SpvOpUnordered: {
579       const bool save_exact = b->nb.exact;
580 
581       b->nb.exact = true;
582       dest->def = nir_ior(&b->nb, nir_fneu(&b->nb, src[0], src[0]),
583                                   nir_fneu(&b->nb, src[1], src[1]));
584       b->nb.exact = save_exact;
585       break;
586    }
587 
588    case SpvOpIsInf: {
589       nir_ssa_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size);
590       dest->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf);
591       break;
592    }
593 
594    case SpvOpFUnordEqual:
595    case SpvOpFUnordNotEqual:
596    case SpvOpFUnordLessThan:
597    case SpvOpFUnordGreaterThan:
598    case SpvOpFUnordLessThanEqual:
599    case SpvOpFUnordGreaterThanEqual: {
600       bool swap;
601       bool unused_exact;
602       unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
603       unsigned dst_bit_size = glsl_get_bit_size(dest_type);
604       nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
605                                                   &unused_exact,
606                                                   src_bit_size, dst_bit_size);
607 
608       if (swap) {
609          nir_ssa_def *tmp = src[0];
610          src[0] = src[1];
611          src[1] = tmp;
612       }
613 
614       const bool save_exact = b->nb.exact;
615 
616       b->nb.exact = true;
617 
618       dest->def =
619          nir_ior(&b->nb,
620                  nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
621                  nir_ior(&b->nb,
622                          nir_fneu(&b->nb, src[0], src[0]),
623                          nir_fneu(&b->nb, src[1], src[1])));
624 
625       b->nb.exact = save_exact;
626       break;
627    }
628 
629    case SpvOpLessOrGreater:
630    case SpvOpFOrdNotEqual: {
631       /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value
632        * from the ALU will probably already be false if the operands are not
633        * ordered so we don’t need to handle it specially.
634        */
635       bool swap;
636       bool exact;
637       unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
638       unsigned dst_bit_size = glsl_get_bit_size(dest_type);
639       nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact,
640                                                   src_bit_size, dst_bit_size);
641 
642       assert(!swap);
643       assert(exact);
644 
645       const bool save_exact = b->nb.exact;
646 
647       b->nb.exact = true;
648 
649       dest->def =
650          nir_iand(&b->nb,
651                   nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
652                   nir_iand(&b->nb,
653                           nir_feq(&b->nb, src[0], src[0]),
654                           nir_feq(&b->nb, src[1], src[1])));
655 
656       b->nb.exact = save_exact;
657       break;
658    }
659 
660    case SpvOpUConvert:
661    case SpvOpConvertFToU:
662    case SpvOpConvertFToS:
663    case SpvOpConvertSToF:
664    case SpvOpConvertUToF:
665    case SpvOpSConvert:
666    case SpvOpFConvert:
667    case SpvOpSatConvertSToU:
668    case SpvOpSatConvertUToS: {
669       unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
670       unsigned dst_bit_size = glsl_get_bit_size(dest_type);
671       nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
672       nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
673 
674       struct conversion_opts opts = {
675          .rounding_mode = nir_rounding_mode_undef,
676          .saturate = false,
677       };
678       vtn_foreach_decoration(b, dest_val, handle_conversion_opts, &opts);
679 
680       if (opcode == SpvOpSatConvertSToU || opcode == SpvOpSatConvertUToS)
681          opts.saturate = true;
682 
683       if (b->shader->info.stage == MESA_SHADER_KERNEL) {
684          if (opts.rounding_mode == nir_rounding_mode_undef && !opts.saturate) {
685             nir_op op = nir_type_conversion_op(src_type, dst_type,
686                                                nir_rounding_mode_undef);
687             dest->def = nir_build_alu(&b->nb, op, src[0], NULL, NULL, NULL);
688          } else {
689             dest->def = nir_convert_alu_types(&b->nb, src[0], src_type,
690                                               dst_type, opts.rounding_mode,
691                                               opts.saturate);
692          }
693       } else {
694          vtn_fail_if(opts.rounding_mode != nir_rounding_mode_undef &&
695                      dst_type != nir_type_float16,
696                      "Rounding modes are only allowed on conversions to "
697                      "16-bit float types");
698          nir_op op = nir_type_conversion_op(src_type, dst_type,
699                                             opts.rounding_mode);
700          dest->def = nir_build_alu(&b->nb, op, src[0], NULL, NULL, NULL);
701       }
702       break;
703    }
704 
705    case SpvOpBitFieldInsert:
706    case SpvOpBitFieldSExtract:
707    case SpvOpBitFieldUExtract:
708    case SpvOpShiftLeftLogical:
709    case SpvOpShiftRightArithmetic:
710    case SpvOpShiftRightLogical: {
711       bool swap;
712       bool exact;
713       unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type);
714       unsigned dst_bit_size = glsl_get_bit_size(dest_type);
715       nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact,
716                                                   src0_bit_size, dst_bit_size);
717 
718       assert(!exact);
719 
720       assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl ||
721               op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract ||
722               op == nir_op_ibitfield_extract);
723 
724       for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
725          unsigned src_bit_size =
726             nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]);
727          if (src_bit_size == 0)
728             continue;
729          if (src_bit_size != src[i]->bit_size) {
730             assert(src_bit_size == 32);
731             /* Convert the Shift, Offset and Count  operands to 32 bits, which is the bitsize
732              * supported by the NIR instructions. See discussion here:
733              *
734              * https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html
735              */
736             src[i] = nir_u2u32(&b->nb, src[i]);
737          }
738       }
739       dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
740       break;
741    }
742 
743    case SpvOpSignBitSet:
744       dest->def = nir_i2b(&b->nb,
745          nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src[0]->bit_size - 1)));
746       break;
747 
748    case SpvOpUCountTrailingZerosINTEL:
749       dest->def = nir_umin(&b->nb,
750                                nir_find_lsb(&b->nb, src[0]),
751                                nir_imm_int(&b->nb, 32u));
752       break;
753 
754    case SpvOpBitCount: {
755       /* bit_count always returns int32, but the SPIR-V opcode just says the return
756        * value needs to be big enough to store the number of bits.
757        */
758       dest->def = nir_u2u(&b->nb, nir_bit_count(&b->nb, src[0]), glsl_get_bit_size(dest_type));
759       break;
760    }
761 
762    default: {
763       bool swap;
764       bool exact;
765       unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
766       unsigned dst_bit_size = glsl_get_bit_size(dest_type);
767       nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
768                                                   &exact,
769                                                   src_bit_size, dst_bit_size);
770 
771       if (swap) {
772          nir_ssa_def *tmp = src[0];
773          src[0] = src[1];
774          src[1] = tmp;
775       }
776 
777       switch (op) {
778       case nir_op_ishl:
779       case nir_op_ishr:
780       case nir_op_ushr:
781          if (src[1]->bit_size != 32)
782             src[1] = nir_u2u32(&b->nb, src[1]);
783          break;
784       default:
785          break;
786       }
787 
788       const bool save_exact = b->nb.exact;
789 
790       if (exact)
791          b->nb.exact = true;
792 
793       dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
794 
795       b->nb.exact = save_exact;
796       break;
797    } /* default */
798    }
799 
800    switch (opcode) {
801    case SpvOpIAdd:
802    case SpvOpIMul:
803    case SpvOpISub:
804    case SpvOpShiftLeftLogical:
805    case SpvOpSNegate: {
806       nir_alu_instr *alu = nir_instr_as_alu(dest->def->parent_instr);
807       vtn_foreach_decoration(b, dest_val, handle_no_wrap, alu);
808       break;
809    }
810    default:
811       /* Do nothing. */
812       break;
813    }
814 
815    vtn_push_ssa_value(b, w[2], dest);
816 
817    b->nb.exact = b->exact;
818 }
819 
820 void
vtn_handle_bitcast(struct vtn_builder * b,const uint32_t * w,unsigned count)821 vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
822 {
823    vtn_assert(count == 4);
824    /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
825     *
826     *    "If Result Type has the same number of components as Operand, they
827     *    must also have the same component width, and results are computed per
828     *    component.
829     *
830     *    If Result Type has a different number of components than Operand, the
831     *    total number of bits in Result Type must equal the total number of
832     *    bits in Operand. Let L be the type, either Result Type or Operand’s
833     *    type, that has the larger number of components. Let S be the other
834     *    type, with the smaller number of components. The number of components
835     *    in L must be an integer multiple of the number of components in S.
836     *    The first component (that is, the only or lowest-numbered component)
837     *    of S maps to the first components of L, and so on, up to the last
838     *    component of S mapping to the last components of L. Within this
839     *    mapping, any single component of S (mapping to multiple components of
840     *    L) maps its lower-ordered bits to the lower-numbered components of L."
841     */
842 
843    struct vtn_type *type = vtn_get_type(b, w[1]);
844    struct nir_ssa_def *src = vtn_get_nir_ssa(b, w[3]);
845 
846    vtn_fail_if(src->num_components * src->bit_size !=
847                glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type),
848                "Source and destination of OpBitcast must have the same "
849                "total number of bits");
850    nir_ssa_def *val =
851       nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type));
852    vtn_push_nir_ssa(b, w[2], val);
853 }
854