1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #if defined(__cplusplus) && (__cplusplus >= 201103L)
12   #include <cstdint>
13   #include <cstddef>
14   #include <cassert>
15   #include <cmath>
16 #else
17   #include <stdint.h>
18   #include <stddef.h>
19   #include <assert.h>
20   #include <math.h>
21 #endif
22 
23 #include <fp16.h>
24 
25 #include <xnnpack/common.h>
26 #include <xnnpack/math.h>
27 #include <xnnpack/params.h>
28 
29 
xnn_init_scalar_qu8_gemm_params(uint8_t kernel_zero_point,float scale,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max)30 static inline union xnn_qu8_gemm_params xnn_init_scalar_qu8_gemm_params(
31   uint8_t kernel_zero_point,
32   float scale,
33   uint8_t output_zero_point,
34   uint8_t output_min,
35   uint8_t output_max)
36 {
37   // Compute requantization parameters
38   const uint32_t scale_bits = fp32_to_bits(scale);
39 
40   // Multiplier is in [0x40000000, 0x7FFFFF80] range.
41   const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
42   assert(multiplier >= INT32_C(0x40000000));
43   assert(multiplier <= INT32_C(0x7FFFFF80));
44 
45   // Shift is in [0, 31] range.
46   const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
47   assert(shift >= 0);
48   assert(shift < 32);
49 
50   const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
51   const uint32_t remainder_threshold = remainder_mask >> 1;
52 
53   union xnn_qu8_gemm_params params;
54   params.scalar.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
55   params.scalar.multiplier = multiplier;
56   params.scalar.remainder_mask = (int32_t) remainder_mask;
57   params.scalar.remainder_threshold = (int32_t) remainder_threshold;
58   params.scalar.shift = (uint32_t) shift;
59   params.scalar.output_min_less_zero_point =
60     (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
61   params.scalar.output_max_less_zero_point =
62     (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
63   params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
64   return params;
65 }
66 
xnn_init_qu8_gemm_params(uint8_t kernel_zero_point,float scale,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max)67 static inline union xnn_qu8_gemm_params xnn_init_qu8_gemm_params(
68   uint8_t kernel_zero_point,
69   float scale,
70   uint8_t output_zero_point,
71   uint8_t output_min,
72   uint8_t output_max)
73 {
74   // Compute requantization parameters.
75   const uint32_t scale_bits = fp32_to_bits(scale);
76 
77   // Multiplier is in [0x40000000, 0x7FFFFF80] range.
78   const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
79   assert(multiplier >= INT32_C(0x40000000));
80   assert(multiplier <= INT32_C(0x7FFFFF80));
81 
82   // Shift is in [0, 31] range.
83   const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
84   assert(shift >= 0);
85   assert(shift < 32);
86 
87   union xnn_qu8_gemm_params params;
88   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
89     const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
90     const uint32_t remainder_threshold = remainder_mask >> 1;
91     for (uint32_t i = 0; i < 8; i++) {
92       params.sse2.kernel_zero_point[i] = (int16_t) (uint16_t) kernel_zero_point;
93     }
94     params.sse2.multiplier[0] = multiplier;
95     params.sse2.multiplier[1] = multiplier;
96     params.sse2.multiplier[2] = multiplier;
97     params.sse2.multiplier[3] = multiplier;
98     params.sse2.rounding[0] = UINT64_C(0x40000000);
99     params.sse2.rounding[1] = UINT64_C(0x40000000);
100     params.sse2.remainder_mask[0] = (int32_t) remainder_mask;
101     params.sse2.remainder_mask[1] = (int32_t) remainder_mask;
102     params.sse2.remainder_mask[2] = (int32_t) remainder_mask;
103     params.sse2.remainder_mask[3] = (int32_t) remainder_mask;
104     params.sse2.remainder_threshold[0] = (int32_t) remainder_threshold;
105     params.sse2.remainder_threshold[1] = (int32_t) remainder_threshold;
106     params.sse2.remainder_threshold[2] = (int32_t) remainder_threshold;
107     params.sse2.remainder_threshold[3] = (int32_t) remainder_threshold;
108     params.sse2.shift[0] = (uint64_t) (uint32_t) shift;
109     params.sse2.shift[1] = (uint64_t) (uint32_t) shift;
110     for (uint32_t i = 0; i < 8; i++) {
111       params.sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
112     }
113     for (uint32_t i = 0; i < 16; i++) {
114       params.sse2.output_min[i] = output_min;
115       params.sse2.output_max[i] = output_max;
116     }
117   #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
118     params.neon.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
119     params.neon.multiplier = multiplier;
120     params.neon.right_shift = -shift;
121     params.neon.output_zero_point = (int16_t) (uint16_t) output_zero_point;
122     params.neon.output_min = output_min;
123     params.neon.output_max = output_max;
124   #else
125     const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
126     const uint32_t remainder_threshold = remainder_mask >> 1;
127     params.scalar.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
128     params.scalar.multiplier = multiplier;
129     params.scalar.remainder_mask = (int32_t) remainder_mask;
130     params.scalar.remainder_threshold = (int32_t) remainder_threshold;
131     params.scalar.shift = (uint32_t) shift;
132     params.scalar.output_min_less_zero_point =
133       (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
134     params.scalar.output_max_less_zero_point =
135       (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
136     params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
137   #endif
138   return params;
139 }
140 
xnn_init_scalar_qs8_gemm_params(float scale,int8_t output_zero_point,int8_t output_min,int8_t output_max)141 static inline union xnn_qs8_gemm_params xnn_init_scalar_qs8_gemm_params(
142   float scale,
143   int8_t output_zero_point,
144   int8_t output_min,
145   int8_t output_max)
146 {
147   // Compute requantization parameters
148   const uint32_t scale_bits = fp32_to_bits(scale);
149 
150   // Multiplier is in [0x40000000, 0x7FFFFF80] range.
151   const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
152   assert(multiplier >= INT32_C(0x40000000));
153   assert(multiplier <= INT32_C(0x7FFFFF80));
154 
155   // Shift is in [0, 31] range.
156   const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
157   assert(shift >= 0);
158   assert(shift < 32);
159 
160   const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
161   const uint32_t remainder_threshold = remainder_mask >> 1;
162 
163   union xnn_qs8_gemm_params params;
164   params.scalar.multiplier = multiplier;
165   params.scalar.remainder_mask = (int32_t) remainder_mask;
166   params.scalar.remainder_threshold = (int32_t) remainder_threshold;
167   params.scalar.shift = (uint32_t) shift;
168   params.scalar.output_min_less_zero_point = (int32_t) output_min - (int32_t) output_zero_point;
169   params.scalar.output_max_less_zero_point = (int32_t) output_max - (int32_t) output_zero_point;
170   params.scalar.output_zero_point = (int32_t) output_zero_point;
171   return params;
172 }
173 
xnn_init_qs8_gemm_params(float scale,int8_t output_zero_point,int8_t output_min,int8_t output_max)174 static inline union xnn_qs8_gemm_params xnn_init_qs8_gemm_params(
175   float scale,
176   int8_t output_zero_point,
177   int8_t output_min,
178   int8_t output_max)
179 {
180   // Compute requantization parameters.
181   const uint32_t scale_bits = fp32_to_bits(scale);
182 
183   // Multiplier is in [0x40000000, 0x7FFFFF80] range.
184   const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
185   assert(multiplier >= INT32_C(0x40000000));
186   assert(multiplier <= INT32_C(0x7FFFFF80));
187 
188   // Shift is in [0, 31] range.
189   const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
190   assert(shift >= 0);
191   assert(shift < 32);
192 
193   union xnn_qs8_gemm_params params;
194   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
195     const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
196     const uint32_t remainder_threshold = remainder_mask >> 1;
197     params.sse2.multiplier[0] = multiplier;
198     params.sse2.multiplier[1] = multiplier;
199     params.sse2.multiplier[2] = multiplier;
200     params.sse2.multiplier[3] = multiplier;
201     params.sse2.rounding[0] = UINT64_C(0x40000000);
202     params.sse2.rounding[1] = UINT64_C(0x40000000);
203     params.sse2.remainder_mask[0] = (int32_t) remainder_mask;
204     params.sse2.remainder_mask[1] = (int32_t) remainder_mask;
205     params.sse2.remainder_mask[2] = (int32_t) remainder_mask;
206     params.sse2.remainder_mask[3] = (int32_t) remainder_mask;
207     params.sse2.remainder_threshold[0] = (int32_t) remainder_threshold;
208     params.sse2.remainder_threshold[1] = (int32_t) remainder_threshold;
209     params.sse2.remainder_threshold[2] = (int32_t) remainder_threshold;
210     params.sse2.remainder_threshold[3] = (int32_t) remainder_threshold;
211     params.sse2.shift[0] = (uint64_t) (uint32_t) shift;
212     params.sse2.shift[1] = (uint64_t) (uint32_t) shift;
213     for (uint32_t i = 0; i < 8; i++) {
214       params.sse2.output_zero_point[i] = (int16_t) output_zero_point;
215       params.sse2.output_min[i] = (int16_t) output_min;
216       params.sse2.output_max[i] = (int16_t) output_max;
217     }
218   #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
219     params.neon.multiplier = multiplier;
220     params.neon.right_shift = -shift;
221     params.neon.output_zero_point = (int16_t) output_zero_point;
222     params.neon.output_min = output_min;
223     params.neon.output_max = output_max;
224   #elif XNN_ARCH_WASMSIMD
225     const int64_t twice_multiplier = INT64_C(2) * (int64_t) multiplier;
226     const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
227     const uint32_t remainder_threshold = remainder_mask >> 1;
228     params.wasmsimd.multiplier[0] = twice_multiplier;
229     params.wasmsimd.multiplier[1] = twice_multiplier;
230     params.wasmsimd.rounding[0] = INT64_C(0x80000000);
231     params.wasmsimd.rounding[1] = INT64_C(0x80000000);
232     params.wasmsimd.remainder_mask[0] = (int32_t) remainder_mask;
233     params.wasmsimd.remainder_mask[1] = (int32_t) remainder_mask;
234     params.wasmsimd.remainder_mask[2] = (int32_t) remainder_mask;
235     params.wasmsimd.remainder_mask[3] = (int32_t) remainder_mask;
236     params.wasmsimd.remainder_threshold[0] = (int32_t) remainder_threshold;
237     params.wasmsimd.remainder_threshold[1] = (int32_t) remainder_threshold;
238     params.wasmsimd.remainder_threshold[2] = (int32_t) remainder_threshold;
239     params.wasmsimd.remainder_threshold[3] = (int32_t) remainder_threshold;
240     params.wasmsimd.shift = shift;
241     for (uint32_t i = 0; i < 8; i++) {
242       params.wasmsimd.output_zero_point[i] = (int16_t) output_zero_point;
243     }
244     for (uint32_t i = 0; i < 16; i++) {
245       params.wasmsimd.output_min[i] = output_min;
246       params.wasmsimd.output_max[i] = output_max;
247     }
248   #else
249     const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
250     const uint32_t remainder_threshold = remainder_mask >> 1;
251     params.scalar.multiplier = multiplier;
252     params.scalar.remainder_mask = (int32_t) remainder_mask;
253     params.scalar.remainder_threshold = (int32_t) remainder_threshold;
254     params.scalar.shift = (uint32_t) shift;
255     params.scalar.output_min_less_zero_point = (int32_t) output_min - (int32_t) output_zero_point;
256     params.scalar.output_max_less_zero_point = (int32_t) output_max - (int32_t) output_zero_point;
257     params.scalar.output_zero_point = (int32_t) output_zero_point;
258   #endif
259   return params;
260 }
261 
xnn_init_scalar_qs8_gemm_xw_params(float scale,int8_t output_zero_point,int8_t output_min,int8_t output_max)262 static inline union xnn_qs8_gemm_xw_params xnn_init_scalar_qs8_gemm_xw_params(
263   float scale,
264   int8_t output_zero_point,
265   int8_t output_min,
266   int8_t output_max)
267 {
268   union {
269     union xnn_qs8_gemm_xw_params gemm_xw;
270     union xnn_qs8_gemm_params gemm;
271   } params;
272   params.gemm = xnn_init_scalar_qs8_gemm_params(scale, output_zero_point, output_min, output_max);
273   return params.gemm_xw;
274 }
275 
xnn_init_qs8_gemm_xw_params(float scale,int8_t output_zero_point,int8_t output_min,int8_t output_max)276 static inline union xnn_qs8_gemm_xw_params xnn_init_qs8_gemm_xw_params(
277   float scale,
278   int8_t output_zero_point,
279   int8_t output_min,
280   int8_t output_max)
281 {
282   union {
283     union xnn_qs8_gemm_xw_params gemm_xw;
284     union xnn_qs8_gemm_params gemm;
285   } params;
286   params.gemm = xnn_init_qs8_gemm_params(scale, output_zero_point, output_min, output_max);
287   return params.gemm_xw;
288 }
289 
xnn_init_qu8_avgpool_params(int32_t bias,float scale,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max)290 static inline union xnn_qu8_avgpool_params xnn_init_qu8_avgpool_params(
291   int32_t bias,
292   float scale,
293   uint8_t output_zero_point,
294   uint8_t output_min,
295   uint8_t output_max)
296 {
297   // Compute requantization parameters.
298   assert(scale >= 0x1.0p-32f);
299   assert(scale < 256.0f);
300   const uint32_t scale_bits = fp32_to_bits(scale);
301 
302   // Multiplier is in [0x00800000, 0x00FFFFFF] range.
303   const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
304   assert(multiplier >= INT32_C(0x00800000));
305   assert(multiplier <= INT32_C(0x00FFFFFF));
306 
307   // Shift is in [16, 55] range.
308   const int32_t shift = 127 + 23 - (scale_bits >> 23);
309   assert(shift >= 16);
310   assert(shift < 64);
311 
312   union xnn_qu8_avgpool_params params;
313   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
314     const uint32_t right_shift = (uint32_t) shift;
315     const uint64_t rounding = UINT64_C(1) << (right_shift - 1);
316     params.sse2.bias[0] = bias;
317     params.sse2.bias[1] = bias;
318     params.sse2.bias[2] = bias;
319     params.sse2.bias[3] = bias;
320     params.sse2.multiplier[0] = (uint32_t) multiplier;
321     params.sse2.multiplier[1] = (uint32_t) multiplier;
322     params.sse2.multiplier[2] = (uint32_t) multiplier;
323     params.sse2.multiplier[3] = (uint32_t) multiplier;
324     params.sse2.rounding[0] = rounding;
325     params.sse2.rounding[1] = rounding;
326     params.sse2.right_shift[0] = (uint64_t) right_shift;
327     params.sse2.right_shift[1] = (uint64_t) right_shift;
328     for (uint32_t i = 0; i < 8; i++) {
329       params.sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
330     }
331     for (uint32_t i = 0; i < 16; i++) {
332       params.sse2.output_min[i] = output_min;
333       params.sse2.output_max[i] = output_max;
334     }
335   #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
336     params.neon.bias = bias;
337     params.neon.multiplier = multiplier;
338     params.neon.left_shift = (int64_t) -shift;
339     params.neon.output_zero_point = (int16_t) (uint16_t) output_zero_point;
340     params.neon.output_min = output_min;
341     params.neon.output_max = output_max;
342   #else
343     const uint32_t right_shift = (uint32_t) shift;
344     const int64_t rounding = INT64_C(1) << (right_shift - 1);
345     params.scalar.bias = bias;
346     params.scalar.multiplier = multiplier;
347     params.scalar.rounding = rounding;
348     params.scalar.right_shift = right_shift;
349     params.scalar.output_min_less_zero_point =
350       (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
351     params.scalar.output_max_less_zero_point =
352       (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
353     params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
354   #endif
355   return params;
356 }
357 
xnn_init_scalar_qu8_avgpool_params(int32_t bias,float scale,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max)358 static inline union xnn_qu8_avgpool_params xnn_init_scalar_qu8_avgpool_params(
359   int32_t bias,
360   float scale,
361   uint8_t output_zero_point,
362   uint8_t output_min,
363   uint8_t output_max)
364 {
365   // Compute requantization parameters.
366   assert(scale >= 0x1.0p-32f);
367   assert(scale < 256.0f);
368   const uint32_t scale_bits = fp32_to_bits(scale);
369 
370   // Multiplier is in [0x00800000, 0x00FFFFFF] range.
371   const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
372   assert(multiplier >= INT32_C(0x00800000));
373   assert(multiplier <= INT32_C(0x00FFFFFF));
374 
375   // Shift is in [16, 55] range.
376   const int32_t shift = 127 + 23 - (scale_bits >> 23);
377   assert(shift >= 16);
378   assert(shift < 64);
379 
380   union xnn_qu8_avgpool_params params;
381   const uint32_t right_shift = (uint32_t) shift;
382   const int64_t rounding = INT64_C(1) << (right_shift - 1);
383   params.scalar.bias = bias;
384   params.scalar.rounding = rounding;
385   params.scalar.multiplier = multiplier;
386   params.scalar.right_shift = right_shift;
387   params.scalar.output_min_less_zero_point =
388     (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
389   params.scalar.output_max_less_zero_point =
390     (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
391   params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
392   return params;
393 }
394 
xnn_update_qu8_avgpool_params(union xnn_qu8_avgpool_params * params,int32_t bias,float scale)395 static inline void xnn_update_qu8_avgpool_params(
396   union xnn_qu8_avgpool_params* params,
397   int32_t bias,
398   float scale)
399 {
400   // Compute requantization parameters.
401   assert(scale >= 0x1.0p-32f);
402   assert(scale < 256.0f);
403   const uint32_t scale_bits = fp32_to_bits(scale);
404 
405   // Multiplier is in [0x00800000, 0x00FFFFFF] range.
406   const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
407   assert(multiplier >= INT32_C(0x00800000));
408   assert(multiplier <= INT32_C(0x00FFFFFF));
409 
410   // Shift is in [16, 55] range.
411   const int32_t shift = 127 + 23 - (scale_bits >> 23);
412   assert(shift >= 16);
413   assert(shift < 64);
414 
415   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
416     const uint64_t rounding = UINT64_C(1) << ((uint32_t) shift - 1);
417     params->sse2.bias[0] = bias;
418     params->sse2.bias[1] = bias;
419     params->sse2.bias[2] = bias;
420     params->sse2.bias[3] = bias;
421     params->sse2.multiplier[0] = (uint32_t) multiplier;
422     params->sse2.multiplier[1] = (uint32_t) multiplier;
423     params->sse2.multiplier[2] = (uint32_t) multiplier;
424     params->sse2.multiplier[3] = (uint32_t) multiplier;
425     params->sse2.rounding[0] = rounding;
426     params->sse2.rounding[1] = rounding;
427     params->sse2.right_shift[0] = (uint64_t) (uint32_t) shift;
428     params->sse2.right_shift[1] = (uint64_t) (uint32_t) shift;
429   #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
430     params->neon.bias = bias;
431     params->neon.multiplier = multiplier;
432     params->neon.left_shift = (int64_t) -shift;
433   #else
434     const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
435     params->scalar.bias = bias;
436     params->scalar.multiplier = multiplier;
437     params->scalar.rounding = rounding;
438     params->scalar.right_shift = (uint32_t) shift;
439   #endif
440 }
441 
xnn_init_qs8_avgpool_params(int32_t bias,float scale,int8_t output_zero_point,int8_t output_min,int8_t output_max)442 static inline union xnn_qs8_avgpool_params xnn_init_qs8_avgpool_params(
443   int32_t bias,
444   float scale,
445   int8_t output_zero_point,
446   int8_t output_min,
447   int8_t output_max)
448 {
449   // Compute requantization parameters.
450   assert(scale >= 0x1.0p-32f);
451   assert(scale < 256.0f);
452   const uint32_t scale_bits = fp32_to_bits(scale);
453 
454   // Multiplier is in [0x00800000, 0x00FFFFFF] range.
455   const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
456   assert(multiplier >= INT32_C(0x00800000));
457   assert(multiplier <= INT32_C(0x00FFFFFF));
458 
459   // Shift is in [16, 55] range.
460   const int32_t shift = 127 + 23 - (scale_bits >> 23);
461   assert(shift >= 16);
462   assert(shift < 64);
463 
464   union xnn_qs8_avgpool_params params;
465   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
466     const uint64_t rounding = UINT64_C(1) << ((uint32_t) shift - 1);
467     params.sse2.bias[0] = bias;
468     params.sse2.bias[1] = bias;
469     params.sse2.bias[2] = bias;
470     params.sse2.bias[3] = bias;
471     params.sse2.multiplier[0] = (uint32_t) multiplier;
472     params.sse2.multiplier[1] = (uint32_t) multiplier;
473     params.sse2.multiplier[2] = (uint32_t) multiplier;
474     params.sse2.multiplier[3] = (uint32_t) multiplier;
475     params.sse2.rounding[0] = rounding;
476     params.sse2.rounding[1] = rounding;
477     params.sse2.shift[0] = (uint64_t) (uint32_t) shift;
478     params.sse2.shift[1] = (uint64_t) (uint32_t) shift;
479     for (uint32_t i = 0; i < 8; i++) {
480       params.sse2.output_zero_point[i] = (int16_t) output_zero_point;
481       params.sse2.output_min[i] = (int16_t) output_min;
482       params.sse2.output_max[i] = (int16_t) output_max;
483     }
484   #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
485     params.neon.bias = bias;
486     params.neon.multiplier = multiplier;
487     params.neon.left_shift = (int64_t) -shift;
488     params.neon.output_zero_point = (int16_t) output_zero_point;
489     params.neon.output_min = output_min;
490     params.neon.output_max = output_max;
491   #elif XNN_ARCH_WASMSIMD
492     const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
493     params.wasmsimd.bias[0] = bias;
494     params.wasmsimd.bias[1] = bias;
495     params.wasmsimd.bias[2] = bias;
496     params.wasmsimd.bias[3] = bias;
497     params.wasmsimd.multiplier[0] = (int64_t) multiplier;
498     params.wasmsimd.multiplier[1] = (int64_t) multiplier;
499     params.wasmsimd.rounding[0] = rounding;
500     params.wasmsimd.rounding[1] = rounding;
501     params.wasmsimd.shift = shift;
502     for (uint32_t i = 0; i < 8; i++) {
503       params.wasmsimd.output_zero_point[i] = (int16_t) output_zero_point;
504     }
505     for (uint32_t i = 0; i < 16; i++) {
506       params.wasmsimd.output_min[i] = output_min;
507       params.wasmsimd.output_max[i] = output_max;
508     }
509   #else
510     const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
511     params.scalar.bias = bias;
512     params.scalar.multiplier = multiplier;
513     params.scalar.rounding = rounding;
514     params.scalar.shift = (uint32_t) shift;
515     params.scalar.output_min_less_zero_point = (int32_t) output_min - (int32_t) output_zero_point;
516     params.scalar.output_max_less_zero_point = (int32_t) output_max - (int32_t) output_zero_point;
517     params.scalar.output_zero_point = (int32_t) output_zero_point;
518   #endif
519   return params;
520 }
521 
xnn_init_scalar_qs8_avgpool_params(int32_t bias,float scale,int8_t output_zero_point,int8_t output_min,int8_t output_max)522 static inline union xnn_qs8_avgpool_params xnn_init_scalar_qs8_avgpool_params(
523   int32_t bias,
524   float scale,
525   int8_t output_zero_point,
526   int8_t output_min,
527   int8_t output_max)
528 {
529   // Compute requantization parameters.
530   assert(scale >= 0x1.0p-32f);
531   assert(scale < 256.0f);
532   const uint32_t scale_bits = fp32_to_bits(scale);
533 
534   // Multiplier is in [0x00800000, 0x00FFFFFF] range.
535   const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
536   assert(multiplier >= INT32_C(0x00800000));
537   assert(multiplier <= INT32_C(0x00FFFFFF));
538 
539   // Shift is in [16, 55] range.
540   const int32_t shift = 127 + 23 - (scale_bits >> 23);
541   assert(shift >= 16);
542   assert(shift < 64);
543 
544   union xnn_qs8_avgpool_params params;
545   const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
546   params.scalar.bias = bias;
547   params.scalar.rounding = rounding;
548   params.scalar.multiplier = multiplier;
549   params.scalar.shift = shift;
550   params.scalar.output_min_less_zero_point = (int32_t) output_min - (int32_t) output_zero_point;
551   params.scalar.output_max_less_zero_point = (int32_t) output_max - (int32_t) output_zero_point;
552   params.scalar.output_zero_point = (int32_t) output_zero_point;
553   return params;
554 }
555 
xnn_update_qs8_avgpool_params(union xnn_qs8_avgpool_params * params,int32_t bias,float scale)556 static inline void xnn_update_qs8_avgpool_params(
557   union xnn_qs8_avgpool_params* params,
558   int32_t bias,
559   float scale)
560 {
561   // Compute requantization parameters.
562   assert(scale >= 0x1.0p-32f);
563   assert(scale < 256.0f);
564   const uint32_t scale_bits = fp32_to_bits(scale);
565 
566   // Multiplier is in [0x00800000, 0x00FFFFFF] range.
567   const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
568   assert(multiplier >= INT32_C(0x00800000));
569   assert(multiplier <= INT32_C(0x00FFFFFF));
570 
571   // Shift is in [16, 55] range.
572   const int32_t shift = 127 + 23 - (scale_bits >> 23);
573   assert(shift >= 16);
574   assert(shift < 64);
575 
576   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
577     const uint64_t rounding = UINT64_C(1) << ((uint32_t) shift - 1);
578     params->sse2.bias[0] = bias;
579     params->sse2.bias[1] = bias;
580     params->sse2.bias[2] = bias;
581     params->sse2.bias[3] = bias;
582     params->sse2.multiplier[0] = (uint32_t) multiplier;
583     params->sse2.multiplier[1] = (uint32_t) multiplier;
584     params->sse2.multiplier[2] = (uint32_t) multiplier;
585     params->sse2.multiplier[3] = (uint32_t) multiplier;
586     params->sse2.rounding[0] = rounding;
587     params->sse2.rounding[1] = rounding;
588     params->sse2.shift[0] = (uint64_t) (uint32_t) shift;
589     params->sse2.shift[1] = (uint64_t) (uint32_t) shift;
590   #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
591     params->neon.bias = bias;
592     params->neon.multiplier = multiplier;
593     params->neon.left_shift = (int64_t) -shift;
594   #elif XNN_ARCH_WASMSIMD
595     const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
596     params->wasmsimd.bias[0] = bias;
597     params->wasmsimd.bias[1] = bias;
598     params->wasmsimd.bias[2] = bias;
599     params->wasmsimd.bias[3] = bias;
600     params->wasmsimd.multiplier[0] = (int64_t) multiplier;
601     params->wasmsimd.multiplier[1] = (int64_t) multiplier;
602     params->wasmsimd.rounding[0] = rounding;
603     params->wasmsimd.rounding[1] = rounding;
604     params->wasmsimd.shift = shift;
605   #else
606     const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
607     params->scalar.bias = bias;
608     params->scalar.multiplier = multiplier;
609     params->scalar.rounding = rounding;
610     params->scalar.shift = (uint32_t) shift;
611   #endif
612 }
613 
xnn_update_f16_scaleminmax_params(struct xnn_f16_scaleminmax_params * params,uint16_t scale)614 static inline void xnn_update_f16_scaleminmax_params(
615   struct xnn_f16_scaleminmax_params* params,
616   uint16_t scale)
617 {
618   params->scale = scale;
619 }
620 
xnn_update_f32_scaleminmax_params(union xnn_f32_scaleminmax_params * params,float scale)621 static inline void xnn_update_f32_scaleminmax_params(
622   union xnn_f32_scaleminmax_params* params,
623   float scale)
624 {
625   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
626     for (uint32_t i = 0; i < 4; i++) {
627       params->sse2.scale[i] = scale;
628     }
629   #else
630     params->scalar.scale = scale;
631   #endif
632 }
633 
xnn_init_f16_scaleminmax_params(uint16_t scale,uint16_t min,uint16_t max)634 static inline struct xnn_f16_scaleminmax_params xnn_init_f16_scaleminmax_params(
635   uint16_t scale,
636   uint16_t min,
637   uint16_t max)
638 {
639   struct xnn_f16_scaleminmax_params params;
640   params.scale = scale;
641   params.min = min;
642   params.max = max;
643   return params;
644 }
645 
xnn_init_f32_scaleminmax_params(float scale,float min,float max)646 static inline union xnn_f32_scaleminmax_params xnn_init_f32_scaleminmax_params(
647   float scale,
648   float min,
649   float max)
650 {
651   union xnn_f32_scaleminmax_params params;
652   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
653     for (uint32_t i = 0; i < 4; i++) {
654       params.sse2.scale[i] = scale;
655       params.sse2.min[i] = min;
656       params.sse2.max[i] = max;
657     }
658   #else
659     params.scalar.scale = scale;
660     params.scalar.min = min;
661     params.scalar.max = max;
662   #endif
663   return params;
664 }
665 
xnn_init_f32_gavgpool_params(float multiplier,float output_min,float output_max,uint32_t width)666 static inline union xnn_f32_gavgpool_params xnn_init_f32_gavgpool_params(
667   float multiplier,
668   float output_min,
669   float output_max,
670   uint32_t width)
671 {
672   union xnn_f32_gavgpool_params params;
673   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
674     for (uint32_t i = 0; i < 4; i++) {
675       params.sse.multiplier[i] = multiplier;
676       params.sse.output_min[i] = output_min;
677       params.sse.output_max[i] = output_max;
678     }
679 
680     const uint32_t w = (width - 1) & 3;
681     params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
682     params.sse.mask[1] = -(uint32_t) (w >= 1);
683     params.sse.mask[2] = -(uint32_t) (w >= 2);
684     params.sse.mask[3] = -(uint32_t) (w >= 3);
685   #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
686     params.neon.multiplier = multiplier;
687     params.neon.output_min = output_min;
688     params.neon.output_max = output_max;
689 
690     const uint32_t w = (width - 1) & 3;
691     params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
692     params.neon.mask[1] = -(uint32_t) (w >= 1);
693     params.neon.mask[2] = -(uint32_t) (w >= 2);
694     params.neon.mask[3] = -(uint32_t) (w >= 3);
695   #else
696     params.scalar.multiplier = multiplier;
697     params.scalar.output_min = output_min;
698     params.scalar.output_max = output_max;
699 
700     const uint32_t w = (width - 1) & 3;
701     params.scalar.mask[0] = UINT32_C(0xFFFFFFFF);
702     params.scalar.mask[1] = -(int32_t) (w >= 1);
703     params.scalar.mask[2] = -(int32_t) (w >= 2);
704     params.scalar.mask[3] = -(int32_t) (w >= 3);
705   #endif
706   return params;
707 }
708 
xnn_update_f32_gavgpool_params(union xnn_f32_gavgpool_params * params,float multiplier,uint32_t width)709 static inline void xnn_update_f32_gavgpool_params(
710   union xnn_f32_gavgpool_params* params,
711   float multiplier,
712   uint32_t width)
713 {
714   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
715     for (uint32_t i = 0; i < 4; i++) {
716       params->sse.multiplier[i] = multiplier;
717     }
718 
719     const uint32_t w = (width - 1) & 3;
720     params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
721     params->sse.mask[1] = -(uint32_t) (w >= 1);
722     params->sse.mask[2] = -(uint32_t) (w >= 2);
723     params->sse.mask[3] = -(uint32_t) (w >= 3);
724   #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
725     params->neon.multiplier = multiplier;
726 
727     const uint32_t w = (width - 1) & 3;
728     params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
729     params->neon.mask[1] = -(uint32_t) (w >= 1);
730     params->neon.mask[2] = -(uint32_t) (w >= 2);
731     params->neon.mask[3] = -(uint32_t) (w >= 3);
732   #else
733     params->scalar.multiplier = multiplier;
734 
735     const uint32_t w = (width - 1) & 3;
736     params->scalar.mask[0] = UINT32_C(0xFFFFFFFF);
737     params->scalar.mask[1] = -(int32_t) (w >= 1);
738     params->scalar.mask[2] = -(int32_t) (w >= 2);
739     params->scalar.mask[3] = -(int32_t) (w >= 3);
740   #endif
741 }
742 
xnn_init_scalar_f32_scaleminmax_params(float scale,float min,float max)743 static inline union xnn_f32_scaleminmax_params xnn_init_scalar_f32_scaleminmax_params(
744   float scale,
745   float min,
746   float max)
747 {
748   union xnn_f32_scaleminmax_params params;
749   params.scalar.scale = scale;
750   params.scalar.min = min;
751   params.scalar.max = max;
752   return params;
753 }
754 
xnn_init_scalar_f32_gavgpool_params(float multiplier,float output_min,float output_max,uint32_t width)755 static inline union xnn_f32_gavgpool_params xnn_init_scalar_f32_gavgpool_params(
756   float multiplier,
757   float output_min,
758   float output_max,
759   uint32_t width)
760 {
761   union xnn_f32_gavgpool_params params;
762   params.scalar.multiplier = multiplier;
763   params.scalar.output_min = output_min;
764   params.scalar.output_max = output_max;
765 
766   const uint32_t w = (width - 1) & 3;
767   params.scalar.mask[0] = UINT32_C(0xFFFFFFFF);
768   params.scalar.mask[1] = -(int32_t) (w >= 1);
769   params.scalar.mask[2] = -(int32_t) (w >= 2);
770   params.scalar.mask[3] = -(int32_t) (w >= 3);
771   return params;
772 }
773 
xnn_init_f16_minmax_params(uint16_t min,uint16_t max)774 static inline struct xnn_f16_minmax_params xnn_init_f16_minmax_params(
775   uint16_t min,
776   uint16_t max)
777 {
778   struct xnn_f16_minmax_params params;
779   params.min = min;
780   params.max = max;
781   return params;
782 }
783 
xnn_init_f32_minmax_params(float output_min,float output_max)784 static inline union xnn_f32_minmax_params xnn_init_f32_minmax_params(
785   float output_min,
786   float output_max)
787 {
788   union xnn_f32_minmax_params params;
789   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
790     for (uint32_t i = 0; i < 4; i++) {
791       params.sse.min[i] = output_min;
792       params.sse.max[i] = output_max;
793     }
794   #else
795     params.scalar.min = output_min;
796     params.scalar.max = output_max;
797   #endif
798   return params;
799 }
800 
xnn_init_scalar_f32_minmax_params(float output_min,float output_max)801 static inline union xnn_f32_minmax_params xnn_init_scalar_f32_minmax_params(
802   float output_min,
803   float output_max)
804 {
805   union xnn_f32_minmax_params params;
806   params.scalar.min = output_min;
807   params.scalar.max = output_max;
808   return params;
809 }
810 
xnn_init_f16_hswish_params(void)811 static inline struct xnn_f16_hswish_params xnn_init_f16_hswish_params(void)
812 {
813   struct xnn_f16_hswish_params params;
814   params.sixth = UINT16_C(0x3155);
815   params.three = UINT16_C(0x4200);
816   params.six = UINT16_C(0x4600);
817   return params;
818 }
819 
xnn_init_f32_hswish_params(void)820 static inline union xnn_f32_hswish_params xnn_init_f32_hswish_params(void)
821 {
822   union xnn_f32_hswish_params params;
823   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
824     for (uint32_t i = 0; i < 4; i++) {
825       params.sse.sixth[i] = 0x1.555556p-3f;
826       params.sse.half[i] = 0.5f;
827       params.sse.one[i] = 1.0f;
828     }
829   #else
830     params.scalar.sixth = 0x1.555556p-3f;
831     params.scalar.three = 3.0f;
832     params.scalar.six = 6.0f;
833   #endif
834   return params;
835 }
836 
xnn_init_scalar_f32_hswish_params(void)837 static inline union xnn_f32_hswish_params xnn_init_scalar_f32_hswish_params(void)
838 {
839   union xnn_f32_hswish_params params;
840   params.scalar.sixth = 0x1.555556p-3f;
841   params.scalar.three = 3.0f;
842   params.scalar.six = 6.0f;
843   return params;
844 }
845 
xnn_init_f32_abs_params(void)846 static inline union xnn_f32_abs_params xnn_init_f32_abs_params(void)
847 {
848   union xnn_f32_abs_params params = { 0 };
849   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
850     for (uint32_t i = 0; i < 4; i++) {
851       params.sse.nonsign_mask[i] = math_nonsign_mask_f32();
852     }
853   #elif XNN_ARCH_WASMSIMD
854     params.wasmsimd.nonsign_mask = math_nonsign_mask_f32();
855   #endif
856   return params;
857 }
858 
xnn_init_scalar_f32_abs_params(void)859 static inline union xnn_f32_abs_params xnn_init_scalar_f32_abs_params(void)
860 {
861   union xnn_f32_abs_params params = { 0 };
862   return params;
863 }
864 
xnn_init_f32_neg_params(void)865 static inline union xnn_f32_neg_params xnn_init_f32_neg_params(void)
866 {
867   union xnn_f32_neg_params params = { 0 };
868   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
869     for (uint32_t i = 0; i < 4; i++) {
870       params.sse.sign_mask[i] = -0.0f;
871     }
872   #elif XNN_ARCH_WASMSIMD
873     params.wasmsimd.sign_mask = -0.0f;
874   #endif
875   return params;
876 }
877 
xnn_init_scalar_f32_neg_params(void)878 static inline union xnn_f32_neg_params xnn_init_scalar_f32_neg_params(void)
879 {
880   union xnn_f32_neg_params params = { 0 };
881   return params;
882 }
883 
xnn_init_f32_rnd_params(void)884 static inline union xnn_f32_rnd_params xnn_init_f32_rnd_params(void)
885 {
886   union xnn_f32_rnd_params params = { 0 };
887   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
888     for (uint32_t i = 0; i < 4; i++) {
889       params.sse2.sign_mask[i] = -0.0f;
890     }
891     for (uint32_t i = 0; i < 4; i++) {
892       params.sse2.one[i] = 1.0f;
893     }
894   #endif
895   return params;
896 }
897 
xnn_init_scalar_f32_rnd_params(void)898 static inline union xnn_f32_rnd_params xnn_init_scalar_f32_rnd_params(void)
899 {
900   union xnn_f32_rnd_params params = { 0 };
901   return params;
902 }
903 
xnn_init_f32_elu_params(float prescale,float alpha,float beta)904 static inline union xnn_f32_elu_params xnn_init_f32_elu_params(float prescale, float alpha, float beta)
905 {
906   union xnn_f32_elu_params params;
907   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
908     for (uint32_t i = 0; i < 4; i++) {
909       params.sse.prescale[i] = prescale;
910       params.sse.alpha[i] = alpha;
911       params.sse.beta[i] = beta;
912     }
913   #else
914     params.scalar.prescale = prescale;
915     params.scalar.alpha = alpha;
916     params.scalar.beta = beta;
917   #endif
918   return params;
919 }
920 
xnn_init_scalar_f32_elu_params(float prescale,float alpha,float beta)921 static inline union xnn_f32_elu_params xnn_init_scalar_f32_elu_params(float prescale, float alpha, float beta)
922 {
923   union xnn_f32_elu_params params;
924   params.scalar.prescale = prescale;
925   params.scalar.alpha = alpha;
926   params.scalar.beta = beta;
927   return params;
928 }
929 
xnn_init_f32_lrelu_params(float slope)930 static inline union xnn_f32_lrelu_params xnn_init_f32_lrelu_params(float slope)
931 {
932   union xnn_f32_lrelu_params params;
933   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
934     for (uint32_t i = 0; i < 4; i++) {
935       params.sse.slope[i] = slope;
936     }
937   #else
938     params.scalar.slope = slope;
939   #endif
940   return params;
941 }
942 
xnn_init_scalar_f32_lrelu_params(float slope)943 static inline union xnn_f32_lrelu_params xnn_init_scalar_f32_lrelu_params(float slope)
944 {
945   union xnn_f32_lrelu_params params;
946   params.scalar.slope = slope;
947   return params;
948 }
949 
xnn_init_f32_sqrt_params(void)950 static inline union xnn_f32_sqrt_params xnn_init_f32_sqrt_params(void)
951 {
952   union xnn_f32_sqrt_params params = { 0 };
953   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
954     params.fma.half = 0.5f;
955   #endif
956   return params;
957 }
958 
xnn_init_scalar_f32_sqrt_params(void)959 static inline union xnn_f32_sqrt_params xnn_init_scalar_f32_sqrt_params(void)
960 {
961   union xnn_f32_sqrt_params params = { 0 };
962   return params;
963 }
964 
xnn_init_f32_chw_params(uint32_t width,float output_min,float output_max)965 static inline union xnn_f32_chw_params xnn_init_f32_chw_params(
966   uint32_t width,
967   float output_min,
968   float output_max)
969 {
970   union xnn_f32_chw_params params;
971   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
972     for (uint32_t i = 0; i < 4; i++) {
973       params.sse.min[i] = output_min;
974       params.sse.max[i] = output_max;
975     }
976 
977     const uint32_t w4 = (width - 1) & 3;
978     params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
979     params.sse.mask[1] = -(uint32_t) (w4 >= 1);
980     params.sse.mask[2] = -(uint32_t) (w4 >= 2);
981     params.sse.mask[3] = -(uint32_t) (w4 >= 3);
982 
983     const uint32_t w8 = (width - 1) & 7;
984     params.sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
985     params.sse.mask_even[1] = -(uint32_t) (w8 >= 2);
986     params.sse.mask_even[2] = -(uint32_t) (w8 >= 4);
987     params.sse.mask_even[3] = -(uint32_t) (w8 >= 6);
988     params.sse.mask_odd[0] = -(uint32_t) (w8 >= 1);
989     params.sse.mask_odd[1] = -(uint32_t) (w8 >= 3);
990     params.sse.mask_odd[2] = -(uint32_t) (w8 >= 5);
991     params.sse.mask_odd[3] = -(uint32_t) (w8 >= 7);
992   #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
993     params.neon.min = output_min;
994     params.neon.max = output_max;
995 
996     const uint32_t w4 = (width - 1) & 3;
997     params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
998     params.neon.mask[1] = -(uint32_t) (w4 >= 1);
999     params.neon.mask[2] = -(uint32_t) (w4 >= 2);
1000     params.neon.mask[3] = -(uint32_t) (w4 >= 3);
1001 
1002     const uint32_t w8 = (width - 1) & 7;
1003     params.neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
1004     params.neon.mask_even[1] = -(uint32_t) (w8 >= 2);
1005     params.neon.mask_even[2] = -(uint32_t) (w8 >= 4);
1006     params.neon.mask_even[3] = -(uint32_t) (w8 >= 6);
1007     params.neon.mask_odd[0] = -(uint32_t) (w8 >= 1);
1008     params.neon.mask_odd[1] = -(uint32_t) (w8 >= 3);
1009     params.neon.mask_odd[2] = -(uint32_t) (w8 >= 5);
1010     params.neon.mask_odd[3] = -(uint32_t) (w8 >= 7);
1011   #else
1012     params.scalar.min = output_min;
1013     params.scalar.max = output_max;
1014 
1015     const uint32_t w4 = (width - 1) & 3;
1016     params.scalar.mask[0] = UINT32_C(0xFFFFFFFF);
1017     params.scalar.mask[1] = -(uint32_t) (w4 >= 1);
1018     params.scalar.mask[2] = -(uint32_t) (w4 >= 2);
1019     params.scalar.mask[3] = -(uint32_t) (w4 >= 3);
1020 
1021     const uint32_t w8 = (width - 1) & 7;
1022     params.scalar.mask_even[0] = UINT32_C(0xFFFFFFFF);
1023     params.scalar.mask_even[1] = -(uint32_t) (w8 >= 2);
1024     params.scalar.mask_even[2] = -(uint32_t) (w8 >= 4);
1025     params.scalar.mask_even[3] = -(uint32_t) (w8 >= 6);
1026     params.scalar.mask_odd[0] = -(uint32_t) (w8 >= 1);
1027     params.scalar.mask_odd[1] = -(uint32_t) (w8 >= 3);
1028     params.scalar.mask_odd[2] = -(uint32_t) (w8 >= 5);
1029     params.scalar.mask_odd[3] = -(uint32_t) (w8 >= 7);
1030   #endif
1031   return params;
1032 }
1033 
xnn_update_f32_chw_params(union xnn_f32_chw_params * params,uint32_t width)1034 static inline void xnn_update_f32_chw_params(
1035   union xnn_f32_chw_params* params,
1036   uint32_t width)
1037 {
1038   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1039     const uint32_t w4 = (width - 1) & 3;
1040     params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
1041     params->sse.mask[1] = -(uint32_t) (w4 >= 1);
1042     params->sse.mask[2] = -(uint32_t) (w4 >= 2);
1043     params->sse.mask[3] = -(uint32_t) (w4 >= 3);
1044 
1045     const uint32_t w8 = (width - 1) & 7;
1046     params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
1047     params->sse.mask_even[1] = -(uint32_t) (w8 >= 2);
1048     params->sse.mask_even[2] = -(uint32_t) (w8 >= 4);
1049     params->sse.mask_even[3] = -(uint32_t) (w8 >= 6);
1050     params->sse.mask_odd[0] = -(uint32_t) (w8 >= 1);
1051     params->sse.mask_odd[1] = -(uint32_t) (w8 >= 3);
1052     params->sse.mask_odd[2] = -(uint32_t) (w8 >= 5);
1053     params->sse.mask_odd[3] = -(uint32_t) (w8 >= 7);
1054   #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1055     const uint32_t w4 = (width - 1) & 3;
1056     params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
1057     params->neon.mask[1] = -(uint32_t) (w4 >= 1);
1058     params->neon.mask[2] = -(uint32_t) (w4 >= 2);
1059     params->neon.mask[3] = -(uint32_t) (w4 >= 3);
1060 
1061     const uint32_t w8 = (width - 1) & 7;
1062     params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
1063     params->neon.mask_even[1] = -(uint32_t) (w8 >= 2);
1064     params->neon.mask_even[2] = -(uint32_t) (w8 >= 4);
1065     params->neon.mask_even[3] = -(uint32_t) (w8 >= 6);
1066     params->neon.mask_odd[0] = -(uint32_t) (w8 >= 1);
1067     params->neon.mask_odd[1] = -(uint32_t) (w8 >= 3);
1068     params->neon.mask_odd[2] = -(uint32_t) (w8 >= 5);
1069     params->neon.mask_odd[3] = -(uint32_t) (w8 >= 7);
1070   #else
1071     const uint32_t w4 = (width - 1) & 3;
1072     params->scalar.mask[0] = UINT32_C(0xFFFFFFFF);
1073     params->scalar.mask[1] = -(uint32_t) (w4 >= 1);
1074     params->scalar.mask[2] = -(uint32_t) (w4 >= 2);
1075     params->scalar.mask[3] = -(uint32_t) (w4 >= 3);
1076 
1077     const uint32_t w8 = (width - 1) & 7;
1078     params->scalar.mask_even[0] = UINT32_C(0xFFFFFFFF);
1079     params->scalar.mask_even[1] = -(uint32_t) (w8 >= 2);
1080     params->scalar.mask_even[2] = -(uint32_t) (w8 >= 4);
1081     params->scalar.mask_even[3] = -(uint32_t) (w8 >= 6);
1082     params->scalar.mask_odd[0] = -(uint32_t) (w8 >= 1);
1083     params->scalar.mask_odd[1] = -(uint32_t) (w8 >= 3);
1084     params->scalar.mask_odd[2] = -(uint32_t) (w8 >= 5);
1085     params->scalar.mask_odd[3] = -(uint32_t) (w8 >= 7);
1086   #endif
1087 }
1088 
xnn_init_scalar_f32_chw_params(uint32_t width,float output_min,float output_max)1089 static inline union xnn_f32_chw_params xnn_init_scalar_f32_chw_params(
1090   uint32_t width,
1091   float output_min,
1092   float output_max)
1093 {
1094   union xnn_f32_chw_params params;
1095   params.scalar.min = output_min;
1096   params.scalar.max = output_max;
1097 
1098   const uint32_t w4 = (width - 1) & 3;
1099   params.scalar.mask[0] = UINT32_C(0xFFFFFFFF);
1100   params.scalar.mask[1] = -(uint32_t) (w4 >= 1);
1101   params.scalar.mask[2] = -(uint32_t) (w4 >= 2);
1102   params.scalar.mask[3] = -(uint32_t) (w4 >= 3);
1103 
1104   const uint32_t w8 = (width - 1) & 7;
1105   params.scalar.mask_even[0] = UINT32_C(0xFFFFFFFF);
1106   params.scalar.mask_even[1] = -(uint32_t) (w8 >= 2);
1107   params.scalar.mask_even[2] = -(uint32_t) (w8 >= 4);
1108   params.scalar.mask_even[3] = -(uint32_t) (w8 >= 6);
1109   params.scalar.mask_odd[0] = -(uint32_t) (w8 >= 1);
1110   params.scalar.mask_odd[1] = -(uint32_t) (w8 >= 3);
1111   params.scalar.mask_odd[2] = -(uint32_t) (w8 >= 5);
1112   params.scalar.mask_odd[3] = -(uint32_t) (w8 >= 7);
1113 
1114   return params;
1115 }
1116 
xnn_init_u8_minmax_params(uint8_t output_min,uint8_t output_max)1117 static inline union xnn_u8_minmax_params xnn_init_u8_minmax_params(
1118   uint8_t output_min,
1119   uint8_t output_max)
1120 {
1121   assert(output_min < output_max);
1122 
1123   union xnn_u8_minmax_params params;
1124   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1125     for (uint32_t i = 0; i < 16; i++) {
1126       params.sse2.min[i] = output_min;
1127       params.sse2.max[i] = output_max;
1128     }
1129   #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1130     params.neon.min = output_min;
1131     params.neon.max = output_max;
1132   #else
1133     params.scalar.min = (int32_t) (uint32_t) output_min;
1134     params.scalar.max = (int32_t) (uint32_t) output_max;
1135   #endif
1136   return params;
1137 }
1138 
xnn_init_scalar_u8_minmax_params(uint8_t output_min,uint8_t output_max)1139 static inline union xnn_u8_minmax_params xnn_init_scalar_u8_minmax_params(
1140   uint8_t output_min,
1141   uint8_t output_max)
1142 {
1143   assert(output_min < output_max);
1144 
1145   union xnn_u8_minmax_params params;
1146   params.scalar.min = (int32_t) (uint32_t) output_min;
1147   params.scalar.max = (int32_t) (uint32_t) output_max;
1148   return params;
1149 }
1150 
xnn_init_qu8_add_params(uint8_t a_zero_point,uint8_t b_zero_point,uint8_t output_zero_point,float a_output_scale,float b_output_scale,uint8_t output_min,uint8_t output_max)1151 static inline union xnn_qu8_add_params xnn_init_qu8_add_params(
1152   uint8_t a_zero_point,
1153   uint8_t b_zero_point,
1154   uint8_t output_zero_point,
1155   float a_output_scale,
1156   float b_output_scale,
1157   uint8_t output_min,
1158   uint8_t output_max)
1159 {
1160   assert(a_output_scale >= 0x1.0p-14f);
1161   assert(b_output_scale >= 0x1.0p-14f);
1162   assert(a_output_scale < 0x1.0p+8f);
1163   assert(b_output_scale < 0x1.0p+8f);
1164 
1165   // Compute requantization parameters.
1166   const float max_output_scale = math_max_f32(a_output_scale, b_output_scale);
1167   assert(max_output_scale >= 0x1.0p-14f);
1168   assert(max_output_scale < 0x1.0p+8f);
1169   const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1170   const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
1171   // Shift is in [13, 31] range.
1172   const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
1173   assert(shift < 32);
1174   assert(shift >= 13);
1175 
1176   const float scale_multiplier = fp32_from_bits((uint32_t) (21 - max_scale_exponent + 127) << 23);
1177 
1178   // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
1179   const uint32_t a_multiplier = (uint32_t) (int32_t) lrintf(a_output_scale * scale_multiplier);
1180   const uint32_t b_multiplier = (uint32_t) (int32_t) lrintf(b_output_scale * scale_multiplier);
1181   assert((a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >= UINT32_C(0x00200000));
1182   assert(a_multiplier < UINT32_C(0x00400000));
1183   assert(b_multiplier < UINT32_C(0x00400000));
1184 
1185   union xnn_qu8_add_params params;
1186   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1187     const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1188     const uint32_t remainder_threshold = remainder_mask >> 1;
1189     const int32_t zero_point_product =
1190       (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
1191     for (uint32_t i = 0; i < 4; i++) {
1192       params.sse2.zero_point_product[i] = zero_point_product;
1193     }
1194     for (uint32_t i = 0; i < 8; i++) {
1195       params.sse2.y_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
1196     }
1197     for (uint32_t i = 0; i < 8; i++) {
1198       params.sse2.a_multiplier_lo[i] = (uint16_t) (uint32_t) a_multiplier;
1199       params.sse2.a_multiplier_hi[i] = (uint16_t) ((uint32_t) a_multiplier >> 16);
1200       params.sse2.b_multiplier_lo[i] = (uint16_t) (uint32_t) b_multiplier;
1201       params.sse2.b_multiplier_hi[i] = (uint16_t) ((uint32_t) b_multiplier >> 16);
1202     }
1203     params.sse2.a_multiplier = a_multiplier;
1204     params.sse2.b_multiplier = b_multiplier;
1205     for (uint32_t i = 0; i < 4; i++) {
1206       params.sse2.remainder_mask[i] = remainder_mask;
1207       params.sse2.remainder_threshold[i] = remainder_threshold;
1208     }
1209     params.sse2.shift = shift;
1210     for (uint32_t i = 0; i < 16; i++) {
1211       params.sse2.y_min[i] = output_min;
1212       params.sse2.y_max[i] = output_max;
1213     }
1214   #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1215     params.neon.a_zero_point = a_zero_point;
1216     params.neon.b_zero_point = b_zero_point;
1217     params.neon.y_zero_point = (int16_t) (uint16_t) output_zero_point;
1218     params.neon.a_multiplier = (int32_t) a_multiplier;
1219     params.neon.b_multiplier = (int32_t) b_multiplier;
1220     params.neon.right_shift = (int32_t) -shift;
1221     params.neon.y_min = output_min;
1222     params.neon.y_max = output_max;
1223   #else
1224     const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1225     const uint32_t remainder_threshold = remainder_mask >> 1;
1226     params.scalar.zero_point_product =
1227       (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
1228     params.scalar.a_multiplier = a_multiplier;
1229     params.scalar.b_multiplier = b_multiplier;
1230     params.scalar.remainder_mask = (int32_t) remainder_mask;
1231     params.scalar.remainder_threshold = (int32_t) remainder_threshold;
1232     params.scalar.shift = shift;
1233     params.scalar.y_zero_point = (int32_t) (uint32_t) output_zero_point;
1234     params.scalar.y_min = (int32_t) (uint32_t) output_min;
1235     params.scalar.y_max = (int32_t) (uint32_t) output_max;
1236   #endif
1237   return params;
1238 }
1239 
xnn_init_scalar_qu8_add_params(uint8_t a_zero_point,uint8_t b_zero_point,uint8_t output_zero_point,float a_output_scale,float b_output_scale,uint8_t output_min,uint8_t output_max)1240 static inline union xnn_qu8_add_params xnn_init_scalar_qu8_add_params(
1241   uint8_t a_zero_point,
1242   uint8_t b_zero_point,
1243   uint8_t output_zero_point,
1244   float a_output_scale,
1245   float b_output_scale,
1246   uint8_t output_min,
1247   uint8_t output_max)
1248 {
1249   assert(a_output_scale >= 0x1.0p-10f);
1250   assert(b_output_scale >= 0x1.0p-10f);
1251   assert(a_output_scale < 0x1.0p+8f);
1252   assert(b_output_scale < 0x1.0p+8f);
1253 
1254   // Compute requantization parameters.
1255   const float max_output_scale = math_max_f32(a_output_scale, b_output_scale);
1256   assert(max_output_scale >= 0x1.0p-10f);
1257   assert(max_output_scale < 0x1.0p+8f);
1258   const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1259   const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
1260   // Shift is in [13, 31] range.
1261   const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
1262   assert(shift < 32);
1263   assert(shift >= 13);
1264 
1265   // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
1266   const uint32_t a_multiplier = (uint32_t) (int32_t) lrintf(fp32_from_bits(fp32_to_bits(a_output_scale) + (shift << 23)));
1267   const uint32_t b_multiplier = (uint32_t) (int32_t) lrintf(fp32_from_bits(fp32_to_bits(b_output_scale) + (shift << 23)));
1268   assert((a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >= UINT32_C(0x00200000));
1269   assert(a_multiplier < UINT32_C(0x00400000));
1270   assert(b_multiplier < UINT32_C(0x00400000));
1271 
1272   union xnn_qu8_add_params params;
1273   const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1274   const uint32_t remainder_threshold = remainder_mask >> 1;
1275   params.scalar.zero_point_product =
1276     (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
1277   params.scalar.a_multiplier = a_multiplier;
1278   params.scalar.b_multiplier = b_multiplier;
1279   params.scalar.remainder_mask = (int32_t) remainder_mask;
1280   params.scalar.remainder_threshold = (int32_t) remainder_threshold;
1281   params.scalar.shift = shift;
1282   params.scalar.y_zero_point = (int32_t) (uint32_t) output_zero_point;
1283   params.scalar.y_min = (int32_t) (uint32_t) output_min;
1284   params.scalar.y_max = (int32_t) (uint32_t) output_max;
1285   return params;
1286 }
1287 
xnn_init_qs8_add_params(int8_t x_zero_point,int8_t y_zero_point,int8_t output_zero_point,float x_output_scale,float y_output_scale,int8_t output_min,int8_t output_max)1288 static inline union xnn_qs8_add_params xnn_init_qs8_add_params(
1289   int8_t x_zero_point,
1290   int8_t y_zero_point,
1291   int8_t output_zero_point,
1292   float x_output_scale,
1293   float y_output_scale,
1294   int8_t output_min,
1295   int8_t output_max)
1296 {
1297   assert(x_output_scale >= 0x1.0p-14f);
1298   assert(y_output_scale >= 0x1.0p-14f);
1299   assert(x_output_scale < 0x1.0p+8f);
1300   assert(y_output_scale < 0x1.0p+8f);
1301 
1302   // Compute requantization parameters.
1303   const float max_output_scale = math_max_f32(x_output_scale, y_output_scale);
1304   assert(max_output_scale >= 0x1.0p-14f);
1305   assert(max_output_scale < 0x1.0p+8f);
1306   const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1307   const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
1308   // Shift is in [13, 31] range.
1309   const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
1310   assert(shift < 32);
1311   assert(shift >= 13);
1312 
1313   const float scale_multiplier = fp32_from_bits((uint32_t) (21 - max_scale_exponent + 127) << 23);
1314 
1315   // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
1316   const int32_t x_multiplier = (int32_t) lrintf(x_output_scale * scale_multiplier);
1317   const int32_t y_multiplier = (int32_t) lrintf(y_output_scale * scale_multiplier);
1318   assert((x_multiplier > y_multiplier ? x_multiplier : y_multiplier) >= INT32_C(0x00200000));
1319   assert(x_multiplier < INT32_C(0x00400000));
1320   assert(y_multiplier < INT32_C(0x00400000));
1321 
1322   union xnn_qs8_add_params params;
1323   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1324     const int32_t remainder_mask = (INT32_C(1) << shift) - INT32_C(1);
1325     const int32_t remainder_threshold = (int32_t) ((uint32_t) remainder_mask >> 1);
1326     const int32_t zero_point_product =
1327       (int32_t) -(x_multiplier * (int32_t) x_zero_point + y_multiplier * (int32_t) y_zero_point);
1328     for (uint32_t i = 0; i < 4; i++) {
1329       params.sse2.zero_point_product[i] = zero_point_product;
1330     }
1331     const uint16_t x_multiplier_lo = (uint16_t) x_multiplier;
1332     const uint16_t x_multiplier_hi = (uint16_t) ((uint32_t) x_multiplier >> 16);
1333     const uint16_t y_multiplier_lo = (uint16_t) y_multiplier;
1334     const uint16_t y_multiplier_hi = (uint16_t) ((uint32_t) y_multiplier >> 16);
1335     for (uint32_t i = 0; i < 8; i++) {
1336       params.sse2.x_multiplier_lo[i] = x_multiplier_lo;
1337       params.sse2.x_multiplier_hi[i] = x_multiplier_hi;
1338       params.sse2.y_multiplier_lo[i] = y_multiplier_lo;
1339       params.sse2.y_multiplier_hi[i] = y_multiplier_hi;
1340     }
1341     params.sse2.shift = shift;
1342     for (uint32_t i = 0; i < 4; i++) {
1343       params.sse2.x_multiplier[i] = x_multiplier;
1344       params.sse2.y_multiplier[i] = y_multiplier;
1345       params.sse2.remainder_mask[i] = remainder_mask;
1346       params.sse2.remainder_threshold[i] = remainder_threshold;
1347     }
1348     for (uint32_t i = 0; i < 8; i++) {
1349       params.sse2.output_zero_point[i] = (int16_t) output_zero_point;
1350       params.sse2.output_min[i] = (int16_t) output_min;
1351       params.sse2.output_max[i] = (int16_t) output_max;
1352     }
1353   #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1354     params.neon.x_zero_point = x_zero_point;
1355     params.neon.y_zero_point = y_zero_point;
1356     params.neon.x_multiplier = (int32_t) x_multiplier;
1357     params.neon.y_multiplier = (int32_t) y_multiplier;
1358     params.neon.right_shift = (int32_t) -shift;
1359     params.neon.output_zero_point = (int16_t) output_zero_point;
1360     params.neon.output_min = output_min;
1361     params.neon.output_max = output_max;
1362   #elif XNN_ARCH_WASMSIMD
1363     const int32_t remainder_mask = (INT32_C(1) << shift) - INT32_C(1);
1364     const int32_t remainder_threshold = (int32_t) ((uint32_t) remainder_mask >> 1);
1365     const int32_t zero_point_product =
1366       (int32_t) -(x_multiplier * (int32_t) x_zero_point + y_multiplier * (int32_t) y_zero_point);
1367     for (uint32_t i = 0; i < 4; i++) {
1368       params.wasmsimd.zero_point_product[i] = zero_point_product;
1369       params.wasmsimd.x_multiplier[i] = x_multiplier;
1370       params.wasmsimd.y_multiplier[i] = y_multiplier;
1371       params.wasmsimd.remainder_mask[i] = remainder_mask;
1372       params.wasmsimd.remainder_threshold[i] = remainder_threshold;
1373     }
1374     params.wasmsimd.shift = shift;
1375     for (uint32_t i = 0; i < 8; i++) {
1376       params.wasmsimd.output_zero_point[i] = (int16_t) output_zero_point;
1377     }
1378     for (uint32_t i = 0; i < 16; i++) {
1379       params.wasmsimd.output_min[i] = output_min;
1380       params.wasmsimd.output_max[i] = output_max;
1381     }
1382   #else
1383     const int32_t remainder_mask = (INT32_C(1) << shift) - INT32_C(1);
1384     const int32_t remainder_threshold = (int32_t) ((uint32_t) remainder_mask >> 1);
1385     params.scalar.zero_point_product =
1386       (int32_t) -(x_multiplier * (int32_t) x_zero_point + y_multiplier * (int32_t) y_zero_point);
1387     params.scalar.x_multiplier = x_multiplier;
1388     params.scalar.y_multiplier = y_multiplier;
1389     params.scalar.remainder_mask = (int32_t) remainder_mask;
1390     params.scalar.remainder_threshold = (int32_t) remainder_threshold;
1391     params.scalar.shift = (int32_t) shift;
1392     params.scalar.output_zero_point = (int32_t) output_zero_point;
1393     params.scalar.output_min = (int32_t) output_min;
1394     params.scalar.output_max = (int32_t) output_max;
1395   #endif
1396   return params;
1397 }
1398 
xnn_init_scalar_qs8_add_params(int8_t x_zero_point,int8_t y_zero_point,int8_t output_zero_point,float x_output_scale,float y_output_scale,int8_t output_min,int8_t output_max)1399 static inline union xnn_qs8_add_params xnn_init_scalar_qs8_add_params(
1400   int8_t x_zero_point,
1401   int8_t y_zero_point,
1402   int8_t output_zero_point,
1403   float x_output_scale,
1404   float y_output_scale,
1405   int8_t output_min,
1406   int8_t output_max)
1407 {
1408   assert(x_output_scale >= 0x1.0p-10f);
1409   assert(y_output_scale >= 0x1.0p-10f);
1410   assert(x_output_scale < 0x1.0p+8f);
1411   assert(y_output_scale < 0x1.0p+8f);
1412 
1413   // Compute requantization parameters.
1414   const float max_output_scale = math_max_f32(x_output_scale, y_output_scale);
1415   assert(max_output_scale >= 0x1.0p-10f);
1416   assert(max_output_scale < 0x1.0p+8f);
1417   const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1418   const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
1419   // Shift is in [13, 31] range.
1420   const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
1421   assert(shift < 32);
1422   assert(shift >= 13);
1423 
1424   // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
1425   const int32_t x_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(x_output_scale) + (shift << 23)));
1426   const int32_t y_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(y_output_scale) + (shift << 23)));
1427   assert((x_multiplier > y_multiplier ? x_multiplier : y_multiplier) >= INT32_C(0x00200000));
1428   assert(x_multiplier < INT32_C(0x00400000));
1429   assert(y_multiplier < INT32_C(0x00400000));
1430 
1431   union xnn_qs8_add_params params;
1432   const int32_t remainder_mask = (INT32_C(1) << shift) - INT32_C(1);
1433   const int32_t remainder_threshold = (int32_t) ((uint32_t) remainder_mask >> 1);
1434   params.scalar.zero_point_product =
1435     (int32_t) -(x_multiplier * (int32_t) x_zero_point + y_multiplier * (int32_t) y_zero_point);
1436   params.scalar.x_multiplier = x_multiplier;
1437   params.scalar.y_multiplier = y_multiplier;
1438   params.scalar.remainder_mask = (int32_t) remainder_mask;
1439   params.scalar.remainder_threshold = (int32_t) remainder_threshold;
1440   params.scalar.shift = shift;
1441   params.scalar.output_zero_point = (int32_t) output_zero_point;
1442   params.scalar.output_min = (int32_t) output_min;
1443   params.scalar.output_max = (int32_t) output_max;
1444   return params;
1445 }
1446 
xnn_init_scalar_qu8_requantization_params(float scale,uint8_t zero_point,uint8_t min,uint8_t max)1447 static inline union xnn_qu8_requantization_params xnn_init_scalar_qu8_requantization_params(
1448   float scale,
1449   uint8_t zero_point,
1450   uint8_t min,
1451   uint8_t max)
1452 {
1453   // Compute requantization parameters.
1454   assert(scale < 1.0f);
1455   assert(scale >= 0x1.0p-32f);
1456   const uint32_t scale_bits = fp32_to_bits(scale);
1457 
1458   // Multiplier is in [0x40000000, 0x7FFFFF80] range.
1459   const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
1460   assert(multiplier >= INT32_C(0x40000000));
1461   assert(multiplier <= INT32_C(0x7FFFFF80));
1462 
1463   // Shift is in [0, 31] range.
1464   const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
1465   assert(shift >= 0);
1466   assert(shift < 32);
1467 
1468   union xnn_qu8_requantization_params params;
1469   const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1470   const uint32_t remainder_threshold = remainder_mask >> 1;
1471   params.q31.multiplier = multiplier;
1472   params.q31.remainder_mask = (int32_t) remainder_mask;
1473   params.q31.remainder_threshold = (int32_t) remainder_threshold;
1474   params.q31.shift = (uint32_t) shift;
1475   params.q31.min_less_zero_point = (int32_t) (uint32_t) min - (int32_t) (uint32_t) zero_point;
1476   params.q31.max_less_zero_point = (int32_t) (uint32_t) max - (int32_t) (uint32_t) zero_point;
1477   params.q31.zero_point = (int32_t) (uint32_t) zero_point;
1478   return params;
1479 }
1480 
xnn_init_scalar_qs8_requantization_params(float scale,int8_t zero_point,int8_t min,int8_t max)1481 static inline union xnn_qs8_requantization_params xnn_init_scalar_qs8_requantization_params(
1482   float scale,
1483   int8_t zero_point,
1484   int8_t min,
1485   int8_t max)
1486 {
1487   // Compute requantization parameters.
1488   assert(scale < 1.0f);
1489   assert(scale >= 0x1.0p-32f);
1490   const uint32_t scale_bits = fp32_to_bits(scale);
1491 
1492   // Multiplier is in [0x40000000, 0x7FFFFF80] range.
1493   const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
1494   assert(multiplier >= INT32_C(0x40000000));
1495   assert(multiplier <= INT32_C(0x7FFFFF80));
1496 
1497   // Shift is in [0, 31] range.
1498   const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
1499   assert(shift >= 0);
1500   assert(shift < 32);
1501 
1502   union xnn_qs8_requantization_params params;
1503   const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1504   const uint32_t remainder_threshold = remainder_mask >> 1;
1505   params.q31.multiplier = multiplier;
1506   params.q31.remainder_mask = (int32_t) remainder_mask;
1507   params.q31.remainder_threshold = (int32_t) remainder_threshold;
1508   params.q31.shift = (uint32_t) shift;
1509   params.q31.min_less_zero_point = (int32_t) min - (int32_t) zero_point;
1510   params.q31.max_less_zero_point = (int32_t) max - (int32_t) zero_point;
1511   params.q31.zero_point = (int32_t) zero_point;
1512   return params;
1513 }
1514 
1515