1 /*
2  * Copyright © 2015 Intel Corporation
3  * Copyright © 2019 Valve Corporation
4  *
5  * Permission is hereby granted, free of charge, to any person obtaining a
6  * copy of this software and associated documentation files (the "Software"),
7  * to deal in the Software without restriction, including without limitation
8  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
9  * and/or sell copies of the Software, and to permit persons to whom the
10  * Software is furnished to do so, subject to the following conditions:
11  *
12  * The above copyright notice and this permission notice (including the next
13  * paragraph) shall be included in all copies or substantial portions of the
14  * Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
19  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22  * IN THE SOFTWARE.
23  *
24  * Authors:
25  *    Jason Ekstrand (jason@jlekstrand.net)
26  *    Samuel Pitoiset (samuel.pitoiset@gmail.com>
27  */
28 
29 #include "nir.h"
30 #include "nir_builder.h"
31 
32 static nir_ssa_def *
lower_frexp_sig(nir_builder * b,nir_ssa_def * x)33 lower_frexp_sig(nir_builder *b, nir_ssa_def *x)
34 {
35    nir_ssa_def *abs_x = nir_fabs(b, x);
36    nir_ssa_def *zero = nir_imm_floatN_t(b, 0, x->bit_size);
37    nir_ssa_def *sign_mantissa_mask, *exponent_value;
38    nir_ssa_def *is_not_zero = nir_fneu(b, abs_x, zero);
39 
40    switch (x->bit_size) {
41    case 16:
42       /* Half-precision floating-point values are stored as
43        *   1 sign bit;
44        *   5 exponent bits;
45        *   10 mantissa bits.
46        *
47        * An exponent shift of 10 will shift the mantissa out, leaving only the
48        * exponent and sign bit (which itself may be zero, if the absolute value
49        * was taken before the bitcast and shift).
50        */
51       sign_mantissa_mask = nir_imm_intN_t(b, 0x83ffu, 16);
52       /* Exponent of floating-point values in the range [0.5, 1.0). */
53       exponent_value = nir_imm_intN_t(b, 0x3800u, 16);
54       break;
55    case 32:
56       /* Single-precision floating-point values are stored as
57        *   1 sign bit;
58        *   8 exponent bits;
59        *   23 mantissa bits.
60        *
61        * An exponent shift of 23 will shift the mantissa out, leaving only the
62        * exponent and sign bit (which itself may be zero, if the absolute value
63        * was taken before the bitcast and shift.
64        */
65       sign_mantissa_mask = nir_imm_int(b, 0x807fffffu);
66       /* Exponent of floating-point values in the range [0.5, 1.0). */
67       exponent_value = nir_imm_int(b, 0x3f000000u);
68       break;
69    case 64:
70       /* Double-precision floating-point values are stored as
71        *   1 sign bit;
72        *   11 exponent bits;
73        *   52 mantissa bits.
74        *
75        * An exponent shift of 20 will shift the remaining mantissa bits out,
76        * leaving only the exponent and sign bit (which itself may be zero, if
77        * the absolute value was taken before the bitcast and shift.
78        */
79       sign_mantissa_mask = nir_imm_int(b, 0x800fffffu);
80       /* Exponent of floating-point values in the range [0.5, 1.0). */
81       exponent_value = nir_imm_int(b, 0x3fe00000u);
82       break;
83    default:
84       unreachable("Invalid bitsize");
85    }
86 
87    if (x->bit_size == 64) {
88       /* We only need to deal with the exponent so first we extract the upper
89        * 32 bits using nir_unpack_64_2x32_split_y.
90        */
91       nir_ssa_def *upper_x = nir_unpack_64_2x32_split_y(b, x);
92       nir_ssa_def *zero32 = nir_imm_int(b, 0);
93 
94       nir_ssa_def *new_upper =
95          nir_ior(b, nir_iand(b, upper_x, sign_mantissa_mask),
96                     nir_bcsel(b, is_not_zero, exponent_value, zero32));
97 
98       nir_ssa_def *lower_x = nir_unpack_64_2x32_split_x(b, x);
99 
100       return nir_pack_64_2x32_split(b, lower_x, new_upper);
101    } else {
102       return nir_ior(b, nir_iand(b, x, sign_mantissa_mask),
103                         nir_bcsel(b, is_not_zero, exponent_value, zero));
104    }
105 }
106 
107 static nir_ssa_def *
lower_frexp_exp(nir_builder * b,nir_ssa_def * x)108 lower_frexp_exp(nir_builder *b, nir_ssa_def *x)
109 {
110    nir_ssa_def *abs_x = nir_fabs(b, x);
111    nir_ssa_def *zero = nir_imm_floatN_t(b, 0, x->bit_size);
112    nir_ssa_def *is_not_zero = nir_fneu(b, abs_x, zero);
113    nir_ssa_def *exponent;
114 
115    switch (x->bit_size) {
116    case 16: {
117       nir_ssa_def *exponent_shift = nir_imm_int(b, 10);
118       nir_ssa_def *exponent_bias = nir_imm_intN_t(b, -14, 16);
119 
120       /* Significand return must be of the same type as the input, but the
121        * exponent must be a 32-bit integer.
122        */
123       exponent = nir_i2i32(b, nir_iadd(b, nir_ushr(b, abs_x, exponent_shift),
124                               nir_bcsel(b, is_not_zero, exponent_bias, zero)));
125       break;
126    }
127    case 32: {
128       nir_ssa_def *exponent_shift = nir_imm_int(b, 23);
129       nir_ssa_def *exponent_bias = nir_imm_int(b, -126);
130 
131       exponent = nir_iadd(b, nir_ushr(b, abs_x, exponent_shift),
132                              nir_bcsel(b, is_not_zero, exponent_bias, zero));
133       break;
134    }
135    case 64: {
136       nir_ssa_def *exponent_shift = nir_imm_int(b, 20);
137       nir_ssa_def *exponent_bias = nir_imm_int(b, -1022);
138 
139       nir_ssa_def *zero32 = nir_imm_int(b, 0);
140       nir_ssa_def *abs_upper_x = nir_unpack_64_2x32_split_y(b, abs_x);
141 
142       exponent = nir_iadd(b, nir_ushr(b, abs_upper_x, exponent_shift),
143                              nir_bcsel(b, is_not_zero, exponent_bias, zero32));
144       break;
145    }
146    default:
147       unreachable("Invalid bitsize");
148    }
149 
150    return exponent;
151 }
152 
153 static bool
lower_frexp_impl(nir_function_impl * impl)154 lower_frexp_impl(nir_function_impl *impl)
155 {
156    bool progress = false;
157 
158    nir_builder b;
159    nir_builder_init(&b, impl);
160 
161    nir_foreach_block(block, impl) {
162       nir_foreach_instr_safe(instr, block) {
163          if (instr->type != nir_instr_type_alu)
164             continue;
165 
166          nir_alu_instr *alu_instr = nir_instr_as_alu(instr);
167          nir_ssa_def *lower;
168 
169          b.cursor = nir_before_instr(instr);
170 
171          switch (alu_instr->op) {
172          case nir_op_frexp_sig:
173             lower = lower_frexp_sig(&b, nir_ssa_for_alu_src(&b, alu_instr, 0));
174             break;
175          case nir_op_frexp_exp:
176             lower = lower_frexp_exp(&b, nir_ssa_for_alu_src(&b, alu_instr, 0));
177             break;
178          default:
179             continue;
180          }
181 
182          nir_ssa_def_rewrite_uses(&alu_instr->dest.dest.ssa,
183                                   nir_src_for_ssa(lower));
184          nir_instr_remove(instr);
185          progress = true;
186       }
187    }
188 
189    if (progress) {
190       nir_metadata_preserve(impl, nir_metadata_block_index |
191                                   nir_metadata_dominance);
192    }
193 
194    return progress;
195 }
196 
197 bool
nir_lower_frexp(nir_shader * shader)198 nir_lower_frexp(nir_shader *shader)
199 {
200    bool progress = false;
201 
202    nir_foreach_function(function, shader) {
203       if (function->impl)
204          progress |= lower_frexp_impl(function->impl);
205    }
206 
207    return progress;
208 }
209