1 /*
2  * Copyright © 2016 Broadcom
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "nir_deref.h"
27 
28 /** @file nir_lower_io_to_scalar.c
29  *
30  * Replaces nir_load_input/nir_store_output operations with num_components !=
31  * 1 with individual per-channel operations.
32  */
33 
34 static void
lower_load_input_to_scalar(nir_builder * b,nir_intrinsic_instr * intr)35 lower_load_input_to_scalar(nir_builder *b, nir_intrinsic_instr *intr)
36 {
37    b->cursor = nir_before_instr(&intr->instr);
38 
39    assert(intr->dest.is_ssa);
40 
41    nir_ssa_def *loads[NIR_MAX_VEC_COMPONENTS];
42 
43    for (unsigned i = 0; i < intr->num_components; i++) {
44       nir_intrinsic_instr *chan_intr =
45          nir_intrinsic_instr_create(b->shader, intr->intrinsic);
46       nir_ssa_dest_init(&chan_intr->instr, &chan_intr->dest,
47                         1, intr->dest.ssa.bit_size, NULL);
48       chan_intr->num_components = 1;
49 
50       nir_intrinsic_set_base(chan_intr, nir_intrinsic_base(intr));
51       nir_intrinsic_set_component(chan_intr, nir_intrinsic_component(intr) + i);
52       nir_intrinsic_set_dest_type(chan_intr, nir_intrinsic_dest_type(intr));
53       nir_intrinsic_set_io_semantics(chan_intr, nir_intrinsic_io_semantics(intr));
54       /* offset */
55       nir_src_copy(&chan_intr->src[0], &intr->src[0], chan_intr);
56 
57       nir_builder_instr_insert(b, &chan_intr->instr);
58 
59       loads[i] = &chan_intr->dest.ssa;
60    }
61 
62    nir_ssa_def_rewrite_uses(&intr->dest.ssa,
63                             nir_src_for_ssa(nir_vec(b, loads,
64                                                     intr->num_components)));
65    nir_instr_remove(&intr->instr);
66 }
67 
68 static void
lower_store_output_to_scalar(nir_builder * b,nir_intrinsic_instr * intr)69 lower_store_output_to_scalar(nir_builder *b, nir_intrinsic_instr *intr)
70 {
71    b->cursor = nir_before_instr(&intr->instr);
72 
73    nir_ssa_def *value = nir_ssa_for_src(b, intr->src[0], intr->num_components);
74 
75    for (unsigned i = 0; i < intr->num_components; i++) {
76       if (!(nir_intrinsic_write_mask(intr) & (1 << i)))
77          continue;
78 
79       nir_intrinsic_instr *chan_intr =
80          nir_intrinsic_instr_create(b->shader, intr->intrinsic);
81       chan_intr->num_components = 1;
82 
83       nir_intrinsic_set_base(chan_intr, nir_intrinsic_base(intr));
84       nir_intrinsic_set_write_mask(chan_intr, 0x1);
85       nir_intrinsic_set_component(chan_intr, nir_intrinsic_component(intr) + i);
86       nir_intrinsic_set_src_type(chan_intr, nir_intrinsic_src_type(intr));
87       nir_intrinsic_set_io_semantics(chan_intr, nir_intrinsic_io_semantics(intr));
88 
89       /* value */
90       chan_intr->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
91       /* offset */
92       nir_src_copy(&chan_intr->src[1], &intr->src[1], chan_intr);
93 
94       nir_builder_instr_insert(b, &chan_intr->instr);
95    }
96 
97    nir_instr_remove(&intr->instr);
98 }
99 
100 static bool
nir_lower_io_to_scalar_instr(nir_builder * b,nir_instr * instr,void * data)101 nir_lower_io_to_scalar_instr(nir_builder *b, nir_instr *instr, void *data)
102 {
103    nir_variable_mode mask = *(nir_variable_mode *)data;
104 
105    if (instr->type != nir_instr_type_intrinsic)
106       return false;
107 
108    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
109 
110    if (intr->num_components == 1)
111       return false;
112 
113    if (intr->intrinsic == nir_intrinsic_load_input &&
114        (mask & nir_var_shader_in)) {
115       lower_load_input_to_scalar(b, intr);
116       return true;
117    }
118 
119    if (intr->intrinsic == nir_intrinsic_store_output &&
120        mask & nir_var_shader_out) {
121       lower_store_output_to_scalar(b, intr);
122       return true;
123    }
124 
125    return false;
126 }
127 
128 void
nir_lower_io_to_scalar(nir_shader * shader,nir_variable_mode mask)129 nir_lower_io_to_scalar(nir_shader *shader, nir_variable_mode mask)
130 {
131    nir_shader_instructions_pass(shader,
132                                 nir_lower_io_to_scalar_instr,
133                                 nir_metadata_block_index |
134                                 nir_metadata_dominance,
135                                 &mask);
136 }
137 
138 static nir_variable **
get_channel_variables(struct hash_table * ht,nir_variable * var)139 get_channel_variables(struct hash_table *ht, nir_variable *var)
140 {
141    nir_variable **chan_vars;
142    struct hash_entry *entry = _mesa_hash_table_search(ht, var);
143    if (!entry) {
144       chan_vars = (nir_variable **) calloc(4, sizeof(nir_variable *));
145       _mesa_hash_table_insert(ht, var, chan_vars);
146    } else {
147       chan_vars = (nir_variable **) entry->data;
148    }
149 
150    return chan_vars;
151 }
152 
153 /*
154  * Note that the src deref that we are cloning is the head of the
155  * chain of deref instructions from the original intrinsic, but
156  * the dst we are cloning to is the tail (because chains of deref
157  * instructions are created back to front)
158  */
159 
160 static nir_deref_instr *
clone_deref_array(nir_builder * b,nir_deref_instr * dst_tail,const nir_deref_instr * src_head)161 clone_deref_array(nir_builder *b, nir_deref_instr *dst_tail,
162                   const nir_deref_instr *src_head)
163 {
164    const nir_deref_instr *parent = nir_deref_instr_parent(src_head);
165 
166    if (!parent)
167       return dst_tail;
168 
169    assert(src_head->deref_type == nir_deref_type_array);
170 
171    dst_tail = clone_deref_array(b, dst_tail, parent);
172 
173    return nir_build_deref_array(b, dst_tail,
174                                 nir_ssa_for_src(b, src_head->arr.index, 1));
175 }
176 
177 static void
lower_load_to_scalar_early(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var,struct hash_table * split_inputs,struct hash_table * split_outputs)178 lower_load_to_scalar_early(nir_builder *b, nir_intrinsic_instr *intr,
179                            nir_variable *var, struct hash_table *split_inputs,
180                            struct hash_table *split_outputs)
181 {
182    b->cursor = nir_before_instr(&intr->instr);
183 
184    assert(intr->dest.is_ssa);
185 
186    nir_ssa_def *loads[NIR_MAX_VEC_COMPONENTS];
187 
188    nir_variable **chan_vars;
189    if (var->data.mode == nir_var_shader_in) {
190       chan_vars = get_channel_variables(split_inputs, var);
191    } else {
192       chan_vars = get_channel_variables(split_outputs, var);
193    }
194 
195    for (unsigned i = 0; i < intr->num_components; i++) {
196       nir_variable *chan_var = chan_vars[var->data.location_frac + i];
197       if (!chan_vars[var->data.location_frac + i]) {
198          chan_var = nir_variable_clone(var, b->shader);
199          chan_var->data.location_frac =  var->data.location_frac + i;
200          chan_var->type = glsl_channel_type(chan_var->type);
201          if (var->data.explicit_offset) {
202             unsigned comp_size = glsl_get_bit_size(chan_var->type) / 8;
203             chan_var->data.offset = var->data.offset + i * comp_size;
204          }
205 
206          chan_vars[var->data.location_frac + i] = chan_var;
207 
208          nir_shader_add_variable(b->shader, chan_var);
209       }
210 
211       nir_intrinsic_instr *chan_intr =
212          nir_intrinsic_instr_create(b->shader, intr->intrinsic);
213       nir_ssa_dest_init(&chan_intr->instr, &chan_intr->dest,
214                         1, intr->dest.ssa.bit_size, NULL);
215       chan_intr->num_components = 1;
216 
217       nir_deref_instr *deref = nir_build_deref_var(b, chan_var);
218 
219       deref = clone_deref_array(b, deref, nir_src_as_deref(intr->src[0]));
220 
221       chan_intr->src[0] = nir_src_for_ssa(&deref->dest.ssa);
222 
223       if (intr->intrinsic == nir_intrinsic_interp_deref_at_offset ||
224           intr->intrinsic == nir_intrinsic_interp_deref_at_sample ||
225           intr->intrinsic == nir_intrinsic_interp_deref_at_vertex)
226          nir_src_copy(&chan_intr->src[1], &intr->src[1], &chan_intr->instr);
227 
228       nir_builder_instr_insert(b, &chan_intr->instr);
229 
230       loads[i] = &chan_intr->dest.ssa;
231    }
232 
233    nir_ssa_def_rewrite_uses(&intr->dest.ssa,
234                             nir_src_for_ssa(nir_vec(b, loads,
235                                                     intr->num_components)));
236 
237    /* Remove the old load intrinsic */
238    nir_instr_remove(&intr->instr);
239 }
240 
241 static void
lower_store_output_to_scalar_early(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var,struct hash_table * split_outputs)242 lower_store_output_to_scalar_early(nir_builder *b, nir_intrinsic_instr *intr,
243                                    nir_variable *var,
244                                    struct hash_table *split_outputs)
245 {
246    b->cursor = nir_before_instr(&intr->instr);
247 
248    nir_ssa_def *value = nir_ssa_for_src(b, intr->src[1], intr->num_components);
249 
250    nir_variable **chan_vars = get_channel_variables(split_outputs, var);
251    for (unsigned i = 0; i < intr->num_components; i++) {
252       if (!(nir_intrinsic_write_mask(intr) & (1 << i)))
253          continue;
254 
255       nir_variable *chan_var = chan_vars[var->data.location_frac + i];
256       if (!chan_vars[var->data.location_frac + i]) {
257          chan_var = nir_variable_clone(var, b->shader);
258          chan_var->data.location_frac =  var->data.location_frac + i;
259          chan_var->type = glsl_channel_type(chan_var->type);
260          if (var->data.explicit_offset) {
261             unsigned comp_size = glsl_get_bit_size(chan_var->type) / 8;
262             chan_var->data.offset = var->data.offset + i * comp_size;
263          }
264 
265          chan_vars[var->data.location_frac + i] = chan_var;
266 
267          nir_shader_add_variable(b->shader, chan_var);
268       }
269 
270       nir_intrinsic_instr *chan_intr =
271          nir_intrinsic_instr_create(b->shader, intr->intrinsic);
272       chan_intr->num_components = 1;
273 
274       nir_intrinsic_set_write_mask(chan_intr, 0x1);
275 
276       nir_deref_instr *deref = nir_build_deref_var(b, chan_var);
277 
278       deref = clone_deref_array(b, deref, nir_src_as_deref(intr->src[0]));
279 
280       chan_intr->src[0] = nir_src_for_ssa(&deref->dest.ssa);
281       chan_intr->src[1] = nir_src_for_ssa(nir_channel(b, value, i));
282 
283       nir_builder_instr_insert(b, &chan_intr->instr);
284    }
285 
286    /* Remove the old store intrinsic */
287    nir_instr_remove(&intr->instr);
288 }
289 
290 struct io_to_scalar_early_state {
291    struct hash_table *split_inputs, *split_outputs;
292    nir_variable_mode mask;
293 };
294 
295 static bool
nir_lower_io_to_scalar_early_instr(nir_builder * b,nir_instr * instr,void * data)296 nir_lower_io_to_scalar_early_instr(nir_builder *b, nir_instr *instr, void *data)
297 {
298    struct io_to_scalar_early_state *state = data;
299 
300    if (instr->type != nir_instr_type_intrinsic)
301       return false;
302 
303    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
304 
305    if (intr->num_components == 1)
306       return false;
307 
308    if (intr->intrinsic != nir_intrinsic_load_deref &&
309        intr->intrinsic != nir_intrinsic_store_deref &&
310        intr->intrinsic != nir_intrinsic_interp_deref_at_centroid &&
311        intr->intrinsic != nir_intrinsic_interp_deref_at_sample &&
312        intr->intrinsic != nir_intrinsic_interp_deref_at_offset &&
313        intr->intrinsic != nir_intrinsic_interp_deref_at_vertex)
314       return false;
315 
316    nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
317    if (!nir_deref_mode_is_one_of(deref, state->mask))
318       return false;
319 
320    nir_variable *var = nir_deref_instr_get_variable(deref);
321    nir_variable_mode mode = var->data.mode;
322 
323    /* TODO: add patch support */
324    if (var->data.patch)
325       return false;
326 
327    /* TODO: add doubles support */
328    if (glsl_type_is_64bit(glsl_without_array(var->type)))
329       return false;
330 
331    if (!(b->shader->info.stage == MESA_SHADER_VERTEX &&
332          mode == nir_var_shader_in) &&
333        var->data.location < VARYING_SLOT_VAR0 &&
334        var->data.location >= 0)
335       return false;
336 
337    /* Don't bother splitting if we can't opt away any unused
338     * components.
339     */
340    if (var->data.always_active_io)
341       return false;
342 
343    /* Skip types we cannot split */
344    if (glsl_type_is_matrix(glsl_without_array(var->type)) ||
345        glsl_type_is_struct_or_ifc(glsl_without_array(var->type)))
346       return false;
347 
348    switch (intr->intrinsic) {
349    case nir_intrinsic_interp_deref_at_centroid:
350    case nir_intrinsic_interp_deref_at_sample:
351    case nir_intrinsic_interp_deref_at_offset:
352    case nir_intrinsic_interp_deref_at_vertex:
353    case nir_intrinsic_load_deref:
354       if ((state->mask & nir_var_shader_in && mode == nir_var_shader_in) ||
355           (state->mask & nir_var_shader_out && mode == nir_var_shader_out)) {
356          lower_load_to_scalar_early(b, intr, var, state->split_inputs,
357                                     state->split_outputs);
358          return true;
359       }
360       break;
361    case nir_intrinsic_store_deref:
362       if (state->mask & nir_var_shader_out &&
363           mode == nir_var_shader_out) {
364          lower_store_output_to_scalar_early(b, intr, var, state->split_outputs);
365          return true;
366       }
367       break;
368    default:
369       break;
370    }
371 
372    return false;
373 }
374 
375 /*
376  * This function is intended to be called earlier than nir_lower_io_to_scalar()
377  * i.e. before nir_lower_io() is called.
378  */
379 bool
nir_lower_io_to_scalar_early(nir_shader * shader,nir_variable_mode mask)380 nir_lower_io_to_scalar_early(nir_shader *shader, nir_variable_mode mask)
381 {
382    struct io_to_scalar_early_state state = {
383       .split_inputs = _mesa_pointer_hash_table_create(NULL),
384       .split_outputs = _mesa_pointer_hash_table_create(NULL),
385       .mask = mask
386    };
387 
388    bool progress = nir_shader_instructions_pass(shader,
389                                                 nir_lower_io_to_scalar_early_instr,
390                                                 nir_metadata_block_index |
391                                                 nir_metadata_dominance,
392                                                 &state);
393 
394    /* Remove old input from the shaders inputs list */
395    hash_table_foreach(state.split_inputs, entry) {
396       nir_variable *var = (nir_variable *) entry->key;
397       exec_node_remove(&var->node);
398 
399       free(entry->data);
400    }
401 
402    /* Remove old output from the shaders outputs list */
403    hash_table_foreach(state.split_outputs, entry) {
404       nir_variable *var = (nir_variable *) entry->key;
405       exec_node_remove(&var->node);
406 
407       free(entry->data);
408    }
409 
410    _mesa_hash_table_destroy(state.split_inputs, NULL);
411    _mesa_hash_table_destroy(state.split_outputs, NULL);
412 
413    nir_remove_dead_derefs(shader);
414 
415    return progress;
416 }
417