1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #pragma once
7 
8 
9 #include <stddef.h>
10 #include <stdint.h>
11 
12 #include <xnnpack.h>
13 #include <xnnpack/common.h>
14 #include <xnnpack/math.h>
15 #include <xnnpack/params.h>
16 
17 
18 enum xnn_parallelization_type {
19   xnn_parallelization_type_invalid = 0,
20   xnn_parallelization_type_1d,
21   xnn_parallelization_type_1d_tile_1d,
22   xnn_parallelization_type_2d,
23   xnn_parallelization_type_2d_tile_1d,
24   xnn_parallelization_type_2d_tile_2d,
25   xnn_parallelization_type_3d,
26   xnn_parallelization_type_3d_tile_2d,
27   xnn_parallelization_type_4d,
28   xnn_parallelization_type_4d_tile_2d,
29   xnn_parallelization_type_5d,
30   xnn_parallelization_type_5d_tile_2d,
31   xnn_parallelization_type_6d_tile_2d,
32 #if XNN_MAX_UARCH_TYPES > 1
33   xnn_parallelization_type_2d_tile_2d_with_uarch,
34   xnn_parallelization_type_3d_tile_2d_with_uarch,
35   xnn_parallelization_type_4d_tile_2d_with_uarch,
36 #endif  // XNN_MAX_UARCH_TYPES > 1
37 };
38 
39 struct compute_parameters {
40   enum xnn_parallelization_type type;
41   union {
42     pthreadpool_task_1d_t task_1d;
43     pthreadpool_task_1d_tile_1d_t task_1d_tile_1d;
44     pthreadpool_task_2d_t task_2d;
45     pthreadpool_task_2d_tile_1d_t task_2d_tile_1d;
46     pthreadpool_task_2d_tile_2d_t task_2d_tile_2d;
47     pthreadpool_task_3d_t task_3d;
48     pthreadpool_task_3d_tile_2d_t task_3d_tile_2d;
49     pthreadpool_task_4d_t task_4d;
50     pthreadpool_task_4d_tile_2d_t task_4d_tile_2d;
51     pthreadpool_task_5d_t task_5d;
52     pthreadpool_task_5d_tile_2d_t task_5d_tile_2d;
53     pthreadpool_task_6d_tile_2d_t task_6d_tile_2d;
54 #if XNN_MAX_UARCH_TYPES > 1
55     pthreadpool_task_2d_tile_2d_with_id_t task_2d_tile_2d_with_id;
56     pthreadpool_task_3d_tile_2d_with_id_t task_3d_tile_2d_with_id;
57     pthreadpool_task_4d_tile_2d_with_id_t task_4d_tile_2d_with_id;
58 #endif  // XNN_MAX_UARCH_TYPES > 1
59   };
60   size_t range[6];
61   size_t tile[2];
62 };
63 
64 struct gemm_context {
65   size_t k_scaled;
66   const void* a;
67   size_t a_stride;
68   const void* packed_w;
69   size_t w_stride;
70   size_t wg_stride;
71   void* c;
72   size_t cm_stride;
73   size_t cn_stride;
74   size_t cg_stride;
75   uint32_t log2_csize;
76   struct xnn_hmp_gemm_ukernel ukernel;
77   union {
78     union xnn_qs8_gemm_params qs8;
79     union xnn_qu8_gemm_params qu8;
80     struct xnn_f16_scaleminmax_params f16;
81     union xnn_f32_minmax_params f32;
82   } params;
83 };
84 
85 #ifndef __cplusplus
86   XNN_PRIVATE void xnn_compute_grouped_gemm(
87       const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
88       size_t group_index,
89       size_t mr_block_start,
90       size_t nr_block_start,
91       size_t mr_block_size,
92       size_t nr_block_size);
93 
94   XNN_PRIVATE void xnn_compute_gemm(
95       const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
96       size_t mr_block_start,
97       size_t nr_block_start,
98       size_t mr_block_size,
99       size_t nr_block_size);
100 
101   #if XNN_MAX_UARCH_TYPES > 1
102     XNN_PRIVATE void xnn_compute_hmp_grouped_gemm(
103         const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
104         uint32_t uarch_index,
105         size_t group_index,
106         size_t mr_block_start,
107         size_t nr_block_start,
108         size_t mr_block_size,
109         size_t nr_block_size);
110 
111     XNN_PRIVATE void xnn_compute_hmp_gemm(
112         const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
113         uint32_t uarch_index,
114         size_t mr_block_start,
115         size_t nr_block_start,
116         size_t mr_block_size,
117         size_t nr_block_size);
118   #endif  // XNN_MAX_UARCH_TYPES > 1
119 #endif
120 
121 // Context for Sparse Matrix-Dense Matrix Multiplication.
122 // C [MxN] := A [MxK] * B [KxN] + bias [N]
123 // A and C are dense matrices with row-major storage, B is a sparse matrix.
124 struct spmm_context {
125   // N dimension of the B and C matrices.
126   // Corresponds to number of output channels in 1x1 convolution.
127   size_t n;
128   // M dimension of the A and C matrices, pre-scaled by sizeof(element size).
129   // Corresponds to the stride, in bytes, between adjacent rows of C matrix.
130   size_t scaled_m;
131   // Input matrix A.
132   const void* input;
133   // Packed bias elements and non-zero filter elements.
134   const void* nonzero_weights;
135   // Input pointer increments, in bytes, after each processed non-zero weight.
136   const int32_t* input_increments;
137   // Number of non-zero filter elements per each N (output channel) dimension.
138   const uint32_t* output_channel_nonzeros;
139   // Output matrix C.
140   void* output;
141   // Stride, in bytes, between matrices A corresponding to different images in batched 1x1 Convolution
142   size_t batched_input_stride;
143   // Stride, in bytes, between matrices C corresponding to different images in batched 1x1 Convolution
144   size_t batched_output_stride;
145   // Micro-kernel function pointer.
146   xnn_spmm_ukernel_function ukernel;
147   // Output activation parameters.
148   union {
149     union xnn_f32_minmax_params f32;
150   } params;
151 };
152 
153 #ifndef __cplusplus
154   XNN_PRIVATE void xnn_compute_spmm(
155     const struct spmm_context context[restrict XNN_MIN_ELEMENTS(1)],
156     size_t batch_index,
157     size_t mr_block_start,
158     size_t mr_block_size);
159 #endif
160 
161 struct igemm_context {
162   size_t ks;
163   size_t ks_scaled;
164   size_t kc;
165   size_t w_stride;
166   const void** indirect_a;
167   size_t a_offset;
168   void* zero;
169   const void* packed_w;
170   void* c;
171   size_t cm_stride;
172   size_t cn_stride;
173   size_t ga_stride;
174   size_t gw_stride;
175   size_t gc_stride;
176   size_t ba_stride;
177   size_t bc_stride;
178   uint32_t log2_csize;
179   struct xnn_hmp_igemm_ukernel ukernel;
180   union {
181     union xnn_qs8_gemm_params qs8;
182     union xnn_qu8_gemm_params qu8;
183     struct xnn_f16_scaleminmax_params f16;
184     union xnn_f32_minmax_params f32;
185   } params;
186 };
187 
188 #ifndef __cplusplus
189   XNN_PRIVATE void xnn_compute_grouped_igemm(
190       const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
191       size_t group_index,
192       size_t mr_block_start,
193       size_t nr_block_start,
194       size_t mr_block_size,
195       size_t nr_block_size);
196 
197   XNN_PRIVATE void xnn_compute_grouped_batch_igemm(
198       const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
199       size_t batch_index,
200       size_t group_index,
201       size_t mr_block_start,
202       size_t nr_block_start,
203       size_t mr_block_size,
204       size_t nr_block_size);
205 
206   XNN_PRIVATE void xnn_compute_igemm(
207       const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
208       size_t mr_block_start,
209       size_t nr_block_start,
210       size_t mr_block_size,
211       size_t nr_block_size);
212 
213   XNN_PRIVATE void xnn_compute_batch_igemm(
214       const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
215       size_t batch_index,
216       size_t mr_block_start,
217       size_t nr_block_start,
218       size_t mr_block_size,
219       size_t nr_block_size);
220 
221   #if XNN_MAX_UARCH_TYPES > 1
222     XNN_PRIVATE void xnn_compute_hmp_grouped_igemm(
223         const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
224         uint32_t uarch_index,
225         size_t group_index,
226         size_t mr_block_start,
227         size_t nr_block_start,
228         size_t mr_block_size,
229         size_t nr_block_size);
230 
231     XNN_PRIVATE void xnn_compute_hmp_grouped_batch_igemm(
232         const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
233         uint32_t uarch_index,
234         size_t batch_index,
235         size_t group_index,
236         size_t mr_block_start,
237         size_t nr_block_start,
238         size_t mr_block_size,
239         size_t nr_block_size);
240 
241     XNN_PRIVATE void xnn_compute_hmp_igemm(
242         const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
243         uint32_t uarch_index,
244         size_t mr_block_start,
245         size_t nr_block_start,
246         size_t mr_block_size,
247         size_t nr_block_size);
248 
249     XNN_PRIVATE void xnn_compute_batch_hmp_igemm(
250         const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
251         uint32_t uarch_index,
252         size_t batch_index,
253         size_t mr_block_start,
254         size_t nr_block_start,
255         size_t mr_block_size,
256         size_t nr_block_size);
257   #endif  // XNN_MAX_UARCH_TYPES > 1
258 #endif
259 
260 struct subgemm_context {
261   const struct subconvolution_params* subconvolution_params;
262   size_t kc;
263   const void* a;
264   size_t ax_stride;
265   size_t ay_stride;
266   size_t cx_stride;
267   size_t cy_stride;
268   size_t cn_stride;
269   size_t ga_stride;
270   size_t gw_stride;
271   size_t gc_stride;
272   size_t ba_stride;
273   size_t bc_stride;
274   uint32_t log2_csize;
275   struct xnn_hmp_gemm_ukernel ukernel;
276   union {
277     union xnn_qs8_gemm_params qs8;
278     union xnn_qu8_gemm_params qu8;
279     struct xnn_f16_scaleminmax_params f16;
280     union xnn_f32_minmax_params f32;
281   } params;
282 };
283 
284 #ifndef __cplusplus
285   XNN_PRIVATE void xnn_compute_grouped_subgemm2d(
286       const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
287       size_t batch_index,
288       size_t group_index,
289       size_t subkernel_index,
290       size_t slice_y,
291       size_t slice_x_start,
292       size_t nr_block_start,
293       size_t slice_x_max,
294       size_t nr_block_size);
295 
296   XNN_PRIVATE void xnn_compute_subgemm2d(
297       const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
298       size_t batch_index,
299       size_t subkernel_index,
300       size_t slice_y,
301       size_t slice_x_start,
302       size_t nr_block_start,
303       size_t slice_x_max,
304       size_t nr_block_size);
305 #endif
306 
307 struct subconv_context {
308   const struct subconvolution_params* subconvolution_params;
309   size_t kc;
310   size_t a_offset;
311   void* zero;
312   size_t cx_stride;
313   size_t cy_stride;
314   size_t cn_stride;
315   size_t ga_stride;
316   size_t gw_stride;
317   size_t gc_stride;
318   size_t ba_stride;
319   size_t bc_stride;
320   uint32_t log2_csize;
321   struct xnn_hmp_igemm_ukernel ukernel;
322   union {
323     union xnn_qs8_gemm_params qs8;
324     union xnn_qu8_gemm_params qu8;
325     struct xnn_f16_scaleminmax_params f16;
326     union xnn_f32_minmax_params f32;
327   } params;
328 };
329 
330 #ifndef __cplusplus
331   XNN_PRIVATE void xnn_compute_grouped_subconv2d(
332       const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
333       size_t batch_index,
334       size_t group_index,
335       size_t subkernel_index,
336       size_t slice_y,
337       size_t slice_x_start,
338       size_t nr_block_start,
339       size_t slice_x_max,
340       size_t nr_block_size);
341 
342   XNN_PRIVATE void xnn_compute_subconv2d(
343       const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
344       size_t batch_index,
345       size_t subkernel_index,
346       size_t slice_y,
347       size_t slice_x_start,
348       size_t nr_block_start,
349       size_t slice_x_max,
350       size_t nr_block_size);
351 #endif
352 
353 struct conv2d_context {
354   size_t input_height;
355   size_t input_width;
356   const void* input;
357   size_t input_batch_stride;
358   const void* zero;
359   const void* packed_weights;
360   void* output;
361   size_t output_batch_stride;
362   size_t input_padding_top;
363   size_t output_channels;
364   size_t output_height_stride;
365   size_t output_channel_stride;
366   union {
367     xnn_conv_hwc2chw_ukernel_function hwc2chw_ukernel;
368   };
369   union {
370     union xnn_f32_minmax_params f32;
371   } params;
372 };
373 
374 #ifndef __cplusplus
375   XNN_PRIVATE void xnn_compute_conv2d_hwc2chw(
376       const struct conv2d_context context[restrict XNN_MIN_ELEMENTS(1)],
377       size_t batch_index,
378       size_t output_y_start,
379       size_t output_y_slice);
380 #endif
381 
382 struct dwconv_context {
383   const void** indirect_input;
384   size_t indirect_input_width_stride;
385   size_t indirect_input_height_stride;
386   size_t input_offset;
387   size_t input_batch_stride;
388   const void* packed_weights;
389   void* output;
390   size_t output_batch_stride;
391   size_t output_height_stride;
392   size_t output_width;
393   size_t groups;
394   const void* zero;
395   size_t output_increment;
396   union {
397     union xnn_qs8_gemm_params qs8;
398     union xnn_qu8_gemm_params qu8;
399     struct xnn_f16_minmax_params f16;
400     union xnn_f32_minmax_params f32;
401   } params;
402   union {
403     xnn_dwconv_unipass_ukernel_function unipass_ukernel;
404   };
405 };
406 
407 #ifndef __cplusplus
408   XNN_PRIVATE void xnn_compute_dwconv_unipass(
409       const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)],
410       size_t batch_index,
411       size_t output_y);
412 #endif
413 
414 struct dwconv2d_context {
415   size_t input_height;
416   size_t input_width;
417   const void* input;
418   const void* zero;
419   uint32_t input_padding_top;
420   size_t input_channel_stride;
421   size_t input_batch_stride;
422   const void* packed_weights;
423   size_t weights_channel_stride;
424   void* output;
425   size_t output_channel_stride;
426   size_t output_batch_stride;
427   union {
428     union xnn_f32_chw_params f32;
429   } params;
430   union {
431     xnn_dwconv2d_chw_ukernel_function chw_ukernel;
432   };
433 };
434 
435 #ifndef __cplusplus
436   XNN_PRIVATE void xnn_compute_dwconv2d_chw(
437       const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS(1)],
438       size_t batch_index,
439       size_t channel);
440 #endif
441 
442 struct depthtospace2d_hwc_context {
443   size_t elements;
444   size_t input_width;
445   size_t block_size;
446   const void* input;
447   void* output;
448   size_t input_height_stride;
449   size_t input_width_stride;
450   size_t output_height_stride;
451   size_t output_width_stride;
452   xnn_univector_ukernel_function ukernel;
453 };
454 
455 #ifndef __cplusplus
456   XNN_PRIVATE void xnn_compute_depthtospace2d_hwc_contiguous(
457       const struct depthtospace2d_hwc_context* context,
458       size_t batch_input_y,
459       size_t input_x,
460       size_t block_y);
461 
462   XNN_PRIVATE void xnn_compute_depthtospace2d_hwc_strided(
463       const struct depthtospace2d_hwc_context* context,
464       size_t batch_input_y,
465       size_t input_x,
466       size_t block_y,
467       size_t block_x);
468 #endif
469 
470 struct depthtospace2d_chw2hwc_context {
471   size_t output_channels;
472   size_t input_height;
473   size_t input_width;
474   uint32_t block_size;
475   const void* input;
476   void* output;
477   size_t input_batch_stride;
478   size_t output_batch_stride;
479   size_t output_channel_stride;
480   xnn_depthtospace2d_chw2hwc_ukernel_function ukernel;
481 };
482 
483 #ifndef __cplusplus
484   XNN_PRIVATE void xnn_compute_depthtospace2d_chw2hwc(
485       const struct depthtospace2d_chw2hwc_context* context,
486       size_t batch_index);
487 #endif
488 
489 struct max_pooling_context {
490   const void** indirect_input;
491   size_t indirect_input_height_stride;
492   size_t input_offset;
493   size_t input_batch_stride;
494   void* output;
495   size_t output_batch_stride;
496   size_t output_height_stride;
497   size_t output_width;
498   size_t pooling_size;
499   size_t channels;
500   size_t input_increment;
501   size_t output_increment;
502   union {
503     union xnn_u8_minmax_params u8;
504     union xnn_f32_minmax_params f32;
505   } params;
506   xnn_maxpool_ukernel_function ukernel;
507 };
508 
509 #ifndef __cplusplus
510   XNN_PRIVATE void xnn_compute_max_pooling(
511       const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
512       size_t batch_index,
513       size_t output_y);
514 #endif
515 
516 struct unpooling_context {
517   const void* input;
518   size_t input_height_stride;
519   size_t input_width_stride;
520   const uint32_t* index;
521   size_t index_height_stride;
522   size_t index_width_stride;
523   void** indirect_output;
524   size_t indirect_output_height_stride;
525   size_t indirect_output_width_stride;
526   size_t pooling_size;
527   size_t channels;
528   uint32_t fill_value;
529   xnn_unpool_ukernel_function ukernel;
530 };
531 
532 #ifndef __cplusplus
533   XNN_PRIVATE void xnn_compute_unpooling(
534       const struct unpooling_context context[restrict XNN_MIN_ELEMENTS(1)],
535       size_t input_y,
536       size_t input_x);
537 #endif
538 
539 struct argmax_pooling_context {
540   const void** indirect_input;
541   size_t indirect_input_height_stride;
542   size_t input_offset;
543   size_t input_batch_stride;
544   void* output;
545   size_t output_batch_stride;
546   size_t output_height_stride;
547   size_t output_width;
548   uint32_t* index;
549   size_t index_batch_stride;
550   size_t index_height_stride;
551   size_t pooling_size;
552   size_t channels;
553   size_t input_increment;
554   size_t output_increment;
555   union {
556     xnn_argmaxpool_unipass_ukernel_function unipass_ukernel;
557     xnn_argmaxpool_multipass_ukernel_function multipass_ukernel;
558   };
559 };
560 
561 #ifndef __cplusplus
562   XNN_PRIVATE void xnn_compute_argmax_pooling_unipass(
563       const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
564       size_t batch_index,
565       size_t output_y);
566 
567   XNN_PRIVATE void xnn_compute_argmax_pooling_multipass(
568       const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
569       size_t batch_index,
570       size_t output_y);
571 #endif
572 
573 struct average_pooling_context {
574   const void** indirect_input;
575   size_t indirect_input_height_stride;
576   size_t input_offset;
577   size_t input_batch_stride;
578   void* output;
579   size_t output_batch_stride;
580   size_t output_height_stride;
581   size_t output_width;
582   size_t pooling_size;
583   size_t channels;
584   const void* zero;
585   size_t input_increment;
586   size_t output_increment;
587   union {
588     union xnn_qu8_avgpool_params qu8;
589     union xnn_f32_scaleminmax_params f32;
590   } params;
591   union {
592     xnn_avgpool_unipass_ukernel_function unipass_ukernel;
593     xnn_avgpool_multipass_ukernel_function multipass_ukernel;
594   };
595 };
596 
597 #ifndef __cplusplus
598   XNN_PRIVATE void xnn_compute_average_pooling_unipass(
599       const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
600       size_t batch_index,
601       size_t output_y);
602 
603   XNN_PRIVATE void xnn_compute_average_pooling_multipass(
604       const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
605       size_t batch_index,
606       size_t output_y);
607 #endif
608 
609 struct pixelwise_average_pooling_context {
610   const void** indirect_input;
611   size_t indirect_input_height_stride;
612   size_t input_offset;
613   size_t input_batch_stride;
614   const void* pixelwise_buffer;
615   size_t pixelwise_buffer_height_stride;
616   void* output;
617   size_t output_batch_stride;
618   size_t output_height_stride;
619   size_t output_width;
620   size_t pooling_size;
621   size_t channels;
622   const void* zero;
623   size_t input_increment;
624   size_t output_increment;
625   union {
626     union xnn_u8_minmax_params u8;
627     union xnn_f32_minmax_params f32;
628   } params;
629   union {
630     xnn_pavgpool_unipass_ukernel_function unipass_ukernel;
631     xnn_pavgpool_multipass_ukernel_function multipass_ukernel;
632   };
633 };
634 
635 #ifndef __cplusplus
636   XNN_PRIVATE void xnn_compute_pixelwise_average_pooling_unipass(
637       const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
638       size_t batch_index,
639       size_t output_y);
640 
641   XNN_PRIVATE void xnn_compute_pixelwise_average_pooling_multipass(
642       const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
643       size_t batch_index,
644       size_t output_y);
645 #endif
646 
647 struct global_average_pooling_nwc_context {
648   const void* input;
649   const void* zero;
650   size_t input_pixel_stride;
651   size_t input_batch_stride;
652   size_t input_elements;
653   size_t channels;
654   void* output;
655   size_t output_batch_stride;
656   union {
657     union xnn_qs8_avgpool_params qs8;
658     union xnn_qu8_avgpool_params qu8;
659     struct xnn_f16_scaleminmax_params f16;
660     union xnn_f32_scaleminmax_params f32;
661   } params;
662   union {
663     xnn_gavgpool_unipass_ukernel_function unipass_ukernel;
664     xnn_gavgpool_multipass_ukernel_function multipass_ukernel;
665   };
666 };
667 
668 #ifndef __cplusplus
669   XNN_PRIVATE void xnn_compute_global_average_pooling_nwc_unipass(
670       const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
671       size_t batch_index);
672 
673   XNN_PRIVATE void xnn_compute_global_average_pooling_nwc_multipass(
674       const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
675       size_t batch_index);
676 #endif
677 
678 struct global_average_pooling_ncw_context {
679   size_t input_elements;
680   const void* input;
681   size_t input_channel_stride;
682   size_t input_batch_stride;
683   void* output;
684   size_t output_channel_stride;
685   size_t output_batch_stride;
686   xnn_gavgpool_cw_ukernel_function ukernel;
687   union {
688     union xnn_f32_gavgpool_params f32;
689   } params;
690 };
691 
692 #ifndef __cplusplus
693   XNN_PRIVATE void xnn_compute_global_average_pooling_ncw(
694       const struct global_average_pooling_ncw_context context[restrict XNN_MIN_ELEMENTS(1)],
695       size_t batch_index,
696       size_t channels_start,
697       size_t channels_slice);
698 #endif
699 
700 struct resize_bilinear_context {
701   // Number of channels multiplied by sizeof(input element).
702   size_t scaled_channels;
703   // Indirection buffer with pointers related to rows of input pixels.
704   const void** indirect_input;
705   // Offset, in bytes, to be added to pointers in indirection buffer.
706   size_t input_offset;
707   // Stride, in bytes, between images of consecutive batches in the input.
708   size_t input_batch_stride;
709   // Packed pairs of (x, y) linear interpolation coefficients.
710   const void* packed_weights;
711   // Pointer to the output tensor.
712   void* output;
713   // Stride, in bytes, between adjacent pixels in the output.
714   size_t output_pixel_stride;
715   // Stride, in bytes, between images of consecutive batches in the output.
716   size_t output_batch_stride;
717   // log2(sizeof(weight element)).
718   uint32_t log2_wsize;
719   // Pointer to BILINEAR micro-kernel function.
720   xnn_ibilinear_ukernel_function ukernel;
721 };
722 
723 struct resize_bilinear_chw_context {
724   // Number of pixels per output image plane.
725   size_t output_pixels;
726   // Number of channels multiplied by sizeof(input element).
727   size_t channels;
728   // Stride, in bytes, between adjacent channels in the input.
729   size_t input_channel_stride;
730   // Indirection buffer with pointers related to rows of input pixels.
731   const void** indirect_input;
732   // Offset, in bytes, to be added to pointers in indirection buffer.
733   size_t input_offset;
734   // Stride, in bytes, between images of consecutive batches in the input.
735   size_t input_batch_stride;
736   // Packed pairs of (x, y) linear interpolation coefficients.
737   const void* packed_weights;
738   // Pointer to the output tensor.
739   void* output;
740   // Stride, in bytes, between images of consecutive batches in the output.
741   size_t output_batch_stride;
742   // Stride, in bytes, between consecutive channels of an output image.
743   size_t output_channel_stride;
744   // Pointer to BILINEAR micro-kernel function.
745   xnn_ibilinear_chw_ukernel_function ukernel;
746 };
747 
748 #ifndef __cplusplus
749   XNN_PRIVATE void xnn_compute_resize_bilinear(
750       const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS(1)],
751       size_t batch_index,
752       size_t pixel_start,
753       size_t pixel_range);
754   XNN_PRIVATE void xnn_compute_resize_bilinear_chw(
755     const struct resize_bilinear_chw_context context[restrict XNN_MIN_ELEMENTS(1)],
756     size_t batch_index,
757     size_t pixel_start,
758     size_t pixel_range);
759 #endif
760 
761 struct elementwise_binary_context {
762   const void* a;
763   size_t a_stride[XNN_MAX_TENSOR_DIMS - 1];
764   const void* b;
765   size_t b_stride[XNN_MAX_TENSOR_DIMS - 1];
766   void* y;
767   size_t y_stride[XNN_MAX_TENSOR_DIMS - 1];
768   size_t elements;
769   union {
770     union xnn_qs8_add_params qs8;
771     union xnn_qu8_add_params qu8;
772     struct xnn_f16_minmax_params f16;
773     union xnn_f32_minmax_params f32;
774   } params;
775   xnn_vbinary_ukernel_function ukernel;
776 };
777 
778 #ifndef __cplusplus
779   XNN_PRIVATE void xnn_compute_elementwise_binary_5d(
780       const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
781       size_t i, size_t j, size_t k, size_t l, size_t m);
782 #endif
783 
784 struct channel_shuffle_context {
785   const void* x;
786   size_t x_stride;
787   void* y;
788   size_t y_stride;
789   size_t n;
790   size_t m;
791   union {
792     xnn_zipc_ukernel_function fixed_ukernel;
793     xnn_zipv_ukernel_function variable_ukernel;
794   };
795 };
796 
797 #ifndef __cplusplus
798   XNN_PRIVATE void xnn_compute_channel_shuffle_fixed(
799       const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)],
800       size_t index);
801 
802   XNN_PRIVATE void xnn_compute_channel_shuffle_variable(
803       const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)],
804       size_t index);
805 #endif
806 
807 struct lut_strided_context {
808   size_t n;
809   const void* x;
810   size_t x_stride;
811   const void* t;
812   void* y;
813   size_t y_stride;
814   xnn_x8_lut_ukernel_function ukernel;
815 };
816 
817 #ifndef __cplusplus
818   XNN_PRIVATE void xnn_compute_lut_strided(
819       const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS(1)],
820       size_t batch_index);
821 #endif
822 
823 struct lut_contiguous_context {
824   const void* x;
825   size_t x_stride;
826   const void* t;
827   void* y;
828   size_t y_stride;
829   xnn_x8_lut_ukernel_function ukernel;
830 };
831 
832 #ifndef __cplusplus
833   XNN_PRIVATE void xnn_compute_lut_contiguous(
834       const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)],
835       size_t offset,
836       size_t size);
837 #endif
838 
839 struct univector_strided_context {
840   size_t n;
841   const void* x;
842   size_t x_stride;
843   void* y;
844   size_t y_stride;
845   xnn_univector_ukernel_function ukernel;
846   union {
847     union xnn_u8_minmax_params u8_output;
848     union xnn_f32_minmax_params f32_output;
849     union xnn_f32_hswish_params f32_hswish;
850   } params;
851 };
852 
853 #ifndef __cplusplus
854   XNN_PRIVATE void xnn_compute_univector_strided(
855       const struct univector_strided_context context[restrict XNN_MIN_ELEMENTS(1)],
856       size_t batch_index,
857       size_t batch_range);
858 #endif
859 
860 struct univector_contiguous_context {
861   const void* x;
862   size_t x_stride;
863   void* y;
864   size_t y_stride;
865   xnn_univector_ukernel_function ukernel;
866   union {
867     union xnn_u8_minmax_params u8_output;
868     union xnn_f32_minmax_params f32_output;
869     union xnn_f32_hswish_params f32_hswish;
870   } params;
871 };
872 
873 #ifndef __cplusplus
874   XNN_PRIVATE void xnn_compute_univector_contiguous(
875       const struct univector_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)],
876       size_t offset,
877       size_t size);
878 #endif
879 
880 struct prelu_context {
881   size_t n;
882   const void* x;
883   size_t x_stride;
884   const void* w;
885   void* y;
886   size_t y_stride;
887   xnn_prelu_ukernel_function ukernel;
888 };
889 
890 #ifndef __cplusplus
891   XNN_PRIVATE void xnn_compute_prelu(
892       const struct prelu_context context[restrict XNN_MIN_ELEMENTS(1)],
893       size_t batch_start,
894       size_t batch_range);
895 #endif
896 
897 struct vmulcaddc_context {
898   size_t n;
899   const void* x;
900   size_t x_stride;
901   const void* w;
902   void* y;
903   size_t y_stride;
904   xnn_vmulcaddc_ukernel_function ukernel;
905   union {
906     struct xnn_f16_minmax_params f16;
907     union xnn_f32_minmax_params f32;
908   } params;
909 };
910 
911 #ifndef __cplusplus
912   XNN_PRIVATE void xnn_compute_vmulcaddc(
913       const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS(1)],
914       size_t batch_start,
915       size_t batch_size);
916 #endif
917 
918 struct pad_context {
919   const void* input;
920   size_t input_stride[XNN_MAX_TENSOR_DIMS - 1];
921   void* output;
922   size_t output_stride[XNN_MAX_TENSOR_DIMS - 1];
923   size_t pre_paddings[XNN_MAX_TENSOR_DIMS];
924   size_t post_paddings[1];
925   size_t input_size[XNN_MAX_TENSOR_DIMS];
926   size_t output_size[1];
927   uint32_t padding_value;
928   xnn_pad_ukernel_function pad_ukernel;
929   xnn_fill_ukernel_function fill_ukernel;
930 };
931 
932 #ifndef __cplusplus
933   XNN_PRIVATE void xnn_compute_pad_5d(
934       const struct pad_context context[restrict XNN_MIN_ELEMENTS(1)],
935       size_t i, size_t j, size_t k, size_t l, size_t m);
936 #endif
937 
938 struct u8_softmax_context {
939   size_t n;
940   const uint8_t* x;
941   size_t x_stride;
942   const uint32_t* t;
943   uint8_t* y;
944   size_t y_stride;
945   xnn_u8_rmax_ukernel_function rmax_ukernel;
946   xnn_u8_lut32norm_ukernel_function lut_norm_ukernel;
947 };
948 
949 #ifndef __cplusplus
950   XNN_PRIVATE void xnn_compute_u8_softmax(
951       const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS(1)],
952       size_t batch_index);
953 #endif
954 
955 struct f32_three_pass_softmax_context {
956   size_t n;
957   const void* x;
958   size_t x_stride;
959   void* y;
960   size_t y_stride;
961   xnn_f32_rmax_ukernel_function rmax_ukernel;
962   xnn_f32_raddstoreexpminusmax_ukernel_function raddstoreexpminusmax_ukernel;
963   xnn_vbinary_ukernel_function vmulc_ukernel;
964   union xnn_f32_minmax_params params;
965 };
966 
967 #ifndef __cplusplus
968   XNN_PRIVATE void xnn_compute_f32_three_pass_softmax(
969       const struct f32_three_pass_softmax_context context[restrict XNN_MIN_ELEMENTS(1)],
970       size_t batch_index);
971 #endif
972