1 //
2 // Copyright 2016 Google Inc.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file.
6 //
7 
8 #ifndef HS_CUDA_MACROS_ONCE
9 #define HS_CUDA_MACROS_ONCE
10 
11 //
12 //
13 //
14 
15 #ifdef __cplusplus
16 extern "C" {
17 #endif
18 
19 #include <stdint.h>
20 
21 #ifdef __cplusplus
22 }
23 #endif
24 
25 //
26 // Define the type based on key and val sizes
27 //
28 
29 #if   HS_KEY_WORDS == 1
30 #if   HS_VAL_WORDS == 0
31 #define HS_KEY_TYPE  uint32_t
32 #endif
33 #elif HS_KEY_WORDS == 2
34 #define HS_KEY_TYPE  uint64_t
35 #endif
36 
37 //
38 // FYI, restrict shouldn't have any impact on these kernels and
39 // benchmarks appear to prove that true
40 //
41 
42 #define HS_RESTRICT  __restrict__
43 
44 //
45 //
46 //
47 
48 #define HS_SCOPE()                              \
49   static
50 
51 #define HS_KERNEL_QUALIFIER()                   \
52   __global__ void
53 
54 //
55 // The sm_35 arch has a maximum of 16 blocks per multiprocessor.  Just
56 // clamp it to 16 when targeting this arch.
57 //
58 // This only arises when compiling the 32-bit sorting kernels.
59 //
60 // You can also generate a narrower 16-warp wide 32-bit sorting kernel
61 // which is sometimes faster and sometimes slower than the 32-block
62 // configuration.
63 //
64 
65 #if ( __CUDA_ARCH__ == 350 )
66 #define HS_CUDA_MAX_BPM  16
67 #else
68 #define HS_CUDA_MAX_BPM  UINT32_MAX // 32
69 #endif
70 
71 #define HS_CLAMPED_BPM(min_bpm)                                 \
72   ((min_bpm) < HS_CUDA_MAX_BPM ? (min_bpm) : HS_CUDA_MAX_BPM)
73 
74 //
75 //
76 //
77 
78 #define HS_LAUNCH_BOUNDS(max_tpb,min_bpm)       \
79   __launch_bounds__(max_tpb,HS_CLAMPED_BPM(min_bpm))
80 
81 //
82 // KERNEL PROTOS
83 //
84 
85 #define HS_BS_KERNEL_NAME(slab_count_ru_log2)   \
86   hs_kernel_bs_##slab_count_ru_log2
87 
88 #define HS_BS_KERNEL_PROTO(slab_count,slab_count_ru_log2)             \
89   HS_SCOPE()                                                          \
90   HS_KERNEL_QUALIFIER()                                               \
91   HS_LAUNCH_BOUNDS(HS_SLAB_THREADS*slab_count,1)                      \
92   HS_BS_KERNEL_NAME(slab_count_ru_log2)(HS_KEY_TYPE       * const HS_RESTRICT vout, \
93                                         HS_KEY_TYPE const * const HS_RESTRICT vin)
94 
95 //
96 
97 #define HS_OFFSET_BS_KERNEL_NAME(slab_count_ru_log2)    \
98   hs_kernel_bs_##slab_count_ru_log2
99 
100 #define HS_OFFSET_BS_KERNEL_PROTO(slab_count,slab_count_ru_log2)              \
101   HS_SCOPE()                                                                  \
102   HS_KERNEL_QUALIFIER()                                                       \
103   HS_LAUNCH_BOUNDS(HS_SLAB_THREADS*slab_count,HS_BS_SLABS/(1<<slab_count_ru_log2)) \
104   HS_OFFSET_BS_KERNEL_NAME(slab_count_ru_log2)(HS_KEY_TYPE       * const HS_RESTRICT vout, \
105                                                HS_KEY_TYPE const * const HS_RESTRICT vin,  \
106                                                uint32_t            const             slab_offset)
107 
108 //
109 
110 #define HS_BC_KERNEL_NAME(slab_count_log2)      \
111   hs_kernel_bc_##slab_count_log2
112 
113 #define HS_BC_KERNEL_PROTO(slab_count,slab_count_log2)                \
114   HS_SCOPE()                                                          \
115   HS_KERNEL_QUALIFIER()                                               \
116   HS_LAUNCH_BOUNDS(HS_SLAB_THREADS*slab_count,HS_BS_SLABS/(1<<slab_count_log2)) \
117   HS_BC_KERNEL_NAME(slab_count_log2)(HS_KEY_TYPE * const HS_RESTRICT vout)
118 
119 //
120 
121 #define HS_HM_KERNEL_NAME(s)                    \
122   hs_kernel_hm_##s
123 
124 #define HS_HM_KERNEL_PROTO(s)                                 \
125   HS_SCOPE()                                                  \
126   HS_KERNEL_QUALIFIER()                                       \
127   HS_HM_KERNEL_NAME(s)(HS_KEY_TYPE * const HS_RESTRICT vout)
128 
129 //
130 
131 #define HS_FM_KERNEL_NAME(s,r)                  \
132   hs_kernel_fm_##s##_##r
133 
134 #define HS_FM_KERNEL_PROTO(s,r)                                      \
135   HS_SCOPE()                                                         \
136   HS_KERNEL_QUALIFIER()                                              \
137   HS_FM_KERNEL_NAME(s,r)(HS_KEY_TYPE * const HS_RESTRICT vout)
138 
139 //
140 
141 #define HS_OFFSET_FM_KERNEL_NAME(s,r)           \
142   hs_kernel_fm_##s##_##r
143 
144 #define HS_OFFSET_FM_KERNEL_PROTO(s,r)                                \
145   HS_SCOPE()                                                          \
146   HS_KERNEL_QUALIFIER()                                               \
147   HS_OFFSET_FM_KERNEL_NAME(s,r)(HS_KEY_TYPE * const HS_RESTRICT vout, \
148                                 uint32_t      const             span_offset)
149 
150 //
151 
152 #define HS_TRANSPOSE_KERNEL_NAME()              \
153   hs_kernel_transpose
154 
155 #define HS_TRANSPOSE_KERNEL_PROTO()                             \
156   HS_SCOPE()                                                    \
157   HS_KERNEL_QUALIFIER()                                         \
158   HS_LAUNCH_BOUNDS(HS_SLAB_THREADS,1)                           \
159   HS_TRANSPOSE_KERNEL_NAME()(HS_KEY_TYPE * const HS_RESTRICT vout)
160 
161 //
162 // BLOCK LOCAL MEMORY DECLARATION
163 //
164 
165 #define HS_BLOCK_LOCAL_MEM_DECL(width,height)   \
166   __shared__ struct {                           \
167     HS_KEY_TYPE m[width * height];              \
168   } shared
169 
170 //
171 // BLOCK BARRIER
172 //
173 
174 #define HS_BLOCK_BARRIER()                      \
175   __syncthreads()
176 
177 //
178 // GRID VARIABLES
179 //
180 
181 #define HS_GLOBAL_SIZE_X() (gridDim.x * blockDim.x)
182 #define HS_GLOBAL_ID_X()   (blockDim.x * blockIdx.x + threadIdx.x)
183 #define HS_LOCAL_ID_X()    threadIdx.x
184 #define HS_WARP_ID_X()     (threadIdx.x / 32)
185 #define HS_LANE_ID()       (threadIdx.x & 31)
186 
187 //
188 // SLAB GLOBAL
189 //
190 
191 #define HS_SLAB_GLOBAL_PREAMBLE()               \
192   uint32_t const gmem_idx =                     \
193     (HS_GLOBAL_ID_X() & ~(HS_SLAB_THREADS-1)) * \
194     HS_SLAB_HEIGHT + HS_LANE_ID()
195 
196 #define HS_OFFSET_SLAB_GLOBAL_PREAMBLE()                        \
197   uint32_t const gmem_idx =                                     \
198     ((slab_offset + HS_GLOBAL_ID_X()) & ~(HS_SLAB_THREADS-1)) * \
199     HS_SLAB_HEIGHT + HS_LANE_ID()
200 
201 #define HS_SLAB_GLOBAL_LOAD(extent,row_idx)  \
202   extent[gmem_idx + HS_SLAB_THREADS * row_idx]
203 
204 #define HS_SLAB_GLOBAL_STORE(row_idx,reg)    \
205   vout[gmem_idx + HS_SLAB_THREADS * row_idx] = reg
206 
207 //
208 // SLAB LOCAL
209 //
210 
211 #define HS_SLAB_LOCAL_L(offset)                 \
212   shared.m[smem_l_idx + (offset)]
213 
214 #define HS_SLAB_LOCAL_R(offset)                 \
215   shared.m[smem_r_idx + (offset)]
216 
217 //
218 // SLAB LOCAL VERTICAL LOADS
219 //
220 
221 #define HS_BX_LOCAL_V(offset)                   \
222   shared.m[HS_LOCAL_ID_X() + (offset)]
223 
224 //
225 // BLOCK SORT MERGE HORIZONTAL
226 //
227 
228 #define HS_BS_MERGE_H_PREAMBLE(slab_count)                      \
229   uint32_t const smem_l_idx =                                   \
230     HS_WARP_ID_X() * (HS_SLAB_THREADS * slab_count) +           \
231     HS_LANE_ID();                                               \
232   uint32_t const smem_r_idx =                                   \
233     (HS_WARP_ID_X() ^ 1) * (HS_SLAB_THREADS * slab_count) +     \
234     (HS_LANE_ID() ^ (HS_SLAB_THREADS - 1))
235 
236 //
237 // BLOCK CLEAN MERGE HORIZONTAL
238 //
239 
240 #define HS_BC_MERGE_H_PREAMBLE(slab_count)                      \
241   uint32_t const gmem_l_idx =                                   \
242     (HS_GLOBAL_ID_X() & ~(HS_SLAB_THREADS*slab_count-1)) *      \
243     HS_SLAB_HEIGHT + HS_LOCAL_ID_X();                           \
244   uint32_t const smem_l_idx =                                   \
245     HS_WARP_ID_X() * (HS_SLAB_THREADS * slab_count) +           \
246     HS_LANE_ID()
247 
248 #define HS_BC_GLOBAL_LOAD_L(slab_idx)                   \
249   vout[gmem_l_idx + (HS_SLAB_THREADS * slab_idx)]
250 
251 //
252 // SLAB FLIP AND HALF PREAMBLES
253 //
254 
255 #define HS_SLAB_FLIP_PREAMBLE(mask)                             \
256   uint32_t const flip_lane_idx  = HS_LANE_ID() ^ mask;          \
257   int32_t  const t_lt           = HS_LANE_ID() < flip_lane_idx;
258 
259 // if we want to shlf_xor: uint32_t const flip_lane_mask = mask;
260 
261 #define HS_SLAB_HALF_PREAMBLE(mask)                             \
262   uint32_t const half_lane_idx  = HS_LANE_ID() ^ mask;          \
263   int32_t  const t_lt           = HS_LANE_ID() < half_lane_idx;
264 
265 // if we want to shfl_xor: uint32_t const half_lane_mask = mask;
266 
267 //
268 // Inter-lane compare exchange
269 //
270 
271 // good
272 #define HS_CMP_XCHG_V0(a,b)                     \
273   {                                             \
274     HS_KEY_TYPE const t = min(a,b);             \
275     b = max(a,b);                               \
276     a = t;                                      \
277   }
278 
279 // surprisingly fast -- #1 on 64-bit keys
280 #define HS_CMP_XCHG_V1(a,b)                     \
281   {                                             \
282     HS_KEY_TYPE const tmp = a;                  \
283     a  = (a < b) ? a : b;                       \
284     b ^= a ^ tmp;                               \
285   }
286 
287 // good
288 #define HS_CMP_XCHG_V2(a,b)                     \
289   if (a >= b) {                                 \
290     HS_KEY_TYPE const t = a;                    \
291     a = b;                                      \
292     b = t;                                      \
293   }
294 
295 // good
296 #define HS_CMP_XCHG_V3(a,b)                     \
297   {                                             \
298     int32_t     const ge = a >= b;              \
299     HS_KEY_TYPE const t  = a;                   \
300     a = ge ? b : a;                             \
301     b = ge ? t : b;                             \
302   }
303 
304 //
305 //
306 //
307 
308 #if   (HS_KEY_WORDS == 1)
309 #define HS_CMP_XCHG(a,b)  HS_CMP_XCHG_V0(a,b)
310 #elif (HS_KEY_WORDS == 2)
311 #define HS_CMP_XCHG(a,b)  HS_CMP_XCHG_V0(a,b)
312 #endif
313 
314 //
315 // The flip/half comparisons rely on a "conditional min/max":
316 //
317 //  - if the flag is false, return min(a,b)
318 //  - otherwise, return max(a,b)
319 //
320 // What's a little surprising is that sequence (1) is faster than (2)
321 // for 32-bit keys.
322 //
323 // I suspect either a code generation problem or that the sequence
324 // maps well to the GEN instruction set.
325 //
326 // We mostly care about 64-bit keys and unsurprisingly sequence (2) is
327 // fastest for this wider type.
328 //
329 
330 // this is what you would normally use
331 #define HS_COND_MIN_MAX_V0(lt,a,b) ((a <= b) ^ lt) ? b : a
332 
333 // this seems to be faster for 32-bit keys
334 #define HS_COND_MIN_MAX_V1(lt,a,b) (lt ? b : a) ^ ((a ^ b) & HS_LTE_TO_MASK(a,b))
335 
336 //
337 //
338 //
339 
340 #if   (HS_KEY_WORDS == 1)
341 #define HS_COND_MIN_MAX(lt,a,b) HS_COND_MIN_MAX_V0(lt,a,b)
342 #elif (HS_KEY_WORDS == 2)
343 #define HS_COND_MIN_MAX(lt,a,b) HS_COND_MIN_MAX_V0(lt,a,b)
344 #endif
345 
346 //
347 // HotSort shuffles are always warp-wide
348 //
349 
350 #define HS_SHFL_ALL 0xFFFFFFFF
351 
352 //
353 // Conditional inter-subgroup flip/half compare exchange
354 //
355 
356 #define HS_CMP_FLIP(i,a,b)                                              \
357   {                                                                     \
358     HS_KEY_TYPE const ta = __shfl_sync(HS_SHFL_ALL,a,flip_lane_idx);    \
359     HS_KEY_TYPE const tb = __shfl_sync(HS_SHFL_ALL,b,flip_lane_idx);    \
360     a = HS_COND_MIN_MAX(t_lt,a,tb);                                     \
361     b = HS_COND_MIN_MAX(t_lt,b,ta);                                     \
362   }
363 
364 #define HS_CMP_HALF(i,a)                                                \
365   {                                                                     \
366     HS_KEY_TYPE const ta = __shfl_sync(HS_SHFL_ALL,a,half_lane_idx);    \
367     a = HS_COND_MIN_MAX(t_lt,a,ta);                                     \
368   }
369 
370 //
371 // The device's comparison operator might return what we actually
372 // want.  For example, it appears GEN 'cmp' returns {true:-1,false:0}.
373 //
374 
375 #define HS_CMP_IS_ZERO_ONE
376 
377 #ifdef HS_CMP_IS_ZERO_ONE
378 // OpenCL requires a {true: +1, false: 0} scalar result
379 // (a < b) -> { +1, 0 } -> NEGATE -> { 0, 0xFFFFFFFF }
380 #define HS_LTE_TO_MASK(a,b) (HS_KEY_TYPE)(-(a <= b))
381 #define HS_CMP_TO_MASK(a)   (HS_KEY_TYPE)(-a)
382 #else
383 // However, OpenCL requires { -1, 0 } for vectors
384 // (a < b) -> { 0xFFFFFFFF, 0 }
385 #define HS_LTE_TO_MASK(a,b) (a <= b) // FIXME for uint64
386 #define HS_CMP_TO_MASK(a)   (a)
387 #endif
388 
389 //
390 // The "flip-merge" and "half-merge" preambles are very similar
391 //
392 // For now, we're only using the .y dimension for the span idx
393 //
394 
395 #define HS_OFFSET_HM_PREAMBLE(half_span,span_offset)                    \
396   uint32_t const span_idx    = span_offset + blockIdx.y;                \
397   uint32_t const span_stride = HS_GLOBAL_SIZE_X();                      \
398   uint32_t const span_size   = span_stride * half_span * 2;             \
399   uint32_t const span_base   = span_idx * span_size;                    \
400   uint32_t const span_off    = HS_GLOBAL_ID_X();                        \
401   uint32_t const span_l      = span_base + span_off
402 
403 #define HS_HM_PREAMBLE(half_span)               \
404   HS_OFFSET_HM_PREAMBLE(half_span,0)            \
405 
406 #define HS_FM_PREAMBLE(half_span)                                       \
407   HS_HM_PREAMBLE(half_span);                                            \
408   uint32_t const span_r = span_base + span_stride * (half_span + 1) - span_off - 1
409 
410 #define HS_OFFSET_FM_PREAMBLE(half_span)                                \
411   HS_OFFSET_HM_PREAMBLE(half_span,span_offset);                         \
412   uint32_t const span_r = span_base + span_stride * (half_span + 1) - span_off - 1
413 
414 //
415 //
416 //
417 
418 #define HS_XM_GLOBAL_L(stride_idx)              \
419   vout[span_l + span_stride * stride_idx]
420 
421 #define HS_XM_GLOBAL_LOAD_L(stride_idx)         \
422   HS_XM_GLOBAL_L(stride_idx)
423 
424 #define HS_XM_GLOBAL_STORE_L(stride_idx,reg)    \
425   HS_XM_GLOBAL_L(stride_idx) = reg
426 
427 #define HS_FM_GLOBAL_R(stride_idx)              \
428   vout[span_r + span_stride * stride_idx]
429 
430 #define HS_FM_GLOBAL_LOAD_R(stride_idx)         \
431   HS_FM_GLOBAL_R(stride_idx)
432 
433 #define HS_FM_GLOBAL_STORE_R(stride_idx,reg)    \
434   HS_FM_GLOBAL_R(stride_idx) = reg
435 
436 //
437 // This snarl of macros is for transposing a "slab" of sorted elements
438 // into linear order.
439 //
440 // This can occur as the last step in hs_sort() or via a custom kernel
441 // that inspects the slab and then transposes and stores it to memory.
442 //
443 // The slab format can be inspected more efficiently than a linear
444 // arrangement.
445 //
446 // The prime example is detecting when adjacent keys (in sort order)
447 // have differing high order bits ("key changes").  The index of each
448 // change is recorded to an auxilary array.
449 //
450 // A post-processing step like this needs to be able to navigate the
451 // slab and eventually transpose and store the slab in linear order.
452 //
453 
454 #define HS_SUBGROUP_SHUFFLE_XOR(v,m)   __shfl_xor_sync(HS_SHFL_ALL,v,m)
455 
456 #define HS_TRANSPOSE_REG(prefix,row)   prefix##row
457 #define HS_TRANSPOSE_DECL(prefix,row)  HS_KEY_TYPE const HS_TRANSPOSE_REG(prefix,row)
458 #define HS_TRANSPOSE_PRED(level)       is_lo_##level
459 
460 #define HS_TRANSPOSE_TMP_REG(prefix_curr,row_ll,row_ur)       \
461   prefix_curr##row_ll##_##row_ur
462 
463 #define HS_TRANSPOSE_TMP_DECL(prefix_curr,row_ll,row_ur)      \
464   HS_KEY_TYPE const HS_TRANSPOSE_TMP_REG(prefix_curr,row_ll,row_ur)
465 
466 #define HS_TRANSPOSE_STAGE(level)                       \
467   bool const HS_TRANSPOSE_PRED(level) =                 \
468     (HS_LANE_ID() & (1 << (level-1))) == 0;
469 
470 #define HS_TRANSPOSE_BLEND(prefix_prev,prefix_curr,level,row_ll,row_ur) \
471   HS_TRANSPOSE_TMP_DECL(prefix_curr,row_ll,row_ur) =                    \
472     HS_SUBGROUP_SHUFFLE_XOR(HS_TRANSPOSE_PRED(level) ?                  \
473                             HS_TRANSPOSE_REG(prefix_prev,row_ll) :      \
474                             HS_TRANSPOSE_REG(prefix_prev,row_ur),       \
475                             1<<(level-1));                              \
476                                                                         \
477   HS_TRANSPOSE_DECL(prefix_curr,row_ll) =                               \
478     HS_TRANSPOSE_PRED(level)                  ?                         \
479     HS_TRANSPOSE_TMP_REG(prefix_curr,row_ll,row_ur) :                   \
480     HS_TRANSPOSE_REG(prefix_prev,row_ll);                               \
481                                                                         \
482   HS_TRANSPOSE_DECL(prefix_curr,row_ur) =                               \
483     HS_TRANSPOSE_PRED(level)                  ?                         \
484     HS_TRANSPOSE_REG(prefix_prev,row_ur)      :                         \
485     HS_TRANSPOSE_TMP_REG(prefix_curr,row_ll,row_ur);
486 
487 #define HS_TRANSPOSE_REMAP(prefix,row_from,row_to)      \
488   vout[gmem_idx + ((row_to-1) << HS_SLAB_WIDTH_LOG2)] = \
489     HS_TRANSPOSE_REG(prefix,row_from);
490 
491 //
492 //
493 //
494 
495 #endif
496 
497 //
498 //
499 //
500