1 /*
2 * Copyright © 2016 Broadcom
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "nir_deref.h"
27
28 /** @file nir_lower_io_to_scalar.c
29 *
30 * Replaces nir_load_input/nir_store_output operations with num_components !=
31 * 1 with individual per-channel operations.
32 */
33
34 static void
lower_load_input_to_scalar(nir_builder * b,nir_intrinsic_instr * intr)35 lower_load_input_to_scalar(nir_builder *b, nir_intrinsic_instr *intr)
36 {
37 b->cursor = nir_before_instr(&intr->instr);
38
39 assert(intr->dest.is_ssa);
40
41 nir_ssa_def *loads[NIR_MAX_VEC_COMPONENTS];
42
43 for (unsigned i = 0; i < intr->num_components; i++) {
44 nir_intrinsic_instr *chan_intr =
45 nir_intrinsic_instr_create(b->shader, intr->intrinsic);
46 nir_ssa_dest_init(&chan_intr->instr, &chan_intr->dest,
47 1, intr->dest.ssa.bit_size, NULL);
48 chan_intr->num_components = 1;
49
50 nir_intrinsic_set_base(chan_intr, nir_intrinsic_base(intr));
51 nir_intrinsic_set_component(chan_intr, nir_intrinsic_component(intr) + i);
52 nir_intrinsic_set_dest_type(chan_intr, nir_intrinsic_dest_type(intr));
53 nir_intrinsic_set_io_semantics(chan_intr, nir_intrinsic_io_semantics(intr));
54 /* offset */
55 nir_src_copy(&chan_intr->src[0], &intr->src[0], chan_intr);
56
57 nir_builder_instr_insert(b, &chan_intr->instr);
58
59 loads[i] = &chan_intr->dest.ssa;
60 }
61
62 nir_ssa_def_rewrite_uses(&intr->dest.ssa,
63 nir_src_for_ssa(nir_vec(b, loads,
64 intr->num_components)));
65 nir_instr_remove(&intr->instr);
66 }
67
68 static void
lower_store_output_to_scalar(nir_builder * b,nir_intrinsic_instr * intr)69 lower_store_output_to_scalar(nir_builder *b, nir_intrinsic_instr *intr)
70 {
71 b->cursor = nir_before_instr(&intr->instr);
72
73 nir_ssa_def *value = nir_ssa_for_src(b, intr->src[0], intr->num_components);
74
75 for (unsigned i = 0; i < intr->num_components; i++) {
76 if (!(nir_intrinsic_write_mask(intr) & (1 << i)))
77 continue;
78
79 nir_intrinsic_instr *chan_intr =
80 nir_intrinsic_instr_create(b->shader, intr->intrinsic);
81 chan_intr->num_components = 1;
82
83 nir_intrinsic_set_base(chan_intr, nir_intrinsic_base(intr));
84 nir_intrinsic_set_write_mask(chan_intr, 0x1);
85 nir_intrinsic_set_component(chan_intr, nir_intrinsic_component(intr) + i);
86 nir_intrinsic_set_src_type(chan_intr, nir_intrinsic_src_type(intr));
87 nir_intrinsic_set_io_semantics(chan_intr, nir_intrinsic_io_semantics(intr));
88
89 /* value */
90 chan_intr->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
91 /* offset */
92 nir_src_copy(&chan_intr->src[1], &intr->src[1], chan_intr);
93
94 nir_builder_instr_insert(b, &chan_intr->instr);
95 }
96
97 nir_instr_remove(&intr->instr);
98 }
99
100 static bool
nir_lower_io_to_scalar_instr(nir_builder * b,nir_instr * instr,void * data)101 nir_lower_io_to_scalar_instr(nir_builder *b, nir_instr *instr, void *data)
102 {
103 nir_variable_mode mask = *(nir_variable_mode *)data;
104
105 if (instr->type != nir_instr_type_intrinsic)
106 return false;
107
108 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
109
110 if (intr->num_components == 1)
111 return false;
112
113 if (intr->intrinsic == nir_intrinsic_load_input &&
114 (mask & nir_var_shader_in)) {
115 lower_load_input_to_scalar(b, intr);
116 return true;
117 }
118
119 if (intr->intrinsic == nir_intrinsic_store_output &&
120 mask & nir_var_shader_out) {
121 lower_store_output_to_scalar(b, intr);
122 return true;
123 }
124
125 return false;
126 }
127
128 void
nir_lower_io_to_scalar(nir_shader * shader,nir_variable_mode mask)129 nir_lower_io_to_scalar(nir_shader *shader, nir_variable_mode mask)
130 {
131 nir_shader_instructions_pass(shader,
132 nir_lower_io_to_scalar_instr,
133 nir_metadata_block_index |
134 nir_metadata_dominance,
135 &mask);
136 }
137
138 static nir_variable **
get_channel_variables(struct hash_table * ht,nir_variable * var)139 get_channel_variables(struct hash_table *ht, nir_variable *var)
140 {
141 nir_variable **chan_vars;
142 struct hash_entry *entry = _mesa_hash_table_search(ht, var);
143 if (!entry) {
144 chan_vars = (nir_variable **) calloc(4, sizeof(nir_variable *));
145 _mesa_hash_table_insert(ht, var, chan_vars);
146 } else {
147 chan_vars = (nir_variable **) entry->data;
148 }
149
150 return chan_vars;
151 }
152
153 /*
154 * Note that the src deref that we are cloning is the head of the
155 * chain of deref instructions from the original intrinsic, but
156 * the dst we are cloning to is the tail (because chains of deref
157 * instructions are created back to front)
158 */
159
160 static nir_deref_instr *
clone_deref_array(nir_builder * b,nir_deref_instr * dst_tail,const nir_deref_instr * src_head)161 clone_deref_array(nir_builder *b, nir_deref_instr *dst_tail,
162 const nir_deref_instr *src_head)
163 {
164 const nir_deref_instr *parent = nir_deref_instr_parent(src_head);
165
166 if (!parent)
167 return dst_tail;
168
169 assert(src_head->deref_type == nir_deref_type_array);
170
171 dst_tail = clone_deref_array(b, dst_tail, parent);
172
173 return nir_build_deref_array(b, dst_tail,
174 nir_ssa_for_src(b, src_head->arr.index, 1));
175 }
176
177 static void
lower_load_to_scalar_early(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var,struct hash_table * split_inputs,struct hash_table * split_outputs)178 lower_load_to_scalar_early(nir_builder *b, nir_intrinsic_instr *intr,
179 nir_variable *var, struct hash_table *split_inputs,
180 struct hash_table *split_outputs)
181 {
182 b->cursor = nir_before_instr(&intr->instr);
183
184 assert(intr->dest.is_ssa);
185
186 nir_ssa_def *loads[NIR_MAX_VEC_COMPONENTS];
187
188 nir_variable **chan_vars;
189 if (var->data.mode == nir_var_shader_in) {
190 chan_vars = get_channel_variables(split_inputs, var);
191 } else {
192 chan_vars = get_channel_variables(split_outputs, var);
193 }
194
195 for (unsigned i = 0; i < intr->num_components; i++) {
196 nir_variable *chan_var = chan_vars[var->data.location_frac + i];
197 if (!chan_vars[var->data.location_frac + i]) {
198 chan_var = nir_variable_clone(var, b->shader);
199 chan_var->data.location_frac = var->data.location_frac + i;
200 chan_var->type = glsl_channel_type(chan_var->type);
201 if (var->data.explicit_offset) {
202 unsigned comp_size = glsl_get_bit_size(chan_var->type) / 8;
203 chan_var->data.offset = var->data.offset + i * comp_size;
204 }
205
206 chan_vars[var->data.location_frac + i] = chan_var;
207
208 nir_shader_add_variable(b->shader, chan_var);
209 }
210
211 nir_intrinsic_instr *chan_intr =
212 nir_intrinsic_instr_create(b->shader, intr->intrinsic);
213 nir_ssa_dest_init(&chan_intr->instr, &chan_intr->dest,
214 1, intr->dest.ssa.bit_size, NULL);
215 chan_intr->num_components = 1;
216
217 nir_deref_instr *deref = nir_build_deref_var(b, chan_var);
218
219 deref = clone_deref_array(b, deref, nir_src_as_deref(intr->src[0]));
220
221 chan_intr->src[0] = nir_src_for_ssa(&deref->dest.ssa);
222
223 if (intr->intrinsic == nir_intrinsic_interp_deref_at_offset ||
224 intr->intrinsic == nir_intrinsic_interp_deref_at_sample ||
225 intr->intrinsic == nir_intrinsic_interp_deref_at_vertex)
226 nir_src_copy(&chan_intr->src[1], &intr->src[1], &chan_intr->instr);
227
228 nir_builder_instr_insert(b, &chan_intr->instr);
229
230 loads[i] = &chan_intr->dest.ssa;
231 }
232
233 nir_ssa_def_rewrite_uses(&intr->dest.ssa,
234 nir_src_for_ssa(nir_vec(b, loads,
235 intr->num_components)));
236
237 /* Remove the old load intrinsic */
238 nir_instr_remove(&intr->instr);
239 }
240
241 static void
lower_store_output_to_scalar_early(nir_builder * b,nir_intrinsic_instr * intr,nir_variable * var,struct hash_table * split_outputs)242 lower_store_output_to_scalar_early(nir_builder *b, nir_intrinsic_instr *intr,
243 nir_variable *var,
244 struct hash_table *split_outputs)
245 {
246 b->cursor = nir_before_instr(&intr->instr);
247
248 nir_ssa_def *value = nir_ssa_for_src(b, intr->src[1], intr->num_components);
249
250 nir_variable **chan_vars = get_channel_variables(split_outputs, var);
251 for (unsigned i = 0; i < intr->num_components; i++) {
252 if (!(nir_intrinsic_write_mask(intr) & (1 << i)))
253 continue;
254
255 nir_variable *chan_var = chan_vars[var->data.location_frac + i];
256 if (!chan_vars[var->data.location_frac + i]) {
257 chan_var = nir_variable_clone(var, b->shader);
258 chan_var->data.location_frac = var->data.location_frac + i;
259 chan_var->type = glsl_channel_type(chan_var->type);
260 if (var->data.explicit_offset) {
261 unsigned comp_size = glsl_get_bit_size(chan_var->type) / 8;
262 chan_var->data.offset = var->data.offset + i * comp_size;
263 }
264
265 chan_vars[var->data.location_frac + i] = chan_var;
266
267 nir_shader_add_variable(b->shader, chan_var);
268 }
269
270 nir_intrinsic_instr *chan_intr =
271 nir_intrinsic_instr_create(b->shader, intr->intrinsic);
272 chan_intr->num_components = 1;
273
274 nir_intrinsic_set_write_mask(chan_intr, 0x1);
275
276 nir_deref_instr *deref = nir_build_deref_var(b, chan_var);
277
278 deref = clone_deref_array(b, deref, nir_src_as_deref(intr->src[0]));
279
280 chan_intr->src[0] = nir_src_for_ssa(&deref->dest.ssa);
281 chan_intr->src[1] = nir_src_for_ssa(nir_channel(b, value, i));
282
283 nir_builder_instr_insert(b, &chan_intr->instr);
284 }
285
286 /* Remove the old store intrinsic */
287 nir_instr_remove(&intr->instr);
288 }
289
290 struct io_to_scalar_early_state {
291 struct hash_table *split_inputs, *split_outputs;
292 nir_variable_mode mask;
293 };
294
295 static bool
nir_lower_io_to_scalar_early_instr(nir_builder * b,nir_instr * instr,void * data)296 nir_lower_io_to_scalar_early_instr(nir_builder *b, nir_instr *instr, void *data)
297 {
298 struct io_to_scalar_early_state *state = data;
299
300 if (instr->type != nir_instr_type_intrinsic)
301 return false;
302
303 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
304
305 if (intr->num_components == 1)
306 return false;
307
308 if (intr->intrinsic != nir_intrinsic_load_deref &&
309 intr->intrinsic != nir_intrinsic_store_deref &&
310 intr->intrinsic != nir_intrinsic_interp_deref_at_centroid &&
311 intr->intrinsic != nir_intrinsic_interp_deref_at_sample &&
312 intr->intrinsic != nir_intrinsic_interp_deref_at_offset &&
313 intr->intrinsic != nir_intrinsic_interp_deref_at_vertex)
314 return false;
315
316 nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
317 if (!nir_deref_mode_is_one_of(deref, state->mask))
318 return false;
319
320 nir_variable *var = nir_deref_instr_get_variable(deref);
321 nir_variable_mode mode = var->data.mode;
322
323 /* TODO: add patch support */
324 if (var->data.patch)
325 return false;
326
327 /* TODO: add doubles support */
328 if (glsl_type_is_64bit(glsl_without_array(var->type)))
329 return false;
330
331 if (!(b->shader->info.stage == MESA_SHADER_VERTEX &&
332 mode == nir_var_shader_in) &&
333 var->data.location < VARYING_SLOT_VAR0 &&
334 var->data.location >= 0)
335 return false;
336
337 /* Don't bother splitting if we can't opt away any unused
338 * components.
339 */
340 if (var->data.always_active_io)
341 return false;
342
343 /* Skip types we cannot split */
344 if (glsl_type_is_matrix(glsl_without_array(var->type)) ||
345 glsl_type_is_struct_or_ifc(glsl_without_array(var->type)))
346 return false;
347
348 switch (intr->intrinsic) {
349 case nir_intrinsic_interp_deref_at_centroid:
350 case nir_intrinsic_interp_deref_at_sample:
351 case nir_intrinsic_interp_deref_at_offset:
352 case nir_intrinsic_interp_deref_at_vertex:
353 case nir_intrinsic_load_deref:
354 if ((state->mask & nir_var_shader_in && mode == nir_var_shader_in) ||
355 (state->mask & nir_var_shader_out && mode == nir_var_shader_out)) {
356 lower_load_to_scalar_early(b, intr, var, state->split_inputs,
357 state->split_outputs);
358 return true;
359 }
360 break;
361 case nir_intrinsic_store_deref:
362 if (state->mask & nir_var_shader_out &&
363 mode == nir_var_shader_out) {
364 lower_store_output_to_scalar_early(b, intr, var, state->split_outputs);
365 return true;
366 }
367 break;
368 default:
369 break;
370 }
371
372 return false;
373 }
374
375 /*
376 * This function is intended to be called earlier than nir_lower_io_to_scalar()
377 * i.e. before nir_lower_io() is called.
378 */
379 bool
nir_lower_io_to_scalar_early(nir_shader * shader,nir_variable_mode mask)380 nir_lower_io_to_scalar_early(nir_shader *shader, nir_variable_mode mask)
381 {
382 struct io_to_scalar_early_state state = {
383 .split_inputs = _mesa_pointer_hash_table_create(NULL),
384 .split_outputs = _mesa_pointer_hash_table_create(NULL),
385 .mask = mask
386 };
387
388 bool progress = nir_shader_instructions_pass(shader,
389 nir_lower_io_to_scalar_early_instr,
390 nir_metadata_block_index |
391 nir_metadata_dominance,
392 &state);
393
394 /* Remove old input from the shaders inputs list */
395 hash_table_foreach(state.split_inputs, entry) {
396 nir_variable *var = (nir_variable *) entry->key;
397 exec_node_remove(&var->node);
398
399 free(entry->data);
400 }
401
402 /* Remove old output from the shaders outputs list */
403 hash_table_foreach(state.split_outputs, entry) {
404 nir_variable *var = (nir_variable *) entry->key;
405 exec_node_remove(&var->node);
406
407 free(entry->data);
408 }
409
410 _mesa_hash_table_destroy(state.split_inputs, NULL);
411 _mesa_hash_table_destroy(state.split_outputs, NULL);
412
413 nir_remove_dead_derefs(shader);
414
415 return progress;
416 }
417