1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #include <assert.h>
10 #include <stdbool.h>
11 #include <stddef.h>
12 #include <stdint.h>
13 #include <string.h>
14 #include <math.h>
15 
16 #include <xnnpack.h>
17 #include <xnnpack/allocator.h>
18 #include <xnnpack/log.h>
19 #include <xnnpack/math.h>
20 #include <xnnpack/operator.h>
21 #include <xnnpack/pack.h>
22 #include <xnnpack/params-init.h>
23 #include <xnnpack/params.h>
24 
25 
create_fully_connected_nc(size_t input_channels,size_t output_channels,size_t input_stride,size_t output_stride,const void * kernel,const void * bias,uint32_t flags,uint32_t log2_filter_element_size,uint32_t bias_element_size,xnn_pack_gemm_io_w_function pack_gemm_io_w,xnn_pack_gemm_goi_w_function pack_gemm_goi_w,const void * packing_params,int packed_weights_padding_byte,const void * params,size_t params_size,const struct gemm_parameters * gemm_parameters,const struct gemm_fused_ukernels * gemm_ukernels,enum xnn_operator_type operator_type,xnn_operator_t * fully_connected_op_out)26 static enum xnn_status create_fully_connected_nc(
27     size_t input_channels,
28     size_t output_channels,
29     size_t input_stride,
30     size_t output_stride,
31     const void* kernel,
32     const void* bias,
33     uint32_t flags,
34     uint32_t log2_filter_element_size,
35     uint32_t bias_element_size,
36     xnn_pack_gemm_io_w_function pack_gemm_io_w,
37     xnn_pack_gemm_goi_w_function pack_gemm_goi_w,
38     const void* packing_params,
39     int packed_weights_padding_byte,
40     const void* params,
41     size_t params_size,
42     const struct gemm_parameters* gemm_parameters,
43     const struct gemm_fused_ukernels* gemm_ukernels,
44     enum xnn_operator_type operator_type,
45     xnn_operator_t* fully_connected_op_out)
46 {
47   xnn_operator_t fully_connected_op = NULL;
48   enum xnn_status status = xnn_status_uninitialized;
49 
50   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
51     xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
52       xnn_operator_type_to_string(operator_type));
53     goto error;
54   }
55 
56   status = xnn_status_invalid_parameter;
57 
58   if (input_channels == 0) {
59     xnn_log_error(
60       "failed to create %s operator with %zu input channels: number of channels must be non-zero",
61       xnn_operator_type_to_string(operator_type), input_channels);
62     goto error;
63   }
64 
65   if (output_channels == 0) {
66     xnn_log_error(
67       "failed to create %s operator with %zu output channels: number of channels must be non-zero",
68       xnn_operator_type_to_string(operator_type), output_channels);
69     goto error;
70   }
71 
72   if (input_stride < input_channels) {
73     xnn_log_error(
74       "failed to create %s operator with input element stride of %zu: "
75       "stride must be at least as large as the number of input channels (%zu)",
76       xnn_operator_type_to_string(operator_type), input_stride, input_channels);
77     goto error;
78   }
79 
80   if (output_stride < output_channels) {
81     xnn_log_error(
82       "failed to create %s operator with output element stride of %zu: "
83       "stride must be at least as large as the number of output channels (%zu)",
84       xnn_operator_type_to_string(operator_type), output_stride, output_channels);
85     goto error;
86   }
87 
88   status = xnn_status_out_of_memory;
89 
90   fully_connected_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
91   if (fully_connected_op == NULL) {
92     xnn_log_error(
93       "failed to allocate %zu bytes for %s operator descriptor",
94       sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
95     goto error;
96   }
97 
98   const uint32_t nr = gemm_parameters->nr;
99   const uint32_t kr = UINT32_C(1) << gemm_parameters->log2_kr;
100   const uint32_t sr = UINT32_C(1) << gemm_parameters->log2_sr;
101 
102   const size_t n_stride = round_up(output_channels, nr);
103   const size_t k_stride = round_up_po2(input_channels, kr);
104 
105   const size_t packed_weights_size = n_stride * (bias_element_size + (k_stride << log2_filter_element_size));
106   fully_connected_op->packed_weights = xnn_allocate_simd_memory(packed_weights_size);
107   if (fully_connected_op->packed_weights == NULL) {
108     xnn_log_error(
109       "failed to allocate %zu bytes for %s operator packed weights",
110       packed_weights_size, xnn_operator_type_to_string(operator_type));
111     goto error;
112   }
113   memset(fully_connected_op->packed_weights, packed_weights_padding_byte, packed_weights_size);
114 
115   if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
116     pack_gemm_io_w(
117       output_channels, input_channels,
118       nr, kr, sr,
119       kernel, bias,
120       fully_connected_op->packed_weights,
121       packing_params);
122   } else {
123     pack_gemm_goi_w(
124       1, output_channels, input_channels,
125       nr, kr, sr,
126       kernel, bias,
127       fully_connected_op->packed_weights,
128       packing_params);
129   }
130 
131   fully_connected_op->group_input_channels = input_channels;
132   fully_connected_op->group_output_channels = output_channels;
133   fully_connected_op->input_pixel_stride = input_stride;
134   fully_connected_op->output_pixel_stride = output_stride;
135 
136   memcpy(&fully_connected_op->params, params, params_size);
137   fully_connected_op->type = operator_type;
138 
139   fully_connected_op->ukernel.type = xnn_ukernel_type_gemm;
140   fully_connected_op->ukernel.gemm = (struct xnn_ukernel_gemm) {
141     .general_case = gemm_ukernels->gemm,
142     .mr1_case = gemm_ukernels->gemm1,
143     .mr = gemm_parameters->mr,
144     .nr = nr,
145     .kr = kr,
146   };
147 
148   fully_connected_op->state = xnn_run_state_invalid;
149 
150   *fully_connected_op_out = fully_connected_op;
151   return xnn_status_success;
152 
153 error:
154   xnn_delete_operator(fully_connected_op);
155   return status;
156 }
157 
setup_fully_connected_nc(xnn_operator_t fully_connected_op,size_t batch_size,const void * input,void * output,uint32_t log2_input_element_size,uint32_t log2_filter_element_size,uint32_t bias_element_size,uint32_t log2_output_element_size,const void * params,size_t params_size,size_t num_threads)158 static enum xnn_status setup_fully_connected_nc(
159   xnn_operator_t fully_connected_op,
160   size_t batch_size,
161   const void* input,
162   void* output,
163   uint32_t log2_input_element_size,
164   uint32_t log2_filter_element_size,
165   uint32_t bias_element_size,
166   uint32_t log2_output_element_size,
167   const void* params,
168   size_t params_size,
169   size_t num_threads)
170 {
171   fully_connected_op->state = xnn_run_state_invalid;
172 
173   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
174     xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
175       xnn_operator_type_to_string(fully_connected_op->type));
176     return xnn_status_uninitialized;
177   }
178 
179   if (batch_size == 0) {
180     fully_connected_op->state = xnn_run_state_skip;
181     return xnn_status_success;
182   }
183 
184   fully_connected_op->batch_size = 1;
185   fully_connected_op->input_height = batch_size;
186   fully_connected_op->input_width = 1;
187   fully_connected_op->input = input;
188 
189   fully_connected_op->output_height = batch_size;
190   fully_connected_op->output_width = 1;
191   fully_connected_op->output = output;
192 
193   const size_t input_channels = fully_connected_op->group_input_channels;
194   const size_t output_channels = fully_connected_op->group_output_channels;
195 
196   uint32_t mr = fully_connected_op->ukernel.gemm.mr;
197   const uint32_t nr = fully_connected_op->ukernel.gemm.nr;
198 
199   struct xnn_hmp_gemm_ukernel gemm_ukernel = fully_connected_op->ukernel.gemm.general_case;
200   if (batch_size == 1 && fully_connected_op->ukernel.gemm.mr1_case.function[XNN_UARCH_DEFAULT] != NULL) {
201     gemm_ukernel = fully_connected_op->ukernel.gemm.mr1_case;
202     mr = 1;
203   }
204 
205   fully_connected_op->context.gemm = (struct gemm_context) {
206     .k_scaled = input_channels << log2_input_element_size,
207     .w_stride = (round_up_po2(input_channels, fully_connected_op->ukernel.gemm.kr) << log2_input_element_size) + bias_element_size,
208     .a = input,
209     .a_stride = fully_connected_op->input_pixel_stride << log2_input_element_size,
210     .packed_w = fully_connected_op->packed_weights,
211     .c = output,
212     .cm_stride = fully_connected_op->output_pixel_stride << log2_output_element_size,
213     .cn_stride = nr << log2_output_element_size,
214     .log2_csize = log2_output_element_size,
215     .ukernel = gemm_ukernel,
216   };
217   memcpy(&fully_connected_op->context.gemm.params, params, params_size);
218 
219   size_t nc = output_channels;
220   if (num_threads > 1) {
221     const size_t num_other_tiles = divide_round_up(batch_size, mr);
222     const size_t target_tiles_per_thread = 5;
223     const size_t max_nc = divide_round_up(output_channels * num_other_tiles, num_threads * target_tiles_per_thread);
224     if (max_nc < nc) {
225       nc = min(nc, divide_round_up(nc, max_nc * nr) * nr);
226     }
227   }
228   fully_connected_op->compute.type = xnn_parallelization_type_2d_tile_2d;
229   fully_connected_op->compute.task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
230   fully_connected_op->compute.range[0] = batch_size;
231   fully_connected_op->compute.range[1] = output_channels;
232   fully_connected_op->compute.tile[0] = mr;
233   fully_connected_op->compute.tile[1] = nc;
234   fully_connected_op->state = xnn_run_state_ready;
235 
236   return xnn_status_success;
237 }
238 
xnn_create_fully_connected_nc_qu8(size_t input_channels,size_t output_channels,size_t input_stride,size_t output_stride,uint8_t input_zero_point,float input_scale,uint8_t kernel_zero_point,float kernel_scale,const uint8_t * kernel,const int32_t * bias,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * fully_connected_op_out)239 enum xnn_status xnn_create_fully_connected_nc_qu8(
240     size_t input_channels,
241     size_t output_channels,
242     size_t input_stride,
243     size_t output_stride,
244     uint8_t input_zero_point,
245     float input_scale,
246     uint8_t kernel_zero_point,
247     float kernel_scale,
248     const uint8_t* kernel,
249     const int32_t* bias,
250     uint8_t output_zero_point,
251     float output_scale,
252     uint8_t output_min,
253     uint8_t output_max,
254     uint32_t flags,
255     xnn_operator_t* fully_connected_op_out)
256 {
257   if (input_scale <= 0.0f || !isnormal(input_scale)) {
258     xnn_log_error(
259       "failed to create %s operator with %.7g input scale: scale must be finite, normalized, and positive",
260       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8), input_scale);
261     return xnn_status_invalid_parameter;
262   }
263 
264   if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) {
265     xnn_log_error(
266       "failed to create %s operator with %.7g kernel scale: scale must be finite, normalized, and positive",
267       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8), kernel_scale);
268     return xnn_status_invalid_parameter;
269   }
270 
271   if (output_scale <= 0.0f || !isnormal(output_scale)) {
272     xnn_log_error(
273       "failed to create %s operator with %.7g output scale: scale must be finite, normalized, and positive",
274       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8), output_scale);
275     return xnn_status_invalid_parameter;
276   }
277 
278   if (output_min >= output_max) {
279     xnn_log_error(
280       "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: range min must be below range max",
281       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8), output_min, output_max);
282     return xnn_status_invalid_parameter;
283   }
284 
285   const float requantization_scale = input_scale * kernel_scale / output_scale;
286   if (requantization_scale >= 1.0f) {
287     xnn_log_error(
288       "failed to create %s operator with %.7g input scale, %.7g kernel scale, and %.7g output scale: "
289       "requantization scale %.7g is greater or equal to 1.0",
290       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8),
291       input_scale, kernel_scale, output_scale, requantization_scale);
292     return xnn_status_unsupported_parameter;
293   }
294 
295   const union xnn_qu8_gemm_params params = xnn_init_qu8_gemm_params(
296     kernel_zero_point, requantization_scale, output_zero_point, output_min, output_max);
297   const struct xnn_qu8_packing_params packing_params = {
298     .input_zero_point = input_zero_point,
299     .kernel_zero_point = kernel_zero_point,
300   };
301   return create_fully_connected_nc(
302     input_channels, output_channels,
303     input_stride, output_stride,
304     kernel, bias, flags,
305     0 /* log2(sizeof(filter element)) = log2(sizeof(uint8_t)) */,
306     sizeof(int32_t) /* sizeof(bias element) */,
307     (xnn_pack_gemm_io_w_function) xnn_pack_qu8_gemm_io_w,
308     (xnn_pack_gemm_goi_w_function) xnn_pack_qu8_gemm_goi_w,
309     &packing_params, kernel_zero_point /* packed weights padding byte */,
310     &params, sizeof(params),
311     &xnn_params.qu8.gemm, &xnn_params.qu8.gemm.minmax,
312     xnn_operator_type_fully_connected_nc_qu8,
313     fully_connected_op_out);
314 }
315 
xnn_create_fully_connected_nc_f32(size_t input_channels,size_t output_channels,size_t input_stride,size_t output_stride,const float * kernel,const float * bias,float output_min,float output_max,uint32_t flags,xnn_operator_t * fully_connected_op_out)316 enum xnn_status xnn_create_fully_connected_nc_f32(
317     size_t input_channels,
318     size_t output_channels,
319     size_t input_stride,
320     size_t output_stride,
321     const float* kernel,
322     const float* bias,
323     float output_min,
324     float output_max,
325     uint32_t flags,
326     xnn_operator_t* fully_connected_op_out)
327 {
328   if (isnan(output_min)) {
329     xnn_log_error(
330       "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
331       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32));
332     return xnn_status_invalid_parameter;
333   }
334 
335   if (isnan(output_max)) {
336     xnn_log_error(
337       "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
338       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32));
339     return xnn_status_invalid_parameter;
340   }
341 
342   if (output_min >= output_max) {
343     xnn_log_error(
344       "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
345       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32), output_min, output_max);
346     return xnn_status_invalid_parameter;
347   }
348 
349   const struct gemm_fused_ukernels* gemm_ukernels = &xnn_params.f32.gemm.minmax;
350   const bool linear_activation = (output_max == INFINITY) && (output_min == -output_max);
351   if (linear_activation && xnn_params.f32.gemm.linear.gemm.function[XNN_UARCH_DEFAULT] != NULL) {
352     gemm_ukernels = &xnn_params.f32.gemm.linear;
353   }
354 
355   const union xnn_f32_minmax_params params = xnn_init_f32_minmax_params(output_min, output_max);
356   return create_fully_connected_nc(
357     input_channels, output_channels,
358     input_stride, output_stride,
359     kernel, bias, flags,
360     2 /* log2(sizeof(filter element)) = log2(sizeof(float)) */,
361     sizeof(float) /* sizeof(bias element) */,
362     (xnn_pack_gemm_io_w_function) xnn_pack_f32_gemm_io_w,
363     (xnn_pack_gemm_goi_w_function) xnn_pack_f32_gemm_goi_w,
364     NULL /* packing params */, 0 /* packed weights padding byte */,
365     &params, sizeof(params),
366     &xnn_params.f32.gemm, gemm_ukernels,
367     xnn_operator_type_fully_connected_nc_f32,
368     fully_connected_op_out);
369 }
370 
xnn_setup_fully_connected_nc_qu8(xnn_operator_t fully_connected_op,size_t batch_size,const uint8_t * input,uint8_t * output,pthreadpool_t threadpool)371 enum xnn_status xnn_setup_fully_connected_nc_qu8(
372     xnn_operator_t fully_connected_op,
373     size_t batch_size,
374     const uint8_t* input,
375     uint8_t* output,
376     pthreadpool_t threadpool)
377 {
378   if (fully_connected_op->type != xnn_operator_type_fully_connected_nc_qu8) {
379     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
380       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8),
381       xnn_operator_type_to_string(fully_connected_op->type));
382     return xnn_status_invalid_parameter;
383   }
384 
385   return setup_fully_connected_nc(
386     fully_connected_op,
387     batch_size,
388     input, output,
389     0 /* log2(sizeof(input element)) = log2(sizeof(uint8_t)) */,
390     0 /* log2(sizeof(filter element)) = log2(sizeof(uint8_t)) */,
391     sizeof(int32_t) /* sizeof(bias element) */,
392     0 /* log2(sizeof(output element)) = log2(sizeof(uint8_t)) */,
393     &fully_connected_op->params.qu8_gemm,
394     sizeof(fully_connected_op->params.qu8_gemm),
395     pthreadpool_get_threads_count(threadpool));
396 }
397 
xnn_setup_fully_connected_nc_f32(xnn_operator_t fully_connected_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)398 enum xnn_status xnn_setup_fully_connected_nc_f32(
399     xnn_operator_t fully_connected_op,
400     size_t batch_size,
401     const float* input,
402     float* output,
403     pthreadpool_t threadpool)
404 {
405   if (fully_connected_op->type != xnn_operator_type_fully_connected_nc_f32) {
406     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
407       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32),
408       xnn_operator_type_to_string(fully_connected_op->type));
409     return xnn_status_invalid_parameter;
410   }
411 
412   return setup_fully_connected_nc(
413     fully_connected_op,
414     batch_size,
415     input, output,
416     2 /* log2(sizeof(input element)) = log2(sizeof(float)) */,
417     2 /* log2(sizeof(filter element)) = log2(sizeof(float)) */,
418     sizeof(float) /* sizeof(bias element) */,
419     2 /* log2(sizeof(output element)) = log2(sizeof(float)) */,
420     &fully_connected_op->params.f32_minmax,
421     sizeof(fully_connected_op->params.f32_minmax),
422     pthreadpool_get_threads_count(threadpool));
423 }
424