// Copyright (c) Facebook, Inc. and its affiliates. // All rights reserved. // // Copyright 2019 Google LLC // // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. #include #include #include #include #include #include #include #include #include #include #include #include void xnn_compute_grouped_gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], size_t group_index, size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, size_t nr_block_size) { const size_t k_scaled = context->k_scaled; const size_t a_stride = context->a_stride; const size_t cm_stride = context->cm_stride; context->ukernel.function[XNN_UARCH_DEFAULT]( mr_block_size, nr_block_size, k_scaled, (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled), a_stride, (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride), (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride), cm_stride, context->cn_stride, &context->params); } void xnn_compute_gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, size_t nr_block_size) { const size_t a_stride = context->a_stride; const size_t cm_stride = context->cm_stride; context->ukernel.function[XNN_UARCH_DEFAULT]( mr_block_size, nr_block_size, context->k_scaled, (const void*) ((uintptr_t) context->a + mr_block_start * a_stride), a_stride, (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), cm_stride, context->cn_stride, &context->params); } void xnn_compute_spmm( const struct spmm_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index, size_t mr_block_start, size_t mr_block_size) { context->ukernel( mr_block_size, context->n, (const void*) ((uintptr_t) context->input + batch_index * context->batched_input_stride + mr_block_start), context->nonzero_weights, context->input_increments, context->output_channel_nonzeros, (void*) ((uintptr_t) context->output + batch_index * context->batched_output_stride + mr_block_start), context->scaled_m, &context->params); } void xnn_compute_grouped_batch_igemm( const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index, size_t group_index, size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, size_t nr_block_size) { const size_t ks = context->ks; const size_t cm_stride = context->cm_stride; context->ukernel.function[XNN_UARCH_DEFAULT]( mr_block_size, nr_block_size, context->kc, context->ks_scaled, (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride), (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), cm_stride, context->cn_stride, context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride, context->zero, &context->params); } void xnn_compute_grouped_igemm( const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], size_t group_index, size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, size_t nr_block_size) { const size_t ks = context->ks; const size_t cm_stride = context->cm_stride; context->ukernel.function[XNN_UARCH_DEFAULT]( mr_block_size, nr_block_size, context->kc, context->ks_scaled, (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride), (void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), cm_stride, context->cn_stride, context->a_offset + group_index * context->ga_stride, context->zero, &context->params); } void xnn_compute_batch_igemm( const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index, size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, size_t nr_block_size) { const size_t ks = context->ks; const size_t cm_stride = context->cm_stride; context->ukernel.function[XNN_UARCH_DEFAULT]( mr_block_size, nr_block_size, context->kc, context->ks_scaled, (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), cm_stride, context->cn_stride, context->a_offset + batch_index * context->ba_stride, context->zero, &context->params); } void xnn_compute_igemm( const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, size_t nr_block_size) { const size_t ks = context->ks; const size_t cm_stride = context->cm_stride; context->ukernel.function[XNN_UARCH_DEFAULT]( mr_block_size, nr_block_size, context->kc, context->ks_scaled, (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), cm_stride, context->cn_stride, context->a_offset, context->zero, &context->params); } void xnn_compute_grouped_subgemm2d( const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index, size_t group_index, size_t subkernel_index, size_t slice_y, size_t slice_x_start, size_t nc_block_start, size_t slice_x_max, size_t nc_block_size) { const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index]; if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) { return; } const size_t slice_width = subconvolution_params->slice_width; if XNN_UNLIKELY(slice_x_start >= slice_width) { return; } const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start); const size_t ax_stride = context->ax_stride; const size_t cx_stride = context->cx_stride; context->ukernel.function[XNN_UARCH_DEFAULT]( slice_x_size, nc_block_size, context->kc, (const void*) ((uintptr_t) context->a + group_index * context->ga_stride + slice_y * context->ay_stride + slice_x_start * ax_stride + batch_index * context->ba_stride), ax_stride, (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride), (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)), cx_stride, context->cn_stride, &context->params); } void xnn_compute_subgemm2d( const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index, size_t subkernel_index, size_t slice_y, size_t slice_x_start, size_t nc_block_start, size_t slice_x_max, size_t nc_block_size) { const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index]; if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) { return; } const size_t slice_width = subconvolution_params->slice_width; if XNN_UNLIKELY(slice_x_start >= slice_width) { return; } const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start); const size_t ax_stride = context->ax_stride; const size_t cx_stride = context->cx_stride; context->ukernel.function[XNN_UARCH_DEFAULT]( slice_x_size, nc_block_size, context->kc, (const void*) ((uintptr_t) context->a + slice_y * context->ay_stride + slice_x_start * ax_stride + batch_index * context->ba_stride), ax_stride, (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride), (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)), cx_stride, context->cn_stride, &context->params); } void xnn_compute_grouped_subconv2d( const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index, size_t group_index, size_t subkernel_index, size_t slice_y, size_t slice_x_start, size_t nc_block_start, size_t slice_x_max, size_t nc_block_size) { const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index]; if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) { return; } const size_t slice_width = subconvolution_params->slice_width; if XNN_UNLIKELY(slice_x_start >= slice_width) { return; } const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start); const size_t cx_stride = context->cx_stride; context->ukernel.function[XNN_UARCH_DEFAULT]( slice_x_size, nc_block_size, context->kc, subconvolution_params->scaled_kernel_size, (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride), (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride), (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)), cx_stride, context->cn_stride, context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride, context->zero, &context->params); } void xnn_compute_subconv2d( const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index, size_t subkernel_index, size_t slice_y, size_t slice_x_start, size_t nc_block_start, size_t slice_x_max, size_t nc_block_size) { const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index]; if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) { return; } const size_t slice_width = subconvolution_params->slice_width; if XNN_UNLIKELY(slice_x_start >= slice_width) { return; } const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start); const size_t cx_stride = context->cx_stride; context->ukernel.function[XNN_UARCH_DEFAULT]( slice_x_size, nc_block_size, context->kc, subconvolution_params->scaled_kernel_size, (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride), (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride), (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)), cx_stride, context->cn_stride, context->a_offset + batch_index * context->ba_stride, context->zero, &context->params); } void xnn_compute_conv2d_hwc2chw( const struct conv2d_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index, size_t output_y_start, size_t output_y_slice) { context->hwc2chw_ukernel( context->input_height, context->input_width, output_y_start, output_y_start + output_y_slice, (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride), context->zero, context->packed_weights, (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride), context->input_padding_top, context->output_channels, context->output_height_stride, context->output_channel_stride, &context->params); } void xnn_compute_dwconv_unipass( const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index, size_t output_y) { const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride); const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; void* output = (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride); context->unipass_ukernel( context->groups, context->output_width, indirect_input, context->packed_weights, output, context->indirect_input_width_stride, context->output_increment, input_offset, context->zero, &context->params); } void xnn_compute_dwconv2d_chw( const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index, size_t channel) { context->chw_ukernel( context->input_height, context->input_width, (const void*) ((uintptr_t) context->input + channel * context->input_channel_stride + batch_index * context->input_batch_stride), (const void*) ((uintptr_t) context->packed_weights + channel * context->weights_channel_stride), context->zero, (void*) ((uintptr_t) context->output + channel * context->output_channel_stride + batch_index * context->output_batch_stride), context->input_padding_top, &context->params); } void xnn_compute_depthtospace2d_hwc_contiguous( const struct depthtospace2d_hwc_context* context, size_t batch_input_y, size_t input_x, size_t block_y) { const size_t input_width = context->input_width; const size_t elements = context->elements; const void* input = (const void*) ((uintptr_t) context->input + (batch_input_y * input_width + input_x) * context->input_width_stride + block_y * elements); void* output = (void*) ((uintptr_t) context->output + ((batch_input_y * context->block_size + block_y) * input_width + input_x) * elements); context->ukernel( elements, input, output, NULL); } void xnn_compute_depthtospace2d_hwc_strided( const struct depthtospace2d_hwc_context* context, size_t batch_input_y, size_t input_x, size_t block_y, size_t block_x) { const size_t block_size = context->block_size; const size_t elements = context->elements; const void* input = (const void*) ((uintptr_t) context->input + batch_input_y * context->input_height_stride + input_x * context->input_width_stride + (block_y * block_size + block_x) * elements); void* output = (void*) ((uintptr_t) context->output + (batch_input_y * block_size + block_y) * context->output_height_stride + (input_x * block_size + block_x) * context->output_width_stride); context->ukernel( elements, input, output, NULL); } void xnn_compute_depthtospace2d_chw2hwc( const struct depthtospace2d_chw2hwc_context* context, size_t batch_index) { context->ukernel( context->output_channels, context->input_height, context->input_width, context->block_size, (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride), (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride), context->output_channel_stride); } void xnn_compute_argmax_pooling_unipass( const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index, size_t output_y) { const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride); const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; void* output = (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride); uint32_t* index = (uint32_t*) ((uintptr_t) context->index + batch_index * context->index_batch_stride + output_y * context->index_height_stride); context->unipass_ukernel( context->output_width, context->pooling_size, context->channels, indirect_input, input_offset, output, index, context->input_increment, context->output_increment); } void xnn_compute_argmax_pooling_multipass( const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index, size_t output_y) { const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride); const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; void* output = (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride); uint32_t* index = (uint32_t*) ((uintptr_t) context->index + batch_index * context->index_batch_stride + output_y * context->index_height_stride); void* multipass_accumulation_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(float) + XNN_EXTRA_BYTES); void* multipass_index_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(uint32_t) + XNN_EXTRA_BYTES); context->multipass_ukernel( context->output_width, context->pooling_size, context->channels, indirect_input, input_offset, multipass_accumulation_buffer, multipass_index_buffer, output, index, context->input_increment, context->output_increment); } void xnn_compute_max_pooling( const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index, size_t output_y) { const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride); const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; void* output = (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride); context->ukernel( context->output_width, context->pooling_size, context->channels, indirect_input, input_offset, output, context->input_increment, context->output_increment, &context->params); } void xnn_compute_unpooling( const struct unpooling_context context[restrict XNN_MIN_ELEMENTS(1)], size_t input_y, size_t input_x) { const void* input = (const void*) ((uintptr_t) context->input + input_y * context->input_height_stride + input_x * context->input_width_stride); const uint32_t* index = (const uint32_t*) ((uintptr_t) context->index + input_y * context->index_height_stride + input_x * context->index_width_stride); void** indirect_output = (void**) ((uintptr_t) context->indirect_output + input_y * context->indirect_output_height_stride + input_x * context->indirect_output_width_stride); context->ukernel( context->pooling_size, context->channels, context->fill_value, input, index, indirect_output); } void xnn_compute_average_pooling_unipass( const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index, size_t output_y) { const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride); const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; void* output = (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride); context->unipass_ukernel( context->output_width, context->pooling_size, context->channels, indirect_input, input_offset, context->zero, output, context->input_increment, context->output_increment, &context->params); } void xnn_compute_average_pooling_multipass( const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index, size_t output_y) { const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride); const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; void* output = (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride); void* multipass_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t)); context->multipass_ukernel( context->output_width, context->pooling_size, context->channels, indirect_input, input_offset, context->zero, multipass_buffer, output, context->input_increment, context->output_increment, &context->params); } void xnn_compute_pixelwise_average_pooling_unipass( const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index, size_t output_y) { const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride); const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; const void* pixelwise_buffer = (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride); void* output = (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride); context->unipass_ukernel( context->output_width, context->pooling_size, context->channels, indirect_input, input_offset, context->zero, pixelwise_buffer, output, context->input_increment, context->output_increment, &context->params); } void xnn_compute_pixelwise_average_pooling_multipass( const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index, size_t output_y) { const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride); const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride; const void* pixelwise_buffer = (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride); void* output = (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride); void* multipass_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t)); context->multipass_ukernel( context->output_width, context->pooling_size, context->channels, indirect_input, input_offset, context->zero, pixelwise_buffer, multipass_buffer, output, context->input_increment, context->output_increment, &context->params); } void xnn_compute_global_average_pooling_nwc_unipass( const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index) { const void* input = (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride); void* output = (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride); context->unipass_ukernel( context->input_elements, context->channels, input, context->input_pixel_stride, context->zero, output, &context->params); } void xnn_compute_global_average_pooling_nwc_multipass( const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index) { const void* input = (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride); void* output = (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride); void* multipass_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t)); context->multipass_ukernel( context->input_elements, context->channels, input, context->input_pixel_stride, context->zero, multipass_buffer, output, &context->params); } void xnn_compute_global_average_pooling_ncw( const struct global_average_pooling_ncw_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index, size_t channels_start, size_t channels_slice) { const void* input = (const void*) ((uintptr_t) context->input + channels_start * context->input_channel_stride + batch_index * context->input_batch_stride); void* output = (void*) ((uintptr_t) context->output + channels_start * context->output_channel_stride + batch_index * context->output_batch_stride); context->ukernel( context->input_elements, channels_slice, input, output, &context->params); } void xnn_compute_resize_bilinear( const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index, size_t pixel_start, size_t pixel_range) { void* output = (void*) ((uintptr_t) context->output + pixel_start * context->output_pixel_stride + batch_index * context->output_batch_stride); context->ukernel( pixel_range, context->scaled_channels, context->indirect_input + pixel_start * 4, context->input_offset + batch_index * context->input_batch_stride, (const void*) ((uintptr_t) context->packed_weights + (pixel_start << context->log2_wsize)), output, context->output_pixel_stride - context->scaled_channels); } void xnn_compute_resize_bilinear_chw( const struct resize_bilinear_chw_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index, size_t channel_start, size_t channel_range) { void* output = (void*) ((uintptr_t) context->output + channel_start * context->output_channel_stride + batch_index * context->output_batch_stride); const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride + channel_start * context->input_channel_stride; context->ukernel( context->output_pixels, channel_range, context->indirect_input, input_offset, context->packed_weights, output, context->input_channel_stride); } void xnn_compute_prelu( const struct prelu_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_start, size_t batch_range) { const size_t x_stride = context->x_stride; const size_t y_stride = context->y_stride; const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start); void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start); context->ukernel(batch_range, context->n, x, x_stride, context->w, y, y_stride); } void xnn_compute_pad_5d( const struct pad_context context[restrict XNN_MIN_ELEMENTS(1)], size_t i, size_t j, size_t k, size_t l, size_t m) { const void* input = (const void*) ((uintptr_t) context->input + i * context->input_stride[4] + j * context->input_stride[3] + k * context->input_stride[2] + l * context->input_stride[1] + m * context->input_stride[0]); void* output = (void*) ((uintptr_t) context->output + i * context->output_stride[4] + j * context->output_stride[3] + k * context->output_stride[2] + l * context->output_stride[1] + m * context->output_stride[0]); const size_t i_padding = context->pre_paddings[5]; const size_t j_padding = context->pre_paddings[4]; const size_t k_padding = context->pre_paddings[3]; const size_t l_padding = context->pre_paddings[2]; const size_t m_padding = context->pre_paddings[1]; const size_t i_size = context->input_size[5]; const size_t j_size = context->input_size[4]; const size_t k_size = context->input_size[3]; const size_t l_size = context->input_size[2]; const size_t m_size = context->input_size[1]; if XNN_LIKELY(i - i_padding < i_size && j - j_padding < j_size && k - k_padding < k_size && l - l_padding < l_size && m - m_padding < m_size) { context->pad_ukernel( 1 /* rows */, context->input_size[0], context->pre_paddings[0], context->post_paddings[0], &context->padding_value, input, 0 /* input stride */, output, 0 /* output stride */); } else { context->fill_ukernel(1 /* rows */, context->output_size[0], output, 0 /* output stride */, &context->padding_value); } } void xnn_compute_elementwise_binary_5d( const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], size_t i, size_t j, size_t k, size_t l, size_t m) { const void* a = (const void*) ((uintptr_t) context->a + i * context->a_stride[0] + j * context->a_stride[1] + k * context->a_stride[2] + l * context->a_stride[3] + m * context->a_stride[4]); const void* b = (const void*) ((uintptr_t) context->b + i * context->b_stride[0] + j * context->b_stride[1] + k * context->b_stride[2] + l * context->b_stride[3] + m * context->b_stride[4]); void* y = (void*) ((uintptr_t) context->y + i * context->y_stride[0] + j * context->y_stride[1] + k * context->y_stride[2] + l * context->y_stride[3] + m * context->y_stride[4]); context->ukernel(context->elements, a, b, y, &context->params); } void xnn_compute_channel_shuffle_fixed( const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)], size_t index) { const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride); void* y = (void*) ((uintptr_t) context->y + index * context->y_stride); context->fixed_ukernel(context->n, x, y); } void xnn_compute_channel_shuffle_variable( const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)], size_t index) { const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride); void* y = (void*) ((uintptr_t) context->y + index * context->y_stride); context->variable_ukernel(context->n, context->m, x, y); } void xnn_compute_lut_strided( const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index) { const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index); void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index); context->ukernel(context->n, x, context->t, y); } void xnn_compute_lut_contiguous( const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)], size_t offset, size_t size) { const void* x = (const void*) ((uintptr_t) context->x + offset); void* y = (void*) ((uintptr_t) context->y + offset); context->ukernel(size, x, context->t, y); } void xnn_compute_univector_strided( const struct univector_strided_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index, size_t batch_range /* always 1 */) { assert(batch_range == 1); const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index); void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index); context->ukernel(context->n, x, y, &context->params); } void xnn_compute_univector_contiguous( const struct univector_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)], size_t offset, size_t size) { const void* x = (const void*) ((uintptr_t) context->x + offset); void* y = (void*) ((uintptr_t) context->y + offset); context->ukernel(size, x, y, &context->params); } void xnn_compute_u8_softmax( const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index) { const uint8_t* x = (const uint8_t*) ((uintptr_t) context->x + context->x_stride * batch_index); uint8_t* y = (uint8_t*) ((uintptr_t) context->y + context->y_stride * batch_index); const size_t n = context->n; uint8_t x_max = 0; context->rmax_ukernel(n, x, &x_max); const size_t adjustment = x_max ^ 255; const uint32_t* t = (const uint32_t*) context->t + adjustment; context->lut_norm_ukernel(n, x, t, y); } void xnn_compute_f32_three_pass_softmax( const struct f32_three_pass_softmax_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index) { const float* x = (const float*) ((uintptr_t) context->x + context->x_stride * batch_index); float* y = (float*) ((uintptr_t) context->y + context->y_stride * batch_index); const size_t n = context->n; // First pass: reduce-max float x_max; context->rmax_ukernel(n, x, &x_max); // Second pass: reduce-add & store exp(x-x_max) float y_sum; context->raddstoreexpminusmax_ukernel(n, x, y, &y_sum, x_max); // Third pass: scale y const float y_scale = 1.0f / y_sum; context->vmulc_ukernel(n, y, &y_scale, y, &context->params); } void xnn_compute_vmulcaddc( const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_start, size_t batch_size) { const size_t x_stride = context->x_stride; const size_t y_stride = context->y_stride; const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start); void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start); context->ukernel( batch_size, context->n, x, x_stride, context->w, y, y_stride, &context->params); } #if XNN_MAX_UARCH_TYPES > 1 void xnn_compute_hmp_grouped_gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], uint32_t uarch_index, size_t group_index, size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, size_t nr_block_size) { const size_t k_scaled = context->k_scaled; const size_t a_stride = context->a_stride; const size_t cm_stride = context->cm_stride; context->ukernel.function[uarch_index]( mr_block_size, nr_block_size, k_scaled, (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled), a_stride, (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride), (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride), cm_stride, context->cn_stride, &context->params); } void xnn_compute_hmp_gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, size_t nr_block_size) { const size_t a_stride = context->a_stride; const size_t cm_stride = context->cm_stride; context->ukernel.function[uarch_index]( mr_block_size, nr_block_size, context->k_scaled, (const void*) ((uintptr_t) context->a + mr_block_start * a_stride), a_stride, (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), cm_stride, context->cn_stride, &context->params); } void xnn_compute_hmp_grouped_batch_igemm( const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], uint32_t uarch_index, size_t batch_index, size_t group_index, size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, size_t nr_block_size) { const size_t ks = context->ks; const size_t cm_stride = context->cm_stride; context->ukernel.function[uarch_index]( mr_block_size, nr_block_size, context->kc, context->ks_scaled, (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride), (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), cm_stride, context->cn_stride, context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride, context->zero, &context->params); } void xnn_compute_hmp_grouped_igemm( const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], uint32_t uarch_index, size_t group_index, size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, size_t nr_block_size) { const size_t ks = context->ks; const size_t cm_stride = context->cm_stride; context->ukernel.function[uarch_index]( mr_block_size, nr_block_size, context->kc, context->ks_scaled, (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride), (void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), cm_stride, context->cn_stride, context->a_offset + group_index * context->ga_stride, context->zero, &context->params); } void xnn_compute_batch_hmp_igemm( const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], uint32_t uarch_index, size_t batch_index, size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, size_t nr_block_size) { const size_t ks = context->ks; const size_t cm_stride = context->cm_stride; context->ukernel.function[uarch_index]( mr_block_size, nr_block_size, context->kc, context->ks_scaled, (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), cm_stride, context->cn_stride, context->a_offset + batch_index * context->ba_stride, context->zero, &context->params); } void xnn_compute_hmp_igemm( const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, size_t nr_block_size) { const size_t ks = context->ks; const size_t cm_stride = context->cm_stride; context->ukernel.function[uarch_index]( mr_block_size, nr_block_size, context->kc, context->ks_scaled, (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)), (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), cm_stride, context->cn_stride, context->a_offset, context->zero, &context->params); } #endif // XNN_MAX_UARCH_TYPES > 1 enum xnn_status xnn_run_operator(xnn_operator_t op, pthreadpool_t threadpool) { if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) { xnn_log_error("failed to run operator: XNNPACK is not initialized"); return xnn_status_uninitialized; } switch (op->state) { case xnn_run_state_invalid: xnn_log_error("failed to run operator: operator was not successfully setup"); return xnn_status_invalid_state; case xnn_run_state_ready: break; case xnn_run_state_skip: return xnn_status_success; } switch (op->compute.type) { case xnn_parallelization_type_invalid: break; case xnn_parallelization_type_1d: assert(op->compute.range[0] != 0); pthreadpool_parallelize_1d( threadpool, op->compute.task_1d, &op->context, op->compute.range[0], PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */); break; case xnn_parallelization_type_1d_tile_1d: assert(op->compute.range[0] != 0); assert(op->compute.tile[0] != 0); pthreadpool_parallelize_1d_tile_1d( threadpool, op->compute.task_1d_tile_1d, &op->context, op->compute.range[0], op->compute.tile[0], PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */); break; case xnn_parallelization_type_2d: assert(op->compute.range[0] != 0); assert(op->compute.range[1] != 0); pthreadpool_parallelize_2d( threadpool, op->compute.task_2d, &op->context, op->compute.range[0], op->compute.range[1], PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */); break; case xnn_parallelization_type_2d_tile_1d: assert(op->compute.range[0] != 0); assert(op->compute.range[1] != 0); assert(op->compute.tile[0] != 0); pthreadpool_parallelize_2d_tile_1d( threadpool, op->compute.task_2d_tile_1d, &op->context, op->compute.range[0], op->compute.range[1], op->compute.tile[0], PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */); break; case xnn_parallelization_type_2d_tile_2d: assert(op->compute.range[0] != 0); assert(op->compute.range[1] != 0); assert(op->compute.tile[0] != 0); assert(op->compute.tile[1] != 0); pthreadpool_parallelize_2d_tile_2d( threadpool, op->compute.task_2d_tile_2d, &op->context, op->compute.range[0], op->compute.range[1], op->compute.tile[0], op->compute.tile[1], PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */); break; case xnn_parallelization_type_3d: assert(op->compute.range[0] != 0); assert(op->compute.range[1] != 0); assert(op->compute.range[2] != 0); pthreadpool_parallelize_3d( threadpool, op->compute.task_3d, &op->context, op->compute.range[0], op->compute.range[1], op->compute.range[2], PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */); break; case xnn_parallelization_type_3d_tile_2d: assert(op->compute.range[0] != 0); assert(op->compute.range[1] != 0); assert(op->compute.range[2] != 0); assert(op->compute.tile[0] != 0); assert(op->compute.tile[1] != 0); pthreadpool_parallelize_3d_tile_2d( threadpool, op->compute.task_3d_tile_2d, &op->context, op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.tile[0], op->compute.tile[1], PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */); break; case xnn_parallelization_type_4d: assert(op->compute.range[0] != 0); assert(op->compute.range[1] != 0); assert(op->compute.range[2] != 0); assert(op->compute.range[3] != 0); pthreadpool_parallelize_4d( threadpool, op->compute.task_4d, &op->context, op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */); break; case xnn_parallelization_type_4d_tile_2d: assert(op->compute.range[0] != 0); assert(op->compute.range[1] != 0); assert(op->compute.range[2] != 0); assert(op->compute.range[3] != 0); assert(op->compute.tile[0] != 0); assert(op->compute.tile[1] != 0); pthreadpool_parallelize_4d_tile_2d( threadpool, op->compute.task_4d_tile_2d, &op->context, op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.tile[0], op->compute.tile[1], PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */); break; case xnn_parallelization_type_5d: assert(op->compute.range[0] != 0); assert(op->compute.range[1] != 0); assert(op->compute.range[2] != 0); assert(op->compute.range[3] != 0); assert(op->compute.range[4] != 0); pthreadpool_parallelize_5d( threadpool, op->compute.task_5d, &op->context, op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4], PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */); break; case xnn_parallelization_type_5d_tile_2d: assert(op->compute.range[0] != 0); assert(op->compute.range[1] != 0); assert(op->compute.range[2] != 0); assert(op->compute.range[3] != 0); assert(op->compute.range[4] != 0); assert(op->compute.tile[0] != 0); assert(op->compute.tile[1] != 0); pthreadpool_parallelize_5d_tile_2d( threadpool, op->compute.task_5d_tile_2d, &op->context, op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4], op->compute.tile[0], op->compute.tile[1], PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */); break; case xnn_parallelization_type_6d_tile_2d: assert(op->compute.range[0] != 0); assert(op->compute.range[1] != 0); assert(op->compute.range[2] != 0); assert(op->compute.range[3] != 0); assert(op->compute.range[4] != 0); assert(op->compute.range[5] != 0); assert(op->compute.tile[0] != 0); assert(op->compute.tile[1] != 0); pthreadpool_parallelize_6d_tile_2d( threadpool, op->compute.task_6d_tile_2d, &op->context, op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4], op->compute.range[5], op->compute.tile[0], op->compute.tile[1], PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */); break; #if XNN_MAX_UARCH_TYPES > 1 case xnn_parallelization_type_2d_tile_2d_with_uarch: assert(op->compute.range[0] != 0); assert(op->compute.range[1] != 0); assert(op->compute.tile[0] != 0); assert(op->compute.tile[1] != 0); pthreadpool_parallelize_2d_tile_2d_with_uarch( threadpool, op->compute.task_2d_tile_2d_with_id, &op->context, 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1, op->compute.range[0], op->compute.range[1], op->compute.tile[0], op->compute.tile[1], PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */); break; case xnn_parallelization_type_3d_tile_2d_with_uarch: assert(op->compute.range[0] != 0); assert(op->compute.range[1] != 0); assert(op->compute.range[2] != 0); assert(op->compute.tile[0] != 0); assert(op->compute.tile[1] != 0); pthreadpool_parallelize_3d_tile_2d_with_uarch( threadpool, op->compute.task_3d_tile_2d_with_id, &op->context, 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1, op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.tile[0], op->compute.tile[1], PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */); break; case xnn_parallelization_type_4d_tile_2d_with_uarch: assert(op->compute.range[0] != 0); assert(op->compute.range[1] != 0); assert(op->compute.range[2] != 0); assert(op->compute.range[3] != 0); assert(op->compute.tile[0] != 0); assert(op->compute.tile[1] != 0); pthreadpool_parallelize_4d_tile_2d_with_uarch( threadpool, op->compute.task_4d_tile_2d_with_id, &op->context, 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1, op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.tile[0], op->compute.tile[1], PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */); break; #endif // XNN_MAX_UARCH_TYPES > 1 default: XNN_UNREACHABLE; } return xnn_status_success; }