1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can
5  * be found in the LICENSE file.
6  *
7  */
8 
9 #include <stdio.h>
10 #include <stdlib.h>
11 
12 //
13 //
14 //
15 
16 #include "gen.h"
17 #include "transpose.h"
18 
19 #include "common/util.h"
20 #include "common/macros.h"
21 
22 //
23 //
24 //
25 
26 struct hsg_transpose_state
27 {
28   FILE                    * header;
29   struct hsg_config const * config;
30 };
31 
32 static
33 char
hsg_transpose_reg_prefix(uint32_t const cols_log2)34 hsg_transpose_reg_prefix(uint32_t const cols_log2)
35 {
36   return 'a' + (('r' + cols_log2 - 'a') % 26);
37 }
38 
39 static
40 void
hsg_transpose_blend(uint32_t const cols_log2,uint32_t const row_ll,uint32_t const row_ur,void * blend)41 hsg_transpose_blend(uint32_t const cols_log2,
42                     uint32_t const row_ll, // lower-left
43                     uint32_t const row_ur, // upper-right
44                     void *         blend)
45 {
46   struct hsg_transpose_state * const state = blend;
47 
48   // we're starting register names at '1' for now
49   fprintf(state->header,
50           "  HS_TRANSPOSE_BLEND( %c, %c, %2u, %3u, %3u ) \\\n",
51           hsg_transpose_reg_prefix(cols_log2-1),
52           hsg_transpose_reg_prefix(cols_log2),
53           cols_log2,row_ll+1,row_ur+1);
54 }
55 
56 static
57 void
hsg_transpose_remap(uint32_t const row_from,uint32_t const row_to,void * remap)58 hsg_transpose_remap(uint32_t const row_from,
59                     uint32_t const row_to,
60                     void *         remap)
61 {
62   struct hsg_transpose_state * const state = remap;
63 
64   // we're starting register names at '1' for now
65   fprintf(state->header,
66           "  HS_TRANSPOSE_REMAP( %c, %3u, %3u )        \\\n",
67           hsg_transpose_reg_prefix(state->config->warp.lanes_log2),
68           row_from+1,row_to+1);
69 }
70 
71 //
72 //
73 //
74 
75 static
76 void
hsg_copyright(FILE * file)77 hsg_copyright(FILE * file)
78 {
79   fprintf(file,
80           "//                                                    \n"
81           "// Copyright 2016 Google Inc.                         \n"
82           "//                                                    \n"
83           "// Use of this source code is governed by a BSD-style \n"
84           "// license that can be found in the LICENSE file.     \n"
85           "//                                                    \n"
86           "\n");
87 }
88 
89 static
90 void
hsg_macros(FILE * file)91 hsg_macros(FILE * file)
92 {
93   fprintf(file,
94           "// target-specific config     \n"
95           "#include \"hs_config.h\"      \n"
96           "                              \n"
97           "// arch/target-specific macros\n"
98           "#include \"hs_cl_macros.h\"   \n"
99           "                              \n"
100           "//                            \n"
101           "//                            \n"
102           "//                            \n");
103 }
104 
105 //
106 //
107 //
108 
109 struct hsg_target_state
110 {
111   FILE * header;
112   FILE * source;
113 };
114 
115 //
116 //
117 //
118 
119 void
hsg_target_opencl(struct hsg_target * const target,struct hsg_config const * const config,struct hsg_merge const * const merge,struct hsg_op const * const ops,uint32_t const depth)120 hsg_target_opencl(struct hsg_target       * const target,
121                   struct hsg_config const * const config,
122                   struct hsg_merge  const * const merge,
123                   struct hsg_op     const * const ops,
124                   uint32_t                  const depth)
125 {
126   switch (ops->type)
127     {
128     case HSG_OP_TYPE_END:
129       fprintf(target->state->source,
130               "}\n");
131       break;
132 
133     case HSG_OP_TYPE_BEGIN:
134       fprintf(target->state->source,
135               "{\n");
136       break;
137 
138     case HSG_OP_TYPE_ELSE:
139       fprintf(target->state->source,
140               "else\n");
141       break;
142 
143     case HSG_OP_TYPE_TARGET_BEGIN:
144       {
145         // allocate state
146         target->state = malloc(sizeof(*target->state));
147 
148         // allocate files
149         target->state->header = fopen("hs_config.h",  "wb");
150         target->state->source = fopen("hs_kernels.cl","wb");
151 
152         // initialize header
153         uint32_t const bc_max = msb_idx_u32(pow2_rd_u32(merge->warps));
154 
155         hsg_copyright(target->state->header);
156 
157         fprintf(target->state->header,
158                 "#ifndef HS_CL_ONCE                                              \n"
159                 "#define HS_CL_ONCE                                              \n"
160                 "                                                                \n"
161                 "#define HS_SLAB_THREADS_LOG2    %u                              \n"
162                 "#define HS_SLAB_THREADS         (1 << HS_SLAB_THREADS_LOG2)     \n"
163                 "#define HS_SLAB_WIDTH_LOG2      %u                              \n"
164                 "#define HS_SLAB_WIDTH           (1 << HS_SLAB_WIDTH_LOG2)       \n"
165                 "#define HS_SLAB_HEIGHT          %u                              \n"
166                 "#define HS_SLAB_KEYS            (HS_SLAB_WIDTH * HS_SLAB_HEIGHT)\n"
167                 "#define HS_REG_LAST(c)          c##%u                           \n"
168                 "#define HS_KEY_WORDS            %u                              \n"
169                 "#define HS_VAL_WORDS            0                               \n"
170                 "#define HS_BS_SLABS             %u                              \n"
171                 "#define HS_BS_SLABS_LOG2_RU     %u                              \n"
172                 "#define HS_BC_SLABS_LOG2_MAX    %u                              \n"
173                 "#define HS_FM_BLOCK_HEIGHT      %u                              \n"
174                 "#define HS_FM_SCALE_MIN         %u                              \n"
175                 "#define HS_FM_SCALE_MAX         %u                              \n"
176                 "#define HS_HM_BLOCK_HEIGHT      %u                              \n"
177                 "#define HS_HM_SCALE_MIN         %u                              \n"
178                 "#define HS_HM_SCALE_MAX         %u                              \n"
179                 "#define HS_EMPTY                                                \n"
180                 "                                                                \n",
181                 config->warp.lanes_log2, // FIXME - may be different on a SIMD target
182                 config->warp.lanes_log2,
183                 config->thread.regs,
184                 config->thread.regs,
185                 config->type.words,
186                 merge->warps,
187                 msb_idx_u32(pow2_ru_u32(merge->warps)),
188                 bc_max,
189                 config->merge.flip.warps,
190                 config->merge.flip.lo,
191                 config->merge.flip.hi,
192                 config->merge.half.warps,
193                 config->merge.half.lo,
194                 config->merge.half.hi);
195 
196         if (target->define != NULL)
197           fprintf(target->state->header,"#define %s\n\n",target->define);
198 
199         fprintf(target->state->header,
200                 "#define HS_SLAB_ROWS()    \\\n");
201 
202         for (uint32_t ii=1; ii<=config->thread.regs; ii++)
203           fprintf(target->state->header,
204                   "  HS_SLAB_ROW( %3u, %3u ) \\\n",ii,ii-1);
205 
206         fprintf(target->state->header,
207                 "  HS_EMPTY\n"
208                 "          \n");
209 
210         fprintf(target->state->header,
211                 "#define HS_TRANSPOSE_SLAB()                \\\n");
212 
213         for (uint32_t ii=1; ii<=config->warp.lanes_log2; ii++)
214           fprintf(target->state->header,
215                   "  HS_TRANSPOSE_STAGE( %u )                  \\\n",ii);
216 
217         struct hsg_transpose_state state[1] =
218           {
219            { .header = target->state->header,
220              .config = config
221            }
222           };
223 
224         hsg_transpose(config->warp.lanes_log2,
225                       config->thread.regs,
226                       hsg_transpose_blend,state,
227                       hsg_transpose_remap,state);
228 
229         fprintf(target->state->header,
230                 "  HS_EMPTY\n"
231                 "          \n");
232 
233         hsg_copyright(target->state->source);
234 
235         hsg_macros(target->state->source);
236       }
237       break;
238 
239     case HSG_OP_TYPE_TARGET_END:
240       // decorate the files
241       fprintf(target->state->header,
242               "#endif  \n"
243               "        \n"
244               "//      \n"
245               "//      \n"
246               "//      \n"
247               "        \n");
248       fprintf(target->state->source,
249               "        \n"
250               "//      \n"
251               "//      \n"
252               "//      \n"
253               "        \n");
254 
255       // close files
256       fclose(target->state->header);
257       fclose(target->state->source);
258 
259       // free state
260       free(target->state);
261       break;
262 
263     case HSG_OP_TYPE_TRANSPOSE_KERNEL_PROTO:
264       {
265         fprintf(target->state->source,
266                 "\nHS_TRANSPOSE_KERNEL_PROTO()\n");
267       }
268       break;
269 
270     case HSG_OP_TYPE_TRANSPOSE_KERNEL_PREAMBLE:
271       {
272         fprintf(target->state->source,
273                 "HS_SLAB_GLOBAL_PREAMBLE();\n");
274       }
275       break;
276 
277     case HSG_OP_TYPE_TRANSPOSE_KERNEL_BODY:
278       {
279         fprintf(target->state->source,
280                 "HS_TRANSPOSE_SLAB()\n");
281       }
282       break;
283 
284     case HSG_OP_TYPE_BS_KERNEL_PROTO:
285       {
286         struct hsg_merge const * const m = merge + ops->a;
287 
288         uint32_t const bs  = pow2_ru_u32(m->warps);
289         uint32_t const msb = msb_idx_u32(bs);
290 
291         fprintf(target->state->source,
292                 "\nHS_BS_KERNEL_PROTO(%u,%u)\n",
293                 m->warps,msb);
294       }
295       break;
296 
297     case HSG_OP_TYPE_BS_KERNEL_PREAMBLE:
298       {
299         struct hsg_merge const * const m = merge + ops->a;
300 
301         if (m->warps > 1)
302           {
303             fprintf(target->state->source,
304                     "HS_BLOCK_LOCAL_MEM_DECL(%u,%u);\n\n",
305                     m->warps * config->warp.lanes,
306                     m->rows_bs);
307           }
308 
309         fprintf(target->state->source,
310                 "HS_SLAB_GLOBAL_PREAMBLE();\n");
311       }
312       break;
313 
314     case HSG_OP_TYPE_BC_KERNEL_PROTO:
315       {
316         struct hsg_merge const * const m = merge + ops->a;
317 
318         uint32_t const msb = msb_idx_u32(m->warps);
319 
320         fprintf(target->state->source,
321                 "\nHS_BC_KERNEL_PROTO(%u,%u)\n",
322                 m->warps,msb);
323       }
324       break;
325 
326     case HSG_OP_TYPE_BC_KERNEL_PREAMBLE:
327       {
328         struct hsg_merge const * const m = merge + ops->a;
329 
330         if (m->warps > 1)
331           {
332             fprintf(target->state->source,
333                     "HS_BLOCK_LOCAL_MEM_DECL(%u,%u);\n\n",
334                     m->warps * config->warp.lanes,
335                     m->rows_bc);
336           }
337 
338         fprintf(target->state->source,
339                 "HS_SLAB_GLOBAL_PREAMBLE();\n");
340       }
341       break;
342 
343     case HSG_OP_TYPE_FM_KERNEL_PROTO:
344       fprintf(target->state->source,
345               "\nHS_FM_KERNEL_PROTO(%u,%u)\n",
346               ops->a,ops->b);
347       break;
348 
349     case HSG_OP_TYPE_FM_KERNEL_PREAMBLE:
350       fprintf(target->state->source,
351               "HS_FM_PREAMBLE(%u);\n",
352               ops->a);
353       break;
354 
355     case HSG_OP_TYPE_HM_KERNEL_PROTO:
356       {
357         fprintf(target->state->source,
358                 "\nHS_HM_KERNEL_PROTO(%u)\n",
359                 ops->a);
360       }
361       break;
362 
363     case HSG_OP_TYPE_HM_KERNEL_PREAMBLE:
364       fprintf(target->state->source,
365               "HS_HM_PREAMBLE(%u);\n",
366               ops->a);
367       break;
368 
369     case HSG_OP_TYPE_BX_REG_GLOBAL_LOAD:
370       {
371         static char const * const vstr[] = { "vin", "vout" };
372 
373         fprintf(target->state->source,
374                 "HS_KEY_TYPE r%-3u = HS_SLAB_GLOBAL_LOAD(%s,%u);\n",
375                 ops->n,vstr[ops->v],ops->n-1);
376       }
377       break;
378 
379     case HSG_OP_TYPE_BX_REG_GLOBAL_STORE:
380       fprintf(target->state->source,
381               "HS_SLAB_GLOBAL_STORE(%u,r%u);\n",
382               ops->n-1,ops->n);
383       break;
384 
385     case HSG_OP_TYPE_HM_REG_GLOBAL_LOAD:
386       fprintf(target->state->source,
387               "HS_KEY_TYPE r%-3u = HS_XM_GLOBAL_LOAD_L(%u);\n",
388               ops->a,ops->b);
389       break;
390 
391     case HSG_OP_TYPE_HM_REG_GLOBAL_STORE:
392       fprintf(target->state->source,
393               "HS_XM_GLOBAL_STORE_L(%-3u,r%u);\n",
394               ops->b,ops->a);
395       break;
396 
397     case HSG_OP_TYPE_FM_REG_GLOBAL_LOAD_LEFT:
398       fprintf(target->state->source,
399               "HS_KEY_TYPE r%-3u = HS_XM_GLOBAL_LOAD_L(%u);\n",
400               ops->a,ops->b);
401       break;
402 
403     case HSG_OP_TYPE_FM_REG_GLOBAL_STORE_LEFT:
404       fprintf(target->state->source,
405               "HS_XM_GLOBAL_STORE_L(%-3u,r%u);\n",
406               ops->b,ops->a);
407       break;
408 
409     case HSG_OP_TYPE_FM_REG_GLOBAL_LOAD_RIGHT:
410       fprintf(target->state->source,
411               "HS_KEY_TYPE r%-3u = HS_FM_GLOBAL_LOAD_R(%u);\n",
412               ops->b,ops->a);
413       break;
414 
415     case HSG_OP_TYPE_FM_REG_GLOBAL_STORE_RIGHT:
416       fprintf(target->state->source,
417               "HS_FM_GLOBAL_STORE_R(%-3u,r%u);\n",
418               ops->a,ops->b);
419       break;
420 
421     case HSG_OP_TYPE_FM_MERGE_RIGHT_PRED:
422       {
423         if (ops->a <= ops->b)
424           {
425             fprintf(target->state->source,
426                     "if (HS_FM_IS_NOT_LAST_SPAN() || (fm_frac == 0))\n");
427           }
428         else if (ops->b > 1)
429           {
430             fprintf(target->state->source,
431                     "else if (fm_frac == %u)\n",
432                     ops->b);
433           }
434         else
435           {
436 	    fprintf(target->state->source,
437 		    "else\n");
438           }
439       }
440       break;
441 
442     case HSG_OP_TYPE_SLAB_FLIP:
443       fprintf(target->state->source,
444               "HS_SLAB_FLIP_PREAMBLE(%u);\n",
445               ops->n-1);
446       break;
447 
448     case HSG_OP_TYPE_SLAB_HALF:
449       fprintf(target->state->source,
450               "HS_SLAB_HALF_PREAMBLE(%u);\n",
451               ops->n / 2);
452       break;
453 
454     case HSG_OP_TYPE_CMP_FLIP:
455       fprintf(target->state->source,
456               "HS_CMP_FLIP(%-3u,r%-3u,r%-3u);\n",ops->a,ops->b,ops->c);
457       break;
458 
459     case HSG_OP_TYPE_CMP_HALF:
460       fprintf(target->state->source,
461               "HS_CMP_HALF(%-3u,r%-3u);\n",ops->a,ops->b);
462       break;
463 
464     case HSG_OP_TYPE_CMP_XCHG:
465       if (ops->c == UINT32_MAX)
466         {
467           fprintf(target->state->source,
468                   "HS_CMP_XCHG(r%-3u,r%-3u);\n",
469                   ops->a,ops->b);
470         }
471       else
472         {
473           fprintf(target->state->source,
474                   "HS_CMP_XCHG(r%u_%u,r%u_%u);\n",
475                   ops->c,ops->a,ops->c,ops->b);
476         }
477       break;
478 
479     case HSG_OP_TYPE_BS_REG_SHARED_STORE_V:
480       fprintf(target->state->source,
481               "HS_BX_LOCAL_V(%-3u * HS_SLAB_THREADS * %-3u) = r%u;\n",
482               merge[ops->a].warps,ops->c,ops->b);
483       break;
484 
485     case HSG_OP_TYPE_BS_REG_SHARED_LOAD_V:
486       fprintf(target->state->source,
487               "r%-3u = HS_BX_LOCAL_V(%-3u * HS_SLAB_THREADS * %-3u);\n",
488               ops->b,merge[ops->a].warps,ops->c);
489       break;
490 
491     case HSG_OP_TYPE_BC_REG_SHARED_LOAD_V:
492       fprintf(target->state->source,
493               "HS_KEY_TYPE r%-3u = HS_BX_LOCAL_V(%-3u * HS_SLAB_THREADS * %-3u);\n",
494               ops->b,ops->a,ops->c);
495       break;
496 
497     case HSG_OP_TYPE_BX_REG_SHARED_STORE_LEFT:
498       fprintf(target->state->source,
499               "HS_SLAB_LOCAL_L(%5u) = r%u_%u;\n",
500               ops->b * config->warp.lanes,
501               ops->c,
502               ops->a);
503       break;
504 
505     case HSG_OP_TYPE_BS_REG_SHARED_STORE_RIGHT:
506       fprintf(target->state->source,
507               "HS_SLAB_LOCAL_R(%5u) = r%u_%u;\n",
508               ops->b * config->warp.lanes,
509               ops->c,
510               ops->a);
511       break;
512 
513     case HSG_OP_TYPE_BS_REG_SHARED_LOAD_LEFT:
514       fprintf(target->state->source,
515               "HS_KEY_TYPE r%u_%-3u = HS_SLAB_LOCAL_L(%u);\n",
516               ops->c,
517               ops->a,
518               ops->b * config->warp.lanes);
519       break;
520 
521     case HSG_OP_TYPE_BS_REG_SHARED_LOAD_RIGHT:
522       fprintf(target->state->source,
523               "HS_KEY_TYPE r%u_%-3u = HS_SLAB_LOCAL_R(%u);\n",
524               ops->c,
525               ops->a,
526               ops->b * config->warp.lanes);
527       break;
528 
529     case HSG_OP_TYPE_BC_REG_GLOBAL_LOAD_LEFT:
530       fprintf(target->state->source,
531               "HS_KEY_TYPE r%u_%-3u = HS_BC_GLOBAL_LOAD_L(%u);\n",
532               ops->c,
533               ops->a,
534               ops->b);
535       break;
536 
537     case HSG_OP_TYPE_BLOCK_SYNC:
538       fprintf(target->state->source,
539               "HS_BLOCK_BARRIER();\n");
540       //
541       // FIXME - Named barriers to allow coordinating warps to proceed?
542       //
543       break;
544 
545     case HSG_OP_TYPE_BS_FRAC_PRED:
546       {
547         if (ops->m == 0)
548           {
549             fprintf(target->state->source,
550                     "if (warp_idx < bs_full)\n");
551           }
552         else
553           {
554             fprintf(target->state->source,
555                     "else if (bs_frac == %u)\n",
556                     ops->w);
557           }
558       }
559       break;
560 
561     case HSG_OP_TYPE_BS_MERGE_H_PREAMBLE:
562       {
563         struct hsg_merge const * const m = merge + ops->a;
564 
565         fprintf(target->state->source,
566                 "HS_BS_MERGE_H_PREAMBLE(%u);\n",
567                 m->warps);
568       }
569       break;
570 
571     case HSG_OP_TYPE_BC_MERGE_H_PREAMBLE:
572       {
573         struct hsg_merge const * const m = merge + ops->a;
574 
575         fprintf(target->state->source,
576                 "HS_BC_MERGE_H_PREAMBLE(%u);\n",
577                 m->warps);
578       }
579       break;
580 
581     case HSG_OP_TYPE_BX_MERGE_H_PRED:
582       fprintf(target->state->source,
583               "if (HS_SUBGROUP_ID() < %u)\n",
584               ops->a);
585       break;
586 
587     case HSG_OP_TYPE_BS_ACTIVE_PRED:
588       {
589         struct hsg_merge const * const m = merge + ops->a;
590 
591         if (m->warps <= 32)
592           {
593             fprintf(target->state->source,
594                     "if (((1u << HS_SUBGROUP_ID()) & 0x%08X) != 0)\n",
595                     m->levels[ops->b].active.b32a2[0]);
596           }
597         else
598           {
599             fprintf(target->state->source,
600                     "if (((1UL << HS_SUBGROUP_ID()) & 0x%08X%08XL) != 0L)\n",
601                     m->levels[ops->b].active.b32a2[1],
602                     m->levels[ops->b].active.b32a2[0]);
603           }
604       }
605       break;
606 
607     default:
608       fprintf(stderr,"type not found: %s\n",hsg_op_type_string[ops->type]);
609       exit(EXIT_FAILURE);
610       break;
611     }
612 }
613 
614 //
615 //
616 //
617