1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #include <stdint.h>
12 #include <stddef.h>
13 
14 #include <xnnpack/common.h>
15 #include <xnnpack/operator.h>
16 
17 
18 #ifdef __cplusplus
19 extern "C" {
20 #endif
21 
22 
23 struct xnn_qu8_packing_params {
24   uint8_t input_zero_point;
25   uint8_t kernel_zero_point;
26 };
27 
28 struct xnn_qs8_packing_params {
29   int8_t input_zero_point;
30 };
31 
32 
33 typedef void (*xnn_pack_gemm_goi_w_function)(
34   size_t g,
35   size_t nc,
36   size_t kc,
37   size_t nr,
38   size_t kr,
39   size_t sr,
40   const void* k,
41   const void* b,
42   void* packed_w,
43   const void* params);
44 
45 XNN_INTERNAL void xnn_pack_f32_gemm_goi_w(
46   size_t g,
47   size_t nc,
48   size_t kc,
49   size_t nr,
50   size_t kr,
51   size_t sr,
52   const float* k,
53   const float* b,
54   float* packed_w,
55   const void* params);
56 
57 XNN_INTERNAL void xnn_pack_f16_gemm_goi_w(
58   size_t g,
59   size_t nc,
60   size_t kc,
61   size_t nr,
62   size_t kr,
63   size_t sr,
64   const uint16_t* k,
65   const uint16_t* b,
66   uint16_t* packed_w,
67   const void* params);
68 
69 XNN_INTERNAL void xnn_pack_qu8_gemm_goi_w(
70   size_t g,
71   size_t nc,
72   size_t kc,
73   size_t nr,
74   size_t kr,
75   size_t sr,
76   const uint8_t* k,
77   const int32_t* b,
78   void* packed_w,
79   const struct xnn_qu8_packing_params* params);
80 
81 XNN_INTERNAL void xnn_pack_qs8_gemm_goi_w(
82   size_t g,
83   size_t nc,
84   size_t kc,
85   size_t nr,
86   size_t kr,
87   size_t sr,
88   const int8_t* k,
89   const int32_t* b,
90   void* packed_w,
91   const struct xnn_qs8_packing_params* params);
92 
93 XNN_INTERNAL void xnn_pack_qs8_gemm_xw_goi_w(
94   size_t g,
95   size_t nc,
96   size_t kc,
97   size_t nr,
98   size_t kr,
99   size_t sr,
100   const int8_t* k,
101   const int32_t* b,
102   void* packed_w,
103   const struct xnn_qs8_packing_params* params);
104 
105 
106 typedef void (*xnn_pack_gemm_io_w_function)(
107   size_t nc,
108   size_t kc,
109   size_t nr,
110   size_t kr,
111   size_t sr,
112   const void* k,
113   const void* b,
114   void* packed_w,
115   const void* params);
116 
117 XNN_INTERNAL void xnn_pack_f32_gemm_io_w(
118   size_t nc,
119   size_t kc,
120   size_t nr,
121   size_t kr,
122   size_t sr,
123   const float* k,
124   const float* b,
125   float* packed_w,
126   const void* params);
127 
128 XNN_INTERNAL void xnn_pack_f16_gemm_io_w(
129   size_t nc,
130   size_t kc,
131   size_t nr,
132   size_t kr,
133   size_t sr,
134   const uint16_t* k,
135   const uint16_t* b,
136   uint16_t* packed_w,
137   const void* params);
138 
139 XNN_INTERNAL void xnn_pack_qu8_gemm_io_w(
140   size_t nc,
141   size_t kc,
142   size_t nr,
143   size_t kr,
144   size_t sr,
145   const uint8_t* k,
146   const int32_t* b,
147   void* packed_w,
148   const struct xnn_qu8_packing_params* params);
149 
150 
151 typedef void (*xnn_pack_conv_goki_w_function)(
152   size_t g,
153   size_t nc,
154   size_t ks,
155   size_t kc,
156   size_t nr,
157   size_t kr,
158   size_t sr,
159   const void* k,
160   const void* b,
161   void* packed_w,
162   const void* params);
163 
164 XNN_INTERNAL void xnn_pack_f32_conv_goki_w(
165   size_t g,
166   size_t nc,
167   size_t ks,
168   size_t kc,
169   size_t nr,
170   size_t kr,
171   size_t sr,
172   const float* k,
173   const float* b,
174   float* packed_w,
175   const void* params);
176 
177 XNN_INTERNAL void xnn_pack_f16_conv_goki_w(
178   size_t g,
179   size_t nc,
180   size_t ks,
181   size_t kc,
182   size_t nr,
183   size_t kr,
184   size_t sr,
185   const uint16_t* k,
186   const uint16_t* b,
187   uint16_t* packed_w,
188   const void* params);
189 
190 XNN_INTERNAL void xnn_pack_qu8_conv_goki_w(
191   size_t g,
192   size_t nc,
193   size_t ks,
194   size_t kc,
195   size_t nr,
196   size_t kr,
197   size_t sr,
198   const uint8_t* k,
199   const int32_t* b,
200   void* packed_w,
201   const struct xnn_qu8_packing_params* params);
202 
203 XNN_INTERNAL void xnn_pack_qs8_conv_goki_w(
204   size_t g,
205   size_t nc,
206   size_t ks,
207   size_t kc,
208   size_t nr,
209   size_t kr,
210   size_t sr,
211   const int8_t* k,
212   const int32_t* b,
213   void* packed_w,
214   const struct xnn_qs8_packing_params* params);
215 
216 
217 typedef void (*xnn_pack_conv_kgo_w_function)(
218   size_t g,
219   size_t nc,
220   size_t ks,
221   size_t nr,
222   size_t kr,
223   const void* k,
224   const void* b,
225   void* packed_w,
226   const void* params);
227 
228 XNN_INTERNAL void xnn_pack_f32_conv_kgo_w(
229   size_t g,
230   size_t nc,
231   size_t ks,
232   size_t nr,
233   size_t kr,
234   const float* k,
235   const float* b,
236   float* packed_w,
237   const void* params);
238 
239 XNN_INTERNAL void xnn_pack_f16_conv_kgo_w(
240   size_t g,
241   size_t nc,
242   size_t ks,
243   size_t nr,
244   size_t kr,
245   const uint16_t* k,
246   const uint16_t* b,
247   uint16_t* packed_w,
248   const void* params);
249 
250 XNN_INTERNAL void xnn_pack_qu8_conv_kgo_w(
251   size_t g,
252   size_t nc,
253   size_t ks,
254   size_t nr,
255   size_t kr,
256   const uint8_t* k,
257   const int32_t* b,
258   void* packed_w,
259   const struct xnn_qu8_packing_params* params);
260 
261 XNN_INTERNAL void xnn_pack_qs8_conv_kgo_w(
262   size_t g,
263   size_t nc,
264   size_t ks,
265   size_t nr,
266   size_t kr,
267   const int8_t* k,
268   const int32_t* b,
269   void* packed_w,
270   const struct xnn_qs8_packing_params* params);
271 
272 
273 typedef void (*xnn_pack_deconv_goki_w_function)(
274   size_t g,
275   size_t nc,
276   size_t kh,
277   size_t kw,
278   size_t kc,
279   size_t sh,
280   size_t sw,
281   size_t nr,
282   size_t kr,
283   size_t sr,
284   const void* k,
285   const void* b,
286   void* packed_w,
287   struct subconvolution_params* subconv_params,
288   const void* params);
289 
290 XNN_INTERNAL void xnn_pack_f32_deconv_goki_w(
291   size_t g,
292   size_t nc,
293   size_t kh,
294   size_t kw,
295   size_t kc,
296   size_t sh,
297   size_t sw,
298   size_t nr,
299   size_t kr,
300   size_t sr,
301   const float* k,
302   const float* b,
303   float* packed_w,
304   struct subconvolution_params* subconv_params,
305   const void* params);
306 
307 XNN_INTERNAL void xnn_pack_f16_deconv_goki_w(
308   size_t g,
309   size_t nc,
310   size_t kh,
311   size_t kw,
312   size_t kc,
313   size_t sh,
314   size_t sw,
315   size_t nr,
316   size_t kr,
317   size_t sr,
318   const uint16_t* k,
319   const uint16_t* b,
320   uint16_t* packed_w,
321   struct subconvolution_params* subconv_params,
322   const void* params);
323 
324 XNN_INTERNAL void xnn_pack_qu8_deconv_goki_w(
325   size_t g,
326   size_t nc,
327   size_t kh,
328   size_t kw,
329   size_t kc,
330   size_t sh,
331   size_t sw,
332   size_t nr,
333   size_t kr,
334   size_t sr,
335   const uint8_t* k,
336   const int32_t* b,
337   void* packed_w,
338   struct subconvolution_params* subconv_params,
339   const struct xnn_qu8_packing_params* params);
340 
341 
342 typedef void (*xnn_pack_dwconv_ghw_w_function)(
343   size_t h,
344   size_t w,
345   size_t c,
346   size_t cr,
347   const void* k,
348   const void* b,
349   void* packed_w,
350   const void* params);
351 
352 XNN_INTERNAL void xnn_pack_f32_dwconv_ghw_w(
353   size_t h,
354   size_t w,
355   size_t c,
356   size_t cr,
357   const float* k,
358   const float* b,
359   float* packed_w,
360   const void* params);
361 
362 XNN_INTERNAL void xnn_pack_f16_dwconv_ghw_w(
363   size_t h,
364   size_t w,
365   size_t c,
366   size_t cr,
367   const uint16_t* k,
368   const uint16_t* b,
369   uint16_t* packed_w,
370   const void* params);
371 
372 XNN_INTERNAL void xnn_pack_qu8_dwconv_ghw_w(
373   size_t h,
374   size_t w,
375   size_t c,
376   size_t cr,
377   const uint8_t* k,
378   const int32_t* b,
379   void* packed_w,
380   const struct xnn_qu8_packing_params* params);
381 
382 XNN_INTERNAL void xnn_pack_qs8_dwconv_ghw_w(
383   size_t h,
384   size_t w,
385   size_t c,
386   size_t cr,
387   const int8_t* k,
388   const int32_t* b,
389   void* packed_w,
390   const struct xnn_qs8_packing_params* params);
391 
392 
393 typedef void (*xnn_pack_dwconv_hwg_w_function)(
394   size_t h,
395   size_t w,
396   size_t c,
397   size_t cr,
398   const void* k,
399   const void* b,
400   void* packed_w,
401   const void* params);
402 
403 XNN_INTERNAL void xnn_pack_f32_dwconv_hwg_w(
404   size_t h,
405   size_t w,
406   size_t c,
407   size_t cr,
408   const float* k,
409   const float* b,
410   float* packed_w,
411   const void* params);
412 
413 XNN_INTERNAL void xnn_pack_f16_dwconv_hwg_w(
414   size_t h,
415   size_t w,
416   size_t c,
417   size_t cr,
418   const uint16_t* k,
419   const uint16_t* b,
420   uint16_t* packed_w,
421   const void* params);
422 
423 XNN_INTERNAL void xnn_pack_qu8_dwconv_hwg_w(
424   size_t h,
425   size_t w,
426   size_t c,
427   size_t cr,
428   const uint8_t* k,
429   const int32_t* b,
430   void* packed_w,
431   const struct xnn_qu8_packing_params* params);
432 
433 XNN_INTERNAL void xnn_pack_qs8_dwconv_hwg_w(
434   size_t h,
435   size_t w,
436   size_t c,
437   size_t cr,
438   const int8_t* k,
439   const int32_t* b,
440   void* packed_w,
441   const struct xnn_qs8_packing_params* params);
442 
443 
444 XNN_INTERNAL void xnn_pack_f32_gemminc_goi_w(
445   size_t g,
446   size_t nc,
447   size_t kc,
448   size_t nr,
449   size_t kr,
450   size_t sr,
451   const float* k,
452   float* packed_w,
453   const void* params);
454 
455 XNN_INTERNAL void xnn_pack_f16_gemminc_goi_w(
456   size_t g,
457   size_t nc,
458   size_t kc,
459   size_t nr,
460   size_t kr,
461   size_t sr,
462   const uint16_t* k,
463   uint16_t* packed_w,
464   const void* params);
465 
466 
467 XNN_INTERNAL void xnn_pack_f32_dconv_oki_w(
468   size_t nc,
469   size_t kc,
470   size_t nr,
471   size_t kh,
472   size_t kw,
473   const float* k,
474   const float* b,
475   float* packed_w,
476   const void* params);
477 
478 XNN_INTERNAL void xnn_pack_f16_dconv_oki_w(
479   size_t nc,
480   size_t kc,
481   size_t nr,
482   size_t kh,
483   size_t kw,
484   const uint16_t* k,
485   const uint16_t* b,
486   uint16_t* packed_w,
487   const void* params);
488 
489 
490 XNN_INTERNAL void xnn_pack_f32_chw_dwconv_ghw_w(
491   size_t kernel_size,
492   size_t groups,
493   const float* kernel,
494   const float* bias,
495   float* packed_weights,
496   const void* params);
497 
498 XNN_INTERNAL void xnn_pack_f16_chw_dwconv_ghw_w(
499   size_t kernel_size,
500   size_t groups,
501   const uint16_t* kernel,
502   const uint16_t* bias,
503   uint16_t* packed_weights,
504   const void* params);
505 
506 
507 XNN_INTERNAL void xnn_pack_f32_chw_dwconv_hwg_w(
508   size_t kernel_size,
509   size_t groups,
510   const float* kernel,
511   const float* bias,
512   float* packed_weights,
513   const void* params);
514 
515 
516 typedef void (*xnn_pack_vmulcaddc_w_function)(
517   size_t c,
518   size_t cr,
519   const void* s,
520   const void* b,
521   void* packed_w,
522   const void* params);
523 
524 XNN_INTERNAL void xnn_pack_f32_vmulcaddc_w(
525   size_t c,
526   size_t cr,
527   const float* s,
528   const float* b,
529   float* packed_w,
530   const void* params);
531 
532 XNN_INTERNAL void xnn_pack_f16_vmulcaddc_w(
533   size_t c,
534   size_t cr,
535   const uint16_t* s,
536   const uint16_t* b,
537   uint16_t* packed_w,
538   const void* params);
539 
540 #ifdef __cplusplus
541 }  // extern "C"
542 #endif
543