1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
7$assert NR % 8 == 0
8$assert 8 <= NR <= 16
9#include <assert.h>
10
11#include <arm_neon.h>
12
13#include <xnnpack/common.h>
14#include <xnnpack/gemm.h>
15
16
17void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}__neon_mlal_lane(
18    size_t mr,
19    size_t nc,
20    size_t kc,
21    const int8_t* restrict a,
22    size_t a_stride,
23    const void* restrict w,
24    int8_t* restrict c,
25    size_t cm_stride,
26    size_t cn_stride,
27    const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
28{
29  assert(mr != 0);
30  assert(mr <= ${MR});
31  assert(nc != 0);
32  assert(kc != 0);
33  assert(kc % sizeof(int8_t) == 0);
34  assert(a != NULL);
35  assert(w != NULL);
36  assert(c != NULL);
37
38  const int8_t* a0 = a;
39  int8_t* c0 = c;
40  $for M in range(1, MR):
41    const int8_t* a${M} = (const int8_t*) ((uintptr_t) a${M-1} + a_stride);
42    int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride);
43    $if M % 2 == 0:
44      if XNN_UNPREDICTABLE(mr <= ${M}) {
45        a${M} = a${M-1};
46        c${M} = c${M-1};
47      }
48    $elif M + 1 == MR:
49      if XNN_UNPREDICTABLE(mr != ${M+1}) {
50        a${M} = a${M-1};
51        c${M} = c${M-1};
52      }
53    $else:
54      if XNN_UNPREDICTABLE(mr < ${M+1}) {
55        a${M} = a${M-1};
56        c${M} = c${M-1};
57      }
58
59  do {
60    $for N in range(0, NR, 4):
61      int32x4_t vacc0x${ABC[N:N+4]} = vld1q_s32(w); w = (const void*) ((uintptr_t) w + 4 * sizeof(int32_t));
62    $for M in range(1, MR):
63      $for N in range(0, NR, 4):
64        int32x4_t vacc${M}x${ABC[N:N+4]} = vacc0x${ABC[N:N+4]};
65
66    size_t k = kc;
67    while (k >= 8 * sizeof(int8_t)) {
68      $for M in range(MR):
69        const int8x8_t va${M} = vld1_s8(a${M}); a${M} += 8;
70        const int16x8_t vxa${M} = vmovl_s8(va${M});
71
72      $for K in range(4):
73        $for N in range(0, NR, 8):
74          const int8x8_t vb${ABC[N:N+8]}c${K} = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
75          const int16x8_t vxb${ABC[N:N+8]}c${K} = vmovl_s8(vb${ABC[N:N+8]}c${K});
76
77          $for M in range(MR):
78            vacc${M}x${ABC[N:N+4]} = vmlal_lane_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vxb${ABC[N:N+8]}c${K}), vget_low_s16(vxa${M}), ${K});
79            vacc${M}x${ABC[N+4:N+8]} = vmlal_lane_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vxb${ABC[N:N+8]}c${K}), vget_low_s16(vxa${M}), ${K});
80
81      $for K in range(4, 8):
82        $for N in range(0, NR, 8):
83          const int8x8_t vb${ABC[N:N+8]}c${K} = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
84          const int16x8_t vxb${ABC[N:N+8]}c${K} = vmovl_s8(vb${ABC[N:N+8]}c${K});
85
86          $for M in range(MR):
87            vacc${M}x${ABC[N:N+4]} = vmlal_lane_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vxb${ABC[N:N+8]}c${K}), vget_high_s16(vxa${M}), ${K-4});
88            vacc${M}x${ABC[N+4:N+8]} = vmlal_lane_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vxb${ABC[N:N+8]}c${K}), vget_high_s16(vxa${M}), ${K-4});
89
90      k -= 8 * sizeof(int8_t);
91    }
92    if XNN_UNLIKELY(k != 0) {
93      $for M in range(MR):
94        const int8x8_t va${M} = vld1_s8(a${M}); a${M} = (const int8_t*) ((uintptr_t) a${M} + k);
95        const int16x8_t vxa${M} = vmovl_s8(va${M});
96
97      $for N in range(0, NR, 8):
98        const int8x8_t vb${ABC[N:N+8]}c0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
99        const int16x8_t vxb${ABC[N:N+8]}c0 = vmovl_s8(vb${ABC[N:N+8]}c0);
100
101      $for M in range(MR):
102        $for N in range(0, NR, 8):
103          vacc${M}x${ABC[N:N+4]} = vmlal_lane_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vxb${ABC[N:N+8]}c0), vget_low_s16(vxa${M}), 0);
104          vacc${M}x${ABC[N+4:N+8]} = vmlal_lane_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vxb${ABC[N:N+8]}c0), vget_low_s16(vxa${M}), 0);
105
106      if (k >= 2 * sizeof(int8_t)) {
107        $for N in range(0, NR, 8):
108          const int8x8_t vb${ABC[N:N+8]}c1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
109          const int16x8_t vxb${ABC[N:N+8]}c1 = vmovl_s8(vb${ABC[N:N+8]}c1);
110
111        $for M in range(MR):
112          $for N in range(0, NR, 8):
113            vacc${M}x${ABC[N:N+4]} = vmlal_lane_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vxb${ABC[N:N+8]}c1), vget_low_s16(vxa${M}), 1);
114            vacc${M}x${ABC[N+4:N+8]} = vmlal_lane_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vxb${ABC[N:N+8]}c1), vget_low_s16(vxa${M}), 1);
115
116        if (k > 2 * sizeof(int8_t)) {
117          $for N in range(0, NR, 8):
118            const int8x8_t vb${ABC[N:N+8]}c2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
119            const int16x8_t vxb${ABC[N:N+8]}c2 = vmovl_s8(vb${ABC[N:N+8]}c2);
120
121          $for M in range(MR):
122            $for N in range(0, NR, 8):
123              vacc${M}x${ABC[N:N+4]} = vmlal_lane_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vxb${ABC[N:N+8]}c2), vget_low_s16(vxa${M}), 2);
124              vacc${M}x${ABC[N+4:N+8]} = vmlal_lane_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vxb${ABC[N:N+8]}c2), vget_low_s16(vxa${M}), 2);
125
126          if (k >= 4 * sizeof(int8_t)) {
127            $for N in range(0, NR, 8):
128              const int8x8_t vb${ABC[N:N+8]}c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
129              const int16x8_t vxb${ABC[N:N+8]}c3 = vmovl_s8(vb${ABC[N:N+8]}c3);
130
131            $for M in range(MR):
132              $for N in range(0, NR, 8):
133                vacc${M}x${ABC[N:N+4]} = vmlal_lane_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vxb${ABC[N:N+8]}c3), vget_low_s16(vxa${M}), 3);
134                vacc${M}x${ABC[N+4:N+8]} = vmlal_lane_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vxb${ABC[N:N+8]}c3), vget_low_s16(vxa${M}), 3);
135
136            if (k > 4 * sizeof(int8_t)) {
137              $for N in range(0, NR, 8):
138                const int8x8_t vb${ABC[N:N+8]}c4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
139                const int16x8_t vxb${ABC[N:N+8]}c4 = vmovl_s8(vb${ABC[N:N+8]}c4);
140
141              $for M in range(MR):
142                $for N in range(0, NR, 8):
143                  vacc${M}x${ABC[N:N+4]} = vmlal_lane_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vxb${ABC[N:N+8]}c4), vget_high_s16(vxa${M}), 0);
144                  vacc${M}x${ABC[N+4:N+8]} = vmlal_lane_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vxb${ABC[N:N+8]}c4), vget_high_s16(vxa${M}), 0);
145
146              if (k >= 6 * sizeof(int8_t)) {
147                $for N in range(0, NR, 8):
148                  const int8x8_t vb${ABC[N:N+8]}c5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
149                  const int16x8_t vxb${ABC[N:N+8]}c5 = vmovl_s8(vb${ABC[N:N+8]}c5);
150
151                $for M in range(MR):
152                  $for N in range(0, NR, 8):
153                    vacc${M}x${ABC[N:N+4]} = vmlal_lane_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vxb${ABC[N:N+8]}c5), vget_high_s16(vxa${M}), 1);
154                    vacc${M}x${ABC[N+4:N+8]} = vmlal_lane_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vxb${ABC[N:N+8]}c5), vget_high_s16(vxa${M}), 1);
155
156                if (k > 6 * sizeof(int8_t)) {
157                  $for N in range(0, NR, 8):
158                    const int8x8_t vb${ABC[N:N+8]}c6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
159                    const int16x8_t vxb${ABC[N:N+8]}c6 = vmovl_s8(vb${ABC[N:N+8]}c6);
160
161                  $for M in range(MR):
162                    $for N in range(0, NR, 8):
163                      vacc${M}x${ABC[N:N+4]} = vmlal_lane_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vxb${ABC[N:N+8]}c6), vget_high_s16(vxa${M}), 2);
164                      vacc${M}x${ABC[N+4:N+8]} = vmlal_lane_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vxb${ABC[N:N+8]}c6), vget_high_s16(vxa${M}), 2);
165                }
166              }
167            }
168          }
169        }
170      }
171    }
172
173    const int32x4_t vmultiplier = vld1q_dup_s32(&params->neon.multiplier);
174    $for M in range(MR):
175      $for N in range(0, NR, 4):
176        vacc${M}x${ABC[N:N+4]} = vqrdmulhq_s32(vacc${M}x${ABC[N:N+4]}, vmultiplier);
177
178    const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
179    const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
180    $for M in range(MR):
181      $for N in range(0, NR, 4):
182        vacc${M}x${ABC[N:N+4]} = vsraq_n_s32(vacc${M}x${ABC[N:N+4]}, vbicq_s32(vacc${M}x${ABC[N:N+4]}, vzero_shift_mask), 31);
183
184    $for M in range(MR):
185      $for N in range(0, NR, 4):
186        vacc${M}x${ABC[N:N+4]} = vrshlq_s32(vacc${M}x${ABC[N:N+4]}, vright_shift);
187
188    const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->neon.output_zero_point);
189#if XNN_ARCH_ARM64
190    $for M in range(MR):
191      $for N in range(0, NR, 8):
192        const int16x8_t vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vacc${M}x${ABC[N+4:N+8]}), voutput_zero_point);
193
194    $for M in range(MR):
195      $for N in range(0, NR, 16):
196        $if N + 8 < NR:
197          int8x16_t vout${M}x${ABC[N:N+16]} = vqmovn_high_s16(vqmovn_s16(vacc${M}x${ABC[N:N+8]}), vacc${M}x${ABC[N+8:N+16]});
198        $elif M % 2 == 1:
199          int8x16_t vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vqmovn_high_s16(vqmovn_s16(vacc${M-1}x${ABC[N:N+8]}), vacc${M}x${ABC[N:N+8]});
200        $elif M + 1 == MR:
201          int8x8_t vout${M}x${ABC[N:N+8]} = vqmovn_s16(vacc${M}x${ABC[N:N+8]});
202#else
203    $for M in range(MR):
204      $for N in range(0, NR, 8):
205        const int16x8_t vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vqmovn_s32(vacc${M}x${ABC[N+4:N+8]})), voutput_zero_point);
206
207    $for M in range(MR):
208      $for N in range(0, NR, 16):
209        $if N + 8 < NR:
210          int8x16_t vout${M}x${ABC[N:N+16]} = vcombine_s8(vqmovn_s16(vacc${M}x${ABC[N:N+8]}), vqmovn_s16(vacc${M}x${ABC[N+8:N+16]}));
211        $elif M % 2 == 1:
212          int8x16_t vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vcombine_s8(vqmovn_s16(vacc${M-1}x${ABC[N:N+8]}), vqmovn_s16(vacc${M}x${ABC[N:N+8]}));
213        $elif M + 1 == MR:
214          int8x8_t vout${M}x${ABC[N:N+8]} = vqmovn_s16(vacc${M}x${ABC[N:N+8]});
215#endif
216    $if NR == 8 and MR == 1:
217      const int8x8_t voutput_min = vld1_dup_s8(&params->neon.output_min);
218      const int8x8_t voutput_max = vld1_dup_s8(&params->neon.output_max);
219    $else:
220      const int8x16_t voutput_min = vld1q_dup_s8(&params->neon.output_min);
221      const int8x16_t voutput_max = vld1q_dup_s8(&params->neon.output_max);
222
223    $for M in range(MR):
224      $for N in range(0, NR, 16):
225        $if N + 8 < NR:
226          vout${M}x${ABC[N:N+16]} = vmaxq_s8(vout${M}x${ABC[N:N+16]}, voutput_min);
227        $elif M % 2 == 1:
228          vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vmaxq_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_min);
229        $elif M + 1 == MR:
230          $if NR == 8 and MR == 1:
231            vout${M}x${ABC[N:N+8]} = vmax_s8(vout${M}x${ABC[N:N+8]}, voutput_min);
232          $else:
233            vout${M}x${ABC[N:N+8]} = vmax_s8(vout${M}x${ABC[N:N+8]}, vget_low_s8(voutput_min));
234
235    $for M in range(MR):
236      $for N in range(0, NR, 16):
237        $if N + 8 < NR:
238          vout${M}x${ABC[N:N+16]} = vminq_s8(vout${M}x${ABC[N:N+16]}, voutput_max);
239        $elif M % 2 == 1:
240          vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vminq_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_max);
241        $elif M + 1 == MR:
242          $if NR == 8 and MR == 1:
243            vout${M}x${ABC[N:N+8]} = vmin_s8(vout${M}x${ABC[N:N+8]}, voutput_max);
244          $else:
245            vout${M}x${ABC[N:N+8]} = vmin_s8(vout${M}x${ABC[N:N+8]}, vget_low_s8(voutput_max));
246
247    if (nc >= ${NR}) {
248      $for M in range(MR):
249        $for N in range(0, NR, 16):
250          $if N + 8 < NR:
251            vst1q_s8(c${M} + ${N}, vout${M}x${ABC[N:N+16]});
252          $elif M % 2 == 1:
253            vst1_s8(c${M-1} + ${N}, vget_low_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}));
254            vst1_s8(c${M} + ${N}, vget_high_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}));
255          $elif M + 1 == MR:
256            vst1_s8(c${M} + ${N}, vout${M}x${ABC[N:N+8]});
257
258      $for M in range(MR):
259        c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
260
261      $for M in range(MR):
262        a${M} = (const int8_t*) ((uintptr_t) a${M} - kc);
263
264      nc -= ${NR};
265    } else {
266      $if NR == 16:
267        $for M in range(MR):
268          $if M % 2 == 1:
269            int8x16_t vout${M-1}x01234567_${M}x01234567 = vcombine_s8(vget_low_s8(vout${M-1}x0123456789ABCDEF), vget_low_s8(vout${M}x0123456789ABCDEF));
270          $elif M + 1 == MR:
271            int8x8_t vout${M}x01234567 = vget_low_s8(vout${M}x0123456789ABCDEF);
272        if (nc & 8) {
273          $for M in range(MR):
274            $if M % 2 == 1:
275              vst1_s8(c${M-1}, vget_low_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]})); c${M-1} += 8;
276              vst1_s8(c${M}, vget_high_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]})); c${M} += 8;
277            $elif M + 1 == MR:
278              vst1_s8(c${M}, vout${M}x${ABC[N:N+8]}); c${M} += 8;
279          $for M in range(MR):
280            $if M % 2 == 1:
281              vout${M-1}x01234567_${M}x01234567 = vcombine_s8(vget_high_s8(vout${M-1}x0123456789ABCDEF), vget_high_s8(vout${M}x0123456789ABCDEF));
282            $elif M + 1 == MR:
283              vout${M}x01234567 = vget_high_s8(vout${M}x0123456789ABCDEF);
284        }
285      if (nc & 4) {
286        $for M in range(MR):
287          $if M % 2 == 1:
288            vst1q_lane_u32(__builtin_assume_aligned(c${M-1}, 1), vreinterpretq_u32_s8(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 4;
289            vst1q_lane_u32(__builtin_assume_aligned(c${M}, 1), vreinterpretq_u32_s8(vout${M-1}x01234567_${M}x01234567), 2); c${M} += 4;
290          $elif M + 1 == MR:
291            vst1_lane_u32(__builtin_assume_aligned(c${M}, 1), vreinterpret_u32_s8(vout${M}x01234567), 0); c${M} += 4;
292        $for M in range(MR):
293          $if M % 2 == 1:
294            vout${M-1}x01234567_${M}x01234567 = vextq_s8(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 4);
295          $elif M + 1 == MR:
296            vout${M}x01234567 = vext_s8(vout${M}x01234567, vout${M}x01234567, 4);
297      }
298      if (nc & 2) {
299        $for M in range(MR):
300          $if M % 2 == 1:
301            vst1q_lane_u16(__builtin_assume_aligned(c${M-1}, 1), vreinterpretq_u16_s8(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 2;
302            vst1q_lane_u16(__builtin_assume_aligned(c${M}, 1), vreinterpretq_u16_s8(vout${M-1}x01234567_${M}x01234567), 4); c${M} += 2;
303          $elif M + 1 == MR:
304            vst1_lane_u16(__builtin_assume_aligned(c${M}, 1), vreinterpret_u16_s8(vout${M}x01234567), 0); c${M} += 2;
305        $for M in range(MR):
306          $if M % 2 == 1:
307            vout${M-1}x01234567_${M}x01234567 = vextq_s8(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 2);
308          $elif M + 1 == MR:
309            vout${M}x01234567 = vext_s8(vout${M}x01234567, vout${M}x01234567, 2);
310      }
311      if (nc & 1) {
312        $for M in range(MR):
313          $if M % 2 == 1:
314            vst1q_lane_s8(c${M-1}, vout${M-1}x01234567_${M}x01234567, 0);
315            vst1q_lane_s8(c${M}, vout${M-1}x01234567_${M}x01234567, 8);
316          $elif M + 1 == MR:
317            vst1_lane_s8(c${M}, vout${M}x01234567, 0);
318      }
319
320      nc = 0;
321    }
322  } while (nc != 0);
323}
324