1 /*
2  * Copyright © 2015 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "vtn_private.h"
25 #include "nir/nir_vla.h"
26 
27 static struct vtn_pointer *
vtn_pointer_for_image_or_sampler_variable(struct vtn_builder * b,struct vtn_variable * var)28 vtn_pointer_for_image_or_sampler_variable(struct vtn_builder *b,
29                                           struct vtn_variable *var)
30 {
31    assert(var->type->base_type == vtn_base_type_image ||
32           var->type->base_type == vtn_base_type_sampler);
33 
34    struct vtn_type *ptr_type = rzalloc(b, struct vtn_type);
35    ptr_type->base_type = vtn_base_type_pointer;
36    ptr_type->storage_class = SpvStorageClassUniformConstant;
37    ptr_type->deref = var->type;
38 
39    return vtn_pointer_for_variable(b, var, ptr_type);
40 }
41 
42 static bool
vtn_cfg_handle_prepass_instruction(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)43 vtn_cfg_handle_prepass_instruction(struct vtn_builder *b, SpvOp opcode,
44                                    const uint32_t *w, unsigned count)
45 {
46    switch (opcode) {
47    case SpvOpFunction: {
48       vtn_assert(b->func == NULL);
49       b->func = rzalloc(b, struct vtn_function);
50 
51       list_inithead(&b->func->body);
52       b->func->control = w[3];
53 
54       MAYBE_UNUSED const struct glsl_type *result_type =
55          vtn_value(b, w[1], vtn_value_type_type)->type->type;
56       struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_function);
57       val->func = b->func;
58 
59       const struct vtn_type *func_type =
60          vtn_value(b, w[4], vtn_value_type_type)->type;
61 
62       vtn_assert(func_type->return_type->type == result_type);
63 
64       nir_function *func =
65          nir_function_create(b->shader, ralloc_strdup(b->shader, val->name));
66 
67       func->num_params = func_type->length;
68       func->params = ralloc_array(b->shader, nir_parameter, func->num_params);
69       unsigned np = 0;
70       for (unsigned i = 0; i < func_type->length; i++) {
71          if (func_type->params[i]->base_type == vtn_base_type_pointer &&
72              func_type->params[i]->type == NULL) {
73             func->params[np].type = func_type->params[i]->deref->type;
74             func->params[np].param_type = nir_parameter_inout;
75             np++;
76          } else if (func_type->params[i]->base_type ==
77                     vtn_base_type_sampled_image) {
78             /* Sampled images are actually two parameters */
79             func->params = reralloc(b->shader, func->params,
80                                     nir_parameter, func->num_params++);
81             func->params[np].type = func_type->params[i]->type;
82             func->params[np].param_type = nir_parameter_in;
83             np++;
84             func->params[np].type = glsl_bare_sampler_type();
85             func->params[np].param_type = nir_parameter_in;
86             np++;
87          } else {
88             func->params[np].type = func_type->params[i]->type;
89             func->params[np].param_type = nir_parameter_in;
90             np++;
91          }
92       }
93       assert(np == func->num_params);
94 
95       func->return_type = func_type->return_type->type;
96 
97       b->func->impl = nir_function_impl_create(func);
98       b->nb.cursor = nir_before_cf_list(&b->func->impl->body);
99 
100       b->func_param_idx = 0;
101       break;
102    }
103 
104    case SpvOpFunctionEnd:
105       b->func->end = w;
106       b->func = NULL;
107       break;
108 
109    case SpvOpFunctionParameter: {
110       struct vtn_type *type = vtn_value(b, w[1], vtn_value_type_type)->type;
111 
112       vtn_assert(b->func_param_idx < b->func->impl->num_params);
113       nir_variable *param = b->func->impl->params[b->func_param_idx++];
114 
115       if (type->base_type == vtn_base_type_pointer && type->type == NULL) {
116          struct vtn_variable *vtn_var = rzalloc(b, struct vtn_variable);
117          vtn_var->type = type->deref;
118          vtn_var->var = param;
119 
120          vtn_assert(vtn_var->type->type == param->type);
121 
122          struct vtn_type *without_array = vtn_var->type;
123          while(glsl_type_is_array(without_array->type))
124             without_array = without_array->array_element;
125 
126          if (glsl_type_is_image(without_array->type)) {
127             vtn_var->mode = vtn_variable_mode_image;
128             param->interface_type = without_array->type;
129          } else if (glsl_type_is_sampler(without_array->type)) {
130             vtn_var->mode = vtn_variable_mode_sampler;
131             param->interface_type = without_array->type;
132          } else {
133             vtn_var->mode = vtn_variable_mode_param;
134          }
135 
136          struct vtn_value *val =
137             vtn_push_value(b, w[2], vtn_value_type_pointer);
138 
139          /* Name the parameter so it shows up nicely in NIR */
140          param->name = ralloc_strdup(param, val->name);
141 
142          val->pointer = vtn_pointer_for_variable(b, vtn_var, type);
143       } else if (type->base_type == vtn_base_type_image ||
144                  type->base_type == vtn_base_type_sampler ||
145                  type->base_type == vtn_base_type_sampled_image) {
146          struct vtn_variable *vtn_var = rzalloc(b, struct vtn_variable);
147          vtn_var->type = type;
148          vtn_var->var = param;
149          param->interface_type = param->type;
150 
151          if (type->base_type == vtn_base_type_sampled_image) {
152             /* Sampled images are actually two parameters.  The first is the
153              * image and the second is the sampler.
154              */
155             struct vtn_value *val =
156                vtn_push_value(b, w[2], vtn_value_type_sampled_image);
157 
158             /* Name the parameter so it shows up nicely in NIR */
159             param->name = ralloc_strdup(param, val->name);
160 
161             /* Adjust the type of the image variable to the image type */
162             vtn_var->type = type->image;
163 
164             /* Now get the sampler parameter and set up its variable */
165             param = b->func->impl->params[b->func_param_idx++];
166             struct vtn_variable *sampler_var = rzalloc(b, struct vtn_variable);
167             sampler_var->type = rzalloc(b, struct vtn_type);
168             sampler_var->type->base_type = vtn_base_type_sampler;
169             sampler_var->type->type = glsl_bare_sampler_type();
170             sampler_var->var = param;
171             param->interface_type = param->type;
172             param->name = ralloc_strdup(param, val->name);
173 
174             val->sampled_image = ralloc(b, struct vtn_sampled_image);
175             val->sampled_image->type = type;
176             val->sampled_image->image =
177                vtn_pointer_for_image_or_sampler_variable(b, vtn_var);
178             val->sampled_image->sampler =
179                vtn_pointer_for_image_or_sampler_variable(b, sampler_var);
180          } else {
181             struct vtn_value *val =
182                vtn_push_value(b, w[2], vtn_value_type_pointer);
183 
184             /* Name the parameter so it shows up nicely in NIR */
185             param->name = ralloc_strdup(param, val->name);
186 
187             val->pointer =
188                vtn_pointer_for_image_or_sampler_variable(b, vtn_var);
189          }
190       } else {
191          /* We're a regular SSA value. */
192          struct vtn_ssa_value *param_ssa =
193             vtn_local_load(b, nir_deref_var_create(b, param));
194          struct vtn_value *val = vtn_push_ssa(b, w[2], type, param_ssa);
195 
196          /* Name the parameter so it shows up nicely in NIR */
197          param->name = ralloc_strdup(param, val->name);
198       }
199       break;
200    }
201 
202    case SpvOpLabel: {
203       vtn_assert(b->block == NULL);
204       b->block = rzalloc(b, struct vtn_block);
205       b->block->node.type = vtn_cf_node_type_block;
206       b->block->label = w;
207       vtn_push_value(b, w[1], vtn_value_type_block)->block = b->block;
208 
209       if (b->func->start_block == NULL) {
210          /* This is the first block encountered for this function.  In this
211           * case, we set the start block and add it to the list of
212           * implemented functions that we'll walk later.
213           */
214          b->func->start_block = b->block;
215          exec_list_push_tail(&b->functions, &b->func->node);
216       }
217       break;
218    }
219 
220    case SpvOpSelectionMerge:
221    case SpvOpLoopMerge:
222       vtn_assert(b->block && b->block->merge == NULL);
223       b->block->merge = w;
224       break;
225 
226    case SpvOpBranch:
227    case SpvOpBranchConditional:
228    case SpvOpSwitch:
229    case SpvOpKill:
230    case SpvOpReturn:
231    case SpvOpReturnValue:
232    case SpvOpUnreachable:
233       vtn_assert(b->block && b->block->branch == NULL);
234       b->block->branch = w;
235       b->block = NULL;
236       break;
237 
238    default:
239       /* Continue on as per normal */
240       return true;
241    }
242 
243    return true;
244 }
245 
246 static void
vtn_add_case(struct vtn_builder * b,struct vtn_switch * swtch,struct vtn_block * break_block,uint32_t block_id,uint64_t val,bool is_default)247 vtn_add_case(struct vtn_builder *b, struct vtn_switch *swtch,
248              struct vtn_block *break_block,
249              uint32_t block_id, uint64_t val, bool is_default)
250 {
251    struct vtn_block *case_block =
252       vtn_value(b, block_id, vtn_value_type_block)->block;
253 
254    /* Don't create dummy cases that just break */
255    if (case_block == break_block)
256       return;
257 
258    if (case_block->switch_case == NULL) {
259       struct vtn_case *c = ralloc(b, struct vtn_case);
260 
261       list_inithead(&c->body);
262       c->start_block = case_block;
263       c->fallthrough = NULL;
264       util_dynarray_init(&c->values, b);
265       c->is_default = false;
266       c->visited = false;
267 
268       list_addtail(&c->link, &swtch->cases);
269 
270       case_block->switch_case = c;
271    }
272 
273    if (is_default) {
274       case_block->switch_case->is_default = true;
275    } else {
276       util_dynarray_append(&case_block->switch_case->values, uint64_t, val);
277    }
278 }
279 
280 /* This function performs a depth-first search of the cases and puts them
281  * in fall-through order.
282  */
283 static void
vtn_order_case(struct vtn_switch * swtch,struct vtn_case * cse)284 vtn_order_case(struct vtn_switch *swtch, struct vtn_case *cse)
285 {
286    if (cse->visited)
287       return;
288 
289    cse->visited = true;
290 
291    list_del(&cse->link);
292 
293    if (cse->fallthrough) {
294       vtn_order_case(swtch, cse->fallthrough);
295 
296       /* If we have a fall-through, place this case right before the case it
297        * falls through to.  This ensures that fallthroughs come one after
298        * the other.  These two can never get separated because that would
299        * imply something else falling through to the same case.  Also, this
300        * can't break ordering because the DFS ensures that this case is
301        * visited before anything that falls through to it.
302        */
303       list_addtail(&cse->link, &cse->fallthrough->link);
304    } else {
305       list_add(&cse->link, &swtch->cases);
306    }
307 }
308 
309 static enum vtn_branch_type
vtn_get_branch_type(struct vtn_builder * b,struct vtn_block * block,struct vtn_case * swcase,struct vtn_block * switch_break,struct vtn_block * loop_break,struct vtn_block * loop_cont)310 vtn_get_branch_type(struct vtn_builder *b,
311                     struct vtn_block *block,
312                     struct vtn_case *swcase, struct vtn_block *switch_break,
313                     struct vtn_block *loop_break, struct vtn_block *loop_cont)
314 {
315    if (block->switch_case) {
316       /* This branch is actually a fallthrough */
317       vtn_assert(swcase->fallthrough == NULL ||
318                  swcase->fallthrough == block->switch_case);
319       swcase->fallthrough = block->switch_case;
320       return vtn_branch_type_switch_fallthrough;
321    } else if (block == loop_break) {
322       return vtn_branch_type_loop_break;
323    } else if (block == loop_cont) {
324       return vtn_branch_type_loop_continue;
325    } else if (block == switch_break) {
326       return vtn_branch_type_switch_break;
327    } else {
328       return vtn_branch_type_none;
329    }
330 }
331 
332 static void
vtn_cfg_walk_blocks(struct vtn_builder * b,struct list_head * cf_list,struct vtn_block * start,struct vtn_case * switch_case,struct vtn_block * switch_break,struct vtn_block * loop_break,struct vtn_block * loop_cont,struct vtn_block * end)333 vtn_cfg_walk_blocks(struct vtn_builder *b, struct list_head *cf_list,
334                     struct vtn_block *start, struct vtn_case *switch_case,
335                     struct vtn_block *switch_break,
336                     struct vtn_block *loop_break, struct vtn_block *loop_cont,
337                     struct vtn_block *end)
338 {
339    struct vtn_block *block = start;
340    while (block != end) {
341       if (block->merge && (*block->merge & SpvOpCodeMask) == SpvOpLoopMerge &&
342           !block->loop) {
343          struct vtn_loop *loop = ralloc(b, struct vtn_loop);
344 
345          loop->node.type = vtn_cf_node_type_loop;
346          list_inithead(&loop->body);
347          list_inithead(&loop->cont_body);
348          loop->control = block->merge[3];
349 
350          list_addtail(&loop->node.link, cf_list);
351          block->loop = loop;
352 
353          struct vtn_block *new_loop_break =
354             vtn_value(b, block->merge[1], vtn_value_type_block)->block;
355          struct vtn_block *new_loop_cont =
356             vtn_value(b, block->merge[2], vtn_value_type_block)->block;
357 
358          /* Note: This recursive call will start with the current block as
359           * its start block.  If we weren't careful, we would get here
360           * again and end up in infinite recursion.  This is why we set
361           * block->loop above and check for it before creating one.  This
362           * way, we only create the loop once and the second call that
363           * tries to handle this loop goes to the cases below and gets
364           * handled as a regular block.
365           *
366           * Note: When we make the recursive walk calls, we pass NULL for
367           * the switch break since you have to break out of the loop first.
368           * We do, however, still pass the current switch case because it's
369           * possible that the merge block for the loop is the start of
370           * another case.
371           */
372          vtn_cfg_walk_blocks(b, &loop->body, block, switch_case, NULL,
373                              new_loop_break, new_loop_cont, NULL );
374          vtn_cfg_walk_blocks(b, &loop->cont_body, new_loop_cont, NULL, NULL,
375                              new_loop_break, NULL, block);
376 
377          block = new_loop_break;
378          continue;
379       }
380 
381       vtn_assert(block->node.link.next == NULL);
382       list_addtail(&block->node.link, cf_list);
383 
384       switch (*block->branch & SpvOpCodeMask) {
385       case SpvOpBranch: {
386          struct vtn_block *branch_block =
387             vtn_value(b, block->branch[1], vtn_value_type_block)->block;
388 
389          block->branch_type = vtn_get_branch_type(b, branch_block,
390                                                   switch_case, switch_break,
391                                                   loop_break, loop_cont);
392 
393          if (block->branch_type != vtn_branch_type_none)
394             return;
395 
396          block = branch_block;
397          continue;
398       }
399 
400       case SpvOpReturn:
401       case SpvOpReturnValue:
402          block->branch_type = vtn_branch_type_return;
403          return;
404 
405       case SpvOpKill:
406          block->branch_type = vtn_branch_type_discard;
407          return;
408 
409       case SpvOpBranchConditional: {
410          struct vtn_block *then_block =
411             vtn_value(b, block->branch[2], vtn_value_type_block)->block;
412          struct vtn_block *else_block =
413             vtn_value(b, block->branch[3], vtn_value_type_block)->block;
414 
415          struct vtn_if *if_stmt = ralloc(b, struct vtn_if);
416 
417          if_stmt->node.type = vtn_cf_node_type_if;
418          if_stmt->condition = block->branch[1];
419          list_inithead(&if_stmt->then_body);
420          list_inithead(&if_stmt->else_body);
421 
422          list_addtail(&if_stmt->node.link, cf_list);
423 
424          if (block->merge &&
425              (*block->merge & SpvOpCodeMask) == SpvOpSelectionMerge) {
426             if_stmt->control = block->merge[2];
427          }
428 
429          if_stmt->then_type = vtn_get_branch_type(b, then_block,
430                                                   switch_case, switch_break,
431                                                   loop_break, loop_cont);
432          if_stmt->else_type = vtn_get_branch_type(b, else_block,
433                                                   switch_case, switch_break,
434                                                   loop_break, loop_cont);
435 
436          if (then_block == else_block) {
437             block->branch_type = if_stmt->then_type;
438             if (block->branch_type == vtn_branch_type_none) {
439                block = then_block;
440                continue;
441             } else {
442                return;
443             }
444          } else if (if_stmt->then_type == vtn_branch_type_none &&
445                     if_stmt->else_type == vtn_branch_type_none) {
446             /* Neither side of the if is something we can short-circuit. */
447             vtn_assert((*block->merge & SpvOpCodeMask) == SpvOpSelectionMerge);
448             struct vtn_block *merge_block =
449                vtn_value(b, block->merge[1], vtn_value_type_block)->block;
450 
451             vtn_cfg_walk_blocks(b, &if_stmt->then_body, then_block,
452                                 switch_case, switch_break,
453                                 loop_break, loop_cont, merge_block);
454             vtn_cfg_walk_blocks(b, &if_stmt->else_body, else_block,
455                                 switch_case, switch_break,
456                                 loop_break, loop_cont, merge_block);
457 
458             enum vtn_branch_type merge_type =
459                vtn_get_branch_type(b, merge_block, switch_case, switch_break,
460                                    loop_break, loop_cont);
461             if (merge_type == vtn_branch_type_none) {
462                block = merge_block;
463                continue;
464             } else {
465                return;
466             }
467          } else if (if_stmt->then_type != vtn_branch_type_none &&
468                     if_stmt->else_type != vtn_branch_type_none) {
469             /* Both sides were short-circuited.  We're done here. */
470             return;
471          } else {
472             /* Exeactly one side of the branch could be short-circuited.
473              * We set the branch up as a predicated break/continue and we
474              * continue on with the other side as if it were what comes
475              * after the if.
476              */
477             if (if_stmt->then_type == vtn_branch_type_none) {
478                block = then_block;
479             } else {
480                block = else_block;
481             }
482             continue;
483          }
484          vtn_fail("Should have returned or continued");
485       }
486 
487       case SpvOpSwitch: {
488          vtn_assert((*block->merge & SpvOpCodeMask) == SpvOpSelectionMerge);
489          struct vtn_block *break_block =
490             vtn_value(b, block->merge[1], vtn_value_type_block)->block;
491 
492          struct vtn_switch *swtch = ralloc(b, struct vtn_switch);
493 
494          swtch->node.type = vtn_cf_node_type_switch;
495          swtch->selector = block->branch[1];
496          list_inithead(&swtch->cases);
497 
498          list_addtail(&swtch->node.link, cf_list);
499 
500          /* First, we go through and record all of the cases. */
501          const uint32_t *branch_end =
502             block->branch + (block->branch[0] >> SpvWordCountShift);
503 
504          struct vtn_value *cond_val = vtn_untyped_value(b, block->branch[1]);
505          vtn_fail_if(!cond_val->type ||
506                      cond_val->type->base_type != vtn_base_type_scalar,
507                      "Selector of OpSelect must have a type of OpTypeInt");
508 
509          nir_alu_type cond_type =
510             nir_get_nir_type_for_glsl_type(cond_val->type->type);
511          vtn_fail_if(nir_alu_type_get_base_type(cond_type) != nir_type_int &&
512                      nir_alu_type_get_base_type(cond_type) != nir_type_uint,
513                      "Selector of OpSelect must have a type of OpTypeInt");
514 
515          bool is_default = true;
516          const uint bitsize = nir_alu_type_get_type_size(cond_type);
517          for (const uint32_t *w = block->branch + 2; w < branch_end;) {
518             uint64_t literal = 0;
519             if (!is_default) {
520                if (bitsize <= 32) {
521                   literal = *(w++);
522                } else {
523                   assert(bitsize == 64);
524                   literal = vtn_u64_literal(w);
525                   w += 2;
526                }
527             }
528 
529             uint32_t block_id = *(w++);
530 
531             vtn_add_case(b, swtch, break_block, block_id, literal, is_default);
532             is_default = false;
533          }
534 
535          /* Now, we go through and walk the blocks.  While we walk through
536           * the blocks, we also gather the much-needed fall-through
537           * information.
538           */
539          list_for_each_entry(struct vtn_case, cse, &swtch->cases, link) {
540             vtn_assert(cse->start_block != break_block);
541             vtn_cfg_walk_blocks(b, &cse->body, cse->start_block, cse,
542                                 break_block, loop_break, loop_cont, NULL);
543          }
544 
545          /* Finally, we walk over all of the cases one more time and put
546           * them in fall-through order.
547           */
548          for (const uint32_t *w = block->branch + 2; w < branch_end;) {
549             struct vtn_block *case_block =
550                vtn_value(b, *w, vtn_value_type_block)->block;
551 
552             if (bitsize <= 32) {
553                w += 2;
554             } else {
555                assert(bitsize == 64);
556                w += 3;
557             }
558 
559             if (case_block == break_block)
560                continue;
561 
562             vtn_assert(case_block->switch_case);
563 
564             vtn_order_case(swtch, case_block->switch_case);
565          }
566 
567          enum vtn_branch_type branch_type =
568             vtn_get_branch_type(b, break_block, switch_case, NULL,
569                                 loop_break, loop_cont);
570 
571          if (branch_type != vtn_branch_type_none) {
572             /* It is possible that the break is actually the continue block
573              * for the containing loop.  In this case, we need to bail and let
574              * the loop parsing code handle the continue properly.
575              */
576             vtn_assert(branch_type == vtn_branch_type_loop_continue);
577             return;
578          }
579 
580          block = break_block;
581          continue;
582       }
583 
584       case SpvOpUnreachable:
585          return;
586 
587       default:
588          vtn_fail("Unhandled opcode");
589       }
590    }
591 }
592 
593 void
vtn_build_cfg(struct vtn_builder * b,const uint32_t * words,const uint32_t * end)594 vtn_build_cfg(struct vtn_builder *b, const uint32_t *words, const uint32_t *end)
595 {
596    vtn_foreach_instruction(b, words, end,
597                            vtn_cfg_handle_prepass_instruction);
598 
599    foreach_list_typed(struct vtn_function, func, node, &b->functions) {
600       vtn_cfg_walk_blocks(b, &func->body, func->start_block,
601                           NULL, NULL, NULL, NULL, NULL);
602    }
603 }
604 
605 static bool
vtn_handle_phis_first_pass(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)606 vtn_handle_phis_first_pass(struct vtn_builder *b, SpvOp opcode,
607                            const uint32_t *w, unsigned count)
608 {
609    if (opcode == SpvOpLabel)
610       return true; /* Nothing to do */
611 
612    /* If this isn't a phi node, stop. */
613    if (opcode != SpvOpPhi)
614       return false;
615 
616    /* For handling phi nodes, we do a poor-man's out-of-ssa on the spot.
617     * For each phi, we create a variable with the appropreate type and
618     * do a load from that variable.  Then, in a second pass, we add
619     * stores to that variable to each of the predecessor blocks.
620     *
621     * We could do something more intelligent here.  However, in order to
622     * handle loops and things properly, we really need dominance
623     * information.  It would end up basically being the into-SSA
624     * algorithm all over again.  It's easier if we just let
625     * lower_vars_to_ssa do that for us instead of repeating it here.
626     */
627    struct vtn_type *type = vtn_value(b, w[1], vtn_value_type_type)->type;
628    nir_variable *phi_var =
629       nir_local_variable_create(b->nb.impl, type->type, "phi");
630    _mesa_hash_table_insert(b->phi_table, w, phi_var);
631 
632    vtn_push_ssa(b, w[2], type,
633                 vtn_local_load(b, nir_deref_var_create(b, phi_var)));
634 
635    return true;
636 }
637 
638 static bool
vtn_handle_phi_second_pass(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)639 vtn_handle_phi_second_pass(struct vtn_builder *b, SpvOp opcode,
640                            const uint32_t *w, unsigned count)
641 {
642    if (opcode != SpvOpPhi)
643       return true;
644 
645    struct hash_entry *phi_entry = _mesa_hash_table_search(b->phi_table, w);
646    vtn_assert(phi_entry);
647    nir_variable *phi_var = phi_entry->data;
648 
649    for (unsigned i = 3; i < count; i += 2) {
650       struct vtn_block *pred =
651          vtn_value(b, w[i + 1], vtn_value_type_block)->block;
652 
653       b->nb.cursor = nir_after_instr(&pred->end_nop->instr);
654 
655       struct vtn_ssa_value *src = vtn_ssa_value(b, w[i]);
656 
657       vtn_local_store(b, src, nir_deref_var_create(b, phi_var));
658    }
659 
660    return true;
661 }
662 
663 static void
vtn_emit_branch(struct vtn_builder * b,enum vtn_branch_type branch_type,nir_variable * switch_fall_var,bool * has_switch_break)664 vtn_emit_branch(struct vtn_builder *b, enum vtn_branch_type branch_type,
665                 nir_variable *switch_fall_var, bool *has_switch_break)
666 {
667    switch (branch_type) {
668    case vtn_branch_type_switch_break:
669       nir_store_var(&b->nb, switch_fall_var, nir_imm_int(&b->nb, NIR_FALSE), 1);
670       *has_switch_break = true;
671       break;
672    case vtn_branch_type_switch_fallthrough:
673       break; /* Nothing to do */
674    case vtn_branch_type_loop_break:
675       nir_jump(&b->nb, nir_jump_break);
676       break;
677    case vtn_branch_type_loop_continue:
678       nir_jump(&b->nb, nir_jump_continue);
679       break;
680    case vtn_branch_type_return:
681       nir_jump(&b->nb, nir_jump_return);
682       break;
683    case vtn_branch_type_discard: {
684       nir_intrinsic_instr *discard =
685          nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_discard);
686       nir_builder_instr_insert(&b->nb, &discard->instr);
687       break;
688    }
689    default:
690       vtn_fail("Invalid branch type");
691    }
692 }
693 
694 static void
vtn_emit_cf_list(struct vtn_builder * b,struct list_head * cf_list,nir_variable * switch_fall_var,bool * has_switch_break,vtn_instruction_handler handler)695 vtn_emit_cf_list(struct vtn_builder *b, struct list_head *cf_list,
696                  nir_variable *switch_fall_var, bool *has_switch_break,
697                  vtn_instruction_handler handler)
698 {
699    list_for_each_entry(struct vtn_cf_node, node, cf_list, link) {
700       switch (node->type) {
701       case vtn_cf_node_type_block: {
702          struct vtn_block *block = (struct vtn_block *)node;
703 
704          const uint32_t *block_start = block->label;
705          const uint32_t *block_end = block->merge ? block->merge :
706                                                     block->branch;
707 
708          block_start = vtn_foreach_instruction(b, block_start, block_end,
709                                                vtn_handle_phis_first_pass);
710 
711          vtn_foreach_instruction(b, block_start, block_end, handler);
712 
713          block->end_nop = nir_intrinsic_instr_create(b->nb.shader,
714                                                      nir_intrinsic_nop);
715          nir_builder_instr_insert(&b->nb, &block->end_nop->instr);
716 
717          if ((*block->branch & SpvOpCodeMask) == SpvOpReturnValue) {
718             struct vtn_ssa_value *src = vtn_ssa_value(b, block->branch[1]);
719             vtn_local_store(b, src,
720                             nir_deref_var_create(b, b->nb.impl->return_var));
721          }
722 
723          if (block->branch_type != vtn_branch_type_none) {
724             vtn_emit_branch(b, block->branch_type,
725                             switch_fall_var, has_switch_break);
726          }
727 
728          break;
729       }
730 
731       case vtn_cf_node_type_if: {
732          struct vtn_if *vtn_if = (struct vtn_if *)node;
733          bool sw_break = false;
734 
735          nir_if *nif =
736             nir_push_if(&b->nb, vtn_ssa_value(b, vtn_if->condition)->def);
737          if (vtn_if->then_type == vtn_branch_type_none) {
738             vtn_emit_cf_list(b, &vtn_if->then_body,
739                              switch_fall_var, &sw_break, handler);
740          } else {
741             vtn_emit_branch(b, vtn_if->then_type, switch_fall_var, &sw_break);
742          }
743 
744          nir_push_else(&b->nb, nif);
745          if (vtn_if->else_type == vtn_branch_type_none) {
746             vtn_emit_cf_list(b, &vtn_if->else_body,
747                              switch_fall_var, &sw_break, handler);
748          } else {
749             vtn_emit_branch(b, vtn_if->else_type, switch_fall_var, &sw_break);
750          }
751 
752          nir_pop_if(&b->nb, nif);
753 
754          /* If we encountered a switch break somewhere inside of the if,
755           * then it would have been handled correctly by calling
756           * emit_cf_list or emit_branch for the interrior.  However, we
757           * need to predicate everything following on wether or not we're
758           * still going.
759           */
760          if (sw_break) {
761             *has_switch_break = true;
762             nir_push_if(&b->nb, nir_load_var(&b->nb, switch_fall_var));
763          }
764          break;
765       }
766 
767       case vtn_cf_node_type_loop: {
768          struct vtn_loop *vtn_loop = (struct vtn_loop *)node;
769 
770          nir_loop *loop = nir_push_loop(&b->nb);
771          vtn_emit_cf_list(b, &vtn_loop->body, NULL, NULL, handler);
772 
773          if (!list_empty(&vtn_loop->cont_body)) {
774             /* If we have a non-trivial continue body then we need to put
775              * it at the beginning of the loop with a flag to ensure that
776              * it doesn't get executed in the first iteration.
777              */
778             nir_variable *do_cont =
779                nir_local_variable_create(b->nb.impl, glsl_bool_type(), "cont");
780 
781             b->nb.cursor = nir_before_cf_node(&loop->cf_node);
782             nir_store_var(&b->nb, do_cont, nir_imm_int(&b->nb, NIR_FALSE), 1);
783 
784             b->nb.cursor = nir_before_cf_list(&loop->body);
785 
786             nir_if *cont_if =
787                nir_push_if(&b->nb, nir_load_var(&b->nb, do_cont));
788 
789             vtn_emit_cf_list(b, &vtn_loop->cont_body, NULL, NULL, handler);
790 
791             nir_pop_if(&b->nb, cont_if);
792 
793             nir_store_var(&b->nb, do_cont, nir_imm_int(&b->nb, NIR_TRUE), 1);
794 
795             b->has_loop_continue = true;
796          }
797 
798          nir_pop_loop(&b->nb, loop);
799          break;
800       }
801 
802       case vtn_cf_node_type_switch: {
803          struct vtn_switch *vtn_switch = (struct vtn_switch *)node;
804 
805          /* First, we create a variable to keep track of whether or not the
806           * switch is still going at any given point.  Any switch breaks
807           * will set this variable to false.
808           */
809          nir_variable *fall_var =
810             nir_local_variable_create(b->nb.impl, glsl_bool_type(), "fall");
811          nir_store_var(&b->nb, fall_var, nir_imm_int(&b->nb, NIR_FALSE), 1);
812 
813          /* Next, we gather up all of the conditions.  We have to do this
814           * up-front because we also need to build an "any" condition so
815           * that we can use !any for default.
816           */
817          const int num_cases = list_length(&vtn_switch->cases);
818          NIR_VLA(nir_ssa_def *, conditions, num_cases);
819 
820          nir_ssa_def *sel = vtn_ssa_value(b, vtn_switch->selector)->def;
821          /* An accumulation of all conditions.  Used for the default */
822          nir_ssa_def *any = NULL;
823 
824          int i = 0;
825          list_for_each_entry(struct vtn_case, cse, &vtn_switch->cases, link) {
826             if (cse->is_default) {
827                conditions[i++] = NULL;
828                continue;
829             }
830 
831             nir_ssa_def *cond = NULL;
832             util_dynarray_foreach(&cse->values, uint64_t, val) {
833                nir_ssa_def *imm = nir_imm_intN_t(&b->nb, *val, sel->bit_size);
834                nir_ssa_def *is_val = nir_ieq(&b->nb, sel, imm);
835 
836                cond = cond ? nir_ior(&b->nb, cond, is_val) : is_val;
837             }
838 
839             any = any ? nir_ior(&b->nb, any, cond) : cond;
840             conditions[i++] = cond;
841          }
842          vtn_assert(i == num_cases);
843 
844          /* Now we can walk the list of cases and actually emit code */
845          i = 0;
846          list_for_each_entry(struct vtn_case, cse, &vtn_switch->cases, link) {
847             /* Figure out the condition */
848             nir_ssa_def *cond = conditions[i++];
849             if (cse->is_default) {
850                vtn_assert(cond == NULL);
851                cond = nir_inot(&b->nb, any);
852             }
853             /* Take fallthrough into account */
854             cond = nir_ior(&b->nb, cond, nir_load_var(&b->nb, fall_var));
855 
856             nir_if *case_if = nir_push_if(&b->nb, cond);
857 
858             bool has_break = false;
859             nir_store_var(&b->nb, fall_var, nir_imm_int(&b->nb, NIR_TRUE), 1);
860             vtn_emit_cf_list(b, &cse->body, fall_var, &has_break, handler);
861             (void)has_break; /* We don't care */
862 
863             nir_pop_if(&b->nb, case_if);
864          }
865          vtn_assert(i == num_cases);
866 
867          break;
868       }
869 
870       default:
871          vtn_fail("Invalid CF node type");
872       }
873    }
874 }
875 
876 void
vtn_function_emit(struct vtn_builder * b,struct vtn_function * func,vtn_instruction_handler instruction_handler)877 vtn_function_emit(struct vtn_builder *b, struct vtn_function *func,
878                   vtn_instruction_handler instruction_handler)
879 {
880    nir_builder_init(&b->nb, func->impl);
881    b->nb.cursor = nir_after_cf_list(&func->impl->body);
882    b->has_loop_continue = false;
883    b->phi_table = _mesa_hash_table_create(b, _mesa_hash_pointer,
884                                           _mesa_key_pointer_equal);
885 
886    vtn_emit_cf_list(b, &func->body, NULL, NULL, instruction_handler);
887 
888    vtn_foreach_instruction(b, func->start_block->label, func->end,
889                            vtn_handle_phi_second_pass);
890 
891    /* Continue blocks for loops get inserted before the body of the loop
892     * but instructions in the continue may use SSA defs in the loop body.
893     * Therefore, we need to repair SSA to insert the needed phi nodes.
894     */
895    if (b->has_loop_continue)
896       nir_repair_ssa_impl(func->impl);
897 
898    func->emitted = true;
899 }
900