1/*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can
5 * be found in the LICENSE file.
6 *
7 */
8
9//
10//
11//
12
13#ifdef __cplusplus
14extern "C" {
15#endif
16
17#include "common/cuda/assert_cuda.h"
18#include "common/macros.h"
19#include "common/util.h"
20
21#ifdef __cplusplus
22}
23#endif
24
25//
26// We want concurrent kernel execution to occur in a few places.
27//
28// The summary is:
29//
30//   1) If necessary, some max valued keys are written to the end of
31//      the vin/vout buffers.
32//
33//   2) Blocks of slabs of keys are sorted.
34//
35//   3) If necesary, the blocks of slabs are merged until complete.
36//
37//   4) If requested, the slabs will be converted from slab ordering
38//      to linear ordering.
39//
40// Below is the general "happens-before" relationship between HotSort
41// compute kernels.
42//
43// Note the diagram assumes vin and vout are different buffers.  If
44// they're not, then the first merge doesn't include the pad_vout
45// event in the wait list.
46//
47//                    +----------+            +---------+
48//                    | pad_vout |            | pad_vin |
49//                    +----+-----+            +----+----+
50//                         |                       |
51//                         |                WAITFOR(pad_vin)
52//                         |                       |
53//                         |                 +-----v-----+
54//                         |                 |           |
55//                         |            +----v----+ +----v----+
56//                         |            | bs_full | | bs_frac |
57//                         |            +----+----+ +----+----+
58//                         |                 |           |
59//                         |                 +-----v-----+
60//                         |                       |
61//                         |  +------NO------JUST ONE BLOCK?
62//                         | /                     |
63//                         |/                     YES
64//                         +                       |
65//                         |                       v
66//                         |         END_WITH_EVENTS(bs_full,bs_frac)
67//                         |
68//                         |
69//        WAITFOR(pad_vout,bs_full,bs_frac) >>> first iteration of loop <<<
70//                         |
71//                         |
72//                         +-----------<------------+
73//                         |                        |
74//                   +-----v-----+                  |
75//                   |           |                  |
76//              +----v----+ +----v----+             |
77//              | fm_full | | fm_frac |             |
78//              +----+----+ +----+----+             |
79//                   |           |                  ^
80//                   +-----v-----+                  |
81//                         |                        |
82//              WAITFOR(fm_full,fm_frac)            |
83//                         |                        |
84//                         v                        |
85//                      +--v--+                WAITFOR(bc)
86//                      | hm  |                     |
87//                      +-----+                     |
88//                         |                        |
89//                    WAITFOR(hm)                   |
90//                         |                        ^
91//                      +--v--+                     |
92//                      | bc  |                     |
93//                      +-----+                     |
94//                         |                        |
95//                         v                        |
96//                  MERGING COMPLETE?-------NO------+
97//                         |
98//                        YES
99//                         |
100//                         v
101//                END_WITH_EVENTS(bc)
102//
103//
104// NOTE: CUDA streams are in-order so a dependency isn't required for
105// kernels launched on the same stream.
106//
107// This is actually a more subtle problem than it appears.
108//
109// We'll take a different approach and declare the "happens before"
110// kernel relationships:
111//
112//      concurrent (pad_vin,pad_vout) -> (pad_vin)  happens_before (bs_full,bs_frac)
113//                                       (pad_vout) happens_before (fm_full,fm_frac)
114//
115//      concurrent (bs_full,bs_frac)  -> (bs_full)  happens_before (fm_full,fm_frac)
116//                                       (bs_frac)  happens_before (fm_full,fm_frac)
117//
118//      concurrent (fm_full,fm_frac)  -> (fm_full)  happens_before (hm)
119//                                       (fm_frac)  happens_before (hm)
120//
121//      concurrent (fm_full,fm_frac)  -> (fm_full)  happens_before (hm)
122//                                       (fm_frac)  happens_before (hm)
123//
124//      launch     (hm)               -> (hm)       happens_before (hm)
125//                                       (hm)       happens_before (bc)
126//
127//      launch     (bc)               -> (bc)       happens_before (fm_full,fm_frac)
128//
129//
130// We can go ahead and permanently map kernel launches to our 3
131// streams.  As an optimization, we'll dynamically assign each kernel
132// to the lowest available stream.  This transforms the problem into
133// one that considers streams happening before streams -- which
134// kernels are involved doesn't matter.
135//
136//      STREAM0   STREAM1   STREAM2
137//      -------   -------   -------
138//
139//      pad_vin             pad_vout     (pad_vin)  happens_before (bs_full,bs_frac)
140//                                       (pad_vout) happens_before (fm_full,fm_frac)
141//
142//      bs_full   bs_frac                (bs_full)  happens_before (fm_full,fm_frac)
143//                                       (bs_frac)  happens_before (fm_full,fm_frac)
144//
145//      fm_full   fm_frac                (fm_full)  happens_before (hm or bc)
146//                                       (fm_frac)  happens_before (hm or bc)
147//
148//      hm                               (hm)       happens_before (hm or bc)
149//
150//      bc                               (bc)       happens_before (fm_full,fm_frac)
151//
152// A single final kernel will always complete on stream 0.
153//
154// This simplifies reasoning about concurrency that's downstream of
155// hs_cuda_sort().
156//
157
158typedef void (*hs_kernel_offset_bs_pfn)(HS_KEY_TYPE       * const HS_RESTRICT vout,
159                                        HS_KEY_TYPE const * const HS_RESTRICT vin,
160                                        uint32_t            const slab_offset);
161
162static hs_kernel_offset_bs_pfn const hs_kernels_offset_bs[]
163{
164#if HS_BS_SLABS_LOG2_RU >= 1
165  hs_kernel_bs_0,
166#endif
167#if HS_BS_SLABS_LOG2_RU >= 2
168  hs_kernel_bs_1,
169#endif
170#if HS_BS_SLABS_LOG2_RU >= 3
171  hs_kernel_bs_2,
172#endif
173#if HS_BS_SLABS_LOG2_RU >= 4
174  hs_kernel_bs_3,
175#endif
176#if HS_BS_SLABS_LOG2_RU >= 5
177  hs_kernel_bs_4,
178#endif
179#if HS_BS_SLABS_LOG2_RU >= 6
180  hs_kernel_bs_5,
181#endif
182#if HS_BS_SLABS_LOG2_RU >= 7
183  hs_kernel_bs_6,
184#endif
185#if HS_BS_SLABS_LOG2_RU >= 8
186  hs_kernel_bs_7,
187#endif
188};
189
190//
191//
192//
193
194typedef void (*hs_kernel_bc_pfn)(HS_KEY_TYPE * const HS_RESTRICT vout);
195
196static hs_kernel_bc_pfn const hs_kernels_bc[]
197{
198  hs_kernel_bc_0,
199#if HS_BC_SLABS_LOG2_MAX >= 1
200  hs_kernel_bc_1,
201#endif
202#if HS_BC_SLABS_LOG2_MAX >= 2
203  hs_kernel_bc_2,
204#endif
205#if HS_BC_SLABS_LOG2_MAX >= 3
206  hs_kernel_bc_3,
207#endif
208#if HS_BC_SLABS_LOG2_MAX >= 4
209  hs_kernel_bc_4,
210#endif
211#if HS_BC_SLABS_LOG2_MAX >= 5
212  hs_kernel_bc_5,
213#endif
214#if HS_BC_SLABS_LOG2_MAX >= 6
215  hs_kernel_bc_6,
216#endif
217#if HS_BC_SLABS_LOG2_MAX >= 7
218  hs_kernel_bc_7,
219#endif
220#if HS_BC_SLABS_LOG2_MAX >= 8
221  hs_kernel_bc_8,
222#endif
223};
224
225//
226//
227//
228
229typedef void (*hs_kernel_hm_pfn)(HS_KEY_TYPE * const HS_RESTRICT vout);
230
231static hs_kernel_hm_pfn const hs_kernels_hm[]
232{
233#if (HS_HM_SCALE_MIN == 0)
234  hs_kernel_hm_0,
235#endif
236#if (HS_HM_SCALE_MIN <= 1) && (1 <= HS_HM_SCALE_MAX)
237  hs_kernel_hm_1,
238#endif
239#if (HS_HM_SCALE_MIN <= 2) && (2 <= HS_HM_SCALE_MAX)
240  hs_kernel_hm_2,
241#endif
242};
243
244//
245//
246//
247
248typedef void (*hs_kernel_fm_pfn)(HS_KEY_TYPE * const HS_RESTRICT vout);
249
250static hs_kernel_fm_pfn const hs_kernels_fm[]
251{
252#if (HS_FM_SCALE_MIN == 0)
253#if (HS_BS_SLABS_LOG2_RU == 1)
254  hs_kernel_fm_0_0,
255#endif
256#if (HS_BS_SLABS_LOG2_RU == 2)
257  hs_kernel_fm_0_1,
258#endif
259#if (HS_BS_SLABS_LOG2_RU == 3)
260  hs_kernel_fm_0_2,
261#endif
262#if (HS_BS_SLABS_LOG2_RU == 4)
263  hs_kernel_fm_0_3,
264#endif
265#if (HS_BS_SLABS_LOG2_RU == 5)
266  hs_kernel_fm_0_4,
267#endif
268#if (HS_BS_SLABS_LOG2_RU == 6)
269  hs_kernel_fm_0_5,
270#endif
271#if (HS_BS_SLABS_LOG2_RU == 7)
272  hs_kernel_fm_0_6,
273#endif
274#endif
275
276#if (HS_FM_SCALE_MIN <= 1) && (1 <= HS_FM_SCALE_MAX)
277  CONCAT_MACRO(hs_kernel_fm_1_,HS_BS_SLABS_LOG2_RU)
278#endif
279
280#if (HS_FM_SCALE_MIN <= 2) && (2 <= HS_FM_SCALE_MAX)
281#if (HS_BS_SLABS_LOG2_RU == 1)
282  hs_kernel_fm_2_2,
283#endif
284#if (HS_BS_SLABS_LOG2_RU == 2)
285  hs_kernel_fm_2_3,
286#endif
287#if (HS_BS_SLABS_LOG2_RU == 3)
288  hs_kernel_fm_2_4,
289#endif
290#if (HS_BS_SLABS_LOG2_RU == 4)
291  hs_kernel_fm_2_5,
292#endif
293#if (HS_BS_SLABS_LOG2_RU == 5)
294  hs_kernel_fm_2_6,
295#endif
296#if (HS_BS_SLABS_LOG2_RU == 6)
297  hs_kernel_fm_2_7,
298#endif
299#if (HS_BS_SLABS_LOG2_RU == 7)
300  hs_kernel_fm_2_8,
301#endif
302
303#endif
304};
305
306//
307//
308//
309
310typedef void (*hs_kernel_offset_fm_pfn)(HS_KEY_TYPE * const HS_RESTRICT vout,
311                                        uint32_t const span_offset);
312
313#if (HS_FM_SCALE_MIN == 0)
314static hs_kernel_offset_fm_pfn const hs_kernels_offset_fm_0[]
315{
316#if (HS_BS_SLABS_LOG2_RU >= 2)
317  hs_kernel_fm_0_0,
318#endif
319#if (HS_BS_SLABS_LOG2_RU >= 3)
320  hs_kernel_fm_0_1,
321#endif
322#if (HS_BS_SLABS_LOG2_RU >= 4)
323  hs_kernel_fm_0_2,
324#endif
325#if (HS_BS_SLABS_LOG2_RU >= 5)
326  hs_kernel_fm_0_3,
327#endif
328#if (HS_BS_SLABS_LOG2_RU >= 6)
329  hs_kernel_fm_0_4,
330#endif
331#if (HS_BS_SLABS_LOG2_RU >= 7)
332  hs_kernel_fm_0_5,
333#endif
334};
335#endif
336
337#if (HS_FM_SCALE_MIN <= 1) && (1 <= HS_FM_SCALE_MAX)
338static hs_kernel_offset_fm_pfn const hs_kernels_offset_fm_1[]
339{
340#if (HS_BS_SLABS_LOG2_RU >= 1)
341  hs_kernel_fm_1_0,
342#endif
343#if (HS_BS_SLABS_LOG2_RU >= 2)
344  hs_kernel_fm_1_1,
345#endif
346#if (HS_BS_SLABS_LOG2_RU >= 3)
347  hs_kernel_fm_1_2,
348#endif
349#if (HS_BS_SLABS_LOG2_RU >= 4)
350  hs_kernel_fm_1_3,
351#endif
352#if (HS_BS_SLABS_LOG2_RU >= 5)
353  hs_kernel_fm_1_4,
354#endif
355#if (HS_BS_SLABS_LOG2_RU >= 6)
356  hs_kernel_fm_1_5,
357#endif
358#if (HS_BS_SLABS_LOG2_RU >= 7)
359  hs_kernel_fm_1_6,
360#endif
361};
362#endif
363
364#if (HS_FM_SCALE_MIN <= 2) && (2 <= HS_FM_SCALE_MAX)
365static hs_kernel_offset_fm_pfn const hs_kernels_offset_fm_2[]
366{
367  hs_kernel_fm_2_0,
368#if (HS_BS_SLABS_LOG2_RU >= 1)
369  hs_kernel_fm_2_1,
370#endif
371#if (HS_BS_SLABS_LOG2_RU >= 2)
372  hs_kernel_fm_2_2,
373#endif
374#if (HS_BS_SLABS_LOG2_RU >= 3)
375  hs_kernel_fm_2_3,
376#endif
377#if (HS_BS_SLABS_LOG2_RU >= 4)
378  hs_kernel_fm_2_4,
379#endif
380#if (HS_BS_SLABS_LOG2_RU >= 5)
381  hs_kernel_fm_2_5,
382#endif
383#if (HS_BS_SLABS_LOG2_RU >= 6)
384  hs_kernel_fm_2_6,
385#endif
386#if (HS_BS_SLABS_LOG2_RU >= 7)
387  hs_kernel_fm_2_7,
388#endif
389};
390#endif
391
392static hs_kernel_offset_fm_pfn const * const hs_kernels_offset_fm[]
393{
394#if (HS_FM_SCALE_MIN == 0)
395  hs_kernels_offset_fm_0,
396#endif
397#if (HS_FM_SCALE_MIN <= 1) && (1 <= HS_FM_SCALE_MAX)
398  hs_kernels_offset_fm_1,
399#endif
400#if (HS_FM_SCALE_MIN <= 2) && (2 <= HS_FM_SCALE_MAX)
401  hs_kernels_offset_fm_2,
402#endif
403};
404
405//
406//
407//
408
409typedef uint32_t hs_indices_t;
410
411//
412//
413//
414
415struct hs_state
416{
417  // key buffers
418  HS_KEY_TYPE *  vin;
419  HS_KEY_TYPE *  vout; // can be vin
420
421  cudaStream_t   streams[3];
422
423  // pool of stream indices
424  hs_indices_t   pool;
425
426  // bx_ru is number of rounded up warps in vin
427  uint32_t       bx_ru;
428};
429
430//
431//
432//
433
434static
435uint32_t
436hs_indices_acquire(hs_indices_t * const indices)
437{
438  //
439  // FIXME -- an FFS intrinsic might be faster but there are so few
440  // bits in this implementation that it might not matter.
441  //
442  if      (*indices & 1)
443    {
444      *indices = *indices & ~1;
445      return 0;
446    }
447  else if (*indices & 2)
448    {
449      *indices = *indices & ~2;
450      return 1;
451    }
452  else // if (*indices & 4)
453    {
454      *indices = *indices & ~4;
455      return 2;
456    }
457}
458
459
460static
461uint32_t
462hs_state_acquire(struct hs_state * const state,
463                 hs_indices_t    * const indices)
464{
465  //
466  // FIXME -- an FFS intrinsic might be faster but there are so few
467  // bits in this implementation that it might not matter.
468  //
469  if      (state->pool & 1)
470    {
471      state->pool &= ~1;
472      *indices    |=  1;
473      return 0;
474    }
475  else if (state->pool & 2)
476    {
477      state->pool &= ~2;
478      *indices    |=  2;
479      return 1;
480    }
481  else // (state->pool & 4)
482    {
483      state->pool &= ~4;
484      *indices    |=  4;
485      return 2;
486    }
487}
488
489static
490void
491hs_indices_merge(hs_indices_t * const to, hs_indices_t const from)
492{
493  *to |= from;
494}
495
496static
497void
498hs_barrier_enqueue(cudaStream_t to, cudaStream_t from)
499{
500  cudaEvent_t event_before;
501
502  cuda(EventCreate(&event_before));
503
504  cuda(EventRecord(event_before,from));
505
506  cuda(StreamWaitEvent(to,event_before,0));
507
508  cuda(EventDestroy(event_before));
509}
510
511static
512hs_indices_t
513hs_barrier(struct hs_state * const state,
514           hs_indices_t      const before,
515           hs_indices_t    * const after,
516           uint32_t          const count) // count is 1 or 2
517{
518  // return streams this stage depends on back into the pool
519  hs_indices_merge(&state->pool,before);
520
521  hs_indices_t indices = 0;
522
523  // acquire 'count' stream indices for this stage
524  for (uint32_t ii=0; ii<count; ii++)
525    {
526      hs_indices_t new_indices = 0;
527
528      // new index
529      uint32_t const idx = hs_state_acquire(state,&new_indices);
530
531      // add the new index to the indices
532      indices |= new_indices;
533
534      // only enqueue barriers when streams are different
535      uint32_t const wait = before & ~new_indices;
536
537      if (wait != 0)
538        {
539          cudaStream_t to = state->streams[idx];
540
541          //
542          // FIXME -- an FFS loop might be slower for so few bits. So
543          // leave it as is for now.
544          //
545          if (wait & 1)
546            hs_barrier_enqueue(to,state->streams[0]);
547          if (wait & 2)
548            hs_barrier_enqueue(to,state->streams[1]);
549          if (wait & 4)
550            hs_barrier_enqueue(to,state->streams[2]);
551        }
552    }
553
554  hs_indices_merge(after,indices);
555
556  return indices;
557}
558
559//
560//
561//
562
563#ifndef NDEBUG
564
565#include <stdio.h>
566#define HS_STREAM_SYNCHRONIZE(s)                \
567  cuda(StreamSynchronize(s));                   \
568  fprintf(stderr,"%s\n",__func__);
569#else
570
571#define HS_STREAM_SYNCHRONIZE(s)
572
573#endif
574
575//
576//
577//
578
579static
580void
581hs_transpose(struct hs_state * const state)
582{
583  HS_TRANSPOSE_KERNEL_NAME()
584    <<<state->bx_ru,HS_SLAB_THREADS,0,state->streams[0]>>>
585    (state->vout);
586
587  HS_STREAM_SYNCHRONIZE(state->streams[0]);
588}
589
590//
591//
592//
593
594static
595void
596hs_bc(struct hs_state * const state,
597      hs_indices_t      const hs_bc,
598      hs_indices_t    * const fm,
599      uint32_t          const down_slabs,
600      uint32_t          const clean_slabs_log2)
601{
602  // enqueue any necessary barriers
603  hs_indices_t indices = hs_barrier(state,hs_bc,fm,1);
604
605  // block clean the minimal number of down_slabs_log2 spans
606  uint32_t const frac_ru = (1u << clean_slabs_log2) - 1;
607  uint32_t const full    = (down_slabs + frac_ru) >> clean_slabs_log2;
608  uint32_t const threads = HS_SLAB_THREADS << clean_slabs_log2;
609
610  // stream will *always* be stream[0]
611  cudaStream_t stream  = state->streams[hs_indices_acquire(&indices)];
612
613  hs_kernels_bc[clean_slabs_log2]
614    <<<full,threads,0,stream>>>
615    (state->vout);
616
617  HS_STREAM_SYNCHRONIZE(stream);
618}
619
620//
621//
622//
623
624static
625uint32_t
626hs_hm(struct hs_state  * const state,
627      hs_indices_t       const hs_bc,
628      hs_indices_t     * const hs_bc_tmp,
629      uint32_t           const down_slabs,
630      uint32_t           const clean_slabs_log2)
631{
632  // enqueue any necessary barriers
633  hs_indices_t   indices    = hs_barrier(state,hs_bc,hs_bc_tmp,1);
634
635  // how many scaled half-merge spans are there?
636  uint32_t const frac_ru    = (1 << clean_slabs_log2) - 1;
637  uint32_t const spans      = (down_slabs + frac_ru) >> clean_slabs_log2;
638
639  // for now, just clamp to the max
640  uint32_t const log2_rem   = clean_slabs_log2 - HS_BC_SLABS_LOG2_MAX;
641  uint32_t const scale_log2 = MIN_MACRO(HS_HM_SCALE_MAX,log2_rem);
642  uint32_t const log2_out   = log2_rem - scale_log2;
643
644  //
645  // Size the grid
646  //
647  // The simplifying choices below limit the maximum keys that can be
648  // sorted with this grid scheme to around ~2B.
649  //
650  //   .x : slab height << clean_log2  -- this is the slab span
651  //   .y : [1...65535]                -- this is the slab index
652  //   .z : ( this could also be used to further expand .y )
653  //
654  // Note that OpenCL declares a grid in terms of global threads and
655  // not grids and blocks
656  //
657  dim3 grid;
658
659  grid.x = (HS_SLAB_HEIGHT / HS_HM_BLOCK_HEIGHT) << log2_out;
660  grid.y = spans;
661  grid.z = 1;
662
663  cudaStream_t stream = state->streams[hs_indices_acquire(&indices)];
664
665  hs_kernels_hm[scale_log2-HS_HM_SCALE_MIN]
666    <<<grid,HS_SLAB_THREADS * HS_HM_BLOCK_HEIGHT,0,stream>>>
667    (state->vout);
668
669  HS_STREAM_SYNCHRONIZE(stream);
670
671  return log2_out;
672}
673
674//
675// FIXME -- some of this logic can be skipped if BS is a power-of-two
676//
677
678static
679uint32_t
680hs_fm(struct hs_state * const state,
681      hs_indices_t      const fm,
682      hs_indices_t    * const hs_bc,
683      uint32_t        * const down_slabs,
684      uint32_t          const up_scale_log2)
685{
686  //
687  // FIXME OPTIMIZATION: in previous HotSort launchers it's sometimes
688  // a performance win to bias toward launching the smaller flip merge
689  // kernel in order to get more warps in flight (increased
690  // occupancy).  This is useful when merging small numbers of slabs.
691  //
692  // Note that HS_FM_SCALE_MIN will always be 0 or 1.
693  //
694  // So, for now, just clamp to the max until there is a reason to
695  // restore the fancier and probably low-impact approach.
696  //
697  uint32_t const scale_log2 = MIN_MACRO(HS_FM_SCALE_MAX,up_scale_log2);
698  uint32_t const clean_log2 = up_scale_log2 - scale_log2;
699
700  // number of slabs in a full-sized scaled flip-merge span
701  uint32_t const full_span_slabs = HS_BS_SLABS << up_scale_log2;
702
703  // how many full-sized scaled flip-merge spans are there?
704  uint32_t full_fm = state->bx_ru / full_span_slabs;
705  uint32_t frac_fm = 0;
706
707  // initialize down_slabs
708  *down_slabs = full_fm * full_span_slabs;
709
710  // how many half-size scaled + fractional scaled spans are there?
711  uint32_t const span_rem        = state->bx_ru - *down_slabs;
712  uint32_t const half_span_slabs = full_span_slabs >> 1;
713
714  // if we have over a half-span then fractionally merge it
715  if (span_rem > half_span_slabs)
716    {
717      // the remaining slabs will be cleaned
718      *down_slabs += span_rem;
719
720      uint32_t const frac_rem      = span_rem - half_span_slabs;
721      uint32_t const frac_rem_pow2 = pow2_ru_u32(frac_rem);
722
723      if (frac_rem_pow2 >= half_span_slabs)
724        {
725          // bump it up to a full span
726          full_fm += 1;
727        }
728      else
729        {
730          // otherwise, add fractional
731          frac_fm  = MAX_MACRO(1,frac_rem_pow2 >> clean_log2);
732        }
733    }
734
735  // enqueue any necessary barriers
736  bool const   both    = (full_fm != 0) && (frac_fm != 0);
737  hs_indices_t indices = hs_barrier(state,fm,hs_bc,both ? 2 : 1);
738
739  //
740  // Size the grid
741  //
742  // The simplifying choices below limit the maximum keys that can be
743  // sorted with this grid scheme to around ~2B.
744  //
745  //   .x : slab height << clean_log2  -- this is the slab span
746  //   .y : [1...65535]                -- this is the slab index
747  //   .z : ( this could also be used to further expand .y )
748  //
749  // Note that OpenCL declares a grid in terms of global threads and
750  // not grids and blocks
751  //
752  dim3 grid;
753
754  grid.x = (HS_SLAB_HEIGHT / HS_FM_BLOCK_HEIGHT) << clean_log2;
755  grid.z = 1;
756
757  if (full_fm > 0)
758    {
759      cudaStream_t stream = state->streams[hs_indices_acquire(&indices)];
760
761      grid.y = full_fm;
762
763      hs_kernels_fm[scale_log2-HS_FM_SCALE_MIN]
764        <<<grid,HS_SLAB_THREADS * HS_FM_BLOCK_HEIGHT,0,stream>>>
765          (state->vout);
766
767      HS_STREAM_SYNCHRONIZE(stream);
768    }
769
770  if (frac_fm > 0)
771    {
772      cudaStream_t stream = state->streams[hs_indices_acquire(&indices)];
773
774      grid.y = 1;
775
776      hs_kernels_offset_fm[scale_log2-HS_FM_SCALE_MIN][msb_idx_u32(frac_fm)]
777        <<<grid,HS_SLAB_THREADS * HS_FM_BLOCK_HEIGHT,0,stream>>>
778        (state->vout,full_fm);
779
780      HS_STREAM_SYNCHRONIZE(stream);
781    }
782
783  return clean_log2;
784}
785
786//
787//
788//
789
790static
791void
792hs_bs(struct hs_state * const state,
793      hs_indices_t      const bs,
794      hs_indices_t    * const fm,
795      uint32_t          const count_padded_in)
796{
797  uint32_t const slabs_in = count_padded_in / HS_SLAB_KEYS;
798  uint32_t const full_bs  = slabs_in / HS_BS_SLABS;
799  uint32_t const frac_bs  = slabs_in - full_bs * HS_BS_SLABS;
800  bool     const both     = (full_bs != 0) && (frac_bs != 0);
801
802  // enqueue any necessary barriers
803  hs_indices_t   indices  = hs_barrier(state,bs,fm,both ? 2 : 1);
804
805  if (full_bs != 0)
806    {
807      cudaStream_t stream = state->streams[hs_indices_acquire(&indices)];
808
809      CONCAT_MACRO(hs_kernel_bs_,HS_BS_SLABS_LOG2_RU)
810        <<<full_bs,HS_BS_SLABS*HS_SLAB_THREADS,0,stream>>>
811        (state->vout,state->vin);
812
813      HS_STREAM_SYNCHRONIZE(stream);
814    }
815
816  if (frac_bs != 0)
817    {
818      cudaStream_t stream = state->streams[hs_indices_acquire(&indices)];
819
820      hs_kernels_offset_bs[msb_idx_u32(frac_bs)]
821        <<<1,frac_bs*HS_SLAB_THREADS,0,stream>>>
822        (state->vout,state->vin,full_bs*HS_BS_SLABS*HS_SLAB_THREADS);
823
824      HS_STREAM_SYNCHRONIZE(stream);
825    }
826}
827
828//
829//
830//
831
832static
833void
834hs_keyset_pre_merge(struct hs_state * const state,
835                    hs_indices_t    * const fm,
836                    uint32_t          const count_lo,
837                    uint32_t          const count_hi)
838{
839  uint32_t const vout_span = count_hi - count_lo;
840  cudaStream_t   stream    = state->streams[hs_state_acquire(state,fm)];
841
842  cuda(MemsetAsync(state->vout + count_lo,
843                   0xFF,
844                   vout_span * sizeof(HS_KEY_TYPE),
845                   stream));
846}
847
848//
849//
850//
851
852static
853void
854hs_keyset_pre_sort(struct hs_state * const state,
855                   hs_indices_t    * const bs,
856                   uint32_t          const count,
857                   uint32_t          const count_hi)
858{
859  uint32_t const vin_span = count_hi - count;
860  cudaStream_t   stream   = state->streams[hs_state_acquire(state,bs)];
861
862  cuda(MemsetAsync(state->vin + count,
863                   0xFF,
864                   vin_span * sizeof(HS_KEY_TYPE),
865                   stream));
866}
867
868//
869//
870//
871
872void
873CONCAT_MACRO(hs_cuda_sort_,HS_KEY_TYPE_PRETTY)
874  (HS_KEY_TYPE * const vin,
875   HS_KEY_TYPE * const vout,
876   uint32_t      const count,
877   uint32_t      const count_padded_in,
878   uint32_t      const count_padded_out,
879   bool          const linearize,
880   cudaStream_t        stream0,  // primary stream
881   cudaStream_t        stream1,  // auxilary
882   cudaStream_t        stream2)  // auxilary
883{
884  // is this sort in place?
885  bool const is_in_place = (vout == NULL);
886
887  // cq, buffers, wait list and slab count
888  struct hs_state state;
889
890  state.vin        = vin;
891  state.vout       = is_in_place ? vin : vout;
892  state.streams[0] = stream0;
893  state.streams[1] = stream1;
894  state.streams[2] = stream2;
895  state.pool       = 0x7; // 3 bits
896  state.bx_ru      = (count + HS_SLAB_KEYS - 1) / HS_SLAB_KEYS;
897
898  // initialize vin
899  uint32_t const count_hi                 = is_in_place ? count_padded_out : count_padded_in;
900  bool     const is_pre_sort_keyset_reqd  = count_hi > count;
901  bool     const is_pre_merge_keyset_reqd = !is_in_place && (count_padded_out > count_padded_in);
902
903  hs_indices_t bs = 0;
904
905  // initialize any trailing keys in vin before sorting
906  if (is_pre_sort_keyset_reqd)
907    hs_keyset_pre_sort(&state,&bs,count,count_hi);
908
909  hs_indices_t fm = 0;
910
911  // concurrently initialize any trailing keys in vout before merging
912  if (is_pre_merge_keyset_reqd)
913    hs_keyset_pre_merge(&state,&fm,count_padded_in,count_padded_out);
914
915  // immediately sort blocks of slabs
916  hs_bs(&state,bs,&fm,count_padded_in);
917
918  //
919  // we're done if this was a single bs block...
920  //
921  // otherwise, merge sorted spans of slabs until done
922  //
923  if (state.bx_ru > HS_BS_SLABS)
924    {
925      int32_t up_scale_log2 = 1;
926
927      while (true)
928        {
929          hs_indices_t hs_or_bc = 0;
930
931          uint32_t down_slabs;
932
933          // flip merge slabs -- return span of slabs that must be cleaned
934          uint32_t clean_slabs_log2 = hs_fm(&state,
935                                            fm,
936                                            &hs_or_bc,
937                                            &down_slabs,
938                                            up_scale_log2);
939
940          // if span is gt largest slab block cleaner then half merge
941          while (clean_slabs_log2 > HS_BC_SLABS_LOG2_MAX)
942            {
943              hs_indices_t hs_or_bc_tmp;
944
945              clean_slabs_log2 = hs_hm(&state,
946                                       hs_or_bc,
947                                       &hs_or_bc_tmp,
948                                       down_slabs,
949                                       clean_slabs_log2);
950              hs_or_bc = hs_or_bc_tmp;
951            }
952
953          // reset fm
954          fm = 0;
955
956          // launch clean slab grid -- is it the final launch?
957          hs_bc(&state,
958                hs_or_bc,
959                &fm,
960                down_slabs,
961                clean_slabs_log2);
962
963          // was this the final block clean?
964          if (((uint32_t)HS_BS_SLABS << up_scale_log2) >= state.bx_ru)
965            break;
966
967          // otherwise, merge twice as many slabs
968          up_scale_log2 += 1;
969        }
970    }
971
972  // slabs or linear?
973  if (linearize) {
974    // guaranteed to be on stream0
975    hs_transpose(&state);
976  }
977}
978
979//
980// all grids will be computed as a function of the minimum number of slabs
981//
982
983void
984CONCAT_MACRO(hs_cuda_pad_,HS_KEY_TYPE_PRETTY)
985  (uint32_t   const count,
986   uint32_t * const count_padded_in,
987   uint32_t * const count_padded_out)
988{
989  //
990  // round up the count to slabs
991  //
992  uint32_t const slabs_ru        = (count + HS_SLAB_KEYS - 1) / HS_SLAB_KEYS;
993  uint32_t const blocks          = slabs_ru / HS_BS_SLABS;
994  uint32_t const block_slabs     = blocks * HS_BS_SLABS;
995  uint32_t const slabs_ru_rem    = slabs_ru - block_slabs;
996  uint32_t const slabs_ru_rem_ru = MIN_MACRO(pow2_ru_u32(slabs_ru_rem),HS_BS_SLABS);
997
998  *count_padded_in  = (block_slabs + slabs_ru_rem_ru) * HS_SLAB_KEYS;
999  *count_padded_out = *count_padded_in;
1000
1001  //
1002  // will merging be required?
1003  //
1004  if (slabs_ru > HS_BS_SLABS)
1005    {
1006      // more than one block
1007      uint32_t const blocks_lo       = pow2_rd_u32(blocks);
1008      uint32_t const block_slabs_lo  = blocks_lo * HS_BS_SLABS;
1009      uint32_t const block_slabs_rem = slabs_ru - block_slabs_lo;
1010
1011      if (block_slabs_rem > 0)
1012        {
1013          uint32_t const block_slabs_rem_ru     = pow2_ru_u32(block_slabs_rem);
1014
1015          uint32_t const block_slabs_hi         = MAX_MACRO(block_slabs_rem_ru,
1016                                                            blocks_lo << (1 - HS_FM_SCALE_MIN));
1017
1018          uint32_t const block_slabs_padded_out = MIN_MACRO(block_slabs_lo+block_slabs_hi,
1019                                                            block_slabs_lo*2); // clamp non-pow2 blocks
1020
1021          *count_padded_out = block_slabs_padded_out * HS_SLAB_KEYS;
1022        }
1023    }
1024}
1025
1026//
1027//
1028//
1029
1030void
1031CONCAT_MACRO(hs_cuda_info_,HS_KEY_TYPE_PRETTY)
1032  (uint32_t * const key_words,
1033   uint32_t * const val_words,
1034   uint32_t * const slab_height,
1035   uint32_t * const slab_width_log2)
1036{
1037  *key_words       = HS_KEY_WORDS;
1038  *val_words       = HS_VAL_WORDS;
1039  *slab_height     = HS_SLAB_HEIGHT;
1040  *slab_width_log2 = HS_SLAB_WIDTH_LOG2;
1041}
1042
1043//
1044//
1045//
1046