1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #include <assert.h>
10 #include <stddef.h>
11 #include <stdint.h>
12 #include <string.h>
13 
14 #include <xnnpack.h>
15 #include <xnnpack/allocator.h>
16 #include <xnnpack/operator.h>
17 #include <xnnpack/log.h>
18 #include <xnnpack/common.h>
19 #include <xnnpack/math.h>
20 #include <xnnpack/params.h>
21 #include <xnnpack/compute.h>
22 
23 
xnn_compute_grouped_gemm(const struct gemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)24 void xnn_compute_grouped_gemm(
25     const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
26     size_t group_index,
27     size_t mr_block_start,
28     size_t nr_block_start,
29     size_t mr_block_size,
30     size_t nr_block_size)
31 {
32   const size_t k_scaled  = context->k_scaled;
33   const size_t a_stride  = context->a_stride;
34   const size_t cm_stride = context->cm_stride;
35 
36   context->ukernel.function[XNN_UARCH_DEFAULT](
37       mr_block_size,
38       nr_block_size,
39       k_scaled,
40       (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled),
41       a_stride,
42       (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride),
43       (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride),
44       cm_stride,
45       context->cn_stride,
46       &context->params);
47 }
48 
xnn_compute_gemm(const struct gemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)49 void xnn_compute_gemm(
50     const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
51     size_t mr_block_start,
52     size_t nr_block_start,
53     size_t mr_block_size,
54     size_t nr_block_size)
55 {
56   const size_t a_stride  = context->a_stride;
57   const size_t cm_stride = context->cm_stride;
58 
59   context->ukernel.function[XNN_UARCH_DEFAULT](
60       mr_block_size,
61       nr_block_size,
62       context->k_scaled,
63       (const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
64       a_stride,
65       (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
66       (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
67       cm_stride,
68       context->cn_stride,
69       &context->params);
70 }
71 
xnn_compute_spmm(const struct spmm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t mr_block_start,size_t mr_block_size)72 void xnn_compute_spmm(
73     const struct spmm_context context[restrict XNN_MIN_ELEMENTS(1)],
74     size_t batch_index,
75     size_t mr_block_start,
76     size_t mr_block_size)
77 {
78   context->ukernel(
79       mr_block_size,
80       context->n,
81       (const void*) ((uintptr_t) context->input + batch_index * context->batched_input_stride + mr_block_start),
82       context->nonzero_weights,
83       context->input_increments,
84       context->output_channel_nonzeros,
85       (void*) ((uintptr_t) context->output + batch_index * context->batched_output_stride + mr_block_start),
86       context->scaled_m,
87       &context->params);
88 }
89 
xnn_compute_grouped_batch_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)90 void xnn_compute_grouped_batch_igemm(
91     const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
92     size_t batch_index,
93     size_t group_index,
94     size_t mr_block_start,
95     size_t nr_block_start,
96     size_t mr_block_size,
97     size_t nr_block_size)
98 {
99   const size_t ks        = context->ks;
100   const size_t cm_stride = context->cm_stride;
101 
102   context->ukernel.function[XNN_UARCH_DEFAULT](
103       mr_block_size,
104       nr_block_size,
105       context->kc,
106       context->ks_scaled,
107       (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
108       (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
109       (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
110       cm_stride,
111       context->cn_stride,
112       context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
113       context->zero,
114       &context->params);
115 }
116 
xnn_compute_grouped_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)117 void xnn_compute_grouped_igemm(
118     const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
119     size_t group_index,
120     size_t mr_block_start,
121     size_t nr_block_start,
122     size_t mr_block_size,
123     size_t nr_block_size)
124 {
125   const size_t ks        = context->ks;
126   const size_t cm_stride = context->cm_stride;
127 
128   context->ukernel.function[XNN_UARCH_DEFAULT](
129       mr_block_size,
130       nr_block_size,
131       context->kc,
132       context->ks_scaled,
133       (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
134       (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
135       (void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
136       cm_stride,
137       context->cn_stride,
138       context->a_offset + group_index * context->ga_stride,
139       context->zero,
140       &context->params);
141 }
142 
xnn_compute_batch_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)143 void xnn_compute_batch_igemm(
144     const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
145     size_t batch_index,
146     size_t mr_block_start,
147     size_t nr_block_start,
148     size_t mr_block_size,
149     size_t nr_block_size)
150 {
151   const size_t ks        = context->ks;
152   const size_t cm_stride = context->cm_stride;
153 
154   context->ukernel.function[XNN_UARCH_DEFAULT](
155       mr_block_size,
156       nr_block_size,
157       context->kc,
158       context->ks_scaled,
159       (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
160       (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
161       (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
162       cm_stride,
163       context->cn_stride,
164       context->a_offset + batch_index * context->ba_stride,
165       context->zero,
166       &context->params);
167 }
168 
xnn_compute_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)169 void xnn_compute_igemm(
170     const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
171     size_t mr_block_start,
172     size_t nr_block_start,
173     size_t mr_block_size,
174     size_t nr_block_size)
175 {
176   const size_t ks        = context->ks;
177   const size_t cm_stride = context->cm_stride;
178 
179   context->ukernel.function[XNN_UARCH_DEFAULT](
180       mr_block_size,
181       nr_block_size,
182       context->kc,
183       context->ks_scaled,
184       (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
185       (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
186       (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
187       cm_stride,
188       context->cn_stride,
189       context->a_offset,
190       context->zero,
191       &context->params);
192 }
193 
xnn_compute_grouped_subgemm2d(const struct subgemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t group_index,size_t subkernel_index,size_t slice_y,size_t slice_x_start,size_t nc_block_start,size_t slice_x_max,size_t nc_block_size)194 void xnn_compute_grouped_subgemm2d(
195       const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
196       size_t batch_index,
197       size_t group_index,
198       size_t subkernel_index,
199       size_t slice_y,
200       size_t slice_x_start,
201       size_t nc_block_start,
202       size_t slice_x_max,
203       size_t nc_block_size)
204 {
205   const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
206 
207   if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
208     return;
209   }
210 
211   const size_t slice_width = subconvolution_params->slice_width;
212   if XNN_UNLIKELY(slice_x_start >= slice_width) {
213     return;
214   }
215   const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
216 
217   const size_t ax_stride = context->ax_stride;
218   const size_t cx_stride = context->cx_stride;
219   context->ukernel.function[XNN_UARCH_DEFAULT](
220       slice_x_size,
221       nc_block_size,
222       context->kc,
223       (const void*) ((uintptr_t) context->a + group_index * context->ga_stride + slice_y * context->ay_stride + slice_x_start * ax_stride + batch_index * context->ba_stride),
224       ax_stride,
225       (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride),
226       (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
227       cx_stride,
228       context->cn_stride,
229       &context->params);
230 }
231 
xnn_compute_subgemm2d(const struct subgemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t subkernel_index,size_t slice_y,size_t slice_x_start,size_t nc_block_start,size_t slice_x_max,size_t nc_block_size)232 void xnn_compute_subgemm2d(
233       const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
234       size_t batch_index,
235       size_t subkernel_index,
236       size_t slice_y,
237       size_t slice_x_start,
238       size_t nc_block_start,
239       size_t slice_x_max,
240       size_t nc_block_size)
241 {
242   const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
243 
244   if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
245     return;
246   }
247 
248   const size_t slice_width = subconvolution_params->slice_width;
249   if XNN_UNLIKELY(slice_x_start >= slice_width) {
250     return;
251   }
252   const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
253 
254   const size_t ax_stride = context->ax_stride;
255   const size_t cx_stride = context->cx_stride;
256   context->ukernel.function[XNN_UARCH_DEFAULT](
257       slice_x_size,
258       nc_block_size,
259       context->kc,
260       (const void*) ((uintptr_t) context->a + slice_y * context->ay_stride + slice_x_start * ax_stride + batch_index * context->ba_stride),
261       ax_stride,
262       (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride),
263       (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
264       cx_stride,
265       context->cn_stride,
266       &context->params);
267 }
268 
xnn_compute_grouped_subconv2d(const struct subconv_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t group_index,size_t subkernel_index,size_t slice_y,size_t slice_x_start,size_t nc_block_start,size_t slice_x_max,size_t nc_block_size)269 void xnn_compute_grouped_subconv2d(
270       const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
271       size_t batch_index,
272       size_t group_index,
273       size_t subkernel_index,
274       size_t slice_y,
275       size_t slice_x_start,
276       size_t nc_block_start,
277       size_t slice_x_max,
278       size_t nc_block_size)
279 {
280   const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
281 
282   if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
283     return;
284   }
285 
286   const size_t slice_width = subconvolution_params->slice_width;
287   if XNN_UNLIKELY(slice_x_start >= slice_width) {
288     return;
289   }
290   const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
291 
292   const size_t cx_stride = context->cx_stride;
293   context->ukernel.function[XNN_UARCH_DEFAULT](
294       slice_x_size,
295       nc_block_size,
296       context->kc,
297       subconvolution_params->scaled_kernel_size,
298       (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
299       (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride),
300       (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
301       cx_stride,
302       context->cn_stride,
303       context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
304       context->zero,
305       &context->params);
306 }
307 
xnn_compute_subconv2d(const struct subconv_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t subkernel_index,size_t slice_y,size_t slice_x_start,size_t nc_block_start,size_t slice_x_max,size_t nc_block_size)308 void xnn_compute_subconv2d(
309       const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
310       size_t batch_index,
311       size_t subkernel_index,
312       size_t slice_y,
313       size_t slice_x_start,
314       size_t nc_block_start,
315       size_t slice_x_max,
316       size_t nc_block_size)
317 {
318   const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
319 
320   if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
321     return;
322   }
323 
324   const size_t slice_width = subconvolution_params->slice_width;
325   if XNN_UNLIKELY(slice_x_start >= slice_width) {
326     return;
327   }
328   const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
329 
330   const size_t cx_stride = context->cx_stride;
331   context->ukernel.function[XNN_UARCH_DEFAULT](
332       slice_x_size,
333       nc_block_size,
334       context->kc,
335       subconvolution_params->scaled_kernel_size,
336       (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
337       (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride),
338       (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
339       cx_stride,
340       context->cn_stride,
341       context->a_offset + batch_index * context->ba_stride,
342       context->zero,
343       &context->params);
344 }
345 
xnn_compute_conv2d_hwc2chw(const struct conv2d_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y_start,size_t output_y_slice)346 void xnn_compute_conv2d_hwc2chw(
347       const struct conv2d_context context[restrict XNN_MIN_ELEMENTS(1)],
348       size_t batch_index,
349       size_t output_y_start,
350       size_t output_y_slice)
351 {
352   context->hwc2chw_ukernel(
353       context->input_height,
354       context->input_width,
355       output_y_start,
356       output_y_start + output_y_slice,
357       (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride),
358       context->zero,
359       context->packed_weights,
360       (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride),
361       context->input_padding_top,
362       context->output_channels,
363       context->output_height_stride,
364       context->output_channel_stride,
365       &context->params);
366 }
367 
xnn_compute_dwconv_unipass(const struct dwconv_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)368 void xnn_compute_dwconv_unipass(
369     const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)],
370     size_t batch_index,
371     size_t output_y)
372 {
373   const void** indirect_input =
374     (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
375   const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
376   void* output = (void*) ((uintptr_t) context->output +
377     batch_index * context->output_batch_stride + output_y * context->output_height_stride);
378 
379   context->unipass_ukernel(
380     context->groups, context->output_width,
381     indirect_input, context->packed_weights, output,
382     context->indirect_input_width_stride, context->output_increment,
383     input_offset, context->zero,
384     &context->params);
385 }
386 
xnn_compute_dwconv2d_chw(const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t channel)387 void xnn_compute_dwconv2d_chw(
388     const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS(1)],
389     size_t batch_index,
390     size_t channel)
391 {
392   context->chw_ukernel(
393     context->input_height,
394     context->input_width,
395     (const void*) ((uintptr_t) context->input + channel * context->input_channel_stride + batch_index * context->input_batch_stride),
396     (const void*) ((uintptr_t) context->packed_weights + channel * context->weights_channel_stride),
397     context->zero,
398     (void*) ((uintptr_t) context->output + channel * context->output_channel_stride + batch_index * context->output_batch_stride),
399     context->input_padding_top,
400     &context->params);
401 }
402 
xnn_compute_depthtospace2d_hwc_contiguous(const struct depthtospace2d_hwc_context * context,size_t batch_input_y,size_t input_x,size_t block_y)403 void xnn_compute_depthtospace2d_hwc_contiguous(
404     const struct depthtospace2d_hwc_context* context,
405     size_t batch_input_y,
406     size_t input_x,
407     size_t block_y)
408 {
409   const size_t input_width = context->input_width;
410   const size_t elements = context->elements;
411   const void* input = (const void*) ((uintptr_t) context->input +
412     (batch_input_y * input_width + input_x) * context->input_width_stride + block_y * elements);
413   void* output = (void*) ((uintptr_t) context->output +
414     ((batch_input_y * context->block_size + block_y) * input_width + input_x) * elements);
415 
416   context->ukernel(
417     elements,
418     input,
419     output,
420     NULL);
421 }
422 
xnn_compute_depthtospace2d_hwc_strided(const struct depthtospace2d_hwc_context * context,size_t batch_input_y,size_t input_x,size_t block_y,size_t block_x)423 void xnn_compute_depthtospace2d_hwc_strided(
424     const struct depthtospace2d_hwc_context* context,
425     size_t batch_input_y,
426     size_t input_x,
427     size_t block_y,
428     size_t block_x)
429 {
430   const size_t block_size = context->block_size;
431   const size_t elements = context->elements;
432   const void* input = (const void*) ((uintptr_t) context->input +
433     batch_input_y * context->input_height_stride + input_x * context->input_width_stride + (block_y * block_size + block_x) * elements);
434   void* output = (void*) ((uintptr_t) context->output +
435     (batch_input_y * block_size + block_y) * context->output_height_stride +
436     (input_x * block_size + block_x) * context->output_width_stride);
437 
438   context->ukernel(
439     elements,
440     input,
441     output,
442     NULL);
443 }
444 
xnn_compute_depthtospace2d_chw2hwc(const struct depthtospace2d_chw2hwc_context * context,size_t batch_index)445 void xnn_compute_depthtospace2d_chw2hwc(
446     const struct depthtospace2d_chw2hwc_context* context,
447     size_t batch_index)
448 {
449   context->ukernel(
450     context->output_channels,
451     context->input_height,
452     context->input_width,
453     context->block_size,
454     (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride),
455     (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride),
456     context->output_channel_stride);
457 }
458 
xnn_compute_argmax_pooling_unipass(const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)459 void xnn_compute_argmax_pooling_unipass(
460     const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
461     size_t batch_index,
462     size_t output_y)
463 {
464   const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
465     output_y * context->indirect_input_height_stride);
466   const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
467   void* output = (void*) ((uintptr_t) context->output +
468     batch_index * context->output_batch_stride + output_y * context->output_height_stride);
469   uint32_t* index = (uint32_t*) ((uintptr_t) context->index +
470     batch_index * context->index_batch_stride + output_y * context->index_height_stride);
471 
472   context->unipass_ukernel(
473     context->output_width, context->pooling_size, context->channels,
474     indirect_input, input_offset, output, index,
475     context->input_increment, context->output_increment);
476 }
477 
xnn_compute_argmax_pooling_multipass(const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)478 void xnn_compute_argmax_pooling_multipass(
479     const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
480     size_t batch_index,
481     size_t output_y)
482 {
483   const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
484     output_y * context->indirect_input_height_stride);
485   const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
486   void* output = (void*) ((uintptr_t) context->output +
487     batch_index * context->output_batch_stride + output_y * context->output_height_stride);
488   uint32_t* index = (uint32_t*) ((uintptr_t) context->index +
489     batch_index * context->index_batch_stride + output_y * context->index_height_stride);
490 
491   void* multipass_accumulation_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(float) + XNN_EXTRA_BYTES);
492   void* multipass_index_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(uint32_t) + XNN_EXTRA_BYTES);
493 
494   context->multipass_ukernel(
495     context->output_width, context->pooling_size, context->channels,
496     indirect_input, input_offset, multipass_accumulation_buffer, multipass_index_buffer, output, index,
497     context->input_increment, context->output_increment);
498 }
499 
xnn_compute_max_pooling(const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)500 void xnn_compute_max_pooling(
501     const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
502     size_t batch_index,
503     size_t output_y)
504 {
505   const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
506     output_y * context->indirect_input_height_stride);
507   const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
508   void* output = (void*) ((uintptr_t) context->output +
509     batch_index * context->output_batch_stride + output_y * context->output_height_stride);
510 
511   context->ukernel(
512     context->output_width, context->pooling_size, context->channels,
513     indirect_input, input_offset, output,
514     context->input_increment, context->output_increment,
515     &context->params);
516 }
517 
xnn_compute_unpooling(const struct unpooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t input_y,size_t input_x)518 void xnn_compute_unpooling(
519     const struct unpooling_context context[restrict XNN_MIN_ELEMENTS(1)],
520     size_t input_y,
521     size_t input_x)
522 {
523   const void* input = (const void*) ((uintptr_t) context->input +
524       input_y * context->input_height_stride + input_x * context->input_width_stride);
525   const uint32_t* index = (const uint32_t*) ((uintptr_t) context->index +
526       input_y * context->index_height_stride + input_x * context->index_width_stride);
527   void** indirect_output =
528     (void**) ((uintptr_t) context->indirect_output +
529       input_y * context->indirect_output_height_stride + input_x * context->indirect_output_width_stride);
530 
531   context->ukernel(
532     context->pooling_size,
533     context->channels,
534     context->fill_value,
535     input, index, indirect_output);
536 }
537 
xnn_compute_average_pooling_unipass(const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)538 void xnn_compute_average_pooling_unipass(
539     const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
540     size_t batch_index,
541     size_t output_y)
542 {
543   const void** indirect_input =
544     (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
545   const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
546   void* output = (void*) ((uintptr_t) context->output +
547     batch_index * context->output_batch_stride + output_y * context->output_height_stride);
548 
549   context->unipass_ukernel(
550     context->output_width, context->pooling_size, context->channels,
551     indirect_input, input_offset, context->zero, output,
552     context->input_increment, context->output_increment,
553     &context->params);
554 }
555 
xnn_compute_average_pooling_multipass(const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)556 void xnn_compute_average_pooling_multipass(
557     const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
558     size_t batch_index,
559     size_t output_y)
560 {
561   const void** indirect_input =
562     (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
563   const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
564   void* output = (void*) ((uintptr_t) context->output +
565     batch_index * context->output_batch_stride + output_y * context->output_height_stride);
566 
567   void* multipass_buffer =
568     XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t));
569 
570   context->multipass_ukernel(
571     context->output_width, context->pooling_size, context->channels,
572     indirect_input, input_offset, context->zero, multipass_buffer, output,
573     context->input_increment, context->output_increment,
574     &context->params);
575 }
576 
xnn_compute_pixelwise_average_pooling_unipass(const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)577 void xnn_compute_pixelwise_average_pooling_unipass(
578     const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
579     size_t batch_index,
580     size_t output_y)
581 {
582   const void** indirect_input =
583     (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
584   const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
585   const void* pixelwise_buffer =
586     (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
587   void* output = (void*) ((uintptr_t) context->output +
588     batch_index * context->output_batch_stride + output_y * context->output_height_stride);
589 
590   context->unipass_ukernel(
591     context->output_width, context->pooling_size, context->channels,
592     indirect_input, input_offset, context->zero, pixelwise_buffer, output,
593     context->input_increment, context->output_increment,
594     &context->params);
595 }
596 
xnn_compute_pixelwise_average_pooling_multipass(const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)597 void xnn_compute_pixelwise_average_pooling_multipass(
598     const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
599     size_t batch_index,
600     size_t output_y)
601 {
602   const void** indirect_input =
603     (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
604   const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
605   const void* pixelwise_buffer =
606     (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
607   void* output = (void*) ((uintptr_t) context->output +
608     batch_index * context->output_batch_stride + output_y * context->output_height_stride);
609 
610   void* multipass_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t));
611 
612   context->multipass_ukernel(
613     context->output_width, context->pooling_size, context->channels,
614     indirect_input, input_offset, context->zero, pixelwise_buffer, multipass_buffer, output,
615     context->input_increment, context->output_increment,
616     &context->params);
617 }
618 
xnn_compute_global_average_pooling_nwc_unipass(const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index)619 void xnn_compute_global_average_pooling_nwc_unipass(
620     const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
621     size_t batch_index)
622 {
623   const void* input =
624     (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
625   void* output =
626     (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
627 
628   context->unipass_ukernel(
629     context->input_elements,
630     context->channels,
631     input,
632     context->input_pixel_stride,
633     context->zero,
634     output,
635     &context->params);
636 }
637 
xnn_compute_global_average_pooling_nwc_multipass(const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index)638 void xnn_compute_global_average_pooling_nwc_multipass(
639     const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
640     size_t batch_index)
641 {
642   const void* input =
643     (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
644   void* output =
645     (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
646 
647   void* multipass_buffer =
648     XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t));
649 
650   context->multipass_ukernel(
651     context->input_elements,
652     context->channels,
653     input,
654     context->input_pixel_stride,
655     context->zero,
656     multipass_buffer,
657     output,
658     &context->params);
659 }
660 
xnn_compute_global_average_pooling_ncw(const struct global_average_pooling_ncw_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t channels_start,size_t channels_slice)661 void xnn_compute_global_average_pooling_ncw(
662     const struct global_average_pooling_ncw_context context[restrict XNN_MIN_ELEMENTS(1)],
663     size_t batch_index,
664     size_t channels_start,
665     size_t channels_slice)
666 {
667   const void* input = (const void*) ((uintptr_t) context->input +
668     channels_start * context->input_channel_stride + batch_index * context->input_batch_stride);
669   void* output = (void*) ((uintptr_t) context->output +
670     channels_start * context->output_channel_stride + batch_index * context->output_batch_stride);
671 
672   context->ukernel(
673     context->input_elements,
674     channels_slice,
675     input,
676     output,
677     &context->params);
678 }
679 
xnn_compute_resize_bilinear(const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t pixel_start,size_t pixel_range)680 void xnn_compute_resize_bilinear(
681     const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS(1)],
682     size_t batch_index,
683     size_t pixel_start,
684     size_t pixel_range)
685 {
686   void* output =
687     (void*) ((uintptr_t) context->output + pixel_start * context->output_pixel_stride + batch_index * context->output_batch_stride);
688 
689   context->ukernel(
690     pixel_range,
691     context->scaled_channels,
692     context->indirect_input + pixel_start * 4,
693     context->input_offset + batch_index * context->input_batch_stride,
694     (const void*) ((uintptr_t) context->packed_weights + (pixel_start << context->log2_wsize)),
695     output,
696     context->output_pixel_stride - context->scaled_channels);
697 }
698 
xnn_compute_resize_bilinear_chw(const struct resize_bilinear_chw_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t channel_start,size_t channel_range)699 void xnn_compute_resize_bilinear_chw(
700     const struct resize_bilinear_chw_context context[restrict XNN_MIN_ELEMENTS(1)],
701     size_t batch_index,
702     size_t channel_start,
703     size_t channel_range)
704 {
705   void* output =
706     (void*) ((uintptr_t) context->output + channel_start * context->output_channel_stride + batch_index * context->output_batch_stride);
707   const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride + channel_start * context->input_channel_stride;
708 
709   context->ukernel(
710     context->output_pixels,
711     channel_range,
712     context->indirect_input,
713     input_offset,
714     context->packed_weights,
715     output,
716     context->input_channel_stride);
717 }
718 
xnn_compute_prelu(const struct prelu_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_start,size_t batch_range)719 void xnn_compute_prelu(
720     const struct prelu_context context[restrict XNN_MIN_ELEMENTS(1)],
721     size_t batch_start,
722     size_t batch_range)
723 {
724   const size_t x_stride = context->x_stride;
725   const size_t y_stride = context->y_stride;
726   const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
727   void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
728 
729   context->ukernel(batch_range, context->n, x, x_stride, context->w, y, y_stride);
730 }
731 
xnn_compute_pad_5d(const struct pad_context context[restrict XNN_MIN_ELEMENTS (1)],size_t i,size_t j,size_t k,size_t l,size_t m)732 void xnn_compute_pad_5d(
733     const struct pad_context context[restrict XNN_MIN_ELEMENTS(1)],
734     size_t i, size_t j, size_t k, size_t l, size_t m)
735 {
736   const void* input = (const void*) ((uintptr_t) context->input +
737     i * context->input_stride[4] + j * context->input_stride[3] + k * context->input_stride[2] + l * context->input_stride[1] + m * context->input_stride[0]);
738   void* output = (void*) ((uintptr_t) context->output +
739     i * context->output_stride[4] + j * context->output_stride[3] + k * context->output_stride[2] + l * context->output_stride[1] + m * context->output_stride[0]);
740 
741   const size_t i_padding = context->pre_paddings[5];
742   const size_t j_padding = context->pre_paddings[4];
743   const size_t k_padding = context->pre_paddings[3];
744   const size_t l_padding = context->pre_paddings[2];
745   const size_t m_padding = context->pre_paddings[1];
746 
747   const size_t i_size = context->input_size[5];
748   const size_t j_size = context->input_size[4];
749   const size_t k_size = context->input_size[3];
750   const size_t l_size = context->input_size[2];
751   const size_t m_size = context->input_size[1];
752 
753   if XNN_LIKELY(i - i_padding < i_size && j - j_padding < j_size && k - k_padding < k_size &&
754                 l - l_padding < l_size && m - m_padding < m_size)
755   {
756     context->pad_ukernel(
757       1 /* rows */,
758       context->input_size[0], context->pre_paddings[0], context->post_paddings[0],
759       &context->padding_value,
760       input, 0 /* input stride */, output, 0 /* output stride */);
761   } else {
762     context->fill_ukernel(1 /* rows */, context->output_size[0], output, 0 /* output stride */, &context->padding_value);
763   }
764 }
765 
xnn_compute_elementwise_binary_5d(const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS (1)],size_t i,size_t j,size_t k,size_t l,size_t m)766 void xnn_compute_elementwise_binary_5d(
767     const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
768     size_t i, size_t j, size_t k, size_t l, size_t m)
769 {
770   const void* a = (const void*) ((uintptr_t) context->a +
771     i * context->a_stride[0] + j * context->a_stride[1] + k * context->a_stride[2] + l * context->a_stride[3] + m * context->a_stride[4]);
772   const void* b = (const void*) ((uintptr_t) context->b +
773     i * context->b_stride[0] + j * context->b_stride[1] + k * context->b_stride[2] + l * context->b_stride[3] + m * context->b_stride[4]);
774   void* y = (void*) ((uintptr_t) context->y +
775     i * context->y_stride[0] + j * context->y_stride[1] + k * context->y_stride[2] + l * context->y_stride[3] + m * context->y_stride[4]);
776   context->ukernel(context->elements, a, b, y, &context->params);
777 }
778 
xnn_compute_channel_shuffle_fixed(const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS (1)],size_t index)779 void xnn_compute_channel_shuffle_fixed(
780     const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)],
781     size_t index)
782 {
783   const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
784   void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
785 
786   context->fixed_ukernel(context->n, x, y);
787 }
788 
xnn_compute_channel_shuffle_variable(const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS (1)],size_t index)789 void xnn_compute_channel_shuffle_variable(
790     const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)],
791     size_t index)
792 {
793   const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
794   void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
795 
796   context->variable_ukernel(context->n, context->m, x, y);
797 }
798 
xnn_compute_lut_strided(const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index)799 void xnn_compute_lut_strided(
800     const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS(1)],
801     size_t batch_index)
802 {
803   const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
804   void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
805 
806   context->ukernel(context->n, x, context->t, y);
807 }
808 
xnn_compute_lut_contiguous(const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS (1)],size_t offset,size_t size)809 void xnn_compute_lut_contiguous(
810     const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)],
811     size_t offset,
812     size_t size)
813 {
814   const void* x = (const void*) ((uintptr_t) context->x + offset);
815   void* y = (void*) ((uintptr_t) context->y + offset);
816 
817   context->ukernel(size, x, context->t, y);
818 }
819 
xnn_compute_univector_strided(const struct univector_strided_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t batch_range)820 void xnn_compute_univector_strided(
821     const struct univector_strided_context context[restrict XNN_MIN_ELEMENTS(1)],
822     size_t batch_index,
823     size_t batch_range /* always 1 */)
824 {
825   assert(batch_range == 1);
826 
827   const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
828   void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
829   context->ukernel(context->n, x, y, &context->params);
830 }
831 
xnn_compute_univector_contiguous(const struct univector_contiguous_context context[restrict XNN_MIN_ELEMENTS (1)],size_t offset,size_t size)832 void xnn_compute_univector_contiguous(
833     const struct univector_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)],
834     size_t offset,
835     size_t size)
836 {
837   const void* x = (const void*) ((uintptr_t) context->x + offset);
838   void* y = (void*) ((uintptr_t) context->y + offset);
839   context->ukernel(size, x, y, &context->params);
840 }
841 
xnn_compute_u8_softmax(const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index)842 void xnn_compute_u8_softmax(
843     const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS(1)],
844     size_t batch_index)
845 {
846   const uint8_t* x = (const uint8_t*) ((uintptr_t) context->x + context->x_stride * batch_index);
847   uint8_t* y = (uint8_t*) ((uintptr_t) context->y + context->y_stride * batch_index);
848   const size_t n = context->n;
849 
850   uint8_t x_max = 0;
851   context->rmax_ukernel(n, x, &x_max);
852   const size_t adjustment = x_max ^ 255;
853   const uint32_t* t = (const uint32_t*) context->t + adjustment;
854   context->lut_norm_ukernel(n, x, t, y);
855 }
856 
xnn_compute_f32_three_pass_softmax(const struct f32_three_pass_softmax_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index)857 void xnn_compute_f32_three_pass_softmax(
858     const struct f32_three_pass_softmax_context context[restrict XNN_MIN_ELEMENTS(1)],
859     size_t batch_index)
860 {
861   const float* x = (const float*) ((uintptr_t) context->x + context->x_stride * batch_index);
862   float* y = (float*) ((uintptr_t) context->y + context->y_stride * batch_index);
863   const size_t n = context->n;
864 
865   // First pass: reduce-max
866   float x_max;
867   context->rmax_ukernel(n, x, &x_max);
868 
869   // Second pass: reduce-add & store exp(x-x_max)
870   float y_sum;
871   context->raddstoreexpminusmax_ukernel(n, x, y, &y_sum, x_max);
872 
873   // Third pass: scale y
874   const float y_scale = 1.0f / y_sum;
875   context->vmulc_ukernel(n, y, &y_scale, y, &context->params);
876 }
877 
xnn_compute_vmulcaddc(const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_start,size_t batch_size)878 void xnn_compute_vmulcaddc(
879     const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS(1)],
880     size_t batch_start,
881     size_t batch_size)
882 {
883   const size_t x_stride = context->x_stride;
884   const size_t y_stride = context->y_stride;
885 
886   const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
887   void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
888 
889   context->ukernel(
890     batch_size,
891     context->n,
892     x, x_stride,
893     context->w,
894     y, y_stride,
895     &context->params);
896 }
897 
898 #if XNN_MAX_UARCH_TYPES > 1
xnn_compute_hmp_grouped_gemm(const struct gemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)899   void xnn_compute_hmp_grouped_gemm(
900       const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
901       uint32_t uarch_index,
902       size_t group_index,
903       size_t mr_block_start,
904       size_t nr_block_start,
905       size_t mr_block_size,
906       size_t nr_block_size)
907   {
908     const size_t k_scaled  = context->k_scaled;
909     const size_t a_stride  = context->a_stride;
910     const size_t cm_stride = context->cm_stride;
911 
912     context->ukernel.function[uarch_index](
913         mr_block_size,
914         nr_block_size,
915         k_scaled,
916         (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled),
917         a_stride,
918         (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride),
919         (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride),
920         cm_stride,
921         context->cn_stride,
922         &context->params);
923   }
924 
xnn_compute_hmp_gemm(const struct gemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)925   void xnn_compute_hmp_gemm(
926       const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
927       uint32_t uarch_index,
928       size_t mr_block_start,
929       size_t nr_block_start,
930       size_t mr_block_size,
931       size_t nr_block_size)
932   {
933     const size_t a_stride  = context->a_stride;
934     const size_t cm_stride = context->cm_stride;
935 
936     context->ukernel.function[uarch_index](
937         mr_block_size,
938         nr_block_size,
939         context->k_scaled,
940         (const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
941         a_stride,
942         (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
943         (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
944         cm_stride,
945         context->cn_stride,
946         &context->params);
947   }
948 
xnn_compute_hmp_grouped_batch_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t batch_index,size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)949   void xnn_compute_hmp_grouped_batch_igemm(
950       const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
951       uint32_t uarch_index,
952       size_t batch_index,
953       size_t group_index,
954       size_t mr_block_start,
955       size_t nr_block_start,
956       size_t mr_block_size,
957       size_t nr_block_size)
958   {
959     const size_t ks        = context->ks;
960     const size_t cm_stride = context->cm_stride;
961 
962     context->ukernel.function[uarch_index](
963         mr_block_size,
964         nr_block_size,
965         context->kc,
966         context->ks_scaled,
967         (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
968         (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
969         (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
970         cm_stride,
971         context->cn_stride,
972         context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
973         context->zero,
974         &context->params);
975   }
976 
xnn_compute_hmp_grouped_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)977   void xnn_compute_hmp_grouped_igemm(
978       const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
979       uint32_t uarch_index,
980       size_t group_index,
981       size_t mr_block_start,
982       size_t nr_block_start,
983       size_t mr_block_size,
984       size_t nr_block_size)
985   {
986     const size_t ks        = context->ks;
987     const size_t cm_stride = context->cm_stride;
988 
989     context->ukernel.function[uarch_index](
990         mr_block_size,
991         nr_block_size,
992         context->kc,
993         context->ks_scaled,
994         (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
995         (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
996         (void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
997         cm_stride,
998         context->cn_stride,
999         context->a_offset + group_index * context->ga_stride,
1000         context->zero,
1001         &context->params);
1002   }
1003 
xnn_compute_batch_hmp_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t batch_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)1004   void xnn_compute_batch_hmp_igemm(
1005       const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1006       uint32_t uarch_index,
1007       size_t batch_index,
1008       size_t mr_block_start,
1009       size_t nr_block_start,
1010       size_t mr_block_size,
1011       size_t nr_block_size)
1012   {
1013     const size_t ks        = context->ks;
1014     const size_t cm_stride = context->cm_stride;
1015 
1016     context->ukernel.function[uarch_index](
1017         mr_block_size,
1018         nr_block_size,
1019         context->kc,
1020         context->ks_scaled,
1021         (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
1022         (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
1023         (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
1024         cm_stride,
1025         context->cn_stride,
1026         context->a_offset + batch_index * context->ba_stride,
1027         context->zero,
1028         &context->params);
1029   }
1030 
xnn_compute_hmp_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)1031   void xnn_compute_hmp_igemm(
1032       const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1033       uint32_t uarch_index,
1034       size_t mr_block_start,
1035       size_t nr_block_start,
1036       size_t mr_block_size,
1037       size_t nr_block_size)
1038   {
1039     const size_t ks        = context->ks;
1040     const size_t cm_stride = context->cm_stride;
1041 
1042     context->ukernel.function[uarch_index](
1043         mr_block_size,
1044         nr_block_size,
1045         context->kc,
1046         context->ks_scaled,
1047         (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
1048         (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
1049         (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
1050         cm_stride,
1051         context->cn_stride,
1052         context->a_offset,
1053         context->zero,
1054         &context->params);
1055   }
1056 #endif  // XNN_MAX_UARCH_TYPES > 1
1057 
xnn_run_operator(xnn_operator_t op,pthreadpool_t threadpool)1058 enum xnn_status xnn_run_operator(xnn_operator_t op, pthreadpool_t threadpool)
1059 {
1060   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
1061     xnn_log_error("failed to run operator: XNNPACK is not initialized");
1062     return xnn_status_uninitialized;
1063   }
1064   switch (op->state) {
1065     case xnn_run_state_invalid:
1066       xnn_log_error("failed to run operator: operator was not successfully setup");
1067       return xnn_status_invalid_state;
1068     case xnn_run_state_ready:
1069       break;
1070     case xnn_run_state_skip:
1071       return xnn_status_success;
1072   }
1073 
1074   switch (op->compute.type) {
1075     case xnn_parallelization_type_invalid:
1076       break;
1077     case xnn_parallelization_type_1d:
1078       assert(op->compute.range[0] != 0);
1079       pthreadpool_parallelize_1d(
1080           threadpool,
1081           op->compute.task_1d,
1082           &op->context,
1083           op->compute.range[0],
1084           PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1085       break;
1086     case xnn_parallelization_type_1d_tile_1d:
1087       assert(op->compute.range[0] != 0);
1088       assert(op->compute.tile[0] != 0);
1089       pthreadpool_parallelize_1d_tile_1d(
1090           threadpool,
1091           op->compute.task_1d_tile_1d,
1092           &op->context,
1093           op->compute.range[0],
1094           op->compute.tile[0],
1095           PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1096       break;
1097     case xnn_parallelization_type_2d:
1098       assert(op->compute.range[0] != 0);
1099       assert(op->compute.range[1] != 0);
1100       pthreadpool_parallelize_2d(
1101           threadpool,
1102           op->compute.task_2d,
1103           &op->context,
1104           op->compute.range[0], op->compute.range[1],
1105           PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1106       break;
1107     case xnn_parallelization_type_2d_tile_1d:
1108       assert(op->compute.range[0] != 0);
1109       assert(op->compute.range[1] != 0);
1110       assert(op->compute.tile[0] != 0);
1111       pthreadpool_parallelize_2d_tile_1d(
1112           threadpool,
1113           op->compute.task_2d_tile_1d,
1114           &op->context,
1115           op->compute.range[0], op->compute.range[1],
1116           op->compute.tile[0],
1117           PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1118       break;
1119     case xnn_parallelization_type_2d_tile_2d:
1120       assert(op->compute.range[0] != 0);
1121       assert(op->compute.range[1] != 0);
1122       assert(op->compute.tile[0] != 0);
1123       assert(op->compute.tile[1] != 0);
1124       pthreadpool_parallelize_2d_tile_2d(
1125           threadpool,
1126           op->compute.task_2d_tile_2d,
1127           &op->context,
1128           op->compute.range[0], op->compute.range[1],
1129           op->compute.tile[0], op->compute.tile[1],
1130           PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1131       break;
1132     case xnn_parallelization_type_3d:
1133       assert(op->compute.range[0] != 0);
1134       assert(op->compute.range[1] != 0);
1135       assert(op->compute.range[2] != 0);
1136       pthreadpool_parallelize_3d(
1137           threadpool,
1138           op->compute.task_3d,
1139           &op->context,
1140           op->compute.range[0], op->compute.range[1], op->compute.range[2],
1141           PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1142       break;
1143     case xnn_parallelization_type_3d_tile_2d:
1144       assert(op->compute.range[0] != 0);
1145       assert(op->compute.range[1] != 0);
1146       assert(op->compute.range[2] != 0);
1147       assert(op->compute.tile[0] != 0);
1148       assert(op->compute.tile[1] != 0);
1149       pthreadpool_parallelize_3d_tile_2d(
1150           threadpool,
1151           op->compute.task_3d_tile_2d,
1152           &op->context,
1153           op->compute.range[0], op->compute.range[1], op->compute.range[2],
1154           op->compute.tile[0], op->compute.tile[1],
1155           PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1156       break;
1157     case xnn_parallelization_type_4d:
1158       assert(op->compute.range[0] != 0);
1159       assert(op->compute.range[1] != 0);
1160       assert(op->compute.range[2] != 0);
1161       assert(op->compute.range[3] != 0);
1162       pthreadpool_parallelize_4d(
1163           threadpool,
1164           op->compute.task_4d,
1165           &op->context,
1166           op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
1167           PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1168       break;
1169     case xnn_parallelization_type_4d_tile_2d:
1170       assert(op->compute.range[0] != 0);
1171       assert(op->compute.range[1] != 0);
1172       assert(op->compute.range[2] != 0);
1173       assert(op->compute.range[3] != 0);
1174       assert(op->compute.tile[0] != 0);
1175       assert(op->compute.tile[1] != 0);
1176       pthreadpool_parallelize_4d_tile_2d(
1177           threadpool,
1178           op->compute.task_4d_tile_2d,
1179           &op->context,
1180           op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
1181           op->compute.tile[0], op->compute.tile[1],
1182           PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1183       break;
1184     case xnn_parallelization_type_5d:
1185       assert(op->compute.range[0] != 0);
1186       assert(op->compute.range[1] != 0);
1187       assert(op->compute.range[2] != 0);
1188       assert(op->compute.range[3] != 0);
1189       assert(op->compute.range[4] != 0);
1190       pthreadpool_parallelize_5d(
1191           threadpool,
1192           op->compute.task_5d,
1193           &op->context,
1194           op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4],
1195           PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1196       break;
1197     case xnn_parallelization_type_5d_tile_2d:
1198       assert(op->compute.range[0] != 0);
1199       assert(op->compute.range[1] != 0);
1200       assert(op->compute.range[2] != 0);
1201       assert(op->compute.range[3] != 0);
1202       assert(op->compute.range[4] != 0);
1203       assert(op->compute.tile[0] != 0);
1204       assert(op->compute.tile[1] != 0);
1205       pthreadpool_parallelize_5d_tile_2d(
1206           threadpool,
1207           op->compute.task_5d_tile_2d,
1208           &op->context,
1209           op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4],
1210           op->compute.tile[0], op->compute.tile[1],
1211           PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1212       break;
1213     case xnn_parallelization_type_6d_tile_2d:
1214       assert(op->compute.range[0] != 0);
1215       assert(op->compute.range[1] != 0);
1216       assert(op->compute.range[2] != 0);
1217       assert(op->compute.range[3] != 0);
1218       assert(op->compute.range[4] != 0);
1219       assert(op->compute.range[5] != 0);
1220       assert(op->compute.tile[0] != 0);
1221       assert(op->compute.tile[1] != 0);
1222       pthreadpool_parallelize_6d_tile_2d(
1223           threadpool,
1224           op->compute.task_6d_tile_2d,
1225           &op->context,
1226           op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4], op->compute.range[5],
1227           op->compute.tile[0], op->compute.tile[1],
1228           PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1229       break;
1230 #if XNN_MAX_UARCH_TYPES > 1
1231     case xnn_parallelization_type_2d_tile_2d_with_uarch:
1232       assert(op->compute.range[0] != 0);
1233       assert(op->compute.range[1] != 0);
1234       assert(op->compute.tile[0] != 0);
1235       assert(op->compute.tile[1] != 0);
1236       pthreadpool_parallelize_2d_tile_2d_with_uarch(
1237           threadpool,
1238           op->compute.task_2d_tile_2d_with_id,
1239           &op->context,
1240           0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
1241           op->compute.range[0], op->compute.range[1],
1242           op->compute.tile[0], op->compute.tile[1],
1243           PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1244       break;
1245     case xnn_parallelization_type_3d_tile_2d_with_uarch:
1246       assert(op->compute.range[0] != 0);
1247       assert(op->compute.range[1] != 0);
1248       assert(op->compute.range[2] != 0);
1249       assert(op->compute.tile[0] != 0);
1250       assert(op->compute.tile[1] != 0);
1251       pthreadpool_parallelize_3d_tile_2d_with_uarch(
1252           threadpool,
1253           op->compute.task_3d_tile_2d_with_id,
1254           &op->context,
1255           0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
1256           op->compute.range[0], op->compute.range[1], op->compute.range[2],
1257           op->compute.tile[0], op->compute.tile[1],
1258           PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1259       break;
1260     case xnn_parallelization_type_4d_tile_2d_with_uarch:
1261       assert(op->compute.range[0] != 0);
1262       assert(op->compute.range[1] != 0);
1263       assert(op->compute.range[2] != 0);
1264       assert(op->compute.range[3] != 0);
1265       assert(op->compute.tile[0] != 0);
1266       assert(op->compute.tile[1] != 0);
1267       pthreadpool_parallelize_4d_tile_2d_with_uarch(
1268           threadpool,
1269           op->compute.task_4d_tile_2d_with_id,
1270           &op->context,
1271           0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
1272           op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
1273           op->compute.tile[0], op->compute.tile[1],
1274           PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1275       break;
1276 #endif  // XNN_MAX_UARCH_TYPES > 1
1277     default:
1278       XNN_UNREACHABLE;
1279   }
1280   return xnn_status_success;
1281 }
1282