1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <xnnpack/assembly.h>
7
8# void xnn_f16_gemm${"inc" if INC else ""}_minmax_ukernel_8x8__aarch64_neonfp16arith_ld64(
9#     size_t mr,                x0
10#     size_t nc,                x1
11#     size_t kc,                x2 / x0
12#     const uint8_t*restrict a, x3
13#     size_t a_stride,          x4
14#     const void*restrict w,    x5
15#     uint8_t*restrict c,       x6
16#     size_t cm_stride,         x7
17#     size_t cn_stride,         [sp] -> (x0)
18$if INC:
19  #     const float*restrict acc,  [sp + 8] -> x15
20  #     const union xnn_f16_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)])  [sp + 16] -> x8
21$else:
22  #     const union xnn_f16_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)])  [sp + 8] -> x8
23
24# d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.
25
26# A pointers
27#  x3 a0
28#  x9 a1
29# x10 a2
30# x11 a3
31# x12 a4
32# x19 a5
33# x20 a6
34#  x4 a7
35
36# C pointers
37#  x6 c0
38# x16 c1
39# x17 c2
40# x14 c3
41# x13 c4
42# x21 c5
43# x22 c6
44#  x7 c7
45
46# Vector register usage
47# A0   v0
48# A1   v1
49# A2   v2
50# A3   v3
51# A4   v4
52# A5   v5
53# A6   v6
54# A7   v7
55# B   v16 v17 v18 v19
56# C   v24
57# C   v25
58# C   v26
59# C   v27
60# C   v28
61# C   v29
62# C   v30
63# C   v31
64
65# Clamp v20 v21 v22
66# unused A   v8 v9 v10 v11
67# unused B   v12 v13 v14 v15
68
69BEGIN_FUNCTION xnn_f16_gemm${"inc" if INC else ""}_minmax_ukernel_8x8__aarch64_neonfp16arith_ld64
70
71        $if INC:
72          # Load acc, params pointer
73          LDP x15, x8, [sp, 8]
74        $else:
75          # Load params pointer
76          LDR x8, [sp, 8]
77
78        # Save x19,x20,x21,x22 on stack
79        STP x19, x20, [sp, -32]!
80        STP x21, x22, [sp, 16]
81
82        # Clamp A and C pointers
83        CMP x0, 2                // if mr < 2
84        ADD x9, x3, x4           // a1 = a0 + a_stride
85        ADD x16, x6, x7          // c1 = c0 + cm_stride
86        CSEL x9, x3, x9, LO      //   a1 = a0
87        CSEL x16, x6, x16, LO    //   c1 = c0
88
89        ADD x10, x9, x4          // a2 = a1 + a_stride
90        ADD x17, x16, x7         // c2 = c1 + cm_stride
91                                 // if mr <= 2
92        CSEL x10, x9, x10, LS    //   a2 = a1
93        CSEL x17, x16, x17, LS   //   c2 = c1
94
95        CMP x0, 4                // if mr < 4
96        ADD x11, x10, x4         // a3 = a2 + a_stride
97        ADD x14, x17, x7         // c3 = c2 + cm_stride
98        CSEL x11, x10, x11, LO   //   a3 = a2
99        CSEL x14, x17, x14, LO   //   c3 = c2
100
101        ADD x12, x11, x4         // a4 = a3 + a_stride
102        ADD x13, x14, x7         // c4 = c3 + cm_stride
103                                 // if mr <= 4
104        CSEL x12, x11, x12, LS   //   a4 = a3
105        CSEL x13, x14, x13, LS   //   c4 = c3
106
107        CMP x0, 6                // if mr < 6
108        ADD x19, x12, x4         // a5 = a4 + a_stride
109        ADD x21, x13, x7         // c5 = c4 + cm_stride
110        CSEL x19, x12, x19, LO   //   a5 = a4
111        CSEL x21, x13, x21, LO   //   c5 = c4
112
113        ADD x20, x19, x4         // a6 = a5 + a_stride
114        ADD x22, x21, x7         // c6 = c5 + cm_stride
115                                 // if mr <= 6
116        CSEL x20, x19, x20, LS   //   a6 = a5
117        CSEL x22, x21, x22, LS   //   c6 = c5
118
119        CMP x0, 8                // if mr < 8
120        ADD x4, x20, x4          // a7 = a5 + a_stride
121        ADD x7, x22, x7          // c7 = c5 + cm_stride
122        CSEL x4, x20, x4, LO     //   a7 = a5
123        CSEL x7, x22, x7, LO     //   c7 = c5
124
125        # Load params scale value
126        LD3R {v20.8h, v21.8h, v22.8h}, [x8]
127
1280:
129        $if INC:
130          # Load initial accumulators
131          LDP q24, q25, [x15], 32
132          LDP q26, q27, [x15], 32
133          LDP q28, q29, [x15], 32
134          LDP q30, q31, [x15], 32
135       $else:
136          # Load initial bias from w into accumulators
137          LDR q24, [x5], 16
138          MOV v25.16b, v24.16b
139          MOV v26.16b, v24.16b
140          MOV v27.16b, v24.16b
141          MOV v28.16b, v24.16b
142          MOV v29.16b, v24.16b
143          MOV v30.16b, v24.16b
144          MOV v31.16b, v24.16b
145
146         # Is there at least 4 halffloats (8 bytes)?
147        SUBS x0, x2, 8  // k = kc - 8
148        B.LO 3f
149
150        # Main loop - 4 halffloats of A (8 bytes)
151        # 32 FMA + 8 ld64 A + 4 LDR B
1521:
153        LDR   d0,  [x3], 8
154        LDR  q16,  [x5], 16
155        LDR  q17,  [x5], 16
156        LDR   d1,  [x9], 8
157        LDR   d2, [x10], 8
158        LDR   d3, [x11], 8
159        LDR   d4, [x12], 8
160        LDR   d5, [x19], 8
161        LDR   d6, [x20], 8
162        LDR   d7,  [x4], 8
163        SUBS x0, x0, 8
164        FMLA v24.8h, v16.8h,  v0.h[0]
165        FMLA v25.8h, v16.8h,  v1.h[0]
166        FMLA v26.8h, v16.8h,  v2.h[0]
167        FMLA v27.8h, v16.8h,  v3.h[0]
168        FMLA v28.8h, v16.8h,  v4.h[0]
169        FMLA v29.8h, v16.8h,  v5.h[0]
170        FMLA v30.8h, v16.8h,  v6.h[0]
171        FMLA v31.8h, v16.8h,  v7.h[0]
172        LDR  q18,  [x5], 16
173        LDR  q19,  [x5], 16
174
175        FMLA v24.8h, v17.8h,  v0.h[1]
176        FMLA v25.8h, v17.8h,  v1.h[1]
177        FMLA v26.8h, v17.8h,  v2.h[1]
178        FMLA v27.8h, v17.8h,  v3.h[1]
179        FMLA v28.8h, v17.8h,  v4.h[1]
180        FMLA v29.8h, v17.8h,  v5.h[1]
181        FMLA v30.8h, v17.8h,  v6.h[1]
182        FMLA v31.8h, v17.8h,  v7.h[1]
183
184        FMLA v24.8h, v18.8h,  v0.h[2]
185        FMLA v25.8h, v18.8h,  v1.h[2]
186        FMLA v26.8h, v18.8h,  v2.h[2]
187        FMLA v27.8h, v18.8h,  v3.h[2]
188        FMLA v28.8h, v18.8h,  v4.h[2]
189        FMLA v29.8h, v18.8h,  v5.h[2]
190        FMLA v30.8h, v18.8h,  v6.h[2]
191        FMLA v31.8h, v18.8h,  v7.h[2]
192
193        FMLA v24.8h, v19.8h,  v0.h[3]
194        FMLA v25.8h, v19.8h,  v1.h[3]
195        FMLA v26.8h, v19.8h,  v2.h[3]
196        FMLA v27.8h, v19.8h,  v3.h[3]
197        FMLA v28.8h, v19.8h,  v4.h[3]
198        FMLA v29.8h, v19.8h,  v5.h[3]
199        FMLA v30.8h, v19.8h,  v6.h[3]
200        FMLA v31.8h, v19.8h,  v7.h[3]
201        B.HS 1b
202
203        # Is there a remainder?- 2 halffloats of A (4 bytes)
204        TBNZ x0, 2, 4f
205        # Is there a remainder?- 1 halffloats of A (2 bytes)
206        TBNZ x0, 1, 5f
2072:
208        # Scale and Clamp
209        FMUL v24.8h, v24.8h, v20.8h
210        FMUL v25.8h, v25.8h, v20.8h
211        FMUL v26.8h, v26.8h, v20.8h
212        FMUL v27.8h, v27.8h, v20.8h
213        FMUL v28.8h, v28.8h, v20.8h
214        FMUL v29.8h, v29.8h, v20.8h
215        FMUL v30.8h, v30.8h, v20.8h
216        FMUL v31.8h, v31.8h, v20.8h
217        # Load cn_stride
218        LDR x0, [sp, 32]
219        FMAX v24.8h, v24.8h, v21.8h
220        FMAX v25.8h, v25.8h, v21.8h
221        FMAX v26.8h, v26.8h, v21.8h
222        FMAX v27.8h, v27.8h, v21.8h
223        FMAX v28.8h, v28.8h, v21.8h
224        FMAX v29.8h, v29.8h, v21.8h
225        FMAX v30.8h, v30.8h, v21.8h
226        FMAX v31.8h, v31.8h, v21.8h
227        SUBS x1, x1, 8
228        FMIN v24.8h, v24.8h, v22.8h
229        FMIN v25.8h, v25.8h, v22.8h
230        FMIN v26.8h, v26.8h, v22.8h
231        FMIN v27.8h, v27.8h, v22.8h
232        FMIN v28.8h, v28.8h, v22.8h
233        FMIN v29.8h, v29.8h, v22.8h
234        FMIN v30.8h, v30.8h, v22.8h
235        FMIN v31.8h, v31.8h, v22.8h
236
237        # Store full 8 x 8
238        B.LO 6f
239
240        $if INC:
241          ST1 {v31.16b},  [x7], x0
242          SUB  x3,  x3, x2 // a0 -= kc
243          ST1 {v30.16b}, [x22], x0
244          SUB  x9,  x9, x2 // a1 -= kc
245          ST1 {v29.16b}, [x21], x0
246          SUB x10, x10, x2 // a2 -= kc
247          ST1 {v28.16b}, [x13], x0
248          SUB x11, x11, x2 // a3 -= kc
249          ST1 {v27.16b}, [x14], x0
250          SUB x12, x12, x2 // a4 -= kc
251          ST1 {v26.16b}, [x17], x0
252          SUB x19, x19, x2 // a6 -= kc
253          ST1 {v25.16b}, [x16], x0
254          SUB x20, x20, x2 // a6 -= kc
255          ST1 {v24.16b},  [x6], x0
256          SUB  x4,  x4, x2 // a7 -= kc
257        $else:
258          ST1 {v24.16b},  [x6], x0
259          SUB  x3,  x3, x2 // a0 -= kc
260          ST1 {v25.16b}, [x16], x0
261          SUB  x9,  x9, x2 // a1 -= kc
262          ST1 {v26.16b}, [x17], x0
263          SUB x10, x10, x2 // a2 -= kc
264          ST1 {v27.16b}, [x14], x0
265          SUB x11, x11, x2 // a3 -= kc
266          ST1 {v28.16b}, [x13], x0
267          SUB x12, x12, x2 // a4 -= kc
268          ST1 {v29.16b}, [x21], x0
269          SUB x19, x19, x2 // a6 -= kc
270          ST1 {v30.16b}, [x22], x0
271          SUB x20, x20, x2 // a6 -= kc
272          ST1 {v31.16b},  [x7], x0
273          SUB  x4,  x4, x2 // a7 -= kc
274
275        B.HI 0b
276
277        # Restore x19,x20,x21,x22 from stack
278        LDP x21, x22, [sp, 16]
279        LDP x19, x20, [sp], 32
280        RET
281
2823:
283        TBZ x0, 2, 5f
2844:
285        # Remainder- 2 halffloats of A (4 bytes)
286        LDR   s0,  [x3], 4
287        LDR  q16,  [x5], 16
288        LDR  q17,  [x5], 16
289        LDR   s1,  [x9], 4
290        LDR   s2, [x10], 4
291        LDR   s3, [x11], 4
292        LDR   s4, [x12], 4
293        LDR   s5, [x19], 4
294        LDR   s6, [x20], 4
295        LDR   s7,  [x4], 4
296
297        FMLA v24.8h, v16.8h,  v0.h[0]
298        FMLA v25.8h, v16.8h,  v1.h[0]
299        FMLA v26.8h, v16.8h,  v2.h[0]
300        FMLA v27.8h, v16.8h,  v3.h[0]
301        FMLA v28.8h, v16.8h,  v4.h[0]
302        FMLA v29.8h, v16.8h,  v5.h[0]
303        FMLA v30.8h, v16.8h,  v6.h[0]
304        FMLA v31.8h, v16.8h,  v7.h[0]
305
306        FMLA v24.8h, v17.8h,  v0.h[1]
307        FMLA v25.8h, v17.8h,  v1.h[1]
308        FMLA v26.8h, v17.8h,  v2.h[1]
309        FMLA v27.8h, v17.8h,  v3.h[1]
310        FMLA v28.8h, v17.8h,  v4.h[1]
311        FMLA v29.8h, v17.8h,  v5.h[1]
312        FMLA v30.8h, v17.8h,  v6.h[1]
313        FMLA v31.8h, v17.8h,  v7.h[1]
314
315        TBZ x0, 1, 2b
316
3175:
318        # Remainder- 1 halffloat of A (2 bytes)
319        LDR   h0,  [x3], 2
320        LDR  q16,  [x5], 16
321        LDR   h1,  [x9], 2
322        LDR   h2, [x10], 2
323        LDR   h3, [x11], 2
324        LDR   h4, [x12], 2
325        LDR   h5, [x19], 2
326        LDR   h6, [x20], 2
327        LDR   h7,  [x4], 2
328
329        FMLA v24.8h, v16.8h,  v0.h[0]
330        FMLA v25.8h, v16.8h,  v1.h[0]
331        FMLA v26.8h, v16.8h,  v2.h[0]
332        FMLA v27.8h, v16.8h,  v3.h[0]
333        FMLA v28.8h, v16.8h,  v4.h[0]
334        FMLA v29.8h, v16.8h,  v5.h[0]
335        FMLA v30.8h, v16.8h,  v6.h[0]
336        FMLA v31.8h, v16.8h,  v7.h[0]
337        B 2b
338
339        # Store odd width
3406:
341        TBZ x1, 2, 7f
342        $if INC:
343          STR d31,  [x7], 8
344          DUP d31, v31.d[1]
345          STR d30, [x22], 8
346          DUP d30, v30.d[1]
347          STR d29, [x21], 8
348          DUP d29, v29.d[1]
349          STR d28, [x13], 8
350          DUP d28, v28.d[1]
351          STR d27, [x14], 8
352          DUP d27, v27.d[1]
353          STR d26, [x17], 8
354          DUP d26, v26.d[1]
355          STR d25, [x16], 8
356          DUP d25, v25.d[1]
357          STR d24,  [x6], 8
358          DUP d24, v24.d[1]
359        $else:
360          STR d24,  [x6], 8
361          DUP d24, v24.d[1]
362          STR d25, [x16], 8
363          DUP d25, v25.d[1]
364          STR d26, [x17], 8
365          DUP d26, v26.d[1]
366          STR d27, [x14], 8
367          DUP d27, v27.d[1]
368          STR d28, [x13], 8
369          DUP d28, v28.d[1]
370          STR d29, [x21], 8
371          DUP d29, v29.d[1]
372          STR d30, [x22], 8
373          DUP d30, v30.d[1]
374          STR d31,  [x7], 8
375          DUP d31, v31.d[1]
3767:
377        TBZ x1, 1, 8f
378        $if INC:
379          STR s31,  [x7], 4
380          DUP s31, v31.s[1]
381          STR s30, [x22], 4
382          DUP s30, v30.s[1]
383          STR s29, [x21], 4
384          DUP s29, v29.s[1]
385          STR s28, [x13], 4
386          DUP s28, v28.s[1]
387          STR s27, [x14], 4
388          DUP s27, v27.s[1]
389          STR s26, [x17], 4
390          DUP s26, v26.s[1]
391          STR s25, [x16], 4
392          DUP s25, v25.s[1]
393          STR s24,  [x6], 4
394          DUP s24, v24.s[1]
395        $else:
396          STR s24,  [x6], 4
397          DUP s24, v24.s[1]
398          STR s25, [x16], 4
399          DUP s25, v25.s[1]
400          STR s26, [x17], 4
401          DUP s26, v26.s[1]
402          STR s27, [x14], 4
403          DUP s27, v27.s[1]
404          STR s28, [x13], 4
405          DUP s28, v28.s[1]
406          STR s29, [x21], 4
407          DUP s29, v29.s[1]
408          STR s30, [x22], 4
409          DUP s30, v30.s[1]
410          STR s31,  [x7], 4
411          DUP s31, v31.s[1]
412
4138:
414        TBZ x1, 0, 9f
415        $if INC:
416          STR h31,  [x7]
417          STR h30, [x22]
418          STR h29, [x21]
419          STR h28, [x13]
420          STR h27, [x14]
421          STR h26, [x17]
422          STR h25, [x16]
423          STR h24,  [x6]
424        $else:
425          STR h24,  [x6]
426          STR h25, [x16]
427          STR h26, [x17]
428          STR h27, [x14]
429          STR h28, [x13]
430          STR h29, [x21]
431          STR h30, [x22]
432          STR h31,  [x7]
4339:
434        # Restore x19,x20,x21,x22 from stack
435        LDP x21, x22, [sp, 16]
436        LDP x19, x20, [sp], 32
437        RET
438
439END_FUNCTION xnn_f16_gemm${"inc" if INC else ""}_minmax_ukernel_8x8__aarch64_neonfp16arith_ld64
440
441#ifdef __ELF__
442.section ".note.GNU-stack","",%progbits
443#endif
444