1// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <xnnpack/assembly.h>
7
8# void xnn_f32_gemm${"inc" if INC else ""}_minmax_ukernel_4x8__aarch64_neonfma_ld64(
9#     size_t mr,                x0
10#     size_t nc,                x1
11#     size_t kc,                x2 / x0
12#     const uint8_t*restrict a, x3
13#     size_t a_stride,          x4
14#     const void*restrict w,    x5
15#     uint8_t*restrict c,       x6
16#     size_t cm_stride,         x7
17#     size_t cn_stride,         [sp] -> x14
18$if INC:
19  #     const float*restrict acc,  [sp + 8] -> x15
20  #     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])  [sp + 16] -> x8
21$else:
22  #     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])  [sp + 8] -> x8
23
24# d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.
25
26# A pointers
27# x3  a0
28# x11 a1
29# x12 a2
30# x4  a3 / a_stride
31
32# C pointers
33# x6  c0
34# x9  c1
35# x10 c2
36# x7  c3 / cm_stride
37
38BEGIN_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_minmax_ukernel_4x8__aarch64_neonfma_ld64
39
40        $if INC:
41          # Load cn_stride, acc
42          LDP x14, x15, [sp]
43          # Load params pointer
44          LDR x8, [sp, 16]
45        $else:
46          # Load cn_stride, params pointer
47          LDP x14, x8, [sp]
48
49        # Load min/max values
50        LD2R {v4.4s, v5.4s}, [x8]
51
52        # Clamp A and C pointers
53        CMP x0, 2                // if mr < 2
54        ADD x11, x3, x4          // a1 = a0 + a_stride
55        ADD x9, x6, x7           // c1 = c0 + cm_stride
56        CSEL x11, x3, x11, LO    //   a1 = a0
57        CSEL x9, x6, x9, LO      //   c1 = c0
58
59        ADD x12, x11, x4         // a2 = a1 + a_stride
60        ADD x10, x9, x7          // c2 = c1 + cm_stride
61                                 // if mr <= 2
62        CSEL x12, x11, x12, LS   //   a2 = a1
63        CSEL x10, x9, x10, LS    //   c2 = c1
64
65        CMP x0, 4                // if mr < 4
66        ADD x4, x12, x4          // a3 = a2 + a_stride
67        ADD x7, x10, x7          // c3 = c2 + cm_stride
68        CSEL x4, x12, x4, LO     //   a3 = a2
69        CSEL x7, x10, x7, LO     //   c3 = c2
70
710:
72        $if INC:
73          # Load initial accumulators
74          LDP q16, q17, [x15], 32
75          LDP q18, q19, [x15], 32
76          LDP q28, q29, [x15], 32
77          LDP q30, q31, [x15], 32
78        $else:
79          # Load initial bias from w into accumulators
80          LDP q16, q17, [x5], 32
81          MOV v18.16b, v16.16b
82          MOV v19.16b, v17.16b
83          MOV v28.16b, v16.16b
84          MOV v29.16b, v17.16b
85          MOV v30.16b, v16.16b
86          MOV v31.16b, v17.16b
87
88        # Is there at least 2 floats (8 bytes)?
89        SUBS x0, x2, 8  // k = kc - 8
90        B.LO 3f
91
92        # Main loop - 2 floats of A (8 bytes)
931:
94        LDR d0,  [x3], 8
95        LDP q20, q21, [x5], 32
96        LDR d1, [x11], 8
97        LDR d2, [x12], 8
98        LDR d3,  [x4], 8
99        FMLA v16.4s, v20.4s, v0.s[0]
100        FMLA v17.4s, v21.4s, v0.s[0]
101        FMLA v18.4s, v20.4s, v1.s[0]
102        FMLA v19.4s, v21.4s, v1.s[0]
103        LDP q22, q23, [x5], 32
104        FMLA v28.4s, v20.4s, v2.s[0]
105        FMLA v29.4s, v21.4s, v2.s[0]
106        FMLA v30.4s, v20.4s, v3.s[0]
107        FMLA v31.4s, v21.4s, v3.s[0]
108        FMLA v16.4s, v22.4s, v0.s[1]
109        FMLA v17.4s, v23.4s, v0.s[1]
110        FMLA v18.4s, v22.4s, v1.s[1]
111        FMLA v19.4s, v23.4s, v1.s[1]
112        FMLA v28.4s, v22.4s, v2.s[1]
113        FMLA v29.4s, v23.4s, v2.s[1]
114        SUBS x0, x0, 8
115        FMLA v30.4s, v22.4s, v3.s[1]
116        FMLA v31.4s, v23.4s, v3.s[1]
117        B.HS 1b
118
119        # Is there a remainder?- 1 floats of A (4 bytes)
120        TBNZ x0, 2, 3f
121
1222:
123        # Clamp
124        FMAX v16.4s, v16.4s, v4.4s
125        SUBS x1, x1, 8
126        FMAX v17.4s, v17.4s, v4.4s
127        FMAX v18.4s, v18.4s, v4.4s
128        FMAX v19.4s, v19.4s, v4.4s
129        FMAX v28.4s, v28.4s, v4.4s
130        FMAX v29.4s, v29.4s, v4.4s
131        FMAX v30.4s, v30.4s, v4.4s
132        FMAX v31.4s, v31.4s, v4.4s
133        FMIN v16.4s, v16.4s, v5.4s
134        FMIN v17.4s, v17.4s, v5.4s
135        FMIN v18.4s, v18.4s, v5.4s
136        FMIN v19.4s, v19.4s, v5.4s
137        FMIN v28.4s, v28.4s, v5.4s
138        FMIN v29.4s, v29.4s, v5.4s
139        FMIN v30.4s, v30.4s, v5.4s
140        FMIN v31.4s, v31.4s, v5.4s
141
142        # Store full 4 x 8
143        B.LO 4f
144
145        $if INC:
146          ST1 {v30.16b, v31.16b},  [x7], x14
147          SUB  x3,  x3, x2 // a0 -= kc
148          ST1 {v28.16b, v29.16b}, [x10], x14
149          SUB x11, x11, x2 // a1 -= kc
150          ST1 {v18.16b, v19.16b},  [x9], x14
151          SUB x12, x12, x2 // a2 -= kc
152          ST1 {v16.16b, v17.16b},  [x6], x14
153          SUB  x4,  x4, x2 // a3 -= kc
154        $else:
155          ST1 {v16.16b, v17.16b},  [x6], x14
156          SUB  x3,  x3, x2 // a0 -= kc
157          ST1 {v18.16b, v19.16b},  [x9], x14
158          SUB x11, x11, x2 // a1 -= kc
159          ST1 {v28.16b, v29.16b}, [x10], x14
160          SUB x12, x12, x2 // a2 -= kc
161          ST1 {v30.16b, v31.16b},  [x7], x14
162          SUB  x4,  x4, x2 // a3 -= kc
163
164        B.HI 0b
165
166        RET
167
168        # Remainder- 1 float of A (4 bytes)
1693:
170        LDR s0,  [x3], 4
171        LDP q20, q21, [x5], 32
172        LDR s1, [x11], 4
173        LDR s2, [x12], 4
174        LDR s3 , [x4], 4
175        FMLA v16.4s, v20.4s, v0.s[0]
176        FMLA v17.4s, v21.4s, v0.s[0]
177        FMLA v18.4s, v20.4s, v1.s[0]
178        FMLA v19.4s, v21.4s, v1.s[0]
179        FMLA v28.4s, v20.4s, v2.s[0]
180        FMLA v29.4s, v21.4s, v2.s[0]
181        FMLA v30.4s, v20.4s, v3.s[0]
182        FMLA v31.4s, v21.4s, v3.s[0]
183        B 2b
184
185        # Store odd width
1864:
187        TBZ x1, 2, 5f
188        $if INC:
189          STR q30, [x7], 16
190          MOV v30.16b, v31.16b
191          STR q28, [x10], 16
192          MOV v28.16b, v29.16b
193          STR q18, [x9], 16
194          MOV v18.16b, v19.16b
195          STR q16, [x6], 16
196          MOV v16.16b, v17.16b
197        $else:
198          STR q16, [x6], 16
199          MOV v16.16b, v17.16b
200          STR q18, [x9], 16
201          MOV v18.16b, v19.16b
202          STR q28, [x10], 16
203          MOV v28.16b, v29.16b
204          STR q30, [x7], 16
205          MOV v30.16b, v31.16b
206
2075:
208        TBZ x1, 1, 6f
209        $if INC:
210          STR d30, [x7], 8
211          DUP d30, v30.d[1]
212          STR d28, [x10], 8
213          DUP d28, v28.d[1]
214          STR d18, [x9], 8
215          DUP d18, v18.d[1]
216          STR d16, [x6], 8
217          DUP d16, v16.d[1]
218        $else:
219          STR d16, [x6], 8
220          DUP d16, v16.d[1]
221          STR d18, [x9], 8
222          DUP d18, v18.d[1]
223          STR d28, [x10], 8
224          DUP d28, v28.d[1]
225          STR d30, [x7], 8
226          DUP d30, v30.d[1]
227
2286:
229        TBZ x1, 0, 7f
230        $if INC:
231          STR s30,  [x7]
232          STR s28, [x10]
233          STR s18,  [x9]
234          STR s16,  [x6]
235        $else:
236          STR s16,  [x6]
237          STR s18,  [x9]
238          STR s28, [x10]
239          STR s30,  [x7]
2407:
241        RET
242
243END_FUNCTION xnn_f32_gemm${"inc" if INC else ""}_minmax_ukernel_4x8__aarch64_neonfma_ld64
244
245#ifdef __ELF__
246.section ".note.GNU-stack","",%progbits
247#endif
248