1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <xnnpack/assembly.h>
7
8# void xnn_qs8_gemm_minmax_ukernel_4x16c4__aarch64_neondot_ld32(
9#     size_t mr,                 x0
10#     size_t nc,                 x1
11#     size_t kc,                 x2 / x0
12#     const int8_t* restrict a,  x3
13#     size_t a_stride,           x4
14#     const void* restrict w,    x5
15#     int8_t* restrict c,        x6
16#     size_t cm_stride,          x7
17#     size_t cn_stride,          [sp] -> x12
18#     const union xnn_qs8_gemm_params params)  [sp + 8] -> x11
19
20# d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.
21
22# Register usage
23# A0  x3 v0
24# A1 x15 v1
25# A2 x13 v2
26# A3  x4 v3
27# B   x5 v4  v5  v6  v7
28# C0  x6 v16 v20 v24 v28
29# C1  x8 v17 v21 v25 v29
30# C2  x9 v18 v22 v26 v30
31# C3  x7 v19 v23 v27 v31
32# unused v8 v9 v10 v11 v12 v13 v14 v15
33
34BEGIN_FUNCTION xnn_qs8_gemm_minmax_ukernel_4x16c4__aarch64_neondot_ld32
35
36        # Clamp A and C pointers
37        CMP      x0, 2             // if mr < 2
38        ADD      x2, x2, 3         // kc = (kc + 3) & ~3
39        ADD     x15, x3, x4        // a1 = a0 + a_stride
40        ADD      x8, x6, x7        // c1 = c0 + cm_stride
41        CSEL    x15, x3, x15, LO   //   a1 = a0
42        CSEL     x8, x6,  x8, LO   //   c1 = c0
43        BIC      x2, x2, 3
44
45        ADD     x13, x15, x4       // a2 = a1 + a_stride
46        ADD      x9,  x8, x7       // c2 = c1 + cm_stride
47                                   // if mr <= 2
48        CSEL    x13, x15, x13, LS  //   a2 = a1
49        CSEL     x9,  x8,  x9, LS  //   c2 = c1
50
51        CMP      x0, 4             // if mr < 4
52        ADD      x4, x13, x4       // a3 = a2 + a_stride
53        ADD      x7,  x9, x7       // c3 = c2 + cm_stride
54        CSEL     x4, x13, x4, LO   //   a3 = a2
55        CSEL     x7,  x9, x7, LO   //   c3 = c2
56
57        .p2align 3
580:
59        # Load initial bias from w into accumulators
60        LDP     q16, q20, [x5], 32
61        MOV     v17.16b, v16.16b
62        MOV     v18.16b, v16.16b
63        LDP     q24, q28, [x5], 32
64        MOV     v19.16b, v16.16b
65        MOV     v21.16b, v20.16b
66        LDR     x11, [sp, 8]       // params
67        MOV     v22.16b, v20.16b
68        MOV     v23.16b, v20.16b
69        MOV     x0, x2             // k = kc.  assumes kc > 0
70        MOV     v25.16b, v24.16b
71        MOV     v26.16b, v24.16b
72        MOV     v27.16b, v24.16b
73        MOV     v29.16b, v28.16b
74        MOV     v30.16b, v28.16b
75        MOV     v31.16b, v28.16b
76
77        # Main loop - 4 bytes of A
78        .p2align 3
791:
80        LDR     s0,  [x3], 4
81        LDR     q4, [x5], 16
82        LDR     s1, [x15], 4
83        LDR     s2, [x13], 4
84        LDR     s3,  [x4], 4
85        SDOT    v16.4s, v4.16b, v0.4b[0]
86        SDOT    v17.4s, v4.16b, v1.4b[0]
87        LDR     q5, [x5], 16
88        SDOT    v18.4s, v4.16b, v2.4b[0]
89        SDOT    v19.4s, v4.16b, v3.4b[0]
90        LDR     q6, [x5], 16
91        SDOT    v20.4s, v5.16b, v0.4b[0]
92        SDOT    v21.4s, v5.16b, v1.4b[0]
93        LDR     q7, [x5], 16
94        SDOT    v22.4s, v5.16b, v2.4b[0]
95        SDOT    v23.4s, v5.16b, v3.4b[0]
96        SUBS    x0, x0, 4
97        SDOT    v24.4s, v6.16b, v0.4b[0]
98        SDOT    v25.4s, v6.16b, v1.4b[0]
99        SDOT    v26.4s, v6.16b, v2.4b[0]
100        SDOT    v27.4s, v6.16b, v3.4b[0]
101        SDOT    v28.4s, v7.16b, v0.4b[0]
102        SDOT    v29.4s, v7.16b, v1.4b[0]
103        SDOT    v30.4s, v7.16b, v2.4b[0]
104        SDOT    v31.4s, v7.16b, v3.4b[0]
105        B.HI    1b
106
107        # Apply params - scale, shift, bias and clamp
108        LD2R    {v0.4s, v1.4s}, [x11], 8
109        CMEQ    v2.4s, v1.4s, 0
110
111        BIC     v4.16b, v16.16b, v2.16b
112        BIC     v5.16b, v17.16b, v2.16b
113        BIC     v6.16b, v18.16b, v2.16b
114        BIC     v7.16b, v19.16b, v2.16b
115
116        SQRDMULH  v16.4s, v16.4s, v0.4s
117        SQRDMULH  v17.4s, v17.4s, v0.4s
118        SQRDMULH  v18.4s, v18.4s, v0.4s
119        SQRDMULH  v19.4s, v19.4s, v0.4s
120
121        SSRA    v16.4s, v4.4s, 31  // signed shift right accumulate
122        SSRA    v17.4s, v5.4s, 31
123        SSRA    v18.4s, v6.4s, 31
124        SSRA    v19.4s, v7.4s, 31
125
126        BIC     v4.16b, v20.16b, v2.16b
127        BIC     v5.16b, v21.16b, v2.16b
128        BIC     v6.16b, v22.16b, v2.16b
129        BIC     v7.16b, v23.16b, v2.16b
130
131        SQRDMULH  v20.4s, v20.4s, v0.4s
132        SQRDMULH  v21.4s, v21.4s, v0.4s
133        SQRDMULH  v22.4s, v22.4s, v0.4s
134        SQRDMULH  v23.4s, v23.4s, v0.4s
135
136        SSRA    v20.4s, v4.4s, 31
137        SSRA    v21.4s, v5.4s, 31
138        SSRA    v22.4s, v6.4s, 31
139        SSRA    v23.4s, v7.4s, 31
140
141        BIC     v4.16b, v24.16b, v2.16b
142        BIC     v5.16b, v25.16b, v2.16b
143        BIC     v6.16b, v26.16b, v2.16b
144        BIC     v7.16b, v27.16b, v2.16b
145
146        SQRDMULH  v24.4s, v24.4s, v0.4s
147        SQRDMULH  v25.4s, v25.4s, v0.4s
148        SQRDMULH  v26.4s, v26.4s, v0.4s
149        SQRDMULH  v27.4s, v27.4s, v0.4s
150
151        SSRA    v24.4s, v4.4s, 31
152        SSRA    v25.4s, v5.4s, 31
153        SSRA    v26.4s, v6.4s, 31
154        SSRA    v27.4s, v7.4s, 31
155
156        BIC     v4.16b, v28.16b, v2.16b
157        BIC     v5.16b, v29.16b, v2.16b
158        BIC     v6.16b, v30.16b, v2.16b
159        BIC     v7.16b, v31.16b, v2.16b
160
161        SQRDMULH  v28.4s, v28.4s, v0.4s
162        SQRDMULH  v29.4s, v29.4s, v0.4s
163        SQRDMULH  v30.4s, v30.4s, v0.4s
164        SQRDMULH  v31.4s, v31.4s, v0.4s
165
166        SSRA    v28.4s, v4.4s, 31
167        SSRA    v29.4s, v5.4s, 31
168        SSRA    v30.4s, v6.4s, 31
169        SSRA    v31.4s, v7.4s, 31
170
171        SRSHL   v16.4s, v16.4s, v1.4s  // signed rounding shift left
172        SRSHL   v17.4s, v17.4s, v1.4s
173        SRSHL   v18.4s, v18.4s, v1.4s
174        SRSHL   v19.4s, v19.4s, v1.4s
175        SRSHL   v20.4s, v20.4s, v1.4s
176        SRSHL   v21.4s, v21.4s, v1.4s
177        SRSHL   v22.4s, v22.4s, v1.4s
178        SRSHL   v23.4s, v23.4s, v1.4s
179        SRSHL   v24.4s, v24.4s, v1.4s
180        SRSHL   v25.4s, v25.4s, v1.4s
181        SRSHL   v26.4s, v26.4s, v1.4s
182        SRSHL   v27.4s, v27.4s, v1.4s
183        SRSHL   v28.4s, v28.4s, v1.4s
184        SRSHL   v29.4s, v29.4s, v1.4s
185        SRSHL   v30.4s, v30.4s, v1.4s
186        SRSHL   v31.4s, v31.4s, v1.4s
187
188        SQXTN   v16.4h, v16.4s
189        SQXTN   v17.4h, v17.4s
190        SQXTN   v18.4h, v18.4s
191        SQXTN   v19.4h, v19.4s
192        SQXTN   v24.4h, v24.4s
193        SQXTN   v25.4h, v25.4s
194        SQXTN   v26.4h, v26.4s
195        SQXTN   v27.4h, v27.4s
196        LD1R    {v2.8h}, [x11], 2   // add bias
197
198        SQXTN2  v16.8h, v20.4s
199        SQXTN2  v17.8h, v21.4s
200        SQXTN2  v18.8h, v22.4s
201        SQXTN2  v19.8h, v23.4s
202        SQXTN2  v24.8h, v28.4s
203        SQXTN2  v25.8h, v29.4s
204        SQXTN2  v26.8h, v30.4s
205        SQXTN2  v27.8h, v31.4s
206
207        SQADD   v16.8h, v16.8h, v2.8h
208        SQADD   v17.8h, v17.8h, v2.8h
209        SQADD   v18.8h, v18.8h, v2.8h
210        SQADD   v19.8h, v19.8h, v2.8h
211        SQADD   v24.8h, v24.8h, v2.8h
212        SQADD   v25.8h, v25.8h, v2.8h
213        SQADD   v26.8h, v26.8h, v2.8h
214        SQADD   v27.8h, v27.8h, v2.8h
215        LD1R    {v0.16b}, [x11], 1  // clamp min value
216
217        SQXTN    v4.8b, v16.8h
218        SQXTN    v5.8b, v17.8h
219        SQXTN    v6.8b, v18.8h
220        SQXTN    v7.8b, v19.8h
221        LD1R    {v1.16b}, [x11]     // clamp max value
222        SQXTN2   v4.16b, v24.8h
223        SQXTN2   v5.16b, v25.8h
224        SQXTN2   v6.16b, v26.8h
225        SQXTN2   v7.16b, v27.8h
226        LDR     x12, [sp]   // cn_stride
227
228        SMAX    v4.16b, v4.16b, v0.16b
229        SMAX    v5.16b, v5.16b, v0.16b
230        SMAX    v6.16b, v6.16b, v0.16b
231        SMAX    v7.16b, v7.16b, v0.16b
232        SUBS    x1, x1, 16
233        SMIN    v4.16b, v4.16b, v1.16b
234        SMIN    v5.16b, v5.16b, v1.16b
235        SMIN    v6.16b, v6.16b, v1.16b
236        SMIN    v7.16b, v7.16b, v1.16b
237        B.LO    2f
238
239        # Store full 4 x 16
240        ST1     {v4.16b}, [x6], x12
241        SUB      x3,  x3, x2         // a0 -= kc
242        ST1     {v5.16b}, [x8], x12
243        SUB     x15, x15, x2         // a1 -= kc
244        ST1     {v6.16b}, [x9], x12
245        SUB     x13, x13, x2         // a2 -= kc
246        ST1     {v7.16b}, [x7], x12
247        SUB      x4,  x4, x2         // a3 -= kc
248        B.NE    0b
249        RET
250
251        # Store odd width
252        .p2align 3
2532:
254        TBZ     x1, 3, 3f
255        STR     d4, [x6], 8
256        DUP     d4, v4.d[1]
257        STR     d5, [x8], 8
258        DUP     d5, v5.d[1]
259        STR     d6, [x9], 8
260        DUP     d6, v6.d[1]
261        STR     d7, [x7], 8
262        DUP     d7, v7.d[1]
2633:
264        TBZ     x1, 2, 4f
265        STR     s4, [x6], 4
266        DUP     s4, v4.s[1]
267        STR     s5, [x8], 4
268        DUP     s5, v5.s[1]
269        STR     s6, [x9], 4
270        DUP     s6, v6.s[1]
271        STR     s7, [x7], 4
272        DUP     s7, v7.s[1]
2734:
274        TBZ     x1, 1, 5f
275        ST1     {v4.h}[0], [x6], 2
276        DUP     h4, v4.h[1]
277        ST1     {v5.h}[0], [x8], 2
278        DUP     h5, v5.h[1]
279        ST1     {v6.h}[0], [x9], 2
280        DUP     h6, v6.h[1]
281        ST1     {v7.h}[0], [x7], 2
282        DUP     h7, v7.h[1]
2835:
284        TBZ     x1, 0, 6f
285        ST1     {v4.b}[0], [x6]
286        ST1     {v5.b}[0], [x8]
287        ST1     {v6.b}[0], [x9]
288        ST1     {v7.b}[0], [x7]
2896:
290        RET
291
292END_FUNCTION xnn_qs8_gemm_minmax_ukernel_4x16c4__aarch64_neondot_ld32
293
294#ifdef __ELF__
295.section ".note.GNU-stack","",%progbits
296#endif
297