1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #include <stddef.h>
12 #include <stdint.h>
13 
14 #include <pthreadpool.h>
15 
16 #include <xnnpack/params.h>
17 #include <xnnpack/compute.h>
18 
19 
20 enum xnn_ukernel_type {
21   xnn_ukernel_type_default = 0,
22   xnn_ukernel_type_average_pooling,
23   xnn_ukernel_type_conv2d_hwc2chw,
24   xnn_ukernel_type_dwconv,
25   xnn_ukernel_type_gemm,
26   xnn_ukernel_type_igemm,
27   xnn_ukernel_type_pixelwise_average_pooling,
28   xnn_ukernel_type_spmm,
29   xnn_ukernel_type_subconv2d,
30   xnn_ukernel_type_vmulcaddc,
31 };
32 
33 enum xnn_operator_type {
34   xnn_operator_type_invalid = 0,
35   xnn_operator_type_abs_nc_f32,
36   xnn_operator_type_add_nd_f16,
37   xnn_operator_type_add_nd_f32,
38   xnn_operator_type_add_nd_qs8,
39   xnn_operator_type_argmax_pooling_nhwc_f32,
40   xnn_operator_type_average_pooling_nhwc_f32,
41   xnn_operator_type_average_pooling_nhwc_qu8,
42   xnn_operator_type_bankers_rounding_nc_f32,
43   xnn_operator_type_channel_shuffle_nc_x32,
44   xnn_operator_type_channel_shuffle_nc_x8,
45   xnn_operator_type_clamp_nc_f32,
46   xnn_operator_type_clamp_nc_u8,
47   xnn_operator_type_ceiling_nc_f32,
48   xnn_operator_type_constant_pad_nd_x32,
49   xnn_operator_type_convolution_nchw_f32,
50   xnn_operator_type_convolution_nhwc_f16,
51   xnn_operator_type_convolution_nhwc_f32,
52   xnn_operator_type_convolution_nhwc_qs8,
53   xnn_operator_type_convolution_nhwc_qu8,
54   xnn_operator_type_copy_nc_x32,
55   xnn_operator_type_deconvolution_nhwc_f32,
56   xnn_operator_type_deconvolution_nhwc_qu8,
57   xnn_operator_type_depth_to_space_nchw2nhwc_x32,
58   xnn_operator_type_depth_to_space_nhwc_x32,
59   xnn_operator_type_divide_nd_f32,
60   xnn_operator_type_elu_nc_f32,
61   xnn_operator_type_fully_connected_nc_f32,
62   xnn_operator_type_fully_connected_nc_qu8,
63   xnn_operator_type_floor_nc_f32,
64   xnn_operator_type_global_average_pooling_nwc_f16,
65   xnn_operator_type_global_average_pooling_nwc_f32,
66   xnn_operator_type_global_average_pooling_nwc_qs8,
67   xnn_operator_type_global_average_pooling_nwc_qu8,
68   xnn_operator_type_global_average_pooling_ncw_f32,
69   xnn_operator_type_hardswish_nc_f16,
70   xnn_operator_type_hardswish_nc_f32,
71   xnn_operator_type_leaky_relu_nc_f32,
72   xnn_operator_type_leaky_relu_nc_qu8,
73   xnn_operator_type_max_pooling_nhwc_f32,
74   xnn_operator_type_max_pooling_nhwc_u8,
75   xnn_operator_type_maximum_nd_f32,
76   xnn_operator_type_minimum_nd_f32,
77   xnn_operator_type_multiply_nd_f16,
78   xnn_operator_type_multiply_nd_f32,
79   xnn_operator_type_negate_nc_f32,
80   xnn_operator_type_prelu_nc_f32,
81   xnn_operator_type_resize_bilinear_nchw_f32,
82   xnn_operator_type_resize_bilinear_nhwc_f32,
83   xnn_operator_type_sigmoid_nc_f32,
84   xnn_operator_type_sigmoid_nc_qu8,
85   xnn_operator_type_softmax_nc_f32,
86   xnn_operator_type_softmax_nc_qu8,
87   xnn_operator_type_square_nc_f32,
88   xnn_operator_type_square_root_nc_f32,
89   xnn_operator_type_squared_difference_nd_f32,
90   xnn_operator_type_subtract_nd_f32,
91   xnn_operator_type_truncation_nc_f32,
92   xnn_operator_type_unpooling_nhwc_x32,
93 };
94 
95 struct xnn_ukernel_conv2d {
96   union {
97     xnn_conv_hwc2chw_ukernel_function hwc2chw_function;
98     xnn_conv_hwc_ukernel_function hwc_function;
99   };
100   uint8_t output_height_tile;
101   uint8_t output_channel_tile;
102 };
103 
104 struct xnn_ukernel_dwconv {
105   union {
106     xnn_dwconv_unipass_ukernel_function unipass_function;
107     xnn_dwconv_multipass_ukernel_function multipass_function;
108   };
109   uint8_t primary_tile;
110   uint8_t incremental_tile;
111 };
112 
113 // Direct 2D Depthwise Convolution
114 struct xnn_ukernel_dwconv2d {
115   union {
116     xnn_dwconv2d_chw_ukernel_function chw_function;
117   };
118   uint8_t output_width_tile;
119 };
120 
121 struct xnn_ukernel_gemm {
122   struct xnn_hmp_gemm_ukernel general_case;
123   struct xnn_hmp_gemm_ukernel mr1_case;
124   uint8_t mr;
125   uint8_t nr;
126   uint8_t kr;
127 };
128 
129 struct xnn_ukernel_igemm {
130   struct xnn_hmp_igemm_ukernel general_case;
131   struct xnn_hmp_igemm_ukernel mr1_case;
132   struct xnn_hmp_gemm_ukernel gemm_case;
133   uint8_t mr;
134   uint8_t nr;
135   uint8_t kr;
136 };
137 
138 struct xnn_ukernel_spmm {
139   xnn_spmm_ukernel_function function;
140   uint8_t mr;
141 };
142 
143 struct xnn_ukernel_vmulcaddc {
144   xnn_vmulcaddc_ukernel_function function;
145   uint8_t mr;
146 };
147 
148 struct xnn_ukernel_vbinary {
149   xnn_vbinary_ukernel_function op_function;
150   xnn_vbinary_ukernel_function opc_function;
151   xnn_vbinary_ukernel_function ropc_function;
152 };
153 
154 struct xnn_ukernel_vunary {
155   xnn_vunary_ukernel_function function;
156 };
157 
158 struct xnn_ukernel {
159   enum xnn_ukernel_type type;
160   union {
161     struct xnn_ukernel_conv2d conv2d;
162     struct xnn_ukernel_dwconv dwconv;
163     struct xnn_ukernel_dwconv2d dwconv2d;
164     struct xnn_ukernel_gemm gemm;
165     struct xnn_ukernel_igemm igemm;
166     struct xnn_ukernel_spmm spmm;
167     struct xnn_ukernel_vmulcaddc vmulcaddc;
168     struct xnn_ukernel_vbinary vbinary;
169     struct xnn_ukernel_vunary vunary;
170   };
171 };
172 
173 enum xnn_run_state {
174   xnn_run_state_invalid = 0,
175   xnn_run_state_ready,
176   xnn_run_state_skip,
177 };
178 
179 struct subconvolution_params {
180   void* weights;
181   size_t w_stride;
182   const void** indirection_buffer;
183   void* output;
184   size_t slice_width;
185   size_t slice_height;
186   size_t indirection_y_stride;
187   size_t indirection_x_stride;
188   // scaled_kernel_size := kernel_size * mr * sizeof(void*).
189   size_t scaled_kernel_size;
190 };
191 
192 struct xnn_operator {
193   size_t batch_size;
194   uint32_t padding_top;
195   uint32_t padding_right;
196   uint32_t padding_bottom;
197   uint32_t padding_left;
198   uint32_t kernel_height;
199   uint32_t kernel_width;
200   uint32_t stride_height;
201   uint32_t stride_width;
202   uint32_t dilation_height;
203   uint32_t dilation_width;
204   uint32_t groups;
205   size_t group_channels;
206   size_t group_input_channels;
207   size_t group_output_channels;
208   size_t channels;
209 
210   size_t pad_before_channels;
211   size_t pad_after_channels;
212   uint32_t pad_value;
213 
214   size_t input_height;
215   size_t input_width;
216   size_t input_pixel_stride;
217   const void* input;
218   const void* input2;
219   const void** indirection_buffer;
220 
221   size_t output_height;
222   size_t output_width;
223   size_t output_pixel_stride;
224   void* output;
225 
226   void* packed_weights;
227   // Total number of non-zero kernel elements when weights use sparse representation.
228   size_t num_nonzero_values;
229   // Total number of non-zero kernel blocks when weights use sparse representation.
230   size_t num_nonzero_blocks;
231   // Total number of output channel blocks when weights use sparse representation.
232   size_t num_output_channel_blocks;
233   // Input channel corresponding to the first non-zero kernel element.
234   size_t first_input_channel;
235 
236   float input_scale;
237   float output_scale;
238   int32_t input_zero_point;
239   uint8_t output_zero_point;
240   uint8_t output_min;
241   uint8_t output_max;
242 
243   size_t valid_batch_size;
244   size_t last_input_height;
245   size_t last_input_width;
246   const void* last_input;
247   size_t last_output_height;
248   size_t last_output_width;
249   void* last_output;
250 
251   uint32_t block_size;
252 
253   void* zero_buffer;
254   void* lookup_table;
255   void* pixelwise_buffer;
256   struct subconvolution_params* subconvolution_buffer;
257   uint32_t flags;
258 
259   union {
260     union xnn_f32_abs_params f32_abs;
261     union xnn_f32_elu_params f32_elu;
262     union xnn_f32_lrelu_params f32_lrelu;
263     union xnn_f32_neg_params f32_neg;
264     union xnn_f32_rnd_params f32_rnd;
265     // Parameters for Global Average Pooling in CHW layout
266     union xnn_f32_gavgpool_params f32_gavgpool;
267     struct xnn_f16_hswish_params f16_hswish;
268     union xnn_f32_hswish_params f32_hswish;
269     struct {
270       struct xnn_f16_minmax_params f16_minmax;
271       struct xnn_f16_scaleminmax_params f16_scaleminmax;
272     };
273     // Pixelwise Average Pooling normally use f32_minmax_params, but also initialize
274     // f32_scaleminmax_params in case it needs to switch to Global Average Pooling operation.
275     struct {
276       union xnn_f32_minmax_params f32_minmax;
277       union xnn_f32_scaleminmax_params f32_scaleminmax;
278     };
279     union xnn_f32_chw_params f32_chw;
280     union xnn_qs8_gemm_params qs8_gemm;
281     // Average Pooling normally use qs8_avgpool_params, but also initialize qs8_gavgpool_params in case it needs to switch
282     // to Global Average Pooling operation.
283     struct {
284       union xnn_qs8_avgpool_params qs8_avgpool;
285       union xnn_qs8_avgpool_params qs8_gavgpool;
286     };
287     // Quantized Add parameters are sensitive to order of inputs, so we initialize an extra copy with the reversed order.
288     struct {
289       union xnn_qs8_add_params qs8_add;
290       union xnn_qs8_add_params qs8_radd;
291     };
292     union xnn_qu8_add_params qu8_add;
293     union xnn_qu8_gemm_params qu8_gemm;
294     // Average Pooling normally use qu8_avgpool_params, but also initialize qu8_gavgpool_params in case it needs to switch
295     // to Global Average Pooling operation.
296     struct {
297       union xnn_qu8_avgpool_params qu8_avgpool;
298       union xnn_qu8_avgpool_params qu8_gavgpool;
299     };
300     union xnn_u8_minmax_params u8_minmax;
301   } params;
302   enum xnn_operator_type type;
303   struct xnn_ukernel ukernel;
304 
305   struct compute_parameters compute;
306   struct compute_parameters compute2;
307   union {
308     struct argmax_pooling_context argmax_pooling;
309     struct average_pooling_context average_pooling;
310     struct channel_shuffle_context channel_shuffle;
311     struct conv2d_context conv2d;
312     struct dwconv2d_context dwconv2d;
313     struct dwconv_context dwconv;
314     struct depthtospace2d_chw2hwc_context depthtospace2d_chw;
315     struct depthtospace2d_hwc_context depthtospace2d_hwc;
316     struct elementwise_binary_context elementwise_binary;
317     struct gemm_context gemm;
318     struct global_average_pooling_nwc_context global_average_pooling_nwc;
319     struct global_average_pooling_ncw_context global_average_pooling_ncw;
320     struct igemm_context igemm;
321     struct lut_contiguous_context lut_contiguous;
322     struct lut_strided_context lut_strided;
323     struct max_pooling_context max_pooling;
324     struct pad_context pad;
325     struct pixelwise_average_pooling_context pixelwise_average_pooling;
326     struct prelu_context prelu;
327     struct resize_bilinear_context resize_bilinear;
328     struct resize_bilinear_chw_context resize_bilinear_chw;
329     struct spmm_context spmm;
330     struct subconv_context subconv;
331     struct subgemm_context subgemm;
332     struct f32_three_pass_softmax_context f32_three_pass_softmax;
333     struct u8_softmax_context u8_softmax;
334     struct univector_contiguous_context univector_contiguous;
335     struct univector_strided_context univector_strided;
336     struct unpooling_context unpooling;
337     struct vmulcaddc_context vmulcaddc;
338   } context;
339 
340   enum xnn_run_state state;
341 };
342