Lines Matching full:context

25     const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],  in xnn_compute_grouped_gemm()
32 const size_t k_scaled = context->k_scaled; in xnn_compute_grouped_gemm()
33 const size_t a_stride = context->a_stride; in xnn_compute_grouped_gemm()
34 const size_t cm_stride = context->cm_stride; in xnn_compute_grouped_gemm()
36 context->ukernel.function[XNN_UARCH_DEFAULT]( in xnn_compute_grouped_gemm()
40 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled), in xnn_compute_grouped_gemm()
42 …(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * in xnn_compute_grouped_gemm()
43 …(void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_cs… in xnn_compute_grouped_gemm()
45 context->cn_stride, in xnn_compute_grouped_gemm()
46 &context->params); in xnn_compute_grouped_gemm()
50 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_gemm()
56 const size_t a_stride = context->a_stride; in xnn_compute_gemm()
57 const size_t cm_stride = context->cm_stride; in xnn_compute_gemm()
59 context->ukernel.function[XNN_UARCH_DEFAULT]( in xnn_compute_gemm()
62 context->k_scaled, in xnn_compute_gemm()
63 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride), in xnn_compute_gemm()
65 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), in xnn_compute_gemm()
66 …(void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_cs… in xnn_compute_gemm()
68 context->cn_stride, in xnn_compute_gemm()
69 &context->params); in xnn_compute_gemm()
73 const struct spmm_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_spmm()
78 context->ukernel( in xnn_compute_spmm()
80 context->n, in xnn_compute_spmm()
81 …(const void*) ((uintptr_t) context->input + batch_index * context->batched_input_stride + mr_block… in xnn_compute_spmm()
82 context->nonzero_weights, in xnn_compute_spmm()
83 context->input_increments, in xnn_compute_spmm()
84 context->output_channel_nonzeros, in xnn_compute_spmm()
85 …(void*) ((uintptr_t) context->output + batch_index * context->batched_output_stride + mr_block_sta… in xnn_compute_spmm()
86 context->scaled_m, in xnn_compute_spmm()
87 &context->params); in xnn_compute_spmm()
91 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_grouped_batch_igemm()
99 const size_t ks = context->ks; in xnn_compute_grouped_batch_igemm()
100 const size_t cm_stride = context->cm_stride; in xnn_compute_grouped_batch_igemm()
102 context->ukernel.function[XNN_UARCH_DEFAULT]( in xnn_compute_grouped_batch_igemm()
105 context->kc, in xnn_compute_grouped_batch_igemm()
106 context->ks_scaled, in xnn_compute_grouped_batch_igemm()
107 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), in xnn_compute_grouped_batch_igemm()
108 …(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * in xnn_compute_grouped_batch_igemm()
109 …ptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block… in xnn_compute_grouped_batch_igemm()
111 context->cn_stride, in xnn_compute_grouped_batch_igemm()
112 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride, in xnn_compute_grouped_batch_igemm()
113 context->zero, in xnn_compute_grouped_batch_igemm()
114 &context->params); in xnn_compute_grouped_batch_igemm()
118 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_grouped_igemm()
125 const size_t ks = context->ks; in xnn_compute_grouped_igemm()
126 const size_t cm_stride = context->cm_stride; in xnn_compute_grouped_igemm()
128 context->ukernel.function[XNN_UARCH_DEFAULT]( in xnn_compute_grouped_igemm()
131 context->kc, in xnn_compute_grouped_igemm()
132 context->ks_scaled, in xnn_compute_grouped_igemm()
133 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), in xnn_compute_grouped_igemm()
134 …(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * in xnn_compute_grouped_igemm()
135 …(void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + … in xnn_compute_grouped_igemm()
137 context->cn_stride, in xnn_compute_grouped_igemm()
138 context->a_offset + group_index * context->ga_stride, in xnn_compute_grouped_igemm()
139 context->zero, in xnn_compute_grouped_igemm()
140 &context->params); in xnn_compute_grouped_igemm()
144 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_batch_igemm()
151 const size_t ks = context->ks; in xnn_compute_batch_igemm()
152 const size_t cm_stride = context->cm_stride; in xnn_compute_batch_igemm()
154 context->ukernel.function[XNN_UARCH_DEFAULT]( in xnn_compute_batch_igemm()
157 context->kc, in xnn_compute_batch_igemm()
158 context->ks_scaled, in xnn_compute_batch_igemm()
159 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), in xnn_compute_batch_igemm()
160 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), in xnn_compute_batch_igemm()
161 …(void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + … in xnn_compute_batch_igemm()
163 context->cn_stride, in xnn_compute_batch_igemm()
164 context->a_offset + batch_index * context->ba_stride, in xnn_compute_batch_igemm()
165 context->zero, in xnn_compute_batch_igemm()
166 &context->params); in xnn_compute_batch_igemm()
170 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_igemm()
176 const size_t ks = context->ks; in xnn_compute_igemm()
177 const size_t cm_stride = context->cm_stride; in xnn_compute_igemm()
179 context->ukernel.function[XNN_UARCH_DEFAULT]( in xnn_compute_igemm()
182 context->kc, in xnn_compute_igemm()
183 context->ks_scaled, in xnn_compute_igemm()
184 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), in xnn_compute_igemm()
185 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), in xnn_compute_igemm()
186 …(void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_cs… in xnn_compute_igemm()
188 context->cn_stride, in xnn_compute_igemm()
189 context->a_offset, in xnn_compute_igemm()
190 context->zero, in xnn_compute_igemm()
191 &context->params); in xnn_compute_igemm()
195 const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_grouped_subgemm2d()
205 …const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subker… in xnn_compute_grouped_subgemm2d()
217 const size_t ax_stride = context->ax_stride; in xnn_compute_grouped_subgemm2d()
218 const size_t cx_stride = context->cx_stride; in xnn_compute_grouped_subgemm2d()
219 context->ukernel.function[XNN_UARCH_DEFAULT]( in xnn_compute_grouped_subgemm2d()
222 context->kc, in xnn_compute_grouped_subgemm2d()
223 …(uintptr_t) context->a + group_index * context->ga_stride + slice_y * context->ay_stride + slice_x… in xnn_compute_grouped_subgemm2d()
225 …ms->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride), in xnn_compute_grouped_subgemm2d()
226 …dex * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index … in xnn_compute_grouped_subgemm2d()
228 context->cn_stride, in xnn_compute_grouped_subgemm2d()
229 &context->params); in xnn_compute_grouped_subgemm2d()
233 const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_subgemm2d()
242 …const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subker… in xnn_compute_subgemm2d()
254 const size_t ax_stride = context->ax_stride; in xnn_compute_subgemm2d()
255 const size_t cx_stride = context->cx_stride; in xnn_compute_subgemm2d()
256 context->ukernel.function[XNN_UARCH_DEFAULT]( in xnn_compute_subgemm2d()
259 context->kc, in xnn_compute_subgemm2d()
260 …(const void*) ((uintptr_t) context->a + slice_y * context->ay_stride + slice_x_start * ax_stride +… in xnn_compute_subgemm2d()
263 …->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_st… in xnn_compute_subgemm2d()
265 context->cn_stride, in xnn_compute_subgemm2d()
266 &context->params); in xnn_compute_subgemm2d()
270 const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_grouped_subconv2d()
280 …const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subker… in xnn_compute_grouped_subconv2d()
292 const size_t cx_stride = context->cx_stride; in xnn_compute_grouped_subconv2d()
293 context->ukernel.function[XNN_UARCH_DEFAULT]( in xnn_compute_grouped_subconv2d()
296 context->kc, in xnn_compute_grouped_subconv2d()
299 …ms->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride), in xnn_compute_grouped_subconv2d()
300 …dex * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index … in xnn_compute_grouped_subconv2d()
302 context->cn_stride, in xnn_compute_grouped_subconv2d()
303 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride, in xnn_compute_grouped_subconv2d()
304 context->zero, in xnn_compute_grouped_subconv2d()
305 &context->params); in xnn_compute_grouped_subconv2d()
309 const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_subconv2d()
318 …const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subker… in xnn_compute_subconv2d()
330 const size_t cx_stride = context->cx_stride; in xnn_compute_subconv2d()
331 context->ukernel.function[XNN_UARCH_DEFAULT]( in xnn_compute_subconv2d()
334 context->kc, in xnn_compute_subconv2d()
338 …->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_st… in xnn_compute_subconv2d()
340 context->cn_stride, in xnn_compute_subconv2d()
341 context->a_offset + batch_index * context->ba_stride, in xnn_compute_subconv2d()
342 context->zero, in xnn_compute_subconv2d()
343 &context->params); in xnn_compute_subconv2d()
347 const struct conv2d_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_conv2d_hwc2chw()
352 context->hwc2chw_ukernel( in xnn_compute_conv2d_hwc2chw()
353 context->input_height, in xnn_compute_conv2d_hwc2chw()
354 context->input_width, in xnn_compute_conv2d_hwc2chw()
357 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride), in xnn_compute_conv2d_hwc2chw()
358 context->zero, in xnn_compute_conv2d_hwc2chw()
359 context->packed_weights, in xnn_compute_conv2d_hwc2chw()
360 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride), in xnn_compute_conv2d_hwc2chw()
361 context->input_padding_top, in xnn_compute_conv2d_hwc2chw()
362 context->output_channels, in xnn_compute_conv2d_hwc2chw()
363 context->output_height_stride, in xnn_compute_conv2d_hwc2chw()
364 context->output_channel_stride, in xnn_compute_conv2d_hwc2chw()
365 &context->params); in xnn_compute_conv2d_hwc2chw()
369 const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_dwconv_unipass()
374 …(const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_st… in xnn_compute_dwconv_unipass()
375 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; in xnn_compute_dwconv_unipass()
376 void* output = (void*) ((uintptr_t) context->output + in xnn_compute_dwconv_unipass()
377 batch_index * context->output_batch_stride + output_y * context->output_height_stride); in xnn_compute_dwconv_unipass()
379 context->unipass_ukernel( in xnn_compute_dwconv_unipass()
380 context->groups, context->output_width, in xnn_compute_dwconv_unipass()
381 indirect_input, context->packed_weights, output, in xnn_compute_dwconv_unipass()
382 context->indirect_input_width_stride, context->output_increment, in xnn_compute_dwconv_unipass()
383 input_offset, context->zero, in xnn_compute_dwconv_unipass()
384 &context->params); in xnn_compute_dwconv_unipass()
388 const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_dwconv2d_chw()
392 context->chw_ukernel( in xnn_compute_dwconv2d_chw()
393 context->input_height, in xnn_compute_dwconv2d_chw()
394 context->input_width, in xnn_compute_dwconv2d_chw()
395 …(const void*) ((uintptr_t) context->input + channel * context->input_channel_stride + batch_index … in xnn_compute_dwconv2d_chw()
396 (const void*) ((uintptr_t) context->packed_weights + channel * context->weights_channel_stride), in xnn_compute_dwconv2d_chw()
397 context->zero, in xnn_compute_dwconv2d_chw()
398 …(void*) ((uintptr_t) context->output + channel * context->output_channel_stride + batch_index * co… in xnn_compute_dwconv2d_chw()
399 context->input_padding_top, in xnn_compute_dwconv2d_chw()
400 &context->params); in xnn_compute_dwconv2d_chw()
404 const struct depthtospace2d_hwc_context* context, in xnn_compute_depthtospace2d_hwc_contiguous() argument
409 const size_t input_width = context->input_width; in xnn_compute_depthtospace2d_hwc_contiguous()
410 const size_t elements = context->elements; in xnn_compute_depthtospace2d_hwc_contiguous()
411 const void* input = (const void*) ((uintptr_t) context->input + in xnn_compute_depthtospace2d_hwc_contiguous()
412 (batch_input_y * input_width + input_x) * context->input_width_stride + block_y * elements); in xnn_compute_depthtospace2d_hwc_contiguous()
413 void* output = (void*) ((uintptr_t) context->output + in xnn_compute_depthtospace2d_hwc_contiguous()
414 ((batch_input_y * context->block_size + block_y) * input_width + input_x) * elements); in xnn_compute_depthtospace2d_hwc_contiguous()
416 context->ukernel( in xnn_compute_depthtospace2d_hwc_contiguous()
424 const struct depthtospace2d_hwc_context* context, in xnn_compute_depthtospace2d_hwc_strided() argument
430 const size_t block_size = context->block_size; in xnn_compute_depthtospace2d_hwc_strided()
431 const size_t elements = context->elements; in xnn_compute_depthtospace2d_hwc_strided()
432 const void* input = (const void*) ((uintptr_t) context->input + in xnn_compute_depthtospace2d_hwc_strided()
433 …batch_input_y * context->input_height_stride + input_x * context->input_width_stride + (block_y * … in xnn_compute_depthtospace2d_hwc_strided()
434 void* output = (void*) ((uintptr_t) context->output + in xnn_compute_depthtospace2d_hwc_strided()
435 (batch_input_y * block_size + block_y) * context->output_height_stride + in xnn_compute_depthtospace2d_hwc_strided()
436 (input_x * block_size + block_x) * context->output_width_stride); in xnn_compute_depthtospace2d_hwc_strided()
438 context->ukernel( in xnn_compute_depthtospace2d_hwc_strided()
446 const struct depthtospace2d_chw2hwc_context* context, in xnn_compute_depthtospace2d_chw2hwc() argument
449 context->ukernel( in xnn_compute_depthtospace2d_chw2hwc()
450 context->output_channels, in xnn_compute_depthtospace2d_chw2hwc()
451 context->input_height, in xnn_compute_depthtospace2d_chw2hwc()
452 context->input_width, in xnn_compute_depthtospace2d_chw2hwc()
453 context->block_size, in xnn_compute_depthtospace2d_chw2hwc()
454 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride), in xnn_compute_depthtospace2d_chw2hwc()
455 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride), in xnn_compute_depthtospace2d_chw2hwc()
456 context->output_channel_stride); in xnn_compute_depthtospace2d_chw2hwc()
460 const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_argmax_pooling_unipass()
464 const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input + in xnn_compute_argmax_pooling_unipass()
465 output_y * context->indirect_input_height_stride); in xnn_compute_argmax_pooling_unipass()
466 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; in xnn_compute_argmax_pooling_unipass()
467 void* output = (void*) ((uintptr_t) context->output + in xnn_compute_argmax_pooling_unipass()
468 batch_index * context->output_batch_stride + output_y * context->output_height_stride); in xnn_compute_argmax_pooling_unipass()
469 uint32_t* index = (uint32_t*) ((uintptr_t) context->index + in xnn_compute_argmax_pooling_unipass()
470 batch_index * context->index_batch_stride + output_y * context->index_height_stride); in xnn_compute_argmax_pooling_unipass()
472 context->unipass_ukernel( in xnn_compute_argmax_pooling_unipass()
473 context->output_width, context->pooling_size, context->channels, in xnn_compute_argmax_pooling_unipass()
475 context->input_increment, context->output_increment); in xnn_compute_argmax_pooling_unipass()
479 const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_argmax_pooling_multipass()
483 const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input + in xnn_compute_argmax_pooling_multipass()
484 output_y * context->indirect_input_height_stride); in xnn_compute_argmax_pooling_multipass()
485 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; in xnn_compute_argmax_pooling_multipass()
486 void* output = (void*) ((uintptr_t) context->output + in xnn_compute_argmax_pooling_multipass()
487 batch_index * context->output_batch_stride + output_y * context->output_height_stride); in xnn_compute_argmax_pooling_multipass()
488 uint32_t* index = (uint32_t*) ((uintptr_t) context->index + in xnn_compute_argmax_pooling_multipass()
489 batch_index * context->index_batch_stride + output_y * context->index_height_stride); in xnn_compute_argmax_pooling_multipass()
491 …void* multipass_accumulation_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(float) + XNN_EXTR… in xnn_compute_argmax_pooling_multipass()
492 …void* multipass_index_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(uint32_t) + XNN_EXTRA_BY… in xnn_compute_argmax_pooling_multipass()
494 context->multipass_ukernel( in xnn_compute_argmax_pooling_multipass()
495 context->output_width, context->pooling_size, context->channels, in xnn_compute_argmax_pooling_multipass()
497 context->input_increment, context->output_increment); in xnn_compute_argmax_pooling_multipass()
501 const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_max_pooling()
505 const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input + in xnn_compute_max_pooling()
506 output_y * context->indirect_input_height_stride); in xnn_compute_max_pooling()
507 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; in xnn_compute_max_pooling()
508 void* output = (void*) ((uintptr_t) context->output + in xnn_compute_max_pooling()
509 batch_index * context->output_batch_stride + output_y * context->output_height_stride); in xnn_compute_max_pooling()
511 context->ukernel( in xnn_compute_max_pooling()
512 context->output_width, context->pooling_size, context->channels, in xnn_compute_max_pooling()
514 context->input_increment, context->output_increment, in xnn_compute_max_pooling()
515 &context->params); in xnn_compute_max_pooling()
519 const struct unpooling_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_unpooling()
523 const void* input = (const void*) ((uintptr_t) context->input + in xnn_compute_unpooling()
524 input_y * context->input_height_stride + input_x * context->input_width_stride); in xnn_compute_unpooling()
525 const uint32_t* index = (const uint32_t*) ((uintptr_t) context->index + in xnn_compute_unpooling()
526 input_y * context->index_height_stride + input_x * context->index_width_stride); in xnn_compute_unpooling()
528 (void**) ((uintptr_t) context->indirect_output + in xnn_compute_unpooling()
529 …input_y * context->indirect_output_height_stride + input_x * context->indirect_output_width_stride… in xnn_compute_unpooling()
531 context->ukernel( in xnn_compute_unpooling()
532 context->pooling_size, in xnn_compute_unpooling()
533 context->channels, in xnn_compute_unpooling()
534 context->fill_value, in xnn_compute_unpooling()
539 const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_average_pooling_unipass()
544 …(const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_st… in xnn_compute_average_pooling_unipass()
545 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; in xnn_compute_average_pooling_unipass()
546 void* output = (void*) ((uintptr_t) context->output + in xnn_compute_average_pooling_unipass()
547 batch_index * context->output_batch_stride + output_y * context->output_height_stride); in xnn_compute_average_pooling_unipass()
549 context->unipass_ukernel( in xnn_compute_average_pooling_unipass()
550 context->output_width, context->pooling_size, context->channels, in xnn_compute_average_pooling_unipass()
551 indirect_input, input_offset, context->zero, output, in xnn_compute_average_pooling_unipass()
552 context->input_increment, context->output_increment, in xnn_compute_average_pooling_unipass()
553 &context->params); in xnn_compute_average_pooling_unipass()
557 const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_average_pooling_multipass()
562 …(const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_st… in xnn_compute_average_pooling_multipass()
563 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; in xnn_compute_average_pooling_multipass()
564 void* output = (void*) ((uintptr_t) context->output + in xnn_compute_average_pooling_multipass()
565 batch_index * context->output_batch_stride + output_y * context->output_height_stride); in xnn_compute_average_pooling_multipass()
568 …XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(u… in xnn_compute_average_pooling_multipass()
570 context->multipass_ukernel( in xnn_compute_average_pooling_multipass()
571 context->output_width, context->pooling_size, context->channels, in xnn_compute_average_pooling_multipass()
572 indirect_input, input_offset, context->zero, multipass_buffer, output, in xnn_compute_average_pooling_multipass()
573 context->input_increment, context->output_increment, in xnn_compute_average_pooling_multipass()
574 &context->params); in xnn_compute_average_pooling_multipass()
578 const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_pixelwise_average_pooling_unipass()
583 …(const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_st… in xnn_compute_pixelwise_average_pooling_unipass()
584 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; in xnn_compute_pixelwise_average_pooling_unipass()
586 …(const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height… in xnn_compute_pixelwise_average_pooling_unipass()
587 void* output = (void*) ((uintptr_t) context->output + in xnn_compute_pixelwise_average_pooling_unipass()
588 batch_index * context->output_batch_stride + output_y * context->output_height_stride); in xnn_compute_pixelwise_average_pooling_unipass()
590 context->unipass_ukernel( in xnn_compute_pixelwise_average_pooling_unipass()
591 context->output_width, context->pooling_size, context->channels, in xnn_compute_pixelwise_average_pooling_unipass()
592 indirect_input, input_offset, context->zero, pixelwise_buffer, output, in xnn_compute_pixelwise_average_pooling_unipass()
593 context->input_increment, context->output_increment, in xnn_compute_pixelwise_average_pooling_unipass()
594 &context->params); in xnn_compute_pixelwise_average_pooling_unipass()
598 const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_pixelwise_average_pooling_multipass()
603 …(const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_st… in xnn_compute_pixelwise_average_pooling_multipass()
604 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; in xnn_compute_pixelwise_average_pooling_multipass()
606 …(const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height… in xnn_compute_pixelwise_average_pooling_multipass()
607 void* output = (void*) ((uintptr_t) context->output + in xnn_compute_pixelwise_average_pooling_multipass()
608 batch_index * context->output_batch_stride + output_y * context->output_height_stride); in xnn_compute_pixelwise_average_pooling_multipass()
610 …void* multipass_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * s… in xnn_compute_pixelwise_average_pooling_multipass()
612 context->multipass_ukernel( in xnn_compute_pixelwise_average_pooling_multipass()
613 context->output_width, context->pooling_size, context->channels, in xnn_compute_pixelwise_average_pooling_multipass()
614 indirect_input, input_offset, context->zero, pixelwise_buffer, multipass_buffer, output, in xnn_compute_pixelwise_average_pooling_multipass()
615 context->input_increment, context->output_increment, in xnn_compute_pixelwise_average_pooling_multipass()
616 &context->params); in xnn_compute_pixelwise_average_pooling_multipass()
620 const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_global_average_pooling_nwc_unipass()
624 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride); in xnn_compute_global_average_pooling_nwc_unipass()
626 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride); in xnn_compute_global_average_pooling_nwc_unipass()
628 context->unipass_ukernel( in xnn_compute_global_average_pooling_nwc_unipass()
629 context->input_elements, in xnn_compute_global_average_pooling_nwc_unipass()
630 context->channels, in xnn_compute_global_average_pooling_nwc_unipass()
632 context->input_pixel_stride, in xnn_compute_global_average_pooling_nwc_unipass()
633 context->zero, in xnn_compute_global_average_pooling_nwc_unipass()
635 &context->params); in xnn_compute_global_average_pooling_nwc_unipass()
639 const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_global_average_pooling_nwc_multipass()
643 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride); in xnn_compute_global_average_pooling_nwc_multipass()
645 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride); in xnn_compute_global_average_pooling_nwc_multipass()
648 …XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(u… in xnn_compute_global_average_pooling_nwc_multipass()
650 context->multipass_ukernel( in xnn_compute_global_average_pooling_nwc_multipass()
651 context->input_elements, in xnn_compute_global_average_pooling_nwc_multipass()
652 context->channels, in xnn_compute_global_average_pooling_nwc_multipass()
654 context->input_pixel_stride, in xnn_compute_global_average_pooling_nwc_multipass()
655 context->zero, in xnn_compute_global_average_pooling_nwc_multipass()
658 &context->params); in xnn_compute_global_average_pooling_nwc_multipass()
662 const struct global_average_pooling_ncw_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_global_average_pooling_ncw()
667 const void* input = (const void*) ((uintptr_t) context->input + in xnn_compute_global_average_pooling_ncw()
668 channels_start * context->input_channel_stride + batch_index * context->input_batch_stride); in xnn_compute_global_average_pooling_ncw()
669 void* output = (void*) ((uintptr_t) context->output + in xnn_compute_global_average_pooling_ncw()
670 channels_start * context->output_channel_stride + batch_index * context->output_batch_stride); in xnn_compute_global_average_pooling_ncw()
672 context->ukernel( in xnn_compute_global_average_pooling_ncw()
673 context->input_elements, in xnn_compute_global_average_pooling_ncw()
677 &context->params); in xnn_compute_global_average_pooling_ncw()
681 const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_resize_bilinear()
687 …(void*) ((uintptr_t) context->output + pixel_start * context->output_pixel_stride + batch_index * in xnn_compute_resize_bilinear()
689 context->ukernel( in xnn_compute_resize_bilinear()
691 context->scaled_channels, in xnn_compute_resize_bilinear()
692 context->indirect_input + pixel_start * 4, in xnn_compute_resize_bilinear()
693 context->input_offset + batch_index * context->input_batch_stride, in xnn_compute_resize_bilinear()
694 (const void*) ((uintptr_t) context->packed_weights + (pixel_start << context->log2_wsize)), in xnn_compute_resize_bilinear()
696 context->output_pixel_stride - context->scaled_channels); in xnn_compute_resize_bilinear()
700 const struct resize_bilinear_chw_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_resize_bilinear_chw()
706 …(void*) ((uintptr_t) context->output + channel_start * context->output_channel_stride + batch_inde… in xnn_compute_resize_bilinear_chw()
707 …const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride + ch… in xnn_compute_resize_bilinear_chw()
709 context->ukernel( in xnn_compute_resize_bilinear_chw()
710 context->output_pixels, in xnn_compute_resize_bilinear_chw()
712 context->indirect_input, in xnn_compute_resize_bilinear_chw()
714 context->packed_weights, in xnn_compute_resize_bilinear_chw()
716 context->input_channel_stride); in xnn_compute_resize_bilinear_chw()
720 const struct prelu_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_prelu()
724 const size_t x_stride = context->x_stride; in xnn_compute_prelu()
725 const size_t y_stride = context->y_stride; in xnn_compute_prelu()
726 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start); in xnn_compute_prelu()
727 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start); in xnn_compute_prelu()
729 context->ukernel(batch_range, context->n, x, x_stride, context->w, y, y_stride); in xnn_compute_prelu()
733 const struct pad_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_pad_5d()
736 const void* input = (const void*) ((uintptr_t) context->input + in xnn_compute_pad_5d()
737 …i * context->input_stride[4] + j * context->input_stride[3] + k * context->input_stride[2] + l * c… in xnn_compute_pad_5d()
738 void* output = (void*) ((uintptr_t) context->output + in xnn_compute_pad_5d()
739 …i * context->output_stride[4] + j * context->output_stride[3] + k * context->output_stride[2] + l … in xnn_compute_pad_5d()
741 const size_t i_padding = context->pre_paddings[5]; in xnn_compute_pad_5d()
742 const size_t j_padding = context->pre_paddings[4]; in xnn_compute_pad_5d()
743 const size_t k_padding = context->pre_paddings[3]; in xnn_compute_pad_5d()
744 const size_t l_padding = context->pre_paddings[2]; in xnn_compute_pad_5d()
745 const size_t m_padding = context->pre_paddings[1]; in xnn_compute_pad_5d()
747 const size_t i_size = context->input_size[5]; in xnn_compute_pad_5d()
748 const size_t j_size = context->input_size[4]; in xnn_compute_pad_5d()
749 const size_t k_size = context->input_size[3]; in xnn_compute_pad_5d()
750 const size_t l_size = context->input_size[2]; in xnn_compute_pad_5d()
751 const size_t m_size = context->input_size[1]; in xnn_compute_pad_5d()
756 context->pad_ukernel( in xnn_compute_pad_5d()
758 context->input_size[0], context->pre_paddings[0], context->post_paddings[0], in xnn_compute_pad_5d()
759 &context->padding_value, in xnn_compute_pad_5d()
762context->fill_ukernel(1 /* rows */, context->output_size[0], output, 0 /* output stride */, &conte… in xnn_compute_pad_5d()
767 const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_elementwise_binary_5d()
770 const void* a = (const void*) ((uintptr_t) context->a + in xnn_compute_elementwise_binary_5d()
771 …i * context->a_stride[0] + j * context->a_stride[1] + k * context->a_stride[2] + l * context->a_st… in xnn_compute_elementwise_binary_5d()
772 const void* b = (const void*) ((uintptr_t) context->b + in xnn_compute_elementwise_binary_5d()
773 …i * context->b_stride[0] + j * context->b_stride[1] + k * context->b_stride[2] + l * context->b_st… in xnn_compute_elementwise_binary_5d()
774 void* y = (void*) ((uintptr_t) context->y + in xnn_compute_elementwise_binary_5d()
775 …i * context->y_stride[0] + j * context->y_stride[1] + k * context->y_stride[2] + l * context->y_st… in xnn_compute_elementwise_binary_5d()
776 context->ukernel(context->elements, a, b, y, &context->params); in xnn_compute_elementwise_binary_5d()
780 const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_channel_shuffle_fixed()
783 const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride); in xnn_compute_channel_shuffle_fixed()
784 void* y = (void*) ((uintptr_t) context->y + index * context->y_stride); in xnn_compute_channel_shuffle_fixed()
786 context->fixed_ukernel(context->n, x, y); in xnn_compute_channel_shuffle_fixed()
790 const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_channel_shuffle_variable()
793 const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride); in xnn_compute_channel_shuffle_variable()
794 void* y = (void*) ((uintptr_t) context->y + index * context->y_stride); in xnn_compute_channel_shuffle_variable()
796 context->variable_ukernel(context->n, context->m, x, y); in xnn_compute_channel_shuffle_variable()
800 const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_lut_strided()
803 const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index); in xnn_compute_lut_strided()
804 void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index); in xnn_compute_lut_strided()
806 context->ukernel(context->n, x, context->t, y); in xnn_compute_lut_strided()
810 const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_lut_contiguous()
814 const void* x = (const void*) ((uintptr_t) context->x + offset); in xnn_compute_lut_contiguous()
815 void* y = (void*) ((uintptr_t) context->y + offset); in xnn_compute_lut_contiguous()
817 context->ukernel(size, x, context->t, y); in xnn_compute_lut_contiguous()
821 const struct univector_strided_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_univector_strided()
827 const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index); in xnn_compute_univector_strided()
828 void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index); in xnn_compute_univector_strided()
829 context->ukernel(context->n, x, y, &context->params); in xnn_compute_univector_strided()
833 const struct univector_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_univector_contiguous()
837 const void* x = (const void*) ((uintptr_t) context->x + offset); in xnn_compute_univector_contiguous()
838 void* y = (void*) ((uintptr_t) context->y + offset); in xnn_compute_univector_contiguous()
839 context->ukernel(size, x, y, &context->params); in xnn_compute_univector_contiguous()
843 const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_u8_softmax()
846 const uint8_t* x = (const uint8_t*) ((uintptr_t) context->x + context->x_stride * batch_index); in xnn_compute_u8_softmax()
847 uint8_t* y = (uint8_t*) ((uintptr_t) context->y + context->y_stride * batch_index); in xnn_compute_u8_softmax()
848 const size_t n = context->n; in xnn_compute_u8_softmax()
851 context->rmax_ukernel(n, x, &x_max); in xnn_compute_u8_softmax()
853 const uint32_t* t = (const uint32_t*) context->t + adjustment; in xnn_compute_u8_softmax()
854 context->lut_norm_ukernel(n, x, t, y); in xnn_compute_u8_softmax()
858 const struct f32_three_pass_softmax_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_f32_three_pass_softmax()
861 const float* x = (const float*) ((uintptr_t) context->x + context->x_stride * batch_index); in xnn_compute_f32_three_pass_softmax()
862 float* y = (float*) ((uintptr_t) context->y + context->y_stride * batch_index); in xnn_compute_f32_three_pass_softmax()
863 const size_t n = context->n; in xnn_compute_f32_three_pass_softmax()
867 context->rmax_ukernel(n, x, &x_max); in xnn_compute_f32_three_pass_softmax()
871 context->raddstoreexpminusmax_ukernel(n, x, y, &y_sum, x_max); in xnn_compute_f32_three_pass_softmax()
875 context->vmulc_ukernel(n, y, &y_scale, y, &context->params); in xnn_compute_f32_three_pass_softmax()
879 const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_vmulcaddc()
883 const size_t x_stride = context->x_stride; in xnn_compute_vmulcaddc()
884 const size_t y_stride = context->y_stride; in xnn_compute_vmulcaddc()
886 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start); in xnn_compute_vmulcaddc()
887 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start); in xnn_compute_vmulcaddc()
889 context->ukernel( in xnn_compute_vmulcaddc()
891 context->n, in xnn_compute_vmulcaddc()
893 context->w, in xnn_compute_vmulcaddc()
895 &context->params); in xnn_compute_vmulcaddc()
900 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_hmp_grouped_gemm()
908 const size_t k_scaled = context->k_scaled; in xnn_compute_hmp_grouped_gemm()
909 const size_t a_stride = context->a_stride; in xnn_compute_hmp_grouped_gemm()
910 const size_t cm_stride = context->cm_stride; in xnn_compute_hmp_grouped_gemm()
912 context->ukernel.function[uarch_index]( in xnn_compute_hmp_grouped_gemm()
916 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled), in xnn_compute_hmp_grouped_gemm()
918 …(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * in xnn_compute_hmp_grouped_gemm()
919 …(void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_cs… in xnn_compute_hmp_grouped_gemm()
921 context->cn_stride, in xnn_compute_hmp_grouped_gemm()
922 &context->params); in xnn_compute_hmp_grouped_gemm()
926 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_hmp_gemm()
933 const size_t a_stride = context->a_stride; in xnn_compute_hmp_gemm()
934 const size_t cm_stride = context->cm_stride; in xnn_compute_hmp_gemm()
936 context->ukernel.function[uarch_index]( in xnn_compute_hmp_gemm()
939 context->k_scaled, in xnn_compute_hmp_gemm()
940 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride), in xnn_compute_hmp_gemm()
942 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), in xnn_compute_hmp_gemm()
943 …(void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_cs… in xnn_compute_hmp_gemm()
945 context->cn_stride, in xnn_compute_hmp_gemm()
946 &context->params); in xnn_compute_hmp_gemm()
950 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_hmp_grouped_batch_igemm()
959 const size_t ks = context->ks; in xnn_compute_hmp_grouped_batch_igemm()
960 const size_t cm_stride = context->cm_stride; in xnn_compute_hmp_grouped_batch_igemm()
962 context->ukernel.function[uarch_index]( in xnn_compute_hmp_grouped_batch_igemm()
965 context->kc, in xnn_compute_hmp_grouped_batch_igemm()
966 context->ks_scaled, in xnn_compute_hmp_grouped_batch_igemm()
967 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), in xnn_compute_hmp_grouped_batch_igemm()
968 …(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * in xnn_compute_hmp_grouped_batch_igemm()
969 …ptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block… in xnn_compute_hmp_grouped_batch_igemm()
971 context->cn_stride, in xnn_compute_hmp_grouped_batch_igemm()
972 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride, in xnn_compute_hmp_grouped_batch_igemm()
973 context->zero, in xnn_compute_hmp_grouped_batch_igemm()
974 &context->params); in xnn_compute_hmp_grouped_batch_igemm()
978 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_hmp_grouped_igemm()
986 const size_t ks = context->ks; in xnn_compute_hmp_grouped_igemm()
987 const size_t cm_stride = context->cm_stride; in xnn_compute_hmp_grouped_igemm()
989 context->ukernel.function[uarch_index]( in xnn_compute_hmp_grouped_igemm()
992 context->kc, in xnn_compute_hmp_grouped_igemm()
993 context->ks_scaled, in xnn_compute_hmp_grouped_igemm()
994 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), in xnn_compute_hmp_grouped_igemm()
995 …(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * in xnn_compute_hmp_grouped_igemm()
996 …(void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + … in xnn_compute_hmp_grouped_igemm()
998 context->cn_stride, in xnn_compute_hmp_grouped_igemm()
999 context->a_offset + group_index * context->ga_stride, in xnn_compute_hmp_grouped_igemm()
1000 context->zero, in xnn_compute_hmp_grouped_igemm()
1001 &context->params); in xnn_compute_hmp_grouped_igemm()
1005 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_batch_hmp_igemm()
1013 const size_t ks = context->ks; in xnn_compute_batch_hmp_igemm()
1014 const size_t cm_stride = context->cm_stride; in xnn_compute_batch_hmp_igemm()
1016 context->ukernel.function[uarch_index]( in xnn_compute_batch_hmp_igemm()
1019 context->kc, in xnn_compute_batch_hmp_igemm()
1020 context->ks_scaled, in xnn_compute_batch_hmp_igemm()
1021 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), in xnn_compute_batch_hmp_igemm()
1022 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), in xnn_compute_batch_hmp_igemm()
1023 …(void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + … in xnn_compute_batch_hmp_igemm()
1025 context->cn_stride, in xnn_compute_batch_hmp_igemm()
1026 context->a_offset + batch_index * context->ba_stride, in xnn_compute_batch_hmp_igemm()
1027 context->zero, in xnn_compute_batch_hmp_igemm()
1028 &context->params); in xnn_compute_batch_hmp_igemm()
1032 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], in xnn_compute_hmp_igemm()
1039 const size_t ks = context->ks; in xnn_compute_hmp_igemm()
1040 const size_t cm_stride = context->cm_stride; in xnn_compute_hmp_igemm()
1042 context->ukernel.function[uarch_index]( in xnn_compute_hmp_igemm()
1045 context->kc, in xnn_compute_hmp_igemm()
1046 context->ks_scaled, in xnn_compute_hmp_igemm()
1047 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), in xnn_compute_hmp_igemm()
1048 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), in xnn_compute_hmp_igemm()
1049 …(void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_cs… in xnn_compute_hmp_igemm()
1051 context->cn_stride, in xnn_compute_hmp_igemm()
1052 context->a_offset, in xnn_compute_hmp_igemm()
1053 context->zero, in xnn_compute_hmp_igemm()
1054 &context->params); in xnn_compute_hmp_igemm()
1082 &op->context, in xnn_run_operator()
1092 &op->context, in xnn_run_operator()
1103 &op->context, in xnn_run_operator()
1114 &op->context, in xnn_run_operator()
1127 &op->context, in xnn_run_operator()
1139 &op->context, in xnn_run_operator()
1152 &op->context, in xnn_run_operator()
1165 &op->context, in xnn_run_operator()
1179 &op->context, in xnn_run_operator()
1193 &op->context, in xnn_run_operator()
1208 &op->context, in xnn_run_operator()
1225 &op->context, in xnn_run_operator()
1239 &op->context, in xnn_run_operator()
1254 &op->context, in xnn_run_operator()
1270 &op->context, in xnn_run_operator()